From 4651231d3b00ecb0431516099caed2249ebbb05d Mon Sep 17 00:00:00 2001 From: Gallo Alejandro Date: Fri, 12 Aug 2022 18:28:20 +0200 Subject: [PATCH] Update test bench for CUDA --- bench/test_main.cxx | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/bench/test_main.cxx b/bench/test_main.cxx index abc2a61..8f4336f 100644 --- a/bench/test_main.cxx +++ b/bench/test_main.cxx @@ -45,6 +45,16 @@ int main(int argc, char** argv) { checkpoint_percentage, "Percentage for checkpoints"); +#if defined(HAVE_CUDA) + size_t ooo_threads = 0, ooo_blocks = 0; + app.add_option("--ooo-blocks", + ooo_blocks, + "CUDA: Number of blocks per block for kernels going through ooo tensors"); + app.add_option("--ooo-threads", + ooo_threads, + "CUDA: Number of threads per block for kernels going through ooo tensors"); +#endif + CLI11_PARSE(app, argc, argv); CTF::World world(argc, argv); @@ -154,15 +164,24 @@ int main(int argc, char** argv) { .with_checkpointAtPercentage(checkpoint_percentage) .with_checkpointPath(checkpoint_path) .with_readCheckpointIfExists(!noCheckpoint) +#if defined(HAVE_CUDA) + .with_oooThreads(ooo_threads) + .with_oooBlocks(ooo_blocks) +#endif ; - auto out = atrip::Atrip::run(in); + try { + auto out = atrip::Atrip::run(in); + if (atrip::Atrip::rank == 0) + std::cout << "Energy: " << out.energy << std::endl; + } catch (const char* msg) { + if (atrip::Atrip::rank == 0) + std::cout << "Atrip throwed with msg:\n\t\t " << msg << "\n"; + } if (!in.deleteVppph) delete Vppph; - if (atrip::Atrip::rank == 0) - std::cout << "Energy: " << out.energy << std::endl; MPI_Finalize(); return 0;