pytorch textcnn模型代码
时间: 2023-10-31 22:04:36 浏览: 154
下面是一个基于 PyTorch 实现的 TextCNN 模型的代码示例:
```python
import torch
import torch.nn as nn
class TextCNN(nn.Module):
def __init__(self, vocab_size, embedding_dim, num_filters, filter_sizes, output_dim, dropout):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.convs = nn.ModuleList([
nn.Conv2d(in_channels=1, out_channels=num_filters, kernel_size=(fs, embedding_dim))
for fs in filter_sizes
])
self.fc = nn.Linear(len(filter_sizes) * num_filters, output_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, text):
# text = [batch size, sent len]
embedded = self.embedding(text)
# embedded = [batch size, sent len, emb dim]
embedded = embedded.unsqueeze(1)
# embedded = [batch size, 1, sent len, emb dim]
conved = [nn.functional.relu(conv(embedded)).squeeze(3) for conv in self.convs]
# conved_n = [batch size, n_filters, sent len - filter_sizes[n] + 1]
pooled = [nn.functional.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved]
# pooled_n = [batch size, n_filters]
cat = self.dropout(torch.cat(pooled, dim=1))
# cat = [batch size, n_filters * len(filter_sizes)]
return self.fc(cat)
```
其中,`vocab_size` 表示词汇表大小,`embedding_dim` 表示嵌入维度,`num_filters` 表示卷积核数量,`filter_sizes` 表示卷积核大小,`output_dim` 表示输出维度,`dropout` 表示 dropout 概率。在 `forward` 方法中,我们首先将输入文本进行词嵌入,然后通过不同大小的卷积核进行卷积和池化,最后将池化结果拼接起来并通过全连接层输出。
阅读全文