给一个固定的数据集,并给出该数据集对应文件要求的代码
时间: 2025-01-03 10:44:11 浏览: 4
根据您提供的《机器学习实验7实验样例.doc》中的描述,以下是使用固定手写数据集并完成相应任务的Python代码示例:
```python
import time
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
import numpy as np
# 1.1 读取给定数据集
start_time = time.time()
mnist = datasets.fetch_openml('mnist_784')
X, y = mnist.data, mnist.target
print(f"载入训练数据...载入完成,耗时{time.time() - start_time:.6f}秒")
# 划分数据集
train_size = 0.8
random_state = 1
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=train_size, test_size=1-train_size, random_state=random_state)
# 输出图片个数和图片像素数目
num_images_train = len(X_train)
num_images_test = len(X_test)
pixels_per_image = X_train.shape[1]
image_shape = int(np.sqrt(pixels_per_image)), int(np.sqrt(pixels_per_image))
print(f"图片个数:{num_images_train + num_images_test},图片像素数目:{pixels_per_image},{image_shape}")
# 1.2 主要参数设置
C = 1000
kernel = 'rbf'
gamma = 1e-10
# 1.3 SVM 开始训练
svm = SVC(C=C, kernel=kernel, gamma=gamma)
start_time = time.time()
svm.fit(X_train, y_train)
training_time = time.time() - start_time
print(f"SVM开始训练...SVM训练结束,耗时{int(training_time // 60)}分钟{training_time % 60:.3f}秒")
# 1.4 输出SVM训练集准确率
y_pred_train = svm.predict(X_train)
train_accuracy = accuracy_score(y_train, y_pred_train)
print(f"SVM训练集准确率:{train_accuracy * 100:.3f}%,耗时{int(training_time // 60)}分钟{training_time % 60:.3f}秒")
# 1.5 输出SVM测试集准确率
start_time = time.time()
y_pred_test = svm.predict(X_test)
test_time = time.time() - start_time
test_accuracy = accuracy_score(y_test, y_pred_test)
print(f"SVM测试集准确率:{test_accuracy * 100:.3f}%,耗时{int(test_time // 60)}分钟{test_time % 60:.3f}秒")
# 1.6 输出分类错误的图片与结果
misclassified_indices = np.where(y_test != y_pred_test)[0]
for idx in misclassified_indices:
print(f"图像索引: {idx}, 实际标签: {y_test.iloc[idx]}, 预测标签: {y_pred_test[idx]}")
```
### 代码说明:
1. **读取数据集**:使用 `datasets.fetch_openml` 加载 MNIST 数据集。
2. **划分数据集**:将数据集划分为训练集和测试集,比例为 80% 和 20%。
3. **输出基本信息**:打印图片总数、每个图片的像素数以及图片的形状。
4. **设置 SVM 参数**:设置 SVM 的主要参数,包括 C、核函数和 gamma。
5. **训练 SVM 模型**:使用训练集训练 SVM 模型,并记录训练时间。
6. **评估模型性能**:计算并输出训练集和测试集的准确率。
7. **输出分类错误的图片**:找出并打印所有分类错误的图片及其实际标签和预测标签。
希望这段代码能满足您的需求!如果有任何进一步的问题,请随时告诉我。
阅读全文