Add error checking in Blas.cxx
This commit is contained in:
parent
da704ad820
commit
c20b9e3bcb
@ -15,19 +15,22 @@
|
|||||||
// [[file:~/cuda/atrip/atrip.org::*Blas][Blas:2]]
|
// [[file:~/cuda/atrip/atrip.org::*Blas][Blas:2]]
|
||||||
#include <atrip/Blas.hpp>
|
#include <atrip/Blas.hpp>
|
||||||
#include <atrip/Atrip.hpp>
|
#include <atrip/Atrip.hpp>
|
||||||
|
#include <atrip/CUDA.hpp>
|
||||||
|
|
||||||
#if defined(HAVE_CUDA)
|
#if defined(HAVE_CUDA)
|
||||||
# include <cstring>
|
# include <cstring>
|
||||||
|
|
||||||
static inline
|
static size_t dgem_call = 0;
|
||||||
cublasOperation_t char_to_cublasOperation(const char* trans) {
|
|
||||||
if (strncmp("N", trans, 1) == 0)
|
static inline
|
||||||
return CUBLAS_OP_N;
|
cublasOperation_t char_to_cublasOperation(const char* trans) {
|
||||||
else if (strncmp("T", trans, 1) == 0)
|
if (strncmp("N", trans, 1) == 0)
|
||||||
return CUBLAS_OP_T;
|
return CUBLAS_OP_N;
|
||||||
else
|
else if (strncmp("T", trans, 1) == 0)
|
||||||
return CUBLAS_OP_C;
|
return CUBLAS_OP_T;
|
||||||
}
|
else
|
||||||
|
return CUBLAS_OP_C;
|
||||||
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -49,13 +52,23 @@ namespace atrip {
|
|||||||
typename DataField<double>::type *C,
|
typename DataField<double>::type *C,
|
||||||
const int *ldc) {
|
const int *ldc) {
|
||||||
#if defined(HAVE_CUDA)
|
#if defined(HAVE_CUDA)
|
||||||
cublasDgemm(Atrip::cuda.handle,
|
// TODO: remove this verbose checking
|
||||||
char_to_cublasOperation(transa),
|
const cublasStatus_t error =
|
||||||
char_to_cublasOperation(transb),
|
cublasDgemm(Atrip::cuda.handle,
|
||||||
*m, *n, *k,
|
char_to_cublasOperation(transa),
|
||||||
alpha, A, *lda,
|
char_to_cublasOperation(transb),
|
||||||
B, *ldb, beta,
|
*m, *n, *k,
|
||||||
C, *ldc);
|
alpha, A, *lda,
|
||||||
|
B, *ldb, beta,
|
||||||
|
C, *ldc);
|
||||||
|
if (error != 0) printf(":%-3ld (%4ldth) ERR<%4d> cublasDgemm: "
|
||||||
|
"A = %20ld "
|
||||||
|
"B = %20ld "
|
||||||
|
"C = %20ld "
|
||||||
|
"\n",
|
||||||
|
Atrip::rank,
|
||||||
|
dgem_call++,
|
||||||
|
error, A, B, C);
|
||||||
#else
|
#else
|
||||||
dgemm_(transa, transb,
|
dgemm_(transa, transb,
|
||||||
m, n, k,
|
m, n, k,
|
||||||
@ -83,16 +96,17 @@ namespace atrip {
|
|||||||
cuDoubleComplex
|
cuDoubleComplex
|
||||||
cu_alpha = {std::real(*alpha), std::imag(*alpha)},
|
cu_alpha = {std::real(*alpha), std::imag(*alpha)},
|
||||||
cu_beta = {std::real(*beta), std::imag(*beta)};
|
cu_beta = {std::real(*beta), std::imag(*beta)};
|
||||||
cublasZgemm(Atrip::cuda.handle,
|
|
||||||
char_to_cublasOperation(transa),
|
_CHECK_CUBLAS_SUCCESS("cublasZgemm",
|
||||||
char_to_cublasOperation(transb),
|
cublasZgemm(Atrip::cuda.handle,
|
||||||
*m, *n, *k,
|
char_to_cublasOperation(transa),
|
||||||
&cu_alpha,
|
char_to_cublasOperation(transb),
|
||||||
|
*m, *n, *k,
|
||||||
A, *lda,
|
&cu_alpha,
|
||||||
B, *ldb,
|
A, *lda,
|
||||||
&cu_beta,
|
B, *ldb,
|
||||||
C, *ldc);
|
&cu_beta,
|
||||||
|
C, *ldc));
|
||||||
#else
|
#else
|
||||||
zgemm_(transa, transb,
|
zgemm_(transa, transb,
|
||||||
m, n, k,
|
m, n, k,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user