def compute_gradient_penalty(D, real_samples, fake_samples): alpha = torch.from_numpy(np.random.random((real_samples.size()[0], 1, 1, 1))).float().cuda() interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True) d_interpolates, _ = D(interpolates) fake = autograd.Variable(torch.ones(real_samples.size()[0]), requires_grad=False).cuda() gradients = autograd.grad( outputs=d_interpolates, inputs=interpolates, grad_outputs=fake, create_graph=True, retain_graph=True, only_inputs=True, )[0] gradients = gradients.view(gradients.size(0), -1) gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() return gradient_penalty解释函数
时间: 2024-02-26 21:54:43 浏览: 200
经典rsa.rar_RSA.c_compute_rsa_rsa
这是一个计算 Wasserstein GAN 梯度惩罚项的函数,用于提高训练的稳定性和生成样本的质量。输入是鉴别器(D)、真实图像样本(real_samples)和生成图像样本(fake_samples),输出是梯度惩罚项(gradient_penalty)。
具体实现如下:
- 随机生成一个(0,1)之间的随机数alpha,其形状为(batch_size, 1, 1, 1)。
- 根据随机数alpha和真实图像样本、生成图像样本,生成一组插值样本(interpolates),其形状与真实图像样本一致。
- 将插值样本(interpolates)输入鉴别器(D)中,得到鉴别器输出(d_interpolates)。
- 构造一个与真实图像样本数目相同的全1张量(fake),作为计算梯度的输出。
- 对鉴别器输出(d_interpolates)关于插值样本(interpolates)的梯度进行计算。
- 将梯度进行reshape,并计算其L2范数(norm)。
- 计算梯度惩罚项(gradient_penalty),即对梯度范数减1后的平方进行平均。
- 返回梯度惩罚项(gradient_penalty)。
阅读全文