bert 池化 代码
时间: 2024-04-25 20:20:31 浏览: 129
BERT(Bidirectional Encoder Representations from Transformers)是一种预训练语言模型,可用于各种自然语言处理任务。BERT模型的输入是一段文本,经过多层双向Transformer编码器处理后得到每个词的向量表示。BERT池化是在得到每个词向量后,将所有词向量通过一些池化方式,如max-pooling、average-pooling等,得到一个定长的向量表示整个文本。
下面是一个使用BERT进行文本分类的例子:
```python
import torch
import torch.nn as nn
from transformers import BertModel
class BertClassifier(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.bert = BertModel.from_pretrained('bert-base-uncased')
self.dropout = nn.Dropout(0.2)
self.fc = nn.Linear(self.bert.config.hidden_size, num_classes)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs
pooled_output = self.dropout(pooled_output)
logits = self.fc(pooled_output)
return logits
```
在这个例子中,我们使用预训练的BERT模型来提取文本特征,并通过全连接层将其映射到目标类别的概率。`input_ids`是一个整数张量,表示文本中每个单词在BERT词表中的索引;`attention_mask`是一个二进制张量,用于指示哪些单词是实际输入,哪些单词是填充的。在`forward`函数中,我们使用BERT模型对输入进行编码,并取出最后一层的输出作为池化输出,然后将其传递给全连接层进行分类。
阅读全文