Post

Bayesian Optimization with Ax

Following this Distill post on Bayesian Optimization very closely, I wanted to recreate the gold finding toy example using Ax and BoTorch. Since the OP already gives a very in-depth explanation of Bayesian Optimization, I will focus on code implementation and only summarize the mathematical details of each block of code where applicable.

1
2
# For the purposes of the blog, hide the imports
from imports import *

See this repo for import details.

Mining Gold!

I used this website to recreate the gold content function from the OP.

1
2
3
4
5
6
7
8
def gold(x: float) -> float:
    return (
        2.85
        + 9.20 * np.power(x, 1)
        + -8.32e0 * np.power(x, 2)
        + 2.39 * np.power(x, 3)
        + -2.10e-1 * np.power(x, 4)
    )
1
2
3
4
5
6
7
8
9
10
# Gold can occur between 0 and 6.
x = np.linspace(0, 6)
# Setup our observation features to use for model predictions.
observation_features = [
    ObservationFeatures(
        parameters={"x": i},
    )
    for i in x
]
y = [gold(i) for i in x]

Plot the gold content function over our domain.

1
2
3
4
5
6
7
8
render_plotly_html(
    px.line(
        pd.DataFrame({"X": x, "Gold Content": y}),
        x="X",
        y="Gold Content",
        title="Ground Truth for Gold Content",
    )
)

Active Learning

Ax Helper Functions

First, define some helper functions for common Ax operations.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def evaluate(parameterization: TParameterization) -> TEvaluationOutcome:
    """Evaluate a parameterization on our groundtruth gold function."""
    x = cast(float, parameterization.get("x", 0.0))
    y = gold(x)
    return {"gold": (y, None)}


def create_generation_strategy(
    botorch_acqf_class, random_class=Models.SOBOL, random_trials: int = 3
) -> GenerationStrategy:
    """Strategy consisting of random search followed by surrogate/acquisition."""
    return GenerationStrategy(
        steps=[
            # Always include random steps.
            GenerationStep(
                model=random_class,
                num_trials=random_trials,
            ),
            # Combine an acquisition class with a GP surrogate function.
            GenerationStep(
                model=Models.BOTORCH_MODULAR,
                num_trials=-1,
                model_kwargs={
                    "botorch_acqf_class": botorch_acqf_class,
                    "surrogate": Surrogate(SingleTaskGP),
                },
            ),
        ]
    )


def create_client(generation_strategy: GenerationStrategy) -> AxClient:
    """Create a ax client for the gold content experiment."""
    ax_client = AxClient(
        generation_strategy=generation_strategy,
        random_seed=1,
        verbose_logging=False,
    )
    ax_client.create_experiment(
        name="gold_content",
        parameters=[
            {
                "name": "x",
                "type": "range",
                "bounds": [0.0, 6.0],
                "value_type": "float",
                "log_scale": False,
            },
        ],
        # We want to get the most gold possible.
        objectives={"gold": ObjectiveProperties(minimize=False)},
    )
    return ax_client


Cache = Tuple[Dict[int, TModelPredict], Dict[int, List[List[float]]]]


def run_ax(n: int, ax_client: AxClient, acqf_n: int = 1) -> Cache:
    """Run an ax client for a specified number of iterations."""
    predictions: Dict[int, TModelPredict] = {}
    acquisitions: Dict[int, List[List[float]]] = {}
    for i in range(n):
        # Get the next trial's parameterization.
        parameterization, trial_index = ax_client.get_next_trial()
        # Extract the model for caching this iteration.
        model = cast(
            TorchModelBridge,
            not_none(ax_client.generation_strategy.model),
        )
        # We won't always have an TorchModelBridge that supports evaluation.
        # If we do, run evaluation and cache the results for visualizing.
        try:
            predictions[i] = model.predict(
                observation_features=observation_features,
            )
            acquisitions[i] = [
                model.evaluate_acquisition_function(
                    observation_features=observation_features
                )
                for _ in range(acqf_n)
            ]
        except Exception:
            pass
        # Complete the trial by evaluating the true gold content.
        ax_client.complete_trial(
            trial_index=trial_index, raw_data=evaluate(parameterization)
        )
    return predictions, acquisitions

Store results of each optimization so we can compare the different methods at the end of this section.

1
results: Dict[str, pd.DataFrame] = {}

Surrogate Function

