optuna.pruners.MedianPruner

class optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=0, interval_steps=1, *, n_min_trials=1)[source]

中央値停止規則を用いたプルーナー

現在のトライアルの中間結果が、これまでのトライアルの中間結果の中央値よりも悪い場合にプルーニングを行う。 完了したトライアルの中間結果の中央値と比較して、見込みのないトライアルを早期に停止する。

使用例

中央値停止規則を用いて目的関数を最小化する。

import numpy as np
from sklearn.datasets import load_iris
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import train_test_split

import optuna

X, y = load_iris(return_X_y=True)
X_train, X_valid, y_train, y_valid = train_test_split(X, y)
classes = np.unique(y)


def objective(trial):
    alpha = trial.suggest_float("alpha", 0.0, 1.0)
    clf = SGDClassifier(alpha=alpha)
    n_train_iter = 100

    for step in range(n_train_iter):
        clf.partial_fit(X_train, y_train, classes=classes)

        intermediate_value = clf.score(X_valid, y_valid)
        trial.report(intermediate_value, step)

        if trial.should_prune():
            raise optuna.TrialPruned()

    return clf.score(X_valid, y_valid)


study = optuna.create_study(
    direction="maximize",
    pruner=optuna.pruners.MedianPruner(
        n_startup_trials=5, n_warmup_steps=30, interval_steps=10
    ),
)
study.optimize(objective, n_trials=20)
Parameters:
  • n_startup_trials (int) – 同じスタディ内で指定されたトライアル数が終了するまでプルーニングは無効。

  • n_warmup_steps (int) – トライアルが指定されたステップ数を超えるまでプルーニングは無効。 この機能は step が0から始まることを前提としている。

  • interval_steps (int) – プルーニングチェック間のステップ間隔。ウォームアップステップ分オフセットされる。 プルーニングチェック時点で値が報告されていない場合、そのチェックは値が報告されるまで延期される。

  • n_min_trials (int) – プルーニング判定に必要な報告トライアル結果の最小数。 現在のステップで全てのトライアルから報告された中間値の数が n_min_trials 未満の場合、 そのトライアルはプルーニングされない。これにより、プルーニングされずに完了するトライアルの 最低数を確保できる。

メソッド

prune(study, trial)

Judge whether the trial should be pruned based on the reported values.

prune(study, trial)

報告された値に基づいてトライアルをプルーニングすべきかどうかを判断する。

このメソッドはライブラリ利用者が直接呼び出すことを想定されていない。代わりに、 optuna.trial.Trial.report()optuna.trial.Trial.should_prune() が 目的関数内でプルーニング機構を実装するためのインターフェースを提供する。

Parameters:
  • study (Study) – 対象スタディのスタディオブジェクト。

  • trial (FrozenTrial) – 対象トライアルのFrozenTrialオブジェクト。 このオブジェクトを変更する前にコピーを取得すること。

Returns:

トライアルをプルーニングすべきかどうかを表す真偽値。

Return type:

bool