把这个网页里的代码理清然后完整给出并标注是哪个.py里的,如果他们是一个文件里的代码,那当我没说
时间: 2024-12-17 13:23:21 浏览: 3
根据《海底世界》制作的游戏及思维导图-少儿编程scratch项目源代码文件案例素材.zip
根据您提供的文档内容,整个GAN项目涉及的代码似乎都是在一个文件内完成的,因此以下整理的代码不需要拆分成多个`.py`文件。如果您希望将其拆分为多个文件,请明确指出具体需求。
以下是整理后的代码,适用于单个Python脚本文件:
### `gan_handwritten_digits.py`
```python
import matplotlib.pyplot as plt
import pandas
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 返回size大小的均值为0,标准差为1的随机数
def generate_random(size):
random_data = torch.randn(size)
return random_data
# 鉴别器
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(784, 200),
nn.LeakyReLU(0.02),
nn.LayerNorm(200),
nn.Linear(200, 1),
nn.Sigmoid()
)
self.loss_function = nn.BCELoss()
self.optimiser = torch.optim.Adam(self.parameters(), lr=0.0001)
self.counter = 0
self.progress = []
def forward(self, inputs):
return self.model(inputs)
def train(self, inputs, targets):
outputs = self.forward(inputs)
loss = self.loss_function(outputs, targets)
self.counter += 1
if (self.counter % 10 == 0):
self.progress.append(loss.item())
if (self.counter % 10000 == 0):
print("counter = ", self.counter)
self.optimiser.zero_grad()
loss.backward()
self.optimiser.step()
def plot_progress(self):
df = pandas.DataFrame(self.progress, columns=['loss'])
df.plot(ylim=(0, 1.0), figsize=(16, 8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5, 1.0, 5.0))
# 生成器
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(100, 200),
nn.LeakyReLU(0.02),
nn.LayerNorm(200),
nn.Linear(200, 784),
nn.Sigmoid()
)
self.optimiser = torch.optim.Adam(self.parameters(), lr=0.0001)
self.counter = 0
self.progress = []
def forward(self, inputs):
return self.model(inputs)
def train(self, D, inputs, targets):
g_output = self.forward(inputs)
d_output = D.forward(g_output)
loss = D.loss_function(d_output, targets)
self.counter += 1
if (self.counter % 10 == 0):
self.progress.append(loss.item())
self.optimiser.zero_grad()
loss.backward()
self.optimiser.step()
def plot_progress(self):
df = pandas.DataFrame(self.progress, columns=['loss'])
df.plot(ylim=(0, 1.0), figsize=(16, 8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5, 1.0, 5.0))
# 数据集加载
train_data = datasets.MNIST(root="./data", train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(train_data, batch_size=1, shuffle=False)
# 实例化鉴别器和生成器
D = Discriminator()
G = Generator()
# 训练GAN
epochs = 4
for epoch in range(epochs):
for step, (images, labels) in enumerate(train_loader):
image_data_tensor = images.view(-1)
# 使用真实数据训练鉴别器
D.train(image_data_tensor, torch.FloatTensor([1.0]))
# 用生成样本训练鉴别器
D.train(G.forward(generate_random(100)).detach(), torch.FloatTensor([0.0]))
# 训练生成器
G.train(D, generate_random(100), torch.FloatTensor([1.0]))
# 保存模型
torch.save(D, 'GAN_Digits_D.pt')
torch.save(G, 'GAN_Digits_G.pt')
# 加载模型
G = torch.load('GAN_Digits_G.pt')
# 生成效果图
f, axarr = plt.subplots(2, 3, figsize=(16, 8))
for i in range(2):
for j in range(3):
output = G.forward(generate_random(100))
img = output.detach().numpy().reshape(28, 28)
axarr[i, j].imshow(img, interpolation='none', cmap='Blues')
plt.show()
```
这个脚本涵盖了从数据加载、模型定义、训练到生成图像的全过程。您可以直接运行这个脚本来实现手写数字生成的任务。如果有任何特定的需求或进一步的修改,请告诉我。
阅读全文