Templatize unions

This commit is contained in:
Alejandro Gallo 2022-01-27 20:44:09 +01:00
parent 6776a7134c
commit 05f5bb6104

202
atrip.org
View File

@ -1241,12 +1241,13 @@ and define subclasses of slice unions.
namespace atrip { namespace atrip {
template <typename F=double>
void sliceIntoVector void sliceIntoVector
( std::vector<double> &v ( std::vector<F> &v
, CTF::Tensor<double> &toSlice , CTF::Tensor<F> &toSlice
, std::vector<int64_t> const low , std::vector<int64_t> const low
, std::vector<int64_t> const up , std::vector<int64_t> const up
, CTF::Tensor<double> const& origin , CTF::Tensor<F> const& origin
, std::vector<int64_t> const originLow , std::vector<int64_t> const originLow
, std::vector<int64_t> const originUp , std::vector<int64_t> const originUp
) { ) {
@ -1273,155 +1274,159 @@ namespace atrip {
, origin_.low.data() , origin_.low.data()
, origin_.up.data() , origin_.up.data()
, 1.0); , 1.0);
memcpy(v.data(), toSlice.data, sizeof(double) * v.size()); memcpy(v.data(), toSlice.data, sizeof(F) * v.size());
#endif #endif
} }
struct TAPHH : public SliceUnion { template <typename F=double>
TAPHH( Tensor const& sourceTensor struct TAPHH : public SliceUnion<F> {
TAPHH( CTF::Tensor<F> const& sourceTensor
, size_t No , size_t No
, size_t Nv , size_t Nv
, size_t np , size_t np
, MPI_Comm child_world , MPI_Comm child_world
, MPI_Comm global_world , MPI_Comm global_world
) : SliceUnion( sourceTensor ) : SliceUnion<F>( sourceTensor
, {Slice::A, Slice::B, Slice::C} , {Slice<F>::A, Slice<F>::B, Slice<F>::C}
, {Nv, No, No} // size of the slices , {Nv, No, No} // size of the slices
, {Nv} , {Nv}
, np , np
, child_world , child_world
, global_world , global_world
, Slice::TA , Slice<F>::TA
, 4) { , 4) {
init(sourceTensor); init(sourceTensor);
} }
void sliceIntoBuffer(size_t it, Tensor &to, Tensor const& from) override void sliceIntoBuffer(size_t it, CTF::Tensor<F> &to, CTF::Tensor<F> const& from) override
{ {
const int Nv = sliceLength[0] const int Nv = this->sliceLength[0]
, No = sliceLength[1] , No = this->sliceLength[1]
, a = rankMap.find({static_cast<size_t>(Atrip::rank), it}); , a = this->rankMap.find({static_cast<size_t>(Atrip::rank), it});
; ;
sliceIntoVector( sources[it] sliceIntoVector<F>( this->sources[it]
, to, {0, 0, 0}, {Nv, No, No} , to, {0, 0, 0}, {Nv, No, No}
, from, {a, 0, 0, 0}, {a+1, Nv, No, No} , from, {a, 0, 0, 0}, {a+1, Nv, No, No}
); );
} }
}; };
struct HHHA : public SliceUnion { template <typename F=double>
HHHA( Tensor const& sourceTensor struct HHHA : public SliceUnion<F> {
HHHA( CTF::Tensor<F> const& sourceTensor
, size_t No , size_t No
, size_t Nv , size_t Nv
, size_t np , size_t np
, MPI_Comm child_world , MPI_Comm child_world
, MPI_Comm global_world , MPI_Comm global_world
) : SliceUnion( sourceTensor ) : SliceUnion<F>( sourceTensor
, {Slice::A, Slice::B, Slice::C} , {Slice<F>::A, Slice<F>::B, Slice<F>::C}
, {No, No, No} // size of the slices , {No, No, No} // size of the slices
, {Nv} // size of the parametrization , {Nv} // size of the parametrization
, np , np
, child_world , child_world
, global_world , global_world
, Slice::VIJKA , Slice<F>::VIJKA
, 4) { , 4) {
init(sourceTensor); init(sourceTensor);
} }
void sliceIntoBuffer(size_t it, Tensor &to, Tensor const& from) override void sliceIntoBuffer(size_t it, CTF::Tensor<F> &to, CTF::Tensor<F> const& from) override
{ {
const int No = sliceLength[0] const int No = this->sliceLength[0]
, a = rankMap.find({static_cast<size_t>(Atrip::rank), it}) , a = this->rankMap.find({static_cast<size_t>(Atrip::rank), it})
; ;
sliceIntoVector( sources[it] sliceIntoVector<F>( this->sources[it]
, to, {0, 0, 0}, {No, No, No} , to, {0, 0, 0}, {No, No, No}
, from, {0, 0, 0, a}, {No, No, No, a+1} , from, {0, 0, 0, a}, {No, No, No, a+1}
); );
} }
}; };
struct ABPH : public SliceUnion { template <typename F=double>
ABPH( Tensor const& sourceTensor struct ABPH : public SliceUnion<F> {
ABPH( CTF::Tensor<F> const& sourceTensor
, size_t No , size_t No
, size_t Nv , size_t Nv
, size_t np , size_t np
, MPI_Comm child_world , MPI_Comm child_world
, MPI_Comm global_world , MPI_Comm global_world
) : SliceUnion( sourceTensor ) : SliceUnion<F>( sourceTensor
, { Slice::AB, Slice::BC, Slice::AC , { Slice<F>::AB, Slice<F>::BC, Slice<F>::AC
, Slice::BA, Slice::CB, Slice::CA , Slice<F>::BA, Slice<F>::CB, Slice<F>::CA
} }
, {Nv, No} // size of the slices , {Nv, No} // size of the slices
, {Nv, Nv} // size of the parametrization , {Nv, Nv} // size of the parametrization
, np , np
, child_world , child_world
, global_world , global_world
, Slice::VABCI , Slice<F>::VABCI
, 2*6) { , 2*6) {
init(sourceTensor); init(sourceTensor);
} }
void sliceIntoBuffer(size_t it, Tensor &to, Tensor const& from) override { void sliceIntoBuffer(size_t it, CTF::Tensor<F> &to, CTF::Tensor<F> const& from) override {
const int Nv = sliceLength[0] const int Nv = this->sliceLength[0]
, No = sliceLength[1] , No = this->sliceLength[1]
, el = rankMap.find({static_cast<size_t>(Atrip::rank), it}) , el = this->rankMap.find({static_cast<size_t>(Atrip::rank), it})
, a = el % Nv , a = el % Nv
, b = el / Nv , b = el / Nv
; ;
sliceIntoVector( sources[it] sliceIntoVector<F>( this->sources[it]
, to, {0, 0}, {Nv, No} , to, {0, 0}, {Nv, No}
, from, {a, b, 0, 0}, {a+1, b+1, Nv, No} , from, {a, b, 0, 0}, {a+1, b+1, Nv, No}
); );
} }
}; };
struct ABHH : public SliceUnion { template <typename F=double>
ABHH( Tensor const& sourceTensor struct ABHH : public SliceUnion<F> {
ABHH( CTF::Tensor<F> const& sourceTensor
, size_t No , size_t No
, size_t Nv , size_t Nv
, size_t np , size_t np
, MPI_Comm child_world , MPI_Comm child_world
, MPI_Comm global_world , MPI_Comm global_world
) : SliceUnion( sourceTensor ) : SliceUnion<F>( sourceTensor
, {Slice::AB, Slice::BC, Slice::AC} , {Slice<F>::AB, Slice<F>::BC, Slice<F>::AC}
, {No, No} // size of the slices , {No, No} // size of the slices
, {Nv, Nv} // size of the parametrization , {Nv, Nv} // size of the parametrization
, np , np
, child_world , child_world
, global_world , global_world
, Slice::VABIJ , Slice<F>::VABIJ
, 6) { , 6) {
init(sourceTensor); init(sourceTensor);
} }
void sliceIntoBuffer(size_t it, Tensor &to, Tensor const& from) override { void sliceIntoBuffer(size_t it, CTF::Tensor<F> &to, CTF::Tensor<F> const& from) override {
const int Nv = from.lens[0] const int Nv = from.lens[0]
, No = sliceLength[1] , No = this->sliceLength[1]
, el = rankMap.find({static_cast<size_t>(Atrip::rank), it}) , el = this->rankMap.find({static_cast<size_t>(Atrip::rank), it})
, a = el % Nv , a = el % Nv
, b = el / Nv , b = el / Nv
; ;
sliceIntoVector( sources[it] sliceIntoVector<F>( this->sources[it]
, to, {0, 0}, {No, No} , to, {0, 0}, {No, No}
, from, {a, b, 0, 0}, {a+1, b+1, No, No} , from, {a, b, 0, 0}, {a+1, b+1, No, No}
); );
} }
@ -1429,39 +1434,40 @@ namespace atrip {
}; };
struct TABHH : public SliceUnion { template <typename F=double>
TABHH( Tensor const& sourceTensor struct TABHH : public SliceUnion<F> {
TABHH( CTF::Tensor<F> const& sourceTensor
, size_t No , size_t No
, size_t Nv , size_t Nv
, size_t np , size_t np
, MPI_Comm child_world , MPI_Comm child_world
, MPI_Comm global_world , MPI_Comm global_world
) : SliceUnion( sourceTensor ) : SliceUnion<F>( sourceTensor
, {Slice::AB, Slice::BC, Slice::AC} , {Slice<F>::AB, Slice<F>::BC, Slice<F>::AC}
, {No, No} // size of the slices , {No, No} // size of the slices
, {Nv, Nv} // size of the parametrization , {Nv, Nv} // size of the parametrization
, np , np
, child_world , child_world
, global_world , global_world
, Slice::TABIJ , Slice<F>::TABIJ
, 6) { , 6) {
init(sourceTensor); init(sourceTensor);
} }
void sliceIntoBuffer(size_t it, Tensor &to, Tensor const& from) override { void sliceIntoBuffer(size_t it, CTF::Tensor<F> &to, CTF::Tensor<F> const& from) override {
// TODO: maybe generalize this with ABHH // TODO: maybe generalize this with ABHH
const int Nv = from.lens[0] const int Nv = from.lens[0]
, No = sliceLength[1] , No = this->sliceLength[1]
, el = rankMap.find({static_cast<size_t>(Atrip::rank), it}) , el = this->rankMap.find({static_cast<size_t>(Atrip::rank), it})
, a = el % Nv , a = el % Nv
, b = el / Nv , b = el / Nv
; ;
sliceIntoVector( sources[it] sliceIntoVector<F>( this->sources[it]
, to, {0, 0}, {No, No} , to, {0, 0}, {No, No}
, from, {a, b, 0, 0}, {a+1, b+1, No, No} , from, {a, b, 0, 0}, {a+1, b+1, No, No}
); );
} }