Optuna による多目的最適化

本チュートリアルでは、Optuna の多目的最適化機能を、Fashion MNIST データセットの 検証精度と PyTorch で実装したモデルの FLOPS の最適化を通じて紹介します。

FLOPS の計測には fvcore を使用します。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from fvcore.nn import FlopCountAnalysis

import optuna


DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
DIR = ".."
BATCHSIZE = 128
N_TRAIN_EXAMPLES = BATCHSIZE * 30
N_VALID_EXAMPLES = BATCHSIZE * 10


def define_model(trial):
    n_layers = trial.suggest_int("n_layers", 1, 3)
    layers = []

    in_features = 28 * 28
    for i in range(n_layers):
        out_features = trial.suggest_int("n_units_l{}".format(i), 4, 128)
        layers.append(nn.Linear(in_features, out_features))
        layers.append(nn.ReLU())
        p = trial.suggest_float("dropout_{}".format(i), 0.2, 0.5)
        layers.append(nn.Dropout(p))

        in_features = out_features

    layers.append(nn.Linear(in_features, 10))
    layers.append(nn.LogSoftmax(dim=1))

    return nn.Sequential(*layers)


# 学習と評価の定義
def train_model(model, optimizer, train_loader):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.view(-1, 28 * 28).to(DEVICE), target.to(DEVICE)
        optimizer.zero_grad()
        F.nll_loss(model(data), target).backward()
        optimizer.step()


def eval_model(model, valid_loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(valid_loader):
            data, target = data.view(-1, 28 * 28).to(DEVICE), target.to(DEVICE)
            pred = model(data).argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    accuracy = correct / N_VALID_EXAMPLES

    flops = FlopCountAnalysis(model, inputs=(torch.randn(1, 28 * 28).to(DEVICE),)).total()
    return flops, accuracy

多目的目的関数の定義 目的は FLOPS と精度

def objective(trial):
    train_dataset = torchvision.datasets.FashionMNIST(
        DIR, train=True, download=True, transform=torchvision.transforms.ToTensor()
    )
    train_loader = torch.utils.data.DataLoader(
        torch.utils.data.Subset(train_dataset, list(range(N_TRAIN_EXAMPLES))),
        batch_size=BATCHSIZE,
        shuffle=True,
    )

    val_dataset = torchvision.datasets.FashionMNIST(
        DIR, train=False, transform=torchvision.transforms.ToTensor()
    )
    val_loader = torch.utils.data.DataLoader(
        torch.utils.data.Subset(val_dataset, list(range(N_VALID_EXAMPLES))),
        batch_size=BATCHSIZE,
        shuffle=True,
    )
    model = define_model(trial).to(DEVICE)

    optimizer = torch.optim.Adam(
        model.parameters(), trial.suggest_float("lr", 1e-5, 1e-1, log=True)
    )

    for epoch in range(10):
        train_model(model, optimizer, train_loader)
    flops, accuracy = eval_model(model, val_loader)
    return flops, accuracy

多目的最適化の実行

最適化問題が多目的の場合、Optuna では各目的に対する最適化方向を指定します。 本例では、FLOPS を最小化(より高速なモデル)し、精度を最大化したいので、 directions["minimize", "maximize"] に設定します。

study = optuna.create_study(directions=["minimize", "maximize"])
study.optimize(objective, n_trials=30, timeout=300)

print("終了したトライアル数: ", len(study.trials))
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
Unsupported operator aten::log_softmax encountered 1 time(s)
終了したトライアル数:  30

以下のセクションでは、可視化のために Plotly と ハイパーパラメータ重要度計算のために scikit-learn のインストールが必要です。

$ pip install plotly
$ pip install scikit-learn
$ pip install nbformat  # Jupyter Notebook でこのチュートリアルを実行する場合に必要

Pareto フロント上のトライアルを視覚的に確認します。

optuna.visualization.plot_pareto_front(study, target_names=["FLOPS", "accuracy"])


best_trials を使用して Pareto フロント上のトライアルリストを取得します。

例えば、以下のコードは Pareto フロント上のトライアル数を表示し、 最高精度のトライアルを選択します。

print(f"Pareto フロント上のトライアル数: {len(study.best_trials)}")

trial_with_highest_accuracy = max(study.best_trials, key=lambda t: t.values[1])
print("最高精度のトライアル: ")
print(f"\t番号: {trial_with_highest_accuracy.number}")
print(f"\tパラメータ: {trial_with_highest_accuracy.params}")
print(f"\t値: {trial_with_highest_accuracy.values}")
Pareto フロント上のトライアル数: 8
最高精度のトライアル:
        番号: 23
        パラメータ: {'n_layers': 1, 'n_units_l0': 84, 'dropout_0': 0.4654215905962879, 'lr': 0.00531644419648229}
        値: [66696.0, 0.84375]

ハイパーパラメータ重要度を使用して、FLOPS に最も影響を与えるハイパーパラメータを特定します。

optuna.visualization.plot_param_importances(
    study, target=lambda t: t.values[0], target_name="flops"
)


Total running time of the script: (1 minutes 35.802 seconds)

Gallery generated by Sphinx-Gallery