举例2个模型相似度代码
时间: 2024-03-16 14:46:39 浏览: 62
以下是两个PyTorch模型的相似度计算示例代码:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
# 定义两个模型
class ModelA(nn.Module):
def __init__(self):
super(ModelA, self).__init__()
self.fc = nn.Linear(10, 5)
def forward(self, x):
x = self.fc(x)
return x
class ModelB(nn.Module):
def __init__(self):
super(ModelB, self).__init__()
self.fc = nn.Linear(10, 5)
def forward(self, x):
x = self.fc(x)
return x
# 初始化两个模型
modelA = ModelA()
modelB = ModelB()
# 计算两个模型的相似度
def calc_similarity(model1, model2):
# 将模型权重展开为向量
vec1 = torch.cat([p.view(-1) for p in model1.parameters()])
vec2 = torch.cat([p.view(-1) for p in model2.parameters()])
# 计算两个向量的余弦相似度
sim = F.cosine_similarity(vec1, vec2, dim=0)
return sim.item()
similarity = calc_similarity(modelA, modelB)
print("The similarity between ModelA and ModelB is:", similarity)
```
在上面的代码中,我们定义了两个PyTorch模型`ModelA`和`ModelB`,并初始化了它们。然后,我们定义了一个函数`calc_similarity`来计算两个模型的相似度。在这个函数中,我们首先将模型权重展开为向量,然后使用PyTorch内置函数`F.cosine_similarity`来计算两个向量的余弦相似度。最后,我们输出了两个模型的相似度。
阅读全文