from__future__importannotationsfromcollections.abcimportContainerfromcollections.abcimportSequenceimportcopyimportdatetimeimportenumimportpickleimportthreadingfromtypingimportAnyimportuuidimportoptunafromoptuna._typingimportJSONSerializablefromoptuna.distributionsimportBaseDistributionfromoptuna.distributionsimportcheck_distribution_compatibilityfromoptuna.distributionsimportdistribution_to_jsonfromoptuna.distributionsimportjson_to_distributionfromoptuna.exceptionsimportDuplicatedStudyErrorfromoptuna.exceptionsimportUpdateFinishedTrialErrorfromoptuna.storagesimportBaseStoragefromoptuna.storages._baseimportDEFAULT_STUDY_NAME_PREFIXfromoptuna.storages.journal._baseimportBaseJournalBackendfromoptuna.storages.journal._baseimportBaseJournalSnapshotfromoptuna.study._frozenimportFrozenStudyfromoptuna.study._study_directionimportStudyDirectionfromoptuna.trialimportFrozenTrialfromoptuna.trialimportTrialState_logger=optuna.logging.get_logger(__name__)NOT_FOUND_MSG="Record does not exist."# A heuristic interval number to dump snapshotsSNAPSHOT_INTERVAL=100classJournalOperation(enum.IntEnum):CREATE_STUDY=0DELETE_STUDY=1SET_STUDY_USER_ATTR=2SET_STUDY_SYSTEM_ATTR=3CREATE_TRIAL=4SET_TRIAL_PARAM=5SET_TRIAL_STATE_VALUES=6SET_TRIAL_INTERMEDIATE_VALUE=7SET_TRIAL_USER_ATTR=8SET_TRIAL_SYSTEM_ATTR=9
[docs]classJournalStorage(BaseStorage):"""Storage class for Journal storage backend. Note that library users can instantiate this class, but the attributes provided by this class are not supposed to be directly accessed by them. Journal storage writes a record of every operation to the database as it is executed and at the same time, keeps a latest snapshot of the database in-memory. If the database crashes for any reason, the storage can re-establish the contents in memory by replaying the operations stored from the beginning. Journal storage has several benefits over the conventional value logging storages. 1. The number of IOs can be reduced because of larger granularity of logs. 2. Journal storage has simpler backend API than value logging storage. 3. Journal storage keeps a snapshot in-memory so no need to add more cache. Example: .. code:: import optuna def objective(trial): ... storage = optuna.storages.JournalStorage( optuna.storages.journal.JournalFileBackend("./optuna_journal_storage.log") ) study = optuna.create_study(storage=storage) study.optimize(objective) In a Windows environment, an error message "A required privilege is not held by the client" may appear. In this case, you can solve the problem with creating storage by specifying :class:`~optuna.storages.journal.JournalFileOpenLock` as follows. .. code:: file_path = "./optuna_journal_storage.log" lock_obj = optuna.storages.journal.JournalFileOpenLock(file_path) storage = optuna.storages.JournalStorage( optuna.storages.journal.JournalFileBackend(file_path, lock_obj=lock_obj), ) """def__init__(self,log_storage:BaseJournalBackend)->None:self._worker_id_prefix=str(uuid.uuid4())+"-"self._backend=log_storageself._thread_lock=threading.Lock()self._replay_result=JournalStorageReplayResult(self._worker_id_prefix)withself._thread_lock:ifisinstance(self._backend,BaseJournalSnapshot):snapshot=self._backend.load_snapshot()ifsnapshotisnotNone:self.restore_replay_result(snapshot)self._sync_with_backend()def__getstate__(self)->dict[Any,Any]:state=self.__dict__.copy()delstate["_worker_id_prefix"]delstate["_replay_result"]delstate["_thread_lock"]returnstatedef__setstate__(self,state:dict[Any,Any])->None:self.__dict__.update(state)self._worker_id_prefix=str(uuid.uuid4())+"-"self._replay_result=JournalStorageReplayResult(self._worker_id_prefix)self._thread_lock=threading.Lock()defrestore_replay_result(self,snapshot:bytes)->None:try:r:JournalStorageReplayResult|None=pickle.loads(snapshot)except(pickle.UnpicklingError,KeyError):_logger.warning("Failed to restore `JournalStorageReplayResult`.")returnifrisNone:returnifnotisinstance(r,JournalStorageReplayResult):_logger.warning("The restored object is not `JournalStorageReplayResult`.")returnr._worker_id_prefix=self._worker_id_prefixr._worker_id_to_owned_trial_id={}r._last_created_trial_id_by_this_process=-1self._replay_result=rdef_write_log(self,op_code:int,extra_fields:dict[str,Any])->None:worker_id=self._replay_result.worker_idself._backend.append_logs([{"op_code":op_code,"worker_id":worker_id,**extra_fields}])def_sync_with_backend(self)->None:logs=self._backend.read_logs(self._replay_result.log_number_read)self._replay_result.apply_logs(logs)
[docs]defcreate_new_study(self,directions:Sequence[StudyDirection],study_name:str|None=None)->int:study_name=study_nameorDEFAULT_STUDY_NAME_PREFIX+str(uuid.uuid4())withself._thread_lock:self._write_log(JournalOperation.CREATE_STUDY,{"study_name":study_name,"directions":directions})self._sync_with_backend()forfrozen_studyinself._replay_result.get_all_studies():iffrozen_study.study_name!=study_name:continue_logger.info("A new study created in Journal with name: {}".format(study_name))study_id=frozen_study._study_id# Dump snapshot here.if(isinstance(self._backend,BaseJournalSnapshot)andstudy_id!=0andstudy_id%SNAPSHOT_INTERVAL==0):self._backend.save_snapshot(pickle.dumps(self._replay_result))returnstudy_idassertFalse,"Should not reach."
[docs]defget_trial_id_from_study_id_trial_number(self,study_id:int,trial_number:int)->int:withself._thread_lock:self._sync_with_backend()iflen(self._replay_result._study_id_to_trial_ids[study_id])<=trial_number:raiseKeyError("No trial with trial number {} exists in study with study_id {}.".format(trial_number,study_id))returnself._replay_result._study_id_to_trial_ids[study_id][trial_number]
classJournalStorageReplayResult:def__init__(self,worker_id_prefix:str)->None:self.log_number_read=0self._worker_id_prefix=worker_id_prefixself._studies:dict[int,FrozenStudy]={}self._trials:dict[int,FrozenTrial]={}self._study_id_to_trial_ids:dict[int,list[int]]={}self._trial_id_to_study_id:dict[int,int]={}self._next_study_id:int=0self._worker_id_to_owned_trial_id:dict[str,int]={}defapply_logs(self,logs:list[dict[str,Any]])->None:forloginlogs:self.log_number_read+=1op=log["op_code"]ifop==JournalOperation.CREATE_STUDY:self._apply_create_study(log)elifop==JournalOperation.DELETE_STUDY:self._apply_delete_study(log)elifop==JournalOperation.SET_STUDY_USER_ATTR:self._apply_set_study_user_attr(log)elifop==JournalOperation.SET_STUDY_SYSTEM_ATTR:self._apply_set_study_system_attr(log)elifop==JournalOperation.CREATE_TRIAL:self._apply_create_trial(log)elifop==JournalOperation.SET_TRIAL_PARAM:self._apply_set_trial_param(log)elifop==JournalOperation.SET_TRIAL_STATE_VALUES:self._apply_set_trial_state_values(log)elifop==JournalOperation.SET_TRIAL_INTERMEDIATE_VALUE:self._apply_set_trial_intermediate_value(log)elifop==JournalOperation.SET_TRIAL_USER_ATTR:self._apply_set_trial_user_attr(log)elifop==JournalOperation.SET_TRIAL_SYSTEM_ATTR:self._apply_set_trial_system_attr(log)else:assertFalse,"Should not reach."defget_study(self,study_id:int)->FrozenStudy:ifstudy_idnotinself._studies:raiseKeyError(NOT_FOUND_MSG)returnself._studies[study_id]defget_all_studies(self)->list[FrozenStudy]:returnlist(self._studies.values())defget_trial(self,trial_id:int)->FrozenTrial:iftrial_idnotinself._trials:raiseKeyError(NOT_FOUND_MSG)returnself._trials[trial_id]defget_all_trials(self,study_id:int,states:Container[TrialState]|None)->list[FrozenTrial]:ifstudy_idnotinself._studies:raiseKeyError(NOT_FOUND_MSG)frozen_trials:list[FrozenTrial]=[]fortrial_idinself._study_id_to_trial_ids[study_id]:trial=self._trials[trial_id]ifstatesisNoneortrial.stateinstates:frozen_trials.append(trial)returnfrozen_trials@propertydefworker_id(self)->str:returnself._worker_id_prefix+str(threading.get_ident())@propertydefowned_trial_id(self)->int|None:returnself._worker_id_to_owned_trial_id.get(self.worker_id)def_is_issued_by_this_worker(self,log:dict[str,Any])->bool:returnlog["worker_id"]==self.worker_iddef_study_exists(self,study_id:int,log:dict[str,Any])->bool:ifstudy_idinself._studies:returnTrueifself._is_issued_by_this_worker(log):raiseKeyError(NOT_FOUND_MSG)returnFalsedef_apply_create_study(self,log:dict[str,Any])->None:study_name=log["study_name"]directions=[StudyDirection(d)fordinlog["directions"]]ifstudy_namein[s.study_nameforsinself._studies.values()]:ifself._is_issued_by_this_worker(log):raiseDuplicatedStudyError("Another study with name '{}' already exists. ""Please specify a different name, or reuse the existing one ""by setting `load_if_exists` (for Python API) or ""`--skip-if-exists` flag (for CLI).".format(study_name))returnstudy_id=self._next_study_idself._next_study_id+=1self._studies[study_id]=FrozenStudy(study_name=study_name,direction=None,user_attrs={},system_attrs={},study_id=study_id,directions=directions,)self._study_id_to_trial_ids[study_id]=[]def_apply_delete_study(self,log:dict[str,Any])->None:study_id=log["study_id"]ifself._study_exists(study_id,log):fs=self._studies.pop(study_id)assertfs._study_id==study_iddef_apply_set_study_user_attr(self,log:dict[str,Any])->None:study_id=log["study_id"]ifself._study_exists(study_id,log):assertlen(log["user_attr"])==1self._studies[study_id].user_attrs.update(log["user_attr"])def_apply_set_study_system_attr(self,log:dict[str,Any])->None:study_id=log["study_id"]ifself._study_exists(study_id,log):assertlen(log["system_attr"])==1self._studies[study_id].system_attrs.update(log["system_attr"])def_apply_create_trial(self,log:dict[str,Any])->None:study_id=log["study_id"]ifnotself._study_exists(study_id,log):returntrial_id=len(self._trials)distributions={}if"distributions"inlog:distributions={k:json_to_distribution(v)fork,vinlog["distributions"].items()}params={}if"params"inlog:params={k:distributions[k].to_external_repr(p)fork,pinlog["params"].items()}iflog["datetime_start"]isnotNone:datetime_start=datetime.datetime.fromisoformat(log["datetime_start"])else:datetime_start=Noneif"datetime_complete"inlog:datetime_complete=datetime.datetime.fromisoformat(log["datetime_complete"])else:datetime_complete=Noneself._trials[trial_id]=FrozenTrial(trial_id=trial_id,number=len(self._study_id_to_trial_ids[study_id]),state=TrialState(log.get("state",TrialState.RUNNING.value)),params=params,distributions=distributions,user_attrs=log.get("user_attrs",{}),system_attrs=log.get("system_attrs",{}),value=log.get("value",None),intermediate_values={int(k):vfork,vinlog.get("intermediate_values",{}).items()},datetime_start=datetime_start,datetime_complete=datetime_complete,values=log.get("values",None),)self._study_id_to_trial_ids[study_id].append(trial_id)self._trial_id_to_study_id[trial_id]=study_idifself._is_issued_by_this_worker(log):self._last_created_trial_id_by_this_process=trial_idifself._trials[trial_id].state==TrialState.RUNNING:self._worker_id_to_owned_trial_id[self.worker_id]=trial_iddef_apply_set_trial_param(self,log:dict[str,Any])->None:trial_id=log["trial_id"]ifnotself._trial_exists_and_updatable(trial_id,log):returnparam_name=log["param_name"]param_value_internal=log["param_value_internal"]distribution=json_to_distribution(log["distribution"])study_id=self._trial_id_to_study_id[trial_id]forprev_trial_idinself._study_id_to_trial_ids[study_id]:prev_trial=self._trials[prev_trial_id]ifparam_nameinprev_trial.params.keys():try:check_distribution_compatibility(prev_trial.distributions[param_name],distribution)exceptException:ifself._is_issued_by_this_worker(log):raisereturnbreaktrial=copy.copy(self._trials[trial_id])trial.params={**copy.copy(trial.params),param_name:distribution.to_external_repr(param_value_internal),}trial.distributions={**copy.copy(trial.distributions),param_name:distribution}self._trials[trial_id]=trialdef_apply_set_trial_state_values(self,log:dict[str,Any])->None:trial_id=log["trial_id"]ifnotself._trial_exists_and_updatable(trial_id,log):returnstate=TrialState(log["state"])ifstate==self._trials[trial_id].stateandstate==TrialState.RUNNING:returntrial=copy.copy(self._trials[trial_id])ifstate==TrialState.RUNNING:trial.datetime_start=datetime.datetime.fromisoformat(log["datetime_start"])ifself._is_issued_by_this_worker(log):self._worker_id_to_owned_trial_id[self.worker_id]=trial_idifstate.is_finished():trial.datetime_complete=datetime.datetime.fromisoformat(log["datetime_complete"])trial.state=stateiflog["values"]isnotNone:trial.values=log["values"]self._trials[trial_id]=trialdef_apply_set_trial_intermediate_value(self,log:dict[str,Any])->None:trial_id=log["trial_id"]ifself._trial_exists_and_updatable(trial_id,log):trial=copy.copy(self._trials[trial_id])trial.intermediate_values={**copy.copy(trial.intermediate_values),log["step"]:log["intermediate_value"],}self._trials[trial_id]=trialdef_apply_set_trial_user_attr(self,log:dict[str,Any])->None:trial_id=log["trial_id"]ifself._trial_exists_and_updatable(trial_id,log):assertlen(log["user_attr"])==1trial=copy.copy(self._trials[trial_id])trial.user_attrs={**copy.copy(trial.user_attrs),**log["user_attr"]}self._trials[trial_id]=trialdef_apply_set_trial_system_attr(self,log:dict[str,Any])->None:trial_id=log["trial_id"]ifself._trial_exists_and_updatable(trial_id,log):assertlen(log["system_attr"])==1trial=copy.copy(self._trials[trial_id])trial.system_attrs={**copy.copy(trial.system_attrs),**log["system_attr"],}self._trials[trial_id]=trialdef_trial_exists_and_updatable(self,trial_id:int,log:dict[str,Any])->bool:iftrial_idnotinself._trials:ifself._is_issued_by_this_worker(log):raiseKeyError(NOT_FOUND_MSG)returnFalseelifself._trials[trial_id].state.is_finished():ifself._is_issued_by_this_worker(log):raiseUpdateFinishedTrialError("Trial#{} has already finished and can not be updated.".format(self._trials[trial_id].number))returnFalseelse:returnTrue