perf(GraphEngine): using caching clawed back ~10% performance
This commit is contained in:
@@ -875,6 +875,7 @@ namespace gridfire::engine {
|
||||
mutable CppAD::ADFun<double> m_rhsADFun; ///< CppAD function for the right-hand side of the ODE.
|
||||
mutable CppAD::ADFun<double> m_epsADFun; ///< CppAD function for the energy generation rate.
|
||||
mutable CppAD::sparse_jac_work m_jac_work; ///< Work object for sparse Jacobian calculations.
|
||||
mutable std::vector<double> m_local_abundance_cache;
|
||||
|
||||
bool m_has_been_primed = false; ///< Flag indicating if the engine has been primed.
|
||||
|
||||
|
||||
@@ -809,6 +809,8 @@ namespace gridfire::reaction {
|
||||
std::vector<RateCoefficientSet> m_rates; ///< List of rate coefficient sets from each source.
|
||||
bool m_weak = false;
|
||||
|
||||
mutable std::unordered_map<double, double> m_cached_rates;
|
||||
|
||||
private:
|
||||
/**
|
||||
* @brief Template implementation for calculating the total reaction rate.
|
||||
|
||||
@@ -655,6 +655,7 @@ namespace gridfire::engine {
|
||||
|
||||
}
|
||||
|
||||
|
||||
StepDerivatives<double> GraphEngine::calculateAllDerivativesUsingPrecomputation(
|
||||
const fourdst::composition::CompositionAbstract &comp,
|
||||
const std::vector<double> &bare_rates,
|
||||
@@ -672,8 +673,13 @@ namespace gridfire::engine {
|
||||
T9,
|
||||
rho
|
||||
);
|
||||
m_local_abundance_cache.clear();
|
||||
for (const auto& species: m_networkSpecies) {
|
||||
m_local_abundance_cache.push_back(comp.contains(species) ? comp.getMolarAbundance(species) : 0.0);
|
||||
}
|
||||
|
||||
StepDerivatives<double> result;
|
||||
std::vector<double> dydt_scratch(m_networkSpecies.size(), 0.0);
|
||||
|
||||
// --- Optimized loop ---
|
||||
std::vector<double> molarReactionFlows;
|
||||
@@ -693,15 +699,18 @@ namespace gridfire::engine {
|
||||
const fourdst::atomic::Species& reactant = m_networkSpecies[reactantIndex];
|
||||
const int power = precomputedReaction.reactant_powers[i];
|
||||
|
||||
if (!comp.contains(reactant)) {
|
||||
forwardAbundanceProduct = 0.0;
|
||||
break; // No need to continue if one of the reactants has zero abundance
|
||||
}
|
||||
const double factor = std::pow(comp.getMolarAbundance(reactant), power);
|
||||
const double abundance = m_local_abundance_cache[reactantIndex];
|
||||
|
||||
double factor;
|
||||
if (power == 1) { factor = abundance; }
|
||||
else if (power == 2) { factor = abundance * abundance; }
|
||||
else { factor = std::pow(abundance, power); }
|
||||
|
||||
if (!std::isfinite(factor)) {
|
||||
LOG_CRITICAL(m_logger, "Non-finite factor encountered in forward abundance product for reaction '{}'. Check input abundances for validity.", reaction->id());
|
||||
throw exceptions::BadRHSEngineError("Non-finite factor encountered in forward abundance product.");
|
||||
}
|
||||
|
||||
forwardAbundanceProduct *= factor;
|
||||
}
|
||||
|
||||
@@ -794,7 +803,8 @@ namespace gridfire::engine {
|
||||
|
||||
// Update the derivative for this species
|
||||
const double dydt_increment = static_cast<double>(stoichiometricCoefficient) * R_j;
|
||||
result.dydt.at(species) += dydt_increment;
|
||||
// result.dydt.at(species) += dydt_increment;
|
||||
dydt_scratch[speciesIndex] += dydt_increment;
|
||||
|
||||
if (m_store_intermediate_reaction_contributions) {
|
||||
result.reactionContributions.value()[species][std::string(reaction->id())] = dydt_increment;
|
||||
@@ -803,6 +813,11 @@ namespace gridfire::engine {
|
||||
reactionCounter++;
|
||||
}
|
||||
|
||||
// load scratch into result.dydt
|
||||
for (size_t i = 0; i < m_networkSpecies.size(); ++i) {
|
||||
result.dydt[m_networkSpecies[i]] = dydt_scratch[i];
|
||||
}
|
||||
|
||||
// --- Calculate the nuclear energy generation rate ---
|
||||
double massProductionRate = 0.0; // [mol][s^-1]
|
||||
for (const auto & species : m_networkSpecies) {
|
||||
|
||||
@@ -278,7 +278,12 @@ namespace gridfire::reaction {
|
||||
double Ye,
|
||||
double mue, const std::vector<double> &Y, const std::unordered_map<size_t, Species>& index_to_species_map
|
||||
) const {
|
||||
return calculate_rate<double>(T9);
|
||||
if (m_cached_rates.contains(T9)) {
|
||||
return m_cached_rates.at(T9);
|
||||
}
|
||||
const double rate = calculate_rate<double>(T9);
|
||||
m_cached_rates[T9] = rate;
|
||||
return rate;
|
||||
}
|
||||
|
||||
double LogicalReaclibReaction::calculate_log_rate_partial_deriv_wrt_T9(
|
||||
|
||||
Reference in New Issue
Block a user