用transformer架构基于pytorch写一个bert二分类的代码
时间: 2024-05-03 09:22:45 浏览: 135
以下是基于transformer架构的BERT二分类代码:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel, BertTokenizer
class BERTBinaryClassifier(nn.Module):
def __init__(self, bert_model, num_classes):
super(BERTBinaryClassifier, self).__init__()
self.bert = BertModel.from_pretrained(bert_model)
self.dropout = nn.Dropout(0.1)
self.linear = nn.Linear(self.bert.config.hidden_size, num_classes)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.linear(pooled_output)
return logits
# instantiate the model
model = BERTBinaryClassifier(bert_model='bert-base-uncased', num_classes=2)
# load data and tokenizer
train_data = ...
test_data = ...
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# define optimizer and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
loss_fn = nn.CrossEntropyLoss()
# training loop
epochs = 5
batch_size = 32
for epoch in range(epochs):
for i in range(0, len(train_data), batch_size):
batch = train_data[i:i+batch_size]
inputs = tokenizer.batch_encode_plus(batch, padding=True, return_tensors='pt')
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
labels = torch.tensor([int(d['label']) for d in batch])
optimizer.zero_grad()
outputs = model(input_ids, attention_mask)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
# evaluate on test data
correct = 0
total = 0
with torch.no_grad():
for i in range(0, len(test_data), batch_size):
batch = test_data[i:i+batch_size]
inputs = tokenizer.batch_encode_plus(batch, padding=True, return_tensors='pt')
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
labels = torch.tensor([int(d['label']) for d in batch])
outputs = model(input_ids, attention_mask)
predicted = torch.argmax(outputs, dim=1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f"Epoch {epoch+1} | Accuracy: {correct/total:.3f}")
```
注意,这里的数据假设已经预处理成了列表形式,每个元素包含text和label两个字段,如:
```python
train_data = [
{'text': 'This is a positive sentence.', 'label': 1},
{'text': 'This is a negative sentence.', 'label': 0},
...
]
```
阅读全文