BERT对抗训练pytorch代码
时间: 2023-09-22 09:05:58 浏览: 184
以下是一个BERT对抗训练的pytorch代码示例:
```
import torch
from transformers import BertTokenizer, BertForSequenceClassification
from transformers import AdamW, get_linear_schedule_with_warmup
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from tqdm import tqdm
import random
# 设置GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 设置随机种子
random.seed(2022)
np.random.seed(2022)
torch.manual_seed(2022)
torch.cuda.manual_seed_all(2022)
# 加载预训练好的BERT模型和tokenizer
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
# 加载数据集
train_texts = ['This is a positive sentence.', 'This is a negative sentence.']
train_labels = [1, 0]
test_texts = ['This is another positive sentence.', 'This is another negative sentence.']
test_labels = [1, 0]
# 将数据集转换为BERT输入格式
train_encodings = tokenizer(train_texts, truncation=True, padding=True)
test_encodings = tokenizer(test_texts, truncation=True, padding=True)
train_labels = torch.tensor(train_labels)
test_labels = torch.tensor(test_labels)
train_dataset = TensorDataset(train_encodings['input_ids'], train_encodings['attention_mask'], train_labels)
test_dataset = TensorDataset(test_encodings['input_ids'], test_encodings['attention_mask'], test_labels)
# 设置训练参数
epochs = 3
batch_size = 8
learning_rate = 2e-5
epsilon = 1e-8
num_adv_steps = 1
adv_learning_rate = 1e-5
# 定义对抗函数,使用FGM对抗训练
def fgsm_attack(input_ids, attention_mask, labels, epsilon):
# 将模型设置为训练模式
model.train()
# 创建对抗样本
input_ids.requires_grad = True
attention_mask.requires_grad = True
loss_func = torch.nn.CrossEntropyLoss()
outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
loss = loss_func(outputs.logits, labels)
loss.backward()
# 对抗样本
input_ids_grad = torch.sign(input_ids.grad)
attention_mask_grad = torch.sign(attention_mask.grad)
input_ids = input_ids + epsilon * input_ids_grad
attention_mask = attention_mask + epsilon * attention_mask_grad
# 清除梯度
model.zero_grad()
input_ids.grad = None
attention_mask.grad = None
return input_ids, attention_mask
# 定义优化器和学习率调度器
optimizer = AdamW(model.parameters(), lr=learning_rate, eps=epsilon)
total_steps = len(train_dataset) * epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)
# 训练模型
train_dataloader = DataLoader(train_dataset, sampler=RandomSampler(train_dataset), batch_size=batch_size)
for epoch in range(epochs):
model.train()
total_loss = 0
for batch in tqdm(train_dataloader, desc="Training"):
input_ids = batch[0].to(device)
attention_mask = batch[1].to(device)
labels = batch[2].to(device)
# 对抗训练
for i in range(num_adv_steps):
adv_input_ids, adv_attention_mask = fgsm_attack(input_ids, attention_mask, labels, epsilon)
adv_outputs = model(adv_input_ids, attention_mask=adv_attention_mask, labels=labels)
adv_loss = adv_outputs.loss
adv_loss.backward()
# 更新模型参数
optimizer.step()
scheduler.step()
model.zero_grad()
# 计算损失函数值
outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
total_loss += loss.item()
avg_loss = total_loss / len(train_dataloader)
print("Epoch:", epoch+1, "Train loss:", avg_loss)
# 在测试集上测试模型
test_dataloader = DataLoader(test_dataset, sampler=SequentialSampler(test_dataset), batch_size=batch_size)
model.eval()
total_accuracy = 0
for batch in tqdm(test_dataloader, desc="Testing"):
input_ids = batch[0].to(device)
attention_mask = batch[1].to(device)
labels = batch[2].to(device)
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask)
logits = outputs.logits
preds = torch.argmax(logits, axis=1)
accuracy = torch.sum(preds == labels).item() / len(labels)
total_accuracy += accuracy
avg_accuracy = total_accuracy / len(test_dataloader)
print("Test accuracy:", avg_accuracy)
```
此代码使用FGM对抗训练方法,在每个训练步骤中对输入数据进行一次对抗样本生成,并使用对抗样本进行模型训练。在测试阶段,模型将在未经过对抗样本生成的测试数据上进行测试。
阅读全文