Note
Go to the end to download the full example code.
Study.optimize 用のコールバック
このチュートリアルでは、optimize()
で使用する Optuna の Callback
の使用方法と実装方法を説明します。
Callback
は objective
の評価ごとに呼び出され、
Study
と FrozenTrial
を引数として受け取り、何らかの処理を行います。
MLflowCallback は優れた使用例です。
一定のトライアルが連続して削除されたら最適化を停止
この例では、一定のトライアルが連続して削除された場合に最適化を停止するステートフルなコールバックを実装します。
連続して削除されるトライアルの数は threshold
で指定します。
import optuna
class StopWhenTrialKeepBeingPrunedCallback:
def __init__(self, threshold: int):
self.threshold = threshold
self._consequtive_pruned_count = 0
def __call__(self, study: optuna.study.Study, trial: optuna.trial.FrozenTrial) -> None:
if trial.state == optuna.trial.TrialState.PRUNED:
self._consequtive_pruned_count += 1
else:
self._consequtive_pruned_count = 0
if self._consequtive_pruned_count >= self.threshold:
study.stop()
この目的関数は、最初の5トライアルを除くすべてのトライアルを削除します(trial.number
は0から始まります)。
def objective(trial):
if trial.number > 4:
raise optuna.TrialPruned
return trial.suggest_float("x", 0, 1)
ここでは閾値を 2
に設定しています。2トライアルが連続して削除されると最適化が終了します。
したがって、このスタディは7トライアル後に停止することが期待されます。
import logging
import sys
# メッセージを表示するため標準出力のストリームハンドラを追加
optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))
study_stop_cb = StopWhenTrialKeepBeingPrunedCallback(2)
study = optuna.create_study()
study.optimize(objective, n_trials=10, callbacks=[study_stop_cb])
A new study created in memory with name: no-name-8f76cc7e-736f-4bc3-a181-4df47e0b759b
Trial 0 finished with value: 0.6379973011159417 and parameters: {'x': 0.6379973011159417}. Best is trial 0 with value: 0.6379973011159417.
Trial 1 finished with value: 0.34148280556180055 and parameters: {'x': 0.34148280556180055}. Best is trial 1 with value: 0.34148280556180055.
Trial 2 finished with value: 0.1532498537853697 and parameters: {'x': 0.1532498537853697}. Best is trial 2 with value: 0.1532498537853697.
Trial 3 finished with value: 0.9056429274294634 and parameters: {'x': 0.9056429274294634}. Best is trial 2 with value: 0.1532498537853697.
Trial 4 finished with value: 0.4825927810677111 and parameters: {'x': 0.4825927810677111}. Best is trial 2 with value: 0.1532498537853697.
Trial 5 pruned.
Trial 6 pruned.
上記のログからわかるように、スタディは期待通り7トライアル後に停止しました。
Total running time of the script: (0 minutes 0.005 seconds)