Sample uniformly from the domain to show how we learn the gold content with our surrogate function. Our surrogate function will be a Gaussian Process:

\[\begin{align*} \begin{bmatrix} f(x_1) \\ \vdots \\ f(x_m) \end{bmatrix} &\sim \mathcal N \big(\begin{bmatrix} m(x_1) \\ \vdots \\ m(x_m) \end{bmatrix}, \begin{bmatrix} k(x_1, x_1) & \dots & k(x_1, x_m) \\ \vdots & \ddots & \vdots \\ k(x_m, x_1) & \dots & k(x_m, x_m) \end{bmatrix} \big) \\ m(x) &= \mathbb E[x] \\ k(x, x') &= \mathbb E[(x - m(x))(x' - m(x'))] \\ y^{(i)} &= f(x^{(i)}) + \epsilon^{(i)},\quad i=1,\dots,m, \quad \epsilon^{(i)}\sim N(0, \sigma^2) \end{align*}\]
1
2
3
4
5
6
7
generation_strategy = create_generation_strategy(
    botorch_acqf_class=ProbabilityOfImprovement,
    random_class=Models.UNIFORM,
    random_trials=5,
)
ax_client = create_client(generation_strategy=generation_strategy)
predictions, acquisitions = run_ax(n=10, ax_client=ax_client)
1
2
[WARNING 08-12 08:02:14] ax.service.ax_client: Random seed set to 1. Note that this setting only affects the Sobol quasi-random generator and BoTorch-powered Bayesian optimization models. For the latter models, setting random seed to the same number for two optimizations will make the generated trials similar, but not exactly the same, and over time the trials will diverge more.
[INFO 08-12 08:02:14] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='x', parameter_type=FLOAT, range=[0.0, 6.0])], parameter_constraints=[]).

Plotly can be animated by adding a slider to the figure layout. Note, the first couple points are the uniformly random samples. These are followed by alternating the surrogate function and a query point.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
fig = go.Figure()
# Plot our groundtruth function for comparison.
fig.add_trace(plot_groundtruth(x=x, y=y, kwargs=get_line_kwargs()))
# Iterate through the trials.
n = len(ax_client.generation_strategy._generator_runs)
j = len(fig.data)
for i in range(n):
    # Not every iteration has predictions.
    if i in predictions:
        fig.add_trace(
            plot_predictions(x=x, predictions=predictions, i=i),
        )
    # Finally, add where we observed this iteration.
    fig.add_trace(
        plot_observed(
            ax_client=ax_client,
            gold_fn=gold,
            i=i,
            kwargs=get_scatter_kwargs(),
        )
    )
# Create the animation steps.
steps = [
    dict(
        method="update",
        label=fig.data[i].name,
        args=[
            {
                "visible": [True] * (j + i) + [False] * (len(fig.data)),
                "title": str(i),
            }
        ],
    )
    for i in range(len(fig.data))
]
# Format the layout of the plot
fig.update_yaxes(range=[1, 9])
fig.update_xaxes(range=[-1, 7])
fig.update_layout(
    title="Gold Search",
    xaxis=dict(title="X"),
    # Add the animation steps onto the layout.
    sliders=[dict(active=0, steps=steps)],
)
render_plotly_html(fig)

Cross Validation

1
2
3
4
5
6
7
8
9
# Extract the final model
model = cast(
    TorchModelBridge,
    not_none(ax_client.generation_strategy.model),
)
# Perform cross validation to estimate out of sample error.
cv = cross_validate(model)
diagnostics = compute_diagnostics(cv)
print_markdown(pd.DataFrame(diagnostics).to_markdown())
 Mean prediction CIMAPEwMAPETotal raw effectCorrelation coefficientRank correlationFisher exact test pLog likelihoodMSE
gold0.2391130.04955050.04149651.628740.9386660.8333330.1666678.393690.293955
1
2
3
fig = interact_cross_validation_plotly(cv)
fig.update_layout(dict(width=None))
render_plotly_html(fig)

Bayesian Optimization

Probability of Improvement (PI)

Choose next point based off the highest probability over the current max.

