import numpy as np import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D x_data = [1.0, 2.0, 3.0, 4.0] y_data = [2.0, 4.0, 6.0, 8.0] fig = plt.figure() ax = fig.add_subplot(111, projection='3d') def forward(x): return x * w + b def loss(x, y): y_pred = forward(x) return (y_pred - y) * (y_pred - y) b_list = [] w_list = [] mse_list = [] for w in np.arange(0.0, 4.1, 0.1): loss_sum = 0 for b in np.arange(-2, 2, 0.1): print(f"w={w}, b={b}") for x_val, y_val in zip(x_data, y_data): y_pred_val = forward(x_val) loss_val = loss(x_val, y_val) loss_sum = loss_sum + loss_val print('/t', x_val, y_val, y_pred_val, loss_val) print('mse=', loss_sum/3) w_list.append(w) b_list.append(b) mse_list.append(loss_sum/3) ax.plot_surface(w_list, b_list, mse_list) ax.set_xlabel("w") ax.set_ylabel("b") ax.set_zlabel("Loss")哪里错了怎么改
时间: 2023-12-02 16:05:41 浏览: 168
import numpy as np import matplotlib.pyplot as plt import math
在代码中有几个问题需要修改:
1. 在loss函数中,应该使用平均损失而不是总损失,因此需要将返回的值除以样本数量。
2. 在循环中,应该遍历所有的b值,而不是仅仅使用np.arange(-2, 2, 0.1)生成的一组值。
3. 在绘制图形时,应该使用meshgrid函数生成所有的(w, b)坐标对,并且使用flatten函数将坐标矩阵展平成一维数组。
下面是修改后的代码:
```python
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
x_data = [1.0, 2.0, 3.0, 4.0]
y_data = [2.0, 4.0, 6.0, 8.0]
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
def forward(x, w, b):
return x * w + b
def loss(x, y, w, b):
y_pred = forward(x, w, b)
return np.mean((y_pred - y) ** 2)
w_list = np.arange(0.0, 4.1, 0.1)
b_list = np.arange(-2.0, 2.1, 0.1)
W, B = np.meshgrid(w_list, b_list)
mse_list = np.empty_like(W)
for i, w in enumerate(w_list):
for j, b in enumerate(b_list):
loss_sum = 0
for x_val, y_val in zip(x_data, y_data):
y_pred_val = forward(x_val, w, b)
loss_val = loss(x_val, y_val, w, b)
loss_sum += loss_val
mse_list[j, i] = loss_sum/len(x_data)
ax.plot_surface(W, B, mse_list)
ax.set_xlabel("w")
ax.set_ylabel("b")
ax.set_zlabel("Loss")
plt.show()
```
这个代码会绘制出一个3D图形,其中的x轴表示w,y轴表示b,z轴表示损失。
阅读全文