pytorch读取文件夹中的图片 残差网络模型训练,预测,保存结果示例代码
时间: 2023-05-23 22:00:38 浏览: 112
一个使用Androidstudio开发的校园通知APP
很高兴回答您的问题,以下是 PyTorch 读取文件夹中的图片、残差网络模型训练,预测,保存结果的示例代码:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
# 定义数据预处理
transform_train = transforms.Compose([
transforms.Resize(256),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
transform_test = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 加载训练集和测试集
trainset = ImageFolder(root='path/to/training/folder',
transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32,
shuffle=True, num_workers=2)
testset = ImageFolder(root='path/to/testing/folder',
transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=32,
shuffle=False, num_workers=2)
# 定义 ResNet18 模型
class ResNet18(nn.Module):
def __init__(self):
super(ResNet18, self).__init__()
self.resnet18 = torchvision.models.resnet18(pretrained=True)
self.fc = nn.Linear(1000, 2)
def forward(self, x):
out = self.resnet18(x)
out = self.fc(out)
return out
net = ResNet18()
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
# 训练模型
for epoch in range(10): # 训练 10 轮
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 100 == 99:
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 100))
running_loss = 0.0
# 测试模型
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the test images: %d %%' % (
100 * correct / total))
# 保存模型和预测结果
PATH = 'path/to/save/model'
torch.save(net.state_dict(), PATH)
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
for i in range(len(predicted)):
print('Predicted: ', predicted[i])
print('Label: ', labels[i])
```
以上是一个简单的示例,可以供参考。需要替换路径和类别数等一些必要参数,才能在你的项目中运行。
阅读全文