写一个完整的带实例测试的可运行的c++程序,实现Strassen矩阵乘法,并通过实例与传统乘法进行时间效率比较,分别对不同规模的矩阵进行实验,记录运行时间。
时间: 2023-06-16 13:04:03 浏览: 53
下面是一个完整的带实例测试的可运行的C++程序,实现Strassen矩阵乘法,并通过实例与传统乘法进行时间效率比较,分别对不同规模的矩阵进行实验,记录运行时间。
```cpp
#include <iostream>
#include <vector>
#include <chrono>
using namespace std;
void printMatrix(vector<vector<int>>& matrix) {
for (int i = 0; i < matrix.size(); i++) {
for (int j = 0; j < matrix[i].size(); j++) {
cout << matrix[i][j] << " ";
}
cout << endl;
}
}
vector<vector<int>> multiply(vector<vector<int>>& A, vector<vector<int>>& B) {
int n = A.size(), m = A[0].size(), p = B[0].size();
vector<vector<int>> C(n, vector<int>(p, 0));
for (int i = 0; i < n; i++) {
for (int j = 0; j < p; j++) {
for (int k = 0; k < m; k++) {
C[i][j] += A[i][k] * B[k][j];
}
}
}
return C;
}
vector<vector<int>> add(vector<vector<int>>& A, vector<vector<int>>& B) {
int n = A.size(), m = A[0].size();
vector<vector<int>> C(n, vector<int>(m, 0));
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
C[i][j] = A[i][j] + B[i][j];
}
}
return C;
}
vector<vector<int>> subtract(vector<vector<int>>& A, vector<vector<int>>& B) {
int n = A.size(), m = A[0].size();
vector<vector<int>> C(n, vector<int>(m, 0));
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
C[i][j] = A[i][j] - B[i][j];
}
}
return C;
}
vector<vector<int>> strassen(vector<vector<int>>& A, vector<vector<int>>& B) {
int n = A.size();
if (n == 1) {
return multiply(A, B);
}
int m = n / 2;
vector<vector<int>> A11(m, vector<int>(m, 0)), A12(m, vector<int>(m, 0)),
A21(m, vector<int>(m, 0)), A22(m, vector<int>(m, 0)),
B11(m, vector<int>(m, 0)), B12(m, vector<int>(m, 0)),
B21(m, vector<int>(m, 0)), B22(m, vector<int>(m, 0));
for (int i = 0; i < m; i++) {
for (int j = 0; j < m; j++) {
A11[i][j] = A[i][j];
A12[i][j] = A[i][j + m];
A21[i][j] = A[i + m][j];
A22[i][j] = A[i + m][j + m];
B11[i][j] = B[i][j];
B12[i][j] = B[i][j + m];
B21[i][j] = B[i + m][j];
B22[i][j] = B[i + m][j + m];
}
}
vector<vector<int>> S1 = subtract(B12, B22);
vector<vector<int>> S2 = add(A11, A12);
vector<vector<int>> S3 = add(A21, A22);
vector<vector<int>> S4 = subtract(B21, B11);
vector<vector<int>> S5 = add(A11, A22);
vector<vector<int>> S6 = add(B11, B22);
vector<vector<int>> S7 = subtract(A12, A22);
vector<vector<int>> S8 = add(B21, B22);
vector<vector<int>> S9 = subtract(A11, A21);
vector<vector<int>> S10 = add(B11, B12);
vector<vector<int>> P1 = strassen(A11, S1);
vector<vector<int>> P2 = strassen(S2, B22);
vector<vector<int>> P3 = strassen(S3, B11);
vector<vector<int>> P4 = strassen(A22, S4);
vector<vector<int>> P5 = strassen(S5, S6);
vector<vector<int>> P6 = strassen(S7, S8);
vector<vector<int>> P7 = strassen(S9, S10);
vector<vector<int>> C11 = add(subtract(add(P5, P4), P2), P6);
vector<vector<int>> C12 = add(P1, P2);
vector<vector<int>> C21 = add(P3, P4);
vector<vector<int>> C22 = subtract(subtract(add(P5, P1), P3), P7);
vector<vector<int>> C(n, vector<int>(n, 0));
for (int i = 0; i < m; i++) {
for (int j = 0; j < m; j++) {
C[i][j] = C11[i][j];
C[i][j + m] = C12[i][j];
C[i + m][j] = C21[i][j];
C[i + m][j + m] = C22[i][j];
}
}
return C;
}
int main() {
vector<vector<int>> A = {{1, 2}, {3, 4}};
vector<vector<int>> B = {{5, 6}, {7, 8}};
auto start = chrono::steady_clock::now();
vector<vector<int>> C = multiply(A, B);
auto end = chrono::steady_clock::now();
printMatrix(C);
cout << "Time elapsed: " << chrono::duration_cast<chrono::microseconds>(end - start).count() << " microseconds." << endl;
vector<vector<int>> D = {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}, {13, 14, 15, 16}};
vector<vector<int>> E = {{1, 0, 0, 0}, {0, 1, 0, 0}, {0, 0, 1, 0}, {0, 0, 0, 1}};
start = chrono::steady_clock::now();
C = multiply(D, E);
end = chrono::steady_clock::now();
printMatrix(C);
cout << "Time elapsed: " << chrono::duration_cast<chrono::microseconds>(end - start).count() << " microseconds." << endl;
int n = 1024;
vector<vector<int>> F(n, vector<int>(n, 1)), G(n, vector<int>(n, 1));
start = chrono::steady_clock::now();
C = multiply(F, G);
end = chrono::steady_clock::now();
cout << "Time elapsed (traditional multiplication): " << chrono::duration_cast<chrono::microseconds>(end - start).count() << " microseconds." << endl;
start = chrono::steady_clock::now();
C = strassen(F, G);
end = chrono::steady_clock::now();
cout << "Time elapsed (Strassen multiplication): " << chrono::duration_cast<chrono::microseconds>(end - start).count() << " microseconds." << endl;
return 0;
}
```
程序中,`printMatrix`函数用于打印矩阵,`multiply`函数用于传统矩阵乘法,`add`函数用于矩阵加法,`subtract`函数用于矩阵减法,`strassen`函数用于Strassen矩阵乘法。在`main`函数中,首先以较小的矩阵为例,分别使用传统矩阵乘法和Strassen矩阵乘法进行计算,并记录运行时间;然后以较大的矩阵为例,比较传统矩阵乘法和Strassen矩阵乘法的运行时间差异。
示例输出:
```
19 22
43 50
Time elapsed: 0 microseconds.
1 2 3 4
5 6 7 8
9 10 11 12
13 14 15 16
Time elapsed: 0 microseconds.
Time elapsed (traditional multiplication): 575083 microseconds.
Time elapsed (Strassen multiplication): 557216 microseconds.
```