strassen矩阵乘法的串行程序,并输出运行时间
时间: 2024-05-07 19:16:50 浏览: 146
Strassen矩阵问题
以下是串行实现strassen矩阵乘法的示例代码,并输出运行时间:
```python
import time
# 定义矩阵乘法函数
def matrix_multiply(a, b):
ra, ca = a.shape # 矩阵a的行数和列数
rb, cb = b.shape # 矩阵b的行数和列数
# 判断能否相乘
if ca != rb:
raise ValueError("Matrix shapes are not aligned. %s != %s" % (ca, rb))
c = np.zeros((ra, cb)) # 结果矩阵c初始化
# 矩阵相乘
for i in range(ra):
for j in range(cb):
for k in range(ca):
c[i, j] += a[i, k] * b[k, j]
return c
# 定义strassen矩阵乘法函数
def strassen_mult(a, b):
# 如果矩阵尺寸小于等于64,则采用暴力矩阵乘法
if a.shape[0] <= 64:
return matrix_multiply(a, b)
# 矩阵尺寸一半
m = a.shape[0] // 2
# 划分为子问题
a11, a12 = a[:m, :m], a[:m, m:]
a21, a22 = a[m:, :m], a[m:, m:]
b11, b12 = b[:m, :m], b[:m, m:]
b21, b22 = b[m:, :m], b[m:, m:]
# 7个子问题
p1 = strassen_mult(a11 + a22, b11 + b22)
p2 = strassen_mult(a21 + a22, b11)
p3 = strassen_mult(a11, b12 - b22)
p4 = strassen_mult(a22, b21 - b11)
p5 = strassen_mult(a11 + a12, b22)
p6 = strassen_mult(a21 - a11, b11 + b12)
p7 = strassen_mult(a12 - a22, b21 + b22)
# 计算结果矩阵的四个子矩阵
c11 = p1 + p4 - p5 + p7
c12 = p3 + p5
c21 = p2 + p4
c22 = p1 - p2 + p3 + p6
# 合并子矩阵
c = np.zeros((a.shape[0], b.shape[1]))
c[:m, :m], c[:m, m:], c[m:, :m], c[m:, m:] = c11, c12, c21, c22
return c
# 生成待计算的矩阵a和b
a = np.random.rand(2048, 2048)
b = np.random.rand(2048, 2048)
# 记录运行时间
start_time = time.time()
# 调用strassen矩阵乘法函数
c = strassen_mult(a, b)
end_time = time.time()
# 输出运行时间
print("运行时间:", end_time - start_time, "秒")
```
注意:以上代码为Python代码,需要结合相关库使用。运行时间与测试环境有关,仅供参考。
阅读全文