Add error checking in Blas.cxx

This commit is contained in:
Gallo Alejandro 2022-09-12 19:07:48 +02:00
parent da704ad820
commit c20b9e3bcb

View File

@ -15,19 +15,22 @@
// [[file:~/cuda/atrip/atrip.org::*Blas][Blas:2]] // [[file:~/cuda/atrip/atrip.org::*Blas][Blas:2]]
#include <atrip/Blas.hpp> #include <atrip/Blas.hpp>
#include <atrip/Atrip.hpp> #include <atrip/Atrip.hpp>
#include <atrip/CUDA.hpp>
#if defined(HAVE_CUDA) #if defined(HAVE_CUDA)
# include <cstring> # include <cstring>
static inline static size_t dgem_call = 0;
cublasOperation_t char_to_cublasOperation(const char* trans) {
static inline
cublasOperation_t char_to_cublasOperation(const char* trans) {
if (strncmp("N", trans, 1) == 0) if (strncmp("N", trans, 1) == 0)
return CUBLAS_OP_N; return CUBLAS_OP_N;
else if (strncmp("T", trans, 1) == 0) else if (strncmp("T", trans, 1) == 0)
return CUBLAS_OP_T; return CUBLAS_OP_T;
else else
return CUBLAS_OP_C; return CUBLAS_OP_C;
} }
#endif #endif
@ -49,6 +52,8 @@ namespace atrip {
typename DataField<double>::type *C, typename DataField<double>::type *C,
const int *ldc) { const int *ldc) {
#if defined(HAVE_CUDA) #if defined(HAVE_CUDA)
// TODO: remove this verbose checking
const cublasStatus_t error =
cublasDgemm(Atrip::cuda.handle, cublasDgemm(Atrip::cuda.handle,
char_to_cublasOperation(transa), char_to_cublasOperation(transa),
char_to_cublasOperation(transb), char_to_cublasOperation(transb),
@ -56,6 +61,14 @@ namespace atrip {
alpha, A, *lda, alpha, A, *lda,
B, *ldb, beta, B, *ldb, beta,
C, *ldc); C, *ldc);
if (error != 0) printf(":%-3ld (%4ldth) ERR<%4d> cublasDgemm: "
"A = %20ld "
"B = %20ld "
"C = %20ld "
"\n",
Atrip::rank,
dgem_call++,
error, A, B, C);
#else #else
dgemm_(transa, transb, dgemm_(transa, transb,
m, n, k, m, n, k,
@ -83,16 +96,17 @@ namespace atrip {
cuDoubleComplex cuDoubleComplex
cu_alpha = {std::real(*alpha), std::imag(*alpha)}, cu_alpha = {std::real(*alpha), std::imag(*alpha)},
cu_beta = {std::real(*beta), std::imag(*beta)}; cu_beta = {std::real(*beta), std::imag(*beta)};
_CHECK_CUBLAS_SUCCESS("cublasZgemm",
cublasZgemm(Atrip::cuda.handle, cublasZgemm(Atrip::cuda.handle,
char_to_cublasOperation(transa), char_to_cublasOperation(transa),
char_to_cublasOperation(transb), char_to_cublasOperation(transb),
*m, *n, *k, *m, *n, *k,
&cu_alpha, &cu_alpha,
A, *lda, A, *lda,
B, *ldb, B, *ldb,
&cu_beta, &cu_beta,
C, *ldc); C, *ldc));
#else #else
zgemm_(transa, transb, zgemm_(transa, transb,
m, n, k, m, n, k,