\[\begin{align*} x_{t+1} &= \arg\max_x P(f(x) \geq (f(x^+))) \\ & = \arg\max_x \Phi\big ( \frac{\mu_t(x) - f(x^+)}{\sigma_t(x)} \big ) \end{align*}\]
1
2
3
4
5
6
generation_strategy = create_generation_strategy(
    botorch_acqf_class=ProbabilityOfImprovement
)
ax_client = create_client(generation_strategy=generation_strategy)
predictions, acquisitions = run_ax(n=10, ax_client=ax_client)
results["PI"] = ax_client.get_trials_data_frame()
1
2
3
[WARNING 08-12 08:02:15] ax.service.ax_client: Random seed set to 1. Note that this setting only affects the Sobol quasi-random generator and BoTorch-powered Bayesian optimization models. For the latter models, setting random seed to the same number for two optimizations will make the generated trials similar, but not exactly the same, and over time the trials will diverge more.
[INFO 08-12 08:02:15] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='x', parameter_type=FLOAT, range=[0.0, 6.0])], parameter_constraints=[]).
[WARNING 08-12 08:02:16] ax.service.utils.report_utils: Column reason missing for all trials. Not appending column.
1
2
3
4
5
6
7
8
9
10
11
render_plotly_html(
    plot_surrogate_and_acquisition(
        x=x,
        y=y,
        ax_client=ax_client,
        gold_fn=gold,
        predictions=predictions,
        acquisitions=acquisitions,
        line_kwargs=get_line_kwargs(),
    )
)

Expected Improvement (EI)

Choose the next point that has the highest expected improvement over the current max.

\[\begin{align*} x_{t+1} &= \arg\max_x \mathbb E[\max\{0, h_{t+1}(x) - f(x^+)\}] \\ &= (\mu_t(x) - f(x^+))\Phi(Z) + \sigma_t(x)\phi(Z) \\ Z &= \frac{\mu_t(x) - f(x^+)}{\sigma_t(x)} \end{align*}\]
1
2
3
4
5
6
generation_strategy = create_generation_strategy(
    botorch_acqf_class=ExpectedImprovement,
)
ax_client = create_client(generation_strategy=generation_strategy)
predictions, acquisitions = run_ax(n=10, ax_client=ax_client)
results["EI"] = ax_client.get_trials_data_frame()
1
2
3
[WARNING 08-12 08:02:16] ax.service.ax_client: Random seed set to 1. Note that this setting only affects the Sobol quasi-random generator and BoTorch-powered Bayesian optimization models. For the latter models, setting random seed to the same number for two optimizations will make the generated trials similar, but not exactly the same, and over time the trials will diverge more.
[INFO 08-12 08:02:16] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='x', parameter_type=FLOAT, range=[0.0, 6.0])], parameter_constraints=[]).
[WARNING 08-12 08:02:18] ax.service.utils.report_utils: Column reason missing for all trials. Not appending column.
1
2
3
4
5
6
7
8
9
10
11
render_plotly_html(
    plot_surrogate_and_acquisition(
        x=x,
        y=y,
        ax_client=ax_client,
        gold_fn=gold,
        predictions=predictions,
        acquisitions=acquisitions,
        line_kwargs=get_line_kwargs(),
    )
)

Upper Confidence Bound (UCB)

Choose the next point by trading off between surrogate’s posterior mean and uncertainty.

\[\begin{align*} x_{t+1} = \arg\max_x \mu_t(x) + \sqrt{\beta} \sigma_t(x) \end{align*}\]
1
2
3
4
5
6
generation_strategy = create_generation_strategy(
    botorch_acqf_class=UpperConfidenceBound
)
ax_client = create_client(generation_strategy=generation_strategy)
predictions, acquisitions = run_ax(n=10, ax_client=ax_client)
results["UCB"] = ax_client.get_trials_data_frame()
1
2
3
[WARNING 08-12 08:02:18] ax.service.ax_client: Random seed set to 1. Note that this setting only affects the Sobol quasi-random generator and BoTorch-powered Bayesian optimization models. For the latter models, setting random seed to the same number for two optimizations will make the generated trials similar, but not exactly the same, and over time the trials will diverge more.
[INFO 08-12 08:02:18] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='x', parameter_type=FLOAT, range=[0.0, 6.0])], parameter_constraints=[]).
[WARNING 08-12 08:02:19] ax.service.utils.report_utils: Column reason missing for all trials. Not appending column.
1
2
3
4
5
6
7
8
9
10
11
render_plotly_html(
    plot_surrogate_and_acquisition(
        x=x,
        y=y,
        ax_client=ax_client,
        gold_fn=gold,
        predictions=predictions,
        acquisitions=acquisitions,
        line_kwargs=get_line_kwargs(),
    )
)

