/tf_static
时间: 2024-12-29 17:26:26 浏览: 5
### TensorFlow 静态图 (`tf_static`) 使用方法与示例
#### 定义静态计算图
在TensorFlow中,静态图是指预先构建好整个计算流程,在执行前已经确定了所有的操作和张量关系。这种模式允许优化器提前对图形进行分析和优化。
对于静态图的操作通常涉及以下几个方面:
- **创建会话 (Session)**:用于启动默认的计算环境。
```python
import tensorflow as tf
sess = tf.Session()
```
- **定义变量和占位符**:这些是用来存储数据或接收外部输入的数据结构。
```python
# 创建常量节点
const_node = tf.constant(2.0, name="const_node")
# 创建可变参数节点
weights = tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="weights")
# 占位符用来提供外部输入
input_data = tf.placeholder(dtype=tf.float32, shape=[None, 784])
```
- **建立运算逻辑**:通过调用各种算子来连接各个节点形成完整的计算路径。
```python
output = tf.matmul(input_data, weights) + bias
```
- **初始化全局变量**
```python
init_op = tf.global_variables_initializer()
sess.run(init_op)
```
- **保存模型**:当训练完成后可以将当前状态序列化以便后续加载重用。
```python
saver = tf.train.Saver()
save_path = saver.save(sess, "/tmp/model.ckpt")
print("Model saved in path: %s" % save_path)[^1]
```
#### 加载预训练好的模型并推理
如果已经有了一个经过训练后的`.pb`文件,则可以直接读取该二进制格式的协议缓冲区表示形式来进行预测工作。
```python
from tensorflow.python.platform import gfile
with tf.Session() as sess:
with gfile.FastGFile('/path/to/frozen_inference_graph.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')
# 获取特定层作为输入输出接口
input_tensor = sess.graph.get_tensor_by_name('input:0')
output_tensor = sess.graph.get_tensor_by_name('output:0')
predictions = sess.run(output_tensor, feed_dict={input_tensor: test_images})
```
上述过程展示了如何利用已有的冻结图(`frozen_inference_graph.pb`)完成图像分类任务中的推断部分。
阅读全文