class BottleneckTransformer(nn.Module): def __init__(self,in_c,out_c,fm_sz,head_n = 4): super(BottleneckTransformer,self).__init__() self.botneck = nn.Conv2d(in_channels = in_c,out_channels = out_c,kernel_size = 1) self.pool = nn.MaxPool2d(kernel_size = 2,stride = 2) self.sa = nn.Sequential( MultiHeadSelfAttention(in_c = in_c,out_c = out_c // head_n,head_n = head_n,fm_sz = fm_sz), MultiHeadSelfAttention(in_c = out_c,out_c = out_c // head_n,head_n = head_n,fm_sz = fm_sz) ) def forward(self,x): x0 = self.botneck(x) x = self.sa(x) x = x + x0 x = self.pool(x) return x 改为tensorflow形式
时间: 2024-04-23 22:23:25 浏览: 109
import tensorflow as tf
class BottleneckTransformer(tf.keras.layers.Layer):
def __init__(self, in_c, out_c, fm_sz, head_n=4):
super(BottleneckTransformer, self).__init__()
self.botneck = tf.keras.layers.Conv2D(filters=out_c, kernel_size=1)
self.pool = tf.keras.layers.MaxPool2D(pool_size=2, strides=2)
self.sa = tf.keras.Sequential([
MultiHeadSelfAttention(in_c=in_c, out_c=out_c // head_n, head_n=head_n, fm_sz=fm_sz),
MultiHeadSelfAttention(in_c=out_c, out_c=out_c // head_n, head_n=head_n, fm_sz=fm_sz)
])
def call(self, x):
x0 = self.botneck(x)
x = self.sa(x)
x = x + x0
x = self.pool(x)
return x
阅读全文