lstm-gan作曲机pytorch代码
时间: 2023-11-11 13:04:07 浏览: 118
gan pytorch算法代码
以下是一个使用LSTM-GAN生成音乐的PyTorch代码示例:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import os
class Generator(nn.Module):
def __init__(self, input_size, hidden_size, output_size, num_layers):
super(Generator, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x, hidden):
out, hidden = self.lstm(x, hidden)
out = self.fc(out)
return out, hidden
def init_hidden(self, batch_size):
weight = next(self.parameters()).data
hidden = (weight.new(self.num_layers, batch_size, self.hidden_size).zero_(),
weight.new(self.num_layers, batch_size, self.hidden_size).zero_())
return hidden
class Discriminator(nn.Module):
def __init__(self, input_size, hidden_size, output_size, num_layers):
super(Discriminator, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x, hidden):
out, hidden = self.lstm(x, hidden)
out = self.fc(out)
return out, hidden
def init_hidden(self, batch_size):
weight = next(self.parameters()).data
hidden = (weight.new(self.num_layers, batch_size, self.hidden_size).zero_(),
weight.new(self.num_layers, batch_size, self.hidden_size).zero_())
return hidden
def train(generator, discriminator, train_data, num_epochs, batch_size, seq_len, lr, device):
criterion = nn.BCELoss()
g_optimizer = optim.Adam(generator.parameters(), lr=lr)
d_optimizer = optim.Adam(discriminator.parameters(), lr=lr)
for epoch in range(num_epochs):
for i in range(0, len(train_data) - batch_size * seq_len, batch_size * seq_len):
# Train discriminator
discriminator.zero_grad()
hidden = discriminator.init_hidden(batch_size)
real_data = torch.tensor(train_data[i:i+batch_size*seq_len], dtype=torch.float32).view(batch_size, seq_len, -1).to(device)
real_labels = torch.ones(batch_size, 1).to(device)
for j in range(seq_len):
_, hidden = discriminator(real_data[:, j, :].unsqueeze(1), hidden)
real_outputs = discriminator.fc(hidden[0][-1])
real_loss = criterion(real_outputs, real_labels)
noise = torch.randn(batch_size, seq_len, generator.input_size).to(device)
fake_data = generator(noise, generator.init_hidden(batch_size))[0]
fake_labels = torch.zeros(batch_size, 1).to(device)
hidden = discriminator.init_hidden(batch_size)
for j in range(seq_len):
_, hidden = discriminator(fake_data[:, j, :].unsqueeze(1), hidden)
fake_outputs = discriminator.fc(hidden[0][-1])
fake_loss = criterion(fake_outputs, fake_labels)
d_loss = real_loss + fake_loss
d_loss.backward()
d_optimizer.step()
# Train generator
generator.zero_grad()
noise = torch.randn(batch_size, seq_len, generator.input_size).to(device)
fake_data = generator(noise, generator.init_hidden(batch_size))[0]
hidden = discriminator.init_hidden(batch_size)
for j in range(seq_len):
_, hidden = discriminator(fake_data[:, j, :].unsqueeze(1), hidden)
fake_outputs = discriminator.fc(hidden[0][-1])
g_loss = criterion(fake_outputs, real_labels)
g_loss.backward()
g_optimizer.step()
print('Epoch [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}'.format(epoch+1, num_epochs, d_loss.item(), g_loss.item()))
def generate(generator, seq_len, device):
generator.eval()
hidden = generator.init_hidden(1)
noise = torch.randn(1, seq_len, generator.input_size).to(device)
generated_data = generator(noise, hidden)[0].cpu().detach().numpy()
return generated_data.reshape(-1, generator.output_size)
# Example usage
input_size = 128
hidden_size = 256
output_size = 128
num_layers = 2
batch_size = 64
seq_len = 32
num_epochs = 100
lr = 0.001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_data = np.load('train_data.npy')
generator = Generator(input_size, hidden_size, output_size, num_layers).to(device)
discriminator = Discriminator(output_size, hidden_size, 1, num_layers).to(device)
train(generator, discriminator, train_data, num_epochs, batch_size, seq_len, lr, device)
generated_data = generate(generator, seq_len, device)
```
阅读全文