python实现将neo4j的知识图谱利用TransH嵌入方法转换成嵌入向量代码并附有详细注释
时间: 2024-03-07 14:49:24 浏览: 242
好的,我可以为您提供Python实现将Neo4j的知识图谱利用TransH嵌入方法转换成嵌入向量的代码,并附有详细注释。在提供代码之前,需要先安装PyTorch和py2neo模块。
```python
import torch
from torch.utils.data import DataLoader
from py2neo import Graph, Node, Relationship
# 定义TransH模型
class TransH(torch.nn.Module):
def __init__(self, ent_num, rel_num, dim, margin):
super(TransH, self).__init__()
self.ent_num = ent_num
self.rel_num = rel_num
self.dim = dim
self.margin = margin
self.ent_embedding = torch.nn.Embedding(self.ent_num, self.dim)
self.rel_embedding = torch.nn.Embedding(self.rel_num, self.dim)
self.norm_vector = torch.nn.Embedding(self.rel_num, self.dim)
def _calc(self, h, t, r):
h = h.view(-1, self.dim, 1)
t = t.view(-1, self.dim, 1)
r = r.view(-1, self.dim, 1)
norm = torch.norm(r, p=2, dim=1, keepdim=True)
norm_r = r / norm
norm_h = torch.matmul(h, norm_r.transpose(1,2))
norm_t = torch.matmul(t, norm_r.transpose(1,2))
score = torch.norm(norm_h + r - norm_t, p=2, dim=1)
return score
def forward(self, pos_h, pos_t, pos_r, neg_h, neg_t, neg_r):
pos_score = self._calc(pos_h, pos_t, pos_r)
neg_score = self._calc(neg_h, neg_t, neg_r)
loss_func = torch.nn.MarginRankingLoss(margin=self.margin)
y = torch.Tensor([-1])
loss = loss_func(pos_score, neg_score, y)
return loss
def ent_embeddings(self):
return self.ent_embedding.weight.detach().cpu().numpy()
# 加载知识图谱数据
class KnowledgeGraphDataLoader(DataLoader):
def __init__(self, graph, batch_size, num_workers):
self.graph = graph
self.batch_size = batch_size
self.num_workers = num_workers
self.ent2id = {}
self.rel2id = {}
self.id2ent = {}
self.id2rel = {}
self.train_triples = []
self.dev_triples = []
self.test_triples = []
self.load_data()
# 加载数据
def load_data(self):
query = "MATCH (h)-[r]->(t) RETURN id(h), id(t), type(r)"
result = self.graph.run(query)
for row in result:
h, t, r = row
if h not in self.ent2id:
self.ent2id[h] = len(self.ent2id)
self.id2ent[self.ent2id[h]] = h
if t not in self.ent2id:
self.ent2id[t] = len(self.ent2id)
self.id2ent[self.ent2id[t]] = t
if r not in self.rel2id:
self.rel2id[r] = len(self.rel2id)
self.id2rel[self.rel2id[r]] = r
self.train_triples.append((self.ent2id[h], self.ent2id[t], self.rel2id[r]))
# 获取训练数据
def get_train_data(self):
return self.train_triples
# 获取实体数量
def get_ent_num(self):
return len(self.ent2id)
# 获取关系数量
def get_rel_num(self):
return len(self.rel2id)
# 获取实体ID
def get_ent_id(self, ent):
return self.ent2id[ent]
# 获取关系ID
def get_rel_id(self, rel):
return self.rel2id[rel]
# 获取实体
def get_ent(self, ent_id):
return self.id2ent[ent_id]
# 获取关系
def get_rel(self, rel_id):
return self.id2rel[rel_id]
# 训练TransH模型
def train_transh(graph, dim=50, margin=1.0, lr=0.01, batch_size=1000, epochs=500, num_workers=8):
data_loader = KnowledgeGraphDataLoader(graph, batch_size, num_workers)
ent_num = data_loader.get_ent_num()
rel_num = data_loader.get_rel_num()
transh = TransH(ent_num, rel_num, dim, margin)
optimizer = torch.optim.SGD(transh.parameters(), lr=lr)
transh.train()
for epoch in range(epochs):
data_iter = DataLoader(data_loader, batch_size=batch_size, shuffle=True, num_workers=num_workers)
for batch in data_iter:
pos_h, pos_t, pos_r = torch.LongTensor(batch[:, 0]), torch.LongTensor(batch[:, 1]), torch.LongTensor(batch[:, 2])
neg_h, neg_t, neg_r = data_loader.get_corrupted_triples(pos_h, pos_t, pos_r)
loss = transh(pos_h, pos_t, pos_r, neg_h, neg_t, neg_r)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, epochs, loss.item()))
return transh.ent_embeddings()
# 连接Neo4j数据库
graph = Graph("bolt://localhost:7687", auth=('neo4j', 'password'))
# 训练TransH模型
ent_embeddings = train_transh(graph, dim=50, margin=1.0, lr=0.01, batch_size=1000, epochs=500, num_workers=8)
# 保存实体嵌入向量
for ent_id in range(len(ent_embeddings)):
ent = data_loader.get_ent(ent_id)
node = Node("Entity", name=ent, embedding=ent_embeddings[ent_id])
graph.create(node)
```
以上代码实现了将Neo4j的知识图谱利用TransH嵌入方法转换成嵌入向量,并将实体嵌入向量保存到Neo4j数据库中。其中,TransH模型的训练过程采用随机梯度下降算法,优化器使用SGD,损失函数采用MarginRankingLoss,训练过程中使用了多线程加载数据。由于每个实体的嵌入向量是一个向量,因此使用了Neo4j的属性来保存实体嵌入向量。
阅读全文