Thompson Sampling

Choose the next point by sampling from the surrogate’s posterior (or a sum over a batch of samples) and then optimizing along this trajectory.

1
2
3
4
5
6
generation_strategy = create_generation_strategy(
    botorch_acqf_class=PathwiseThompsonSampling
)
ax_client = create_client(generation_strategy=generation_strategy)
predictions, acquisitions = run_ax(n=10, ax_client=ax_client, acqf_n=5)
results["Thompson"] = ax_client.get_trials_data_frame()
1
2
3
[WARNING 08-12 08:02:19] ax.service.ax_client: Random seed set to 1. Note that this setting only affects the Sobol quasi-random generator and BoTorch-powered Bayesian optimization models. For the latter models, setting random seed to the same number for two optimizations will make the generated trials similar, but not exactly the same, and over time the trials will diverge more.
[INFO 08-12 08:02:19] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='x', parameter_type=FLOAT, range=[0.0, 6.0])], parameter_constraints=[]).
[WARNING 08-12 08:02:21] ax.service.utils.report_utils: Column reason missing for all trials. Not appending column.
1
2
3
4
5
6
7
8
9
10
11
render_plotly_html(
    plot_surrogate_and_acquisition(
        x=x,
        y=y,
        ax_client=ax_client,
        gold_fn=gold,
        predictions=predictions,
        acquisitions=acquisitions,
        line_kwargs=get_line_kwargs(color="red"),
    )
)

Random

Choose the next point by uniformly sampling a point from the domain.

1
2
3
4
5
6
7
8
generation_strategy = create_generation_strategy(
    botorch_acqf_class=None,
    random_class=Models.UNIFORM,
    random_trials=10,
)
ax_client = create_client(generation_strategy=generation_strategy)
predictions, acquisitions = run_ax(n=10, ax_client=ax_client, acqf_n=0)
results["Random"] = ax_client.get_trials_data_frame()
1
2
3
[WARNING 08-12 08:02:21] ax.service.ax_client: Random seed set to 1. Note that this setting only affects the Sobol quasi-random generator and BoTorch-powered Bayesian optimization models. For the latter models, setting random seed to the same number for two optimizations will make the generated trials similar, but not exactly the same, and over time the trials will diverge more.
[INFO 08-12 08:02:21] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='x', parameter_type=FLOAT, range=[0.0, 6.0])], parameter_constraints=[]).
[WARNING 08-12 08:02:21] ax.service.utils.report_utils: Column reason missing for all trials. Not appending column.

Comparison

Compare the acquisition methods from above by showing the max gold achieved against how many trials in the experiment.

1
2
3
4
5
6
max_gold = {"Method": [], "Trial": [], "Max Gold": []}
for k, v in results.items():
    max_gold["Method"].extend([k] * (len(v) + 1))
    max_gold["Trial"].extend([0] + (v.trial_index + 1).to_list())
    max_gold["Max Gold"].extend([0] + v["gold"].cummax().tolist())
render_plotly_html(px.line(max_gold, x="Trial", y="Max Gold", color="Method"))

Hyperparameter Tuning

1
2
3
4
5
# Import moons and SVC from sklearn
from sklearn.datasets import make_moons
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
1
2
def create_svc_data() -> Tuple[np.ndarray, np.ndarray]:
    return make_moons(n_samples=1000, noise=0.5)
1
2
X_train, _, y_train, _ = train_test_split(*create_svc_data())
render_plotly_html(px.scatter(x=X_train[:, 0], y=X_train[:, 1], color=y_train))
1
results_svc: Dict[str, pd.DataFrame] = {}

Ax Helper Functions

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def evaluate_svc(parameterization: TParameterization) -> TEvaluationOutcome:
    _C = parameterization.get("C", 1.0)
    _gamma = parameterization.get("gamma", 1.0)
    X_train, X_test, y_train, y_test = train_test_split(*create_svc_data())
    svc = SVC(C=_C, gamma=_gamma).fit(X=X_train, y=y_train)
    y_pred = svc.predict(X=X_test)
    return {"accuracy": (accuracy_score(y_true=y_test, y_pred=y_pred), None)}


