试用 sklearn 的 svc 函数对西瓜数据集3.0a进行分类,画出散点图和分离超平面。并与logistic分类方法比较。
时间: 2023-06-12 11:07:31 浏览: 119
西瓜数据集3.0-数据集
好的,我来给您讲一下如何使用sklearn的SVC函数对西瓜数据集3.0a进行分类。
首先,我们需要导入相关的库和数据集。代码如下:
```python
import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm, datasets
# 导入西瓜数据集3.0a
data = np.loadtxt('watermelon3.0a.txt', delimiter=',')
X = data[:, 1:3] # 取特征
y = data[:, 3] # 取标签
```
接下来,我们可以使用SVC函数进行分类。代码如下:
```python
# 创建SVC分类器对象
svc = svm.SVC(kernel='linear', C=1.0)
# 训练模型
svc.fit(X, y)
# 预测
y_pred = svc.predict(X)
```
然后,我们可以利用matplotlib库绘制散点图和分离超平面。代码如下:
```python
# 画出散点图
plt.scatter(X[:, 0], X[:, 1], c=y)
# 画出分离超平面
w = svc.coef_[0]
a = -w[0] / w[1]
xx = np.linspace(0, 1)
yy = a * xx - (svc.intercept_[0]) / w[1]
plt.plot(xx, yy)
# 设置坐标轴范围
plt.xlim(0, 1)
plt.ylim(0, 1)
# 显示图像
plt.show()
```
最后,我们可以将SVC分类方法和Logistic分类方法进行比较。代码如下:
```python
from sklearn.linear_model import LogisticRegression
# 创建Logistic分类器对象
lr = LogisticRegression()
# 训练模型
lr.fit(X, y)
# 预测
y_pred_lr = lr.predict(X)
# 画出散点图和分离超平面
plt.figure(figsize=(10, 5))
plt.subplot(121)
plt.title('SVC')
plt.scatter(X[:, 0], X[:, 1], c=y)
plt.plot(xx, yy)
plt.xlim(0, 1)
plt.ylim(0, 1)
plt.subplot(122)
plt.title('Logistic')
plt.scatter(X[:, 0], X[:, 1], c=y)
w = lr.coef_[0]
a = -w[0] / w[1]
xx = np.linspace(0, 1)
yy = a * xx - (lr.intercept_[0]) / w[1]
plt.plot(xx, yy)
plt.xlim(0, 1)
plt.ylim(0, 1)
plt.show()
```
这样,我们就可以得到SVC分类方法和Logistic分类方法的比较结果了。
阅读全文