jieshixi解释下B = out.size(0)//p # repeat重复指定维度 hidden = self.rnn(out, hidden.repeat(p,1,1,1)) labels = torch.softmax(self.classifier(hidden), 1) labels = labels.reshape(p, B, self.classes, self.ggridsz**2) # (p,64,10,49) 预测出的labels # ent(p,64,49) ent = (- labels * torch.log(labels)).sum(-2) # ent=-x*log(x)[log以e为底],sum(-2)是以-2维度相加 # cent(p,64,49) pastp是没有rnn之前的概率上面一行的labels是rnn之后的概率 cent = (- labels * torch.log(pastp[None,:,:,None])).sum(-2) kld = (cent - ent).mean(0) # (64,49) kld[mask.reshape(B, self.ggridsz**2)] = -100 locmax = kld.argmax(-1) loc = torch.stack([locmax//self.ggridsz, locmax%self.ggridsz],-1) kld = kld.reshape(B, 1, self.ggridsz, self.ggridsz)
时间: 2023-04-07 09:03:47 浏览: 94
使用 tf.nn.dynamic_rnn 展开时间维度方式
这段代码中,out.size(0)表示out张量的第一个维度的大小,即张量的batch size。//表示整除运算符,p为一个整数。因此,B = out.size(0)//p表示将batch size除以p,得到的结果向下取整,赋值给变量B。
阅读全文