用java实现下协同过滤算法
时间: 2023-09-30 16:08:37 浏览: 84
协同过滤算法是一种常见的推荐系统算法,可以通过用户对物品的评分数据来推荐给用户可能感兴趣的物品。下面是使用Java实现基于用户的协同过滤算法的示例代码:
```java
import java.util.*;
public class UserBasedCF {
private Map<Integer, Map<Integer, Double>> userItemRatingMatrix; // 用户-物品评分矩阵
private Map<Integer, Set<Integer>> itemUserMatrix; // 物品-用户评分矩阵
private Map<Integer, Set<Integer>> userSimMatrix; // 用户相似度矩阵
public UserBasedCF(Map<Integer, Map<Integer, Double>> userItemRatingMatrix) {
this.userItemRatingMatrix = userItemRatingMatrix;
itemUserMatrix = new HashMap<Integer, Set<Integer>>();
userSimMatrix = new HashMap<Integer, Set<Integer>>();
for (Integer userId : userItemRatingMatrix.keySet()) {
Map<Integer, Double> itemRatingMap = userItemRatingMatrix.get(userId);
for (Integer itemId : itemRatingMap.keySet()) {
if (!itemUserMatrix.containsKey(itemId)) {
itemUserMatrix.put(itemId, new HashSet<Integer>());
}
itemUserMatrix.get(itemId).add(userId);
}
}
}
/**
* 计算用户相似度矩阵
*/
public void calculateUserSimMatrix() {
for (Integer u1 : userItemRatingMatrix.keySet()) {
for (Integer u2 : userItemRatingMatrix.keySet()) {
if (u1.equals(u2)) {
continue;
}
double sim = cosineSimilarity(u1, u2);
if (sim > 0) {
if (!userSimMatrix.containsKey(u1)) {
userSimMatrix.put(u1, new HashSet<Integer>());
}
userSimMatrix.get(u1).add(u2);
}
}
}
}
/**
* 基于用户的协同过滤算法
*
* @param userId 用户ID
* @param topN 推荐物品数量
* @return 推荐的物品ID列表
*/
public List<Integer> recommendItemsByUserCF(Integer userId, Integer topN) {
Set<Integer> ratedItems = userItemRatingMatrix.get(userId).keySet(); // 用户已评价物品集合
Map<Integer, Double> scores = new HashMap<Integer, Double>();
for (Integer simUser : userSimMatrix.get(userId)) {
for (Integer itemId : userItemRatingMatrix.get(simUser).keySet()) {
if (ratedItems.contains(itemId)) { // 用户已评价过该物品
continue;
}
if (!scores.containsKey(itemId)) {
scores.put(itemId, 0.0);
}
scores.put(itemId, scores.get(itemId) + userItemRatingMatrix.get(simUser).get(itemId));
}
}
List<Map.Entry<Integer, Double>> sortedScores = new ArrayList<>(scores.entrySet());
sortedScores.sort((o1, o2) -> o2.getValue().compareTo(o1.getValue())); // 按分数降序排序
List<Integer> recommendedItems = new ArrayList<>();
for (int i = 0; i < Math.min(topN, sortedScores.size()); i++) {
recommendedItems.add(sortedScores.get(i).getKey());
}
return recommendedItems;
}
/**
* 计算用户间的余弦相似度
*
* @param u1 用户1 ID
* @param u2 用户2 ID
* @return 余弦相似度
*/
private double cosineSimilarity(Integer u1, Integer u2) {
Set<Integer> commonItems = new HashSet<Integer>(userItemRatingMatrix.get(u1).keySet());
commonItems.retainAll(userItemRatingMatrix.get(u2).keySet()); // 共同评分物品集合
if (commonItems.isEmpty()) {
return 0;
}
double dotProduct = 0;
double norm1 = 0;
double norm2 = 0;
for (Integer itemId : commonItems) {
double rating1 = userItemRatingMatrix.get(u1).get(itemId);
double rating2 = userItemRatingMatrix.get(u2).get(itemId);
dotProduct += rating1 * rating2;
norm1 += rating1 * rating1;
norm2 += rating2 * rating2;
}
return dotProduct / Math.sqrt(norm1 * norm2);
}
public static void main(String[] args) {
Map<Integer, Map<Integer, Double>> userItemRatingMatrix = new HashMap<Integer, Map<Integer, Double>>();
Map<Integer, Double> itemRatingMap1 = new HashMap<Integer, Double>();
itemRatingMap1.put(1, 5.0);
itemRatingMap1.put(2, 3.0);
itemRatingMap1.put(3, 4.0);
Map<Integer, Double> itemRatingMap2 = new HashMap<Integer, Double>();
itemRatingMap2.put(1, 3.0);
itemRatingMap2.put(2, 4.0);
itemRatingMap2.put(4, 4.0);
Map<Integer, Double> itemRatingMap3 = new HashMap<Integer, Double>();
itemRatingMap3.put(2, 2.0);
itemRatingMap3.put(3, 3.0);
itemRatingMap3.put(4, 5.0);
Map<Integer, Double> itemRatingMap4 = new HashMap<Integer, Double>();
itemRatingMap4.put(1, 4.0);
itemRatingMap4.put(3, 3.0);
itemRatingMap4.put(4, 4.0);
userItemRatingMatrix.put(1, itemRatingMap1);
userItemRatingMatrix.put(2, itemRatingMap2);
userItemRatingMatrix.put(3, itemRatingMap3);
userItemRatingMatrix.put(4, itemRatingMap4);
UserBasedCF userBasedCF = new UserBasedCF(userItemRatingMatrix);
userBasedCF.calculateUserSimMatrix();
List<Integer> recommendedItems = userBasedCF.recommendItemsByUserCF(2, 2);
System.out.println(recommendedItems);
}
}
```
注:上述代码中的`userItemRatingMatrix`是一个以用户ID为Key、以物品ID和评分为Value的Map,存储用户对物品的评分数据。`itemUserMatrix`是一个以物品ID为Key、以用户ID集合为Value的Map,用于计算用户相似度矩阵。`userSimMatrix`是一个以用户ID为Key、以相似用户ID集合为Value的Map,存储用户间的相似度关系。`cosineSimilarity`方法计算用户间的余弦相似度。`recommendItemsByUserCF`方法根据用户相似度矩阵推荐给用户可能感兴趣的物品。
阅读全文