Fix problem with complex numbers
This commit is contained in:
parent
565fb1dcc8
commit
c757c4650c
365
atrip.org
365
atrip.org
@ -214,17 +214,49 @@ of a GPU architecture.
|
||||
|
||||
#+begin_src c++ :tangle (atrip-types-h)
|
||||
#pragma once
|
||||
#include <atrip/Complex.hpp>
|
||||
#include <atrip/Atrip.hpp>
|
||||
|
||||
namespace atrip {
|
||||
|
||||
template <typename F>
|
||||
struct DataField;
|
||||
|
||||
template <>
|
||||
struct DataField<double> {
|
||||
using type = double;
|
||||
};
|
||||
|
||||
#if defined(HAVE_CUDA)
|
||||
|
||||
template <typename F>
|
||||
using DataPtr = CUdeviceptr;
|
||||
#define DataNullPtr 0x00
|
||||
#define _AT_(_array, _idx) ((F*)(_array))[(_idx)]
|
||||
|
||||
template <>
|
||||
struct DataField<Complex> {
|
||||
using type = cuDoubleComplex;
|
||||
};
|
||||
|
||||
|
||||
#else
|
||||
|
||||
template <typename F>
|
||||
using DataPtr = F*;
|
||||
#define DataNullPtr nullptr
|
||||
#define _AT_(_array, _idx) (_array)[(_idx)]
|
||||
|
||||
template <>
|
||||
struct DataField<Complex> {
|
||||
using type = Complex;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
template <typename F>
|
||||
using DataFieldType = typename DataField<F>::type;
|
||||
|
||||
}
|
||||
#+end_src
|
||||
|
||||
|
||||
@ -748,7 +780,7 @@ The main behaviour of the function should
|
||||
|
||||
#if defined(HAVE_CUDA)
|
||||
// copy the retrieved mpi data to the device
|
||||
cuMemcpy((DataPtr<F>)mpi_data, data, size);
|
||||
cuMemcpyHtoD(data, (void*)mpi_data, sizeof(F) * size);
|
||||
std::free(mpi_data);
|
||||
#endif
|
||||
|
||||
@ -2558,6 +2590,10 @@ tensor contractions.
|
||||
#include<atrip/Blas.hpp>
|
||||
#include<atrip/Utils.hpp>
|
||||
|
||||
#if defined(HAVE_CUDA)
|
||||
#include<thrust/device_vector.h>
|
||||
#endif
|
||||
|
||||
|
||||
namespace atrip {
|
||||
using ABCTuple = std::array<size_t, 3>;
|
||||
@ -2886,26 +2922,27 @@ V^{{\color{blue}ab}}_{{\color{red}e}i} T^{{\color{blue}c}{\color{red}e}}_{ij} \
|
||||
, size_t const No
|
||||
, size_t const Nv
|
||||
// -- VABCI
|
||||
, F* const VABph
|
||||
, F* const VACph
|
||||
, F* const VBCph
|
||||
, F* const VBAph
|
||||
, F* const VCAph
|
||||
, F* const VCBph
|
||||
, DataPtr<F> const VABph
|
||||
, DataPtr<F> const VACph
|
||||
, DataPtr<F> const VBCph
|
||||
, DataPtr<F> const VBAph
|
||||
, DataPtr<F> const VCAph
|
||||
, DataPtr<F> const VCBph
|
||||
// -- VHHHA
|
||||
, F* const VhhhA
|
||||
, F* const VhhhB
|
||||
, F* const VhhhC
|
||||
, DataPtr<F> const VhhhA
|
||||
, DataPtr<F> const VhhhB
|
||||
, DataPtr<F> const VhhhC
|
||||
// -- TA
|
||||
, F* const TAphh
|
||||
, F* const TBphh
|
||||
, F* const TCphh
|
||||
, DataPtr<F> const TAphh
|
||||
, DataPtr<F> const TBphh
|
||||
, DataPtr<F> const TCphh
|
||||
// -- TABIJ
|
||||
, F* const TABhh
|
||||
, F* const TAChh
|
||||
, F* const TBChh
|
||||
, DataPtr<F> const TABhh
|
||||
, DataPtr<F> const TAChh
|
||||
, DataPtr<F> const TBChh
|
||||
// -- TIJK
|
||||
, F* Tijk
|
||||
// , DataPtr<F> Tijk
|
||||
, DataFieldType<F>* Tijk_
|
||||
);
|
||||
#+end_src
|
||||
|
||||
@ -2919,32 +2956,35 @@ V^{{\color{blue}ab}}_{{\color{red}e}i} T^{{\color{blue}c}{\color{red}e}}_{ij} \
|
||||
, size_t const No
|
||||
, size_t const Nv
|
||||
// -- VABCI
|
||||
, F* const VABph
|
||||
, F* const VACph
|
||||
, F* const VBCph
|
||||
, F* const VBAph
|
||||
, F* const VCAph
|
||||
, F* const VCBph
|
||||
, DataPtr<F> const VABph
|
||||
, DataPtr<F> const VACph
|
||||
, DataPtr<F> const VBCph
|
||||
, DataPtr<F> const VBAph
|
||||
, DataPtr<F> const VCAph
|
||||
, DataPtr<F> const VCBph
|
||||
// -- VHHHA
|
||||
, F* const VhhhA
|
||||
, F* const VhhhB
|
||||
, F* const VhhhC
|
||||
, DataPtr<F> const VhhhA
|
||||
, DataPtr<F> const VhhhB
|
||||
, DataPtr<F> const VhhhC
|
||||
// -- TA
|
||||
, F* const TAphh
|
||||
, F* const TBphh
|
||||
, F* const TCphh
|
||||
, DataPtr<F> const TAphh
|
||||
, DataPtr<F> const TBphh
|
||||
, DataPtr<F> const TCphh
|
||||
// -- TABIJ
|
||||
, F* const TABhh
|
||||
, F* const TAChh
|
||||
, F* const TBChh
|
||||
, DataPtr<F> const TABhh
|
||||
, DataPtr<F> const TAChh
|
||||
, DataPtr<F> const TBChh
|
||||
// -- TIJK
|
||||
, F* Tijk
|
||||
// , DataPtr<F> Tijk_
|
||||
, DataFieldType<F>* Tijk_
|
||||
) {
|
||||
|
||||
const size_t a = abc[0], b = abc[1], c = abc[2]
|
||||
, NoNo = No*No, NoNv = No*Nv
|
||||
;
|
||||
|
||||
typename DataField<F>::type* Tijk = (typename DataField<F>::type*) Tijk_;
|
||||
|
||||
#if defined(ATRIP_USE_DGEMM)
|
||||
#define _IJK_(i, j, k) i + j*No + k*NoNo
|
||||
#define REORDER(__II, __JJ, __KK) \
|
||||
@ -2952,58 +2992,105 @@ V^{{\color{blue}ab}}_{{\color{red}e}i} T^{{\color{blue}c}{\color{red}e}}_{ij} \
|
||||
for (size_t k = 0; k < No; k++) \
|
||||
for (size_t j = 0; j < No; j++) \
|
||||
for (size_t i = 0; i < No; i++) { \
|
||||
Tijk[_IJK_(i, j, k)] += _t_buffer[_IJK_(__II, __JJ, __KK)]; \
|
||||
Tijk[_IJK_(i, j, k)] += _t_buffer_p[_IJK_(__II, __JJ, __KK)]; \
|
||||
} \
|
||||
)
|
||||
#if defined(HAVE_CUDA)
|
||||
#define __TO_DEVICEPTR(_v) (DataFieldType<F>*)(CUdeviceptr)thrust::raw_pointer_cast((_v))
|
||||
#define DGEMM_PARTICLES(__A, __B) \
|
||||
atrip::xgemm<F>( "T" \
|
||||
, "N" \
|
||||
, (int const*)&NoNo \
|
||||
, (int const*)&No \
|
||||
, (int const*)&Nv \
|
||||
, &one \
|
||||
, __A \
|
||||
, (int const*)&Nv \
|
||||
, __B \
|
||||
, (int const*)&Nv \
|
||||
, &zero \
|
||||
, _t_buffer.data() \
|
||||
, (int const*)&NoNo \
|
||||
);
|
||||
atrip::xgemm<F>("T", \
|
||||
"N", \
|
||||
(int const*)&NoNo, \
|
||||
(int const*)&No, \
|
||||
(int const*)&Nv, \
|
||||
&one, \
|
||||
(DataFieldType<F>*)__A, \
|
||||
(int const*)&Nv, \
|
||||
(DataFieldType<F>*)__B, \
|
||||
(int const*)&Nv, \
|
||||
&zero, \
|
||||
_t_buffer_p, \
|
||||
(int const*)&NoNo);
|
||||
#define DGEMM_HOLES(__A, __B, __TRANSB) \
|
||||
atrip::xgemm<F>( "N" \
|
||||
, __TRANSB \
|
||||
, (int const*)&NoNo \
|
||||
, (int const*)&No \
|
||||
, (int const*)&No \
|
||||
, &m_one \
|
||||
, __A \
|
||||
, (int const*)&NoNo \
|
||||
, __B \
|
||||
, (int const*)&No \
|
||||
, &zero \
|
||||
, _t_buffer.data() \
|
||||
, (int const*)&NoNo \
|
||||
atrip::xgemm<F>("N", \
|
||||
__TRANSB, \
|
||||
(int const*)&NoNo, \
|
||||
(int const*)&No, \
|
||||
(int const*)&No, \
|
||||
&m_one, \
|
||||
__TO_DEVICEPTR(__A), \
|
||||
(int const*)&NoNo, \
|
||||
(DataFieldType<F>*)__B, \
|
||||
(int const*)&No, \
|
||||
&zero, \
|
||||
_t_buffer_p, \
|
||||
(int const*)&NoNo \
|
||||
);
|
||||
#define MAYBE_CONJ(_conj, _buffer) \
|
||||
for (size_t __i = 0; __i < NoNoNo; ++__i) \
|
||||
_conj[__i] = maybeConjugate<F>(_buffer[__i]); \
|
||||
_conj[__i] = \
|
||||
maybeConjugate<DataFieldType<F>>(((DataFieldType<F>*)_buffer)[__i]);
|
||||
#else
|
||||
#define __TO_DEVICEPTR(_v) (_v)
|
||||
#define DGEMM_PARTICLES(__A, __B) \
|
||||
atrip::xgemm<F>("T", \
|
||||
"N", \
|
||||
(int const*)&NoNo, \
|
||||
(int const*)&No, \
|
||||
(int const*)&Nv, \
|
||||
&one, \
|
||||
__A, \
|
||||
(int const*)&Nv, \
|
||||
__B, \
|
||||
(int const*)&Nv, \
|
||||
&zero, \
|
||||
_t_buffer_p, \
|
||||
(int const*)&NoNo \
|
||||
);
|
||||
#define DGEMM_HOLES(__A, __B, __TRANSB) \
|
||||
atrip::xgemm<F>("N", \
|
||||
__TRANSB, \
|
||||
(int const*)&NoNo, \
|
||||
(int const*)&No, \
|
||||
(int const*)&No, \
|
||||
&m_one, \
|
||||
__A, \
|
||||
(int const*)&NoNo, \
|
||||
__B, \
|
||||
(int const*)&No, \
|
||||
&zero, \
|
||||
_t_buffer_p, \
|
||||
(int const*)&NoNo \
|
||||
);
|
||||
#define MAYBE_CONJ(_conj, _buffer) \
|
||||
for (size_t __i = 0; __i < NoNoNo; ++__i) \
|
||||
_conj[__i] = maybeConjugate<F>(_buffer[__i]);
|
||||
#endif
|
||||
|
||||
const size_t NoNoNo = No*NoNo;
|
||||
#ifdef HAVE_CUDA
|
||||
thrust::device_vector< DataFieldType<F> > _t_buffer;
|
||||
#else
|
||||
std::vector<F> _t_buffer;
|
||||
#endif
|
||||
_t_buffer.reserve(NoNoNo);
|
||||
DataFieldType<F>* _t_buffer_p = __TO_DEVICEPTR(_t_buffer.data());
|
||||
F one{1.0}, m_one{-1.0}, zero{0.0};
|
||||
|
||||
WITH_CHRONO("double:reorder",
|
||||
for (size_t k = 0; k < NoNoNo; k++) {
|
||||
Tijk[k] = 0.0;
|
||||
Tijk[k] = DataFieldType<F>{0.0};
|
||||
})
|
||||
|
||||
// TOMERGE: replace chronos
|
||||
WITH_CHRONO("doubles:holes",
|
||||
{ // Holes part %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||
|
||||
#ifdef HAVE_CUDA
|
||||
thrust::device_vector< DataFieldType<F> > _vhhh(NoNoNo);
|
||||
#else
|
||||
std::vector<F> _vhhh(NoNoNo);
|
||||
#endif
|
||||
|
||||
// VhhhC[i + k*No + L*NoNo] * TABhh[L + j*No]; H1
|
||||
MAYBE_CONJ(_vhhh, VhhhC)
|
||||
@ -3135,59 +3222,59 @@ V^{{\color{blue}ab}}_{{\color{red}e}i} T^{{\color{blue}c}{\color{red}e}}_{ij} \
|
||||
|
||||
// instantiate templates
|
||||
template
|
||||
void doublesContribution
|
||||
void doublesContribution<double>
|
||||
( const ABCTuple &abc
|
||||
, size_t const No
|
||||
, size_t const Nv
|
||||
// -- VABCI
|
||||
, double* const VABph
|
||||
, double* const VACph
|
||||
, double* const VBCph
|
||||
, double* const VBAph
|
||||
, double* const VCAph
|
||||
, double* const VCBph
|
||||
, DataPtr<double> const VABph
|
||||
, DataPtr<double> const VACph
|
||||
, DataPtr<double> const VBCph
|
||||
, DataPtr<double> const VBAph
|
||||
, DataPtr<double> const VCAph
|
||||
, DataPtr<double> const VCBph
|
||||
// -- VHHHA
|
||||
, double* const VhhhA
|
||||
, double* const VhhhB
|
||||
, double* const VhhhC
|
||||
, DataPtr<double> const VhhhA
|
||||
, DataPtr<double> const VhhhB
|
||||
, DataPtr<double> const VhhhC
|
||||
// -- TA
|
||||
, double* const TAphh
|
||||
, double* const TBphh
|
||||
, double* const TCphh
|
||||
, DataPtr<double> const TAphh
|
||||
, DataPtr<double> const TBphh
|
||||
, DataPtr<double> const TCphh
|
||||
// -- TABIJ
|
||||
, double* const TABhh
|
||||
, double* const TAChh
|
||||
, double* const TBChh
|
||||
, DataPtr<double> const TABhh
|
||||
, DataPtr<double> const TAChh
|
||||
, DataPtr<double> const TBChh
|
||||
// -- TIJK
|
||||
, double* Tijk
|
||||
, DataFieldType<double>* Tijk
|
||||
);
|
||||
|
||||
template
|
||||
void doublesContribution
|
||||
void doublesContribution<Complex>
|
||||
( const ABCTuple &abc
|
||||
, size_t const No
|
||||
, size_t const Nv
|
||||
// -- VABCI
|
||||
, Complex* const VABph
|
||||
, Complex* const VACph
|
||||
, Complex* const VBCph
|
||||
, Complex* const VBAph
|
||||
, Complex* const VCAph
|
||||
, Complex* const VCBph
|
||||
, DataPtr<Complex> const VABph
|
||||
, DataPtr<Complex> const VACph
|
||||
, DataPtr<Complex> const VBCph
|
||||
, DataPtr<Complex> const VBAph
|
||||
, DataPtr<Complex> const VCAph
|
||||
, DataPtr<Complex> const VCBph
|
||||
// -- VHHHA
|
||||
, Complex* const VhhhA
|
||||
, Complex* const VhhhB
|
||||
, Complex* const VhhhC
|
||||
, DataPtr<Complex> const VhhhA
|
||||
, DataPtr<Complex> const VhhhB
|
||||
, DataPtr<Complex> const VhhhC
|
||||
// -- TA
|
||||
, Complex* const TAphh
|
||||
, Complex* const TBphh
|
||||
, Complex* const TCphh
|
||||
, DataPtr<Complex> const TAphh
|
||||
, DataPtr<Complex> const TBphh
|
||||
, DataPtr<Complex> const TCphh
|
||||
// -- TABIJ
|
||||
, Complex* const TABhh
|
||||
, Complex* const TAChh
|
||||
, Complex* const TBChh
|
||||
, DataPtr<Complex> const TABhh
|
||||
, DataPtr<Complex> const TAChh
|
||||
, DataPtr<Complex> const TBChh
|
||||
// -- TIJK
|
||||
, Complex* Tijk
|
||||
, DataFieldType<Complex>* Tijk
|
||||
);
|
||||
#+end_src
|
||||
|
||||
@ -3212,6 +3299,7 @@ is mainly using the =DGEMM= function, which we declare as
|
||||
#pragma once
|
||||
|
||||
#include <atrip/Complex.hpp>
|
||||
#include <atrip/Types.hpp>
|
||||
#include "config.h"
|
||||
|
||||
namespace atrip {
|
||||
@ -3259,12 +3347,12 @@ namespace atrip {
|
||||
const int *n,
|
||||
const int *k,
|
||||
F *alpha,
|
||||
const F *A,
|
||||
const DataFieldType<F> *A,
|
||||
const int *lda,
|
||||
const F *B,
|
||||
const DataFieldType<F> *B,
|
||||
const int *ldb,
|
||||
F *beta,
|
||||
F *C,
|
||||
DataFieldType<F> *C,
|
||||
const int *ldc);
|
||||
}
|
||||
#+end_src
|
||||
@ -3292,18 +3380,18 @@ namespace atrip {
|
||||
|
||||
|
||||
template <>
|
||||
void xgemm(const char *transa,
|
||||
void xgemm<double>(const char *transa,
|
||||
const char *transb,
|
||||
const int *m,
|
||||
const int *n,
|
||||
const int *k,
|
||||
double *alpha,
|
||||
const double *A,
|
||||
const typename DataField<double>::type *A,
|
||||
const int *lda,
|
||||
const double *B,
|
||||
const typename DataField<double>::type *B,
|
||||
const int *ldb,
|
||||
double *beta,
|
||||
double *C,
|
||||
typename DataField<double>::type *C,
|
||||
const int *ldc) {
|
||||
#if defined(HAVE_CUDA)
|
||||
cublasDgemm(Atrip::cuda.handle,
|
||||
@ -3323,20 +3411,21 @@ namespace atrip {
|
||||
}
|
||||
|
||||
template <>
|
||||
void xgemm(const char *transa,
|
||||
void xgemm<Complex>(const char *transa,
|
||||
const char *transb,
|
||||
const int *m,
|
||||
const int *n,
|
||||
const int *k,
|
||||
Complex *alpha,
|
||||
const Complex *A,
|
||||
const typename DataField<Complex>::type *A,
|
||||
const int *lda,
|
||||
const Complex *B,
|
||||
const typename DataField<Complex>::type *B,
|
||||
const int *ldb,
|
||||
Complex *beta,
|
||||
Complex *C,
|
||||
typename DataField<Complex>::type *C,
|
||||
const int *ldc) {
|
||||
#if defined(HAVE_CUDA)
|
||||
#pragma warning HAVE_CUDA
|
||||
cuDoubleComplex
|
||||
cu_alpha = {std::real(*alpha), std::imag(*alpha)},
|
||||
cu_beta = {std::real(*beta), std::imag(*beta)};
|
||||
@ -3345,10 +3434,11 @@ namespace atrip {
|
||||
char_to_cublasOperation(transb),
|
||||
,*m, *n, *k,
|
||||
&cu_alpha,
|
||||
reinterpret_cast<const cuDoubleComplex*>(A), *lda,
|
||||
reinterpret_cast<const cuDoubleComplex*>(B), *ldb,
|
||||
|
||||
A, *lda,
|
||||
B, *ldb,
|
||||
&cu_beta,
|
||||
reinterpret_cast<const cuDoubleComplex*>(C), *ldc);
|
||||
C, *ldc);
|
||||
#else
|
||||
zgemm_(transa, transb,
|
||||
m, n, k,
|
||||
@ -3531,9 +3621,9 @@ Atrip::Output Atrip::run(Atrip::Input<F> const& in) {
|
||||
cuMemAlloc(&epsi, sizeof(F) * _epsi.size());
|
||||
cuMemAlloc(&epsa, sizeof(F) * _epsa.size());
|
||||
|
||||
cuMemcpy(Tai, (DataPtr<F>)_Tai.data(), sizeof(F) * _Tai.size());
|
||||
cuMemcpy(epsi,(DataPtr<F>)_epsi.data(), sizeof(F) * _epsi.size());
|
||||
cuMemcpy(epsa, (DataPtr<F>)_epsa.data(), sizeof(F) * _epsa.size());
|
||||
cuMemcpyHtoD(Tai, (void*)_Tai.data(), sizeof(F) * _Tai.size());
|
||||
cuMemcpyHtoD(epsi,(void*)_epsi.data(), sizeof(F) * _epsi.size());
|
||||
cuMemcpyHtoD(epsa, (void*)_epsa.data(), sizeof(F) * _epsa.size());
|
||||
|
||||
DataPtr<F> Tijk, Zijk;
|
||||
cuMemAlloc(&Tijk, sizeof(F) * No * No * No);
|
||||
@ -3930,27 +4020,27 @@ Atrip::Output Atrip::run(Atrip::Input<F> const& in) {
|
||||
WITH_CHRONO("doubles",
|
||||
doublesContribution<F>( abc, (size_t)No, (size_t)Nv
|
||||
// -- VABCI
|
||||
, (F*)abph.unwrapSlice(Slice<F>::AB, abc)
|
||||
, (F*)abph.unwrapSlice(Slice<F>::AC, abc)
|
||||
, (F*)abph.unwrapSlice(Slice<F>::BC, abc)
|
||||
, (F*)abph.unwrapSlice(Slice<F>::BA, abc)
|
||||
, (F*)abph.unwrapSlice(Slice<F>::CA, abc)
|
||||
, (F*)abph.unwrapSlice(Slice<F>::CB, abc)
|
||||
, abph.unwrapSlice(Slice<F>::AB, abc)
|
||||
, abph.unwrapSlice(Slice<F>::AC, abc)
|
||||
, abph.unwrapSlice(Slice<F>::BC, abc)
|
||||
, abph.unwrapSlice(Slice<F>::BA, abc)
|
||||
, abph.unwrapSlice(Slice<F>::CA, abc)
|
||||
, abph.unwrapSlice(Slice<F>::CB, abc)
|
||||
// -- VHHHA
|
||||
, (F*)hhha.unwrapSlice(Slice<F>::A, abc)
|
||||
, (F*)hhha.unwrapSlice(Slice<F>::B, abc)
|
||||
, (F*)hhha.unwrapSlice(Slice<F>::C, abc)
|
||||
, hhha.unwrapSlice(Slice<F>::A, abc)
|
||||
, hhha.unwrapSlice(Slice<F>::B, abc)
|
||||
, hhha.unwrapSlice(Slice<F>::C, abc)
|
||||
// -- TA
|
||||
, (F*)taphh.unwrapSlice(Slice<F>::A, abc)
|
||||
, (F*)taphh.unwrapSlice(Slice<F>::B, abc)
|
||||
, (F*)taphh.unwrapSlice(Slice<F>::C, abc)
|
||||
, taphh.unwrapSlice(Slice<F>::A, abc)
|
||||
, taphh.unwrapSlice(Slice<F>::B, abc)
|
||||
, taphh.unwrapSlice(Slice<F>::C, abc)
|
||||
// -- TABIJ
|
||||
, (F*)tabhh.unwrapSlice(Slice<F>::AB, abc)
|
||||
, (F*)tabhh.unwrapSlice(Slice<F>::AC, abc)
|
||||
, (F*)tabhh.unwrapSlice(Slice<F>::BC, abc)
|
||||
, tabhh.unwrapSlice(Slice<F>::AB, abc)
|
||||
, tabhh.unwrapSlice(Slice<F>::AC, abc)
|
||||
, tabhh.unwrapSlice(Slice<F>::BC, abc)
|
||||
// -- TIJK
|
||||
#if defined(HAVE_CUDA)
|
||||
, (F*)Tijk
|
||||
, (DataFieldType<F>*)Tijk
|
||||
#else
|
||||
, Tijk.data()
|
||||
#endif
|
||||
@ -4423,6 +4513,10 @@ namespace atrip {
|
||||
|
||||
template <typename F> F maybeConjugate(const F);
|
||||
|
||||
#if defined(HAVE_CUDA)
|
||||
void operator+=(cuDoubleComplex& lz, cuDoubleComplex const& rz);
|
||||
#endif
|
||||
|
||||
namespace traits {
|
||||
|
||||
template <typename FF> bool isComplex();
|
||||
@ -4445,6 +4539,13 @@ namespace atrip {
|
||||
template <> double maybeConjugate(const double a) { return a; }
|
||||
template <> Complex maybeConjugate(const Complex a) { return std::conj(a); }
|
||||
|
||||
#if defined(HAVE_CUDA)
|
||||
void operator+=(cuDoubleComplex& lz, cuDoubleComplex const& rz) {
|
||||
lz.x += rz.x;
|
||||
lz.y += rz.y;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
namespace traits {
|
||||
template <typename F> bool isComplex() { return false; }
|
||||
|
||||
@ -7,19 +7,19 @@
|
||||
|
||||
let
|
||||
|
||||
mkl = import ./etc/nix/mkl.nix { pkgs = (import <nixpkgs> {
|
||||
unfree-pkgs = import <nixpkgs> {
|
||||
config.allowUnfree = true;
|
||||
}); };
|
||||
};
|
||||
|
||||
openblas = import ./etc/nix/openblas.nix { inherit pkgs; };
|
||||
|
||||
cuda-pkg = if cuda then (import ./cuda.nix { inherit pkgs; }) else {};
|
||||
mkl = import ./etc/nix/mkl.nix { pkgs = unfree-pkgs; };
|
||||
cuda-pkg = if cuda then (import ./cuda.nix { pkgs = unfree-pkgs; }) else {};
|
||||
|
||||
in
|
||||
|
||||
pkgs.mkShell rec {
|
||||
|
||||
|
||||
compiler-pkg
|
||||
= if compiler == "gcc11" then pkgs.gcc11
|
||||
else if compiler == "gcc10" then pkgs.gcc10
|
||||
|
||||
Loading…
Reference in New Issue
Block a user