给我一份VGG16的训练代码
时间: 2024-04-10 13:28:15 浏览: 25
好的,以下是一个使用PyTorch训练VGG16模型的示例代码:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import vgg16
from torchvision.transforms import ToTensor
from my_dataset import MyDataset
# 设置训练参数
batch_size = 16
num_epochs = 10
learning_rate = 0.001
# 加载数据集
train_dataset = MyDataset(train=True, transform=ToTensor())
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# 创建VGG16模型
model = vgg16(pretrained=True)
num_features = model.classifier[6].in_features
model.classifier[6] = nn.Linear(num_features, 2) # 假设我们有两个类别
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# 开始训练
total_step = len(train_loader)
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
images = images.to(device)
labels = labels.to(device)
# 前向传播
outputs = model(images)
loss = criterion(outputs, labels)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i + 1) % 10 == 0:
print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
.format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))
# 保存模型
torch.save(model.state_dict(), 'vgg16_model.pth')
```
请注意,这只是一个示例代码,你需要根据你的数据集和需求进行相应的调整和修改。你还需要自定义自己的数据集类(MyDataset)来加载你的数据集,并根据你的实际情况更改模型的最后一层(classifier)以适应你的分类任务。另外,你需要确保安装了PyTorch和torchvision库。希望对你有帮助!
相关推荐
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![7z](https://img-home.csdnimg.cn/images/20210720083312.png)
![rar](https://img-home.csdnimg.cn/images/20210720083606.png)