optuna 保存训练过程
时间: 2024-06-12 19:08:12 浏览: 211
optuna-with-pytorch:Optuna + PyTorch示例
Optuna提供了多种方法来保存训练过程,以下是其中几种:
1.使用Optuna内置的JSON格式记录器:
```python
import optuna
def objective(trial):
...
study = optuna.create_study(study_name='example_study', storage='sqlite:///example.db')
study.optimize(objective, n_trials=100)
# 保存JSON格式文件
optuna.visualization.plot_optimization_history(study).write_image('optimization_history.png')
study.trials_dataframe().to_json('trials.json')
```
2.使用Optuna内置的CSV格式记录器:
```python
import optuna
def objective(trial):
...
study = optuna.create_study(study_name='example_study', storage='sqlite:///example.db', pruner=optuna.pruners.MedianPruner())
study.optimize(objective, n_trials=100)
# 保存CSV格式文件
optuna.visualization.plot_optimization_history(study).write_image('optimization_history.png')
study.trials_dataframe().to_csv('trials.csv')
```
3.自定义记录器,保存更多的训练信息:
```python
import optuna
import pandas as pd
class CustomRecorder(optuna.study.Study):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._results = []
def _log(self, trial_id, values):
trial = self._storage.get_trial(trial_id)
self._results.append((trial.datetime_start, trial.value, trial.params, trial.user_attrs, trial.system_attrs))
def save_results(self, filename):
df = pd.DataFrame(self._results, columns=['datetime_start', 'value', 'params', 'user_attrs', 'system_attrs'])
df.to_pickle(filename)
def objective(trial):
...
study = optuna.create_study(study_name='example_study', storage='sqlite:///example.db', pruner=optuna.pruners.MedianPruner())
study.optimize(objective, n_trials=100)
# 保存自定义格式文件
recorder = CustomRecorder(study_name='example_study', storage='sqlite:///example.db')
recorder.save_results('results.pkl')
```
阅读全文