Add error checking in Blas.cxx

This commit is contained in:
Gallo Alejandro 2022-09-12 19:07:48 +02:00
parent da704ad820
commit c20b9e3bcb

View File

@ -15,19 +15,22 @@
// [[file:~/cuda/atrip/atrip.org::*Blas][Blas:2]]
#include <atrip/Blas.hpp>
#include <atrip/Atrip.hpp>
#include <atrip/CUDA.hpp>
#if defined(HAVE_CUDA)
# include <cstring>
static inline
cublasOperation_t char_to_cublasOperation(const char* trans) {
static size_t dgem_call = 0;
static inline
cublasOperation_t char_to_cublasOperation(const char* trans) {
if (strncmp("N", trans, 1) == 0)
return CUBLAS_OP_N;
else if (strncmp("T", trans, 1) == 0)
return CUBLAS_OP_T;
else
return CUBLAS_OP_C;
}
}
#endif
@ -49,6 +52,8 @@ namespace atrip {
typename DataField<double>::type *C,
const int *ldc) {
#if defined(HAVE_CUDA)
// TODO: remove this verbose checking
const cublasStatus_t error =
cublasDgemm(Atrip::cuda.handle,
char_to_cublasOperation(transa),
char_to_cublasOperation(transb),
@ -56,6 +61,14 @@ namespace atrip {
alpha, A, *lda,
B, *ldb, beta,
C, *ldc);
if (error != 0) printf(":%-3ld (%4ldth) ERR<%4d> cublasDgemm: "
"A = %20ld "
"B = %20ld "
"C = %20ld "
"\n",
Atrip::rank,
dgem_call++,
error, A, B, C);
#else
dgemm_(transa, transb,
m, n, k,
@ -83,16 +96,17 @@ namespace atrip {
cuDoubleComplex
cu_alpha = {std::real(*alpha), std::imag(*alpha)},
cu_beta = {std::real(*beta), std::imag(*beta)};
_CHECK_CUBLAS_SUCCESS("cublasZgemm",
cublasZgemm(Atrip::cuda.handle,
char_to_cublasOperation(transa),
char_to_cublasOperation(transb),
*m, *n, *k,
&cu_alpha,
A, *lda,
B, *ldb,
&cu_beta,
C, *ldc);
C, *ldc));
#else
zgemm_(transa, transb,
m, n, k,