import tensorflow as tf from tensorflow import keras import numpy as np # 加载 IMDB 数据集 imdb = keras.datasets.imdb (train_data, train_labels), (test_data, test_labels) = imdb.load_data(num_words=10000) # 将整数序列填充为相同的长度 maxlen = 256 train_data = keras.preprocessing.sequence.pad_sequences(train_data, value=0, padding='post', maxlen=maxlen) test_data = keras.preprocessing.sequence.pad_sequences(test_data, value=0, padding='post', maxlen=maxlen) # 构建模型 model = keras.Sequential() model.add(keras.layers.Embedding(10000, 16)) model.add(keras.layers.GlobalAveragePooling1D()) model.add(keras.layers.Dense(16, activation=tf.nn.relu)) model.add(keras.layers.Dense(1, activation=tf.sigmoid)) # 编译模型 model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy']) # 训练模型 history = model.fit(train_data, train_labels, epochs=40, batch_size=512, validation_data=(test_data, test_labels), verbose=1) # 在测试数据上评估模型 results = model.evaluate(test_data, test_labels) print('Test loss:', results[0]) print('Test accuracy:', results[1])优化代码
时间: 2023-07-15 16:15:00 浏览: 176
以下是对代码的优化:
1. 使用更加简洁的导入方式:
```
from tensorflow import keras
import tensorflow as tf
import numpy as np
```
2. 将模型的架构定义为一个函数,这样可以使代码更加清晰:
```
def build_model():
model = keras.Sequential([
keras.layers.Embedding(10000, 16),
keras.layers.GlobalAveragePooling1D(),
keras.layers.Dense(16, activation=tf.nn.relu),
keras.layers.Dense(1, activation=tf.sigmoid)
])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
return model
```
3. 使用 EarlyStopping 和 ModelCheckpoint 回调函数来提高模型的性能:
```
early_stop = keras.callbacks.EarlyStopping(monitor='val_loss', patience=5)
model_checkpoint = keras.callbacks.ModelCheckpoint('best_model.h5', save_best_only=True)
history = model.fit(train_data, train_labels, epochs=40, batch_size=512,
validation_data=(test_data, test_labels), verbose=1,
callbacks=[early_stop, model_checkpoint])
```
4. 使用更加简洁的方式来输出测试结果:
```
loss, accuracy = model.evaluate(test_data, test_labels)
print(f'Test loss: {loss}, Test accuracy: {accuracy}')
```
综上所述,我们可以对代码进行如下的优化:
```
from tensorflow import keras
import tensorflow as tf
import numpy as np
def build_model():
model = keras.Sequential([
keras.layers.Embedding(10000, 16),
keras.layers.GlobalAveragePooling1D(),
keras.layers.Dense(16, activation=tf.nn.relu),
keras.layers.Dense(1, activation=tf.sigmoid)
])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
return model
# 加载 IMDB 数据集
(train_data, train_labels), (test_data, test_labels) = keras.datasets.imdb.load_data(num_words=10000)
# 将整数序列填充为相同的长度
maxlen = 256
train_data = keras.preprocessing.sequence.pad_sequences(train_data, value=0, padding='post', maxlen=maxlen)
test_data = keras.preprocessing.sequence.pad_sequences(test_data, value=0, padding='post', maxlen=maxlen)
# 构建模型
model = build_model()
# 使用 EarlyStopping 和 ModelCheckpoint 回调函数来提高模型的性能
early_stop = keras.callbacks.EarlyStopping(monitor='val_loss', patience=5)
model_checkpoint = keras.callbacks.ModelCheckpoint('best_model.h5', save_best_only=True)
# 训练模型
history = model.fit(train_data, train_labels, epochs=40, batch_size=512,
validation_data=(test_data, test_labels), verbose=1,
callbacks=[early_stop, model_checkpoint])
# 在测试数据上评估模型
loss, accuracy = model.evaluate(test_data, test_labels)
print(f'Test loss: {loss}, Test accuracy: {accuracy}')
```
阅读全文
相关推荐
![rar](https://img-home.csdnimg.cn/images/20241231044955.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![zip](https://img-home.csdnimg.cn/images/20241231045053.png)