diff --git a/atrip.org b/atrip.org index 1fd7d91..2c14694 100644 --- a/atrip.org +++ b/atrip.org @@ -1863,6 +1863,9 @@ is mainly using the =DGEMM= function, which we declare as #+begin_src c++ :tangle (atrip-blas-h) #pragma once namespace atrip { + + using Complex = std::complex; + extern "C" { void dgemm_( const char *transa, @@ -1871,14 +1874,73 @@ namespace atrip { const int *n, const int *k, double *alpha, - const double *A, + const double *a, const int *lda, - const double *B, + const double *b, const int *ldb, double *beta, - double *C, + double *c, const int *ldc ); + + void zgemm_( + const char *transa, + const char *transb, + const int *m, + const int *n, + const int *k, + Complex *alpha, + const Complex *A, + const int *lda, + const Complex *B, + const int *ldb, + Complex *beta, + Complex *C, + const int *ldc + ); + } + + + template + void xgemm(const char *transa, + const char *transb, + const int *m, + const int *n, + const int *k, + F *alpha, + const F *A, + const int *lda, + const F *B, + const int *ldb, + F *beta, + F *C, + const int *ldc) { + dgemm_(transa, transb, + m, n, k, + alpha, A, lda, + B, ldb, beta, + C, ldc); + } + + template <> + void xgemm(const char *transa, + const char *transb, + const int *m, + const int *n, + const int *k, + Complex *alpha, + const Complex *A, + const int *lda, + const Complex *B, + const int *ldb, + Complex *beta, + Complex *C, + const int *ldc) { + zgemm_(transa, transb, + m, n, k, + alpha, A, lda, + B, ldb, beta, + C, ldc); } } #+end_src