tensorflow2保存模型格式
时间: 2023-09-14 10:10:15 浏览: 103
TensorFlow模型保存/载入的两种方法
在TensorFlow 2中,可以使用以下两种格式来保存模型:
1. SavedModel格式:这是TensorFlow 2中推荐使用的格式。它是一种通用的格式,可以在不同的平台和语言中使用,例如Python、C++、Java等。SavedModel格式保存的是整个模型的图结构和权重。
保存模型:
```python
import tensorflow as tf
# 创建模型
model = tf.keras.Sequential([
tf.keras.layers.Dense(10, input_shape=(784,), activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
# 训练模型
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test))
# 保存模型
model.save('my_model')
```
加载模型:
```python
import tensorflow as tf
# 加载模型
model = tf.keras.models.load_model('my_model')
# 使用模型预测
predictions = model.predict(x_test)
```
2. Keras HDF5格式:这是一种仅限于Python使用的格式,它只保存了模型的权重,没有保存图结构。
保存模型:
```python
import tensorflow as tf
# 创建模型
model = tf.keras.Sequential([
tf.keras.layers.Dense(10, input_shape=(784,), activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
# 训练模型
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test))
# 保存模型
model.save_weights('my_model_weights.h5')
```
加载模型:
```python
import tensorflow as tf
# 创建模型
model = tf.keras.Sequential([
tf.keras.layers.Dense(10, input_shape=(784,), activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
# 加载模型权重
model.load_weights('my_model_weights.h5')
# 编译模型
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
# 使用模型预测
predictions = model.predict(x_test)
```
阅读全文