From 5c177a85bc4d490517c2612b76eb6c685161bef6 Mon Sep 17 00:00:00 2001 From: Alejandro Gallo Date: Thu, 27 Jan 2022 20:48:38 +0100 Subject: [PATCH] Templatize main algorithm --- atrip.org | 127 +++++++++++++++++++++++++++++------------------------- 1 file changed, 68 insertions(+), 59 deletions(-) diff --git a/atrip.org b/atrip.org index a3aa110..7f19a7a 100644 --- a/atrip.org +++ b/atrip.org @@ -2028,7 +2028,8 @@ void Atrip::init() { MPI_Comm_size(MPI_COMM_WORLD, &Atrip::np); } -Atrip::Output Atrip::run(Atrip::Input const& in) { +template +Atrip::Output Atrip::run(Atrip::Input const& in) { const int np = Atrip::np; const int rank = Atrip::rank; @@ -2043,14 +2044,14 @@ Atrip::Output Atrip::run(Atrip::Input const& in) { LOG(0,"Atrip") << "Nv: " << Nv << "\n"; // allocate the three scratches, see piecuch - std::vector Tijk(No*No*No) // doubles only (see piecuch) - , Zijk(No*No*No) // singles + doubles (see piecuch) - // we need local copies of the following tensors on every - // rank - , epsi(No) - , epsa(Nv) - , Tai(No * Nv) - ; + std::vector Tijk(No*No*No) // doubles only (see piecuch) + , Zijk(No*No*No) // singles + doubles (see piecuch) + // we need local copies of the following tensors on every + // rank + , epsi(No) + , epsa(Nv) + , Tai(No * Nv) + ; in.ei->read_all(epsi.data()); in.ea->read_all(epsa.data()); @@ -2079,20 +2080,20 @@ Atrip::Output Atrip::run(Atrip::Input const& in) { chrono["nv-slices"].start(); // BUILD SLICES PARAMETRIZED BY NV ==================================={{{1 LOG(0,"Atrip") << "BUILD NV-SLICES\n"; - TAPHH taphh(*in.Tpphh, (size_t)No, (size_t)Nv, (size_t)np, child_comm, universe); - HHHA hhha(*in.Vhhhp, (size_t)No, (size_t)Nv, (size_t)np, child_comm, universe); + TAPHH taphh(*in.Tpphh, (size_t)No, (size_t)Nv, (size_t)np, child_comm, universe); + HHHA hhha(*in.Vhhhp, (size_t)No, (size_t)Nv, (size_t)np, child_comm, universe); chrono["nv-slices"].stop(); chrono["nv-nv-slices"].start(); // BUILD SLICES PARAMETRIZED BY NV x NV =============================={{{1 LOG(0,"Atrip") << "BUILD NV x NV-SLICES\n"; - ABPH abph(*in.Vppph, (size_t)No, (size_t)Nv, (size_t)np, child_comm, universe); - ABHH abhh(*in.Vpphh, (size_t)No, (size_t)Nv, (size_t)np, child_comm, universe); - TABHH tabhh(*in.Tpphh, (size_t)No, (size_t)Nv, (size_t)np, child_comm, universe); + ABPH abph(*in.Vppph, (size_t)No, (size_t)Nv, (size_t)np, child_comm, universe); + ABHH abhh(*in.Vpphh, (size_t)No, (size_t)Nv, (size_t)np, child_comm, universe); + TABHH tabhh(*in.Tpphh, (size_t)No, (size_t)Nv, (size_t)np, child_comm, universe); chrono["nv-nv-slices"].stop(); // all tensors - std::vector< SliceUnion* > unions = {&taphh, &hhha, &abph, &abhh, &tabhh}; + std::vector< SliceUnion* > unions = {&taphh, &hhha, &abph, &abhh, &tabhh}; //CONSTRUCT TUPLE LIST ==============================================={{{1 LOG(0,"Atrip") << "BUILD TUPLE LIST\n"; @@ -2126,18 +2127,20 @@ Atrip::Output Atrip::run(Atrip::Input const& in) { = [&tuplesList](size_t const i) { return i >= tuplesList.size(); }; + using Database = typename Slice::Database; + using LocalDatabase = typename Slice::LocalDatabase; auto communicateDatabase = [ &unions , np , &chrono - ] (ABCTuple const& abc, MPI_Comm const& c) -> Slice::Database { + ] (ABCTuple const& abc, MPI_Comm const& c) -> typename Slice::Database { chrono["db:comm:type:do"].start(); - auto MPI_LDB_ELEMENT = Slice::mpi::localDatabaseElement(); + auto MPI_LDB_ELEMENT = Slice::mpi::localDatabaseElement(); chrono["db:comm:type:do"].stop(); chrono["db:comm:ldb"].start(); - Slice::LocalDatabase ldb; + LocalDatabase ldb; for (auto const& tensor: unions) { auto const& tensorDb = tensor->buildLocalDatabase(abc); @@ -2145,7 +2148,8 @@ Atrip::Output Atrip::run(Atrip::Input const& in) { } chrono["db:comm:ldb"].stop(); - Slice::Database db(np * ldb.size(), ldb[0]); + typename + Slice::Database db(np * ldb.size(), ldb[0]); chrono["oneshot-db:comm:allgather"].start(); chrono["db:comm:allgather"].start(); @@ -2167,7 +2171,7 @@ Atrip::Output Atrip::run(Atrip::Input const& in) { }; auto doIOPhase - = [&unions, &rank, &np, &universe, &chrono] (Slice::Database const& db) { + = [&unions, &rank, &np, &universe, &chrono] (typename Slice::Database const& db) { const size_t localDBLength = db.size() / np; @@ -2217,7 +2221,7 @@ Atrip::Output Atrip::run(Atrip::Input const& in) { ; for (auto it = begin; it != end; ++it) { sendTag++; - Slice::LocalDatabaseElement const& el = *it; + typename Slice::LocalDatabaseElement const& el = *it; if (el.info.from.rank != rank) continue; @@ -2266,7 +2270,7 @@ Atrip::Output Atrip::run(Atrip::Input const& in) { // START MAIN LOOP ======================================================{{{1 - Slice::Database db; + typename Slice::Database db; for ( size_t i = abcIndex.first, iteration = 1 ; i < abcIndex.second @@ -2373,30 +2377,31 @@ Atrip::Output Atrip::run(Atrip::Input const& in) { ))) chrono["oneshot-doubles"].start(); chrono["doubles"].start(); - doublesContribution( abc, (size_t)No, (size_t)Nv - // -- VABCI - , abph.unwrapSlice(Slice::AB, abc) - , abph.unwrapSlice(Slice::AC, abc) - , abph.unwrapSlice(Slice::BC, abc) - , abph.unwrapSlice(Slice::BA, abc) - , abph.unwrapSlice(Slice::CA, abc) - , abph.unwrapSlice(Slice::CB, abc) - // -- VHHHA - , hhha.unwrapSlice(Slice::A, abc) - , hhha.unwrapSlice(Slice::B, abc) - , hhha.unwrapSlice(Slice::C, abc) - // -- TA - , taphh.unwrapSlice(Slice::A, abc) - , taphh.unwrapSlice(Slice::B, abc) - , taphh.unwrapSlice(Slice::C, abc) - // -- TABIJ - , tabhh.unwrapSlice(Slice::AB, abc) - , tabhh.unwrapSlice(Slice::AC, abc) - , tabhh.unwrapSlice(Slice::BC, abc) - // -- TIJK - , Tijk.data() - , chrono - ); + LOGREMOVE << "doubles " << iteration << "\n"; + doublesContribution( abc, (size_t)No, (size_t)Nv + // -- VABCI + , abph.unwrapSlice(Slice::AB, abc) + , abph.unwrapSlice(Slice::AC, abc) + , abph.unwrapSlice(Slice::BC, abc) + , abph.unwrapSlice(Slice::BA, abc) + , abph.unwrapSlice(Slice::CA, abc) + , abph.unwrapSlice(Slice::CB, abc) + // -- VHHHA + , hhha.unwrapSlice(Slice::A, abc) + , hhha.unwrapSlice(Slice::B, abc) + , hhha.unwrapSlice(Slice::C, abc) + // -- TA + , taphh.unwrapSlice(Slice::A, abc) + , taphh.unwrapSlice(Slice::B, abc) + , taphh.unwrapSlice(Slice::C, abc) + // -- TABIJ + , tabhh.unwrapSlice(Slice::AB, abc) + , tabhh.unwrapSlice(Slice::AC, abc) + , tabhh.unwrapSlice(Slice::BC, abc) + // -- TIJK + , Tijk.data() + , chrono + ); WITH_RANK << iteration << "-th doubles done\n"; chrono["doubles"].stop(); chrono["oneshot-doubles"].stop(); @@ -2414,12 +2419,12 @@ Atrip::Output Atrip::run(Atrip::Input const& in) { for (size_t I(0); I < Zijk.size(); I++) Zijk[I] = Tijk[I]; chrono["reorder"].stop(); chrono["singles"].start(); - singlesContribution( No, Nv, abc - , Tai.data() - , abhh.unwrapSlice(Slice::AB, abc) - , abhh.unwrapSlice(Slice::AC, abc) - , abhh.unwrapSlice(Slice::BC, abc) - , Zijk.data()); + singlesContribution( No, Nv, abc + , Tai.data() + , abhh.unwrapSlice(Slice::AB, abc) + , abhh.unwrapSlice(Slice::AC, abc) + , abhh.unwrapSlice(Slice::BC, abc) + , Zijk.data()); chrono["singles"].stop(); } @@ -2431,13 +2436,13 @@ Atrip::Output Atrip::run(Atrip::Input const& in) { int distinct(0); if (abc[0] == abc[1]) distinct++; if (abc[1] == abc[2]) distinct--; - const double epsabc(epsa[abc[0]] + epsa[abc[1]] + epsa[abc[2]]); + const F epsabc(epsa[abc[0]] + epsa[abc[1]] + epsa[abc[2]]); chrono["energy"].start(); if ( distinct == 0) - tupleEnergy = getEnergyDistinct(epsabc, epsi, Tijk, Zijk); + tupleEnergy = getEnergyDistinct(epsabc, epsi, Tijk, Zijk); else - tupleEnergy = getEnergySame(epsabc, epsi, Tijk, Zijk); + tupleEnergy = getEnergySame(epsabc, epsi, Tijk, Zijk); chrono["energy"].stop(); #if defined(HAVE_OCD) || defined(ATRIP_PRINT_TUPLES) @@ -2478,8 +2483,8 @@ Atrip::Output Atrip::run(Atrip::Input const& in) { << " :abc " << pretty_print(abc) << " :abcN " << pretty_print(*abcNext) << "\n"; - for (auto const& slice: u->slices) - WITH_RANK << "__gc__:guts:" << slice.info << "\n"; + // for (auto const& slice: u->slices) + // WITH_RANK << "__gc__:guts:" << slice.info << "\n"; u->clearUnusedSlicesForNext(*abcNext); WITH_RANK << "__gc__: checking validity\n"; @@ -2487,13 +2492,13 @@ Atrip::Output Atrip::run(Atrip::Input const& in) { #ifdef HAVE_OCD // check for validity of the slices for (auto type: u->sliceTypes) { - auto tuple = Slice::subtupleBySlice(abc, type); + auto tuple = Slice::subtupleBySlice(abc, type); for (auto& slice: u->slices) { if ( slice.info.type == type && slice.info.tuple == tuple && slice.isDirectlyFetchable() ) { - if (slice.info.state == Slice::Dispatched) + if (slice.info.state == Slice::Dispatched) throw std::domain_error( "This slice should not be undispatched! " + pretty_print(slice.info)); } @@ -2560,6 +2565,10 @@ Atrip::Output Atrip::run(Atrip::Input const& in) { return { - globalEnergy }; } +// instantiate +template Atrip::Output Atrip::run(Atrip::Input const& in); +template Atrip::Output Atrip::run(Atrip::Input const& in); + #+end_src