用Bert实现中文文本分类的代码
时间: 2023-11-27 19:52:45 浏览: 111
以下是用BERT实现中文文本分类的代码,包括数据预处理、模型构建和训练等步骤:
1. 数据预处理
```python
import pandas as pd
import numpy as np
import tensorflow as tf
import os
from sklearn.model_selection import train_test_split
from transformers import BertTokenizer
# 加载数据
df = pd.read_csv('data.csv')
# 标签映射
labels = df['label'].unique()
label2id = {label: index for index, label in enumerate(labels)}
id2label = {index: label for index, label in enumerate(labels)}
# 文本处理
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
max_length = 128
def create_inputs_targets(data):
input_ids = []
attention_masks = []
targets = []
for _, row in data.iterrows():
text = row['text']
label = row['label']
# 分词
encoded_dict = tokenizer.encode_plus(text,
add_special_tokens=True,
max_length=max_length,
pad_to_max_length=True,
return_attention_mask=True,
return_tensors='tf')
input_ids.append(encoded_dict['input_ids'])
attention_masks.append(encoded_dict['attention_mask'])
targets.append(label2id[label])
return np.array(input_ids), np.array(attention_masks), np.array(targets)
# 划分数据集
train_data, test_data = train_test_split(df, test_size=0.2, random_state=42)
train_inputs, train_masks, train_targets = create_inputs_targets(train_data)
test_inputs, test_masks, test_targets = create_inputs_targets(test_data)
```
2. 模型构建
```python
from transformers import TFBertForSequenceClassification
# 加载预训练模型
model = TFBertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=len(labels))
# 定义优化器和损失函数
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-5)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
# 编译模型
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
```
3. 模型训练
```python
batch_size = 32
epochs = 3
# 训练模型
history = model.fit([train_inputs, train_masks], train_targets, batch_size=batch_size, epochs=epochs, validation_split=0.1)
```
4. 模型评估
```python
test_loss, test_accuracy = model.evaluate([test_inputs, test_masks], test_targets, batch_size=batch_size)
print('Test Loss:', test_loss)
print('Test Accuracy:', test_accuracy)
```
以上是用BERT实现中文文本分类的完整代码,其中数据集需自行准备。
阅读全文