def create_svc_client(
    botorch_acqf_class,
    random_class=Models.SOBOL,
    random_trials: int = 3,
) -> AxClient:
    ax_client = AxClient(
        verbose_logging=False,
        random_seed=1,
        generation_strategy=create_generation_strategy(
            botorch_acqf_class=botorch_acqf_class,
            random_class=random_class,
            random_trials=random_trials,
        ),
    )
    ax_client.create_experiment(
        name="moons_dataset_svc",
        parameters=[
            {
                "name": "C",
                "type": "range",
                "bounds": [1e-9, 1e6],
                "value_type": "float",
                "log_scale": True,
            },
            {
                "name": "gamma",
                "type": "range",
                "bounds": [1e-6, 1e6],
                "value_type": "float",
                "log_scale": True,
            },
        ],
        objectives={"accuracy": ObjectiveProperties(minimize=False)},
    )
    return ax_client

Probability of Improvement (PI)

1
2
3
4
5
6
7
ax_client = create_svc_client(botorch_acqf_class=ProbabilityOfImprovement)
for i in range(10):
    parameterization, trial_index = ax_client.get_next_trial()
    ax_client.complete_trial(
        trial_index=trial_index, raw_data=evaluate_svc(parameterization)
    )
results_svc["PI"] = ax_client.get_trials_data_frame()
1
2
3
[WARNING 08-12 08:02:21] ax.service.ax_client: Random seed set to 1. Note that this setting only affects the Sobol quasi-random generator and BoTorch-powered Bayesian optimization models. For the latter models, setting random seed to the same number for two optimizations will make the generated trials similar, but not exactly the same, and over time the trials will diverge more.
[INFO 08-12 08:02:21] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='C', parameter_type=FLOAT, range=[1e-09, 1000000.0], log_scale=True), RangeParameter(name='gamma', parameter_type=FLOAT, range=[1e-06, 1000000.0], log_scale=True)], parameter_constraints=[]).
[WARNING 08-12 08:02:24] ax.service.utils.report_utils: Column reason missing for all trials. Not appending column.
1
2
3
4
5
6
7
model = cast(
    TorchModelBridge,
    not_none(ax_client.generation_strategy.model),
)
fig = interact_contour_plotly(model=model, metric_name="accuracy")
fig.update_layout(dict(width=None, autosize=True))
render_plotly_html(fig)

Expected Improvement (EI)

1
2
3
4
5
6
7
ax_client = create_svc_client(botorch_acqf_class=ExpectedImprovement)
for i in range(10):
    parameterization, trial_index = ax_client.get_next_trial()
    ax_client.complete_trial(
        trial_index=trial_index, raw_data=evaluate_svc(parameterization)
    )
results_svc["EI"] = ax_client.get_trials_data_frame()
1
2
3
[WARNING 08-12 08:02:25] ax.service.ax_client: Random seed set to 1. Note that this setting only affects the Sobol quasi-random generator and BoTorch-powered Bayesian optimization models. For the latter models, setting random seed to the same number for two optimizations will make the generated trials similar, but not exactly the same, and over time the trials will diverge more.
[INFO 08-12 08:02:25] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='C', parameter_type=FLOAT, range=[1e-09, 1000000.0], log_scale=True), RangeParameter(name='gamma', parameter_type=FLOAT, range=[1e-06, 1000000.0], log_scale=True)], parameter_constraints=[]).
[WARNING 08-12 08:02:35] ax.service.utils.report_utils: Column reason missing for all trials. Not appending column.
1
2
3
4
5
6
7
model = cast(
    TorchModelBridge,
    not_none(ax_client.generation_strategy.model),
)
fig = interact_contour_plotly(model=model, metric_name="accuracy")
fig.update_layout(dict(width=None, autosize=True))
render_plotly_html(fig)

Comparison

1
2
3
4
5
6
max_accuracy = {"Method": [], "Trial": [], "Accuracy": []}
for k, v in results_svc.items():
    max_accuracy["Method"].extend([k] * (len(v) + 1))
    max_accuracy["Trial"].extend([0] + (v.trial_index + 1).to_list())
    max_accuracy["Accuracy"].extend([0] + v["accuracy"].cummax().tolist())
render_plotly_html(px.line(max_accuracy, x="Trial", y="Accuracy", color="Method"))
This post is licensed under CC BY 4.0 by the author.