用python写一段GAN的代码
时间: 2023-03-11 07:48:44 浏览: 83
import numpy as np
import matplotlib.pyplot as plt # 设置权重和偏差
weights_g = np.random.normal(loc=0, scale=1, size=(1,2))
bias_g = np.random.normal(loc=0, scale=1, size=(1,2))# 生成噪声
noise = np.random.rand(1,2)*2 - 1# 生成噪声
gen_data = weights_g * noise + bias_g# 设置权重和偏差
weights_d = np.random.normal(loc=0, scale=1, size=(2,1))
bias_d = np.random.normal(loc=0, scale=1, size=(1,1))# 计算判别器的输出
d_output = np.matmul(gen_data, weights_d) + bias_d# 优化网络
while True:
weights_d += 0.1 * np.matmul(np.transpose(gen_data), d_output)
bias_d += 0.1 * np.sum(d_output, axis=0)
weights_g += 0.1 * np.matmul(noise, np.transpose(d_output))
bias_g += 0.1 * np.sum(d_output, axis=0) d_output = np.matmul(gen_data, weights_d) + bias_d
if np.sum(d_output) > 0.8:
break# 画出生成的数据
plt.scatter(gen_data[:, 0], gen_data[:, 1])
plt.show()