那我如何记录每一次改变超参数相对应的图表呢
时间: 2024-03-03 10:47:31 浏览: 16
您可以使用一个列表来存储每个超参数组合的训练历史记录,然后使用字典来存储每个超参数组合的准确率和验证准确率。下面是一个示例代码:
```python
# 超参数列表
embedding_sizes = [64, 128, 256]
filter_sizes = [[3, 4, 5], [4, 5, 6], [5, 6, 7]]
num_filters = [64, 128, 256]
# 记录每个超参数组合的训练历史记录和准确率
histories = []
accuracies = {}
# 遍历所有超参数组合
for i in range(len(embedding_sizes)):
for j in range(len(filter_sizes)):
for k in range(len(num_filters)):
# 定义模型和超参数
embedding_size = embedding_sizes[i]
filter_size = filter_sizes[j]
num_filter = num_filters[k]
model = TextCNN(max_len, vocab_size, embedding_size, filter_size, num_filter, num_classes)
# 训练模型并存储历史记录
hist = model.fit(X_train, y_train, batch_size=128, epochs=10, validation_data=(X_test, y_test))
histories.append(hist.history)
# 存储准确率和验证准确率
key = f"embedding_size={embedding_size},filter_size={filter_size},num_filter={num_filter}"
accuracies[key] = {'accuracy': hist.history['accuracy'], 'val_accuracy': hist.history['val_accuracy']}
```
在上面的代码中,我们使用三层嵌套循环遍历所有可能的超参数组合。对于每个超参数组合,我们定义一个新的模型并训练它,然后将训练历史记录存储在一个列表中(histories)。我们还使用一个字典来存储每个超参数组合的准确率和验证准确率。键是一个字符串,包含每个超参数的值,例如“embedding_size=64,filter_size=[3,4,5],num_filter=64”。
一旦您完成训练,您就可以使用以下代码来绘制每个超参数组合的准确率图表:
```python
# 绘制每个超参数组合的准确率图表
for key in accuracies:
acc = accuracies[key]['accuracy']
val_acc = accuracies[key]['val_accuracy']
plt.plot(acc, label='accuracy')
plt.plot(val_acc, label='val_accuracy')
plt.title(key)
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
```
在上面的代码中,我们使用一个循环遍历每个超参数组合的准确率字典。对于每个超参数组合,我们使用对应的准确率列表来绘制准确率图表,并使用超参数的字符串键作为标题。