from __future__ import annotations
from collections.abc import Container
from collections.abc import Sequence
import copy
from datetime import datetime
import threading
from typing import Any
import uuid
import optuna
from optuna import distributions # NOQA
from optuna._typing import JSONSerializable
from optuna.exceptions import DuplicatedStudyError
from optuna.storages import BaseStorage
from optuna.storages._base import DEFAULT_STUDY_NAME_PREFIX
from optuna.study._frozen import FrozenStudy
from optuna.study._study_direction import StudyDirection
from optuna.trial import FrozenTrial
from optuna.trial import TrialState
_logger = optuna.logging.get_logger(__name__)
[docs]
class InMemoryStorage(BaseStorage):
"""Storage class that stores data in memory of the Python process.
Example:
Create an :class:`~optuna.storages.InMemoryStorage` instance.
.. testcode::
import optuna
def objective(trial):
x = trial.suggest_float("x", -100, 100)
return x**2
storage = optuna.storages.InMemoryStorage()
study = optuna.create_study(storage=storage)
study.optimize(objective, n_trials=10)
"""
def __init__(self) -> None:
self._trial_id_to_study_id_and_number: dict[int, tuple[int, int]] = {}
self._study_name_to_id: dict[str, int] = {}
self._studies: dict[int, _StudyInfo] = {}
self._max_study_id = -1
self._max_trial_id = -1
self._lock = threading.RLock()
self._prev_waiting_trial_number: dict[int, int] = {}
def __getstate__(self) -> dict[Any, Any]:
state = self.__dict__.copy()
del state["_lock"]
return state
def __setstate__(self, state: dict[Any, Any]) -> None:
self.__dict__.update(state)
self._lock = threading.RLock()
[docs]
def create_new_study(
self, directions: Sequence[StudyDirection], study_name: str | None = None
) -> int:
with self._lock:
study_id = self._max_study_id + 1
self._max_study_id += 1
if study_name is not None:
if study_name in self._study_name_to_id:
raise DuplicatedStudyError
else:
study_uuid = str(uuid.uuid4())
study_name = DEFAULT_STUDY_NAME_PREFIX + study_uuid
self._studies[study_id] = _StudyInfo(study_name, list(directions))
self._study_name_to_id[study_name] = study_id
self._prev_waiting_trial_number[study_id] = 0
_logger.info("A new study created in memory with name: {}".format(study_name))
return study_id
[docs]
def delete_study(self, study_id: int) -> None:
with self._lock:
self._check_study_id(study_id)
for trial in self._studies[study_id].trials:
del self._trial_id_to_study_id_and_number[trial._trial_id]
study_name = self._studies[study_id].name
del self._study_name_to_id[study_name]
del self._studies[study_id]
del self._prev_waiting_trial_number[study_id]
[docs]
def set_study_user_attr(self, study_id: int, key: str, value: Any) -> None:
with self._lock:
self._check_study_id(study_id)
self._studies[study_id].user_attrs[key] = value
[docs]
def set_study_system_attr(self, study_id: int, key: str, value: JSONSerializable) -> None:
with self._lock:
self._check_study_id(study_id)
self._studies[study_id].system_attrs[key] = value
[docs]
def get_study_id_from_name(self, study_name: str) -> int:
with self._lock:
if study_name not in self._study_name_to_id:
raise KeyError("No such study {}.".format(study_name))
return self._study_name_to_id[study_name]
[docs]
def get_study_name_from_id(self, study_id: int) -> str:
with self._lock:
self._check_study_id(study_id)
return self._studies[study_id].name
[docs]
def get_study_directions(self, study_id: int) -> list[StudyDirection]:
with self._lock:
self._check_study_id(study_id)
return self._studies[study_id].directions
[docs]
def get_study_user_attrs(self, study_id: int) -> dict[str, Any]:
with self._lock:
self._check_study_id(study_id)
return self._studies[study_id].user_attrs
[docs]
def get_study_system_attrs(self, study_id: int) -> dict[str, Any]:
with self._lock:
self._check_study_id(study_id)
return self._studies[study_id].system_attrs
[docs]
def get_all_studies(self) -> list[FrozenStudy]:
with self._lock:
return [self._build_frozen_study(study_id) for study_id in self._studies]
def _build_frozen_study(self, study_id: int) -> FrozenStudy:
study = self._studies[study_id]
return FrozenStudy(
study_name=study.name,
direction=None,
directions=study.directions,
user_attrs=copy.deepcopy(study.user_attrs),
system_attrs=copy.deepcopy(study.system_attrs),
study_id=study_id,
)
[docs]
def create_new_trial(self, study_id: int, template_trial: FrozenTrial | None = None) -> int:
with self._lock:
self._check_study_id(study_id)
if template_trial is None:
trial = self._create_running_trial()
else:
trial = copy.deepcopy(template_trial)
trial_id = self._max_trial_id + 1
self._max_trial_id += 1
trial.number = len(self._studies[study_id].trials)
trial._trial_id = trial_id
self._trial_id_to_study_id_and_number[trial_id] = (study_id, trial.number)
self._studies[study_id].trials.append(trial)
self._update_cache(trial_id, study_id)
return trial_id
@staticmethod
def _create_running_trial() -> FrozenTrial:
return FrozenTrial(
trial_id=-1, # dummy value.
number=-1, # dummy value.
state=TrialState.RUNNING,
params={},
distributions={},
user_attrs={},
system_attrs={},
value=None,
intermediate_values={},
datetime_start=datetime.now(),
datetime_complete=None,
)
[docs]
def set_trial_param(
self,
trial_id: int,
param_name: str,
param_value_internal: float,
distribution: distributions.BaseDistribution,
) -> None:
with self._lock:
trial = self._get_trial(trial_id)
self.check_trial_is_updatable(trial_id, trial.state)
study_id = self._trial_id_to_study_id_and_number[trial_id][0]
# Check param distribution compatibility with previous trial(s).
if param_name in self._studies[study_id].param_distribution:
distributions.check_distribution_compatibility(
self._studies[study_id].param_distribution[param_name], distribution
)
# Set param distribution.
self._studies[study_id].param_distribution[param_name] = distribution
# Set param.
trial = copy.copy(trial)
trial.params = copy.copy(trial.params)
trial.params[param_name] = distribution.to_external_repr(param_value_internal)
trial.distributions = copy.copy(trial.distributions)
trial.distributions[param_name] = distribution
self._set_trial(trial_id, trial)
[docs]
def get_trial_id_from_study_id_trial_number(self, study_id: int, trial_number: int) -> int:
with self._lock:
study = self._studies.get(study_id)
if study is None:
raise KeyError("No study with study_id {} exists.".format(study_id))
trials = study.trials
if len(trials) <= trial_number:
raise KeyError(
"No trial with trial number {} exists in study with study_id {}.".format(
trial_number, study_id
)
)
trial = trials[trial_number]
assert trial.number == trial_number
return trial._trial_id
[docs]
def get_trial_number_from_id(self, trial_id: int) -> int:
with self._lock:
self._check_trial_id(trial_id)
return self._trial_id_to_study_id_and_number[trial_id][1]
[docs]
def get_best_trial(self, study_id: int) -> FrozenTrial:
with self._lock:
self._check_study_id(study_id)
best_trial_id = self._studies[study_id].best_trial_id
if best_trial_id is None:
raise ValueError("No trials are completed yet.")
elif len(self._studies[study_id].directions) > 1:
raise RuntimeError(
"Best trial can be obtained only for single-objective optimization."
)
return self.get_trial(best_trial_id)
[docs]
def get_trial_param(self, trial_id: int, param_name: str) -> float:
with self._lock:
trial = self._get_trial(trial_id)
distribution = trial.distributions[param_name]
return distribution.to_internal_repr(trial.params[param_name])
[docs]
def set_trial_state_values(
self, trial_id: int, state: TrialState, values: Sequence[float] | None = None
) -> bool:
with self._lock:
trial = copy.copy(self._get_trial(trial_id))
self.check_trial_is_updatable(trial_id, trial.state)
if state == TrialState.RUNNING and trial.state != TrialState.WAITING:
return False
trial.state = state
if values is not None:
trial.values = values
if state == TrialState.RUNNING:
trial.datetime_start = datetime.now()
if state.is_finished():
trial.datetime_complete = datetime.now()
self._set_trial(trial_id, trial)
study_id = self._trial_id_to_study_id_and_number[trial_id][0]
self._update_cache(trial_id, study_id)
else:
self._set_trial(trial_id, trial)
return True
def _update_cache(self, trial_id: int, study_id: int) -> None:
trial = self._get_trial(trial_id)
if trial.state != TrialState.COMPLETE:
return
best_trial_id = self._studies[study_id].best_trial_id
if best_trial_id is None:
self._studies[study_id].best_trial_id = trial_id
return
_directions = self.get_study_directions(study_id)
if len(_directions) > 1:
return
direction = _directions[0]
best_trial = self._get_trial(best_trial_id)
assert best_trial is not None
if best_trial.value is None:
self._studies[study_id].best_trial_id = trial_id
return
# Complete trials do not have `None` values.
assert trial.value is not None
best_value = best_trial.value
new_value = trial.value
if direction == StudyDirection.MAXIMIZE:
if best_value < new_value:
self._studies[study_id].best_trial_id = trial_id
else:
if best_value > new_value:
self._studies[study_id].best_trial_id = trial_id
[docs]
def set_trial_user_attr(self, trial_id: int, key: str, value: Any) -> None:
with self._lock:
self._check_trial_id(trial_id)
trial = self._get_trial(trial_id)
self.check_trial_is_updatable(trial_id, trial.state)
trial = copy.copy(trial)
trial.user_attrs = copy.copy(trial.user_attrs)
trial.user_attrs[key] = value
self._set_trial(trial_id, trial)
[docs]
def set_trial_system_attr(self, trial_id: int, key: str, value: JSONSerializable) -> None:
with self._lock:
trial = self._get_trial(trial_id)
self.check_trial_is_updatable(trial_id, trial.state)
trial = copy.copy(trial)
trial.system_attrs = copy.copy(trial.system_attrs)
trial.system_attrs[key] = value
self._set_trial(trial_id, trial)
[docs]
def get_trial(self, trial_id: int) -> FrozenTrial:
with self._lock:
return self._get_trial(trial_id)
def _get_trial(self, trial_id: int) -> FrozenTrial:
self._check_trial_id(trial_id)
study_id, trial_number = self._trial_id_to_study_id_and_number[trial_id]
return self._studies[study_id].trials[trial_number]
def _set_trial(self, trial_id: int, trial: FrozenTrial) -> None:
study_id, trial_number = self._trial_id_to_study_id_and_number[trial_id]
self._studies[study_id].trials[trial_number] = trial
[docs]
def get_all_trials(
self,
study_id: int,
deepcopy: bool = True,
states: Container[TrialState] | None = None,
) -> list[FrozenTrial]:
with self._lock:
self._check_study_id(study_id)
# Optimized retrieval of trials in the WAITING state to improve performance
# for the call, `get_all_trials(states=(TrialState.WAITING,))`.
if states == (TrialState.WAITING,):
trials: list[FrozenTrial] = []
for trial in self._studies[study_id].trials[
self._prev_waiting_trial_number[study_id] :
]:
if trial.state == TrialState.WAITING:
if not trials:
self._prev_waiting_trial_number[study_id] = trial.number
trials.append(trial)
if not trials:
self._prev_waiting_trial_number[study_id] = len(self._studies[study_id].trials)
else:
trials = self._studies[study_id].trials
if states is not None:
trials = [t for t in trials if t.state in states]
if deepcopy:
trials = copy.deepcopy(trials)
else:
# This copy is required for the replacing trick in `set_trial_xxx`.
trials = copy.copy(trials)
return trials
def _check_study_id(self, study_id: int) -> None:
if study_id not in self._studies:
raise KeyError("No study with study_id {} exists.".format(study_id))
def _check_trial_id(self, trial_id: int) -> None:
if trial_id not in self._trial_id_to_study_id_and_number:
raise KeyError("No trial with trial_id {} exists.".format(trial_id))
class _StudyInfo:
def __init__(self, name: str, directions: list[StudyDirection]) -> None:
self.trials: list[FrozenTrial] = []
self.param_distribution: dict[str, distributions.BaseDistribution] = {}
self.user_attrs: dict[str, Any] = {}
self.system_attrs: dict[str, Any] = {}
self.name: str = name
self.directions: list[StudyDirection] = directions
self.best_trial_id: int | None = None