RDBバックエンドを使用したスタディの保存/再開

RDBバックエンドを使用すると、スタディの永続化(保存と再開)やスタディ履歴の参照が可能になります。 さらに、この機能を利用してマルチノード環境での最適化タスクを実行できます。詳細は 並列化の容易さ を参照してください。

このセクションでは、SQLiteデータベースを使用したローカル環境での簡単な例を試してみましょう。

Note

PostgreSQLやMySQLなどの他のRDBバックエンドを使用する場合は、storage引数にDBのURLを指定します。 URLの設定方法については、SQLAlchemyのドキュメント を参照してください。

新しいスタディの作成

create_study() 関数を使用して、以下のように永続化可能なスタディを作成できます。 SQLiteファイル example.db が自動的に初期化され、新しいスタディレコードが作成されます。

import logging
import sys

import optuna

# 標準出力にログを表示するストリームハンドラを追加
optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))
study_name = "example-study"  # スタディの一意の識別子
storage_name = "sqlite:///{}.db".format(study_name)
study = optuna.create_study(study_name=study_name, storage=storage_name)
A new study created in RDB with name: example-study

スタディを実行するには、目的関数を引数として optimize() メソッドを呼び出します。

def objective(trial):
    x = trial.suggest_float("x", -10, 10)
    return (x - 2) ** 2

study.optimize(objective, n_trials=3)
Trial 0 finished with value: 54.20662386878493 and parameters: {'x': -5.362514778850018}. Best is trial 0 with value: 54.20662386878493.
Trial 1 finished with value: 20.803831203674598 and parameters: {'x': -2.561121704545341}. Best is trial 1 with value: 20.803831203674598.
Trial 2 finished with value: 50.426732795905124 and parameters: {'x': 9.101178268140092}. Best is trial 1 with value: 20.803831203674598.

スタディの再開

スタディを再開するには、スタディ名 example-study とDB URL sqlite:///example-study.db を指定して Study オブジェクトをインスタンス化します。

study = optuna.create_study(study_name=study_name, storage=storage_name, load_if_exists=True)
study.optimize(objective, n_trials=3)
Using an existing study with name 'example-study' instead of creating a new one.
Trial 3 finished with value: 35.542655131975735 and parameters: {'x': -3.96176610845945}. Best is trial 1 with value: 20.803831203674598.
Trial 4 finished with value: 0.05530669688892123 and parameters: {'x': 1.7648262410707325}. Best is trial 4 with value: 0.05530669688892123.
Trial 5 finished with value: 50.47414279721775 and parameters: {'x': -5.104515662395133}. Best is trial 4 with value: 0.05530669688892123.

注意: samplerspruners のインスタンス状態はストレージに保存されません。 再現性のために seed 引数を指定したサンプラーでスタディを再開する場合、pickle を使用して 以下のようにサンプラーを復元する必要があります:

import pickle

# 後で読み込むためにサンプラーをpickleで保存
with open("sampler.pkl", "wb") as fout:
    pickle.dump(study.sampler, fout)

restored_sampler = pickle.load(open("sampler.pkl", "rb"))
study = optuna.create_study(
    study_name=study_name, storage=storage_name, load_if_exists=True, sampler=restored_sampler
)
study.optimize(objective, n_trials=3)

実験履歴

このセクションでは Pandas のインストールが必要です:

$ pip install pandas

Study クラスを使用してスタディとトライアルの履歴にアクセスできます。 例えば、example-study の全トライアルを取得するには以下のようにします:

study = optuna.create_study(study_name=study_name, storage=storage_name, load_if_exists=True)
df = study.trials_dataframe(attrs=("number", "value", "params", "state"))
Using an existing study with name 'example-study' instead of creating a new one.

メソッド trials_dataframe() は以下のような pandas データフレームを返します:

print(df)
   number      value  params_x     state
0       0  54.206624 -5.362515  COMPLETE
1       1  20.803831 -2.561122  COMPLETE
2       2  50.426733  9.101178  COMPLETE
3       3  35.542655 -3.961766  COMPLETE
4       4   0.055307  1.764826  COMPLETE
5       5  50.474143 -5.104516  COMPLETE

Study オブジェクトは trials, best_value, best_params などのプロパティも提供します (詳細は 軽量かつ汎用性が高く、プラットフォームに依存しないアーキテクチャ を参照)。

print("Best params: ", study.best_params)
print("Best value: ", study.best_value)
print("Best Trial: ", study.best_trial)
print("Trials: ", study.trials)
Best params:  {'x': 1.7648262410707325}
Best value:  0.05530669688892123
Best Trial:  FrozenTrial(number=4, state=1, values=[0.05530669688892123], datetime_start=datetime.datetime(2025, 6, 9, 16, 14, 36, 151898), datetime_complete=datetime.datetime(2025, 6, 9, 16, 14, 36, 207656), params={'x': 1.7648262410707325}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'x': FloatDistribution(high=10.0, log=False, low=-10.0, step=None)}, trial_id=5, value=None)
Trials:  [FrozenTrial(number=0, state=1, values=[54.20662386878493], datetime_start=datetime.datetime(2025, 6, 9, 16, 14, 35, 517150), datetime_complete=datetime.datetime(2025, 6, 9, 16, 14, 35, 597466), params={'x': -5.362514778850018}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'x': FloatDistribution(high=10.0, log=False, low=-10.0, step=None)}, trial_id=1, value=None), FrozenTrial(number=1, state=1, values=[20.803831203674598], datetime_start=datetime.datetime(2025, 6, 9, 16, 14, 35, 662083), datetime_complete=datetime.datetime(2025, 6, 9, 16, 14, 35, 727951), params={'x': -2.561121704545341}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'x': FloatDistribution(high=10.0, log=False, low=-10.0, step=None)}, trial_id=2, value=None), FrozenTrial(number=2, state=1, values=[50.426732795905124], datetime_start=datetime.datetime(2025, 6, 9, 16, 14, 35, 797485), datetime_complete=datetime.datetime(2025, 6, 9, 16, 14, 35, 865820), params={'x': 9.101178268140092}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'x': FloatDistribution(high=10.0, log=False, low=-10.0, step=None)}, trial_id=3, value=None), FrozenTrial(number=3, state=1, values=[35.542655131975735], datetime_start=datetime.datetime(2025, 6, 9, 16, 14, 36, 20536), datetime_complete=datetime.datetime(2025, 6, 9, 16, 14, 36, 92706), params={'x': -3.96176610845945}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'x': FloatDistribution(high=10.0, log=False, low=-10.0, step=None)}, trial_id=4, value=None), FrozenTrial(number=4, state=1, values=[0.05530669688892123], datetime_start=datetime.datetime(2025, 6, 9, 16, 14, 36, 151898), datetime_complete=datetime.datetime(2025, 6, 9, 16, 14, 36, 207656), params={'x': 1.7648262410707325}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'x': FloatDistribution(high=10.0, log=False, low=-10.0, step=None)}, trial_id=5, value=None), FrozenTrial(number=5, state=1, values=[50.47414279721775], datetime_start=datetime.datetime(2025, 6, 9, 16, 14, 36, 268828), datetime_complete=datetime.datetime(2025, 6, 9, 16, 14, 36, 324919), params={'x': -5.104515662395133}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'x': FloatDistribution(high=10.0, log=False, low=-10.0, step=None)}, trial_id=6, value=None)]

Total running time of the script: (0 minutes 2.830 seconds)

Gallery generated by Sphinx-Gallery