import numpy as np import matplotlib.pyplot as plt from sklearn import svm from sklearn.datasets import make_blobs from sklearn import model_selection from sklearn.metrics import f1_score def show_svm(a, b, bt): plt.figure(bt) plt.title('SVM with ' + bt) # 建立图像坐标 axis = plt.gca() plt.scatter(a[:, 0], a[:, 1], c=b, s=30) xlim = [a[:, 0].min(), a[:, 0].max()] ylim = [a[:, 1].min(), a[:, 1].max()] # 生成两个等差数列 xx = np.linspace(xlim[0], xlim[1], 50) yy = np.linspace(ylim[0], ylim[1], 50) X, Y = np.meshgrid(xx, yy) xy = np.vstack([X.ravel(), Y.ravel()]).T Z = clf.decision_function(xy).reshape(X.shape) # 画出分界线 axis.contour(X, Y, Z, colors='k', levels=[-1, 0, 1], alpha=0.5, linestyles=['--', '-', '--']) axis.scatter(clf.support_vectors_[:, 0], clf.support_vectors_[:, 1], s=200, linewidths=1, facecolors='none') if __name__ == '__main__': # data = np.loadtxt('separable_data.txt', delimiter=',') # data = np.loadtxt('non_separable_data.txt', delimiter=',') # data = np.loadtxt('banknote.txt', delimiter=',') data = np.loadtxt('ionosphere.txt', delimiter=',') # data = np.loadtxt('wdbc.txt', delimiter=',') X = data[:, 0:-1] y = data[:, -1] """标签中有一类标签为1""" y = y + 1 ymin = min(y) if not (1 in set(y)): ll = max(list(set(y))) + 1 for i in range(len(y)): if y[i] == ymin: y[i] = 1 # 建立一个线性核(多项式核)的SVM clf = svm.SVC(kernel='linear') clf.fit(X, y) """显示所有数据用于训练后的可视化结果""" show_svm(X, y, 'all dataset') """divide the data into two sections: training and test datasets""" X_train, X_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1, random_state=42) """training""" clf = svm.SVC(kernel='linear')#线性内核 # clf = svm.SVC(kernel='poly')# 多项式内核 # clf = svm.SVC(kernel='sigmoid')# Sigmoid内核 clf.fit(X_train, y_train) # show_svm(X_train, y_train, 'training dataset') """predict""" pred = clf.predict(X_test) pred = np.array(pred) y_test = np.array(y_test) print(f'SVM 的预测结果 f1-score:{f1_score(y_test, pred)}') # plt.show()结果与分析
时间: 2024-04-22 10:22:18 浏览: 150
Matplotlib.pyplot 三维绘图的实现示例
这段代码实现了一个支持向量机(SVM)分类器,并对数据进行了可视化展示。具体来说,它包括以下几个主要步骤:
1. 导入需要使用的库,如 numpy、matplotlib、sklearn 等。
2. 定义一个名为 show_svm 的函数,用于对数据进行可视化展示。函数有三个参数:a 表示数据的特征值,b 表示数据的标签,bt 表示图像的标题。
3. 在主函数中,首先读取数据,然后将标签中的一类标签设为 1,并划分数据集为训练集和测试集。
4. 建立一个线性核的 SVM 分类器,并用训练集进行训练。
5. 对训练后的模型进行可视化展示,包括所有数据集和训练数据集的分界线。
6. 对测试集进行预测,并计算预测结果的 f1-score。
7. 最后展示结果。
总的来说,这段代码的功能是对数据进行 SVM 分类,并对分类结果进行可视化和评估。其中,SVM 模型的核函数可以通过修改代码中的 kernel 参数进行更换。
阅读全文