CNN与GRU结合代码 pytorch
时间: 2024-05-04 17:19:20 浏览: 163
下面是一个简单的示例代码,将CNN和GRU结合在一起,用于文本分类任务。代码使用PyTorch框架。
```
import torch
import torch.nn as nn
import torch.nn.functional as F
class CNN_GRU(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim,
n_filters, filter_sizes, dropout):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.convs = nn.ModuleList([
nn.Conv2d(in_channels=1, out_channels=n_filters, kernel_size=(fs, embedding_dim))
for fs in filter_sizes
])
self.gru = nn.GRU(input_size=n_filters * len(filter_sizes),
hidden_size=hidden_dim,
num_layers=1,
bidirectional=True,
batch_first=True,
dropout=0 if num_layers < 2 else dropout)
self.fc = nn.Linear(hidden_dim * 2, output_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, text):
# text = [batch size, sent len]
embedded = self.embedding(text)
# embedded = [batch size, sent len, emb dim]
embedded = embedded.unsqueeze(1)
# embedded = [batch size, 1, sent len, emb dim]
conved = [F.relu(conv(embedded)).squeeze(3) for conv in self.convs]
# conved_n = [batch size, n_filters, sent len - filter_sizes[n] + 1]
pooled = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved]
# pooled_n = [batch size, n_filters]
cat = self.dropout(torch.cat(pooled, dim=1))
# cat = [batch size, n_filters * len(filter_sizes)]
_, hidden = self.gru(cat.unsqueeze(0))
# hidden = [n layers * n directions, batch size, emb dim]
hidden = self.dropout(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1))
# hidden = [batch size, hid dim * num directions]
return self.fc(hidden)
```
在这个模型中,我们首先通过嵌入层将输入的文本转换为词向量,然后使用卷积神经网络进行特征提取,最后将提取的特征输入到GRU中进行序列建模。最终,我们使用全连接层将GRU的输出映射到所需的输出维度。
阅读全文