基于pytorch的LSTM-GAN读取11*11灰度图像文件夹并保存模型,后加载模型,生成的每个不同图像保存 代码实现
时间: 2024-03-10 12:46:46 浏览: 29
好的,以下是代码实现:
1. 读取11*11灰度图像文件夹
```python
import os
from PIL import Image
import numpy as np
def load_data(data_dir):
data = []
for filename in os.listdir(data_dir):
if filename.endswith('.jpg'):
img = Image.open(os.path.join(data_dir, filename)).convert('L')
img = img.resize((11, 11), Image.ANTIALIAS)
img_data = np.asarray(img, dtype=np.uint8)
img_data = img_data / 255.0 # 归一化到[0,1]之间
data.append(img_data)
return np.array(data)
```
2. 定义LSTM-GAN模型
```python
import torch
import torch.nn as nn
class Generator(nn.Module):
def __init__(self, hidden_size, output_size):
super(Generator, self).__init__()
self.hidden_size = hidden_size
self.output_size = output_size
self.lstm = nn.LSTMCell(output_size, hidden_size)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x, hidden):
h, c = self.lstm(x, hidden)
x = self.fc(h)
return x, (h, c)
class Discriminator(nn.Module):
def __init__(self, input_size, hidden_size):
super(Discriminator, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.lstm = nn.LSTMCell(input_size, hidden_size)
self.fc = nn.Linear(hidden_size, 1)
def forward(self, x, hidden):
h, c = self.lstm(x, hidden)
x = self.fc(h)
return x, (h, c)
```
3. 训练LSTM-GAN模型并保存
```python
import torch.optim as optim
# 加载数据
data_dir = './data'
data = load_data(data_dir)
# 定义超参数
input_size = 1
hidden_size = 32
output_size = 121
batch_size = 64
num_epochs = 100
lr = 0.001
# 定义模型
generator = Generator(hidden_size, output_size)
discriminator = Discriminator(input_size, hidden_size)
# 定义优化器和损失函数
g_optimizer = optim.Adam(generator.parameters(), lr=lr)
d_optimizer = optim.Adam(discriminator.parameters(), lr=lr)
criterion = nn.BCEWithLogitsLoss()
# 训练模型
for epoch in range(num_epochs):
for i in range(0, data.shape[0] - batch_size, batch_size):
# 生成fake样本
z = torch.randn(batch_size, hidden_size)
fake_imgs = []
hidden = (torch.zeros(batch_size, hidden_size), torch.zeros(batch_size, hidden_size))
for j in range(output_size):
x, hidden = generator(z, hidden)
fake_imgs.append(x)
fake_imgs = torch.stack(fake_imgs, dim=1)
# 训练判别器
real_imgs = torch.from_numpy(data[i:i+batch_size]).float().unsqueeze(1)
real_labels = torch.ones(batch_size, 1)
fake_labels = torch.zeros(batch_size, 1)
d_optimizer.zero_grad()
real_logits, _ = discriminator(real_imgs, hidden)
d_loss_real = criterion(real_logits, real_labels)
fake_logits, _ = discriminator(fake_imgs.detach(), hidden)
d_loss_fake = criterion(fake_logits, fake_labels)
d_loss = d_loss_real + d_loss_fake
d_loss.backward()
d_optimizer.step()
# 训练生成器
g_optimizer.zero_grad()
z = torch.randn(batch_size, hidden_size)
fake_imgs = []
hidden = (torch.zeros(batch_size, hidden_size), torch.zeros(batch_size, hidden_size))
for j in range(output_size):
x, hidden = generator(z, hidden)
fake_imgs.append(x)
fake_imgs = torch.stack(fake_imgs, dim=1)
fake_logits, _ = discriminator(fake_imgs, hidden)
g_loss = criterion(fake_logits, real_labels)
g_loss.backward()
g_optimizer.step()
print('Epoch [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}'.format(epoch+1, num_epochs, d_loss.item(), g_loss.item()))
# 保存模型
torch.save(generator.state_dict(), 'generator.pth')
```
4. 加载模型并生成不同图像
```python
import matplotlib.pyplot as plt
# 加载模型
generator = Generator(hidden_size, output_size)
generator.load_state_dict(torch.load('generator.pth'))
# 生成不同图像
n_samples = 10
z = torch.randn(n_samples, hidden_size)
samples = []
hidden = (torch.zeros(n_samples, hidden_size), torch.zeros(n_samples, hidden_size))
for i in range(output_size):
x, hidden = generator(z, hidden)
samples.append(x.detach().numpy())
samples = np.stack(samples, axis=1)
# 显示图像
fig, axes = plt.subplots(nrows=n_samples, ncols=1, figsize=(10, 10))
for i, ax in enumerate(axes):
ax.imshow(samples[i].reshape(11, 11), cmap='gray')
plt.show()
```
希望能够帮助到你!
相关推荐
![pdf](https://img-home.csdnimg.cn/images/20210720083512.png)
![pdf](https://img-home.csdnimg.cn/images/20210720083512.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)