From 4e2b3cb11f3d6d88880c795e21aa106404b71731 Mon Sep 17 00:00:00 2001 From: Emily Boudreaux Date: Sat, 6 Dec 2025 12:10:43 -0500 Subject: [PATCH] perf(GraphEngine): using caching clawed back ~10% performance --- src/include/gridfire/engine/engine_graph.h | 1 + src/include/gridfire/reaction/reaction.h | 2 ++ src/lib/engine/engine_graph.cpp | 27 +++++++++++++++++----- src/lib/reaction/reaction.cpp | 7 +++++- 4 files changed, 30 insertions(+), 7 deletions(-) diff --git a/src/include/gridfire/engine/engine_graph.h b/src/include/gridfire/engine/engine_graph.h index 0e044f7d..4ae79ba2 100644 --- a/src/include/gridfire/engine/engine_graph.h +++ b/src/include/gridfire/engine/engine_graph.h @@ -875,6 +875,7 @@ namespace gridfire::engine { mutable CppAD::ADFun m_rhsADFun; ///< CppAD function for the right-hand side of the ODE. mutable CppAD::ADFun m_epsADFun; ///< CppAD function for the energy generation rate. mutable CppAD::sparse_jac_work m_jac_work; ///< Work object for sparse Jacobian calculations. + mutable std::vector m_local_abundance_cache; bool m_has_been_primed = false; ///< Flag indicating if the engine has been primed. diff --git a/src/include/gridfire/reaction/reaction.h b/src/include/gridfire/reaction/reaction.h index d60d0281..8f1e8ab3 100644 --- a/src/include/gridfire/reaction/reaction.h +++ b/src/include/gridfire/reaction/reaction.h @@ -809,6 +809,8 @@ namespace gridfire::reaction { std::vector m_rates; ///< List of rate coefficient sets from each source. bool m_weak = false; + mutable std::unordered_map m_cached_rates; + private: /** * @brief Template implementation for calculating the total reaction rate. diff --git a/src/lib/engine/engine_graph.cpp b/src/lib/engine/engine_graph.cpp index 7ba0c75b..70151b9d 100644 --- a/src/lib/engine/engine_graph.cpp +++ b/src/lib/engine/engine_graph.cpp @@ -655,6 +655,7 @@ namespace gridfire::engine { } + StepDerivatives GraphEngine::calculateAllDerivativesUsingPrecomputation( const fourdst::composition::CompositionAbstract &comp, const std::vector &bare_rates, @@ -672,8 +673,13 @@ namespace gridfire::engine { T9, rho ); + m_local_abundance_cache.clear(); + for (const auto& species: m_networkSpecies) { + m_local_abundance_cache.push_back(comp.contains(species) ? comp.getMolarAbundance(species) : 0.0); + } StepDerivatives result; + std::vector dydt_scratch(m_networkSpecies.size(), 0.0); // --- Optimized loop --- std::vector molarReactionFlows; @@ -693,15 +699,18 @@ namespace gridfire::engine { const fourdst::atomic::Species& reactant = m_networkSpecies[reactantIndex]; const int power = precomputedReaction.reactant_powers[i]; - if (!comp.contains(reactant)) { - forwardAbundanceProduct = 0.0; - break; // No need to continue if one of the reactants has zero abundance - } - const double factor = std::pow(comp.getMolarAbundance(reactant), power); + const double abundance = m_local_abundance_cache[reactantIndex]; + + double factor; + if (power == 1) { factor = abundance; } + else if (power == 2) { factor = abundance * abundance; } + else { factor = std::pow(abundance, power); } + if (!std::isfinite(factor)) { LOG_CRITICAL(m_logger, "Non-finite factor encountered in forward abundance product for reaction '{}'. Check input abundances for validity.", reaction->id()); throw exceptions::BadRHSEngineError("Non-finite factor encountered in forward abundance product."); } + forwardAbundanceProduct *= factor; } @@ -794,7 +803,8 @@ namespace gridfire::engine { // Update the derivative for this species const double dydt_increment = static_cast(stoichiometricCoefficient) * R_j; - result.dydt.at(species) += dydt_increment; + // result.dydt.at(species) += dydt_increment; + dydt_scratch[speciesIndex] += dydt_increment; if (m_store_intermediate_reaction_contributions) { result.reactionContributions.value()[species][std::string(reaction->id())] = dydt_increment; @@ -803,6 +813,11 @@ namespace gridfire::engine { reactionCounter++; } + // load scratch into result.dydt + for (size_t i = 0; i < m_networkSpecies.size(); ++i) { + result.dydt[m_networkSpecies[i]] = dydt_scratch[i]; + } + // --- Calculate the nuclear energy generation rate --- double massProductionRate = 0.0; // [mol][s^-1] for (const auto & species : m_networkSpecies) { diff --git a/src/lib/reaction/reaction.cpp b/src/lib/reaction/reaction.cpp index debee6f6..310f4a91 100644 --- a/src/lib/reaction/reaction.cpp +++ b/src/lib/reaction/reaction.cpp @@ -278,7 +278,12 @@ namespace gridfire::reaction { double Ye, double mue, const std::vector &Y, const std::unordered_map& index_to_species_map ) const { - return calculate_rate(T9); + if (m_cached_rates.contains(T9)) { + return m_cached_rates.at(T9); + } + const double rate = calculate_rate(T9); + m_cached_rates[T9] = rate; + return rate; } double LogicalReaclibReaction::calculate_log_rate_partial_deriv_wrt_T9(