Add error checking in Blas.cxx

This commit is contained in:
Gallo Alejandro 2022-09-12 19:07:48 +02:00
parent da704ad820
commit c20b9e3bcb

View File

@ -15,19 +15,22 @@
// [[file:~/cuda/atrip/atrip.org::*Blas][Blas:2]] // [[file:~/cuda/atrip/atrip.org::*Blas][Blas:2]]
#include <atrip/Blas.hpp> #include <atrip/Blas.hpp>
#include <atrip/Atrip.hpp> #include <atrip/Atrip.hpp>
#include <atrip/CUDA.hpp>
#if defined(HAVE_CUDA) #if defined(HAVE_CUDA)
# include <cstring> # include <cstring>
static inline static size_t dgem_call = 0;
cublasOperation_t char_to_cublasOperation(const char* trans) {
if (strncmp("N", trans, 1) == 0) static inline
return CUBLAS_OP_N; cublasOperation_t char_to_cublasOperation(const char* trans) {
else if (strncmp("T", trans, 1) == 0) if (strncmp("N", trans, 1) == 0)
return CUBLAS_OP_T; return CUBLAS_OP_N;
else else if (strncmp("T", trans, 1) == 0)
return CUBLAS_OP_C; return CUBLAS_OP_T;
} else
return CUBLAS_OP_C;
}
#endif #endif
@ -49,13 +52,23 @@ namespace atrip {
typename DataField<double>::type *C, typename DataField<double>::type *C,
const int *ldc) { const int *ldc) {
#if defined(HAVE_CUDA) #if defined(HAVE_CUDA)
cublasDgemm(Atrip::cuda.handle, // TODO: remove this verbose checking
char_to_cublasOperation(transa), const cublasStatus_t error =
char_to_cublasOperation(transb), cublasDgemm(Atrip::cuda.handle,
*m, *n, *k, char_to_cublasOperation(transa),
alpha, A, *lda, char_to_cublasOperation(transb),
B, *ldb, beta, *m, *n, *k,
C, *ldc); alpha, A, *lda,
B, *ldb, beta,
C, *ldc);
if (error != 0) printf(":%-3ld (%4ldth) ERR<%4d> cublasDgemm: "
"A = %20ld "
"B = %20ld "
"C = %20ld "
"\n",
Atrip::rank,
dgem_call++,
error, A, B, C);
#else #else
dgemm_(transa, transb, dgemm_(transa, transb,
m, n, k, m, n, k,
@ -83,16 +96,17 @@ namespace atrip {
cuDoubleComplex cuDoubleComplex
cu_alpha = {std::real(*alpha), std::imag(*alpha)}, cu_alpha = {std::real(*alpha), std::imag(*alpha)},
cu_beta = {std::real(*beta), std::imag(*beta)}; cu_beta = {std::real(*beta), std::imag(*beta)};
cublasZgemm(Atrip::cuda.handle,
char_to_cublasOperation(transa), _CHECK_CUBLAS_SUCCESS("cublasZgemm",
char_to_cublasOperation(transb), cublasZgemm(Atrip::cuda.handle,
*m, *n, *k, char_to_cublasOperation(transa),
&cu_alpha, char_to_cublasOperation(transb),
*m, *n, *k,
A, *lda, &cu_alpha,
B, *ldb, A, *lda,
&cu_beta, B, *ldb,
C, *ldc); &cu_beta,
C, *ldc));
#else #else
zgemm_(transa, transb, zgemm_(transa, transb,
m, n, k, m, n, k,