用pytorch实现FM
时间: 2023-06-15 16:07:33 浏览: 154
FM(Factorization Machine)是一种处理稀疏数据的机器学习模型,通常用于推荐系统和广告推荐等场景。下面是使用 PyTorch 实现 FM 的示例代码:
首先,导入所需的库:
```python
import torch
import torch.nn as nn
```
接着,定义 FM 模型:
```python
class FM(nn.Module):
def __init__(self, num_features, num_factors):
super(FM, self).__init__()
self.num_features = num_features
self.num_factors = num_factors
self.linear = nn.Linear(num_features, 1)
self.embedding = nn.Embedding(num_features, num_factors)
def forward(self, x):
linear_terms = self.linear(x).squeeze()
interactions = torch.sum(self.embedding(x), dim=1).pow(2).sum(1)
interactions -= torch.sum(self.embedding(x).pow(2), dim=1).sum(1)
fm_output = linear_terms + interactions
return fm_output
```
其中,num_features 表示特征的数量,num_factors 表示每个特征的隐向量维度。
在 forward 函数中,首先对输入的 x 进行线性变换,并将结果压缩到一维,得到 linear_terms。然后,对 x 进行 embedding,并在第二维上求和,得到交叉项的一次方和。再对 embedding 进行平方,然后在第二个维度上求和,得到交叉项的二次方和。最后,将 linear_terms 和 interactions 相加,得到 FM 模型的输出。
接着,定义训练过程:
```python
def train(model, optimizer, loss_fn, train_loader, num_epochs):
for epoch in range(num_epochs):
running_loss = 0.0
for i, data in enumerate(train_loader, 0):
inputs, labels = data
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print('Epoch %d, loss: %.3f' % (epoch + 1, running_loss / len(train_loader)))
```
其中,model 是 FM 模型,optimizer 是优化器,loss_fn 是损失函数,train_loader 是训练数据的 DataLoader,num_epochs 是训练次数。在训练过程中,对输入数据进行前向传播,计算损失,然后进行反向传播和参数更新。
最后,使用示例数据进行训练和测试:
```python
# 示例数据
X = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
y = torch.tensor([1, 0, 1])
# 定义模型和优化器
model = FM(num_features=X.shape[1], num_factors=5)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 训练模型
train(model, optimizer, nn.MSELoss(), torch.utils.data.DataLoader(torch.utils.data.TensorDataset(X, y), batch_size=1), num_epochs=10)
# 测试模型
inputs = torch.tensor([[0, 1, 2], [3, 4, 5]])
outputs = model(inputs)
print(outputs)
```
在上面的示例中,输入数据 X 是一个 3x3 的张量,表示 3 个样本,每个样本有 3 个特征。标签 y 是一个长度为 3 的张量,表示每个样本的标签。在训练过程中,使用 MSE 损失函数进行优化。最后,输入测试数据进行预测,并输出模型的输出。
阅读全文