没有合适的资源?快使用搜索试试~ 我知道了~
首页Keras实现将两个模型连接到一起
资源详情
资源评论
资源推荐
Keras实现将两个模型连接到一起实现将两个模型连接到一起
主要介绍了Keras实现将两个模型连接到一起,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来
看看吧
神经网络玩得越久就越会尝试一些网络结构上的大改动。
先说意图先说意图
有两个模型:模型A和模型B。模型A的输出可以连接B的输入。将两个小模型连接成一个大模型,A-B,既可以同时训练又可
以分离训练。
流行的算法里经常有这么关系的两个模型,对GAN来说,生成器和判别器就是这样子;对VAE来说,编码器和解码器就是这
样子;对目标检测网络来说,backbone和整体也是可以拆分的。所以,应用范围还是挺广的。
实现方法实现方法
首先说明,我的实现方法不一定是最佳方法。也是实在没有借鉴到比较好的方法,所以才自己手动写了一个。
第一步,我们有现成的两个模型A和B;我们想把A的输出连到B的输入,组成一个整体C。
第二步, 重构新模型C;我的方法是:读出A和B各有哪些layer,然后一层一层重新搭成C。
可以看一个自编码器的代码(本人所编写):
class AE:
def __init__(self, dim, img_dim, batch_size):
self.dim = dim
self.img_dim = img_dim
self.batch_size = batch_size
self.encoder = self.encoder_construct()
self.decoder = self.decoder_construct()
def encoder_construct(self):
x_in = Input(shape=(self.img_dim, self.img_dim, 3))
x = x_in
x = Conv2D(self.dim // 16, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)
x = Conv2D(self.dim // 8, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)
x = Conv2D(self.dim // 4, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)
x = Conv2D(self.dim // 2, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)
x = Conv2D(self.dim, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)
x = GlobalAveragePooling2D()(x)
encoder = Model(x_in, x)
return encoder
def decoder_construct(self):
map_size = K.int_shape(self.encoder.layers[-2].output)[1:-1]
# print(type(map_size))
z_in = Input(shape=K.int_shape(self.encoder.output)[1:])
z = z_in
z_dim = self.dim
z = Dense(np.prod(map_size) * z_dim)(z)
z = Reshape(map_size + (z_dim,))(z)
z = Conv2DTranspose(z_dim // 2, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
z = BatchNormalization()(z)
z = Activation('relu')(z)
z = Conv2DTranspose(z_dim // 4, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
z = BatchNormalization()(z)
z = Activation('relu')(z)
z = Conv2DTranspose(z_dim // 8, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
z = BatchNormalization()(z)
z = Activation('relu')(z)
z = Conv2DTranspose(z_dim // 16, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
z = BatchNormalization()(z)
z = Activation('relu')(z)
z = Conv2DTranspose(3, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
z = Activation('tanh')(z)
weixin_38526208
- 粉丝: 3
- 资源: 940
上传资源 快速赚钱
- 我的内容管理 收起
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
会员权益专享
最新资源
- RTL8188FU-Linux-v5.7.4.2-36687.20200602.tar(20765).gz
- c++校园超市商品信息管理系统课程设计说明书(含源代码) (2).pdf
- 建筑供配电系统相关课件.pptx
- 企业管理规章制度及管理模式.doc
- vb打开摄像头.doc
- 云计算-可信计算中认证协议改进方案.pdf
- [详细完整版]单片机编程4.ppt
- c语言常用算法.pdf
- c++经典程序代码大全.pdf
- 单片机数字时钟资料.doc
- 11项目管理前沿1.0.pptx
- 基于ssm的“魅力”繁峙宣传网站的设计与实现论文.doc
- 智慧交通综合解决方案.pptx
- 建筑防潮设计-PowerPointPresentati.pptx
- SPC统计过程控制程序.pptx
- SPC统计方法基础知识.pptx
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功
评论0