优化代码import numpy as np from PIL import Image from sklearn import svm from sklearn.model_selection import train_test_split import os import matplotlib.pyplot as plt # 定义图像文件夹路径和类别 cat_path = "cats/" dog_path = "dogs/" cat_label = 0 dog_label = 1 # 定义图像预处理函数 def preprocess_image(file_path): # 读取图像并转换为灰度图像 img = Image.open(file_path).convert('L') # 调整图像尺寸 img = img.resize((100, 100)) # 将图像转换为 Numpy 数组 img_array = np.array(img) # 将二维数组展平为一维数组 img_array = img_array.reshape(-1) return img_array # 读取猫和狗的图像并转换成 Numpy 数组 X = [] y = [] for file_name in os.listdir(cat_path): file_path = os.path.join(cat_path, file_name) img_array = preprocess_image(file_path) X.append(img_array) y.append(cat_label) for file_name in os.listdir(dog_path): file_path = os.path.join(dog_path, file_name) img_array = preprocess_image(file_path) X.append(img_array) y.append(dog_label) X = np.array(X) y = np.array(y) # 将数据集划分为训练集和测试集 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3) # 训练 SVM 分类器 clf = svm.SVC(kernel='linear') clf.fit(X_train, y_train) # 在测试集上进行预测 y_pred = clf.predict(X_test) # 计算测试集上的准确率 accuracy = np.mean(y_pred == y_test) print("Accuracy:", accuracy) # 显示测试集中的前 16 张图像和它们的预测结果 fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(8, 8)) for i, ax in enumerate(axes.flat): # 显示图像 ax.imshow(X_test[i].reshape(100, 100), cmap='gray') # 显示预测结果和标签 if y_pred[i] == 0: ax.set_xlabel("Cat") else: ax.set_xlabel("Dog") ax.set_xticks([]) ax.set_yticks([]) plt.show()
时间: 2024-03-24 22:40:02 浏览: 26
这段代码已经很简洁了,不过还有一些可以优化的地方:
1. 在导入模块时,可以将不同模块分行导入,以提高代码可读性。
2. 在预处理函数中,可以将调整图像尺寸和转换为 Numpy 数组合并为一步,以减少代码行数。
3. 在循环读取图像并转换为 Numpy 数组时,可以使用列表推导式代替 for 循环,以提高代码效率。
4. 在计算准确率时,可以使用 sklearn.metrics 中的 accuracy_score 函数代替自己编写的代码。
下面是优化后的代码:
```python
import numpy as np
from PIL import Image
from sklearn import svm, metrics
from sklearn.model_selection import train_test_split
import os
import matplotlib.pyplot as plt
# 定义图像文件夹路径和类别
cat_path = "cats/"
dog_path = "dogs/"
cat_label = 0
dog_label = 1
# 定义图像预处理函数
def preprocess_image(file_path):
# 读取图像并转换为灰度图像,调整尺寸并转换为 Numpy 数组
img = np.array(Image.open(file_path).convert('L').resize((100, 100)))
# 将二维数组展平为一维数组
img_array = img.reshape(-1)
return img_array
# 读取猫和狗的图像并转换成 Numpy 数组
cat_files = [os.path.join(cat_path, f) for f in os.listdir(cat_path)]
dog_files = [os.path.join(dog_path, f) for f in os.listdir(dog_path)]
X = [preprocess_image(f) for f in cat_files + dog_files]
y = np.concatenate([np.full(len(cat_files), cat_label), np.full(len(dog_files), dog_label)])
# 将数据集划分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)
# 训练 SVM 分类器
clf = svm.SVC(kernel='linear')
clf.fit(X_train, y_train)
# 在测试集上进行预测
y_pred = clf.predict(X_test)
# 计算测试集上的准确率
accuracy = metrics.accuracy_score(y_test, y_pred)
print("Accuracy:", accuracy)
# 显示测试集中的前 16 张图像和它们的预测结果
fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(8, 8))
for i, ax in enumerate(axes.flat):
# 显示图像
ax.imshow(X_test[i].reshape(100, 100), cmap='gray')
# 显示预测结果和标签
if y_pred[i] == 0:
ax.set_xlabel("Cat")
else:
ax.set_xlabel("Dog")
ax.set_xticks([])
ax.set_yticks([])
plt.show()
```
相关推荐
![pdf](https://img-home.csdnimg.cn/images/20210720083512.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)