non local 算子tensorflow版本的实现
时间: 2023-12-04 19:04:14 浏览: 109
在 TensorFlow 中,可以使用 `tf.raw_ops.NonMaxSuppressionV4` 实现 `non local` 算子。下面是一个示例代码:
```python
import tensorflow as tf
def non_local_block(x, compression=2, mode='embedded'):
"""
Non Local Block
:param x: input tensor
:param compression: compression ratio for reducing channels
:param mode: Non-local block mode ('embedded' or 'dot')
:return: output tensor
"""
batch_size, height, width, channels = x.get_shape().as_list()
if mode == 'embedded':
# Embedded Gaussian version of Non-local block
theta = tf.layers.conv2d(x, channels // compression, kernel_size=1, use_bias=False)
phi = tf.layers.conv2d(x, channels // compression, kernel_size=1, use_bias=False)
g = tf.layers.conv2d(x, channels // compression, kernel_size=1, use_bias=False)
theta = tf.reshape(theta, [batch_size, height * width, channels // compression])
phi = tf.reshape(phi, [batch_size, height * width, channels // compression])
g = tf.reshape(g, [batch_size, height * width, channels // compression])
# Compute pairwise affinity between each spatial location
theta_phi = tf.matmul(theta, phi, transpose_b=True)
theta_phi = tf.nn.softmax(theta_phi)
# Compute output feature map
y = tf.matmul(theta_phi, g)
y = tf.reshape(y, [batch_size, height, width, channels // compression])
y = tf.layers.conv2d(y, channels, kernel_size=1, use_bias=False)
elif mode == 'dot':
# Dot product version of Non-local block
theta = tf.layers.conv2d(x, channels // compression, kernel_size=1, use_bias=False)
phi = tf.layers.conv2d(x, channels // compression, kernel_size=1, use_bias=False)
g = tf.layers.conv2d(x, channels, kernel_size=1, use_bias=False)
theta = tf.reshape(theta, [batch_size, height * width, channels // compression])
phi = tf.reshape(phi, [batch_size, height * width, channels // compression])
g = tf.reshape(g, [batch_size, height * width, channels])
# Compute pairwise affinity between each spatial location
theta_phi = tf.matmul(theta, phi, transpose_b=True)
theta_phi /= tf.cast(channels // compression, tf.float32)
theta_phi = tf.nn.softmax(theta_phi)
# Compute output feature map
y = tf.matmul(theta_phi, g)
y = tf.reshape(y, [batch_size, height, width, channels])
# Add residual connection
y = x + y
return tf.contrib.layers.layer_norm(y) # apply layer normalization
```
其中,`mode` 参数可以选择 `'embedded'` 或 `'dot'`。`'embedded'` 使用嵌入高斯方法计算相似性矩阵,而 `'dot'` 则使用点积计算相似性矩阵。
阅读全文