用python实现在鸢尾花数据集和 MNIST 手写数字数据库上,训练一个SVM,基于交叉验证(数据分为训练集(80%)与测试集(20%) ),计算分类器性能,包括准确率、查全率(召回率)、查准率(精准率)、F1值、混淆矩阵等。并输出运行结果。
时间: 2023-12-10 16:41:44 浏览: 109
mnist_1.zip_mnist 手写识别_mnist 数字识别_mnist数据集实现对手写数字识别_手写识别 MNIST
好的,我可以为您提供这个实现。
首先,我们需要引入必要的库和数据集:
```python
from sklearn import datasets
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.svm import SVC
# 加载鸢尾花数据集
iris = datasets.load_iris()
X_iris = iris.data
y_iris = iris.target
# 加载MNIST手写数字数据库
digits = datasets.load_digits()
X_digits = digits.images.reshape((len(digits.images), -1))
y_digits = digits.target
```
接下来,我们定义一个函数 `evaluate_classifier` 来评估分类器的性能:
```python
def evaluate_classifier(X_train, y_train, X_test, y_test, classifier):
# 训练分类器
classifier.fit(X_train, y_train)
# 在测试集上进行预测
y_pred = classifier.predict(X_test)
# 输出混淆矩阵
print("Confusion Matrix:\n", confusion_matrix(y_test, y_pred))
# 输出分类报告
print("Classification Report:\n", classification_report(y_test, y_pred))
```
然后,我们定义一个函数 `train_and_evaluate_classifier` 来训练并评估分类器:
```python
def train_and_evaluate_classifier(X, y, classifier):
# 将数据集分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 评估分类器性能
evaluate_classifier(X_train, y_train, X_test, y_test, classifier)
```
最后,我们可以使用 SVM 分别在鸢尾花数据集和 MNIST 手写数字数据库上进行训练和评估:
```python
# 在鸢尾花数据集上训练和评估SVM
svm_iris = SVC()
train_and_evaluate_classifier(X_iris, y_iris, svm_iris)
# 在MNIST手写数字数据库上训练和评估SVM
svm_digits = SVC()
train_and_evaluate_classifier(X_digits, y_digits, svm_digits)
```
完整代码如下:
```python
from sklearn import datasets
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.svm import SVC
# 加载鸢尾花数据集
iris = datasets.load_iris()
X_iris = iris.data
y_iris = iris.target
# 加载MNIST手写数字数据库
digits = datasets.load_digits()
X_digits = digits.images.reshape((len(digits.images), -1))
y_digits = digits.target
def evaluate_classifier(X_train, y_train, X_test, y_test, classifier):
# 训练分类器
classifier.fit(X_train, y_train)
# 在测试集上进行预测
y_pred = classifier.predict(X_test)
# 输出混淆矩阵
print("Confusion Matrix:\n", confusion_matrix(y_test, y_pred))
# 输出分类报告
print("Classification Report:\n", classification_report(y_test, y_pred))
def train_and_evaluate_classifier(X, y, classifier):
# 将数据集分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 评估分类器性能
evaluate_classifier(X_train, y_train, X_test, y_test, classifier)
# 在鸢尾花数据集上训练和评估SVM
svm_iris = SVC()
train_and_evaluate_classifier(X_iris, y_iris, svm_iris)
# 在MNIST手写数字数据库上训练和评估SVM
svm_digits = SVC()
train_and_evaluate_classifier(X_digits, y_digits, svm_digits)
```
运行该代码将会输出每个数据集上的分类器性能。
阅读全文