from clearml.automation import UniformParameterRange, UniformIntegerParameterRange, DiscreteParameterRange from clearml.automation import HyperParameterOptimizer from clearml.automation.optuna import OptimizerOptuna from optuna.pruners import HyperbandPruner, MedianPruner from clearml import Task task = Task.init( project_name='SpoterEmbedding', task_name='Automatic Hyper-Parameter Optimization', task_type=Task.TaskTypes.optimizer, reuse_last_task_id=False ) optimizer = HyperParameterOptimizer( # specifying the task to be optimized, task must be in system already so it can be cloned base_task_id="4504e0b3ec6745249d3d4c94d3d40652", # setting the hyperparameters to optimize hyper_parameters=[ # epochs: DiscreteParameterRange('Args/epochs', [200]), # learning rate UniformParameterRange('Args/lr', 0.000001, 0.01), # optimizer DiscreteParameterRange('Args/optimizer', ['ADAM', 'SGD']), # vector length UniformIntegerParameterRange('Args/vector_length', 10, 100), ], # setting the objective metric we want to maximize/minimize objective_metric_title='train_loss', objective_metric_series='loss', objective_metric_sign='min', # setting optimizer optimizer_class=OptimizerOptuna, # configuring optimization parameters execution_queue='default', optimization_time_limit=360, compute_time_limit=480, total_max_jobs=20, min_iteration_per_job=0, max_iteration_per_job=150000, pool_period_min=0.1, save_top_k_tasks_only=3, optuna_pruner=MedianPruner(), ) def job_complete_callback( job_id, # type: str objective_value, # type: float objective_iteration, # type: int job_parameters, # type: dict top_performance_job_id # type: str ): print('Job completed!', job_id, objective_value, objective_iteration, job_parameters) if job_id == top_performance_job_id: print('WOOT WOOT we broke the record! Objective reached {}'.format(objective_value)) task.execute_remotely(queue_name='hypertuning', exit_process=True) optimizer.set_report_period(0.3) optimizer.start(job_complete_callback=job_complete_callback) optimizer.wait() top_exp = optimizer.get_top_experiments(top_k=3) print([t.id for t in top_exp]) optimizer.stop()