MapReduce实现矩阵相乘java代码及实例
时间: 2024-05-04 14:18:34 浏览: 5
以下是一个简单的 MapReduce 实现矩阵相乘的 Java 代码:
Map 阶段:
```java
public static class Map extends Mapper<LongWritable, Text, Text, Text> {
private Text outputKey = new Text();
private Text outputValue = new Text();
public void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
Configuration conf = context.getConfiguration();
int m = Integer.parseInt(conf.get("m"));
int p = Integer.parseInt(conf.get("p"));
String line = value.toString();
String[] tokens = line.split(",");
if (tokens[0].equals("A")) {
for (int k = 0; k < p; k++) {
outputKey.set(tokens[1] + "," + k);
outputValue.set("A," + tokens[2] + "," + tokens[3]);
context.write(outputKey, outputValue);
}
} else {
for (int i = 0; i < m; i++) {
outputKey.set(i + "," + tokens[2]);
outputValue.set("B," + tokens[1] + "," + tokens[3]);
context.write(outputKey, outputValue);
}
}
}
}
```
Reduce 阶段:
```java
public static class Reduce extends Reducer<Text, Text, Text, Text> {
public void reduce(Text key, Iterable<Text> values, Context context) throws IOException, InterruptedException {
String[] value;
HashMap<Integer, Float> hashA = new HashMap<Integer, Float>();
HashMap<Integer, Float> hashB = new HashMap<Integer, Float>();
for (Text val : values) {
value = val.toString().split(",");
if (value[0].equals("A")) {
hashA.put(Integer.parseInt(value[1]), Float.parseFloat(value[2]));
} else {
hashB.put(Integer.parseInt(value[1]), Float.parseFloat(value[2]));
}
}
int n = Integer.parseInt(context.getConfiguration().get("n"));
float result = 0.0f;
float a_ij, b_jk;
for (int j = 0; j < n; j++) {
a_ij = hashA.containsKey(j) ? hashA.get(j) : 0.0f;
b_jk = hashB.containsKey(j) ? hashB.get(j) : 0.0f;
result += a_ij * b_jk;
}
context.write(null, new Text(key.toString() + "," + Float.toString(result)));
}
}
```
驱动程序:
```java
public static void main(String[] args) throws Exception {
Configuration conf = new Configuration();
int m = 2;
int n = 2;
int p = 2;
conf.set("m", String.valueOf(m));
conf.set("n", String.valueOf(n));
conf.set("p", String.valueOf(p));
Job job = new Job(conf, "MatrixMultiply");
job.setJarByClass(MatrixMultiply.class);
job.setOutputKeyClass(Text.class);
job.setOutputValueClass(Text.class);
job.setMapperClass(Map.class);
job.setReducerClass(Reduce.class);
job.setInputFormatClass(TextInputFormat.class);
job.setOutputFormatClass(TextOutputFormat.class);
FileInputFormat.addInputPath(job, new Path(args[0]));
FileOutputFormat.setOutputPath(job, new Path(args[1]));
job.waitForCompletion(true);
}
```
输入数据格式为:
```
A,0,0,1
A,0,1,2
B,0,0,3
B,1,0,4
```
输出数据格式为:
```
null 0,0,11.0
null 0,1,16.0
null 1,0,23.0
null 1,1,34.0
```
其中第一列为 A 矩阵的行号,第二列为 B 矩阵的列号,第三列为两个矩阵相乘的结果。