# Adversarial Supervise Architecture E_Hat = self.generator_aux(Z) H_hat = self.supervisor(E_Hat) Y_fake = self.discriminator(H_hat) self.adversarial_supervised = Model(inputs=Z, outputs=Y_fake, name='AdversarialSupervised')
时间: 2024-02-14 10:26:50 浏览: 81
在这段代码中,实现了一个对抗监督架构(Adversarial Supervised Architecture)。
首先,通过将随机噪声输入 Z 传递给生成器模型 self.generator_aux,得到生成器的输出 E_Hat。接下来,将 E_Hat 作为输入传递给监督模型 self.supervisor,得到监督模型的输出 H_hat。然后,将 H_hat 作为输入传递给判别器模型 self.discriminator,得到判别器的输出 Y_fake。
最后,使用 Keras 的 Model 类创建了一个名为 adversarial_supervised 的模型对象,其输入为 Z,输出为 Y_fake。这个 adversarial_supervised 模型将生成器、监督模型和判别器连接在一起,形成了一个整体的对抗监督架构。
这种架构的目的是通过生成器和监督模型的协同训练,使生成器能够生成逼真的数据样本,并通过判别器对生成的样本进行判别和评估。这样可以实现对抗训练和生成器的优化。
相关问题
@function def train_generator(self, x, z, opt): with GradientTape() as tape: y_fake = self.adversarial_supervised(z) generator_loss_unsupervised = self._bce(y_true=ones_like(y_fake), y_pred=y_fake) y_fake_e = self.adversarial_embedded(z) generator_loss_unsupervised_e = self._bce(y_true=ones_like(y_fake_e), y_pred=y_fake_e) h = self.embedder(x) h_hat_supervised = self.supervisor(h) generator_loss_supervised = self._mse(h[:, 1:, :], h_hat_supervised[:, 1:, :]) x_hat = self.generator(z) generator_moment_loss = self.calc_generator_moments_loss(x, x_hat) generator_loss = (generator_loss_unsupervised + generator_loss_unsupervised_e + 100 * sqrt(generator_loss_supervised) + 100 * generator_moment_loss) var_list = self.generator_aux.trainable_variables + self.supervisor.trainable_variables gradients = tape.gradient(generator_loss, var_list) opt.apply_gradients(zip(gradients, var_list)) return generator_loss_unsupervised, generator_loss_supervised, generator_moment_loss
这是一个用于训练生成器的函数。该函数接受三个输入,`x`和`z`分别表示真实样本和生成样本,`opt`表示优化器。
在函数内部,首先使用 `adversarial_supervised` 模型对生成样本进行预测,得到 `y_fake`。然后使用二元交叉熵损失函数 `_bce` 计算生成样本的非监督损失 `generator_loss_unsupervised`。
接下来,通过 `adversarial_embedded` 模型对生成样本进行预测,得到 `y_fake_e`。然后使用二元交叉熵损失函数 `_bce` 计算生成样本的嵌入式非监督损失 `generator_loss_unsupervised_e`。
然后,通过 `embedder` 模型对真实样本进行预测,得到 `h`。使用 `supervisor` 模型对 `h` 进行预测,得到 `h_hat_supervised`。然后使用均方误差损失函数 `_mse` 计算生成样本的监督损失 `generator_loss_supervised`。
接下来,使用 `generator` 模型对生成样本进行预测,得到 `x_hat`。然后使用 `calc_generator_moments_loss` 函数计算生成样本的生成器矩损失 `generator_moment_loss`。
最后,将非监督损失、嵌入式非监督损失、监督损失以及生成器矩损失进行加权求和,得到最终的生成器损失 `generator_loss`。
使用 `GradientTape` 记录梯度信息,并根据生成器损失和可训练变量计算梯度。然后使用优化器 `opt` 应用梯度更新模型参数。
最后,返回非监督损失、监督损失和生成器矩损失三个部分的损失值。
def discriminator_loss(self, x, z): y_real = self.discriminator_model(x) discriminator_loss_real = self._bce(y_true=ones_like(y_real), y_pred=y_real) y_fake = self.adversarial_supervised(z) discriminator_loss_fake = self._bce(y_true=zeros_like(y_fake), y_pred=y_fake) y_fake_e = self.adversarial_embedded(z) discriminator_loss_fake_e = self._bce(y_true=zeros_like(y_fake_e), y_pred=y_fake_e) return (discriminator_loss_real + discriminator_loss_fake + self.gamma * discriminator_loss_fake_e)
这是一个用于计算鉴别器损失的函数。该函数接受两个输入,`x`和`z`,分别表示真实样本和生成样本。在函数中,首先通过鉴别器模型对真实样本进行预测,得到`y_real`。然后使用二元交叉熵损失函数 `_bce` 计算真实样本的鉴别器损失 `discriminator_loss_real`。
接下来,通过对生成样本使用两个不同的辅助鉴别器模型 `adversarial_supervised` 和 `adversarial_embedded` 进行预测。分别得到 `y_fake` 和 `y_fake_e`。然后使用二元交叉熵损失函数 `_bce` 分别计算生成样本的鉴别器损失 `discriminator_loss_fake` 和 `discriminator_loss_fake_e`。
最后,通过加权求和将三个损失项组合起来,其中 `self.gamma` 是一个权重参数。返回最终的鉴别器损失值。
请注意,该代码片段中的 `_bce` 函数可能是定义在其他地方的一个二元交叉熵损失函数。
阅读全文
相关推荐
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![gz](https://img-home.csdnimg.cn/images/20210720083447.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![py](https://img-home.csdnimg.cn/images/20250102104920.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![pdf](https://img-home.csdnimg.cn/images/20241231044930.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![pdf](https://img-home.csdnimg.cn/images/20241231044930.png)
![pdf](https://img-home.csdnimg.cn/images/20241231044930.png)
![rar](https://img-home.csdnimg.cn/images/20241231044955.png)