编写一个VAE pytorch版本,完整代码
时间: 2024-10-19 08:04:44 浏览: 29
在PyTorch中,Variational Autoencoder (VAE)是一种生成模型,它结合了自动编码器(AE)和概率图模型。以下是创建一个基础的VAE模型的简化代码示例。这里我们假设你已经安装了`torch`, `torchvision`, 和 `numpy`库。
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
# 数据预处理
transform = ToTensor()
dataset = MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=100, shuffle=True)
class VAE(nn.Module):
def __init__(self, input_dim, hidden_dim, latent_dim):
super(VAE, self).__init__()
self.encoder = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim * 2),
nn.ReLU(),
nn.Linear(hidden_dim * 2, latent_dim + latent_dim) # 输出 mu 和 logvar
)
self.decoder = nn.Sequential(
nn.Linear(latent_dim, hidden_dim * 2),
nn.ReLU(),
nn.Linear(hidden_dim * 2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, input_dim),
nn.Sigmoid() # 对于二进制数据如MNIST,使用Sigmoid
)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x):
encoded = self.encoder(x.view(-1, 784)) # Flatten the data
mu, logvar = encoded[:, :latent_dim], encoded[:, latent_dim:]
z = self.reparameterize(mu, logvar)
decoded = self.decoder(z)
return decoded, mu, logvar
model = VAE(input_dim=784, hidden_dim=400, latent_dim=20) # 假设输入是MNIST图片的扁平化表示
optimizer = optim.Adam(model.parameters(), lr=1e-3)
def loss_function(recon_x, x, mu, logvar):
BCE = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
for epoch in range(num_epochs):
for i, (images, _) in enumerate(dataloader):
optimizer.zero_grad()
recon_batch, mu, logvar = model(images)
loss = loss_function(recon_batch, images, mu, logvar)
loss.backward()
optimizer.step()
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}")
阅读全文