pytorch 预测函数代码
时间: 2023-12-19 09:25:27 浏览: 97
以下是一个PyTorch中的简单预测函数代码示例:
```python
import torch
import torch.nn as nn
# 定义一个网络模型类,继承自nn.Module类
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# 定义预测函数
def predict(model, data):
# 将数据转换为Tensor对象
inputs = torch.Tensor(data)
# 将模型设为评估模式
model.eval()
# 禁用梯度计算
with torch.no_grad():
# 将输入数据传入模型中进行预测
outputs = model(inputs)
# 将输出转换为numpy数组
predictions = outputs.numpy()
# 返回预测结果
return predictions
```
上述代码定义了一个简单的网络模型类Net,该模型有两个全连接层,其中第一个层的输入维度为10,输出维度为5,第二个层的输入维度为5,输出维度为1。同时,还定义了一个预测函数predict,该函数接受一个PyTorch模型对象和一个数据数组作为输入,返回一个预测结果数组。在预测函数中,首先将数据转换为Tensor对象,然后将模型设为评估模式并禁用梯度计算,接着将输入数据传入模型中进行预测,并将预测结果转换为numpy数组后返回。
阅读全文