输入为11*11大小数据集用wgan生成11*11灰度图代码
时间: 2023-08-12 17:04:39 浏览: 126
下面是一个简单的 PyTorch WGAN-GP 模型,可以用来生成 11x11 的灰度图像。这个模型使用了一个全连接层来生成图像。
首先,我们需要定义模型的生成器和判别器:
```python
import torch
import torch.nn as nn
class Generator(nn.Module):
def __init__(self, latent_dim=100):
super().__init__()
self.latent_dim = latent_dim
self.fc = nn.Linear(latent_dim, 64)
self.conv = nn.Sequential(
nn.Conv2d(1, 16, 3, stride=1, padding=1),
nn.BatchNorm2d(16),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(16, 32, 3, stride=1, padding=1),
nn.BatchNorm2d(32),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(32, 1, 3, stride=1, padding=1),
nn.Tanh()
)
def forward(self, z):
x = self.fc(z)
x = x.view(-1, 1, 8, 8)
x = self.conv(x)
return x
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 16, 3, stride=2, padding=1),
nn.BatchNorm2d(16),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(16, 32, 3, stride=2, padding=1),
nn.BatchNorm2d(32),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(32, 64, 3, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 1, 3, stride=1, padding=0),
)
def forward(self, x):
x = self.conv(x)
x = x.view(-1, 1)
return x
```
然后,我们需要定义 WGAN-GP 模型的损失函数和优化器:
```python
import torch.optim as optim
def wasserstein_loss(real, fake):
return torch.mean(real) - torch.mean(fake)
def gradient_penalty(discriminator, real, fake):
device = real.device
alpha = torch.rand(real.size(0), 1, 1, 1, device=device)
interpolated = alpha * real + (1 - alpha) * fake
interpolated.requires_grad_()
d_interpolated = discriminator(interpolated)
gradients = torch.autograd.grad(outputs=d_interpolated, inputs=interpolated,
grad_outputs=torch.ones(d_interpolated.size(), device=device),
create_graph=True, retain_graph=True)[0]
gradients = gradients.view(gradients.size(0), -1)
return ((gradients.norm(2, dim=1) - 1) ** 2).mean()
def train_wgan(generator, discriminator, dataloader, epochs=100, lr=0.0001, latent_dim=100):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator.to(device)
discriminator.to(device)
g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.9))
d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.9))
for epoch in range(epochs):
for i, real_images in enumerate(dataloader):
real_images = real_images.to(device)
# Train discriminator
for j in range(5):
z = torch.randn(real_images.size(0), latent_dim, device=device)
fake_images = generator(z).detach()
real_logits = discriminator(real_images)
fake_logits = discriminator(fake_images)
d_loss = wasserstein_loss(real_logits, fake_logits)
gp = gradient_penalty(discriminator, real_images, fake_images)
d_loss += 10 * gp
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
# Train generator
z = torch.randn(real_images.size(0), latent_dim, device=device)
fake_images = generator(z)
fake_logits = discriminator(fake_images)
g_loss = -torch.mean(fake_logits)
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
if i % 10 == 0:
print("Epoch [{}/{}], Step [{}/{}], Discriminator Loss: {:.4f}, Generator Loss: {:.4f}"
.format(epoch+1, epochs, i+1, len(dataloader), d_loss.item(), g_loss.item()))
```
最后,我们可以用以下代码生成 11x11 的灰度图像:
```python
import torch.utils.data as data
class RandomDataset(data.Dataset):
def __init__(self, size):
self.size = size
def __getitem__(self, index):
return torch.randn(1, self.size, self.size)
def __len__(self):
return 1000
batch_size = 64
latent_dim = 100
dataset = RandomDataset(11)
dataloader = data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
generator = Generator(latent_dim=latent_dim)
discriminator = Discriminator()
train_wgan(generator, discriminator, dataloader, epochs=100, lr=0.0001, latent_dim=latent_dim)
z = torch.randn(1, latent_dim)
image = generator(z)
image = image.detach().numpy().squeeze()
plt.imshow(image, cmap='gray')
plt.show()
```
这个代码会生成 11x11 的灰度图像,并显示在屏幕上。注意,由于这个模型是 WGAN-GP,因此生成器的输出不是在 [-1, 1] 范围内的像素值,而是在 [-∞, ∞] 范围内的像素值。因此,在显示图像之前,我们需要将像素值截断到 [-1, 1] 范围内。
阅读全文