使用thchs30数据集训练CycleGAN模型实现语音转换并,将其封装起来的代码
时间: 2024-05-16 12:15:00 浏览: 17
以下是使用thchs30数据集训练CycleGAN模型实现语音转换的代码:
1. 安装必要的依赖项
```python
!pip install torch==1.7.0 torchvision==0.8.1 torchaudio==0.7.0
!pip install librosa==0.8.0 matplotlib==3.3.2 numpy==1.19.2
!pip install tensorboard==2.4.0
!pip install tqdm==4.54.0
```
2. 导入必要的库
```python
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchaudio
import numpy as np
import librosa
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
```
3. 定义超参数
```python
# 超参数
num_epochs = 200
batch_size = 1
lr = 0.0002
sample_rate = 16000
num_mels = 80
n_fft = 1024
hop_length = 256
num_iters = 200000
num_frames = 128
lambda_cycle = 10
lambda_identity = 5
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
```
4. 定义数据集
```python
class Thchs30Dataset(data.Dataset):
def __init__(self, root_dir):
self.root_dir = root_dir
self.speaker_dirs = os.listdir(root_dir)
def __len__(self):
return len(self.speaker_dirs)
def __getitem__(self, idx):
speaker_dir = self.speaker_dirs[idx]
audio_files = os.listdir(os.path.join(self.root_dir, speaker_dir))
random.shuffle(audio_files)
audio_file = audio_files[0]
audio_path = os.path.join(self.root_dir, speaker_dir, audio_file)
waveform, _ = torchaudio.load(audio_path)
waveform = waveform.squeeze()
if waveform.shape[0] > num_iters:
waveform = waveform[:num_iters]
else:
waveform = F.pad(waveform, (0, num_iters - waveform.shape[0]), 'constant', 0)
return waveform
```
5. 定义生成器
```python
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.encoder = nn.Sequential(
nn.Conv1d(num_mels, 512, kernel_size=7, padding=3),
nn.InstanceNorm1d(512),
nn.ReLU(inplace=True),
nn.Conv1d(512, 256, kernel_size=3, stride=2, padding=1),
nn.InstanceNorm1d(256),
nn.ReLU(inplace=True),
nn.Conv1d(256, 128, kernel_size=3, stride=2, padding=1),
nn.InstanceNorm1d(128),
nn.ReLU(inplace=True)
)
self.transform = nn.Sequential(
nn.Conv1d(128, 128, kernel_size=3, stride=1, padding=1),
nn.InstanceNorm1d(128),
nn.ReLU(inplace=True)
)
self.decoder = nn.Sequential(
nn.ConvTranspose1d(128, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.InstanceNorm1d(256),
nn.ReLU(inplace=True),
nn.ConvTranspose1d(256, 512, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.InstanceNorm1d(512),
nn.ReLU(inplace=True),
nn.ConvTranspose1d(512, num_mels, kernel_size=7, padding=3),
nn.InstanceNorm1d(num_mels),
nn.Tanh()
)
def forward(self, x):
encoded = self.encoder(x)
transformed = self.transform(encoded)
decoded = self.decoder(transformed)
return decoded
```
6. 定义判别器
```python
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.conv1 = nn.Conv1d(num_mels, 64, kernel_size=15, padding=7)
self.conv2 = nn.Conv1d(64, 128, kernel_size=41, stride=4, padding=20)
self.bn1 = nn.BatchNorm1d(64)
self.bn2 = nn.BatchNorm1d(128)
self.fc1 = nn.Linear(128 * 4, 1)
def forward(self, x):
x = F.leaky_relu(self.bn1(self.conv1(x)), 0.2)
x = F.leaky_relu(self.bn2(self.conv2(x)), 0.2)
x = x.view(x.shape[0], -1)
x = self.fc1(x)
return x
```
7. 定义损失函数和优化器
```python
generator1 = Generator().to(device)
generator2 = Generator().to(device)
discriminator1 = Discriminator().to(device)
discriminator2 = Discriminator().to(device)
criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()
optimizer_G = optim.Adam(
itertools.chain(generator1.parameters(), generator2.parameters()),
lr=lr, betas=(0.5, 0.999)
)
optimizer_D1 = optim.Adam(discriminator1.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D2 = optim.Adam(discriminator2.parameters(), lr=lr, betas=(0.5, 0.999))
```
8. 训练模型
```python
def train():
writer = SummaryWriter()
train_dataset = Thchs30Dataset('thchs30/data')
train_loader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
for epoch in range(num_epochs):
for i, x in enumerate(tqdm(train_loader)):
x = x.to(device)
# 训练生成器
optimizer_G.zero_grad()
y = generator1(x)
x_reconstructed = generator2(y)
y_reconstructed = generator1(x_reconstructed)
# GAN损失
gan_loss1 = criterion_GAN(discriminator1(y), torch.ones_like(discriminator1(y)))
gan_loss2 = criterion_GAN(discriminator2(x_reconstructed), torch.ones_like(discriminator2(x_reconstructed)))
gan_loss = (gan_loss1 + gan_loss2) / 2
# 循环一致性损失
cycle_loss1 = criterion_cycle(x_reconstructed, x)
cycle_loss2 = criterion_cycle(y_reconstructed, y)
cycle_loss = (cycle_loss1 + cycle_loss2) / 2
# 身份映射损失
identity_loss1 = criterion_identity(generator1(x), x)
identity_loss2 = criterion_identity(generator2(y), y)
identity_loss = (identity_loss1 + identity_loss2) / 2
# 总损失
total_loss = gan_loss + lambda_cycle * cycle_loss + lambda_identity * identity_loss
total_loss.backward()
optimizer_G.step()
# 训练判别器1
optimizer_D1.zero_grad()
real_loss = criterion_GAN(discriminator1(x), torch.ones_like(discriminator1(x)))
fake_loss = criterion_GAN(discriminator1(y_reconstructed.detach()), torch.zeros_like(discriminator1(y_reconstructed)))
loss_D1 = (real_loss + fake_loss) / 2
loss_D1.backward()
optimizer_D1.step()
# 训练判别器2
optimizer_D2.zero_grad()
real_loss = criterion_GAN(discriminator2(x_reconstructed), torch.ones_like(discriminator2(x_reconstructed)))
fake_loss = criterion_GAN(discriminator2(x.detach()), torch.zeros_like(discriminator2(x)))
loss_D2 = (real_loss + fake_loss) / 2
loss_D2.backward()
optimizer_D2.step()
# 打印损失
if i % 100 == 0:
print('[Epoch %d/%d] [Batch %d/%d] [G loss: %f] [D1 loss: %f] [D2 loss: %f]'
% (epoch + 1, num_epochs, i, len(train_loader), total_loss.item(), loss_D1.item(), loss_D2.item()))
# 记录损失
writer.add_scalar('G loss', total_loss.item(), epoch * len(train_loader) + i)
writer.add_scalar('D1 loss', loss_D1.item(), epoch * len(train_loader) + i)
writer.add_scalar('D2 loss', loss_D2.item(), epoch * len(train_loader) + i)
# 保存模型
if i % 1000 == 0:
torch.save(generator1.state_dict(), 'generator1.pth')
torch.save(generator2.state_dict(), 'generator2.pth')
```
9. 运行训练函数
```python
train()
```
10. 使用模型进行语音转换
```python
def convert_audio(input_audio_path, output_audio_path):
generator1 = Generator().to(device)
generator1.load_state_dict(torch.load('generator1.pth', map_location=device))
generator1.eval()
waveform, _ = torchaudio.load(input_audio_path)
waveform = waveform.squeeze()
if waveform.shape[0] > num_iters:
waveform = waveform[:num_iters]
else:
waveform = F.pad(waveform, (0, num_iters - waveform.shape[0]), 'constant', 0)
mel_spec = librosa.feature.melspectrogram(
waveform.numpy(), sr=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=num_mels)
mel_spec = torch.FloatTensor(mel_spec).unsqueeze(0).to(device)
with torch.no_grad():
output_mel_spec = generator1(mel_spec)
output_mel_spec = output_mel_spec.squeeze().cpu().numpy()
output_waveform = librosa.feature.inverse.mel_to_audio(
output_mel_spec, sr=sample_rate, n_fft=n_fft, hop_length=hop_length)
output_waveform = torch.FloatTensor(output_waveform).unsqueeze(0)
torchaudio.save(output_audio_path, output_waveform, sample_rate)
```
11. 运行语音转换函数
```python
convert_audio('input.wav', 'output.wav')
```
相关推荐
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)