Study.optimize 用のコールバック

このチュートリアルでは、optimize() で使用する Optuna の Callback の使用方法と実装方法を説明します。

Callbackobjective の評価ごとに呼び出され、 StudyFrozenTrial を引数として受け取り、何らかの処理を行います。

MLflowCallback は優れた使用例です。

一定のトライアルが連続して削除されたら最適化を停止

この例では、一定のトライアルが連続して削除された場合に最適化を停止するステートフルなコールバックを実装します。 連続して削除されるトライアルの数は threshold で指定します。

import optuna


class StopWhenTrialKeepBeingPrunedCallback:
    def __init__(self, threshold: int):
        self.threshold = threshold
        self._consequtive_pruned_count = 0

    def __call__(self, study: optuna.study.Study, trial: optuna.trial.FrozenTrial) -> None:
        if trial.state == optuna.trial.TrialState.PRUNED:
            self._consequtive_pruned_count += 1
        else:
            self._consequtive_pruned_count = 0

        if self._consequtive_pruned_count >= self.threshold:
            study.stop()

この目的関数は、最初の5トライアルを除くすべてのトライアルを削除します(trial.number は0から始まります)。

def objective(trial):
    if trial.number > 4:
        raise optuna.TrialPruned

    return trial.suggest_float("x", 0, 1)

ここでは閾値を 2 に設定しています。2トライアルが連続して削除されると最適化が終了します。 したがって、このスタディは7トライアル後に停止することが期待されます。

import logging
import sys

# メッセージを表示するため標準出力のストリームハンドラを追加
optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))

study_stop_cb = StopWhenTrialKeepBeingPrunedCallback(2)
study = optuna.create_study()
study.optimize(objective, n_trials=10, callbacks=[study_stop_cb])
A new study created in memory with name: no-name-8f76cc7e-736f-4bc3-a181-4df47e0b759b
Trial 0 finished with value: 0.6379973011159417 and parameters: {'x': 0.6379973011159417}. Best is trial 0 with value: 0.6379973011159417.
Trial 1 finished with value: 0.34148280556180055 and parameters: {'x': 0.34148280556180055}. Best is trial 1 with value: 0.34148280556180055.
Trial 2 finished with value: 0.1532498537853697 and parameters: {'x': 0.1532498537853697}. Best is trial 2 with value: 0.1532498537853697.
Trial 3 finished with value: 0.9056429274294634 and parameters: {'x': 0.9056429274294634}. Best is trial 2 with value: 0.1532498537853697.
Trial 4 finished with value: 0.4825927810677111 and parameters: {'x': 0.4825927810677111}. Best is trial 2 with value: 0.1532498537853697.
Trial 5 pruned.
Trial 6 pruned.

上記のログからわかるように、スタディは期待通り7トライアル後に停止しました。

Total running time of the script: (0 minutes 0.005 seconds)

Gallery generated by Sphinx-Gallery