diff --git a/src/atrip/Blas.cxx b/src/atrip/Blas.cxx index 05b34b1..4b6b37a 100644 --- a/src/atrip/Blas.cxx +++ b/src/atrip/Blas.cxx @@ -15,19 +15,22 @@ // [[file:~/cuda/atrip/atrip.org::*Blas][Blas:2]] #include #include +#include #if defined(HAVE_CUDA) # include - static inline - cublasOperation_t char_to_cublasOperation(const char* trans) { - if (strncmp("N", trans, 1) == 0) - return CUBLAS_OP_N; - else if (strncmp("T", trans, 1) == 0) - return CUBLAS_OP_T; - else - return CUBLAS_OP_C; - } +static size_t dgem_call = 0; + +static inline +cublasOperation_t char_to_cublasOperation(const char* trans) { + if (strncmp("N", trans, 1) == 0) + return CUBLAS_OP_N; + else if (strncmp("T", trans, 1) == 0) + return CUBLAS_OP_T; + else + return CUBLAS_OP_C; +} #endif @@ -49,13 +52,23 @@ namespace atrip { typename DataField::type *C, const int *ldc) { #if defined(HAVE_CUDA) - cublasDgemm(Atrip::cuda.handle, - char_to_cublasOperation(transa), - char_to_cublasOperation(transb), - *m, *n, *k, - alpha, A, *lda, - B, *ldb, beta, - C, *ldc); + // TODO: remove this verbose checking + const cublasStatus_t error = + cublasDgemm(Atrip::cuda.handle, + char_to_cublasOperation(transa), + char_to_cublasOperation(transb), + *m, *n, *k, + alpha, A, *lda, + B, *ldb, beta, + C, *ldc); + if (error != 0) printf(":%-3ld (%4ldth) ERR<%4d> cublasDgemm: " + "A = %20ld " + "B = %20ld " + "C = %20ld " + "\n", + Atrip::rank, + dgem_call++, + error, A, B, C); #else dgemm_(transa, transb, m, n, k, @@ -83,16 +96,17 @@ namespace atrip { cuDoubleComplex cu_alpha = {std::real(*alpha), std::imag(*alpha)}, cu_beta = {std::real(*beta), std::imag(*beta)}; - cublasZgemm(Atrip::cuda.handle, - char_to_cublasOperation(transa), - char_to_cublasOperation(transb), - *m, *n, *k, - &cu_alpha, - - A, *lda, - B, *ldb, - &cu_beta, - C, *ldc); + + _CHECK_CUBLAS_SUCCESS("cublasZgemm", + cublasZgemm(Atrip::cuda.handle, + char_to_cublasOperation(transa), + char_to_cublasOperation(transb), + *m, *n, *k, + &cu_alpha, + A, *lda, + B, *ldb, + &cu_beta, + C, *ldc)); #else zgemm_(transa, transb, m, n, k,