使用torch对知识表示学习模型RotatE进行修改,使之输出自对抗负采样损失(NSSALoss),而不保留实体关系嵌入向量
时间: 2024-02-25 09:51:39 浏览: 136
Pytorch深度学习(5) — 词向量及Glove预训练模型使用
5星 · 资源好评率100%
好的,下面是一个可能的实现方法:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class RotatE_NSSA(nn.Module):
def __init__(self, num_entities, num_relations, embedding_dim, margin, nssa_lambda):
super(RotatE_NSSA, self).__init__()
self.num_entities = num_entities
self.num_relations = num_relations
self.embedding_dim = embedding_dim
self.margin = margin
self.nssa_lambda = nssa_lambda
self.entity_embedding = nn.Embedding(num_entities, embedding_dim)
self.relation_embedding = nn.Embedding(num_relations, embedding_dim)
self.relation_embedding_weight = nn.Parameter(torch.Tensor(num_relations, embedding_dim))
nn.init.xavier_uniform_(self.entity_embedding.weight.data)
nn.init.xavier_uniform_(self.relation_embedding.weight.data)
nn.init.xavier_uniform_(self.relation_embedding_weight.data)
def _calc(self, h, r, t):
pi = 3.14159265358979323846
re_head, im_head = torch.chunk(h, 2, dim=-1)
re_tail, im_tail = torch.chunk(t, 2, dim=-1)
re_relation, im_relation = torch.chunk(self.relation_embedding_weight[r], 2, dim=-1)
# make phases of relation embeddings in [-pi/2, pi/2]
phase_relation = torch.atan2(im_relation, re_relation)
phase_relation = torch.remainder(phase_relation + pi/2, pi) - pi/2
re_relation = torch.cos(phase_relation)
im_relation = torch.sin(phase_relation)
re_score = re_head * re_relation - im_head * im_relation - re_tail
im_score = re_head * im_relation + im_head * re_relation - im_tail
score = torch.stack([re_score, im_score], dim=0)
score = score.norm(dim=0)
return score
def forward(self, pos_h, pos_r, pos_t, neg_h, neg_r, neg_t):
pos_score = self._calc(self.entity_embedding(pos_h), pos_r, self.entity_embedding(pos_t))
neg_score = self._calc(self.entity_embedding(neg_h), neg_r, self.entity_embedding(neg_t))
score = torch.cat([pos_score, neg_score], dim=-1)
# nssa_loss
nssa_loss = torch.mean(torch.logsumexp(score * self.nssa_lambda, dim=-1))
# margin loss
margin_loss = torch.mean(F.relu(self.margin + pos_score - neg_score))
return margin_loss + nssa_loss
def predict(self, h, r, t):
score = self._calc(self.entity_embedding(h), r, self.entity_embedding(t))
return score
```
在这个修改版的RotatE模型中,我们添加了一个`nssa_loss`,用于计算自对抗负采样损失。具体来说,在每个训练batch中,我们先计算正样本和负样本的分数,然后将它们拼接起来,得到一个形状为`[2, batch_size]`的张量。接着,我们将这个张量乘以一个超参数`nssa_lambda`,并对每个张量元素取`logsumexp`,最后对所有元素取平均值,得到`nssa_loss`。
在训练过程中,我们可以将`nssa_loss`和`margin_loss`相加,得到总的损失函数。`margin_loss`和原版的RotatE是一样的,用于保证正样本的分数高于负样本的分数。但是,由于我们将`nssa_loss`作为模型的输出,而不是实体关系嵌入向量,所以我们不需要保留实体关系嵌入向量。这样可以减少模型的计算复杂度,并提高模型的训练效率。
在使用这个修改版的RotatE模型进行训练时,我们只需要传入正样本和负样本的三元组,而不需要传入实体关系嵌入向量。在训练完毕后,我们可以使用`predict`方法来预测任意一个三元组的分数。
阅读全文