cnn-gru pytorch代码
时间: 2023-08-29 13:14:00 浏览: 122
以下是一个简单的 CNN-GRU 的 PyTorch 代码,用于文本分类任务。
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class CNN_GRU(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, emb_dim,
kernel_sizes, num_filters, dropout, pad_idx):
super().__init__()
self.embedding = nn.Embedding(input_dim, emb_dim, padding_idx=pad_idx)
self.conv_layers = nn.ModuleList([
nn.Conv1d(in_channels=emb_dim, out_channels=num_filters, kernel_size=ks)
for ks in kernel_sizes
])
self.gru = nn.GRU(input_size=num_filters * len(kernel_sizes),
hidden_size=hidden_dim,
num_layers=1,
bidirectional=True)
self.fc = nn.Linear(hidden_dim * 2, output_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, text, text_lengths):
embedded = self.dropout(self.embedding(text))
# embedded = [batch size, seq len, emb dim]
embedded = embedded.permute(0, 2, 1)
# embedded = [batch size, emb dim, seq len]
conved = [F.relu(conv(embedded)) for conv in self.conv_layers]
# conved_n = [batch size, num_filters, seq len - filter_sizes[n] + 1]
pooled = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved]
# pooled_n = [batch size, num_filters]
cat = self.dropout(torch.cat(pooled, dim=1))
# cat = [batch size, num_filters * len(filter_sizes)]
packed_embedded = nn.utils.rnn.pack_padded_sequence(cat, text_lengths.cpu())
packed_output, hidden = self.gru(packed_embedded)
output, _ = nn.utils.rnn.pad_packed_sequence(packed_output)
# output = [seq len, batch size, hid dim * num directions]
# output over padding tokens are zero tensors
hidden = self.dropout(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1))
# hidden = [batch size, hid dim * num directions]
return self.fc(hidden)
```
这个模型的输入是一个文本序列,由单词的索引组成,以及每个文本的长度。模型首先使用嵌入层将索引转换为稠密向量;然后,通过卷积层提取文本特征;接着,使用双向 GRU 层对提取的特征进行建模;最后,通过全连接层将 GRU 层的输出映射到分类标签空间。该模型使用 dropout 技术以减少过拟合。
阅读全文