def compile_train(model, data_train, data_label,data_val,label_val, deep=True): if (deep == True): import matplotlib.pyplot as plt model.compile(loss='binary_crossentropy', optimizer = 'Adam', metrics=['accuracy']) history = model.fit(data_train, data_label,validation_data=(data_val,label_val) ,epochs=10, batch_size=64,verbose=1) plt.plot(history.history['accuracy']) plt.plot(history.history['loss']) plt.title('Train Accuracy and Loss') plt.ylabel('Accuracy/Loss') plt.xlabel('Epoch') plt.legend(['accuracy', 'loss'], loc='upper left') plt.show() plt.plot(history.history['val_accuracy']) plt.plot(history.history['val_loss']) plt.title('Verification Accuracy and Loss') plt.ylabel('Accuracy/Loss') plt.xlabel('Epoch') plt.legend(['accuracy', 'loss'], loc='upper left') plt.show() print('Model Compiled and Trained') return model改成对于决策树的函数
时间: 2024-01-16 07:03:25 浏览: 106
MNIST_data.rar
对于决策树,我们不需要编译和训练模型,因为决策树不是基于梯度下降的优化算法。相反,我们可以直接使用sklearn库中的DecisionTreeClassifier类来拟合训练数据和标签:
```
from sklearn.tree import DecisionTreeClassifier
def fit_decision_tree(data_train, data_label):
model = DecisionTreeClassifier(max_depth=10, random_state=42)
model.fit(data_train, data_label)
print('Decision Tree Model Trained')
return model
```
这里我们定义了一个新的函数fit_decision_tree,它接受训练数据和标签作为输入,并返回训练好的决策树模型。我们使用max_depth参数来限制树的深度,以避免过拟合。random_state参数可用于在每次运行代码时获得相同的结果。
阅读全文