optuna.pruners.PercentilePruner
- class optuna.pruners.PercentilePruner(percentile, n_startup_trials=5, n_warmup_steps=0, interval_steps=1, *, n_min_trials=1)[source]
指定したパーセンタイルのトライアルを保持するプルーナー。
同じステップのトライアル中で最良の中間値が下位パーセンタイルに属する場合にプルーニングを行う。
使用例
import numpy as np from sklearn.datasets import load_iris from sklearn.linear_model import SGDClassifier from sklearn.model_selection import train_test_split import optuna X, y = load_iris(return_X_y=True) X_train, X_valid, y_train, y_valid = train_test_split(X, y) classes = np.unique(y) def objective(trial): alpha = trial.suggest_float("alpha", 0.0, 1.0) clf = SGDClassifier(alpha=alpha) n_train_iter = 100 for step in range(n_train_iter): clf.partial_fit(X_train, y_train, classes=classes) intermediate_value = clf.score(X_valid, y_valid) trial.report(intermediate_value, step) if trial.should_prune(): raise optuna.TrialPruned() return clf.score(X_valid, y_valid) study = optuna.create_study( direction="maximize", pruner=optuna.pruners.PercentilePruner( 25.0, n_startup_trials=5, n_warmup_steps=30, interval_steps=10 ), ) study.optimize(objective, n_trials=20)
- Parameters:
percentile (float) – 保持するトライアルのパーセンタイル値(0から100の範囲) (例: 25.0を指定すると、上位25パーセンタイルのトライアルが保持される)
n_startup_trials (int) – 指定したトライアル数が終了するまでプルーニングを無効化
n_warmup_steps (int) – 指定したステップ数までプルーニングを無効化。 この機能は
step
が0から始まることを前提とするinterval_steps (int) – プルーニングチェックの間隔(ウォームアップステップをオフセット) チェック時点で値が報告されていない場合、値が報告されるまで延期される 値は少なくとも1以上でなければならない
n_min_trials (int) – プルーニング判定に必要なトライアル数の最小値 現在のステップで報告された中間値のトライアル数が
n_min_trials
未満の場合、 トライアルはプルーニングされない。これにより、プルーニングされずに完了するトライアルの 最低数を確保できる
メソッド
prune
(study, trial)Judge whether the trial should be pruned based on the reported values.
- prune(study, trial)[source]
報告された値に基づいてトライアルをプルーニングするかどうかを判定
このメソッドはライブラリ利用者が直接呼び出すことを想定していない。代わりに、
optuna.trial.Trial.report()
とoptuna.trial.Trial.should_prune()
が 目的関数内でプルーニング機構を実装するためのインターフェースを提供する- Parameters:
study (Study) – 対象スタディのStudyオブジェクト
trial (FrozenTrial) – 対象トライアルのFrozenTrialオブジェクト このオブジェクトを変更する前にコピーを取得すること
- Returns:
トライアルをプルーニングするかどうかを表すブール値
- Return type: