在tensorflow2.0中用LSTM训练路透社数据集进行文本分量,给出python代码
时间: 2024-06-09 12:05:40 浏览: 15
好的,以下是用LSTM训练路透社数据集进行文本分类的Python代码:
```python
import numpy as np
from tensorflow.keras.datasets import reuters
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, LSTM, Embedding
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.utils import to_categorical
# Load Reuters dataset
(x_train, y_train), (x_test, y_test) = reuters.load_data(num_words=10000)
# Get word index dictionary
word_index = reuters.get_word_index()
# Reverse word index to get words from indices
reverse_word_index = dict([(value, key) for (key, value) in word_index.items()])
# Decode input sequences
decoded_newswire = ' '.join([reverse_word_index.get(i - 3, '?') for i in x_train[0]])
# Pad input sequences to have same length
maxlen = 500
x_train = pad_sequences(x_train, maxlen=maxlen)
x_test = pad_sequences(x_test, maxlen=maxlen)
# One-hot encode labels
num_classes = np.max(y_train) + 1
y_train = to_categorical(y_train, num_classes)
y_test = to_categorical(y_test, num_classes)
# Build LSTM model
model = Sequential()
model.add(Embedding(10000, 32))
model.add(LSTM(32))
model.add(Dense(num_classes, activation='softmax'))
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
# Train LSTM model
history = model.fit(x_train, y_train, epochs=10, batch_size=128, validation_split=0.2)
# Evaluate LSTM model on test data
test_loss, test_acc = model.evaluate(x_test, y_test)
print('Test accuracy:', test_acc)
```
这段代码中,我们首先加载路透社数据集,并将其分为训练集和测试集。然后,我们获取数据集的词汇表,并使用其构建一个反向词汇表,以便我们可以从整数序列中获取单词。接着,我们将输入序列填充到相同的长度,并对标签进行one-hot编码。然后,我们使用Embedding层和LSTM层构建一个LSTM模型,并使用rmsprop优化器和分类交叉熵损失函数进行编译。最后,我们在训练集的20%上进行验证,并在测试集上评估模型的性能。
相关推荐
![py](https://img-home.csdnimg.cn/images/20210720083646.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)