利用pytorch实现推荐算法FM
时间: 2023-06-19 17:07:51 浏览: 44
FM(Factorization Machines)是一种经典的推荐算法。它通过对特征进行二阶多项式扩展,建立了一个能够考虑特征交互的模型。在这个模型中,每个特征都有一个隐向量表示,通过对这些隐向量的乘积求和来计算二阶特征交互的权重。
下面是使用PyTorch实现FM模型的代码:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
class FM(nn.Module):
def __init__(self, n_feat, k):
super(FM, self).__init__()
self.n_feat = n_feat
self.k = k
self.linear = nn.Linear(n_feat, 1)
self.embedding = nn.Embedding(n_feat, k)
def forward(self, x):
# x: (batch_size, n_feat)
linear_part = self.linear(x).squeeze(1) # (batch_size,)
embedding_part = torch.sum(self.embedding(x), dim=1) # (batch_size, k)
square_of_sum = torch.sum(embedding_part, dim=1) ** 2 # (batch_size,)
sum_of_square = torch.sum(embedding_part ** 2, dim=1) # (batch_size,)
second_order = 0.5 * (square_of_sum - sum_of_square) # (batch_size,)
y = linear_part + second_order # (batch_size,)
return y
class RatingDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
x = self.data[index, :-1]
y = self.data[index, -1]
return x, y
if __name__ == '__main__':
# 读取数据
data = torch.tensor([
[0, 1, 1],
[1, 0, 1],
[1, 1, 1],
[0, 0, 0],
[0, 1, 0],
[1, 0, 0],
], dtype=torch.long)
target = torch.tensor([5, 3, 4, 1, 2, 1], dtype=torch.float)
dataset = RatingDataset(torch.cat([data, target.unsqueeze(dim=1)], dim=1))
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
# 定义模型
model = FM(n_feat=3, k=2)
# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 训练模型
for epoch in range(100):
for x, y in dataloader:
optimizer.zero_grad()
y_pred = model(x)
loss = criterion(y_pred, y)
loss.backward()
optimizer.step()
print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, 100, loss.item()))
# 预测
with torch.no_grad():
x_test = torch.tensor([[0, 1, 0], [1, 1, 0]], dtype=torch.long)
y_test = torch.tensor([0, 0], dtype=torch.float)
y_pred = model(x_test)
print('Test Loss: {:.4f}'.format(criterion(y_pred, y_test).item()))
print('Predictions:', y_pred)
```
在这个代码中,我们首先定义了一个继承自`nn.Module`的`FM`类,其中`linear`是一阶线性部分的权重,`embedding`是二阶交互部分的权重。在`forward`函数中,我们首先计算一阶部分的输出,然后计算二阶部分的输出,并将两部分相加得到最终输出。
接下来我们定义了一个继承自`Dataset`的`RatingDataset`类,用于读取数据和构建数据集。在主函数中,我们首先读取数据并构建数据集和数据加载器。然后定义了一个`FM`模型、一个均方误差损失函数和一个随机梯度下降优化器。在训练过程中,我们遍历数据集中的所有样本,计算损失并更新模型参数。最后,我们使用训练好的模型对测试集进行预测,并输出预测结果和测试误差。