ユーザー定義プルーナー

optuna.pruners モジュールでは、目的関数がオプションでプルーニング機能を 呼び出せるようにすることで、Optuna が中間結果が有望でない場合に最適化トライアルを 終了できることを説明しました。このドキュメントでは、独自のプルーナー(トライアルを 終了するタイミングを決定するカスタム戦略)を実装する方法について説明します。

プルーニングインターフェースの概要

create_study() コンストラクタは、オプション引数として BasePruner を継承したプルーナーを受け取ります。プルーナーは 抽象メソッド prune() を実装する必要があり、このメソッドは 関連付けられた StudyTrial の 引数を受け取り、トライアルを終了する場合は True を、そうでない場合は False を返します。Study と Trial オブジェクトを使用すると、 get_trials() メソッドで他のすべてのトライアルにアクセスでき、 トライアルからは intermediate_values() を通じて報告された中間値に アクセスできます(これは整数 step を浮動小数点数にマップする辞書です)。

組み込みの Optuna プルーナーのソースコードをテンプレートとして参照できます。この ドキュメントでは、例として、同じステップで完了したトライアルと比較して最下位の トライアルを終了するシンプルな(しかし積極的な)プルーナーの実装と使用方法を 説明します。

Note

より堅牢なプルーナー実装の例(エラーチェックや複雑な内部ロジックを含む)については、 BasePruner または ThresholdPrunerPercentilePruner のドキュメントを参照してください。

例: LastPlacePruner の実装

sklearn の iris データセットで実行される確率的勾配降下法分類器 (SGDClassifier) の lossalpha ハイパーパラメータを最適化することを目的とします。同じステップで 完了したトライアルと比較して最下位のトライアルを終了するプルーナーを実装します。1 回の トレーニングステップと 5 回の完了トライアルを「ウォームアップ」として考慮します。説明 のために、pruneTrue を返す(プルーニングを示す)直前に診断メッセージを print() します。

重要な点として、ホールドアウトセットで評価される SGDClassifier のスコアは、 過学習により十分なトレーニングステップで減少します。これは、トライアルが以前の トレーニングセットで好ましい(高い)値を持っていた場合でも終了される可能性があることを 意味します。プルーニング後、Optuna はトライアルの値として最後に報告された中間値を 使用します。

import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.linear_model import SGDClassifier

import optuna
from optuna.pruners import BasePruner
from optuna.trial._state import TrialState


class LastPlacePruner(BasePruner):
    def __init__(self, warmup_steps, warmup_trials):
        self._warmup_steps = warmup_steps
        self._warmup_trials = warmup_trials

    def prune(self, study: "optuna.study.Study", trial: "optuna.trial.FrozenTrial") -> bool:
        # このトライアルから報告された最新のスコアを取得
        step = trial.last_step

        if step:  # trial.last_step == None はスコアがまだ報告されていない場合
            this_score = trial.intermediate_values[step]

            # 同じステップで報告された他のトライアルのスコアを取得
            completed_trials = study.get_trials(deepcopy=False, states=(TrialState.COMPLETE,))
            other_scores = [
                t.intermediate_values[step]
                for t in completed_trials
                if step in t.intermediate_values
            ]
            other_scores = sorted(other_scores)

            # このトライアルのスコアが完了トライアルの中で最低の場合にプルーニング
            # ステップ番号はobjective関数の定義で0から始まることに注意
            if step >= self._warmup_steps and len(other_scores) > self._warmup_trials:
                if this_score < other_scores[0]:
                    print(f"prune() True: Trial {trial.number}, Step {step}, Score {this_score}")
                    return True

        return False

最後に、簡単なハイパーパラメータ最適化で実装が正しいことを確認します。

def objective(trial):
    iris = load_iris()
    classes = np.unique(iris.target)
    X_train, X_valid, y_train, y_valid = train_test_split(
        iris.data, iris.target, train_size=100, test_size=50, random_state=0
    )

    loss = trial.suggest_categorical("loss", ["hinge", "log_loss", "perceptron"])
    alpha = trial.suggest_float("alpha", 0.00001, 0.001, log=True)
    clf = SGDClassifier(loss=loss, alpha=alpha, random_state=0)
    score = 0

    for step in range(0, 5):
        clf.partial_fit(X_train, y_train, classes=classes)
        score = clf.score(X_valid, y_valid)

        trial.report(score, step)

        if trial.should_prune():
            raise optuna.TrialPruned()

    return score


pruner = LastPlacePruner(warmup_steps=1, warmup_trials=5)
study = optuna.create_study(direction="maximize", pruner=pruner)
study.optimize(objective, n_trials=50)
prune() True: Trial 14, Step 4, Score 0.68
prune() True: Trial 15, Step 2, Score 0.62
prune() True: Trial 18, Step 2, Score 0.62
prune() True: Trial 21, Step 3, Score 0.58
prune() True: Trial 23, Step 4, Score 0.7
prune() True: Trial 25, Step 3, Score 0.42
prune() True: Trial 26, Step 4, Score 0.68
prune() True: Trial 32, Step 3, Score 0.64
prune() True: Trial 33, Step 1, Score 0.32
prune() True: Trial 35, Step 4, Score 0.68
prune() True: Trial 38, Step 4, Score 0.7
prune() True: Trial 40, Step 4, Score 0.7
prune() True: Trial 49, Step 2, Score 0.54

Total running time of the script: (0 minutes 0.714 seconds)

Gallery generated by Sphinx-Gallery