cblas_dgemm函数源代码
时间: 2024-01-24 18:16:57 浏览: 67
以下是cblas_dgemm函数的源代码:
```
#include "cblas.h"
void cblas_dgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const double alpha, const double *A,
const int lda, const double *B, const int ldb,
const double beta, double *C, const int ldc)
{
char TA, TB;
int nrowa, nrowb, lda_t, ldb_t;
double *buffer;
if (Order == CblasRowMajor) {
nrowa = K;
nrowb = N;
TA = (TransB == CblasNoTrans) ? 'T' : 'N';
TB = (TransA == CblasNoTrans) ? 'T' : 'N';
lda_t = ldb;
ldb_t = lda;
} else {
nrowa = M;
nrowb = K;
TA = (TransA == CblasNoTrans) ? 'T' : 'N';
TB = (TransB == CblasNoTrans) ? 'T' : 'N';
lda_t = lda;
ldb_t = ldb;
}
buffer = (double *) malloc(nrowa * nrowb * sizeof(double));
if (buffer == NULL) {
fprintf(stderr, "Memory allocation failed\n");
exit(1);
}
/* Compute buffer = A * B */
dgemm_(&TA, &TB, &nrowa, &nrowb, &K, &alpha, A, &lda_t, B, &ldb_t,
&beta, buffer, &nrowa);
/* Copy buffer to C */
if (Order == CblasRowMajor) {
int i, j;
for (i = 0; i < M; i++) {
for (j = 0; j < N; j++) {
C[i * ldc + j] = buffer[j * nrowa + i];
}
}
} else {
int i, j;
for (i = 0; i < M; i++) {
for (j = 0; j < N; j++) {
C[i + j * ldc] = buffer[i * nrowb + j];
}
}
}
free(buffer);
}
```
这个函数实现了矩阵乘法,调用了BLAS库中的dgemm()函数进行计算。该函数支持两种不同的矩阵存储方式:行优先存储和列优先存储。函数首先根据存储方式调整输入矩阵的维度和转置方式,然后使用dgemm()函数计算出结果并将其存储在内存中的缓冲区中。最后,函数将缓冲区中的结果复制到输出矩阵中。
阅读全文