卷积自编码器pytorch的代码
时间: 2023-11-12 11:51:59 浏览: 97
torch实现自编码器-Pytorch卷积自动编码器
以下是一个简单的卷积自编码器的PyTorch代码实现:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
# 定义网络结构
class ConvAutoencoder(nn.Module):
def __init__(self):
super(ConvAutoencoder, self).__init__()
# 编码器
self.encoder = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=7)
)
# 解码器
self.decoder = nn.Sequential(
nn.ConvTranspose2d(64, 32, kernel_size=7),
nn.ReLU(),
nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(16, 3, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.Sigmoid()
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
# 加载数据集
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
# 定义模型、损失函数和优化器
model = ConvAutoencoder()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 训练模型
epochs = 10
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
for epoch in range(epochs):
running_loss = 0.0
for data in trainloader:
inputs, _ = data
inputs = inputs.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, inputs)
loss.backward()
optimizer.step()
running_loss += loss.item()
print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, epochs, running_loss/len(trainloader)))
# 测试模型
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)
model.eval()
with torch.no_grad():
for data in testloader:
inputs, _ = data
inputs = inputs.to(device)
outputs = model(inputs)
inputs = inputs.cpu()
outputs = outputs.cpu()
torchvision.utils.save_image(inputs, 'input.png')
torchvision.utils.save_image(outputs, 'output.png')
break
```
这个模型使用了CIFAR-10数据集进行训练和测试,输入数据是RGB图像,输出数据也是RGB图像。模型的编码器部分使用了3个卷积层,解码器部分使用了3个转置卷积层,最后使用Sigmoid函数将输出值限制在0到1之间。模型使用MSE损失函数进行训练,优化器使用Adam算法。在训练过程中,每个epoch结束后会打印出训练集上的平均损失。在测试过程中,程序会读取测试集中的一个batch数据,将输入和输出图像分别保存到input.png和output.png文件中。
阅读全文