python 的pytroch ,举个简单例子
时间: 2024-04-30 07:17:51 浏览: 16
好的,一个简单的 PyTorch 示例是使用线性回归模型来拟合数据。以下是一个简单的示例代码:
```
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
# 构造数据
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = x.pow(2) + 0.2 * torch.rand(x.size())
# 定义模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc = nn.Linear(1, 1)
def forward(self, x):
out = self.fc(x)
return out
net = Net()
# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.1)
# 训练模型
for epoch in range(200):
prediction = net(x)
loss = criterion(prediction, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch % 10 == 0:
print('epoch {}, loss {}'.format(epoch, loss.item()))
# 可视化结果
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.show()
```
这个示例使用了 PyTorch 中的 nn.Module 和 nn.Linear 来构建线性回归模型,并且使用了 nn.MSELoss 作为损失函数,使用了 torch.optim.SGD 作为优化器,最后通过训练模型来拟合数据。
相关推荐
![pdf](https://img-home.csdnimg.cn/images/20210720083512.png)
![pdf](https://img-home.csdnimg.cn/images/20210720083512.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)