写一个pytorch训练coco数据集的代码
时间: 2023-09-20 15:07:22 浏览: 185
Sure! 下面是一个基本的 PyTorch 训练 COCO 数据集的代码:
```python
import torch
import torchvision
import torchvision.transforms as transforms
from coco_dataset import CocoDetection
# 设置一些超参数
num_epochs = 10
batch_size = 4
learning_rate = 0.001
# 加载 COCO 数据集
transform_train = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
])
train_dataset = CocoDetection(root='./data/coco/train2017',
annFile='./data/coco/annotations/instances_train2017.json',
transform=transform_train)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
shuffle=True, num_workers=2)
# 定义模型
model = torchvision.models.resnet18(pretrained=False)
num_ftrs = model.fc.in_features
model.fc = torch.nn.Linear(num_ftrs, 80)
# 定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# 训练模型
total_step = len(train_loader)
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
# 前向传播
outputs = model(images)
loss = criterion(outputs, labels)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i + 1) % 100 == 0:
print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
.format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))
# 保存模型
torch.save(model.state_dict(), 'coco_resnet18.pth')
```
注意,这里假设你已经有了一个名为 `coco_dataset.py` 的自定义 COCO 数据集类。你也需要在 `./data/coco/train2017` 中放置 COCO 训练图像,以及在 `./data/coco/annotations` 中放置 COCO 注释。另外,你还需要安装 `pycocotools` 才能使用 `CocoDetection` 类。
阅读全文