用C语言写出使用分治法解决规模为n*n(n小于100)矩阵乘法的代码
时间: 2023-06-03 18:01:32 浏览: 96
矩阵乘法分治算法
以下是使用分治法解决规模为n*n(n小于100)矩阵乘法的C语言代码:
```c
#include <stdio.h>
#define N 100
void matrix_mult(int a[][N], int b[][N], int c[][N], int n){
if(n == 1){
c[0][0] = a[0][0] * b[0][0];
} else {
int i, j, k;
int a11[N/2][N/2], a12[N/2][N/2], a21[N/2][N/2], a22[N/2][N/2];
int b11[N/2][N/2], b12[N/2][N/2], b21[N/2][N/2], b22[N/2][N/2];
int c11[N/2][N/2], c12[N/2][N/2], c21[N/2][N/2], c22[N/2][N/2];
int p1[N/2][N/2], p2[N/2][N/2], p3[N/2][N/2], p4[N/2][N/2], p5[N/2][N/2], p6[N/2][N/2], p7[N/2][N/2];
// Divide matrices a and b into four sub-matrices
for(i = 0; i < n/2; i++){
for(j = 0; j < n/2; j++){
a11[i][j] = a[i][j];
a12[i][j] = a[i][j+n/2];
a21[i][j] = a[i+n/2][j];
a22[i][j] = a[i+n/2][j+n/2];
b11[i][j] = b[i][j];
b12[i][j] = b[i][j+n/2];
b21[i][j] = b[i+n/2][j];
b22[i][j] = b[i+n/2][j+n/2];
}
}
// Calculate p1, p2, p3, p4, p5, p6, p7
matrix_mult(a11, b11, p1, n/2);
matrix_mult(a12, b21, p2, n/2);
matrix_mult(a11, b12, p3, n/2);
matrix_mult(a12, b22, p4, n/2);
matrix_mult(a21, b11, p5, n/2);
matrix_mult(a22, b21, p6, n/2);
matrix_mult(a21, b12, p7, n/2);
// Calculate sub-matrices of c
for(i = 0; i < n/2; i++){
for(j = 0; j < n/2; j++){
c11[i][j] = p1[i][j] + p2[i][j];
c12[i][j] = p3[i][j] + p4[i][j];
c21[i][j] = p5[i][j] + p6[i][j];
c22[i][j] = p7[i][j] - p2[i][j] + p4[i][j] + p6[i][j];
}
}
// Combine sub-matrices of c
for(i = 0; i < n/2; i++){
for(j = 0; j < n/2; j++){
c[i][j] = c11[i][j];
c[i][j+n/2] = c12[i][j];
c[i+n/2][j] = c21[i][j];
c[i+n/2][j+n/2] = c22[i][j];
}
}
}
}
int main(){
int n, i, j;
int a[N][N], b[N][N], c[N][N];
printf("Enter the value of n (n should be less than 100):\n");
scanf("%d", &n);
printf("Enter the elements of matrix a:\n");
for(i = 0; i < n; i++){
for(j = 0; j < n; j++){
scanf("%d", &a[i][j]);
}
}
printf("Enter the elements of matrix b:\n");
for(i = 0; i < n; i++){
for(j = 0; j < n; j++){
scanf("%d", &b[i][j]);
}
}
matrix_mult(a, b, c, n);
printf("Resultant matrix c:\n");
for(i = 0; i < n; i++){
for(j = 0; j < n; j++){
printf("%d\t", c[i][j]);
}
printf("\n");
}
return 0;
}
```
阅读全文