# This is a sample Python script. # Press Shift+F10 to execute it or replace it with your code. # Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings. import torch import torchvision from PIL.Image import Image from torch.utils.tensorboard import SummaryWriter from torch import nn, optim from torch.utils.data import dataloader from torchvision.transforms import transforms from module import MyModule train = torchvision.datasets.CIFAR10(root="../data",train=True, download=True, transform= transforms.ToTensor()) vgg_model = torchvision.models.vgg16(pretrained=True) vgg_model.classifier.add_module('add_linear', nn.Linear(1000,2)) #ToImage = transforms.ToPILImage() #Image.show(ToImage(train[0][0])) train_data = dataloader.DataLoader(train, batch_size = 128, shuffle=True) model = MyModule() epochs = 5 learningRate = 1e-3 optimizer = optim.SGD(model.parameters(),lr = learningRate) loss = nn.CrossEntropyLoss() Writer = SummaryWriter(log_dir="Training") step = 0 for epoch in range(epochs): total_loss = 0 for data,labels in train_data: y = vgg_model(data) los = loss(y,labels) optimizer.zero_grad() los.backward() optimizer.step() Writer.add_scalar("Training",los,step) step = step + 1 if step%100 == 0: print("Training for {0} times".format(step)) total_loss += los print("total_loss is {0}".format(los)) Writer.close() torch.save(vgg_model,"model_vgg.pth")将以上代码修改为二分类
时间: 2024-01-28 17:05:19 浏览: 124
The MySQL server is running with the –read-only option so it cannot execute this statement
5星 · 资源好评率100%
好的,以下是将代码修改为二分类的版本:
```
import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision.transforms import transforms
from module import MyModule
# 加载数据集
train_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_dataset = torchvision.datasets.CIFAR10(root="../data", train=True, download=True, transform=train_transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
# 构建VGG网络并修改输出层
vgg_model = torchvision.models.vgg16(pretrained=True)
vgg_model.classifier = nn.Sequential(
nn.Linear(25088, 4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096, 2),
nn.LogSoftmax(dim=1)
)
model = MyModule()
# 设置超参数
epochs = 5
learning_rate = 0.001
optimizer = optim.SGD(vgg_model.parameters(), lr=learning_rate)
loss_fn = nn.NLLLoss()
writer = SummaryWriter(log_dir="Training")
step = 0
# 训练模型
for epoch in range(epochs):
total_loss = 0
for data, labels in train_loader:
optimizer.zero_grad()
output = vgg_model(data)
loss = loss_fn(output, labels)
loss.backward()
optimizer.step()
writer.add_scalar("Training Loss", loss, step)
step += 1
if step % 100 == 0:
print("Training for {0} times".format(step))
total_loss += loss.item()
print("Epoch {0}, total loss: {1}".format(epoch + 1, total_loss))
writer.close()
# 保存模型
torch.save(vgg_model, "model_vgg.pth")
```
在这个版本中,我们使用了`nn.LogSoftmax`作为最后一层的激活函数,并使用`nn.NLLLoss()`作为损失函数。同时,我们将VGG网络的输出层修改为一个具有两个输出节点的线性层。另外,我们还对输入图片进行了大小调整和归一化处理。
阅读全文