perf(GraphEngine): using caching clawed back ~10% performance

This commit is contained in:
2025-12-06 12:10:43 -05:00
parent b6f452e74c
commit 4e2b3cb11f
4 changed files with 30 additions and 7 deletions

View File

@@ -875,6 +875,7 @@ namespace gridfire::engine {
mutable CppAD::ADFun<double> m_rhsADFun; ///< CppAD function for the right-hand side of the ODE. mutable CppAD::ADFun<double> m_rhsADFun; ///< CppAD function for the right-hand side of the ODE.
mutable CppAD::ADFun<double> m_epsADFun; ///< CppAD function for the energy generation rate. mutable CppAD::ADFun<double> m_epsADFun; ///< CppAD function for the energy generation rate.
mutable CppAD::sparse_jac_work m_jac_work; ///< Work object for sparse Jacobian calculations. mutable CppAD::sparse_jac_work m_jac_work; ///< Work object for sparse Jacobian calculations.
mutable std::vector<double> m_local_abundance_cache;
bool m_has_been_primed = false; ///< Flag indicating if the engine has been primed. bool m_has_been_primed = false; ///< Flag indicating if the engine has been primed.

View File

@@ -809,6 +809,8 @@ namespace gridfire::reaction {
std::vector<RateCoefficientSet> m_rates; ///< List of rate coefficient sets from each source. std::vector<RateCoefficientSet> m_rates; ///< List of rate coefficient sets from each source.
bool m_weak = false; bool m_weak = false;
mutable std::unordered_map<double, double> m_cached_rates;
private: private:
/** /**
* @brief Template implementation for calculating the total reaction rate. * @brief Template implementation for calculating the total reaction rate.

View File

@@ -655,6 +655,7 @@ namespace gridfire::engine {
} }
StepDerivatives<double> GraphEngine::calculateAllDerivativesUsingPrecomputation( StepDerivatives<double> GraphEngine::calculateAllDerivativesUsingPrecomputation(
const fourdst::composition::CompositionAbstract &comp, const fourdst::composition::CompositionAbstract &comp,
const std::vector<double> &bare_rates, const std::vector<double> &bare_rates,
@@ -672,8 +673,13 @@ namespace gridfire::engine {
T9, T9,
rho rho
); );
m_local_abundance_cache.clear();
for (const auto& species: m_networkSpecies) {
m_local_abundance_cache.push_back(comp.contains(species) ? comp.getMolarAbundance(species) : 0.0);
}
StepDerivatives<double> result; StepDerivatives<double> result;
std::vector<double> dydt_scratch(m_networkSpecies.size(), 0.0);
// --- Optimized loop --- // --- Optimized loop ---
std::vector<double> molarReactionFlows; std::vector<double> molarReactionFlows;
@@ -693,15 +699,18 @@ namespace gridfire::engine {
const fourdst::atomic::Species& reactant = m_networkSpecies[reactantIndex]; const fourdst::atomic::Species& reactant = m_networkSpecies[reactantIndex];
const int power = precomputedReaction.reactant_powers[i]; const int power = precomputedReaction.reactant_powers[i];
if (!comp.contains(reactant)) { const double abundance = m_local_abundance_cache[reactantIndex];
forwardAbundanceProduct = 0.0;
break; // No need to continue if one of the reactants has zero abundance double factor;
} if (power == 1) { factor = abundance; }
const double factor = std::pow(comp.getMolarAbundance(reactant), power); else if (power == 2) { factor = abundance * abundance; }
else { factor = std::pow(abundance, power); }
if (!std::isfinite(factor)) { if (!std::isfinite(factor)) {
LOG_CRITICAL(m_logger, "Non-finite factor encountered in forward abundance product for reaction '{}'. Check input abundances for validity.", reaction->id()); LOG_CRITICAL(m_logger, "Non-finite factor encountered in forward abundance product for reaction '{}'. Check input abundances for validity.", reaction->id());
throw exceptions::BadRHSEngineError("Non-finite factor encountered in forward abundance product."); throw exceptions::BadRHSEngineError("Non-finite factor encountered in forward abundance product.");
} }
forwardAbundanceProduct *= factor; forwardAbundanceProduct *= factor;
} }
@@ -794,7 +803,8 @@ namespace gridfire::engine {
// Update the derivative for this species // Update the derivative for this species
const double dydt_increment = static_cast<double>(stoichiometricCoefficient) * R_j; const double dydt_increment = static_cast<double>(stoichiometricCoefficient) * R_j;
result.dydt.at(species) += dydt_increment; // result.dydt.at(species) += dydt_increment;
dydt_scratch[speciesIndex] += dydt_increment;
if (m_store_intermediate_reaction_contributions) { if (m_store_intermediate_reaction_contributions) {
result.reactionContributions.value()[species][std::string(reaction->id())] = dydt_increment; result.reactionContributions.value()[species][std::string(reaction->id())] = dydt_increment;
@@ -803,6 +813,11 @@ namespace gridfire::engine {
reactionCounter++; reactionCounter++;
} }
// load scratch into result.dydt
for (size_t i = 0; i < m_networkSpecies.size(); ++i) {
result.dydt[m_networkSpecies[i]] = dydt_scratch[i];
}
// --- Calculate the nuclear energy generation rate --- // --- Calculate the nuclear energy generation rate ---
double massProductionRate = 0.0; // [mol][s^-1] double massProductionRate = 0.0; // [mol][s^-1]
for (const auto & species : m_networkSpecies) { for (const auto & species : m_networkSpecies) {

View File

@@ -278,7 +278,12 @@ namespace gridfire::reaction {
double Ye, double Ye,
double mue, const std::vector<double> &Y, const std::unordered_map<size_t, Species>& index_to_species_map double mue, const std::vector<double> &Y, const std::unordered_map<size_t, Species>& index_to_species_map
) const { ) const {
return calculate_rate<double>(T9); if (m_cached_rates.contains(T9)) {
return m_cached_rates.at(T9);
}
const double rate = calculate_rate<double>(T9);
m_cached_rates[T9] = rate;
return rate;
} }
double LogicalReaclibReaction::calculate_log_rate_partial_deriv_wrt_T9( double LogicalReaclibReaction::calculate_log_rate_partial_deriv_wrt_T9(