import torch import matplotlib.pyplot as plt def features(x): x = x.unsqueeze(1) return torch.cat([x ** i for i in range(1, 5)], 1) x_weight = torch.Tensor([-1.13, -2.14, 3.15, -0.01]).unsqueeze(1) b = torch.Tensor([0.512]) def target(x): return x.mm(x_weight) + b.item() def get_batch_data(batch_size): batch_x = torch.rand(batch_size) features_x = features(batch_x) target_y = target(features_x) return features_x, target_y class PR(torch.nn.Module): def __init__(self): super(PR, self).__init__() self.poly = torch.nn.Linear(4, 1) def forward(self, x): return self.poly(x) epochs = 10000 batch_size = 32 model = PR() criterion = torch.nn.MSELoss() optimizer = torch.optim.SGD(model.parameters(), 0.01) for i in range(epochs): batch_x, batch_y = get_batch_data(batch_size) out = model(batch_x) loss = criterion(out, batch_y) optimizer.zero_grad() loss.backward() optimizer.step() if(i % 100 == 0): print(f"Epoch:{i+1}/{epochs}, loss:{loss.item}") if(i%1000 == 0): predict = model(features(batch_x)) plt.plot(batch_x.data.numpy(), predict.squeeze().data.numpy(), "r") loss = criterion(predict, batch_y) plt.title(f"Loss:{loss.item()}") plt.xlabel("X") plt.ylabel("Y") plt.scatter(batch_x, batch_y) plt.show()哪里出错了
时间: 2023-11-12 13:04:18 浏览: 84
代码有一个小错误,第26行应该是`print(f"Epoch:{i+1}/{epochs}, loss:{loss.item()}")`,需要加上括号。另外,在第34行,应该是`loss = criterion(out, batch_y)`,而不是`loss = criterion(predict, batch_y)`。正确的代码如下所示:
```
import torch
import matplotlib.pyplot as plt
def features(x):
x = x.unsqueeze(1)
return torch.cat([x ** i for i in range(1, 5)], 1)
x_weight = torch.Tensor([-1.13, -2.14, 3.15, -0.01]).unsqueeze(1)
b = torch.Tensor([0.512])
def target(x):
return x.mm(x_weight) + b.item()
def get_batch_data(batch_size):
batch_x = torch.rand(batch_size)
features_x = features(batch_x)
target_y = target(features_x)
return features_x, target_y
class PR(torch.nn.Module):
def __init__(self):
super(PR, self).__init__()
self.poly = torch.nn.Linear(4, 1)
def forward(self, x):
return self.poly(x)
epochs = 10000
batch_size = 32
model = PR()
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), 0.01)
for i in range(epochs):
batch_x, batch_y = get_batch_data(batch_size)
out = model(batch_x)
loss = criterion(out, batch_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if(i % 100 == 0):
print(f"Epoch:{i+1}/{epochs}, loss:{loss.item()}")
if(i%1000 == 0):
predict = model(features(batch_x))
plt.plot(batch_x.data.numpy(), predict.squeeze().data.numpy(), "r")
loss = criterion(out, batch_y)
plt.title(f"Loss:{loss.item()}")
plt.xlabel("X")
plt.ylabel("Y")
plt.scatter(batch_x, batch_y)
plt.show()
```
阅读全文