给出一个使用bert模型对新闻文本分类的示例,基于tensorflow2.0
时间: 2023-05-22 11:01:26 浏览: 116
基于tensorflow2.0中的keras进行中文的文本分类,实验数据为中文新闻分类文本cnews数据集。.zip
以下是一个使用BERT模型对新闻文本进行分类的示例,基于TensorFlow 2.0:
```python
import tensorflow as tf
import tensorflow_hub as hub
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Dropout
# 加载BERT模型
bert_layer = hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2", trainable=True)
# 构建模型
input_word_ids = Input(shape=(128,), dtype=tf.int32, name="input_word_ids")
input_mask = Input(shape=(128,), dtype=tf.int32, name="input_mask")
segment_ids = Input(shape=(128,), dtype=tf.int32, name="segment_ids")
pooled_output, sequence_output = bert_layer([input_word_ids, input_mask, segment_ids])
x = Dense(256, activation='relu')(pooled_output)
x = Dropout(0.2)(x)
output = Dense(10, activation='softmax')(x)
model = Model(inputs=[input_word_ids, input_mask, segment_ids], outputs=output)
# 编译模型
optimizer = tf.keras.optimizers.Adam(lr=2e-5)
loss = tf.keras.losses.CategoricalCrossentropy()
metric = tf.keras.metrics.CategoricalAccuracy('accuracy')
model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
# 训练模型
history = model.fit(train_dataset, epochs=5, validation_data=val_dataset)
```
在这个示例中,我们使用了BERT模型对新闻文本进行分类。我们首先加载了BERT模型,然后构建了一个包含一个Dense层和一个Dropout层的分类器。我们使用了CategoricalCrossentropy作为损失函数,使用Adam优化器进行优化,并使用CategoricalAccuracy作为评估指标。最后,我们训练了模型并记录了训练历史。
阅读全文