feat(precomputation): added precomputation

preformance speed up by a factor of ~5
This commit is contained in:
2025-07-01 14:30:45 -04:00
parent 4ee6f816d0
commit 5b4db3ea43
3 changed files with 225 additions and 19 deletions

View File

@@ -63,6 +63,7 @@ namespace gridfire {
*/ */
static constexpr double MIN_JACOBIAN_THRESHOLD = 1e-24; static constexpr double MIN_JACOBIAN_THRESHOLD = 1e-24;
/** /**
* @class GraphEngine * @class GraphEngine
* @brief A reaction network engine that uses a graph-based representation. * @brief A reaction network engine that uses a graph-based representation.
@@ -108,7 +109,7 @@ namespace gridfire {
* This constructor uses the given set of reactions to construct the * This constructor uses the given set of reactions to construct the
* reaction network. * reaction network.
*/ */
explicit GraphEngine(reaction::LogicalReactionSet reactions); explicit GraphEngine(const reaction::LogicalReactionSet &reactions);
/** /**
* @brief Calculates the right-hand side (dY/dt) and energy generation rate. * @brief Calculates the right-hand side (dY/dt) and energy generation rate.
@@ -302,8 +303,32 @@ namespace gridfire {
[[nodiscard]] screening::ScreeningType getScreeningModel() const override; [[nodiscard]] screening::ScreeningType getScreeningModel() const override;
void setPrecomputation(bool precompute);
[[nodiscard]] bool isPrecomputationEnabled() const;
private: private:
struct PrecomputedReaction {
size_t reaction_index;
std::vector<size_t> unique_reactant_indices;
std::vector<int> reactant_powers;
double symmetry_factor;
std::vector<size_t> affected_species_indices;
std::vector<int> stoichiometric_coefficients;
};
struct constants {
const double u = Constants::getInstance().get("u").value; ///< Atomic mass unit in g.
const double Na = Constants::getInstance().get("N_a").value; ///< Avogadro's number.
const double c = Constants::getInstance().get("c").value; ///< Speed of light in cm/s.
};
private:
Config& m_config = Config::getInstance();
quill::Logger* m_logger = LogManager::getInstance().getLogger("log");
constants m_constants;
reaction::LogicalReactionSet m_reactions; ///< Set of REACLIB reactions in the network. reaction::LogicalReactionSet m_reactions; ///< Set of REACLIB reactions in the network.
std::unordered_map<std::string_view, reaction::Reaction*> m_reactionIDMap; ///< Map from reaction ID to REACLIBReaction. //PERF: This makes copies of REACLIBReaction and could be a performance bottleneck. std::unordered_map<std::string_view, reaction::Reaction*> m_reactionIDMap; ///< Map from reaction ID to REACLIBReaction. //PERF: This makes copies of REACLIBReaction and could be a performance bottleneck.
@@ -319,9 +344,9 @@ namespace gridfire {
screening::ScreeningType m_screeningType = screening::ScreeningType::BARE; ///< Screening type for the reaction network. Default to no screening. screening::ScreeningType m_screeningType = screening::ScreeningType::BARE; ///< Screening type for the reaction network. Default to no screening.
std::unique_ptr<screening::ScreeningModel> m_screeningModel = screening::selectScreeningModel(m_screeningType); std::unique_ptr<screening::ScreeningModel> m_screeningModel = screening::selectScreeningModel(m_screeningType);
Config& m_config = Config::getInstance(); bool m_usePrecomputation = true; ///< Flag to enable or disable using precomputed reactions for efficiency. Mathematically, this should not change the results. Generally end users should not need to change this.
Constants& m_constants = Constants::getInstance(); ///< Access to physical constants.
quill::Logger* m_logger = LogManager::getInstance().getLogger("log"); std::vector<PrecomputedReaction> m_precomputedReactions; ///< Precomputed reactions for efficiency.
private: private:
/** /**
@@ -377,6 +402,8 @@ namespace gridfire {
*/ */
void recordADTape(); void recordADTape();
void precomputeNetwork();
/** /**
* @brief Validates mass and charge conservation across all reactions. * @brief Validates mass and charge conservation across all reactions.
* *
@@ -405,6 +432,13 @@ namespace gridfire {
double T9 double T9
); );
[[nodiscard]] StepDerivatives<double> calculateAllDerivativesUsingPrecomputation(
const std::vector<double> &Y_in,
const std::vector<double>& bare_rates,
double T9,
double rho
) const;
/** /**
* @brief Calculates the molar reaction flow for a given reaction. * @brief Calculates the molar reaction flow for a given reaction.
* *
@@ -522,9 +556,9 @@ namespace gridfire {
Y[i] = CppAD::CondExpLt(Y[i], zero, zero, Y[i]); // Ensure no negative abundances Y[i] = CppAD::CondExpLt(Y[i], zero, zero, Y[i]); // Ensure no negative abundances
} }
const T u = static_cast<T>(m_constants.get("u").value); // Atomic mass unit in grams const T u = static_cast<T>(m_constants.u); // Atomic mass unit in grams
const T N_A = static_cast<T>(m_constants.get("N_a").value); // Avogadro's number in mol^-1 const T N_A = static_cast<T>(m_constants.Na); // Avogadro's number in mol^-1
const T c = static_cast<T>(m_constants.get("c").value); // Speed of light in cm/s const T c = static_cast<T>(m_constants.c); // Speed of light in cm/s
// --- SINGLE LOOP OVER ALL REACTIONS --- // --- SINGLE LOOP OVER ALL REACTIONS ---
for (size_t reactionIndex = 0; reactionIndex < m_reactions.size(); ++reactionIndex) { for (size_t reactionIndex = 0; reactionIndex < m_reactions.size(); ++reactionIndex) {

View File

@@ -28,19 +28,34 @@ namespace gridfire {
): ):
m_reactions(build_reaclib_nuclear_network(composition, false)) { m_reactions(build_reaclib_nuclear_network(composition, false)) {
syncInternalMaps(); syncInternalMaps();
precomputeNetwork();
} }
GraphEngine::GraphEngine(reaction::LogicalReactionSet reactions) : GraphEngine::GraphEngine(
m_reactions(std::move(reactions)) { const reaction::LogicalReactionSet &reactions
syncInternalMaps(); ) :
} m_reactions(reactions) {
syncInternalMaps();
precomputeNetwork();
}
StepDerivatives<double> GraphEngine::calculateRHSAndEnergy( StepDerivatives<double> GraphEngine::calculateRHSAndEnergy(
const std::vector<double> &Y, const std::vector<double> &Y,
const double T9, const double T9,
const double rho const double rho
) const { ) const {
return calculateAllDerivatives<double>(Y, T9, rho); if (m_usePrecomputation) {
std::vector<double> bare_rates;
bare_rates.reserve(m_reactions.size());
for (const auto& reaction: m_reactions) {
bare_rates.push_back(reaction.calculate_rate(T9));
}
// --- The public facing interface can always use the precomputed version since taping is done internally ---
return calculateAllDerivativesUsingPrecomputation(Y, bare_rates, T9, rho);
} else {
return calculateAllDerivatives<double>(Y, T9, rho);
}
} }
@@ -200,11 +215,90 @@ namespace gridfire {
// This allows for dynamic network modification while retaining caching for networks which are very similar. // This allows for dynamic network modification while retaining caching for networks which are very similar.
if (validationReactionSet != m_reactions) { if (validationReactionSet != m_reactions) {
LOG_DEBUG(m_logger, "Reaction set not cached. Rebuilding the reaction set for T9={} and culling={}.", T9, culling); LOG_DEBUG(m_logger, "Reaction set not cached. Rebuilding the reaction set for T9={} and culling={}.", T9, culling);
m_reactions = std::move(validationReactionSet); m_reactions = validationReactionSet;
syncInternalMaps(); // Re-sync internal maps after updating reactions. Note this will also retrace the AD tape. syncInternalMaps(); // Re-sync internal maps after updating reactions. Note this will also retrace the AD tape.
} }
} }
StepDerivatives<double> GraphEngine::calculateAllDerivativesUsingPrecomputation(
const std::vector<double> &Y_in,
const std::vector<double> &bare_rates,
const double T9,
const double rho
) const {
// --- Calculate screening factors ---
const std::vector<double> screeningFactors = m_screeningModel->calculateScreeningFactors(
m_reactions,
m_networkSpecies,
Y_in,
T9,
rho
);
// --- Optimized loop ---
std::vector<double> molarReactionFlows;
molarReactionFlows.reserve(m_precomputedReactions.size());
for (const auto& precomp : m_precomputedReactions) {
double abundanceProduct = 1.0;
bool below_threshold = false;
for (size_t i = 0; i < precomp.unique_reactant_indices.size(); ++i) {
const size_t reactantIndex = precomp.unique_reactant_indices[i];
const int power = precomp.reactant_powers[i];
const double abundance = Y_in[reactantIndex];
if (abundance < MIN_ABUNDANCE_THRESHOLD) {
below_threshold = true;
break;
}
abundanceProduct *= std::pow(Y_in[reactantIndex], power);
}
if (below_threshold) {
molarReactionFlows.push_back(0.0);
continue; // Skip this reaction if any reactant is below the abundance threshold
}
const double bare_rate = bare_rates[precomp.reaction_index];
const double screeningFactor = screeningFactors[precomp.reaction_index];
const size_t numReactants = m_reactions[precomp.reaction_index].reactants().size();
const double molarReactionFlow =
screeningFactor *
bare_rate *
precomp.symmetry_factor *
abundanceProduct *
std::pow(rho, numReactants);
molarReactionFlows.push_back(molarReactionFlow);
}
// --- Assemble molar abundance derivatives ---
StepDerivatives<double> result;
result.dydt.assign(m_networkSpecies.size(), 0.0); // Initialize derivatives to zero
for (size_t j = 0; j < m_precomputedReactions.size(); ++j) {
const auto& precomp = m_precomputedReactions[j];
const double R_j = molarReactionFlows[j];
for (size_t i = 0; i < precomp.affected_species_indices.size(); ++i) {
const size_t speciesIndex = precomp.affected_species_indices[i];
const int stoichiometricCoefficient = precomp.stoichiometric_coefficients[i];
// Update the derivative for this species
result.dydt[speciesIndex] += static_cast<double>(stoichiometricCoefficient) * R_j / rho;
}
}
// --- Calculate the nuclear energy generation rate ---
double massProductionRate = 0.0; // [mol][s^-1]
for (size_t i = 0; i < m_networkSpecies.size(); ++i) {
const auto& species = m_networkSpecies[i];
massProductionRate += result.dydt[i] * species.mass() * m_constants.u;
}
result.nuclearEnergyGenerationRate = -massProductionRate * m_constants.Na * m_constants.c * m_constants.c; // [erg][s^-1][g^-1]
return result;
}
// --- Generate Stoichiometry Matrix --- // --- Generate Stoichiometry Matrix ---
void GraphEngine::generateStoichiometryMatrix() { void GraphEngine::generateStoichiometryMatrix() {
LOG_TRACE_L1(m_logger, "Generating stoichiometry matrix..."); LOG_TRACE_L1(m_logger, "Generating stoichiometry matrix...");
@@ -272,6 +366,14 @@ namespace gridfire {
return m_screeningType; return m_screeningType;
} }
void GraphEngine::setPrecomputation(const bool precompute) {
m_usePrecomputation = precompute;
}
bool GraphEngine::isPrecomputationEnabled() const {
return m_usePrecomputation;
}
double GraphEngine::calculateMolarReactionFlow( double GraphEngine::calculateMolarReactionFlow(
const reaction::Reaction &reaction, const reaction::Reaction &reaction,
const std::vector<double> &Y, const std::vector<double> &Y,
@@ -450,7 +552,7 @@ namespace gridfire {
} }
void GraphEngine::update(const NetIn &netIn) { void GraphEngine::update(const NetIn &netIn) {
return; // No-op for GraphEngine, as it does not support manually triggering updates. // No-op for GraphEngine, as it does not support manually triggering updates.
} }
void GraphEngine::recordADTape() { void GraphEngine::recordADTape() {
@@ -471,7 +573,7 @@ namespace gridfire {
// Their numeric values are irrelevant except for in so far as they avoid numerical instabilities. // Their numeric values are irrelevant except for in so far as they avoid numerical instabilities.
// Distribute total mass fraction uniformly between species in the dummy variable space // Distribute total mass fraction uniformly between species in the dummy variable space
const auto uniformMassFraction = static_cast<CppAD::AD<double>>(1.0 / numSpecies); const auto uniformMassFraction = static_cast<CppAD::AD<double>>(1.0 / static_cast<double>(numSpecies));
std::vector<CppAD::AD<double>> adInput(numADInputs, uniformMassFraction); std::vector<CppAD::AD<double>> adInput(numADInputs, uniformMassFraction);
adInput[numSpecies] = 1.0; // Dummy T9 adInput[numSpecies] = 1.0; // Dummy T9
adInput[numSpecies + 1] = 1.0; // Dummy rho adInput[numSpecies + 1] = 1.0; // Dummy rho
@@ -497,4 +599,53 @@ namespace gridfire {
LOG_TRACE_L1(m_logger, "AD tape recorded successfully for the RHS calculation. Number of independent variables: {}.", LOG_TRACE_L1(m_logger, "AD tape recorded successfully for the RHS calculation. Number of independent variables: {}.",
adInput.size()); adInput.size());
} }
void GraphEngine::precomputeNetwork() {
LOG_TRACE_L1(m_logger, "Pre-computing constant components of GraphNetwork state...");
// --- Reverse map for fast species lookups ---
std::unordered_map<fourdst::atomic::Species, size_t> speciesIndexMap;
for (size_t i = 0; i < m_networkSpecies.size(); ++i) {
speciesIndexMap[m_networkSpecies[i]] = i;
}
m_precomputedReactions.clear();
m_precomputedReactions.reserve(m_reactions.size());
for (size_t i = 0; i < m_reactions.size(); ++i) {
const auto& reaction = m_reactions[i];
PrecomputedReaction precomp;
precomp.reaction_index = i;
// --- Precompute reactant information ---
// Count occurrences for each reactant to determine powers and symmetry
std::unordered_map<size_t, int> reactantCounts;
for (const auto& reactant: reaction.reactants()) {
size_t reactantIndex = speciesIndexMap.at(reactant);
reactantCounts[reactantIndex]++;
}
double symmetryDenominator = 1.0;
for (const auto& [index, count] : reactantCounts) {
precomp.unique_reactant_indices.push_back(index);
precomp.reactant_powers.push_back(count);
symmetryDenominator *= 1.0/std::tgamma(count + 1);
}
precomp.symmetry_factor = symmetryDenominator;
// --- Precompute stoichiometry information ---
const auto stoichiometryMap = reaction.stoichiometry();
precomp.affected_species_indices.reserve(stoichiometryMap.size());
precomp.stoichiometric_coefficients.reserve(stoichiometryMap.size());
for (const auto& [species, coeff] : stoichiometryMap) {
precomp.affected_species_indices.push_back(speciesIndexMap.at(species));
precomp.stoichiometric_coefficients.push_back(coeff);
}
m_precomputedReactions.push_back(std::move(precomp));
}
}
} }

View File

@@ -19,9 +19,23 @@
#include "quill/Backend.h" #include "quill/Backend.h"
#include "quill/Frontend.h" #include "quill/Frontend.h"
#include <chrono>
#include <functional>
// Keep a copy of the previous handler // Keep a copy of the previous handler
static std::terminate_handler g_previousHandler = nullptr; static std::terminate_handler g_previousHandler = nullptr;
void measure_execution_time(const std::function<void()>& callback, const std::string& name)
{
// variable names in camelCase
const auto startTime = std::chrono::steady_clock::now();
callback();
const auto endTime = std::chrono::steady_clock::now();
const auto duration = std::chrono::duration_cast<std::chrono::nanoseconds>(endTime - startTime);
std::cout << "Execution time for " << name << ": "
<< duration.count()/1e9 << " s\n";
}
void quill_terminate_handler() void quill_terminate_handler()
{ {
// 1. Stop the Quill backend (flushes all sinks and joins thread) // 1. Stop the Quill backend (flushes all sinks and joins thread)
@@ -65,21 +79,28 @@ int main() {
NetOut netOut; NetOut netOut;
netIn.dt0 = 1e12; // netIn.dt0 = 1e12;
// approx8::Approx8Network approx8Network; // approx8::Approx8Network approx8Network;
// netOut = approx8Network.evaluate(netIn); // measure_execution_time([&]() {
// netOut = approx8Network.evaluate(netIn);
// }, "Approx8 Network Initialization");
// std::cout << "Approx8 Network H-1: " << netOut.composition.getMassFraction("H-1") << " in " << netOut.num_steps << " steps." << std::endl; // std::cout << "Approx8 Network H-1: " << netOut.composition.getMassFraction("H-1") << " in " << netOut.num_steps << " steps." << std::endl;
netIn.dt0 = 1e-15; netIn.dt0 = 1e-15;
GraphEngine ReaclibEngine(composition); GraphEngine ReaclibEngine(composition);
ReaclibEngine.setPrecomputation(true);
// AdaptiveEngineView adaptiveEngine(ReaclibEngine); // AdaptiveEngineView adaptiveEngine(ReaclibEngine);
io::SimpleReactionListFileParser parser{}; io::SimpleReactionListFileParser parser{};
FileDefinedEngineView approx8EngineView(ReaclibEngine, "approx8.net", parser); FileDefinedEngineView approx8EngineView(ReaclibEngine, "approx8.net", parser);
approx8EngineView.setScreeningModel(screening::ScreeningType::WEAK); approx8EngineView.setScreeningModel(screening::ScreeningType::WEAK);
solver::QSENetworkSolver solver(approx8EngineView); solver::QSENetworkSolver solver(approx8EngineView);
netOut = solver.evaluate(netIn); netOut = solver.evaluate(netIn);
std::cout << "QSE Graph Network H-1: " << netOut.composition.getMassFraction("H-1") << " in " << netOut.num_steps << " steps." << std::endl;
// measure_execution_time([&]() {
// netOut = solver.evaluate(netIn);
// }, "Approx8 Network Evaluation (Precomputation)");
// ReaclibEngine.setPrecomputation(false);
// std::cout << "Precomputation H-1: " << netOut.composition.getMassFraction("H-1") << " in " << netOut.num_steps << " steps." << std::endl;
} }