编写一个用SVM分类MNIST手写集带有batch的迭代循环的代码,并将动态循环的训练集和测试集的准确率画出来
时间: 2023-03-28 07:02:21 浏览: 58
我可以回答这个问题。以下是一个用SVM分类MNIST手写集带有batch的迭代循环的代码:
```
from sklearn import svm
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import numpy as np
# 加载MNIST数据集
mnist = fetch_openml('mnist_784', version=1)
X, y = mnist['data'], mnist['target']
X = X / 255. # 数据归一化
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.2, random_state=42)
# 定义SVM分类器
clf = svm.SVC(kernel='linear')
# 定义batch大小和迭代次数
batch_size = 100
n_iterations = len(X_train) // batch_size
# 动态循环训练集和测试集
train_accs = []
test_accs = []
for iteration in range(n_iterations):
X_batch = X_train[iteration*batch_size:(iteration+1)*batch_size]
y_batch = y_train[iteration*batch_size:(iteration+1)*batch_size]
clf.fit(X_batch, y_batch)
train_acc = accuracy_score(y_batch, clf.predict(X_batch))
test_acc = accuracy_score(y_test, clf.predict(X_test))
train_accs.append(train_acc)
test_accs.append(test_acc)
# 画出动态循环的训练集和测试集的准确率
import matplotlib.pyplot as plt
plt.plot(np.arange(n_iterations), train_accs, label='train')
plt.plot(np.arange(n_iterations), test_accs, label='test')
plt.xlabel('iteration')
plt.ylabel('accuracy')
plt.legend()
plt.show()
```
希望这可以帮助你。