class TextMatchDataset(dataset.Dataset): def __init__(self, args, tokenizer, file_path): self.config = args self.tokenizer = tokenizer self.path = file_path self.inference = False self.max_seq_len = self.config.max_seq_len self.labels2id = args.labels2id_list[0] self.contents = self.load_dataset_match(self.config)
时间: 2023-06-10 14:06:06 浏览: 117
这段代码是一个自定义的 PyTorch Dataset 类,用于加载文本匹配任务的数据集。其中包含了如下的属性和方法:
- `__init__(self, args, tokenizer, file_path)`:初始化函数,参数包括训练参数 `args`、分词器 `tokenizer`、数据集文件路径 `file_path`。同时还包括一些其他的属性,例如 `inference` 表示是否为预测模式,`max_seq_len` 表示最大序列长度,`labels2id` 表示标签的映射关系等。
- `load_dataset_match(self, config)`:加载数据集的方法,返回一个 `List[List[str]]` 类型的数据,每个元素都是一个长度为 3 的列表,分别表示 query、pos_doc 和 neg_doc。
- `__len__(self)`:返回数据集的长度。
- `__getitem__(self, index)`:根据索引返回一个样本,返回的是一个字典类型,包括了 query、pos_doc、neg_doc 的分词结果以及对应的标签。
该自定义 Dataset 类可以被用于 PyTorch 模型的训练和评估。
相关问题
AttributeError: 'Seq2SeqTrainer' object has no attribute 'is_deepspeed_enabled'
这个错误通常是由于使用了深度学习框架Hugging Face的Seq2SeqTrainer类的一个属性,但是没有正确地配置深度学习框架的深度学习加速库Deepspeed。解决这个问题的方法是在代码中添加一些必要的配置,以确保Deepspeed正确地启用。以下是一些可能有用的步骤:
1. 确保你已经安装了Deepspeed库,并且版本与你的深度学习框架版本兼容。
2. 在你的代码中添加以下导入语句:
```python
from transformers import DeepspeedConfig, set_seed
```
3. 在你的代码中添加以下配置:
```python
deepspeed_config = DeepspeedConfig()
deepspeed_config["deepspeed"] = {
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu"
}
}
}
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
data_collator=data_collator,
tokenizer=tokenizer,
deepspeed=deepspeed_config
)
```
这些配置将确保Deepspeed正确地启用,并且你的代码应该能够正常运行。
https://github.com/weizhepei/CasRel中run.py解读
`run.py` 是 `CasRel` 项目的入口文件,用于训练和测试模型。以下是 `run.py` 的主要代码解读和功能说明:
### 导入依赖包和模块
首先,`run.py` 导入了所需的依赖包和模块,包括 `torch`、`numpy`、`argparse`、`logging` 等。
```python
import argparse
import logging
import os
import random
import time
import numpy as np
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from casrel import CasRel
from dataset import RE_Dataset
from utils import init_logger, load_tokenizer, set_seed, collate_fn
```
### 解析命令行参数
接下来,`run.py` 解析了命令行参数,包括训练数据路径、模型保存路径、预训练模型路径、学习率等参数。
```python
def set_args():
parser = argparse.ArgumentParser()
parser.add_argument("--train_data", default=None, type=str, required=True,
help="The input training data file (a text file).")
parser.add_argument("--dev_data", default=None, type=str, required=True,
help="The input development data file (a text file).")
parser.add_argument("--test_data", default=None, type=str, required=True,
help="The input testing data file (a text file).")
parser.add_argument("--model_path", default=None, type=str, required=True,
help="Path to save, load model")
parser.add_argument("--pretrain_path", default=None, type=str,
help="Path to pre-trained model")
parser.add_argument("--vocab_path", default=None, type=str, required=True,
help="Path to vocabulary")
parser.add_argument("--batch_size", default=32, type=int,
help="Batch size per GPU/CPU for training.")
parser.add_argument("--gradient_accumulation_steps", default=1, type=int,
help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument("--learning_rate", default=5e-5, type=float,
help="The initial learning rate for Adam.")
parser.add_argument("--num_train_epochs", default=3, type=int,
help="Total number of training epochs to perform.")
parser.add_argument("--max_seq_length", default=256, type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.")
parser.add_argument("--warmup_proportion", default=0.1, type=float,
help="Linear warmup over warmup_steps.")
parser.add_argument("--weight_decay", default=0.01, type=float,
help="Weight decay if we apply some.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
help="Epsilon for Adam optimizer.")
parser.add_argument("--max_grad_norm", default=1.0, type=float,
help="Max gradient norm.")
parser.add_argument("--logging_steps", type=int, default=500,
help="Log every X updates steps.")
parser.add_argument("--save_steps", type=int, default=500,
help="Save checkpoint every X updates steps.")
parser.add_argument("--seed", type=int, default=42,
help="random seed for initialization")
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
help="selected device (default: cuda if available)")
args = parser.parse_args()
return args
```
### 加载数据和模型
接下来,`run.py` 加载了训练、验证和测试数据,以及 `CasRel` 模型。
```python
def main():
args = set_args()
init_logger()
set_seed(args)
tokenizer = load_tokenizer(args.vocab_path)
train_dataset = RE_Dataset(args.train_data, tokenizer, args.max_seq_length)
dev_dataset = RE_Dataset(args.dev_data, tokenizer, args.max_seq_length)
test_dataset = RE_Dataset(args.test_data, tokenizer, args.max_seq_length)
train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.batch_size,
collate_fn=collate_fn)
dev_sampler = SequentialSampler(dev_dataset)
dev_dataloader = DataLoader(dev_dataset, sampler=dev_sampler, batch_size=args.batch_size,
collate_fn=collate_fn)
test_sampler = SequentialSampler(test_dataset)
test_dataloader = DataLoader(test_dataset, sampler=test_sampler, batch_size=args.batch_size,
collate_fn=collate_fn)
model = CasRel(args)
if args.pretrain_path:
model.load_state_dict(torch.load(args.pretrain_path, map_location="cpu"))
logging.info(f"load pre-trained model from {args.pretrain_path}")
model.to(args.device)
```
### 训练模型
接下来,`run.py` 开始训练模型,包括前向传播、反向传播、梯度更新等步骤。
```python
optimizer = torch.optim.Adam([{'params': model.bert.parameters(), 'lr': args.learning_rate},
{'params': model.subject_fc.parameters(), 'lr': args.learning_rate},
{'params': model.object_fc.parameters(), 'lr': args.learning_rate},
{'params': model.predicate_fc.parameters(), 'lr': args.learning_rate},
{'params': model.linear.parameters(), 'lr': args.learning_rate}],
lr=args.learning_rate, eps=args.adam_epsilon, weight_decay=args.weight_decay)
total_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
warmup_steps = int(total_steps * args.warmup_proportion)
scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer,
lr_lambda=lambda epoch: 1 / (1 + 0.05 * (epoch - 1))
)
global_step = 0
best_f1 = 0
for epoch in range(args.num_train_epochs):
for step, batch in enumerate(train_dataloader):
model.train()
batch = tuple(t.to(args.device) for t in batch)
inputs = {
"input_ids": batch[0],
"attention_mask": batch[1],
"token_type_ids": batch[2],
"subj_pos": batch[3],
"obj_pos": batch[4],
"subj_type": batch[5],
"obj_type": batch[6],
"subj_label": batch[7],
"obj_label": batch[8],
"predicate_label": batch[9],
}
outputs = model(**inputs)
loss = outputs[0]
loss.backward()
if (step + 1) % args.gradient_accumulation_steps == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
global_step += 1
if global_step % args.logging_steps == 0:
logging.info(f"Epoch:[{epoch + 1}]/[{args.num_train_epochs}] Step:[{global_step}] "
f"Train loss:{loss.item():.6f}")
if global_step % args.save_steps == 0:
f1 = evaluate(model, dev_dataloader, args)
if f1 > best_f1:
best_f1 = f1
torch.save(model.state_dict(), os.path.join(args.model_path, "best_model.bin"))
logging.info(f"Save model at step [{global_step}] with best f1 {best_f1:.4f}")
```
### 测试模型
最后,`run.py` 对模型进行测试,输出模型在测试集上的预测结果。
```python
model.load_state_dict(torch.load(os.path.join(args.model_path, "best_model.bin"), map_location="cpu"))
logging.info(f"load best model from {os.path.join(args.model_path, 'best_model.bin')}")
f1, precision, recall = evaluate(model, test_dataloader, args)
logging.info(f"Test f1:{f1:.4f} precision:{precision:.4f} recall:{recall:.4f}")
```
以上就是 `run.py` 的主要代码解读和功能说明。
阅读全文
相关推荐
![-](https://img-home.csdnimg.cn/images/20241231044930.png)
![-](https://img-home.csdnimg.cn/images/20241231045053.png)
![-](https://img-home.csdnimg.cn/images/20241231045053.png)
![rar](https://img-home.csdnimg.cn/images/20241231044955.png)
![-](https://img-home.csdnimg.cn/images/20241226111658.png)
![-](https://img-home.csdnimg.cn/images/20241226111658.png)
![-](https://img-home.csdnimg.cn/images/20241226111658.png)
![-](https://img-home.csdnimg.cn/images/20241226111658.png)
![-](https://img-home.csdnimg.cn/images/20241226111658.png)
![-](https://img-home.csdnimg.cn/images/20241226111658.png)
![-](https://img-home.csdnimg.cn/images/20241226111658.png)
![-](https://img-home.csdnimg.cn/images/20241226111658.png)
![-](https://img-home.csdnimg.cn/images/20241226111658.png)
![-](https://img-home.csdnimg.cn/images/20241226111658.png)
![-](https://img-home.csdnimg.cn/images/20241226111658.png)
![-](https://img-home.csdnimg.cn/images/20241226111658.png)