Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import tensorflow as tf
- import optuna
- import sklearn.datasets
- from sklearn.model_selection import train_test_split
- class OptunaPruningHook(tf.train.SessionRunHook):
- def __init__(self, trial, estimator, metrics_name, is_higher_better):
- self.trial = trial
- self.estimator = estimator
- self.current_step = -1
- self.metrics_name = metrics_name
- self.is_higher_better = is_higher_better
- def after_run(self, run_context, run_value):
- eval_metrics = tf.contrib.estimator.read_eval_metrics(self.estimator.eval_dir())
- if eval_metrics:
- step = next(reversed(eval_metrics))
- latest_eval_metrics = eval_metrics[step]
- # If there exists a new evaluation summary
- if step > self.current_step:
- if self.is_higher_better:
- current_score = 1.0 - latest_eval_metrics[self.metrics_name]
- else:
- current_score = latest_eval_metrics[self.metrics_name]
- self.trial.report(current_score, step=step)
- self.current_step = step
- if self.trial.should_prune(self.current_step):
- message = "Trial was pruned at iteration {}.".format(self.current_step)
- raise optuna.structs.TrialPruned(message)
- def create_input_fn():
- iris = sklearn.datasets.load_iris()
- x, y = iris.data, iris.target
- x_train, x_eval, y_train, y_eval = train_test_split(x, y, test_size=0.5, random_state=42)
- def _train_input_fn():
- dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
- dataset = dataset.shuffle(128).repeat().batch(16)
- iterator = dataset.make_one_shot_iterator()
- features, labels = iterator.get_next()
- return {"x": features}, labels
- def _eval_input_fn():
- dataset = tf.data.Dataset.from_tensor_slices((x_eval, y_eval))
- dataset = dataset.batch(32)
- iterator = dataset.make_one_shot_iterator()
- features, labels = iterator.get_next()
- return {"x": features}, labels
- return _train_input_fn, _eval_input_fn
- def objective(trial):
- save_steps = 50
- # Create input functions for train and eval
- train_input_fn, eval_input_fn = create_input_fn()
- # Hyper parameters to be tuned with Optuna
- learning_rate = trial.suggest_loguniform("learning_rate", 1e-5, 1e-1)
- # Create Estimator config
- config = tf.estimator.RunConfig(save_summary_steps=save_steps, save_checkpoints_steps=save_steps)
- # Create Estimator
- clf = tf.estimator.DNNClassifier(
- feature_columns=[tf.feature_column.numeric_column(key="x", shape=[4])],
- n_classes=3,
- hidden_units=[],
- optimizer=tf.train.GradientDescentOptimizer(learning_rate=learning_rate),
- model_dir="outputs_pruning/lr_{}".format(learning_rate),
- config=config
- )
- # Create hooks
- early_stopping_hook = tf.contrib.estimator.stop_if_no_decrease_hook(clf, "accuracy", save_steps)
- optuna_pruning_hook = OptunaPruningHook(trial=trial, estimator=clf, metrics_name="accuracy", is_higher_better=True)
- hooks = [early_stopping_hook, optuna_pruning_hook]
- # Create TrainSpec and EvalSpec
- train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=500, hooks=hooks)
- eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn, steps=1000, start_delay_secs=0, throttle_secs=0)
- # Run training and evaluation
- tf.estimator.train_and_evaluate(clf, train_spec, eval_spec)
- result = clf.evaluate(input_fn=eval_input_fn, steps=100)
- accuracy = result["accuracy"]
- return 1.0 - accuracy
- if __name__ == "__main__":
- study = optuna.create_study(pruner=optuna.pruners.MedianPruner(n_warmup_steps=100))
- study.optimize(objective, n_trials=20)
- print(study.best_trial)
- print([t.state for t in study.trials])
Add Comment
Please, Sign In to add comment