如果把其中的nn.BatchNorm2d(self.num_filters)换成nn.InstanceNorm2d(out_features)效果会更好吗?
时间: 2024-04-05 19:14:07 浏览: 112
在使用 SNGAN 的时候,使用 InstanceNorm 代替 BatchNorm 确实可以得到更好的效果。因为 BatchNorm 的计算是在一个 batch 内进行的,而 InstanceNorm 是在一个单独的样本内进行的,因此 InstanceNorm 更适合用于生成器和判别器中,因为它们的输入不同。
因此,如果你将 `SNGANGenerator` 和 `SNGANDiscriminator` 中的 `nn.BatchNorm2d` 替换为 `nn.InstanceNorm2d`,会得到更好的结果。以下是代码示例:
```python
import torch
import torch.nn as nn
class SNGANGenerator(nn.Module):
def __init__(self, z_dim=100, image_size=64, num_channels=3, num_filters=64):
super(SNGANGenerator, self).__init__()
self.image_size = image_size
self.num_channels = num_channels
self.num_filters = num_filters
self.z_dim = z_dim
self.linear = nn.Linear(z_dim, self.num_filters * 8 * self.image_size // 8 * self.image_size // 8)
self.blocks = nn.Sequential(
nn.InstanceNorm2d(self.num_filters * 8),
nn.Upsample(scale_factor=2),
nn.Conv2d(self.num_filters * 8, self.num_filters * 4, kernel_size=3, stride=1, padding=1),
nn.InstanceNorm2d(self.num_filters * 4),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=2),
nn.Conv2d(self.num_filters * 4, self.num_filters * 2, kernel_size=3, stride=1, padding=1),
nn.InstanceNorm2d(self.num_filters * 2),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=2),
nn.Conv2d(self.num_filters * 2, self.num_filters, kernel_size=3, stride=1, padding=1),
nn.InstanceNorm2d(self.num_filters),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=2),
nn.Conv2d(self.num_filters, self.num_channels, kernel_size=3, stride=1, padding=1),
nn.Tanh()
)
def forward(self, noise):
x = self.linear(noise)
x = x.view(-1, self.num_filters * 8, self.image_size // 8, self.image_size // 8)
x = self.blocks(x)
return x
class SNGANDiscriminator(nn.Module):
def __init__(self, image_size=64, num_channels=3, num_filters=64):
super(SNGANDiscriminator, self).__init__()
self.image_size = image_size
self.num_channels = num_channels
self.num_filters = num_filters
self.blocks = nn.Sequential(
nn.Conv2d(self.num_channels, self.num_filters, kernel_size=3, stride=1, padding=1),
nn.InstanceNorm2d(self.num_filters),
nn.ReLU(inplace=True),
nn.Conv2d(self.num_filters, self.num_filters * 2, kernel_size=3, stride=1, padding=1),
nn.InstanceNorm2d(self.num_filters * 2),
nn.ReLU(inplace=True),
nn.Conv2d(self.num_filters * 2, self.num_filters * 4, kernel_size=3, stride=1, padding=1),
nn.InstanceNorm2d(self.num_filters * 4),
nn.ReLU(inplace=True),
nn.Conv2d(self.num_filters * 4, self.num_filters * 8, kernel_size=3, stride=1, padding=1),
nn.InstanceNorm2d(self.num_filters * 8),
nn.ReLU(inplace=True)
)
self.linear = nn.Linear(self.num_filters * 8 * self.image_size // 8 * self.image_size // 8, 1)
def forward(self, img):
x = self.blocks(img)
x = x.view(-1, self.num_filters * 8 * self.image_size // 8 * self.image_size // 8)
x = self.linear(x)
return x
```
需要注意的是,如果使用 InstanceNorm,需要保证样本的大小是一致的,否则可能会导致效果变差。
阅读全文