Pytorch转转keras的有效方法的有效方法,以以FlowNet为例讲解为例讲解
Pytorch凭借动态图机制,获得了广泛的使用,大有超越tensorflow的趋势,不过在工程应用上,TF仍然占据优势。有的时候
我们会遇到这种情况,需要把模型应用到工业中,运用到实际项目上,TF支持的PB文件和TF的C++接口就成为了有效的工
具。今天就给大家讲解一下Pytorch转成Keras的方法,进而我们也可以获得Pb文件,因为Keras是支持tensorflow的,我将会
在下一篇博客讲解获得Pb文件,并使用Pb文件的方法。
Pytorch To Keras
首先,我们必须有清楚的认识,网上以及github上一些所谓的pytorch转换Keras或者Keras转换成Pytorch的工具代码几乎不能
运行或者有使用的局限性(比如仅仅能转换某一些模型),但是我们是可以用这些转换代码中看出一些端倪来,比如二者的参
数的尺寸(shape)的形式、channel的排序(first or last)是否一样,掌握到差异性,就能根据这些差异自己编写转换代码,
没错,自己编写转换代码,是最稳妥的办法。整个过程也就分为两个部分。笔者将会以Nvidia开源的FlowNet为例,将开源的
Pytorch代码转化为Keras模型。
按照Pytorch中模型的结构,编写对应的Keras代码,用keras的函数式API,构建起来会非常方便。
把Pytorch的模型参数,按照层的名称依次赋值给Keras的模型
以上两步虽然看上去简单,但实际我也走了不少弯路。这里一个关键的地方,就是参数的shape在两个框架中是否统一,那当
然是不统一的。下面我以FlowNet为例。
Pytorch中的中的FlowNet代码代码
我们仅仅展示层名称和层参数,就不把整个结构贴出来了,否则会占很多的空间,形成水文。
先看用Keras搭建的flowNet模型,直接用model.summary()输出模型信息
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, 6, 512, 512) 0
__________________________________________________________________________________________________
conv0 (Conv2D) (None, 64, 512, 512) 3520 input_1[0][0]
__________________________________________________________________________________________________
leaky_re_lu_1 (LeakyReLU) (None, 64, 512, 512) 0 conv0[0][0]
__________________________________________________________________________________________________
zero_padding2d_1 (ZeroPadding2D (None, 64, 514, 514) 0 leaky_re_lu_1[0][0]
__________________________________________________________________________________________________
conv1 (Conv2D) (None, 64, 256, 256) 36928 zero_padding2d_1[0][0]
__________________________________________________________________________________________________
leaky_re_lu_2 (LeakyReLU) (None, 64, 256, 256) 0 conv1[0][0]
__________________________________________________________________________________________________
conv1_1 (Conv2D) (None, 128, 256, 256 73856 leaky_re_lu_2[0][0]
__________________________________________________________________________________________________
leaky_re_lu_3 (LeakyReLU) (None, 128, 256, 256 0 conv1_1[0][0]
__________________________________________________________________________________________________
zero_padding2d_2 (ZeroPadding2D (None, 128, 258, 258 0 leaky_re_lu_3[0][0]
__________________________________________________________________________________________________
conv2 (Conv2D) (None, 128, 128, 128 147584 zero_padding2d_2[0][0]
__________________________________________________________________________________________________
leaky_re_lu_4 (LeakyReLU) (None, 128, 128, 128 0 conv2[0][0]
__________________________________________________________________________________________________
conv2_1 (Conv2D) (None, 128, 128, 128 147584 leaky_re_lu_4[0][0]
__________________________________________________________________________________________________
leaky_re_lu_5 (LeakyReLU) (None, 128, 128, 128 0 conv2_1[0][0]
__________________________________________________________________________________________________
zero_padding2d_3 (ZeroPadding2D (None, 128, 130, 130 0 leaky_re_lu_5[0][0]
__________________________________________________________________________________________________
conv3 (Conv2D) (None, 256, 64, 64) 295168 zero_padding2d_3[0][0]
__________________________________________________________________________________________________
leaky_re_lu_6 (LeakyReLU) (None, 256, 64, 64) 0 conv3[0][0]
__________________________________________________________________________________________________
conv3_1 (Conv2D) (None, 256, 64, 64) 590080 leaky_re_lu_6[0][0]
__________________________________________________________________________________________________
leaky_re_lu_7 (LeakyReLU) (None, 256, 64, 64) 0 conv3_1[0][0]
__________________________________________________________________________________________________
zero_padding2d_4 (ZeroPadding2D (None, 256, 66, 66) 0 leaky_re_lu_7[0][0]
__________________________________________________________________________________________________
conv4 (Conv2D) (None, 512, 32, 32) 1180160 zero_padding2d_4[0][0]
__________________________________________________________________________________________________
leaky_re_lu_8 (LeakyReLU) (None, 512, 32, 32) 0 conv4[0][0]
__________________________________________________________________________________________________
conv4_1 (Conv2D) (None, 512, 32, 32) 2359808 leaky_re_lu_8[0][0]
__________________________________________________________________________________________________