model_input = tf.concat([tf.expand_dims(perm, axis=1), tf.expand_dims(poro, axis=1)], axis=1) 中的axis=1都是什么意思
时间: 2023-09-03 21:08:43 浏览: 136
Python中axis=0和axis=1的理解
在这个代码中,`tf.expand_dims()`函数是用来增加维度的,`axis`参数表示在哪个维度上增加维度。而`tf.concat()`函数是用来在某个维度上拼接张量的。
具体来说,`tf.expand_dims(perm, axis=1)`表示将`perm`张量在第1维上增加一个维度,即将形状为`(batch_size, length)`的张量变成形状为`(batch_size, 1, length)`的张量。同理,`tf.expand_dims(poro, axis=1)`也是在第1维上增加一个维度。
然后,`tf.concat([tf.expand_dims(perm, axis=1), tf.expand_dims(poro, axis=1)], axis=1)`表示在第1维上拼接这两个形状相同的张量,即将两个形状为`(batch_size, 1, length)`的张量拼接成形状为`(batch_size, 2, length)`的张量。其中,`axis=1`表示在第1维上拼接。
阅读全文