请给出使用pytorch搭建Faster R-CNN训练网络模型的代码
时间: 2024-02-11 21:07:59 浏览: 94
Python-用PyTorch实现FasterRCNN
5星 · 资源好评率100%
以下是使用PyTorch搭建Faster R-CNN训练网络模型的代码:
```python
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from torch.utils.data import DataLoader
# Set device to GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Define transforms for the data
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# Load the dataset
trainset = torchvision.datasets.CocoDetection(root='./data', annFile='./annotations/instances_train2017.json', transform=transform)
trainloader = DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2)
# Define the model
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model.to(device)
# Define the optimizer and loss function
params = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
criterion = torchvision.models.detection.fasterrcnn_loss
# Train the model
num_epochs = 10
for epoch in range(num_epochs):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# Get the inputs and labels
inputs, targets = data[0].to(device), [{k: v.to(device) for k, v in t.items()} for t in data[1]]
# Zero the parameter gradients
optimizer.zero_grad()
# Forward pass
outputs = model(inputs, targets)
loss = sum(criterion(out, targ) for out, targ in zip(outputs, targets))
# Backward pass
loss.backward()
optimizer.step()
# Print statistics
running_loss += loss.item()
if i % 100 == 99:
print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 100))
running_loss = 0.0
# Update the learning rate
lr_scheduler.step()
print('Finished Training')
```
在这个代码中,我们使用了PyTorch内置的`torchvision`库来加载COCO数据集,并使用`torch.utils.data.DataLoader`来创建数据加载器。我们还使用了`torchvision.models.detection.fasterrcnn_resnet50_fpn`来定义Faster R-CNN模型,并使用`torch.optim.SGD`作为优化器,以及`torch.optim.lr_scheduler.StepLR`来动态调整学习率。最后,我们使用了`torchvision.models.detection.fasterrcnn_loss`作为损失函数,并在训练循环中使用了标准的前向传递、后向传递和优化步骤。
需要注意的是,这个代码只是一个简单的示例,如果要在实际项目中使用,需要根据具体的需求进行修改和优化。
阅读全文