线性回归三维spider代码
时间: 2024-10-18 16:11:49 浏览: 27
线性回归三维蜘蛛图是一种可视化工具,用于展示一个因变量(通常是预测值)如何依赖于两个自变量(特征)。在Python中,我们可以使用matplotlib库的scatter和plot_surface函数结合来创建这种图表。下面是一个简单的例子,假设我们已经有了训练数据集(X_train, y_train),并且已经训练了一个线性回归模型:
```python
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
# 假设我们有训练数据
X = X_train[:, :2] # 我们只取前两个特征作为二维坐标
y = y_train
# 创建并训练线性回归模型
model = LinearRegression()
model.fit(X, y)
# 预测值网格
x_min, x_max = X.min() - 1, X.max() + 1
y_min, y_max = y.min() - 1, y.max() + 1
xx, yy = np.meshgrid(np.linspace(x_min, x_max, 100),
np.linspace(y_min, y_max, 100))
Z = model.predict(np.c_[xx.ravel(), yy.ravel()]).reshape(xx.shape)
# 绘制三维散点图和曲面
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(X[:, 0], X[:, 1], y, c='r', s=50)
ax.plot_surface(xx, yy, Z, cmap='viridis')
# 设置轴标签
ax.set_xlabel('特征1')
ax.set_ylabel('特征2')
ax.set_zlabel('目标变量')
plt.title('三维蜘蛛图示例')
plt.show()
阅读全文