将一个.csv文件有多组多维数据,第一列是标签名,利用SVM进行分类的机器学习代码,预测最后一列数据属于哪一类,最后可视化展现各类分类情况效果
时间: 2023-12-23 08:01:54 浏览: 80
使用SVM做分类
好的,以下是一个简单的Python代码示例,使用SVM对多组多维数据进行分类,并预测最后一列数据属于哪一类,最后使用matplotlib库可视化展现各类分类情况效果:
``` python
import pandas as pd
import numpy as np
from sklearn import svm
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
# 加载数据
data = pd.read_csv('data.csv')
x = data.iloc[:, :-1].values
y = data.iloc[:, 0].values
predict_x = data.iloc[:, -1].values.reshape(-1, 1)
# 将标签名编码为数字
le = LabelEncoder()
y = le.fit_transform(y)
# 训练SVM模型
clf = svm.SVC(kernel='linear', C=1)
clf.fit(x, y)
# 预测分类结果
y_pred = clf.predict(x)
predict_y = clf.predict(predict_x)
# 可视化分类结果
fig = plt.figure()
if x.shape[1] == 2:
# 二维平面展示
plt.scatter(x[:, 0], x[:, 1], c=y_pred)
plt.scatter(predict_x[:, 0], np.zeros_like(predict_y), c=predict_y, marker='x', s=1000)
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
else:
# 三维空间展示
ax = fig.add_subplot(111, projection='3d')
ax.scatter(x[:, 0], x[:, 1], x[:, 2], c=y_pred)
ax.scatter(predict_x[:, 0], np.zeros_like(predict_y), np.zeros_like(predict_y), c=predict_y, marker='x', s=1000)
ax.set_xlabel('Feature 1')
ax.set_ylabel('Feature 2')
ax.set_zlabel('Feature 3')
plt.show()
```
上述代码中,我们首先使用pandas库加载名为"data.csv"的.csv文件,然后将数据集分为特征数据和标签数据。由于标签名是字符串类型的,我们需要将其编码为数字类型的标签。接下来,我们使用SVM训练模型,并使用训练好的模型对特征数据进行分类预测。最后,我们使用matplotlib库将分类结果可视化展示。如果特征数据为二维数据,则在图表中展示二维平面,否则在三维空间中展示。在图表中,不同标签名对应的数据点颜色不同,以显示各类分类情况。同时,我们用红色的叉号(x)标注出预测数据点的位置,并将其分类结果用红色或蓝色的颜色表示,以便比较预测结果和实际分类情况的差异。
阅读全文