如何用Tensorflow训练模型成pb文件和和如何加载已经训练好的模型文件1
时间: 2024-04-30 15:18:09 浏览: 15
训练模型成pb文件:
首先,你需要先定义并训练好一个Tensorflow模型。在模型训练完成后,你可以使用Tensorflow的`freeze_graph.py`脚本将模型转换为pb文件。
在终端中输入以下命令:
```
python tensorflow/python/tools/freeze_graph.py \
--input_graph=<input_graph.pb> \
--input_checkpoint=<input_checkpoint> \
--output_graph=<output_graph.pb> \
--output_node_names=<output_node_names>
```
其中:
- `<input_graph.pb>`:模型的GraphDef文件。
- `<input_checkpoint>`:模型的checkpoint文件。
- `<output_graph.pb>`:转换后的pb文件的输出路径。
- `<output_node_names>`:输出节点的名称,可以在定义模型时指定。
例如:
```
python tensorflow/python/tools/freeze_graph.py \
--input_graph=./models/graph.pb \
--input_checkpoint=./models/model.ckpt \
--output_graph=./models/frozen_graph.pb \
--output_node_names=output_node
```
这将把`graph.pb`和`model.ckpt`转换为`frozen_graph.pb`文件。其中`output_node`是模型定义时指定的输出节点名称。
加载已经训练好的模型文件:
要加载已经训练好的模型文件,你需要使用Tensorflow的`tf.Session()`来创建一个会话,并使用`tf.train.import_meta_graph()`方法将模型的MetaGraph文件导入到当前的计算图中。然后,你可以使用`tf.get_default_graph()`方法获取默认的计算图,并使用`get_tensor_by_name()`方法获取模型中的张量。
以下是一个加载已经训练好的模型文件的示例代码:
```
import tensorflow as tf
# 创建一个会话
sess = tf.Session()
# 加载MetaGraph文件
saver = tf.train.import_meta_graph('./models/model.ckpt.meta')
# 恢复变量
saver.restore(sess, './models/model.ckpt')
# 获取默认计算图
graph = tf.get_default_graph()
# 获取模型中的张量
input_tensor = graph.get_tensor_by_name('input:0')
output_tensor = graph.get_tensor_by_name('output:0')
```
在这个例子中,我们使用`saver.restore()`方法恢复了模型的变量,然后获取了模型中的`input`和`output`张量。这里`input`和`output`是在定义模型时所命名的张量名称。