self-attention gan 代码_GAN+异常检测
时间: 2023-10-16 19:11:10 浏览: 111
以下是 Self-Attention GAN 代码和 GAN+异常检测的代码示例:
Self-Attention GAN 代码:
```
import torch.nn as nn
import torch
class SelfAttention(nn.Module):
def __init__(self, in_channels):
super(SelfAttention, self).__init__()
self.query_conv = nn.Conv2d(in_channels=in_channels, out_channels=in_channels // 8, kernel_size=1)
self.key_conv = nn.Conv2d(in_channels=in_channels, out_channels=in_channels // 8, kernel_size=1)
self.value_conv = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=1)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
m_batchsize, C, width, height = x.size()
proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1)
proj_key = self.key_conv(x).view(m_batchsize, -1, width * height)
energy = torch.bmm(proj_query, proj_key)
attention = torch.softmax(energy, dim=-1)
proj_value = self.value_conv(x).view(m_batchsize, -1, width * height)
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
out = out.view(m_batchsize, C, width, height)
out = self.gamma * out + x
return out
```
GAN+异常检测代码:
```
import torch.nn as nn
import torch
import numpy as np
class Generator(nn.Module):
def __init__(self, latent_dim, img_shape):
super(Generator, self).__init__()
self.img_shape = img_shape
def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*block(latent_dim, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), *self.img_shape)
return img
class Discriminator(nn.Module):
def __init__(self, img_shape):
super(Discriminator, self).__init__()
self.img_shape = img_shape
self.model = nn.Sequential(
nn.Linear(int(np.prod(img_shape)), 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid(),
)
def forward(self, img):
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)
return validity
class GAN_Anomaly_Detector(nn.Module):
def __init__(self, latent_dim, img_shape):
super(GAN_Anomaly_Detector, self).__init__()
self.generator = Generator(latent_dim, img_shape)
self.discriminator = Discriminator(img_shape)
def forward(self, x):
z = torch.randn(x.shape[0], LATENT_DIM, device=device)
gen_imgs = self.generator(z)
validity_real = self.discriminator(x)
validity_fake = self.discriminator(gen_imgs)
return torch.mean(torch.abs(x - gen_imgs)) + valid_loss(validity_real, validity_fake)
def valid_loss(validity_real, validity_fake):
real_loss = nn.functional.binary_cross_entropy(validity_real, torch.ones_like(validity_real))
fake_loss = nn.functional.binary_cross_entropy(validity_fake, torch.zeros_like(validity_fake))
return (real_loss + fake_loss) / 2
```
这里的 GAN+异常检测是通过计算生成图像与输入图像之间的差异以及判别器的输出来进行异常检测。如果生成的图像与输入图像越接近,则相似度越高,否则就是异常。
阅读全文