From 4e2d1143e5aae72f301361b40ccc4d95b88f809e Mon Sep 17 00:00:00 2001 From: Alejandro Gallo Date: Wed, 25 Jan 2023 16:25:09 +0100 Subject: [PATCH] Add getSize static method to calculate the size of sources in SliceUnion --- include/atrip/SliceUnion.hpp | 16 ++++++++++++ src/atrip/Atrip.cxx | 49 ++++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+) diff --git a/include/atrip/SliceUnion.hpp b/include/atrip/SliceUnion.hpp index fa79a54..b98563a 100644 --- a/include/atrip/SliceUnion.hpp +++ b/include/atrip/SliceUnion.hpp @@ -387,6 +387,22 @@ template } } + static size_t + getSize(const std::vector sliceLength, + const std::vector paramLength, + const size_t np, + const MPI_Comm global_world) { + const RankMap rankMap(paramLength, np, global_world); + const size_t + nSources = rankMap.nSources(), + sliceSize = std::accumulate(sliceLength.begin(), + sliceLength.end(), + 1UL, + std::multiplies()); + return nSources * sliceSize; + } + + // CONSTRUCTOR SliceUnion( std::vector::Type> sliceTypes_ , std::vector sliceLength_ diff --git a/src/atrip/Atrip.cxx b/src/atrip/Atrip.cxx index 51cad53..a560375 100644 --- a/src/atrip/Atrip.cxx +++ b/src/atrip/Atrip.cxx @@ -235,11 +235,54 @@ Atrip::Output Atrip::run(Atrip::Input const& in) { MPI_Comm_size(child_comm, &child_size); } + // a, b, c, d, e, f and P => Nv + // H => No + // total_source_sizes contains a list of the number of elements + // in all sources of every tensor union, therefore nSlices * sliceSize + const std::vector total_source_sizes = { + // ABPH + SliceUnion::getSize({Nv, No}, {Nv, Nv}, (size_t)np, universe), + // ABHH + SliceUnion::getSize({No, No}, {Nv, Nv}, (size_t)np, universe), + // TABHH + SliceUnion::getSize({No, No}, {Nv, Nv}, (size_t)np, universe), + // TAPHH + SliceUnion::getSize({Nv, No, No}, {Nv}, (size_t)np, universe), + // HHHA + SliceUnion::getSize({No, No, No}, {Nv}, (size_t)np, universe), + }; + + const size_t + total_source_size = sizeof(DataFieldType) + * std::accumulate(total_source_sizes.begin(), + total_source_sizes.end(), + 0UL); + +#if defined(HAVE_CUDA) + DataPtr all_sources_pointer; + cuMemAlloc(&all_sources_pointer, total_source_size); +#else + DataPtr + all_sources_pointer = (DataPtr)malloc(total_source_size); +#endif + size_t _source_pointer_idx = 0; + // BUILD SLICES PARAMETRIZED BY NV x NV =============================={{{1 WITH_CHRONO("nv-nv-slices", LOG(0,"Atrip") << "building NV x NV slices\n"; + // TODO + // DataPtr offseted_pointer = all_sources_pointer + // * total_source_sizes[_source_pointer_idx++]; ABPH abph(*in.Vppph, (size_t)No, (size_t)Nv, (size_t)np, child_comm, universe); + + // TODO + // DataPtr offseted_pointer = all_sources_pointer + // * total_source_sizes[_source_pointer_idx++]; ABHH abhh(*in.Vpphh, (size_t)No, (size_t)Nv, (size_t)np, child_comm, universe); + + // TODO + // DataPtr offseted_pointer = all_sources_pointer + // * total_source_sizes[_source_pointer_idx++]; TABHH tabhh(*in.Tpphh, (size_t)No, (size_t)Nv, (size_t)np, child_comm, universe); ) @@ -251,7 +294,13 @@ Atrip::Output Atrip::run(Atrip::Input const& in) { // BUILD SLICES PARAMETRIZED BY NV ==================================={{{1 WITH_CHRONO("nv-slices", LOG(0,"Atrip") << "building NV slices\n"; + // TODO + // DataPtr offseted_pointer = all_sources_pointer + // * total_source_sizes[_source_pointer_idx++]; TAPHH taphh(*in.Tpphh, (size_t)No, (size_t)Nv, (size_t)np, child_comm, universe); + // TODO + // DataPtr offseted_pointer = all_sources_pointer + // * total_source_sizes[_source_pointer_idx++]; HHHA hhha(*in.Vhhhp, (size_t)No, (size_t)Nv, (size_t)np, child_comm, universe); )