pytorch古诗词自动生成
时间: 2024-12-27 14:12:02 浏览: 7
### 使用PyTorch实现古诗词自动生成
#### 准备工作
为了使用PyTorch实现古诗词的自动生成,首先需要准备环境并加载必要的库。这一步骤确保后续的数据处理和模型训练能够顺利进行。
```python
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
```
#### 数据预处理
数据预处理阶段至关重要,它决定了输入给神经网络的信息质量。对于古诗词生成任务来说,通常会涉及到字符级别的编码解码过程[^1]。
- **创建词典**:将所有的汉字映射成唯一的整数ID。
- **序列化文本**:把每首诗转换为由这些ID组成的列表形式。
- **填充长度**:由于不同诗句长短不一,因此要统一它们的最大长度,并对较短的部分做适当补全。
```python
class Tokenizer():
def __init__(self, vocab_size=5000):
self.vocab = {}
self.reverse_vocab = {}
def fit_on_texts(self, texts):
unique_chars = set(''.join(texts))
for idx, char in enumerate(unique_chars):
self.vocab[char] = idx + 1
self.reverse_vocab[idx + 1] = char
def text_to_sequence(self, text):
return [self.vocab.get(c, 0) for c in text]
def pad_sequences(sequences, maxlen=None, padding='post', truncating='pre'):
lengths = [len(s) for s in sequences]
nb_samples = len(sequences)
if maxlen is None:
maxlen = max(lengths)
x = np.zeros((nb_samples, maxlen), dtype=np.int32)
for i, seq in enumerate(sequences):
if not seq:
continue
if truncating == 'pre':
trunc = seq[-maxlen:]
elif truncating == 'post':
trunc = seq[:maxlen]
trunc = list(trunc)
if padding == 'post':
x[i, :len(trunc)] = trunc
elif padding == 'post':
x[i, -len(trunc):] = trunc
return x
```
#### 定义数据集类
为了让`DataLoader`可以方便地读取数据,在这里定义了一个继承自`Dataset`的基础类来封装具体的逻辑[^2]。
```python
class PoemDataset(Dataset):
def __init__(self, poems, tokenizer, max_len):
super().__init__()
self.poems = poems
self.tokenizer = tokenizer
self.max_len = max_len
def __getitem__(self, index):
poem = self.poems[index]
tokens = self.tokenizer.text_to_sequence(poem)
padded_tokens = pad_sequences([tokens], maxlen=self.max_len)[0]
input_tensor = torch.tensor(padded_tokens[:-1])
target_tensor = torch.tensor(padded_tokens[1:])
return input_tensor.long(), target_tensor.long()
def __len__(self):
return len(self.poems)
```
#### LSTM模型搭建
接下来就是核心部分——构建LSTM模型架构。此模型接收经过编码后的字符序列为输入,并预测下一个可能出现的字符[^3]。
```python
class PoetryGenerator(nn.Module):
def __init__(self, vocab_size, embedding_dim=128, hidden_units=256):
super(PoetryGenerator, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, hidden_units, batch_first=True)
self.fc_out = nn.Linear(hidden_units, vocab_size)
def forward(self, inputs, hiddens=None):
embeds = self.embedding(inputs.unsqueeze(-1).float())
lstm_output, (hidden_state, cell_state) = self.lstm(embeds, hiddens)
outputs = self.fc_out(lstm_output.squeeze(dim=-2))
return outputs, (hidden_state, cell_state)
```
#### 训练循环
有了前面的工作之后就可以进入实际的训练环节了。在这个过程中不断调整参数使得损失函数最小化从而提高准确性[^4]。
```python
device = "cuda" if torch.cuda.is_available() else "cpu"
model = PoetryGenerator(len(tokenizer.vocab)+1).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())
for epoch in range(num_epochs):
model.train()
total_loss = 0.
for X_batch, Y_batch in train_loader:
optimizer.zero_grad()
output, _ = model(X_batch.to(device))
loss = criterion(output.view(-1,output.shape[-1]),Y_batch.reshape(-1).to(device))
loss.backward()
optimizer.step()
total_loss += loss.item()
```
#### 测试与应用
完成以上步骤后便可以通过调用已训练好的模型来进行新诗句创作尝试。
```python
start_token = "<START>"
end_token = "<END>"
generated_poem = start_token
input_seq = torch.tensor([[tokenizer.vocab[start_token]]]).long().to(device)
hidden_states = None
while True:
with torch.no_grad():
predictions, hidden_states = model(input_seq, hidden_states)
predicted_index = int(torch.argmax(predictions[:, -1]))
next_char = tokenizer.reverse_vocab[predicted_index]
generated_poem += next_char
if next_char == end_token or len(generated_poem)>max_length_of_a_line*lines_number:
break
input_seq = torch.tensor([[predicted_index]]).long().to(device)
print(f'Generated Poem:\n{generated_poem}')
```
阅读全文