变分自动编码器pytorch实现
时间: 2023-09-14 18:05:17 浏览: 59
好的,以下是变分自动编码器(VAE)的 PyTorch 实现代码:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
from torchvision import transforms
class VAE(nn.Module):
def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
super(VAE, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.latent_dim = latent_dim
# Encoder layers
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc21 = nn.Linear(hidden_dim, latent_dim)
self.fc22 = nn.Linear(hidden_dim, latent_dim)
# Decoder layers
self.fc3 = nn.Linear(latent_dim, hidden_dim)
self.fc4 = nn.Linear(hidden_dim, input_dim)
def encode(self, x):
h1 = F.relu(self.fc1(x))
mu = self.fc21(h1)
logvar = self.fc22(h1)
return mu, logvar
def reparameterize(self, mu, logvar):
std = torch.exp(0.5*logvar)
eps = torch.randn_like(std)
return mu + eps*std
def decode(self, z):
h3 = F.relu(self.fc3(z))
return torch.sigmoid(self.fc4(h3))
def forward(self, x):
mu, logvar = self.encode(x.view(-1, self.input_dim))
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar
# Define loss function
def loss_function(recon_x, x, mu, logvar):
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
# Initialize model and optimizer
model = VAE()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# Train model
for epoch in range(1, 11):
train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
optimizer.zero_grad()
recon_batch, mu, logvar = model(data)
loss = loss_function(recon_batch, data, mu, logvar)
loss.backward()
train_loss += loss.item()
optimizer.step()
if batch_idx % 100 == 0:
print('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader),
loss.item() / len(data)))
print('====> Epoch: {} Average loss: {:.4f}'.format(
epoch, train_loss / len(train_loader.dataset)))
```
这段代码定义了一个包含一个编码器和一个解码器的 VAE 模型。在训练过程中,模型会从 MNIST 数据集中读取图像数据,并使用 Adam 优化器来更新模型参数。模型在每个 epoch 结束后会输出平均损失。