Source code for optuna.storages._in_memory

from __future__ import annotations

from collections.abc import Container
from collections.abc import Sequence
import copy
from datetime import datetime
import threading
from typing import Any
import uuid

import optuna
from optuna import distributions  # NOQA
from optuna._typing import JSONSerializable
from optuna.exceptions import DuplicatedStudyError
from optuna.storages import BaseStorage
from optuna.storages._base import DEFAULT_STUDY_NAME_PREFIX
from optuna.study._frozen import FrozenStudy
from optuna.study._study_direction import StudyDirection
from optuna.trial import FrozenTrial
from optuna.trial import TrialState


_logger = optuna.logging.get_logger(__name__)


[docs] class InMemoryStorage(BaseStorage): """Storage class that stores data in memory of the Python process. Example: Create an :class:`~optuna.storages.InMemoryStorage` instance. .. testcode:: import optuna def objective(trial): x = trial.suggest_float("x", -100, 100) return x**2 storage = optuna.storages.InMemoryStorage() study = optuna.create_study(storage=storage) study.optimize(objective, n_trials=10) """ def __init__(self) -> None: self._trial_id_to_study_id_and_number: dict[int, tuple[int, int]] = {} self._study_name_to_id: dict[str, int] = {} self._studies: dict[int, _StudyInfo] = {} self._max_study_id = -1 self._max_trial_id = -1 self._lock = threading.RLock() self._prev_waiting_trial_number: dict[int, int] = {} def __getstate__(self) -> dict[Any, Any]: state = self.__dict__.copy() del state["_lock"] return state def __setstate__(self, state: dict[Any, Any]) -> None: self.__dict__.update(state) self._lock = threading.RLock()
[docs] def create_new_study( self, directions: Sequence[StudyDirection], study_name: str | None = None ) -> int: with self._lock: study_id = self._max_study_id + 1 self._max_study_id += 1 if study_name is not None: if study_name in self._study_name_to_id: raise DuplicatedStudyError else: study_uuid = str(uuid.uuid4()) study_name = DEFAULT_STUDY_NAME_PREFIX + study_uuid self._studies[study_id] = _StudyInfo(study_name, list(directions)) self._study_name_to_id[study_name] = study_id self._prev_waiting_trial_number[study_id] = 0 _logger.info("A new study created in memory with name: {}".format(study_name)) return study_id
[docs] def delete_study(self, study_id: int) -> None: with self._lock: self._check_study_id(study_id) for trial in self._studies[study_id].trials: del self._trial_id_to_study_id_and_number[trial._trial_id] study_name = self._studies[study_id].name del self._study_name_to_id[study_name] del self._studies[study_id] del self._prev_waiting_trial_number[study_id]
[docs] def set_study_user_attr(self, study_id: int, key: str, value: Any) -> None: with self._lock: self._check_study_id(study_id) self._studies[study_id].user_attrs[key] = value
[docs] def set_study_system_attr(self, study_id: int, key: str, value: JSONSerializable) -> None: with self._lock: self._check_study_id(study_id) self._studies[study_id].system_attrs[key] = value
[docs] def get_study_id_from_name(self, study_name: str) -> int: with self._lock: if study_name not in self._study_name_to_id: raise KeyError("No such study {}.".format(study_name)) return self._study_name_to_id[study_name]
[docs] def get_study_name_from_id(self, study_id: int) -> str: with self._lock: self._check_study_id(study_id) return self._studies[study_id].name
[docs] def get_study_directions(self, study_id: int) -> list[StudyDirection]: with self._lock: self._check_study_id(study_id) return self._studies[study_id].directions
[docs] def get_study_user_attrs(self, study_id: int) -> dict[str, Any]: with self._lock: self._check_study_id(study_id) return self._studies[study_id].user_attrs
[docs] def get_study_system_attrs(self, study_id: int) -> dict[str, Any]: with self._lock: self._check_study_id(study_id) return self._studies[study_id].system_attrs
[docs] def get_all_studies(self) -> list[FrozenStudy]: with self._lock: return [self._build_frozen_study(study_id) for study_id in self._studies]
def _build_frozen_study(self, study_id: int) -> FrozenStudy: study = self._studies[study_id] return FrozenStudy( study_name=study.name, direction=None, directions=study.directions, user_attrs=copy.deepcopy(study.user_attrs), system_attrs=copy.deepcopy(study.system_attrs), study_id=study_id, )
[docs] def create_new_trial(self, study_id: int, template_trial: FrozenTrial | None = None) -> int: with self._lock: self._check_study_id(study_id) if template_trial is None: trial = self._create_running_trial() else: trial = copy.deepcopy(template_trial) trial_id = self._max_trial_id + 1 self._max_trial_id += 1 trial.number = len(self._studies[study_id].trials) trial._trial_id = trial_id self._trial_id_to_study_id_and_number[trial_id] = (study_id, trial.number) self._studies[study_id].trials.append(trial) self._update_cache(trial_id, study_id) return trial_id
@staticmethod def _create_running_trial() -> FrozenTrial: return FrozenTrial( trial_id=-1, # dummy value. number=-1, # dummy value. state=TrialState.RUNNING, params={}, distributions={}, user_attrs={}, system_attrs={}, value=None, intermediate_values={}, datetime_start=datetime.now(), datetime_complete=None, )
[docs] def set_trial_param( self, trial_id: int, param_name: str, param_value_internal: float, distribution: distributions.BaseDistribution, ) -> None: with self._lock: trial = self._get_trial(trial_id) self.check_trial_is_updatable(trial_id, trial.state) study_id = self._trial_id_to_study_id_and_number[trial_id][0] # Check param distribution compatibility with previous trial(s). if param_name in self._studies[study_id].param_distribution: distributions.check_distribution_compatibility( self._studies[study_id].param_distribution[param_name], distribution ) # Set param distribution. self._studies[study_id].param_distribution[param_name] = distribution # Set param. trial = copy.copy(trial) trial.params = copy.copy(trial.params) trial.params[param_name] = distribution.to_external_repr(param_value_internal) trial.distributions = copy.copy(trial.distributions) trial.distributions[param_name] = distribution self._set_trial(trial_id, trial)
[docs] def get_trial_id_from_study_id_trial_number(self, study_id: int, trial_number: int) -> int: with self._lock: study = self._studies.get(study_id) if study is None: raise KeyError("No study with study_id {} exists.".format(study_id)) trials = study.trials if len(trials) <= trial_number: raise KeyError( "No trial with trial number {} exists in study with study_id {}.".format( trial_number, study_id ) ) trial = trials[trial_number] assert trial.number == trial_number return trial._trial_id
[docs] def get_trial_number_from_id(self, trial_id: int) -> int: with self._lock: self._check_trial_id(trial_id) return self._trial_id_to_study_id_and_number[trial_id][1]
[docs] def get_best_trial(self, study_id: int) -> FrozenTrial: with self._lock: self._check_study_id(study_id) best_trial_id = self._studies[study_id].best_trial_id if best_trial_id is None: raise ValueError("No trials are completed yet.") elif len(self._studies[study_id].directions) > 1: raise RuntimeError( "Best trial can be obtained only for single-objective optimization." ) return self.get_trial(best_trial_id)
[docs] def get_trial_param(self, trial_id: int, param_name: str) -> float: with self._lock: trial = self._get_trial(trial_id) distribution = trial.distributions[param_name] return distribution.to_internal_repr(trial.params[param_name])
[docs] def set_trial_state_values( self, trial_id: int, state: TrialState, values: Sequence[float] | None = None ) -> bool: with self._lock: trial = copy.copy(self._get_trial(trial_id)) self.check_trial_is_updatable(trial_id, trial.state) if state == TrialState.RUNNING and trial.state != TrialState.WAITING: return False trial.state = state if values is not None: trial.values = values if state == TrialState.RUNNING: trial.datetime_start = datetime.now() if state.is_finished(): trial.datetime_complete = datetime.now() self._set_trial(trial_id, trial) study_id = self._trial_id_to_study_id_and_number[trial_id][0] self._update_cache(trial_id, study_id) else: self._set_trial(trial_id, trial) return True
def _update_cache(self, trial_id: int, study_id: int) -> None: trial = self._get_trial(trial_id) if trial.state != TrialState.COMPLETE: return best_trial_id = self._studies[study_id].best_trial_id if best_trial_id is None: self._studies[study_id].best_trial_id = trial_id return _directions = self.get_study_directions(study_id) if len(_directions) > 1: return direction = _directions[0] best_trial = self._get_trial(best_trial_id) assert best_trial is not None if best_trial.value is None: self._studies[study_id].best_trial_id = trial_id return # Complete trials do not have `None` values. assert trial.value is not None best_value = best_trial.value new_value = trial.value if direction == StudyDirection.MAXIMIZE: if best_value < new_value: self._studies[study_id].best_trial_id = trial_id else: if best_value > new_value: self._studies[study_id].best_trial_id = trial_id
[docs] def set_trial_intermediate_value( self, trial_id: int, step: int, intermediate_value: float ) -> None: with self._lock: trial = self._get_trial(trial_id) self.check_trial_is_updatable(trial_id, trial.state) trial = copy.copy(trial) trial.intermediate_values = copy.copy(trial.intermediate_values) trial.intermediate_values[step] = intermediate_value self._set_trial(trial_id, trial)
[docs] def set_trial_user_attr(self, trial_id: int, key: str, value: Any) -> None: with self._lock: self._check_trial_id(trial_id) trial = self._get_trial(trial_id) self.check_trial_is_updatable(trial_id, trial.state) trial = copy.copy(trial) trial.user_attrs = copy.copy(trial.user_attrs) trial.user_attrs[key] = value self._set_trial(trial_id, trial)
[docs] def set_trial_system_attr(self, trial_id: int, key: str, value: JSONSerializable) -> None: with self._lock: trial = self._get_trial(trial_id) self.check_trial_is_updatable(trial_id, trial.state) trial = copy.copy(trial) trial.system_attrs = copy.copy(trial.system_attrs) trial.system_attrs[key] = value self._set_trial(trial_id, trial)
[docs] def get_trial(self, trial_id: int) -> FrozenTrial: with self._lock: return self._get_trial(trial_id)
def _get_trial(self, trial_id: int) -> FrozenTrial: self._check_trial_id(trial_id) study_id, trial_number = self._trial_id_to_study_id_and_number[trial_id] return self._studies[study_id].trials[trial_number] def _set_trial(self, trial_id: int, trial: FrozenTrial) -> None: study_id, trial_number = self._trial_id_to_study_id_and_number[trial_id] self._studies[study_id].trials[trial_number] = trial
[docs] def get_all_trials( self, study_id: int, deepcopy: bool = True, states: Container[TrialState] | None = None, ) -> list[FrozenTrial]: with self._lock: self._check_study_id(study_id) # Optimized retrieval of trials in the WAITING state to improve performance # for the call, `get_all_trials(states=(TrialState.WAITING,))`. if states == (TrialState.WAITING,): trials: list[FrozenTrial] = [] for trial in self._studies[study_id].trials[ self._prev_waiting_trial_number[study_id] : ]: if trial.state == TrialState.WAITING: if not trials: self._prev_waiting_trial_number[study_id] = trial.number trials.append(trial) if not trials: self._prev_waiting_trial_number[study_id] = len(self._studies[study_id].trials) else: trials = self._studies[study_id].trials if states is not None: trials = [t for t in trials if t.state in states] if deepcopy: trials = copy.deepcopy(trials) else: # This copy is required for the replacing trick in `set_trial_xxx`. trials = copy.copy(trials) return trials
def _check_study_id(self, study_id: int) -> None: if study_id not in self._studies: raise KeyError("No study with study_id {} exists.".format(study_id)) def _check_trial_id(self, trial_id: int) -> None: if trial_id not in self._trial_id_to_study_id_and_number: raise KeyError("No trial with trial_id {} exists.".format(trial_id))
class _StudyInfo: def __init__(self, name: str, directions: list[StudyDirection]) -> None: self.trials: list[FrozenTrial] = [] self.param_distribution: dict[str, distributions.BaseDistribution] = {} self.user_attrs: dict[str, Any] = {} self.system_attrs: dict[str, Any] = {} self.name: str = name self.directions: list[StudyDirection] = directions self.best_trial_id: int | None = None