[docs]@experimental_class("2.8.0")classPatientPruner(BasePruner):"""Pruner which wraps another pruner with tolerance. This pruner monitors intermediate values in a trial and prunes the trial if the improvement in the intermediate values after a patience period is less than a threshold. The pruner handles NaN values in the following manner: 1. If all intermediate values before or during the patient period are NaN, the trial will not be pruned 2. During the pruning calculations, NaN values are ignored. Only valid numeric values are considered. Example: .. testcode:: import numpy as np from sklearn.datasets import load_iris from sklearn.linear_model import SGDClassifier from sklearn.model_selection import train_test_split import optuna X, y = load_iris(return_X_y=True) X_train, X_valid, y_train, y_valid = train_test_split(X, y) classes = np.unique(y) def objective(trial): alpha = trial.suggest_float("alpha", 0.0, 1.0) clf = SGDClassifier(alpha=alpha) n_train_iter = 100 for step in range(n_train_iter): clf.partial_fit(X_train, y_train, classes=classes) intermediate_value = clf.score(X_valid, y_valid) trial.report(intermediate_value, step) if trial.should_prune(): raise optuna.TrialPruned() return clf.score(X_valid, y_valid) study = optuna.create_study( direction="maximize", pruner=optuna.pruners.PatientPruner(optuna.pruners.MedianPruner(), patience=1), ) study.optimize(objective, n_trials=20) Args: wrapped_pruner: Wrapped pruner to perform pruning when :class:`~optuna.pruners.PatientPruner` allows a trial to be pruned. If it is :obj:`None`, this pruner is equivalent to early-stopping taken the intermediate values in the individual trial. patience: Pruning is disabled until the objective doesn't improve for ``patience`` consecutive steps. min_delta: Tolerance value to check whether or not the objective improves. This value should be non-negative. """def__init__(self,wrapped_pruner:BasePruner|None,patience:int,min_delta:float=0.0)->None:ifpatience<0:raiseValueError(f"patience cannot be negative but got {patience}.")ifmin_delta<0:raiseValueError(f"min_delta cannot be negative but got {min_delta}.")self._wrapped_pruner=wrapped_prunerself._patience=patienceself._min_delta=min_delta
[docs]defprune(self,study:"optuna.study.Study",trial:"optuna.trial.FrozenTrial")->bool:step=trial.last_stepifstepisNone:returnFalseintermediate_values=trial.intermediate_valuessteps=np.asarray(list(intermediate_values.keys()))# Do not prune if number of step to determine are insufficient.ifsteps.size<=self._patience+1:returnFalsesteps.sort()# This is the score patience steps agosteps_before_patience=steps[:-self._patience-1]scores_before_patience=np.asarray(list(intermediate_values[step]forstepinsteps_before_patience))# And these are the scores after thatsteps_after_patience=steps[-self._patience-1:]scores_after_patience=np.asarray(list(intermediate_values[step]forstepinsteps_after_patience))direction=study.directionifdirection==StudyDirection.MINIMIZE:maybe_prune=np.nanmin(scores_before_patience)+self._min_delta<np.nanmin(scores_after_patience)else:maybe_prune=np.nanmax(scores_before_patience)-self._min_delta>np.nanmax(scores_after_patience)ifmaybe_prune:ifself._wrapped_prunerisnotNone:returnself._wrapped_pruner.prune(study,trial)else:returnTrueelse:returnFalse