textcnn文本分类代码keras
时间: 2023-08-24 10:05:01 浏览: 106
以下是使用Keras实现TextCNN文本分类的示例代码:
```python
import numpy as np
from keras.models import Model
from keras.layers import Input, Dense, Dropout, Embedding, Conv1D, GlobalMaxPooling1D
from keras.preprocessing.sequence import pad_sequences
from keras.utils import to_categorical
from sklearn.model_selection import train_test_split
# 读取数据
def read_data(filename):
sentences = []
labels = []
with open(filename, 'r', encoding='utf-8') as f:
for line in f:
sentence, label = line.strip().split('\t')
sentences.append(sentence)
labels.append(label)
return sentences, labels
# 处理数据
def process_data(sentences, labels, max_len, word_dict):
# 将文本转换为数字序列
X = []
for sentence in sentences:
words = sentence.split(' ')
seq = []
for word in words:
if word in word_dict:
seq.append(word_dict[word])
else:
seq.append(0)
X.append(seq)
# 对数据进行填充和截断
X = pad_sequences(X, maxlen=max_len, padding='post', truncating='post', value=0)
# 将标签转换为one-hot编码
Y = to_categorical(np.array(labels))
return X, Y
# 构建模型
def build_model(max_len, max_word, embedding_dim):
# 输入层
inputs = Input(shape=(max_len,))
# 词嵌入层
x = Embedding(max_word, embedding_dim)(inputs)
# 卷积层和池化层
x = Conv1D(128, 5, activation='relu')(x)
x = GlobalMaxPooling1D()(x)
# 全连接层和输出层
x = Dense(128, activation='relu')(x)
x = Dropout(0.5)(x)
outputs = Dense(2, activation='softmax')(x)
# 构建模型
model = Model(inputs=inputs, outputs=outputs)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
return model
# 主函数
if __name__ == '__main__':
# 读取数据
sentences, labels = read_data('data.txt')
# 构建词典
word_dict = {}
for sentence in sentences:
for word in sentence.split(' '):
if word not in word_dict:
word_dict[word] = len(word_dict) + 1
# 处理数据
max_len = 100
X, Y = process_data(sentences, labels, max_len, word_dict)
# 划分训练集和测试集
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=42)
# 构建模型
max_word = len(word_dict) + 1
embedding_dim = 100
model = build_model(max_len, max_word, embedding_dim)
# 训练模型
model.fit(X_train, Y_train, batch_size=64, epochs=10, validation_data=(X_test, Y_test))
```
这是一个基本的TextCNN模型,可以根据实际需求进行调整和优化。
阅读全文