请给我一段使用Python写的带有详细注释的使用GPT2模型进行文本分类的代码
时间: 2024-02-13 07:02:35 浏览: 104
文本分类算法研究及python代码实现
5星 · 资源好评率100%
好的,这里是一个使用 Python 和 PyTorch 实现的基于 GPT2 模型的文本分类的代码。
```python
import torch
import torch.nn as nn
from transformers import GPT2Tokenizer, GPT2Model, GPT2Config
from torch.utils.data import Dataset, DataLoader
class TextDataset(Dataset):
def __init__(self, texts, labels, tokenizer, max_len):
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
self.max_len = max_len
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = str(self.texts[idx])
label = self.labels[idx]
inputs = self.tokenizer.encode_plus(
text,
None,
add_special_tokens=True,
max_length=self.max_len,
padding='max_length',
truncation=True
)
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
return {
'input_ids': torch.tensor(input_ids, dtype=torch.long),
'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
'label': torch.tensor(label, dtype=torch.long)
}
class GPT2Classifier(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.config = GPT2Config.from_pretrained('gpt2')
self.transformer = GPT2Model(self.config)
self.classifier = nn.Linear(self.config.n_embd, num_classes)
def forward(self, input_ids, attention_mask=None):
hidden_states = self.transformer(input_ids, attention_mask=attention_mask).last_hidden_state
pooled_output = hidden_states.mean(dim=1)
logits = self.classifier(pooled_output)
return logits
# 加载训练数据和测试数据
train_texts = ['This is a positive sentence.', 'This is a negative sentence.']
train_labels = [1, 0]
test_texts = ['This is another positive sentence.', 'This is another negative sentence.']
test_labels = [1, 0]
# 初始化 tokenizer 和数据集
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
train_dataset = TextDataset(train_texts, train_labels, tokenizer, max_len=32)
test_dataset = TextDataset(test_texts, test_labels, tokenizer, max_len=32)
# 初始化模型和优化器
model = GPT2Classifier(num_classes=2)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
# 训练模型
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)
for epoch in range(5):
model.train()
for batch in train_dataloader:
input_ids, attention_mask, label = tuple(t.to(device) for t in batch.values())
optimizer.zero_grad()
logits = model(input_ids, attention_mask=attention_mask)
loss_func = nn.CrossEntropyLoss()
loss = loss_func(logits, label)
loss.backward()
optimizer.step()
print('Epoch {}, train loss: {:.4f}'.format(epoch, loss.item()))
# 测试模型
test_dataloader = DataLoader(test_dataset, batch_size=2, shuffle=False)
model.eval()
with torch.no_grad():
for batch in test_dataloader:
input_ids, attention_mask, label = tuple(t.to(device) for t in batch.values())
logits = model(input_ids, attention_mask=attention_mask)
preds = logits.argmax(dim=1)
print('Preds: ', preds)
print('Labels: ', label)
```
上面的代码中,我们首先定义了一个 `TextDataset` 类,用于加载文本数据和标签,并将文本数据转换为 GPT2 模型的输入格式。然后我们定义了一个 `GPT2Classifier` 类,用于实例化 GPT2 模型和分类器。接下来,我们加载训练数据和测试数据,并使用 `TextDataset` 类将它们转换为 PyTorch 数据集。然后,我们使用 Adam 优化器训练模型,并在每个 epoch 结束后计算训练集的损失。最后,我们测试模型并输出预测结果和真实标签。
当然,这只是一个简单的例子,实际应用时你可能需要进行更多的调整和优化。但是,希望这段代码能够帮助你理解如何使用 GPT2 模型实现文本分类。
阅读全文