def doTrainInstanceOf(self, ids): entity_embs = self.entity_vec(ids[[0, 2], :]) concept_embs = self.concept_vec(ids[[1, 3], :]) radius = concept_embs[:, :, -1] concept_embs = concept_embs[:, :, :-1] if self.args.pnorm==1: dis = F.relu(norm(entity_embs - concept_embs,pnorm=self.args.pnorm) - torch.abs(radius)) else: dis = F.relu(norm(entity_embs - concept_embs,pnorm=self.args.pnorm) - radius ** 2) loss = F.relu(dis[0] + self.args.margin_ins - dis[1]).sum() return loss改为分类损失
时间: 2023-06-15 08:04:31 浏览: 110
ENTITY_LIST类-ACIS系统简介
如果你想将这个函数改为分类损失,你需要知道你的输入数据有多少个类别。假设你的类别数为num_classes,你需要将原来的输出维度从1变成num_classes,并使用交叉熵损失函数。修改后的代码如下:
```python
def doTrainInstanceOf(self, ids, targets):
entity_embs = self.entity_vec(ids[[0, 2], :])
concept_embs = self.concept_vec(ids[[1, 3], :])
radius = concept_embs[:, :, -1]
concept_embs = concept_embs[:, :, :-1]
if self.args.pnorm==1:
dis = F.relu(norm(entity_embs - concept_embs,pnorm=self.args.pnorm) - torch.abs(radius))
else:
dis = F.relu(norm(entity_embs - concept_embs,pnorm=self.args.pnorm) - radius ** 2)
loss = F.cross_entropy(dis, targets)
return loss
```
在这个版本中,我们将输出从1维变成了num_classes维。我们还使用了交叉熵损失函数来计算损失。交叉熵损失函数可以将输出看作每个类别的概率分布,因此我们可以将dis直接输入到交叉熵损失函数中。我们还需要传入一个targets参数,该参数是一个长为batch_size的一维张量,其中包含每个样本的类别标签。
阅读全文