如果输入变量为两列,代码怎么修改
时间: 2024-05-06 09:17:22 浏览: 8
如果输入变量为两列,则需要将数据处理成一个文本序列,然后将其输入到模型中进行训练和预测。具体的代码修改如下:
1. 数据读取和预处理
```python
import pandas as pd
# 读取数据
data = pd.read_csv("data.csv")
# 将两列数据合并成一个文本序列,并添加分隔符
data["text"] = data["column1"].astype(str) + " [SEP] " + data["column2"].astype(str)
# 保存处理后的数据
data.to_csv("processed_data.csv", index=False)
```
2. 模型训练和预测
```python
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# 加载模型和分词器
tokenizer = AutoTokenizer.from_pretrained("CSDN/ChitGPT")
model = AutoModelForCausalLM.from_pretrained("CSDN/ChitGPT")
# 加载处理后的数据
data = pd.read_csv("processed_data.csv")
# 将文本序列编码成数字序列
inputs = tokenizer(data["text"].tolist(), padding=True, truncation=True, return_tensors="pt")
# 将数字序列输入到模型中进行预测
outputs = model.generate(inputs["input_ids"], max_length=50, do_sample=True)
# 将预测结果解码成文本序列
predictions = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
# 打印预测结果
for prediction in predictions:
print(prediction)
```