没有合适的资源?快使用搜索试试~ 我知道了~
首页Pytorch: 自定义网络层实例
Pytorch: 自定义网络层实例
9 下载量 156 浏览量
更新于2023-03-16
评论
收藏 70KB PDF 举报
今天小编就为大家分享一篇Pytorch: 自定义网络层实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
资源详情
资源评论
资源推荐
Pytorch: 自定义网络层实例自定义网络层实例
今天小编就为大家分享一篇Pytorch: 自定义网络层实例,具有很好的参考价值,希望对大家有所帮助。一起跟
随小编过来看看吧
自定义自定义Autograd函数函数
对于浅层的网络,我们可以手动的书写前向传播和反向传播过程。但是当网络变得很大时,特别是在做深度学习时,网络结构
变得复杂。前向传播和反向传播也随之变得复杂,手动书写这两个过程就会存在很大的困难。幸运地是在pytorch中存在了自
动微分的包,可以用来解决该问题。在使用自动求导的时候,网络的前向传播会定义一个计算图(computational graph),
图中的节点是张量(tensor),两个节点之间的边对应了两个张量之间变换关系的函数。有了计算图的存在,张量的梯度计算
也变得容易了些。例如, x是一个张量,其属性 x.requires_grad = True,那么 x.grad就是一个保存这个张量x的梯度的一些标
量值。
最基础的自动求导操作在底层就是作用在两个张量上。前向传播函数是从输入张量到输出张量的计算过程;反向传播是输入输
出张量的梯度(一些标量)并输出输入张量的梯度(一些标量)。在pytorch中我们可以很容易地定义自己的自动求导操作,
通过继承torch.autograd.Function并定义forward和backward函数。
forward(): 前向传播操作。可以输入任意多的参数,任意的python对象都可以。
backward():反向传播(梯度公式)。输出的梯度个数需要与所使用的张量个数保持一致,且返回的顺序也要对应起来。
# Inherit from Function
class LinearFunction(Function):
# Note that both forward and backward are @staticmethods
@staticmethod
# bias is an optional argument
def forward(ctx, input, weight, bias=None):
# ctx在这里类似self,ctx的属性可以在backward中调用
ctx.save_for_backward(input, weight, bias)
output = input.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
# This function has only a single output, so it gets only one gradient
@staticmethod
def backward(ctx, grad_output):
# This is a pattern that is very convenient - at the top of backward
# unpack saved_tensors and initialize all gradients w.r.t. inputs to
# None. Thanks to the fact that additional trailing Nones are
# ignored, the return statement is simple even when the function has
# optional inputs.
input, weight, bias = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
# These needs_input_grad checks are optional and there only to
# improve efficiency. If you want to make your code simpler, you can
# skip them. Returning gradients for inputs that don't require it is
# not an error.
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(weight)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().mm(input)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0).squeeze(0)
return grad_input, grad_weight, grad_bias
#调用自定义的自动求导函数
linear = LinearFunction.apply(*args) #前向传播
linear.backward()#反向传播
linear.grad_fn.apply(*args)#反向传播
对于非参数化的张量(权重是常量,不需要更新),此时可以定义为:
class MulConstant(Function):
@staticmethod
def forward(ctx, tensor, constant):
# ctx is a context object that can be used to stash information
# for backward computation
ctx.constant = constant
return tensor * constant
@staticmethod
weixin_38724349
- 粉丝: 5
- 资源: 916
上传资源 快速赚钱
- 我的内容管理 收起
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
会员权益专享
最新资源
- RTL8188FU-Linux-v5.7.4.2-36687.20200602.tar(20765).gz
- c++校园超市商品信息管理系统课程设计说明书(含源代码) (2).pdf
- 建筑供配电系统相关课件.pptx
- 企业管理规章制度及管理模式.doc
- vb打开摄像头.doc
- 云计算-可信计算中认证协议改进方案.pdf
- [详细完整版]单片机编程4.ppt
- c语言常用算法.pdf
- c++经典程序代码大全.pdf
- 单片机数字时钟资料.doc
- 11项目管理前沿1.0.pptx
- 基于ssm的“魅力”繁峙宣传网站的设计与实现论文.doc
- 智慧交通综合解决方案.pptx
- 建筑防潮设计-PowerPointPresentati.pptx
- SPC统计过程控制程序.pptx
- SPC统计方法基础知识.pptx
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功
评论0