获取在graph中的tensor形状
时间: 2024-09-23 12:14:53 浏览: 53
接口说明-ddr4 spd spec
在TensorFlow的图(Graph)中,获取张量(tensor)的形状通常通过`tf.shape()`函数完成。这个操作会返回一个one-hot编码的张量,其中每个元素表示原张量对应维度的大小。例如:
```python
# 假设我们有一个名为input_tensor的张量
input_tensor = tf.placeholder(tf.float32, shape=[None, 10, 5])
# 要获取其形状,可以这样做
shape_tensor = tf.shape(input_tensor)
# 运行计算图时,你可以得到具体的形状信息,比如
with tf.Session() as sess:
tensor_shape = sess.run(shape_tensor, {input_tensor: np.random.rand(3, 10, 5)})
print("输入张量的形状:", tensor_shape)
```
这将输出类似 `[3, 10, 5]` 的形状数组,表示该张量是一个3维张量,第一维有3个元素,第二维长度为10,第三维长度为5。
阅读全文