tf.gather函数用法
时间: 2023-06-19 21:08:48 浏览: 186
`tf.gather`函数是 TensorFlow 中的一个操作,用于根据索引从张量中收集元素。它的语法如下:
```python
tf.gather(params, indices, axis=None, batch_dims=0, name=None)
```
参数说明:
- `params`: 要从中收集元素的张量,可以是任何形状的张量。
- `indices`: 一个张量,指定要收集哪些元素。它可以是任何形状的整数张量。
- `axis`: 一个可选的整数,指定从哪个轴收集元素。默认为`None`,表示要将`indices`解释为一维向量。如果指定`axis`,则`indices`必须是具有相同形状的张量。
- `batch_dims`: 一个可选的整数,指定批次维度的数量。默认为0,表示没有批次维度。例如,如果`params`的形状是`(batch_size, height, width, channels)`,则`batch_dims`可以设置为1,表示`indices`的形状是`(batch_size, num_elements)`。
- `name`: (可选)操作的名称。
`tf.gather`函数的返回值是一个张量,其中包含来自`params`的元素,其索引由`indices`指定。
下面是一个使用`tf.gather`函数的示例代码:
```python
import tensorflow as tf
# 创建一个3x3的矩阵
x = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
# 从第1个轴(即行)收集第0行和第2行
indices = tf.constant([0, 2])
y = tf.gather(x, indices, axis=0)
print(y.numpy()) # 输出 [[1 2 3] [7 8 9]]
```
在这个例子中,我们创建了一个3x3的矩阵`x`,然后使用`tf.gather`函数从第一个轴(即行)收集第0行和第2行,得到了形状为`(2, 3)`的矩阵`y`。注意,`indices`的形状是`(2,)`,因为我们没有指定`axis`参数。
阅读全文