perf(thread saftey): All Engines are now thread safe

Previously engines were not thread safe, a seperate engine would be
needed for every thread. This is no longer the case. This allows for
much more efficient parallel execution
This commit is contained in:
2025-12-12 12:08:47 -05:00
parent c7574a2f3d
commit e114c0e240
46 changed files with 3685 additions and 1604 deletions

View File

@@ -0,0 +1,178 @@
// ReSharper disable CppUnusedIncludeDirective
#include <iostream>
#include <fstream>
#include <chrono>
#include <thread>
#include <format>
#include "gridfire/gridfire.h"
#include <cppad/utility/thread_alloc.hpp> // Required for parallel_setup
#include "fourdst/composition/composition.h"
#include "fourdst/logging/logging.h"
#include "fourdst/atomic/species.h"
#include "fourdst/composition/utils.h"
#include "quill/Logger.h"
#include "quill/Backend.h"
#include <clocale>
#include "gridfire/reaction/reaclib.h"
#include <omp.h>
unsigned long get_thread_id() {
return static_cast<unsigned long>(omp_get_thread_num());
}
bool in_parallel() {
return omp_in_parallel() != 0;
}
gridfire::NetIn init(const double temp, const double rho, const double tMax) {
std::setlocale(LC_ALL, "");
quill::Logger* logger = fourdst::logging::LogManager::getInstance().getLogger("log");
logger->set_log_level(quill::LogLevel::TraceL2);
using namespace gridfire;
const std::vector<double> X = {0.7081145999999999, 2.94e-5, 0.276, 0.003, 0.0011, 9.62e-3, 1.62e-3, 5.16e-4};
const std::vector<std::string> symbols = {"H-1", "He-3", "He-4", "C-12", "N-14", "O-16", "Ne-20", "Mg-24"};
const fourdst::composition::Composition composition = fourdst::composition::buildCompositionFromMassFractions(symbols, X);
NetIn netIn;
netIn.composition = composition;
netIn.temperature = temp;
netIn.density = rho;
netIn.energy = 0;
netIn.tMax = tMax;
netIn.dt0 = 1e-12;
return netIn;
}
int main() {
using namespace gridfire;
constexpr size_t breaks = 1;
constexpr double temp = 1.5e7;
constexpr double rho = 1.5e2;
constexpr double tMax = 3.1536e+16/breaks;
const NetIn netIn = init(temp, rho, tMax);
policy::MainSequencePolicy stellarPolicy(netIn.composition);
const policy::ConstructionResults construct = stellarPolicy.construct();
std::println("Sandbox Engine Stack: {}", stellarPolicy);
std::println("Scratch Blob State: {}", *construct.scratch_blob);
constexpr size_t runs = 1000;
auto startTime = std::chrono::high_resolution_clock::now();
// arrays to store timings
std::array<std::chrono::duration<double>, runs> setup_times;
std::array<std::chrono::duration<double>, runs> eval_times;
std::array<NetOut, runs> serial_results;
for (size_t i = 0; i < runs; ++i) {
auto start_setup_time = std::chrono::high_resolution_clock::now();
solver::CVODESolverStrategy solver(construct.engine, *construct.scratch_blob);
solver.set_stdout_logging_enabled(false);
auto end_setup_time = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> setup_elapsed = end_setup_time - start_setup_time;
setup_times[i] = setup_elapsed;
auto start_eval_time = std::chrono::high_resolution_clock::now();
const NetOut netOut = solver.evaluate(netIn);
auto end_eval_time = std::chrono::high_resolution_clock::now();
serial_results[i] = netOut;
std::chrono::duration<double> eval_elapsed = end_eval_time - start_eval_time;
eval_times[i] = eval_elapsed;
}
auto endTime = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> elapsed = endTime - startTime;
std::println("");
// Summarize serial timings
double total_setup_time = 0.0;
double total_eval_time = 0.0;
for (size_t i = 0; i < runs; ++i) {
total_setup_time += setup_times[i].count();
total_eval_time += eval_times[i].count();
}
std::println("Average Setup Time over {} runs: {:.6f} seconds", runs, total_setup_time / runs);
std::println("Average Evaluation Time over {} runs: {:.6f} seconds", runs, total_eval_time / runs);
std::println("Total Time for {} runs: {:.6f} seconds", runs, elapsed.count());
std::println("Final H-1 Abundances Serial: {}", serial_results[0].composition.getMolarAbundance(fourdst::atomic::H_1));
CppAD::thread_alloc::parallel_setup(
static_cast<size_t>(omp_get_max_threads()), // Max threads
[]() -> bool { return in_parallel(); }, // Function to get thread ID
[]() -> size_t { return get_thread_id(); } // Function to check parallel state
);
// OPTIONAL: Prevent CppAD from returning memory to the system
// during execution to reduce overhead (can speed up tight loops)
CppAD::thread_alloc::hold_memory(true);
std::array<NetOut, runs> parallelResults;
std::array<std::chrono::duration<double>, runs> setupTimes;
std::array<std::chrono::duration<double>, runs> evalTimes;
std::array<std::unique_ptr<gridfire::engine::scratch::StateBlob>, runs> workspaces;
for (size_t i = 0; i < runs; ++i) {
workspaces[i] = construct.scratch_blob->clone_structure();
}
// Parallel runs
startTime = std::chrono::high_resolution_clock::now();
#pragma omp parallel for
for (size_t i = 0; i < runs; ++i) {
auto start_setup_time = std::chrono::high_resolution_clock::now();
solver::CVODESolverStrategy solver(construct.engine, *workspaces[i]);
solver.set_stdout_logging_enabled(false);
auto end_setup_time = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> setup_elapsed = end_setup_time - start_setup_time;
setupTimes[i] = setup_elapsed;
auto start_eval_time = std::chrono::high_resolution_clock::now();
parallelResults[i] = solver.evaluate(netIn);
auto end_eval_time = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> eval_elapsed = end_eval_time - start_eval_time;
evalTimes[i] = eval_elapsed;
}
endTime = std::chrono::high_resolution_clock::now();
elapsed = endTime - startTime;
std::println("");
// Summarize parallel timings
total_setup_time = 0.0;
total_eval_time = 0.0;
for (size_t i = 0; i < runs; ++i) {
total_setup_time += setupTimes[i].count();
total_eval_time += evalTimes[i].count();
}
std::println("Average Parallel Setup Time over {} runs: {:.6f} seconds", runs, total_setup_time / runs);
std::println("Average Parallel Evaluation Time over {} runs: {:.6f} seconds", runs, total_eval_time / runs);
std::println("Total Parallel Time for {} runs: {:.6f} seconds", runs, elapsed.count());
std::println("Final H-1 Abundances Parallel: {}", utils::iterable_to_delimited_string(parallelResults, ",", [](const auto& result) {
return result.composition.getMolarAbundance(fourdst::atomic::H_1);
}));
std::println("========== Summary ==========");
std::println("Serial Runs:");
std::println(" Average Setup Time: {:.6f} seconds", total_setup_time / runs);
std::println(" Average Evaluation Time: {:.6f} seconds", total_eval_time / runs);
std::println("Parallel Runs:");
std::println(" Average Setup Time: {:.6f} seconds", total_setup_time / runs);
std::println(" Average Evaluation Time: {:.6f} seconds", total_eval_time / runs);
std::println("Difference:");
std::println(" Setup Time Difference: {:.6f} seconds", (total_setup_time / runs) - (total_setup_time / runs));
std::println(" Evaluation Time Difference: {:.6f} seconds", (total_eval_time / runs) - (total_eval_time / runs));
std::println(" Setup Time Fractional Difference: {:.2f}%", ((total_setup_time / runs) - (total_setup_time / runs)) / (total_setup_time / runs) * 100.0);
std::println(" Evaluation Time Fractional Difference: {:.2f}%", ((total_eval_time / runs) - (total_eval_time / runs)) / (total_eval_time / runs) * 100.0);
}

View File

@@ -0,0 +1,5 @@
executable(
'gf_bench_single_zone_parallel',
'main.cpp',
dependencies: [gridfire_dep],
)

0
benchmarks/meson.build Normal file
View File

View File

@@ -35,11 +35,15 @@ subdir('src')
# Build the Python bindings
subdir('build-python')
# Buil the test suite
# Build the test suite
subdir('tests')
# Build the tool suite
subdir('tools')
# Build the benchmark suite
subdir('benchmarks')
# Build the pkg-config file
subdir('build-extra/pkg-config')

View File

@@ -10,4 +10,5 @@ option('python_target_version', type: 'string', value: '3.13', description: 'Tar
option('build_c_api', type: 'boolean', value: true, description: 'compile the C API')
option('build_tools', type: 'boolean', value: true, description: 'build the GridFire command line tools')
option('openmp_support', type: 'boolean', value: false, description: 'Enable OpenMP support for parallelization')
option('use_mimalloc', type: 'boolean', value: true, description: 'Use mimalloc as the memory allocator for GridFire. Generally this is ~10% faster than the system allocator.')
option('use_mimalloc', type: 'boolean', value: true, description: 'Use mimalloc as the memory allocator for GridFire. Generally this is ~10% faster than the system allocator.')
option('build_benchmarks', type: 'boolean', value: false, description: 'build the benchmark suite')

View File

@@ -29,6 +29,7 @@
#pragma once
#include "gridfire/engine/engine_abstract.h"
#include "gridfire/engine/scratchpads/blob.h"
#include <vector>
#include <string>
@@ -49,6 +50,7 @@ namespace gridfire::engine::diagnostics {
* @return std::optional<nlohmann::json> JSON object containing the limiting species report if `json` is true; otherwise, std::nullopt.
*/
std::optional<nlohmann::json> report_limiting_species(
scratch::StateBlob& ctx,
const DynamicEngine &engine,
const std::vector<double> &Y_full,
const std::vector<double> &E_full,
@@ -71,6 +73,7 @@ namespace gridfire::engine::diagnostics {
* @return std::optional<nlohmann::json> JSON object containing the species balance report if `json` is true; otherwise, std::nullopt.
*/
std::optional<nlohmann::json> inspect_species_balance(
scratch::StateBlob& ctx,
const DynamicEngine& engine,
const std::string& species_name,
const fourdst::composition::Composition &comp,
@@ -89,6 +92,7 @@ namespace gridfire::engine::diagnostics {
* @return std::optional<nlohmann::json> JSON object containing the Jacobian stiffness report if `json` is true; otherwise, std::nullopt.
*/
std::optional<nlohmann::json> inspect_jacobian_stiffness(
scratch::StateBlob& ctx,
const DynamicEngine &engine,
const fourdst::composition::Composition &comp,
double T9,

View File

@@ -180,4 +180,6 @@
#include "gridfire/engine/views/engine_views.h"
#include "gridfire/engine/procedures/engine_procedures.h"
#include "gridfire/engine/types/engine_types.h"
#include "gridfire/engine/diagnostics/dynamic_engine_diagnostics.h"
#include "gridfire/engine/diagnostics/dynamic_engine_diagnostics.h"
#include "gridfire/engine/scratchpads/scratchpads.h"

View File

@@ -6,9 +6,10 @@
#include "gridfire/screening/screening_types.h"
#include "gridfire/engine/types/reporting.h"
#include "gridfire/engine/types/building.h"
#include "gridfire/engine/types/jacobian.h"
#include "gridfire/engine/scratchpads/blob.h"
#include "fourdst/composition/composition_abstract.h"
#include <vector>
@@ -136,7 +137,9 @@ namespace gridfire::engine {
* @brief Get the list of species in the network.
* @return Vector of Species objects representing all network species.
*/
[[nodiscard]] virtual const std::vector<fourdst::atomic::Species>& getNetworkSpecies() const = 0;
[[nodiscard]] virtual const std::vector<fourdst::atomic::Species>& getNetworkSpecies(
scratch::StateBlob& ctx
) const = 0;
/**
* @brief Calculate the right-hand side (dY/dt) and energy generation.
@@ -153,6 +156,7 @@ namespace gridfire::engine {
* rate for the current state.
*/
[[nodiscard]] virtual std::expected<StepDerivatives<double>, EngineStatus> calculateRHSAndEnergy(
scratch::StateBlob&,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho,
@@ -187,6 +191,7 @@ namespace gridfire::engine {
* for the current state. The matrix can then be accessed via getJacobianMatrixEntry().
*/
[[nodiscard]] virtual NetworkJacobian generateJacobianMatrix(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho
@@ -205,6 +210,7 @@ namespace gridfire::engine {
* The matrix can then be accessed via getJacobianMatrixEntry().
*/
[[nodiscard]] virtual NetworkJacobian generateJacobianMatrix(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho,
@@ -228,6 +234,7 @@ namespace gridfire::engine {
* @see getJacobianMatrixEntry()
*/
[[nodiscard]] virtual NetworkJacobian generateJacobianMatrix(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho,
@@ -235,28 +242,6 @@ namespace gridfire::engine {
) const = 0;
/**
* @brief Generate the stoichiometry matrix for the network.
*
* This method must compute and store the stoichiometry matrix,
* which encodes the net change of each species in each reaction.
*/
virtual void generateStoichiometryMatrix() = 0;
/**
* @brief Get an entry from the stoichiometry matrix.
*
* @param species species to look up stoichiometry for.
* @param reaction reaction to find
* @return Stoichiometric coefficient for the species in the reaction.
*
* The stoichiometry matrix must have been generated by generateStoichiometryMatrix().
*/
[[nodiscard]] virtual int getStoichiometryMatrixEntry(
const fourdst::atomic::Species& species,
const reaction::Reaction& reaction
) const = 0;
/**
* @brief Calculate the molar reaction flow for a given reaction.
*
@@ -270,6 +255,7 @@ namespace gridfire::engine {
* under the current state.
*/
[[nodiscard]] virtual double calculateMolarReactionFlow(
scratch::StateBlob& ctx,
const reaction::Reaction& reaction,
const fourdst::composition::CompositionAbstract &comp,
double T9,
@@ -288,6 +274,7 @@ namespace gridfire::engine {
* generation rate with respect to temperature and density for the current state.
*/
[[nodiscard]] virtual EnergyDerivatives calculateEpsDerivatives(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho
@@ -298,18 +285,10 @@ namespace gridfire::engine {
*
* @return Reference to the LogicalReactionSet containing all reactions.
*/
[[nodiscard]] virtual const reaction::ReactionSet& getNetworkReactions() const = 0;
[[nodiscard]] virtual const reaction::ReactionSet& getNetworkReactions(
scratch::StateBlob& ctx
) const = 0;
/**
* @brief Set the reactions for the network.
*
* @param reactions The set of reactions to use in the network.
*
* This method replaces the current set of reactions in the network
* with the provided set. It marks the engine as stale, requiring
* regeneration of matrices and recalculation of rates.
*/
virtual void setNetworkReactions(const reaction::ReactionSet& reactions) = 0;
/**
* @brief Compute timescales for all species in the network.
@@ -323,6 +302,7 @@ namespace gridfire::engine {
* which can be used for timestep control, diagnostics, and reaction network culling.
*/
[[nodiscard]] virtual std::expected<std::unordered_map<fourdst::atomic::Species, double>, EngineStatus> getSpeciesTimescales(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho
@@ -340,13 +320,14 @@ namespace gridfire::engine {
* which can be useful for understanding reaction flows and equilibrium states.
*/
[[nodiscard]] virtual std::expected<std::unordered_map<fourdst::atomic::Species, double>, EngineStatus> getSpeciesDestructionTimescales(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho
) const = 0;
/**
* @brief Update the internal state of the engine.
* @brief Update the thread local scratch pad state of a network.
*
* @param netIn A struct containing the current network input, such as
* temperature, density, and composition.
@@ -365,47 +346,10 @@ namespace gridfire::engine {
*
* @post The internal state of the engine is updated to reflect the new conditions.
*/
virtual fourdst::composition::Composition update(const NetIn &netIn) = 0;
/**
* @brief Check if the engine's internal state is stale.
*
* @param netIn A struct containing the current network input, such as
* temperature, density, and composition.
* @return True if the engine's state is stale and needs to be updated; false otherwise.
*
* This method allows derived classes to determine if their internal state
* is out-of-date with respect to the provided network conditions. If the engine
* is stale, it may require a call to `update()` before performing calculations.
*
* @par Usage Example:
* @code
* NetIn input = { ... };
* if (myEngine.isStale(input)) {
* // Update the engine before proceeding
* }
* @endcode
*/
[[nodiscard]]
virtual bool isStale(const NetIn& netIn) = 0;
/**
* @brief Set the electron screening model.
*
* @param model The type of screening model to use for reaction rate calculations.
*
* This method allows changing the screening model at runtime. Screening corrections
* account for the electrostatic shielding of nuclei by electrons, which affects
* reaction rates in dense stellar plasmas.
*
* @par Usage Example:
* @code
* myEngine.setScreeningModel(screening::ScreeningType::WEAK);
* @endcode
*
* @post The engine will use the specified screening model for subsequent rate calculations.
*/
virtual void setScreeningModel(screening::ScreeningType model) = 0;
[[nodiscard]] virtual fourdst::composition::Composition project(
scratch::StateBlob& ctx,
const NetIn &netIn
) const = 0;
/**
* @brief Get the current electron screening model.
@@ -417,7 +361,9 @@ namespace gridfire::engine {
* screening::ScreeningType currentModel = myEngine.getScreeningModel();
* @endcode
*/
[[nodiscard]] virtual screening::ScreeningType getScreeningModel() const = 0;
[[nodiscard]] virtual screening::ScreeningType getScreeningModel(
scratch::StateBlob& ctx
) const = 0;
/**
* @brief Get the index of a species in the network.
@@ -428,18 +374,10 @@ namespace gridfire::engine {
* engine's internal representation. It is useful for accessing species
* data efficiently.
*/
[[nodiscard]] virtual size_t getSpeciesIndex(const fourdst::atomic::Species &species) const = 0;
/**
* @brief Map a NetIn object to a vector of molar abundances.
*
* @param netIn The input conditions for the network.
* @return A vector of molar abundances corresponding to the species in the network.
*
* This method converts the input conditions into a vector of molar abundances,
* which can be used for further calculations or diagnostics.
*/
[[nodiscard]] virtual std::vector<double> mapNetInToMolarAbundanceVector(const NetIn &netIn) const = 0;
[[nodiscard]] virtual size_t getSpeciesIndex(
scratch::StateBlob& ctx,
const fourdst::atomic::Species &species
) const = 0;
/**
* @brief Prime the engine with initial conditions.
@@ -452,36 +390,10 @@ namespace gridfire::engine {
* rates, initializing internal data structures, and performing any necessary
* pre-computation.
*/
[[nodiscard]] virtual PrimingReport primeEngine(const NetIn &netIn) = 0;
/**
* @brief Get the depth of the network.
*
* @return The depth of the network, which may indicate the level of detail or
* complexity in the reaction network.
*
* This method is intended to provide information about the network's structure,
* such as how many layers of reactions or species are present. It can be useful
* for diagnostics and understanding the network's complexity.
*/
[[nodiscard]] virtual BuildDepthType getDepth() const {
throw std::logic_error("Network depth not supported by this engine.");
}
/**
* @brief Rebuild the network with a specified depth.
*
* @param comp The composition to rebuild the network with.
* @param depth The desired depth of the network.
*
* This method is intended to allow dynamic adjustment of the network's depth,
* which may involve adding or removing species and reactions based on the
* specified depth. However, not all engines support this operation.
*/
virtual void rebuild(const fourdst::composition::CompositionAbstract &comp, BuildDepthType depth) {
throw std::logic_error("Setting network depth not supported by this engine.");
// ReSharper disable once CppDFAUnreachableCode
}
[[nodiscard]] virtual PrimingReport primeEngine(
scratch::StateBlob& ctx,
const NetIn &netIn
) const = 0;
/**
* @brief Recursively collect composition from current engine and any sub engines if they exist.
@@ -497,7 +409,8 @@ namespace gridfire::engine {
* @return An updated composition which is a superset of comp. This may contain species which were culled, for
* example, by either QSE partitioning or reaction flow rate culling
*/
virtual fourdst::composition::Composition collectComposition(
[[nodiscard]] virtual fourdst::composition::Composition collectComposition(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho
@@ -512,11 +425,14 @@ namespace gridfire::engine {
* This method allows querying the current status of a specific species
* within the engine's network.
*/
[[nodiscard]] virtual SpeciesStatus getSpeciesStatus(const fourdst::atomic::Species& species) const = 0;
[[nodiscard]] virtual SpeciesStatus getSpeciesStatus(
scratch::StateBlob& ctx,
const fourdst::atomic::Species& species
) const = 0;
[[nodiscard]] virtual std::optional<StepDerivatives<double>> getMostRecentRHSCalculation() const {
return std::nullopt;
}
[[nodiscard]] virtual std::optional<StepDerivatives<double>> getMostRecentRHSCalculation(
scratch::StateBlob& ctx
) const = 0;
};
}

View File

@@ -14,6 +14,8 @@
#include "gridfire/engine/procedures/construction.h"
#include "gridfire/config/config.h"
#include "gridfire/engine/scratchpads/blob.h"
#include "ankerl/unordered_dense.h"
#include <string>
@@ -31,10 +33,6 @@
#include "gridfire/reaction/weak/weak_interpolator.h"
#include "gridfire/reaction/weak/weak_rate_library.h"
// PERF: The function getNetReactionStoichiometry returns a map of species to their stoichiometric coefficients for a given reaction.
// this makes extra copies of the species, which is not ideal and could be optimized further.
// Even more relevant is the member m_reactionIDMap which makes copies of a REACLIBReaction for each reaction ID.
// REACLIBReactions are quite large data structures, so this could be a performance bottleneck.
namespace gridfire::engine {
/**
* @brief Alias for CppAD AD type for double precision.
@@ -154,6 +152,7 @@ namespace gridfire::engine {
* @see StepDerivatives
*/
[[nodiscard]] std::expected<StepDerivatives<double>, engine::EngineStatus> calculateRHSAndEnergy(
scratch::StateBlob&,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho,
@@ -177,6 +176,7 @@ namespace gridfire::engine {
* @see StepDerivatives
*/
[[nodiscard]] std::expected<StepDerivatives<double>, EngineStatus> calculateRHSAndEnergy(
scratch::StateBlob&,
const fourdst::composition::CompositionAbstract& comp,
double T9,
double rho,
@@ -198,6 +198,7 @@ namespace gridfire::engine {
* @see EnergyDerivatives
*/
[[nodiscard]] EnergyDerivatives calculateEpsDerivatives(
scratch::StateBlob&,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho
@@ -222,12 +223,15 @@ namespace gridfire::engine {
* @see EnergyDerivatives
*/
[[nodiscard]] EnergyDerivatives calculateEpsDerivatives(
scratch::StateBlob&,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho,
const reaction::ReactionSet &activeReactions
) const;
void generate_jacobian_sparsity_pattern();
/**
* @brief Generates the Jacobian matrix for the current state.
*
@@ -242,6 +246,7 @@ namespace gridfire::engine {
* @see getJacobianMatrixEntry()
*/
[[nodiscard]] NetworkJacobian generateJacobianMatrix(
scratch::StateBlob&,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho
@@ -260,6 +265,7 @@ namespace gridfire::engine {
* @see generateJacobianMatrix()
*/
[[nodiscard]] NetworkJacobian generateJacobianMatrix(
scratch::StateBlob&,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho,
@@ -282,20 +288,13 @@ namespace gridfire::engine {
* @see getJacobianMatrixEntry()
*/
[[nodiscard]] NetworkJacobian generateJacobianMatrix(
scratch::StateBlob&,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho,
const SparsityPattern &sparsityPattern
) const override;
/**
* @brief Generates the stoichiometry matrix for the network.
*
* This method computes and stores the stoichiometry matrix,
* which encodes the net change of each species in each reaction.
*/
void generateStoichiometryMatrix() override;
/**
* @brief Calculates the molar reaction flow for a given reaction.
*
@@ -310,6 +309,7 @@ namespace gridfire::engine {
*
*/
[[nodiscard]] double calculateMolarReactionFlow(
scratch::StateBlob&,
const reaction::Reaction& reaction,
const fourdst::composition::CompositionAbstract &comp,
double T9,
@@ -320,51 +320,19 @@ namespace gridfire::engine {
* @brief Gets the list of species in the network.
* @return Vector of Species objects representing all network species.
*/
[[nodiscard]] const std::vector<fourdst::atomic::Species>& getNetworkSpecies() const override;
[[nodiscard]] const std::vector<fourdst::atomic::Species>& getNetworkSpecies(
scratch::StateBlob &ctx
) const override;
/**
* @brief Gets the set of logical reactions in the network.
* @return Reference to the LogicalReactionSet containing all reactions.
*/
[[nodiscard]] const reaction::ReactionSet& getNetworkReactions() const override;
/**
* @brief Sets the reactions for the network.
*
* @param reactions The set of reactions to use in the network.
*
* This method replaces the current set of reactions in the network
* with the provided set. It marks the engine as stale, requiring
* regeneration of matrices and recalculation of rates.
*/
void setNetworkReactions(const reaction::ReactionSet& reactions) override;
/**
* @brief Gets the net stoichiometry for a given reaction.
*
* @param reaction The reaction for which to get the stoichiometry.
* @return Map of species to their stoichiometric coefficients.
*/
[[nodiscard]] static std::unordered_map<fourdst::atomic::Species, int> getNetReactionStoichiometry(
const reaction::Reaction& reaction
);
/**
* @brief Gets an entry from the stoichiometry matrix.
*
* @param species Species to look up stoichiometry for.
* @param reaction Reaction to find.
* @return Stoichiometric coefficient for the species in the reaction.
*
* The stoichiometry matrix must have been generated by `generateStoichiometryMatrix()`.
*
* @see generateStoichiometryMatrix()
*/
[[nodiscard]] int getStoichiometryMatrixEntry(
const fourdst::atomic::Species& species,
const reaction::Reaction& reaction
[[nodiscard]] const reaction::ReactionSet& getNetworkReactions(
scratch::StateBlob&
) const override;
/**
* @brief Computes timescales for all species in the network.
*
@@ -376,8 +344,8 @@ namespace gridfire::engine {
* This method estimates the timescale for abundance change of each species,
* which can be used for timestep control or diagnostics.
*/
[[nodiscard]] std::expected<std::unordered_map<fourdst::atomic::Species, double>, engine::EngineStatus>
getSpeciesTimescales(
[[nodiscard]] std::expected<std::unordered_map<fourdst::atomic::Species, double>, EngineStatus> getSpeciesTimescales(
scratch::StateBlob&,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho
@@ -397,6 +365,7 @@ namespace gridfire::engine {
* calculations with different reaction sets without modifying the engine's internal state.
*/
[[nodiscard]] std::expected<std::unordered_map<fourdst::atomic::Species, double>, EngineStatus> getSpeciesTimescales(
scratch::StateBlob&,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho,
@@ -415,8 +384,8 @@ namespace gridfire::engine {
* This method estimates the destruction timescale for each species,
* which can be useful for understanding reaction flows and equilibrium states.
*/
[[nodiscard]] std::expected<std::unordered_map<fourdst::atomic::Species, double>, engine::EngineStatus>
getSpeciesDestructionTimescales(
[[nodiscard]] std::expected<std::unordered_map<fourdst::atomic::Species, double>, EngineStatus> getSpeciesDestructionTimescales(
scratch::StateBlob&,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho
@@ -436,6 +405,7 @@ namespace gridfire::engine {
* calculations with different reaction sets without modifying the engine's internal state.
*/
[[nodiscard]] std::expected<std::unordered_map<fourdst::atomic::Species, double>, EngineStatus> getSpeciesDestructionTimescales(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho,
@@ -452,24 +422,10 @@ namespace gridfire::engine {
*
* @return The updated composition that includes all species in the network.
*/
fourdst::composition::Composition update(
fourdst::composition::Composition project(
scratch::StateBlob& ctx,
const NetIn &netIn
) override;
/**
* @brief Checks if the engine view is stale and needs to be updated.
*
* @param netIn The current network input (unused).
* @return True if the view is stale, false otherwise.
*
* @deprecated This method is deprecated and will be removed in future versions.
* Stale states are returned as part of the results of methods that
* require the ability to report them.
*/
[[deprecated]] bool isStale(
const NetIn &netIn
) override;
) const override;
/**
* @brief Checks if a given species is involved in the network.
@@ -478,6 +434,7 @@ namespace gridfire::engine {
* @return True if the species is involved in the network, false otherwise.
*/
[[nodiscard]] bool involvesSpecies(
scratch::StateBlob& ctx,
const fourdst::atomic::Species& species
) const;
@@ -498,6 +455,7 @@ namespace gridfire::engine {
* @endcode
*/
void exportToDot(
scratch::StateBlob& ctx,
const std::string& filename
) const;
@@ -518,21 +476,10 @@ namespace gridfire::engine {
* @endcode
*/
void exportToCSV(
scratch::StateBlob& ctx,
const std::string& filename
) const;
/**
* @brief Sets the electron screening model for reaction rate calculations.
*
* @param model The type of screening model to use.
*
* This method allows changing the screening model at runtime. Screening corrections
* account for the electrostatic shielding of nuclei by electrons, which affects
* reaction rates in dense stellar plasmas.
*/
void setScreeningModel(
screening::ScreeningType model
) override;
/**
* @brief Gets the current electron screening model.
@@ -544,23 +491,9 @@ namespace gridfire::engine {
* screening::ScreeningType currentModel = engine.getScreeningModel();
* @endcode
*/
[[nodiscard]] screening::ScreeningType getScreeningModel() const override;
/**
* @brief Sets whether to precompute reaction rates.
*
* @param precompute True to enable precomputation, false to disable.
*
* This method allows enabling or disabling precomputation of reaction rates
* for performance optimization. When enabled, reaction rates are computed
* once and stored for later use.
*
* @post If precomputation is enabled, reaction rates will be precomputed and cached.
* If disabled, reaction rates will be computed on-the-fly as needed.
*/
void setPrecomputation(
bool precompute
);
[[nodiscard]] screening::ScreeningType getScreeningModel(
scratch::StateBlob& ctx
) const override;
/**
* @brief Checks if precomputation of reaction rates is enabled.
@@ -570,7 +503,9 @@ namespace gridfire::engine {
* This method allows checking the current state of precomputation for
* reaction rates in the engine.
*/
[[nodiscard]] bool isPrecomputationEnabled() const;
[[nodiscard]] bool isPrecomputationEnabled(
scratch::StateBlob& ctx
) const;
/**
* @brief Gets the partition function used for reaction rate calculations.
@@ -580,7 +515,9 @@ namespace gridfire::engine {
* This method provides access to the partition function used in the engine,
* which is essential for calculating thermodynamic properties and reaction rates.
*/
[[nodiscard]] const partition::PartitionFunction& getPartitionFunction() const;
[[nodiscard]] const partition::PartitionFunction& getPartitionFunction(
scratch::StateBlob& ctx
) const;
/**
* @brief Calculates the reverse rate for a given reaction.
@@ -648,23 +585,10 @@ namespace gridfire::engine {
* This method allows checking whether the engine is configured to use
* reverse reactions in its calculations.
*/
[[nodiscard]] bool isUsingReverseReactions() const;
[[nodiscard]] bool isUsingReverseReactions(
scratch::StateBlob& ctx
) const;
/**
* @brief Sets whether to use reverse reactions in the engine.
*
* @param useReverse True to enable reverse reactions, false to disable.
*
* This method allows enabling or disabling reverse reactions in the engine.
* If disabled, only forward reactions will be considered in calculations.
*
* @post If reverse reactions are enabled, the engine will consider both
* forward and reverse reactions in its calculations. If disabled,
* only forward reactions will be considered.
*/
void setUseReverseReactions(
bool useReverse
);
/**
* @brief Gets the index of a species in the network.
@@ -676,22 +600,10 @@ namespace gridfire::engine {
* species vector. If the species is not found, it returns -1.
*/
[[nodiscard]] size_t getSpeciesIndex(
scratch::StateBlob& ctx,
const fourdst::atomic::Species &species
) const override;
/**
* @brief Maps the NetIn object to a vector of molar abundances.
*
* @param netIn The NetIn object containing the input conditions.
* @return Vector of molar abundances corresponding to the species in the network.
*
* This method converts the NetIn object into a vector of molar abundances
* for each species in the network, which can be used for further calculations.
*/
[[deprecated]] [[nodiscard]] std::vector<double> mapNetInToMolarAbundanceVector(
const NetIn &netIn
) const override;
/**
* @brief Prepares the engine for calculations with initial conditions.
*
@@ -702,32 +614,10 @@ namespace gridfire::engine {
* setting up reactions, species, and precomputing necessary data.
*/
[[nodiscard]] PrimingReport primeEngine(
scratch::StateBlob& ctx,
const NetIn &netIn
) override;
) const override;
/**
* @brief Gets the depth of the network.
*
* @return The build depth of the network.
*
* This method returns the current build depth of the reaction network,
* which indicates how many levels of reactions are included in the network.
*/
[[nodiscard]] BuildDepthType getDepth() const override;
/**
* @brief Rebuilds the reaction network based on a new composition.
*
* @param comp The new composition to use for rebuilding the network.
* @param depth The build depth to use for the network.
*
* This method rebuilds the reaction network using the provided composition
* and build depth. It updates all internal data structures accordingly.
*/
void rebuild(
const fourdst::composition::CompositionAbstract &comp,
BuildDepthType depth
) override;
/**
* @brief This will return the input comp with the molar abundances of any species not registered in that but
@@ -744,7 +634,12 @@ namespace gridfire::engine {
* have a molar abundance set to 0.
* @throws BadCollectionError If the input composition contains species not present in the network species set
*/
fourdst::composition::Composition collectComposition(const fourdst::composition::CompositionAbstract &comp, double T9, double rho) const override;
fourdst::composition::Composition collectComposition(
scratch::StateBlob&,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho
) const override;
/**
* @brief Gets the status of a species in the network.
@@ -756,7 +651,10 @@ namespace gridfire::engine {
* returns its status (e.g., Active, Inactive, NotFound).
*/
[[nodiscard]]
SpeciesStatus getSpeciesStatus(const fourdst::atomic::Species &species) const override;
SpeciesStatus getSpeciesStatus(
scratch::StateBlob&,
const fourdst::atomic::Species &species
) const override;
[[nodiscard]] bool get_store_intermediate_reaction_contributions() const {
return m_store_intermediate_reaction_contributions;
@@ -766,7 +664,11 @@ namespace gridfire::engine {
m_store_intermediate_reaction_contributions = value;
}
[[nodiscard]] std::optional<StepDerivatives<double>> getMostRecentRHSCalculation() const override;
[[nodiscard]] std::optional<StepDerivatives<double>> getMostRecentRHSCalculation(
scratch::StateBlob&
) const override;
[[nodiscard]] const CppAD::ADFun<double>& getAuthoritativeADFun() const { return m_authoritativeADFun; }
private:
@@ -787,6 +689,7 @@ namespace gridfire::engine {
double reverse_symmetry_factor{}; ///< Symmetry factor for reverse reactions.
};
struct constants {
const double u = Constants::getInstance().get("u").value; ///< Atomic mass unit in g.
const double Na = Constants::getInstance().get("N_a").value; ///< Avogadro's number.
@@ -883,17 +786,7 @@ namespace gridfire::engine {
std::unordered_map<fourdst::atomic::Species, size_t> m_speciesToIndexMap; ///< Map from species to their index in the stoichiometry matrix.
std::unordered_map<size_t, fourdst::atomic::Species> m_indexToSpeciesMap; ///< Map from index to species in the stoichiometry matrix.
mutable CppAD::ADFun<double> m_rhsADFun; ///< CppAD function for the right-hand side of the ODE.
mutable CppAD::ADFun<double> m_epsADFun; ///< CppAD function for the energy generation rate.
mutable CppAD::sparse_jac_work m_jac_work; ///< Work object for sparse Jacobian calculations.
mutable std::vector<double> m_local_abundance_cache;
mutable std::unordered_map<size_t, StepDerivatives<double>> m_stepDerivativesCache;
mutable std::unordered_map<size_t, CppAD::sparse_rcv<std::vector<size_t>, std::vector<double>>> m_jacobianSubsetCache;
mutable std::unordered_map<size_t, CppAD::sparse_jac_work> m_jacWorkCache;
mutable std::optional<StepDerivatives<double>> m_most_recent_rhs_calculation;
bool m_has_been_primed = false; ///< Flag indicating if the engine has been primed.
std::unique_ptr<partition::PartitionFunction> m_partitionFunction; ///< Partition function for the network.
CppAD::sparse_rc<std::vector<size_t>> m_full_jacobian_sparsity_pattern; ///< Full sparsity pattern for the Jacobian matrix.
std::set<std::pair<size_t, size_t>> m_full_sparsity_set; ///< For quick lookups of the base sparsity pattern
@@ -904,14 +797,17 @@ namespace gridfire::engine {
std::unique_ptr<screening::ScreeningModel> m_screeningModel = screening::selectScreeningModel(m_screeningType);
bool m_usePrecomputation = true; ///< Flag to enable or disable using precomputed reactions for efficiency. Mathematically, this should not change the results. Generally end users should not need to change this.
std::vector<PrecomputedReaction> m_precomputed_reactions;
std::unordered_map<uint64_t, size_t> m_precomputed_reaction_index_map;
bool m_useReverseReactions = false; ///< Flag to enable or disable reverse reactions. If false, only forward reactions are considered.
bool m_store_intermediate_reaction_contributions = false; ///< Flag to enable or disable storing intermediate reaction contributions for debugging.
BuildDepthType m_depth;
std::vector<PrecomputedReaction> m_precomputedReactions; ///< Precomputed reactions for efficiency.
std::unordered_map<uint64_t, size_t> m_precomputedReactionIndexMap; ///< Set of hashed precomputed reactions for quick lookup.
std::unique_ptr<partition::PartitionFunction> m_partitionFunction; ///< Partition function for the network.
CppAD::ADFun<double> m_authoritativeADFun;
const size_t m_state_blob_offset;
private:
/**
@@ -957,25 +853,14 @@ namespace gridfire::engine {
*
* @throws std::runtime_error If there are no species in the network.
*/
void recordADTape() const;
void recordADTape();
void collectAtomicReverseRateAtomicBases();
void precomputeNetwork();
/**
* @brief Validates mass and charge conservation across all reactions.
*
* @return True if all reactions conserve mass and charge, false otherwise.
*
* This method checks that all reactions in the network conserve mass
* and charge. If any reaction does not conserve mass or charge, an
* error message is logged and false is returned.
*/
[[nodiscard]] bool validateConservation() const;
double compute_reaction_flow(
scratch::StateBlob& ctx,
const std::vector<double> &local_abundances,
const std::vector<double> &screening_factors,
const std::vector<double> &bare_rates,
@@ -988,10 +873,12 @@ namespace gridfire::engine {
) const;
std::pair<double, double> compute_neutrino_fluxes(
scratch::StateBlob& ctx,
double netFlow,
const reaction::Reaction &reaction) const;
PrecomputationKernelResults accumulate_flows_serial(
scratch::StateBlob& ctx,
const std::vector<double>& local_abundances,
const std::vector<double>& screening_factors,
const std::vector<double>& bare_rates,
@@ -1000,19 +887,9 @@ namespace gridfire::engine {
const reaction::ReactionSet& activeReactions
) const;
#ifdef GRIDFIRE_USE_OPENMP
PrecomputationKernelResults accumulate_flows_parallel(
const std::vector<double>& local_abundances,
const std::vector<double>& screening_factors,
const std::vector<double>& bare_rates,
const std::vector<double>& bare_reverse_rates,
double rho,
const reaction::ReactionSet& activeReactions
) const;
#endif
[[nodiscard]] StepDerivatives<double> calculateAllDerivativesUsingPrecomputation(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
const std::vector<double> &bare_rates,
const std::vector<double> &bare_reverse_rates,
@@ -1308,7 +1185,7 @@ namespace gridfire::engine {
const T k_reaction = reaction.calculate_rate(T9, rho, Ye, mue, Y, m_indexToSpeciesMap);
// --- Cound the number of each reactant species to account for species multiplicity ---
std::unordered_map<fourdst::atomic::Species, int> reactant_counts;
std::unordered_map<fourdst::atomic::Species, size_t> reactant_counts;
reactant_counts.reserve(reaction.reactants().size());
for (const auto& reactant : reaction.reactants()) {
reactant_counts[reactant] = reaction.countReactantOccurrences(reactant);

View File

@@ -6,6 +6,8 @@
#include "fourdst/atomic/atomicSpecies.h"
#include "gridfire/engine/scratchpads/blob.h"
namespace gridfire::engine {
@@ -18,6 +20,7 @@ namespace gridfire::engine {
*
* Refer to priming.cpp for implementation details on logging, algorithmic steps, and error handling.
*
* @param ctx
* @param netIn Input network data containing initial composition, temperature, and density.
* @param engine DynamicEngine used to build and evaluate the reaction network.
* @param ignoredReactionTypes Types of reactions to ignore during priming (e.g., weak reactions).
@@ -27,8 +30,8 @@ namespace gridfire::engine {
* @return PrimingReport encapsulating the results of the priming operation.
*/
PrimingReport primeNetwork(
scratch::StateBlob &ctx,
const NetIn& netIn,
GraphEngine& engine,
const std::optional<std::vector<reaction::ReactionType>>& ignoredReactionTypes
const GraphEngine& engine, const std::optional<std::vector<reaction::ReactionType>>& ignoredReactionTypes
);
}

View File

@@ -0,0 +1,499 @@
/**
* @file blob.h
* @brief Container class for managing multiple scratchpad instances.
*
* This header defines the StateBlob class, which serves as a centralized
* registry for managing multiple scratchpad instances used by computational
* engines. It provides type-safe enrollment, retrieval, and cloning of
* scratchpads using compile-time type checking via C++20 concepts.
*
* @par Purpose
* The StateBlob provides:
* - A fixed-size array of scratchpad slots indexed by ScratchPadType
* - Type-safe enrollment ensuring one instance per scratchpad type
* - Compile-time verified retrieval with optional initialization checks
* - Deep cloning of all enrolled scratchpads for parallel execution
* - Status tracking for each scratchpad slot
*
* @par Examples
* @code{.cpp}
* #include "gridfire/engine/scratchpads/blob.h"
* #include "gridfire/engine/scratchpads/engine_graph_scratchpad.h"
*
* using namespace gridfire::engine::scratch;
*
* // Create a StateBlob and enroll scratchpads
* StateBlob blob;
* blob.enroll<GraphEngineScratchPad>();
*
* // Retrieve a scratchpad (returns std::expected)
* auto result = blob.get<GraphEngineScratchPad>();
* if (result.has_value()) {
* GraphEngineScratchPad* scratch = result.value();
* scratch->initialize(engine);
* }
*
* // Retrieve with initialization check
* auto checked = blob.get<GraphEngineScratchPad, true>();
* if (!checked.has_value()) {
* if (checked.error() == StateBlob::Error::SCRATCHPAD_NOT_INITIALIZED) {
* // Handle uninitialized scratchpad
* }
* }
*
* // Clone for parallel execution
* auto worker_blob = blob.clone_structure();
* @endcode
*
* @par Thread Safety
* The StateBlob class is **not thread-safe**. Each thread should have its own
* StateBlob instance. Use clone_structure() to create independent copies for
* parallel workers. The cloned blob contains deep copies of all enrolled
* scratchpads.
*
* @see AbstractScratchPad
* @see ScratchPadType
*/
#pragma once
#include "gridfire/engine/scratchpads/scratchpad_abstract.h"
#include "gridfire/engine/scratchpads/types.h"
#include "gridfire/exceptions/error_scratchpad.h"
#include <unordered_map>
#include <memory>
#include <expected>
#include <unordered_set>
namespace gridfire::engine::scratch {
/**
* @brief Concept that constrains types to valid scratchpad implementations.
*
* A type satisfies IsScratchPad if:
* - It derives from AbstractScratchPad
* - It has a static ID member convertible to ScratchPadType
*
* @tparam T The type to check against the concept.
*
* @par Examples
* @code{.cpp}
* // This will compile only if MyScratchPad satisfies IsScratchPad
* template <IsScratchPad T>
* void process_scratchpad(T& scratch) {
* // Use scratch...
* }
* @endcode
*/
template <typename T>
concept IsScratchPad = std::is_base_of_v<AbstractScratchPad, T>
&& requires { { T::ID } -> std::convertible_to<ScratchPadType>; };
/**
* @brief Container for managing a collection of typed scratchpad instances.
*
* StateBlob provides a centralized registry for scratchpads used by engines.
* It uses a fixed-size array indexed by ScratchPadType for O(1) access, with
* compile-time type safety enforced through the IsScratchPad concept.
*
* The blob supports:
* - Enrolling new scratchpad types (one instance per type)
* - Type-safe retrieval with optional initialization validation
* - Runtime retrieval by ScratchPadType enum value
* - Deep cloning for parallel execution contexts
* - Status queries for monitoring scratchpad states
*
* @par Thread Safety
* This class is **not thread-safe**. Each thread should have its own StateBlob
* instance. Use clone_structure() to create independent copies for parallel
* workers. Concurrent access to the same StateBlob instance requires external
* synchronization.
*/
class StateBlob {
public:
/**
* @brief Error codes for scratchpad operations.
*/
enum class Error : uint8_t {
SCRATCHPAD_NOT_FOUND, ///< Requested scratchpad type is not enrolled.
SCRATCHPAD_BAD_CAST, ///< Dynamic cast to requested type failed.
SCRATCHPAD_NOT_INITIALIZED, ///< Scratchpad exists but is not initialized.
SCRATCHPAD_TYPE_COLLISION, ///< Attempted to enroll duplicate type.
SCRATCHPAD_OUT_OF_BOUNDS, ///< ScratchPadType index exceeds array bounds.
SCRATCHPAD_UNKNOWN_ERROR ///< Unspecified error condition.
};
/**
* @brief Convert an Error enum value to a human-readable string.
*
* @param error The error code to convert.
* @return A string representation of the error.
*
* @par Examples
* @code{.cpp}
* auto result = blob.get<MyScratchPad>();
* if (!result.has_value()) {
* std::cerr << "Error: " << StateBlob::error_to_string(result.error());
* }
* @endcode
*/
static std::string error_to_string(const Error error) {
switch (error) {
case Error::SCRATCHPAD_NOT_FOUND:
return "SCRATCHPAD_NOT_FOUND";
case Error::SCRATCHPAD_BAD_CAST:
return "SCRATCHPAD_BAD_CAST";
case Error::SCRATCHPAD_NOT_INITIALIZED:
return "SCRATCHPAD_NOT_INITIALIZED";
case Error::SCRATCHPAD_TYPE_COLLISION:
return "SCRATCHPAD_TYPE_COLLISION";
case Error::SCRATCHPAD_OUT_OF_BOUNDS:
return "SCRATCHPAD_OUT_OF_BOUNDS";
default:
return "SCRATCHPAD_UNKNOWN_ERROR";
}
}
/**
* @brief Status codes for scratchpad slots.
*/
enum class ScratchPadStatus : uint8_t {
NOT_ENROLLED, ///< No scratchpad has been enrolled for this slot.
ENROLLED_NOT_INITIALIZED,///< Scratchpad enrolled but not yet initialized.
ENROLLED_INITIALIZED ///< Scratchpad enrolled and fully initialized.
};
public:
/// @brief Default constructor.
StateBlob() = default;
/// @brief Default destructor.
~StateBlob() = default;
/**
* @brief Enroll a new scratchpad type into the blob.
*
* Creates a new instance of the specified scratchpad type and registers it
* in the appropriate slot. Only one instance per type is allowed.
*
* @tparam CTX The scratchpad type to enroll (must satisfy IsScratchPad).
*
* @throws exceptions::ScratchPadError if a scratchpad of this type is already enrolled.
*
* @par Examples
* @code{.cpp}
* StateBlob blob;
* blob.enroll<GraphEngineScratchPad>();
* blob.enroll<AdaptiveEngineViewScratchPad>();
*
* // This would throw - duplicate enrollment
* // blob.enroll<GraphEngineScratchPad>();
* @endcode
*/
template <IsScratchPad CTX>
void enroll() {
constexpr auto index = static_cast<size_t>(CTX::ID);
static_assert(index < MAX_SCRATCHPADS, "ScratchPadType ID exceeds (maximum) allowed scratchpads.");
if (scratchpad_enrolled_flags[index]) {
throw exceptions::ScratchPadError("ScratchPad of this type has already been enrolled. Only one instance per type is allowed.");
}
scratchpads[index] = std::make_unique<CTX>();
scratchpad_enrolled_flags[index] = true;
}
/**
* @brief Retrieve a scratchpad by type.
*
* Returns a pointer to the enrolled scratchpad of the specified type.
* In debug builds, performs a dynamic_cast for type safety; in release
* builds, uses static_cast for performance.
*
* @tparam CTX The scratchpad type to retrieve (must satisfy IsScratchPad).
*
* @return std::expected containing the scratchpad pointer, or an Error if not found/invalid.
*
* @par Examples
* @code{.cpp}
* auto result = blob.get<GraphEngineScratchPad>();
* if (result.has_value()) {
* GraphEngineScratchPad* scratch = result.value();
* // Use scratch...
* } else {
* // Handle error
* std::cerr << StateBlob::error_to_string(result.error());
* }
* @endcode
*/
template <IsScratchPad CTX>
std::expected<CTX*, Error> get() const {
constexpr auto index = static_cast<size_t>(CTX::ID);
static_assert(index < MAX_SCRATCHPADS, "ScratchPadType ID exceeds maximum allowed scratchpads.");
AbstractScratchPad* scratchpad = scratchpads[index].get();
if (!scratchpad) {
return std::unexpected<Error>(Error::SCRATCHPAD_NOT_FOUND);
}
#if !defined(NDEBUG)
if (auto* cast_ptr = dynamic_cast<CTX*>(scratchpad)) {
return cast_ptr;
} else {
return std::unexpected<Error>(Error::SCRATCHPAD_BAD_CAST);
}
#else
return static_cast<CTX*>(scratchpad);
#endif
}
/**
* @brief Retrieve a scratchpad by type with optional initialization check.
*
* Returns a pointer to the enrolled scratchpad of the specified type.
* When MUST_BE_INITIALIZED is true, also verifies that the scratchpad
* has been initialized before returning it.
*
* @tparam CTX The scratchpad type to retrieve (must satisfy IsScratchPad).
* @tparam MUST_BE_INITIALIZED If true, returns an error for uninitialized scratchpads.
*
* @return std::expected containing the scratchpad pointer, or an Error if not found/invalid/uninitialized.
*
* @par Examples
* @code{.cpp}
* // Get only if initialized
* auto result = blob.get<GraphEngineScratchPad, true>();
* if (!result.has_value()) {
* if (result.error() == StateBlob::Error::SCRATCHPAD_NOT_INITIALIZED) {
* // Need to initialize first
* }
* }
* @endcode
*/
template <IsScratchPad CTX, bool MUST_BE_INITIALIZED>
std::expected<CTX*, Error> get() const {
constexpr auto index = static_cast<size_t>(CTX::ID);
static_assert(index < MAX_SCRATCHPADS, "ScratchPadType ID exceeds maximum allowed scratchpads.");
AbstractScratchPad* scratchpad = scratchpads[index].get();
if (!scratchpad) {
return std::unexpected<Error>(Error::SCRATCHPAD_NOT_FOUND);
}
#if !defined(NDEBUG)
if (auto* cast_ptr = dynamic_cast<CTX*>(scratchpad)) {
if constexpr (MUST_BE_INITIALIZED) {
if (!cast_ptr->is_initialized()) {
return std::unexpected<Error>(Error::SCRATCHPAD_NOT_INITIALIZED);
}
}
return cast_ptr;
} else {
return std::unexpected<Error>(Error::SCRATCHPAD_BAD_CAST);
}
#else
CTX* cast_ptr = static_cast<CTX*>(scratchpad);
if constexpr (MUST_BE_INITIALIZED) {
if (!cast_ptr->is_initialized()) {
return std::unexpected<Error>(Error::SCRATCHPAD_NOT_INITIALIZED);
}
}
return cast_ptr;
#endif
}
/**
* @brief Retrieve a scratchpad by runtime ScratchPadType value.
*
* Returns a pointer to the abstract base class for the scratchpad at
* the specified type index. Useful when the concrete type is not known
* at compile time.
*
* @param type The ScratchPadType enum value identifying the scratchpad.
*
* @return std::expected containing the AbstractScratchPad pointer, or an Error.
*
* @par Examples
* @code{.cpp}
* ScratchPadType type = ScratchPadType::GRAPH_ENGINE_SCRATCHPAD;
* auto result = blob.get(type);
* if (result.has_value()) {
* AbstractScratchPad* scratch = result.value();
* if (scratch->is_initialized()) {
* // Use scratch...
* }
* }
* @endcode
*/
[[nodiscard]] std::expected<AbstractScratchPad*, Error> get(const ScratchPadType type) const { // NOLINT(*-convert-member-functions-to-static)
const auto index = static_cast<size_t>(type);
if (index >= MAX_SCRATCHPADS) {
return std::unexpected<Error>(Error::SCRATCHPAD_OUT_OF_BOUNDS);
}
AbstractScratchPad* scratchpad = scratchpads[index].get();
if (!scratchpad) {
return std::unexpected<Error>(Error::SCRATCHPAD_NOT_FOUND);
}
return scratchpad;
}
/**
* @brief Create a deep copy of this blob with all enrolled scratchpads.
*
* Clones the blob structure and all enrolled scratchpads using their
* clone() methods. The resulting blob is independent and can be used
* in a separate thread.
*
* @return A unique pointer to the cloned StateBlob.
*
* @par Examples
* @code{.cpp}
* StateBlob blob;
* blob.enroll<GraphEngineScratchPad>();
*
* // Initialize the scratchpad
* auto scratch = blob.get<GraphEngineScratchPad>().value();
* scratch->initialize(engine);
*
* // Clone for parallel workers
* std::vector<std::unique_ptr<StateBlob>> worker_blobs;
* for (int i = 0; i < num_threads; ++i) {
* worker_blobs.push_back(blob.clone_structure());
* }
* @endcode
*/
[[nodiscard]] std::unique_ptr<StateBlob> clone_structure() const { // NOLINT(*-convert-member-functions-to-static)
auto new_blob = std::make_unique<StateBlob>();
for (size_t i = 0; i < MAX_SCRATCHPADS; ++i) {
if (scratchpad_enrolled_flags[i]) {
new_blob->scratchpads[i] = scratchpads[i]->clone();
new_blob->scratchpad_enrolled_flags[i] = true;
}
}
return new_blob;
}
/**
* @brief Get the set of all registered scratchpad types.
*
* @return An unordered set of ScratchPadType values for all enrolled scratchpads.
*
* @par Examples
* @code{.cpp}
* auto registered = blob.get_registered_scratchpads();
* for (auto type : registered) {
* std::cout << "Registered: " << static_cast<int>(type) << "\n";
* }
* @endcode
*/
[[nodiscard]] std::unordered_set<ScratchPadType> get_registered_scratchpads() const { // NOLINT(*-convert-member-functions-to-static)
std::unordered_set<ScratchPadType> sset;
for (size_t i = 0; i < MAX_SCRATCHPADS; ++i) {
if (scratchpad_enrolled_flags[i]) {
sset.insert(static_cast<ScratchPadType>(i));
}
}
return sset;
}
/**
* @brief Check if a specific scratchpad type is initialized.
*
* @tparam CTX The scratchpad type to check (must satisfy IsScratchPad).
*
* @return true if the scratchpad is enrolled and initialized.
*
* @throws exceptions::ScratchPadError if the scratchpad type is not enrolled.
*
* @par Examples
* @code{.cpp}
* if (blob.initialized<GraphEngineScratchPad>()) {
* // Safe to use the scratchpad
* }
* @endcode
*/
template <IsScratchPad CTX>
[[nodiscard]] bool initialized() const {
constexpr auto index = static_cast<size_t>(CTX::ID);
static_assert(index < MAX_SCRATCHPADS, "ScratchPadType ID exceeds maximum allowed scratchpads.");
if (!scratchpad_enrolled_flags[index]) {
throw exceptions::ScratchPadError("Cannot check initialization status: ScratchPad of this type is not enrolled.");
}
return scratchpads[index]->is_initialized();
}
/**
* @brief Get the status of a specific scratchpad type.
*
* @tparam CTX The scratchpad type to query (must satisfy IsScratchPad).
*
* @return The ScratchPadStatus indicating enrollment and initialization state.
*
* @par Examples
* @code{.cpp}
* auto status = blob.get_status<GraphEngineScratchPad>();
* switch (status) {
* case StateBlob::ScratchPadStatus::NOT_ENROLLED:
* blob.enroll<GraphEngineScratchPad>();
* break;
* case StateBlob::ScratchPadStatus::ENROLLED_NOT_INITIALIZED:
* // Need to initialize
* break;
* case StateBlob::ScratchPadStatus::ENROLLED_INITIALIZED:
* // Ready to use
* break;
* }
* @endcode
*/
template <IsScratchPad CTX>
[[nodiscard]] ScratchPadStatus get_status() const {
constexpr auto index = static_cast<size_t>(CTX::ID);
static_assert(index < MAX_SCRATCHPADS, "ScratchPadType ID exceeds maximum allowed scratchpads.");
if (!scratchpad_enrolled_flags[index]) {
return ScratchPadStatus::NOT_ENROLLED;
}
if (scratchpads[index]->is_initialized()) {
return ScratchPadStatus::ENROLLED_INITIALIZED;
} else {
return ScratchPadStatus::ENROLLED_NOT_INITIALIZED;
}
}
/**
* @brief Get a map of all scratchpad types to their current status.
*
* @return An unordered map from ScratchPadType to ScratchPadStatus for all slots.
*
* @par Examples
* @code{.cpp}
* auto status_map = blob.get_status_map();
* for (const auto& [type, status] : status_map) {
* if (status == StateBlob::ScratchPadStatus::ENROLLED_NOT_INITIALIZED) {
* std::cout << "Scratchpad " << static_cast<int>(type) << " needs initialization\n";
* }
* }
* @endcode
*/
[[nodiscard]] std::unordered_map<ScratchPadType, ScratchPadStatus> get_status_map() const { // NOLINT(*-convert-member-functions-to-static)
std::unordered_map<ScratchPadType, ScratchPadStatus> status_map;
for (size_t i = 0; i < MAX_SCRATCHPADS; ++i) {
auto type = static_cast<ScratchPadType>(i);
if (!scratchpad_enrolled_flags[i]) {
status_map[type] = ScratchPadStatus::NOT_ENROLLED;
} else if (scratchpads[i]->is_initialized()) {
status_map[type] = ScratchPadStatus::ENROLLED_INITIALIZED;
} else {
status_map[type] = ScratchPadStatus::ENROLLED_NOT_INITIALIZED;
}
}
return status_map;
}
private:
/// @brief Maximum number of scratchpad slots, derived from ScratchPadType::_COUNT.
static constexpr size_t MAX_SCRATCHPADS = static_cast<size_t>(ScratchPadType::_COUNT);
/// @brief Array of scratchpad instances indexed by ScratchPadType.
std::array<std::unique_ptr<AbstractScratchPad>, MAX_SCRATCHPADS> scratchpads{};
/// @brief Flags indicating which scratchpad slots have been enrolled.
std::array<bool, MAX_SCRATCHPADS> scratchpad_enrolled_flags{false};
};
} // namespace gridfire::engine::scratch

View File

@@ -0,0 +1,135 @@
/**
* @file engine_adaptive_scratchpad.h
* @brief Scratchpad implementation for the AdaptiveEngineView.
*
* This header defines the AdaptiveEngineViewScratchPad, a concrete implementation
* of AbstractScratchPad designed for use with the AdaptiveEngineView. It provides
* thread-local storage for active species and reactions that are dynamically
* determined during adaptive network computations.
*
* @par Purpose
* The AdaptiveEngineViewScratchPad stores:
* - The set of currently active species in the adaptive network
* - The set of currently active reactions based on network topology
* - Initialization state to track whether the scratchpad is ready for use
*
* @par Examples
* @code{.cpp}
* #include "gridfire/engine/scratchpads/engine_adaptive_scratchpad.h"
* #include "gridfire/engine/views/engine_adaptive.h"
*
* // Create and initialize the scratchpad from an AdaptiveEngineView
* gridfire::engine::scratch::AdaptiveEngineViewScratchPad scratch;
* AdaptiveEngineView engine = create_adaptive_engine();
* scratch.initialize(engine);
*
* if (scratch.is_initialized()) {
* // Access active species for computation
* for (const auto& species : scratch.active_species) {
* // Process each active species
* }
* }
*
* // Clone for parallel execution
* auto worker_scratch = scratch.clone();
* @endcode
*
* @par Thread Safety
* This class is **not thread-safe**. Each thread should have its own instance
* of AdaptiveEngineViewScratchPad. Use clone() to create independent copies
* for parallel workers.
*
* @see AbstractScratchPad
* @see AdaptiveEngineView
*/
#pragma once
#include "gridfire/engine/scratchpads/scratchpad_abstract.h"
#include "gridfire/engine/scratchpads/types.h"
#include "gridfire/engine/views/engine_adaptive.h"
namespace gridfire::engine::scratch {
/**
* @brief Scratchpad for storing working memory used by AdaptiveEngineView computations.
*
* AdaptiveEngineViewScratchPad provides temporary storage for the active species
* and reactions determined by the adaptive network algorithm. This allows the
* engine to avoid recalculating network topology on every evaluation.
*
* @par Thread Safety
* This class is **not thread-safe**. Each thread should operate on its own
* independent instance. Use clone() to create copies for parallel execution.
*/
struct AdaptiveEngineViewScratchPad final : AbstractScratchPad {
/// @brief Unique identifier for this scratchpad type.
static constexpr auto ID = ScratchPadType::ADAPTIVE_ENGINE_VIEW_SCRATCHPAD;
/// @brief Flag indicating whether the scratchpad has been initialized.
bool has_initialized = false;
/// @brief Vector of species currently active in the adaptive network.
std::vector<fourdst::atomic::Species> active_species;
/// @brief Set of reactions currently active in the adaptive network.
reaction::ReactionSet active_reactions;
/**
* @brief Check whether the scratchpad has been initialized.
* @return true if initialized, false otherwise.
*/
[[nodiscard]] bool is_initialized() const override { return has_initialized; }
/**
* @brief Initialize the scratchpad from an AdaptiveEngineView.
*
* Clears any existing state and prepares the scratchpad for use.
* This method is idempotent; calling it multiple times has no effect
* after the first successful initialization.
*
* @param engine The AdaptiveEngineView to initialize from.
*
* @par Examples
* @code{.cpp}
* AdaptiveEngineViewScratchPad scratch;
* AdaptiveEngineView engine = create_engine();
* scratch.initialize(engine);
* @endcode
*/
void initialize(const AdaptiveEngineView& engine) {
if (has_initialized) return;
active_species.clear();
active_reactions.clear();
has_initialized = true;
}
/**
* @brief Create a deep copy of this scratchpad.
*
* Creates an independent copy of all internal state, including
* active species and reactions. The clone can be modified without
* affecting the original.
*
* @return A unique pointer to the cloned scratchpad.
*
* @par Examples
* @code{.cpp}
* auto original = std::make_unique<AdaptiveEngineViewScratchPad>();
* original->initialize(engine);
*
* // Create independent copy for a worker thread
* auto worker_copy = original->clone();
* @endcode
*/
std::unique_ptr<AbstractScratchPad> clone() const override {
auto ptr = std::make_unique<AdaptiveEngineViewScratchPad>();
ptr->has_initialized = has_initialized;
ptr->active_species = active_species;
ptr->active_reactions = active_reactions;
return ptr;
}
};
} // namespace gridfire::engine::scratch

View File

@@ -0,0 +1,140 @@
/**
* @file engine_defined_scratchpad.h
* @brief Scratchpad implementation for the DefinedEngineView.
*
* This header defines the DefinedEngineViewScratchPad, a concrete implementation
* of AbstractScratchPad designed for use with engines that have a statically
* defined reaction network. It provides storage for the active species, reactions,
* and index mappings that enable efficient lookups during computations.
*
* @par Purpose
* The DefinedEngineViewScratchPad stores:
* - The set of active species in the defined network
* - The set of active reactions in the defined network
* - Index mappings for efficient species and reaction lookups
* - A cached vector representation of active species for performance
*
* @par Examples
* @code{.cpp}
* #include "gridfire/engine/scratchpads/engine_defined_scratchpad.h"
*
* // Create a scratchpad for a defined engine
* gridfire::engine::scratch::DefinedEngineViewScratchPad scratch;
*
* // Populate active species
* scratch.active_species.insert(species1);
* scratch.active_species.insert(species2);
*
* // Set up index mappings for efficient lookups
* scratch.species_index_map = {0, 1, 2, 3};
* scratch.reaction_index_map = {0, 1};
*
* if (scratch.is_initialized()) {
* // Use the scratchpad for computations
* for (size_t idx : scratch.species_index_map) {
* // Process species by index
* }
* }
*
* // Clone for parallel execution
* auto worker_scratch = scratch.clone();
* @endcode
*
* @par Thread Safety
* This class is **not thread-safe**. Each thread should have its own instance
* of DefinedEngineViewScratchPad. Use clone() to create independent copies
* for parallel workers.
*
* @see AbstractScratchPad
* @see DefinedEngineView
*/
#pragma once
#include "gridfire/engine/scratchpads/types.h"
#include "gridfire/engine/scratchpads/scratchpad_abstract.h"
#include "gridfire/reaction/reaction.h"
#include "fourdst/atomic/atomicSpecies.h"
#include <vector>
#include <memory>
#include <set>
namespace gridfire::engine::scratch {
/**
* @brief Scratchpad for storing working memory used by defined reaction network engines.
*
* DefinedEngineViewScratchPad provides storage for species and reaction data
* in engines with statically defined reaction networks. It includes index mappings
* for efficient lookups and an optional cache for the active species vector.
*
* Unlike adaptive scratchpads, the defined scratchpad is considered initialized
* by default (has_initialized = true), as the network structure is known at
* construction time.
*
* @par Thread Safety
* This class is **not thread-safe**. Each thread should operate on its own
* independent instance. Use clone() to create copies for parallel execution.
*/
struct DefinedEngineViewScratchPad final : AbstractScratchPad {
/// @brief Unique identifier for this scratchpad type.
constexpr static auto ID = ScratchPadType::DEFINED_ENGINE_VIEW_SCRATCHPAD;
/// @brief Flag indicating whether the scratchpad is initialized (default: true).
bool has_initialized = true;
/// @brief Set of species active in the defined network.
std::set<fourdst::atomic::Species> active_species;
/// @brief Set of reactions active in the defined network.
reaction::ReactionSet active_reactions;
/// @brief Mapping from local indices to global species indices.
std::vector<size_t> species_index_map;
/// @brief Mapping from local indices to global reaction indices.
std::vector<size_t> reaction_index_map;
/// @brief Cached vector of active species for performance optimization.
/// @details This optional cache avoids repeated conversion from set to vector.
std::optional<std::vector<fourdst::atomic::Species>> active_species_vector_cache = std::nullopt;
/**
* @brief Check whether the scratchpad has been initialized.
* @return true if initialized (always true by default for defined networks).
*/
bool is_initialized() const override { return has_initialized; }
/**
* @brief Create a deep copy of this scratchpad.
*
* Creates an independent copy of all internal state, including
* active species, reactions, index mappings, and the species vector cache.
* The clone can be modified without affecting the original.
*
* @return A unique pointer to the cloned scratchpad.
*
* @par Examples
* @code{.cpp}
* DefinedEngineViewScratchPad scratch;
* scratch.active_species.insert(species);
* scratch.species_index_map = {0, 1, 2};
*
* // Create independent copy for a worker thread
* auto worker_copy = scratch.clone();
* @endcode
*/
std::unique_ptr<AbstractScratchPad> clone() const override {
auto pad = std::make_unique<DefinedEngineViewScratchPad>();
pad->has_initialized = this->has_initialized;
pad->active_species = this->active_species;
pad->active_reactions = this->active_reactions;
pad->species_index_map = this->species_index_map;
pad->reaction_index_map = this->reaction_index_map;
pad->active_species_vector_cache = this->active_species_vector_cache;
return pad;
}
};
} // namespace gridfire::engine::scratch

View File

@@ -0,0 +1,205 @@
/**
* @file engine_graph_scratchpad.h
* @brief Scratchpad implementation for the GraphEngine using CppAD automatic differentiation.
*
* This header defines the GraphEngineScratchPad, a concrete implementation of
* AbstractScratchPad designed for use with the GraphEngine. It provides thread-local
* storage for CppAD automatic differentiation functions, Jacobian computation work
* structures, and cached derivatives used during ODE integration.
*
* @par Purpose
* The GraphEngineScratchPad stores:
* - A local copy of the CppAD ADFun for RHS evaluation
* - Work structures for sparse Jacobian calculations
* - Cached abundance values for efficient reuse
* - Cached step derivatives and Jacobian subsets by timestep
* - The most recent RHS calculation for warm-starting
*
* @par Examples
* @code{.cpp}
* #include "gridfire/engine/scratchpads/engine_graph_scratchpad.h"
* #include "gridfire/engine/engine_graph.h"
*
* // Create and initialize the scratchpad from a GraphEngine
* gridfire::engine::scratch::GraphEngineScratchPad scratch;
* GraphEngine engine = create_graph_engine();
* scratch.initialize(engine);
*
* if (scratch.is_initialized()) {
* // Access the local ADFun for thread-safe evaluation
* auto& adfun = scratch.rhsADFun.value();
*
* // Use cached Jacobian work for efficient sparse computations
* auto& jac_work = scratch.jac_work;
* }
*
* // Clone for parallel execution
* auto worker_scratch = scratch.clone();
* @endcode
*
* @par Thread Safety
* This class is **not thread-safe**. Each thread must have its own instance
* of GraphEngineScratchPad because CppAD ADFun objects maintain internal state
* that is modified during evaluation. Use clone() to create independent copies
* for parallel workers, ensuring each thread has its own ADFun instance.
*
* @see AbstractScratchPad
* @see GraphEngine
* @see CppAD::ADFun
*/
#pragma once
#include <vector>
#include "gridfire/engine/scratchpads/scratchpad_abstract.h"
#include "gridfire/engine/scratchpads/types.h"
#include "gridfire/engine/engine_graph.h"
#include "gridfire/engine/engine_abstract.h"
#include "cppad/cppad.hpp"
#include <optional>
namespace gridfire::engine::scratch {
/**
* @brief Scratchpad for storing CppAD automatic differentiation state for GraphEngine.
*
* GraphEngineScratchPad provides thread-local storage for all CppAD-related
* objects needed during ODE integration with the GraphEngine. This includes
* the ADFun object for evaluating the right-hand side of the ODE and computing
* Jacobians, as well as various caches to improve performance.
*
* @par Thread Safety
* This class is **not thread-safe**. CppAD ADFun objects maintain internal
* state that is modified during Forward and Reverse mode operations. Each
* thread must have its own scratchpad instance. Use clone() to create
* independent copies for parallel execution.
*
* @note When cloning, if the rhsADFun has not been initialized, the clone
* will also be uninitialized and has_initialized will be false.
*/
struct GraphEngineScratchPad final : AbstractScratchPad {
/**
* @brief Result codes for ADFun registration operations.
*/
enum class ADFunRegistrationResult : uint8_t {
SUCCESS, ///< Registration completed successfully.
ALREADY_REGISTERED ///< ADFun was already registered; no action taken.
};
/// @brief CppAD function object for evaluating the ODE right-hand side.
/// @details Contains the computational graph for automatic differentiation.
std::optional<CppAD::ADFun<double>> rhsADFun;
/// @brief Work structure for sparse Jacobian calculations.
/// @details Reused across Jacobian evaluations to avoid reallocation.
CppAD::sparse_jac_work jac_work;
/// @brief Local cache of abundance values for efficient RHS evaluation.
std::vector<double> local_abundance_cache;
/// @brief Cache of step derivatives indexed by timestep identifier.
std::unordered_map<size_t, StepDerivatives<double>> stepDerivativesCache;
/// @brief Cache of sparse Jacobian subsets indexed by timestep identifier.
std::unordered_map<size_t, CppAD::sparse_rcv<std::vector<size_t>, std::vector<double>>> jacobianSubsetCache;
/// @brief Cache of Jacobian work structures indexed by timestep identifier.
std::unordered_map<size_t, CppAD::sparse_jac_work> jacWorkCache;
/// @brief The most recent RHS calculation result for warm-starting.
std::optional<StepDerivatives<double>> most_recent_rhs_calculation;
/// @brief Flag indicating whether the scratchpad has been initialized.
bool has_initialized = false;
/// @brief Unique identifier for this scratchpad type.
static constexpr auto ID = ScratchPadType::GRAPH_ENGINE_SCRATCHPAD;
/**
* @brief Check whether the scratchpad has been initialized.
* @return true if initialized with a valid ADFun, false otherwise.
*/
[[nodiscard]] bool is_initialized() const override { return has_initialized; }
/**
* @brief Initialize the scratchpad from a GraphEngine.
*
* Copies the authoritative ADFun from the engine and clears all caches.
* This method is idempotent; calling it multiple times has no effect
* after the first successful initialization.
*
* @param engine The GraphEngine to initialize from.
*
* @par Examples
* @code{.cpp}
* GraphEngineScratchPad scratch;
* GraphEngine engine = create_engine();
* scratch.initialize(engine);
*
* // Now safe to use for thread-local computations
* auto& adfun = scratch.rhsADFun.value();
* @endcode
*/
void initialize(const GraphEngine& engine) {
if (has_initialized) return;
const auto& sourceTape = engine.getAuthoritativeADFun();
rhsADFun.emplace();
*rhsADFun = sourceTape;
jac_work.clear();
local_abundance_cache.clear();
stepDerivativesCache.clear();
jacobianSubsetCache.clear();
jacWorkCache.clear();
most_recent_rhs_calculation = std::nullopt;
has_initialized = true;
}
/**
* @brief Create a deep copy of this scratchpad.
*
* Creates an independent copy of all internal state, including the
* CppAD ADFun object and all caches. The clone can be safely used
* in a separate thread without affecting the original.
*
* @return A unique pointer to the cloned scratchpad.
*
* @note If rhsADFun is not initialized, the clone will also be
* uninitialized (has_initialized = false).
*
* @par Examples
* @code{.cpp}
* GraphEngineScratchPad scratch;
* scratch.initialize(engine);
*
* // Create independent copies for parallel workers
* std::vector<std::unique_ptr<AbstractScratchPad>> worker_pads;
* for (int i = 0; i < num_threads; ++i) {
* worker_pads.push_back(scratch.clone());
* }
* @endcode
*/
[[nodiscard]] std::unique_ptr<AbstractScratchPad> clone() const override {
auto ptr = std::make_unique<GraphEngineScratchPad>();
if (!rhsADFun.has_value()) {
ptr->rhsADFun = std::nullopt;
ptr->has_initialized = false;
} else {
ptr->rhsADFun.emplace();
*ptr->rhsADFun = rhsADFun.value();
ptr->has_initialized = true;
}
ptr->jac_work = jac_work;
ptr->local_abundance_cache = local_abundance_cache;
ptr->stepDerivativesCache = stepDerivativesCache;
ptr->jacobianSubsetCache = jacobianSubsetCache;
ptr->jacWorkCache = jacWorkCache;
ptr->most_recent_rhs_calculation = most_recent_rhs_calculation;
return ptr;
}
};
} // namespace gridfire::engine::scratch

View File

@@ -0,0 +1,212 @@
/**
* @file engine_multiscale_scratchpad.h
* @brief Scratchpad implementation for the MultiscalePartitioningEngineView.
*
* This header defines the MultiscalePartitioningEngineViewScratchPad, a concrete
* implementation of AbstractScratchPad designed for use with multiscale partitioning
* algorithms. It provides thread-local storage for QSE (Quasi-Static Equilibrium)
* groups, solvers, species classifications, and SUNDIALS contexts required for
* solving stiff subsystems.
*
* @par Purpose
* The MultiscalePartitioningEngineViewScratchPad stores:
* - QSE groups representing clusters of species in quasi-static equilibrium
* - QSE solvers for each group (managed via unique_ptr)
* - Classification of species into dynamic vs. algebraic categories
* - A composition cache for efficient lookup of computed compositions
* - A SUNDIALS context for numerical solver operations
*
* @par Examples
* @code{.cpp}
* #include "gridfire/engine/scratchpads/engine_multiscale_scratchpad.h"
*
* // Create and initialize the scratchpad
* gridfire::engine::scratch::MultiscalePartitioningEngineViewScratchPad scratch;
* scratch.initialize();
*
* if (scratch.is_initialized()) {
* // Add QSE groups and species classifications
* scratch.dynamic_species.push_back(hydrogen);
* scratch.dynamic_species.push_back(helium);
* scratch.algebraic_species.push_back(carbon);
*
* // Use the SUNDIALS context for solver operations
* SUNContext ctx = scratch.sun_ctx;
* }
*
* // Clone for parallel execution (note: solvers not cloned)
* auto worker_scratch = scratch.clone();
* worker_scratch->initialize(); // Must re-initialize for new SUNContext
* @endcode
*
* @par Thread Safety
* This class is **not thread-safe**. Each thread must have its own instance
* because SUNDIALS contexts and QSE solvers maintain internal state. When
* cloning, the QSE solvers are not copied and the SUNContext is not cloned;
* the new instance must call initialize() to create its own context.
*
* @see AbstractScratchPad
* @see MultiscalePartitioningEngineView
* @see SUNContext
*/
#pragma once
#include "gridfire/engine/views/engine_multiscale.h"
#include "gridfire/engine/scratchpads/scratchpad_abstract.h"
#include "gridfire/engine/scratchpads/types.h"
#include "fourdst/atomic/atomicSpecies.h"
#include <vector>
#include <memory>
#include <unordered_map>
#include "sundials/sundials_context.h"
namespace gridfire::engine::scratch {
/**
* @brief Scratchpad for multiscale partitioning engine computations with QSE groups.
*
* MultiscalePartitioningEngineViewScratchPad provides thread-local storage for
* the multiscale partitioning algorithm, which separates species into fast
* (algebraic/QSE) and slow (dynamic) timescale groups. This enables efficient
* integration of stiff reaction networks by solving QSE subsystems separately.
*
* @par Thread Safety
* This class is **not thread-safe**. SUNDIALS contexts and QSE solver objects
* maintain internal state that cannot be shared across threads. Each thread
* must have its own scratchpad instance with its own SUNContext. When using
* clone(), the new instance starts uninitialized and must call initialize()
* to create a new SUNContext.
*
* @note The destructor properly cleans up SUNDIALS resources by freeing the
* SUNContext after clearing all solvers that depend on it.
*
* @warning QSE solvers are **not cloned** - they must be re-created in the
* cloned scratchpad as needed.
*/
struct MultiscalePartitioningEngineViewScratchPad final : AbstractScratchPad {
/// @brief Type alias for QSE group from the multiscale engine view.
using QSEGroup = MultiscalePartitioningEngineView::QSEGroup;
/// @brief Type alias for QSE solver from the multiscale engine view.
using QSESolver = MultiscalePartitioningEngineView::QSESolver;
/// @brief Type alias for atomic species.
using Species = fourdst::atomic::Species;
/// @brief Unique identifier for this scratchpad type.
static constexpr auto ID = ScratchPadType::MULTISCALE_PARTITIONING_ENGINE_VIEW_SCRATCHPAD;
/// @brief Flag indicating whether the scratchpad has been initialized.
bool has_initialized = false;
/// @brief Vector of QSE groups representing equilibrium clusters.
std::vector<QSEGroup> qse_groups;
/// @brief Vector of QSE solvers, one per QSE group.
/// @note These are not cloned; new instances must create their own solvers.
std::vector<std::unique_ptr<QSESolver>> qse_solvers;
/// @brief Species that evolve on the dynamic (slow) timescale.
std::vector<Species> dynamic_species;
/// @brief Species that are solved algebraically (fast timescale/QSE).
std::vector<Species> algebraic_species;
/// @brief Cache of computed compositions indexed by a hash key.
std::unordered_map<uint64_t, fourdst::composition::Composition> composition_cache;
/// @brief SUNDIALS context for solver operations.
/// @note Must be freed in destructor after clearing solvers.
SUNContext sun_ctx = nullptr;
/**
* @brief Check whether the scratchpad has been initialized.
* @return true if initialized with a valid SUNContext, false otherwise.
*/
[[nodiscard]] bool is_initialized() const override { return has_initialized; }
/**
* @brief Initialize the scratchpad by creating a SUNDIALS context.
*
* Creates a new SUNContext for use with SUNDIALS solvers. This method
* is idempotent; calling it multiple times has no effect after the
* first successful initialization.
*
* @throws std::runtime_error if SUNContext creation fails.
*
* @par Examples
* @code{.cpp}
* MultiscalePartitioningEngineViewScratchPad scratch;
* scratch.initialize();
*
* // Now safe to use SUNContext for solver creation
* SUNContext ctx = scratch.sun_ctx;
* @endcode
*/
void initialize() {
if (has_initialized) return;
const int flag = SUNContext_Create(SUN_COMM_NULL, &sun_ctx);
if (flag != 0) {
throw std::runtime_error("Failed to create SUNContext in MultiscalePartitioningEngineViewScratchPad.");
}
has_initialized = true;
}
/**
* @brief Destructor that properly releases SUNDIALS resources.
*
* Clears all QSE solvers before freeing the SUNContext to ensure
* proper cleanup order and avoid dangling references.
*/
~MultiscalePartitioningEngineViewScratchPad() override {
qse_solvers.clear();
if (sun_ctx != nullptr) {
SUNContext_Free(&sun_ctx);
sun_ctx = nullptr;
}
}
/**
* @brief Create a partial copy of this scratchpad.
*
* Creates a copy with the QSE groups, species classifications, and
* composition cache.
*
* @return A unique pointer to the cloned scratchpad.
*
* @note The new instance will automatically initialize a new SUNContext and clone the QSE solvers.
*
* @par Examples
* @code{.cpp}
* MultiscalePartitioningEngineViewScratchPad scratch;
* scratch.initialize();
* scratch.dynamic_species = {hydrogen, helium};
*
* // Clone for a worker thread
* auto worker = scratch.clone();
*
* @endcode
*/
[[nodiscard]] std::unique_ptr<AbstractScratchPad> clone() const override {
auto clone_pad = std::make_unique<MultiscalePartitioningEngineViewScratchPad>();
clone_pad->qse_groups = this->qse_groups;
clone_pad->dynamic_species = this->dynamic_species;
clone_pad->algebraic_species = this->algebraic_species;
clone_pad->composition_cache = this->composition_cache;
clone_pad->initialize();
clone_pad->qse_solvers.reserve(this->qse_solvers.size());
for (const auto& solver : qse_solvers) {
clone_pad->qse_solvers.push_back(solver->clone(clone_pad->sun_ctx)); // Must rebind context to new SUNContext
}
return clone_pad;
}
};
} // namespace gridfire::engine::scratch

View File

@@ -0,0 +1,138 @@
#pragma once
#include "gridfire/engine/scratchpads/blob.h"
#include "gridfire/engine/scratchpads/scratchpad_abstract.h"
#include "gridfire/engine/scratchpads/engine_graph_scratchpad.h"
#include "gridfire/engine/scratchpads/engine_adaptive_scratchpad.h"
#include "gridfire/engine/scratchpads/engine_multiscale_scratchpad.h"
#include "gridfire/engine/scratchpads/engine_defined_scratchpad.h"
#include "gridfire/engine/scratchpads/types.h"
#include "gridfire/utils/logging.h"
#include <format>
#include <string>
#include <string_view>
// 1. ScratchPadType: Inherit from string_view formatter for efficiency
template <>
struct std::formatter<gridfire::engine::scratch::ScratchPadType> : std::formatter<std::string_view> {
// Note: NOT static, marked const
auto format(gridfire::engine::scratch::ScratchPadType type, auto& ctx) const {
// Convert to string_view
std::string_view name = gridfire::engine::scratch::get_scratchpad_type_name(type);
// Delegate to the base class to handle width/fill/alignment
return std::formatter<std::string_view>::format(name, ctx);
}
};
// 2. AbstractScratchPad
template <>
struct std::formatter<gridfire::engine::scratch::AbstractScratchPad> : std::formatter<std::string> {
auto format(const gridfire::engine::scratch::AbstractScratchPad& pad, auto& ctx) const {
std::string str = std::format("AbstractScratchPad(Initialized: {})",
pad.is_initialized());
return std::formatter<std::string>::format(str, ctx);
}
};
// 3. GraphEngineScratchPad
template <>
struct std::formatter<gridfire::engine::scratch::GraphEngineScratchPad> : std::formatter<std::string> {
auto format(const gridfire::engine::scratch::GraphEngineScratchPad& pad, auto& ctx) const {
std::string str = std::format("GraphEngineScratchPad(Initialized: {}, HasADFun: {}, CachedStepDerivatives: {}, CachedJacobians: {})",
pad.has_initialized,
pad.rhsADFun.has_value(),
pad.stepDerivativesCache.size(),
pad.jacobianSubsetCache.size());
return std::formatter<std::string>::format(str, ctx);
}
};
// 4. AdaptiveEngineViewScratchPad
template<>
struct std::formatter<gridfire::engine::scratch::AdaptiveEngineViewScratchPad> : std::formatter<std::string> {
auto format(const gridfire::engine::scratch::AdaptiveEngineViewScratchPad& pad, auto& ctx) const {
std::string str = std::format("AdaptiveEngineViewScratchPad(Initialized: {}, Active Species: {}, Active Reactions: {})",
pad.has_initialized,
pad.active_species.size(),
pad.active_reactions.size());
return std::formatter<std::string>::format(str, ctx);
}
};
// 5. MultiscalePartitioningEngineViewScratchPad
template <>
struct std::formatter<gridfire::engine::scratch::MultiscalePartitioningEngineViewScratchPad> : std::formatter<std::string> {
auto format(const gridfire::engine::scratch::MultiscalePartitioningEngineViewScratchPad& pad, auto& ctx) const {
std::string str = std::format("MultiscalePartitioningEngineViewScratchPad(Initialized: {}, QSE Groups: {}, Dynamic Species: {}, Algebraic Species: {}, Cached Compositions: {})",
pad.has_initialized,
pad.qse_groups.size(),
pad.dynamic_species.size(),
pad.algebraic_species.size(),
pad.composition_cache.size());
return std::formatter<std::string>::format(str, ctx);
}
};
// 6. DefinedEngineViewScratchPad
template <>
struct std::formatter<gridfire::engine::scratch::DefinedEngineViewScratchPad> : std::formatter<std::string> {
auto format(const gridfire::engine::scratch::DefinedEngineViewScratchPad& pad, auto& ctx) const {
std::string str = std::format("DefinedEngineViewScratchPad(Initialized: {}, Active Species: {}, Active Reactions: {})",
pad.has_initialized,
pad.active_species.size(),
pad.active_reactions.size());
return std::formatter<std::string>::format(str, ctx);
}
};
// 7. StateBlob
template <>
struct std::formatter<gridfire::engine::scratch::StateBlob> : std::formatter<std::string> {
auto format(const gridfire::engine::scratch::StateBlob& blob, auto& ctx) const {
// Construct the full string representation
std::string str = std::format("StateBlob(Enrolled: {})",
gridfire::utils::iterable_to_delimited_string(
blob.get_registered_scratchpads(),
", ",
[&blob](const auto& type) {
auto result = blob.get(type);
if (!result.has_value()) {
return std::format("{}(Error: {})",
gridfire::engine::scratch::get_scratchpad_type_name(type),
gridfire::engine::scratch::StateBlob::error_to_string(result.error()));
}
gridfire::engine::scratch::AbstractScratchPad* scratchpad = result.value();
// We can reuse the formatters we defined above by dereferencing the cast pointers!
switch (type) {
case gridfire::engine::scratch::ScratchPadType::GRAPH_ENGINE_SCRATCHPAD : {
auto* cast_pad = dynamic_cast<gridfire::engine::scratch::GraphEngineScratchPad*>(scratchpad);
// This works because we defined a formatter for GraphEngineScratchPad above
return std::format("{}", *cast_pad);
}
case gridfire::engine::scratch::ScratchPadType::MULTISCALE_PARTITIONING_ENGINE_VIEW_SCRATCHPAD : {
auto* cast_pad = dynamic_cast<gridfire::engine::scratch::MultiscalePartitioningEngineViewScratchPad*>(scratchpad);
return std::format("{}", *cast_pad);
}
case gridfire::engine::scratch::ScratchPadType::ADAPTIVE_ENGINE_VIEW_SCRATCHPAD : {
auto* cast_pad = dynamic_cast<gridfire::engine::scratch::AdaptiveEngineViewScratchPad*>(scratchpad);
return std::format("{}", *cast_pad);
}
case gridfire::engine::scratch::ScratchPadType::DEFINED_ENGINE_VIEW_SCRATCHPAD : {
auto* cast_pad = dynamic_cast<gridfire::engine::scratch::DefinedEngineViewScratchPad*>(scratchpad);
return std::format("{}", *cast_pad);
}
default: {
return std::format("{}(Unknown ScratchPad Type)", gridfire::engine::scratch::get_scratchpad_type_name(type));
}
}
}
)
);
// Delegate to base class to write the string to the context
return std::formatter<std::string>::format(str, ctx);
}
};

View File

@@ -0,0 +1,137 @@
/**
* @file scratchpad_abstract.h
* @brief Abstract base class for scratchpad memory used during engine computations.
*
* This header defines the AbstractScratchPad interface, which provides a common
* contract for temporary working memory (scratchpads) used by computational engines.
* Scratchpads are designed to store intermediate results, cached computations, or
* pre-allocated buffers that can be reused across multiple computational steps,
* improving performance by avoiding repeated memory allocations.
*
* @par Purpose
* The scratchpad pattern allows engines to:
* - Pre-allocate working memory once and reuse it across iterations
* - Store intermediate computational results between solver steps
* - Enable efficient cloning for parallel execution contexts
* - Provide a type-erased interface for heterogeneous scratchpad management
*
* @par Examples
* @code{.cpp}
* // Define a concrete scratchpad for a specific engine
* class MyScratchPad : public gridfire::engine::scratch::AbstractScratchPad {
* public:
* MyScratchPad() : initialized_(false) {}
*
* void initialize(size_t size) {
* buffer_.resize(size);
* initialized_ = true;
* }
*
* [[nodiscard]] bool is_initialized() const override {
* return initialized_;
* }
*
* [[nodiscard]] std::unique_ptr<AbstractScratchPad> clone() const override {
* auto copy = std::make_unique<MyScratchPad>();
* copy->buffer_ = buffer_;
* copy->initialized_ = initialized_;
* return copy;
* }
*
* private:
* std::vector<double> buffer_;
* bool initialized_;
* };
*
* // Usage in an engine context
* auto scratch = std::make_unique<MyScratchPad>();
* scratch->initialize(1024);
*
* if (scratch->is_initialized()) {
* // Use scratchpad for computations
* auto parallel_scratch = scratch->clone(); // Clone for parallel worker
* }
* @endcode
*/
#pragma once
#include <memory>
namespace gridfire::engine::scratch {
/**
* @brief Abstract base struct for engine scratchpad memory.
*
* AbstractScratchPad defines the interface for temporary working memory
* containers used by computational engines. Implementations should provide
* storage for intermediate results, cached values, or pre-allocated buffers
* that persist across multiple computational steps.
*
* This interface enables polymorphic handling of different scratchpad types
* while ensuring proper resource management through virtual destruction and
* deep cloning capabilities.
*
* @par Thread Safety
* This interface is **not thread-safe** by design. Scratchpads are intended
* to be used as thread-local working memory. Each thread should operate on
* its own independent scratchpad instance. Use the clone() method to create
* separate copies for each thread in parallel execution contexts. Sharing a
* single scratchpad instance across multiple threads without external
* synchronization will result in undefined behavior.
*/
struct AbstractScratchPad {
/**
* @brief Virtual destructor for proper cleanup of derived classes.
*
* Ensures that resources held by concrete scratchpad implementations
* are properly released when the scratchpad is destroyed through a
* base class pointer.
*/
virtual ~AbstractScratchPad() = default;
/**
* @brief Check whether the scratchpad has been properly initialized.
*
* Derived classes should return true only after all necessary memory
* allocations and setup operations have been completed successfully.
*
* @return true if the scratchpad is initialized and ready for use.
* @return false if the scratchpad has not been initialized or initialization failed.
*
* @par Examples
* @code{.cpp}
* auto scratch = create_scratchpad();
* if (!scratch->is_initialized()) {
* throw std::runtime_error("Scratchpad not ready for computation");
* }
* @endcode
*/
[[nodiscard]] virtual bool is_initialized() const = 0;
/**
* @brief Create a deep copy of this scratchpad.
*
* Produces an independent clone of the scratchpad, including all internal
* state and allocated memory. This is essential for parallel execution
* scenarios where each thread requires its own working memory.
*
* @return A unique pointer to a newly allocated copy of this scratchpad.
*
* @note The returned clone should be fully independent; modifications to
* the clone must not affect the original, and vice versa.
*
* @par Examples
* @code{.cpp}
* std::unique_ptr<AbstractScratchPad> original = create_scratchpad();
*
* // Create independent copies for parallel workers
* std::vector<std::unique_ptr<AbstractScratchPad>> worker_scratches;
* for (int i = 0; i < num_threads; ++i) {
* worker_scratches.push_back(original->clone());
* }
* @endcode
*/
[[nodiscard]] virtual std::unique_ptr<AbstractScratchPad> clone() const = 0;
};
} // namespace gridfire::engine::scratch

View File

@@ -0,0 +1,167 @@
/**
* @file scratchpads.h
* @brief Unified header for the scratchpad memory management system.
*
* This is the main include file for the scratchpad subsystem. It provides
* a single include point for all scratchpad-related functionality, including
* the abstract base class, concrete implementations, type definitions,
* the StateBlob container, utility functions, and formatters for debugging.
*
* @par What are Scratchpads?
* Scratchpads are temporary working memory containers used by computational
* engines during ODE integration and reaction network calculations. They serve
* several critical purposes:
*
* - **Performance**: Pre-allocate memory once and reuse across iterations,
* avoiding repeated heap allocations during time-critical computations.
* - **Caching**: Store intermediate results (Jacobians, derivatives, etc.)
* that can be reused across solver steps.
* - **Thread Safety**: Provide thread-local storage for parallel execution,
* where each thread operates on its own independent scratchpad instance.
* - **State Management**: Encapsulate engine-specific working state separate
* from the engine's persistent configuration.
*
* @par Architecture Overview
* The scratchpad system consists of:
*
* - **AbstractScratchPad**: Base interface defining `is_initialized()` and `clone()`
* - **Concrete Scratchpads**: Engine-specific implementations
* - `GraphEngineScratchPad`: CppAD ADFun and Jacobian caches
* - `AdaptiveEngineViewScratchPad`: Active species/reactions for adaptive networks
* - `DefinedEngineViewScratchPad`: Species/reactions for static networks
* - `MultiscalePartitioningEngineViewScratchPad`: QSE groups and SUNDIALS context
* - **StateBlob**: Container managing multiple scratchpads with type-safe access
* - **Utilities**: Helper functions for exception-based retrieval
* - **Formatters**: std::format specializations for debugging output
*
* @par Why Use Scratchpads?
* During numerical integration of stiff reaction networks, engines must:
* 1. Evaluate right-hand side functions (species derivatives)
* 2. Compute sparse Jacobian matrices
* 3. Solve linear systems within Newton iterations
*
* These operations require substantial temporary memory. Without scratchpads,
* each evaluation would allocate and deallocate working buffers, causing:
* - Memory fragmentation
* - Cache thrashing
* - Unnecessary allocation overhead
*
* Scratchpads solve this by providing persistent, reusable working memory
* that lives for the duration of an integration step (or longer).
*
* @par Thread Safety Model
* The scratchpad system is designed for thread-local usage:
*
* - Scratchpads are **not thread-safe** by design
* - Each thread must have its own scratchpad instances
* - Use `StateBlob::clone_structure()` to create independent copies for workers
* - The original scratchpad/blob remains usable by the main thread
*
* @par Examples
* @code{.cpp}
* #include "gridfire/engine/scratchpads/scratchpads.h"
*
* using namespace gridfire::engine::scratch;
*
* // === Basic Usage ===
* // Create a StateBlob and enroll scratchpads
* StateBlob blob;
* blob.enroll<GraphEngineScratchPad>();
* blob.enroll<AdaptiveEngineViewScratchPad>();
*
* // Initialize scratchpads
* auto* graph_scratch = get_state<GraphEngineScratchPad>(blob);
* graph_scratch->initialize(engine);
*
* // Use initialized scratchpad
* auto* scratch = get_state<GraphEngineScratchPad, true>(blob); // Throws if not initialized
* auto& adfun = scratch->rhsADFun.value();
*
* // === Parallel Execution ===
* // Clone for worker threads
* std::vector<std::unique_ptr<StateBlob>> worker_blobs;
* for (int i = 0; i < num_threads; ++i) {
* worker_blobs.push_back(blob.clone_structure());
* }
*
* // Each worker uses its own blob
* #pragma omp parallel for
* for (int i = 0; i < work_items; ++i) {
* int tid = omp_get_thread_num();
* StateBlob& my_blob = *worker_blobs[tid];
* auto* my_scratch = get_state<GraphEngineScratchPad>(my_blob);
* // Thread-safe: each thread has its own scratchpad
* compute_with_scratchpad(*my_scratch);
* }
*
* // === Debugging ===
* // Use formatters for logging
* std::cout << std::format("Blob state: {}\n", blob);
* // Output: StateBlob(Enrolled: GraphEngineScratchPad(...), AdaptiveEngineViewScratchPad(...))
*
* // Check status
* auto status = blob.get_status<GraphEngineScratchPad>();
* if (status == StateBlob::ScratchPadStatus::ENROLLED_NOT_INITIALIZED) {
* // Need to initialize before use
* }
* @endcode
*
* @par Included Headers
* This header includes:
* - scratchpad_abstract.h - AbstractScratchPad base class
* - engine_graph_scratchpad.h - GraphEngineScratchPad
* - engine_adaptive_scratchpad.h - AdaptiveEngineViewScratchPad
* - engine_multiscale_scratchpad.h - MultiscalePartitioningEngineViewScratchPad
* - engine_defined_scratchpad.h - DefinedEngineViewScratchPad
* - types.h - ScratchPadType enumeration
* - blob.h - StateBlob container
* - utils.h - get_state() helper functions
* - formatters.h - std::format specializations
*
* @see AbstractScratchPad
* @see StateBlob
* @see ScratchPadType
*/
/**
* @namespace gridfire::engine::scratch
* @brief Scratchpad memory management for computational engines.
*
* The scratch namespace contains all components related to temporary working
* memory management for GridFire's computational engines. This includes the
* abstract scratchpad interface, concrete implementations for each engine type,
* the StateBlob container for managing multiple scratchpads, and utilities
* for convenient access.
*
* @par Key Components
* - AbstractScratchPad: Interface for all scratchpad types
* - GraphEngineScratchPad: Working memory for CppAD-based graph engines
* - AdaptiveEngineViewScratchPad: Storage for adaptive network computations
* - DefinedEngineViewScratchPad: Storage for static network computations
* - MultiscalePartitioningEngineViewScratchPad: QSE solver state management
* - StateBlob: Type-safe container for multiple scratchpads
* - ScratchPadType: Enumeration of registered scratchpad types
* - IsScratchPad: Concept constraining valid scratchpad types
* - get_state(): Utility functions for exception-based retrieval
*
* @par Design Philosophy
* The scratchpad system follows these design principles:
*
* 1. **Separation of Concerns**: Working memory is separate from engine configuration
* 2. **Type Safety**: Compile-time verification of scratchpad types via concepts
* 3. **Performance**: O(1) access via enum-indexed arrays
* 4. **Thread Locality**: Each thread owns its scratchpad instances
* 5. **Clonability**: Deep copying enables parallel execution patterns
*/
#pragma once
#include "gridfire/engine/scratchpads/scratchpad_abstract.h"
#include "gridfire/engine/scratchpads/engine_graph_scratchpad.h"
#include "gridfire/engine/scratchpads/engine_adaptive_scratchpad.h"
#include "gridfire/engine/scratchpads/engine_multiscale_scratchpad.h"
#include "gridfire/engine/scratchpads/engine_defined_scratchpad.h"
#include "gridfire/engine/scratchpads/types.h"
#include "gridfire/engine/scratchpads/blob.h"
#include "gridfire/engine/scratchpads/utils.h"
#include "gridfire/engine/scratchpads/formatters.h"

View File

@@ -0,0 +1,141 @@
/**
* @file types.h
* @brief Type definitions and utilities for the scratchpad system.
*
* This header defines the ScratchPadType enumeration which identifies all
* registered scratchpad types in the system, along with utility functions
* for querying scratchpad type information at compile-time and runtime.
*
* @par Purpose
* The types header provides:
* - A centralized enumeration of all scratchpad types
* - Compile-time constant for the maximum number of scratchpad types
* - Runtime conversion of scratchpad types to human-readable names
*
* @par Adding New Scratchpad Types
* To add a new scratchpad type:
* 1. Add a new enumerator before `_COUNT` in ScratchPadType
* 2. Add a corresponding case in get_scratchpad_type_name()
* 3. Create the concrete scratchpad class with a static `ID` member set to the new type
*
* @par Examples
* @code{.cpp}
* #include "gridfire/engine/scratchpads/types.h"
*
* using namespace gridfire::engine::scratch;
*
* // Get the maximum number of scratchpad types at compile time
* constexpr size_t max_types = get_max_scratchpad_types();
* std::array<bool, max_types> enrolled_flags{};
*
* // Get a human-readable name for a scratchpad type
* ScratchPadType type = ScratchPadType::GRAPH_ENGINE_SCRATCHPAD;
* std::string_view name = get_scratchpad_type_name(type);
* // name == "GraphEngineScratchPad"
*
* // Iterate over all scratchpad types
* for (size_t i = 0; i < get_max_scratchpad_types(); ++i) {
* auto type = static_cast<ScratchPadType>(i);
* std::cout << get_scratchpad_type_name(type) << "\n";
* }
* @endcode
*
* @see AbstractScratchPad
* @see StateBlob
*/
#pragma once
#include <cstdint>
#include <string_view>
namespace gridfire::engine::scratch {
/**
* @brief Enumeration of all registered scratchpad types.
*
* Each scratchpad implementation must have a unique type identifier in this
* enumeration. The concrete scratchpad class should define a static `ID`
* member initialized to its corresponding ScratchPadType value.
*
* @note The `_COUNT` enumerator is a sentinel value used to determine the
* total number of scratchpad types. It must always be the last entry.
* Do not use `_COUNT` as an actual scratchpad type.
*/
enum class ScratchPadType : uint8_t {
GRAPH_ENGINE_SCRATCHPAD, ///< GraphEngineScratchPad for CppAD-based engines.
MULTISCALE_PARTITIONING_ENGINE_VIEW_SCRATCHPAD, ///< MultiscalePartitioningEngineViewScratchPad for QSE partitioning.
ADAPTIVE_ENGINE_VIEW_SCRATCHPAD, ///< AdaptiveEngineViewScratchPad for adaptive networks.
DEFINED_ENGINE_VIEW_SCRATCHPAD, ///< DefinedEngineViewScratchPad for static networks.
PRIMING_ENGINE_VIEW_SCRATCHPAD, ///< PrimingEngineViewScratchPad for engine priming.
_COUNT ///< Sentinel value representing the total number of scratchpad types. Do not use as a type.
};
/**
* @brief Get the maximum number of scratchpad types at compile time.
*
* Returns the total count of registered scratchpad types, derived from
* the ScratchPadType::_COUNT sentinel value. This is useful for sizing
* fixed-size arrays that need a slot for each scratchpad type.
*
* @return The number of valid scratchpad types (excluding _COUNT).
*
* @par Examples
* @code{.cpp}
* // Use at compile time for array sizing
* constexpr size_t NUM_TYPES = get_max_scratchpad_types();
* std::array<std::unique_ptr<AbstractScratchPad>, NUM_TYPES> scratchpads;
*
* // Use in static_assert
* static_assert(get_max_scratchpad_types() > 0, "No scratchpad types defined");
* @endcode
*/
consteval size_t get_max_scratchpad_types() {
return static_cast<size_t>(ScratchPadType::_COUNT);
}
/**
* @brief Convert a ScratchPadType to a human-readable name.
*
* Returns a string view containing the class name associated with
* the given scratchpad type. Useful for logging, debugging, and
* error messages.
*
* @param scratchpad_type The scratchpad type to convert.
*
* @return A string view containing the scratchpad class name, or
* "UnknownScratchPadType" for unrecognized values.
*
* @par Examples
* @code{.cpp}
* ScratchPadType type = ScratchPadType::GRAPH_ENGINE_SCRATCHPAD;
* std::cout << "Using: " << get_scratchpad_type_name(type) << "\n";
* // Output: "Using: GraphEngineScratchPad"
*
* // Use in error messages
* throw std::runtime_error(
* std::format("Failed to initialize {}", get_scratchpad_type_name(type))
* );
* @endcode
*/
constexpr std::string_view get_scratchpad_type_name(const ScratchPadType scratchpad_type) {
if constexpr (get_max_scratchpad_types() == 0) {
return {""};
}
switch (scratchpad_type) {
case ScratchPadType::GRAPH_ENGINE_SCRATCHPAD:
return "GraphEngineScratchPad";
case ScratchPadType::MULTISCALE_PARTITIONING_ENGINE_VIEW_SCRATCHPAD:
return "MultiscalePartitioningEngineViewScratchPad";
case ScratchPadType::ADAPTIVE_ENGINE_VIEW_SCRATCHPAD:
return "AdaptiveEngineViewScratchPad";
case ScratchPadType::DEFINED_ENGINE_VIEW_SCRATCHPAD:
return "DefinedEngineViewScratchPad";
case ScratchPadType::PRIMING_ENGINE_VIEW_SCRATCHPAD:
return "PrimingEngineViewScratchPad";
default:
return "UnknownScratchPadType";
}
}
} // namespace gridfire::engine::scratch

View File

@@ -0,0 +1,240 @@
/**
* @file utils.h
* @brief Utility functions for convenient scratchpad retrieval with exception handling.
*
* This header provides helper functions that wrap StateBlob's get() methods,
* converting error codes into exceptions for simpler error handling. These
* utilities eliminate the need to manually check std::expected results and
* switch on error codes at every call site.
*
* @par Purpose
* The utility functions provide:
* - Exception-based error handling instead of std::expected
* - Consistent error messages across the codebase
* - Both mutable and const-correct overloads
* - Optional initialization checking via template parameter
*
* @par Examples
* @code{.cpp}
* #include "gridfire/engine/scratchpads/utils.h"
* #include "gridfire/engine/scratchpads/engine_graph_scratchpad.h"
*
* using namespace gridfire::engine::scratch;
*
* void compute(StateBlob& blob) {
* // Simple retrieval - throws if not found
* GraphEngineScratchPad* scratch = get_state<GraphEngineScratchPad>(blob);
* scratch->initialize(engine);
*
* // Retrieval with initialization check - throws if not initialized
* auto* initialized_scratch = get_state<GraphEngineScratchPad, true>(blob);
* // Safe to use - guaranteed to be initialized
* }
*
* void read_only_access(const StateBlob& blob) {
* // Const overload for read-only access
* const GraphEngineScratchPad* scratch = get_state<GraphEngineScratchPad>(blob);
* bool ready = scratch->is_initialized();
* }
* @endcode
*
* @par Thread Safety
* These functions inherit the thread safety characteristics of StateBlob.
* They are **not thread-safe** - each thread should operate on its own
* StateBlob instance.
*
* @see StateBlob
* @see IsScratchPad
* @see exceptions::ScratchPadError
*/
#pragma once
#include "gridfire/engine/scratchpads/blob.h"
#include "gridfire/exceptions/error_scratchpad.h"
namespace gridfire::engine::scratch {
/**
* @brief Retrieve a scratchpad from a StateBlob, throwing on error.
*
* Convenience wrapper around StateBlob::get() that converts error codes
* into exceptions. Use this when you expect the scratchpad to exist and
* want exception-based error handling.
*
* @tparam CTX The scratchpad type to retrieve (must satisfy IsScratchPad).
*
* @param ctx The StateBlob to retrieve the scratchpad from.
*
* @return Pointer to the requested scratchpad.
*
* @throws exceptions::ScratchPadError if the scratchpad is not found,
* cannot be cast to the requested type, or any other error occurs.
*
* @par Examples
* @code{.cpp}
* StateBlob blob;
* blob.enroll<GraphEngineScratchPad>();
*
* // Retrieve the scratchpad - throws if not enrolled
* GraphEngineScratchPad* scratch = get_state<GraphEngineScratchPad>(blob);
* scratch->initialize(engine);
* @endcode
*/
template <IsScratchPad CTX>
CTX* get_state(StateBlob& ctx) {
auto result = ctx.get<CTX>();
if (!result.has_value()) {
switch (result.error()) {
case StateBlob::Error::SCRATCHPAD_NOT_FOUND:
throw exceptions::ScratchPadError("Requested scratchpad not found in StateBlob.");
case StateBlob::Error::SCRATCHPAD_BAD_CAST:
throw exceptions::ScratchPadError("Failed to cast scratchpad to the requested type.");
case StateBlob::Error::SCRATCHPAD_TYPE_COLLISION:
throw exceptions::ScratchPadError("Scratchpad type collision detected in StateBlob.");
default:
throw exceptions::ScratchPadError("Unknown error occurred while retrieving scratchpad from StateBlob.");
}
}
return result.value();
}
/**
* @brief Retrieve a const scratchpad from a const StateBlob, throwing on error.
*
* Const-correct overload of get_state() for read-only access to scratchpads.
* Use this when you have a const reference to the StateBlob and only need
* to read from the scratchpad.
*
* @tparam CTX The scratchpad type to retrieve (must satisfy IsScratchPad).
*
* @param ctx The const StateBlob to retrieve the scratchpad from.
*
* @return Const pointer to the requested scratchpad.
*
* @throws exceptions::ScratchPadError if the scratchpad is not found,
* cannot be cast to the requested type, or any other error occurs.
*
* @par Examples
* @code{.cpp}
* void inspect(const StateBlob& blob) {
* const GraphEngineScratchPad* scratch = get_state<GraphEngineScratchPad>(blob);
* if (scratch->is_initialized()) {
* // Read from scratch...
* }
* }
* @endcode
*/
template <IsScratchPad CTX>
const CTX* get_state(const StateBlob& ctx) {
auto result = ctx.get<CTX>();
if (!result.has_value()) {
switch (result.error()) {
case StateBlob::Error::SCRATCHPAD_NOT_FOUND:
throw exceptions::ScratchPadError("Requested scratchpad not found in StateBlob.");
case StateBlob::Error::SCRATCHPAD_BAD_CAST:
throw exceptions::ScratchPadError("Failed to cast scratchpad to the requested type.");
case StateBlob::Error::SCRATCHPAD_TYPE_COLLISION:
throw exceptions::ScratchPadError("Scratchpad type collision detected in StateBlob.");
default:
throw exceptions::ScratchPadError("Unknown error occurred while retrieving scratchpad from StateBlob.");
}
}
return result.value();
}
/**
* @brief Retrieve a scratchpad with optional initialization check, throwing on error.
*
* Extended version of get_state() that can optionally verify the scratchpad
* is initialized before returning it. When MUST_BE_INITIALIZED is true, an
* exception is thrown if the scratchpad exists but hasn't been initialized.
*
* @tparam CTX The scratchpad type to retrieve (must satisfy IsScratchPad).
* @tparam MUST_BE_INITIALIZED If true, throws when the scratchpad is not initialized.
*
* @param ctx The StateBlob to retrieve the scratchpad from.
*
* @return Pointer to the requested scratchpad (guaranteed initialized if MUST_BE_INITIALIZED is true).
*
* @throws exceptions::ScratchPadError if the scratchpad is not found,
* cannot be cast, is not initialized (when required), or any other error.
*
* @par Examples
* @code{.cpp}
* // Ensure scratchpad is initialized before use
* try {
* auto* scratch = get_state<GraphEngineScratchPad, true>(blob);
* // Guaranteed to be initialized here
* use_scratchpad(*scratch);
* } catch (const exceptions::ScratchPadError& e) {
* // Handle missing or uninitialized scratchpad
* }
* @endcode
*/
template <IsScratchPad CTX, bool MUST_BE_INITIALIZED>
CTX* get_state(StateBlob& ctx) {
auto result = ctx.get<CTX, MUST_BE_INITIALIZED>();
if (!result.has_value()) {
switch (result.error()) {
case StateBlob::Error::SCRATCHPAD_NOT_FOUND:
throw exceptions::ScratchPadError("Requested scratchpad not found in StateBlob.");
case StateBlob::Error::SCRATCHPAD_BAD_CAST:
throw exceptions::ScratchPadError("Failed to cast scratchpad to the requested type.");
case StateBlob::Error::SCRATCHPAD_TYPE_COLLISION:
throw exceptions::ScratchPadError("Scratchpad type collision detected in StateBlob.");
case StateBlob::Error::SCRATCHPAD_NOT_INITIALIZED:
throw exceptions::ScratchPadError("Requested scratchpad not initialized in StateBlob. If this is acceptable behavior, use scratch::get_state<>() without the MUST_BE_INITIALIZED template parameter.");
default:
throw exceptions::ScratchPadError("Unknown error occurred while retrieving scratchpad from StateBlob.");
}
}
return result.value();
}
/**
* @brief Retrieve a const scratchpad with optional initialization check, throwing on error.
*
* Const-correct overload of the initialization-checking get_state(). Combines
* read-only access with optional initialization verification.
*
* @tparam CTX The scratchpad type to retrieve (must satisfy IsScratchPad).
* @tparam MUST_BE_INITIALIZED If true, throws when the scratchpad is not initialized.
*
* @param ctx The const StateBlob to retrieve the scratchpad from.
*
* @return Const pointer to the requested scratchpad (guaranteed initialized if MUST_BE_INITIALIZED is true).
*
* @throws exceptions::ScratchPadError if the scratchpad is not found,
* cannot be cast, is not initialized (when required), or any other error.
*
* @par Examples
* @code{.cpp}
* void validate(const StateBlob& blob) {
* // Get const access, ensuring initialization
* const auto* scratch = get_state<GraphEngineScratchPad, true>(blob);
* // Safe to read - guaranteed initialized
* const auto& adfun = scratch->rhsADFun.value();
* }
* @endcode
*/
template <IsScratchPad CTX, bool MUST_BE_INITIALIZED>
const CTX* get_state(const StateBlob& ctx) {
auto result = ctx.get<CTX, MUST_BE_INITIALIZED>();
if (!result.has_value()) {
switch (result.error()) {
case StateBlob::Error::SCRATCHPAD_NOT_FOUND:
throw exceptions::ScratchPadError("Requested scratchpad not found in StateBlob.");
case StateBlob::Error::SCRATCHPAD_BAD_CAST:
throw exceptions::ScratchPadError("Failed to cast scratchpad to the requested type.");
case StateBlob::Error::SCRATCHPAD_TYPE_COLLISION:
throw exceptions::ScratchPadError("Scratchpad type collision detected in StateBlob.");
case StateBlob::Error::SCRATCHPAD_NOT_INITIALIZED:
throw exceptions::ScratchPadError("Requested scratchpad is not initialized. If this is acceptable behavior, use scratch::get_state<>() without the MUST_BE_INITIALIZED template parameter.");
default:
throw exceptions::ScratchPadError("Unknown error occurred while retrieving scratchpad from StateBlob.");
}
}
return result.value();
}
} // namespace gridfire::engine::scratch

View File

@@ -11,6 +11,7 @@
#include "fourdst/logging/logging.h"
#include "gridfire/engine/procedures/construction.h"
#include "gridfire/engine/scratchpads/blob.h"
#include "quill/Logger.h"
@@ -62,6 +63,7 @@ namespace gridfire::engine {
/**
* @brief Updates the active species and reactions based on the current conditions.
*
* @param ctx The scratchpad context for storing thread-local data.
* @param netIn The current network input, containing temperature, density, and composition.
*
* This method performs the reaction flow calculation, reaction culling, connectivity analysis,
@@ -76,19 +78,23 @@ namespace gridfire::engine {
* @see AdaptiveEngineView::constructSpeciesIndexMap()
* @see AdaptiveEngineView::constructReactionIndexMap()
*/
fourdst::composition::Composition update(const NetIn &netIn) override;
bool isStale(const NetIn& netIn) override;
fourdst::composition::Composition project(
scratch::StateBlob& ctx,
const NetIn &netIn
) const override;
/**
* @brief Gets the list of active species in the network.
* @return A const reference to the vector of active species.
*/
[[nodiscard]] const std::vector<fourdst::atomic::Species>& getNetworkSpecies() const override;
[[nodiscard]] const std::vector<fourdst::atomic::Species>& getNetworkSpecies(
scratch::StateBlob& ctx
) const override;
/**
* @brief Calculates the right-hand side (dY/dt) and energy generation for the active species.
*
* @param ctx The scratchpad context for storing thread-local data.
* @param comp The current composition of the system.
* @param T9 The temperature in units of 10^9 K.
* @param rho The density in g/cm^3.
@@ -103,7 +109,8 @@ namespace gridfire::engine {
* @throws std::runtime_error If the AdaptiveEngineView is stale (i.e., `update()` has not been called).
* @see AdaptiveEngineView::update()
*/
[[nodiscard]] std::expected<StepDerivatives<double>, engine::EngineStatus> calculateRHSAndEnergy(
[[nodiscard]] std::expected<StepDerivatives<double>, EngineStatus> calculateRHSAndEnergy(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho,
@@ -113,12 +120,14 @@ namespace gridfire::engine {
/**
*
* @param ctx The scratchpad context for storing thread-local data.
* @param comp The current composition of the system.
* @param T9 The temperature in units of 10^9 K.
* @param rho The density in g/cm^3.
* @return A struct containing the derivatives of the energy generation rate with respect to temperature and density.
*/
[[nodiscard]] EnergyDerivatives calculateEpsDerivatives(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho
@@ -127,6 +136,7 @@ namespace gridfire::engine {
/**
* @brief Generates the Jacobian matrix for the active species.
*
* @param ctx The scratchpad context for storing thread-local data.
* @param comp The current composition of the system.
* @param T9 The temperature in units of 10^9 K.
* @param rho The density in g/cm^3.
@@ -138,6 +148,7 @@ namespace gridfire::engine {
* @see AdaptiveEngineView::update()
*/
[[nodiscard]] NetworkJacobian generateJacobianMatrix(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho
@@ -146,6 +157,7 @@ namespace gridfire::engine {
/**
* @brief Generates the Jacobian matrix for some set of active species such that that set is a subset of the active species in the view.
*
* @param ctx The scratchpad context for storing thread-local data.
* @param comp The current composition of the system.
* @param T9 The temperature in units of 10^9 K.
* @param rho The density in g/cm^3.
@@ -158,6 +170,7 @@ namespace gridfire::engine {
* @see AdaptiveEngineView::update()
*/
[[nodiscard]] NetworkJacobian generateJacobianMatrix(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho,
@@ -167,6 +180,7 @@ namespace gridfire::engine {
/**
* @brief Generates the Jacobian matrix for the active species with a given sparsity pattern.
*
* @param ctx The scratchpad context for storing thread-local data.
* @param comp The current composition of the system.
* @param T9 The temperature in units of 10^9 K.
* @param rho The density in g/cm^3.
@@ -179,45 +193,18 @@ namespace gridfire::engine {
* @see AdaptiveEngineView::update()
*/
[[nodiscard]] NetworkJacobian generateJacobianMatrix(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho,
const SparsityPattern &sparsityPattern
) const override;
/**
* @brief Generates the stoichiometry matrix for the active reactions and species.
*
* This method calls the base engine to generate the stoichiometry matrix.
*
* @throws std::runtime_error If the AdaptiveEngineView is stale (i.e., `update()` has not been called).
* @note The stoichiometry matrix generated by the base engine is assumed to be consistent with
* the active species and reactions in this view.
*/
void generateStoichiometryMatrix() override;
/**
* @brief Gets an entry from the stoichiometry matrix for the active species and reactions.
*
* @param species The species for which to get the stoichiometric coefficient.
* @param reaction The reaction for which to get the stoichiometric coefficient.
* @return The stoichiometric coefficient for the given species and reaction.
*
* This method maps the culled indices to the full network indices and calls the base engine
* to get the stoichiometry matrix entry.
*
* @throws std::runtime_error If the AdaptiveEngineView is stale (i.e., `update()` has not been called).
* @throws std::out_of_range If the culled index is out of bounds for the species or reaction index map.
* @see AdaptiveEngineView::update()
*/
[[nodiscard]] int getStoichiometryMatrixEntry(
const fourdst::atomic::Species& species,
const reaction::Reaction& reaction
) const override;
/**
* @brief Calculates the molar reaction flow for a given reaction in the active network.
*
* @param ctx The scratchpad context for storing thread-local data.
* @param reaction The reaction for which to calculate the flow.
* @param comp Composition object containing current abundances.
* @param T9 Temperature in units of 10^9 K.
@@ -231,6 +218,7 @@ namespace gridfire::engine {
* @throws std::runtime_error If the reaction is not part of the active reactions in the adaptive engine view.
*/
[[nodiscard]] double calculateMolarReactionFlow(
scratch::StateBlob& ctx,
const reaction::Reaction &reaction,
const fourdst::composition::CompositionAbstract &comp,
double T9,
@@ -242,22 +230,14 @@ namespace gridfire::engine {
*
* @return Reference to the LogicalReactionSet containing all active reactions.
*/
[[nodiscard]] const reaction::ReactionSet& getNetworkReactions() const override;
/**
* @brief Sets the reaction set for the base engine.
*
* This method delegates the call to the base engine to set the reaction set.
*
* @param reactions The ReactionSet to set in the base engine.
*
* @post The reaction set of the base engine is updated.
*/
void setNetworkReactions(const reaction::ReactionSet& reactions) override;
[[nodiscard]] const reaction::ReactionSet& getNetworkReactions(
scratch::StateBlob& ctx
) const override;
/**
* @brief Computes timescales for all active species in the network.
*
* @param ctx The scratchpad context for storing thread-local data.
* @param comp Composition object containing current abundances.
* @param T9 Temperature in units of 10^9 K.
* @param rho Density in g/cm^3.
@@ -269,6 +249,7 @@ namespace gridfire::engine {
* @throws std::runtime_error If the AdaptiveEngineView is stale (i.e., `update()` has not been called).
*/
[[nodiscard]] std::expected<std::unordered_map<fourdst::atomic::Species, double>, EngineStatus> getSpeciesTimescales(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho
@@ -277,6 +258,7 @@ namespace gridfire::engine {
/**
* @brief Computes destruction timescales for all active species in the network.
*
* @param ctx The scratchpad context for storing thread-local data.
* @param comp Composition object containing current abundances.
* @param T9 Temperature in units of 10^9 K.
* @param rho Density in g/cm^3.
@@ -288,6 +270,7 @@ namespace gridfire::engine {
* @throws std::runtime_error If the AdaptiveEngineView is stale (i.e., `update()` has not been called).
*/
[[nodiscard]] std::expected<std::unordered_map<fourdst::atomic::Species, double>, EngineStatus> getSpeciesDestructionTimescales(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho
@@ -297,24 +280,9 @@ namespace gridfire::engine {
* @brief Gets the base engine.
* @return A const reference to the base engine.
*/
[[nodiscard]] const DynamicEngine& getBaseEngine() const override { return m_baseEngine; }
/**
* @brief Sets the screening model for the base engine.
*
* This method delegates the call to the base engine to set the electron screening model.
*
* @param model The electron screening model to set.
*
* @par Usage Example:
* @code
* AdaptiveEngineView engineView(...);
* engineView.setScreeningModel(screening::ScreeningType::WEAK);
* @endcode
*
* @post The screening model of the base engine is updated.
*/
void setScreeningModel(screening::ScreeningType model) override;
[[nodiscard]] const DynamicEngine& getBaseEngine() const override {
return m_baseEngine;
}
/**
* @brief Gets the screening model from the base engine.
@@ -329,32 +297,24 @@ namespace gridfire::engine {
* screening::ScreeningType model = engineView.getScreeningModel();
* @endcode
*/
[[nodiscard]] screening::ScreeningType getScreeningModel() const override;
[[nodiscard]] screening::ScreeningType getScreeningModel(
scratch::StateBlob& ctx
) const override;
/**
* @brief Gets the index of a species in the active species list.
*
* @param ctx The scratchpad context for storing thread-local data.
* @param species The species for which to get the index.
* @return The index of the species in the active species list.
*
* @throws std::runtime_error If the AdaptiveEngineView is stale (i.e., `update()` has not been called).
* @throws std::out_of_range If the species is not part of the active species in the adaptive engine view.
*/
[[nodiscard]] size_t getSpeciesIndex(const fourdst::atomic::Species &species) const override;
/**
* @brief Maps the molar abundance vector from the active species to the full network species.
*
* @param netIn The current network input, containing temperature, density, and composition.
* @return A vector of molar abundances for all species in the full network.
*
* This method constructs a molar abundance vector for the full network by mapping the
* abundances from the active species in `netIn` to their corresponding indices in the
* full network. Species not present in `netIn` are assigned an abundance of zero.
*
* @throws std::runtime_error If the AdaptiveEngineView is stale (i.e., `update()` has not been called).
*/
[[nodiscard]] std::vector<double> mapNetInToMolarAbundanceVector(const NetIn &netIn) const override;
[[nodiscard]] size_t getSpeciesIndex(
scratch::StateBlob& ctx,
const fourdst::atomic::Species &species
) const override;
/**
* @brief Primes the engine with the given network input.
@@ -364,12 +324,16 @@ namespace gridfire::engine {
*
* This method delegates the priming operation to the base engine.
*/
[[nodiscard]] PrimingReport primeEngine(const NetIn &netIn) override;
[[nodiscard]] PrimingReport primeEngine(
scratch::StateBlob& ctx,
const NetIn &netIn
) const override;
/**
* @brief Collect the composition of the base engine, ensure all active species are registered, and pass
* the composition back to the caller.
*
* @param ctx The scratchpad context for storing thread-local data.
* @param comp The current composition of the system.
* @param T9 The temperature in units of 10^9 K.
* @param rho The density in g/cm^3.
@@ -377,17 +341,32 @@ namespace gridfire::engine {
* @note This function ensures that the state of both the base engine and the adaptive view are synchronized in the
* result back to the caller
*/
[[nodiscard]] fourdst::composition::Composition collectComposition(const fourdst::composition::CompositionAbstract &comp, double T9, double rho) const override;
[[nodiscard]] fourdst::composition::Composition collectComposition(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho
) const override;
/**
* @brief Gets the status of a species in the network.
*
* @param ctx The scratchpad context for storing thread-local data.
* @param species The species for which to get the status.
* @return The SpeciesStatus indicating the status of the species.
*
* This method delegates the call to the base engine to get the species status. If the base engine says that the species is active but it is not in the active species list of this view, the status is returned as INACTIVE_FLOW.
* This method delegates the call to the base engine to get the species status. If the base engine says that
* the species is active, but it is not in the active species list of this view, the status is returned as
* INACTIVE_FLOW.
*/
[[nodiscard]] SpeciesStatus getSpeciesStatus(const fourdst::atomic::Species &species) const override;
[[nodiscard]] SpeciesStatus getSpeciesStatus(
scratch::StateBlob& ctx,
const fourdst::atomic::Species &species
) const override;
[[nodiscard]] std::optional<StepDerivatives<double>>getMostRecentRHSCalculation(
scratch::StateBlob &ctx
) const override;
private:
using LogManager = fourdst::logging::LogManager;
@@ -399,17 +378,6 @@ namespace gridfire::engine {
/** @brief The underlying engine to which this view delegates calculations. */
DynamicEngine& m_baseEngine;
/** @brief The set of species that are currently active in the network. */
std::vector<fourdst::atomic::Species> m_activeSpecies;
/** @brief The set of reactions that are currently active in the network. */
reaction::ReactionSet m_activeReactions;
/** @brief A flag indicating whether the view is stale and needs to be updated. */
bool m_isStale = true;
mutable std::unordered_map<size_t, fourdst::composition::Composition> m_collected_composition_cache;
private:
/**
* @brief A struct to hold a reaction and its flow rate.
@@ -419,14 +387,6 @@ namespace gridfire::engine {
double flowRate;
};
private:
/**
* @brief Validates that the AdaptiveEngineView is not stale.
*
* @throws std::runtime_error If the AdaptiveEngineView is stale (i.e., `update()` has not been called).
*/
void validateState() const;
/**
* @brief Calculates the molar reaction flow rate for all reactions in the full network.
*
@@ -435,8 +395,11 @@ namespace gridfire::engine {
* and composition). It also constructs a vector of molar abundances for all species in the
* full network.
*
* @param ctx The scratchpad context for storing thread-local data.
* @param netIn The current network input, containing temperature, density, and composition.
* @return A pair with the first element a vector of ReactionFlow structs, each containing a pointer to a reaction and its calculated flow rate and the second being a composition object where species which were not present in netIn but are present in the definition of the base engine are registered but have 0 mass fraction.
* @return A pair with the first element a vector of ReactionFlow structs, each containing a pointer to a
* reaction and its calculated flow rate and the second being a composition object where species which were not
* present in netIn but are present in the definition of the base engine are registered but have 0 mass fraction
*
* @par Algorithm:
* 1. Iterates through all species in the base engine's network.
@@ -447,6 +410,7 @@ namespace gridfire::engine {
* 6. Stores the reaction pointer and its flow rate in a `ReactionFlow` struct and adds it to the returned vector.
*/
[[nodiscard]] std::pair<std::vector<ReactionFlow>, fourdst::composition::Composition> calculateAllReactionFlows(
scratch::StateBlob& ctx,
const NetIn& netIn
) const;
/**
@@ -456,6 +420,7 @@ namespace gridfire::engine {
* starting from the initial fuel species. A species is considered part of the initial fuel if its
* mass fraction is above a certain threshold (`ABUNDANCE_FLOOR`).
*
* @param ctx The scratchpad context for storing thread-local data.
* @param netIn The current network input, containing the initial composition.
* @return An unordered set of all reachable species.
*
@@ -467,6 +432,7 @@ namespace gridfire::engine {
* 5. The process continues until a full pass over all reactions does not add any new species to the `reachable` set.
*/
[[nodiscard]] std::unordered_set<fourdst::atomic::Species> findReachableSpecies(
scratch::StateBlob& ctx,
const NetIn& netIn
) const;
/**
@@ -476,6 +442,7 @@ namespace gridfire::engine {
* above an absolute culling threshold. The threshold is calculated by multiplying the
* maximum flow rate by a relative culling threshold read from the configuration.
*
* @param ctx The scratchpad context for storing thread-local data.
* @param allFlows A vector of all reactions and their flow rates.
* @param reachableSpecies A set of all species reachable from the initial fuel.
* @param comp The current composition of the system.
@@ -490,6 +457,7 @@ namespace gridfire::engine {
* 5. The pointers to the kept reactions are stored in a vector and returned.
*/
[[nodiscard]] std::vector<const reaction::Reaction*> cullReactionsByFlow(
scratch::StateBlob& ctx,
const std::vector<ReactionFlow>& allFlows,
const std::unordered_set<fourdst::atomic::Species>& reachableSpecies,
const fourdst::composition::Composition& comp,
@@ -498,11 +466,10 @@ namespace gridfire::engine {
typedef std::pair<std::unordered_set<const reaction::Reaction*>, std::unordered_set<fourdst::atomic::Species>> RescueSet;
[[nodiscard]] RescueSet rescueEdgeSpeciesDestructionChannel(
scratch::StateBlob& ctx,
const fourdst::composition::Composition& comp,
double T9,
double rho,
const std::vector<fourdst::atomic::Species>& activeSpecies,
const reaction::ReactionSet& activeReactions
double rho
) const;
/**
* @brief Finalizes the set of active species and reactions.
@@ -512,6 +479,7 @@ namespace gridfire::engine {
* determined by collecting all reactants and products from the final reactions.
* The active species list is then sorted by mass.
*
* @param ctx The scratchpad context for storing thread-local data.
* @param finalReactions A vector of pointers to the reactions to be included in the active set.
*
* @post
@@ -520,7 +488,8 @@ namespace gridfire::engine {
* - `m_activeSpecies` is sorted by atomic mass.
*/
void finalizeActiveSet(
scratch::StateBlob& ctx,
const std::vector<const reaction::Reaction*>& finalReactions
);
) const;
};
}

View File

@@ -8,6 +8,8 @@
#include "gridfire/config/config.h"
#include "gridfire/engine/scratchpads/blob.h"
#include "fourdst/config/config.h"
#include "fourdst/logging/logging.h"
@@ -30,7 +32,9 @@ namespace gridfire::engine {
* @brief Gets the list of active species in the network defined by the file.
* @return A const reference to the vector of active species.
*/
[[nodiscard]] const std::vector<fourdst::atomic::Species>& getNetworkSpecies() const override;
[[nodiscard]] const std::vector<fourdst::atomic::Species>& getNetworkSpecies(
scratch::StateBlob& ctx
) const override;
// --- DynamicEngine Interface ---
/**
@@ -45,7 +49,8 @@ namespace gridfire::engine {
*
* @throws std::runtime_error If the view is stale (i.e., `update()` has not been called after `setNetworkFile()`).
*/
[[nodiscard]] std::expected<StepDerivatives<double>, engine::EngineStatus> calculateRHSAndEnergy(
[[nodiscard]] std::expected<StepDerivatives<double>, EngineStatus> calculateRHSAndEnergy(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho,
@@ -53,6 +58,7 @@ namespace gridfire::engine {
) const override;
[[nodiscard]] EnergyDerivatives calculateEpsDerivatives(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho
@@ -68,6 +74,7 @@ namespace gridfire::engine {
* @throws std::runtime_error If the view is stale.
*/
[[nodiscard]] NetworkJacobian generateJacobianMatrix(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho
@@ -84,6 +91,7 @@ namespace gridfire::engine {
* @throws std::runtime_error If the view is stale.
*/
[[nodiscard]] NetworkJacobian generateJacobianMatrix(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho,
@@ -101,33 +109,13 @@ namespace gridfire::engine {
* @throws std::runtime_error If the view is stale.
*/
[[nodiscard]] NetworkJacobian generateJacobianMatrix(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho,
const SparsityPattern &sparsityPattern
) const override;
/**
* @brief Generates the stoichiometry matrix for the active reactions and species.
*
* @throws std::runtime_error If the view is stale.
*/
void generateStoichiometryMatrix() override;
/**
* @brief Gets an entry from the stoichiometry matrix for the active species and reactions.
*
* @param species The species for which to get the stoichiometric coefficient.
* @param reaction The reaction for which to get the stoichiometric coefficient.
* @return The stoichiometric coefficient for the given species and reaction.
*
* @throws std::runtime_error If the view is stale.
* @throws std::out_of_range If an index is out of bounds.
*/
[[nodiscard]] int getStoichiometryMatrixEntry(
const fourdst::atomic::Species& species,
const reaction::Reaction& reaction
) const override;
/**
* @brief Calculates the molar reaction flow for a given reaction in the active network.
*
@@ -140,6 +128,7 @@ namespace gridfire::engine {
* @throws std::runtime_error If the view is stale or if the reaction is not in the active set.
*/
[[nodiscard]] double calculateMolarReactionFlow(
scratch::StateBlob& ctx,
const reaction::Reaction& reaction,
const fourdst::composition::CompositionAbstract &comp,
double T9,
@@ -152,16 +141,9 @@ namespace gridfire::engine {
*
* @throws std::runtime_error If the view is stale.
*/
[[nodiscard]] const reaction::ReactionSet& getNetworkReactions() const override;
/**
* @brief Sets the active reactions in the network.
*
* @param reactions The ReactionSet containing the reactions to set as active.
*
* @post The view is marked as stale and will need to be updated.
*/
void setNetworkReactions(const reaction::ReactionSet& reactions) override;
[[nodiscard]] const reaction::ReactionSet& getNetworkReactions(
scratch::StateBlob& ctx
) const override;
/**
* @brief Computes timescales for all active species in the network.
@@ -173,8 +155,8 @@ namespace gridfire::engine {
*
* @throws std::runtime_error If the view is stale.
*/
[[nodiscard]] std::expected<std::unordered_map<fourdst::atomic::Species, double>, engine::EngineStatus>
getSpeciesTimescales(
[[nodiscard]] std::expected<std::unordered_map<fourdst::atomic::Species, double>, EngineStatus>getSpeciesTimescales(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho
@@ -190,8 +172,8 @@ namespace gridfire::engine {
*
* @throws std::runtime_error If the view is stale.
*/
[[nodiscard]] std::expected<std::unordered_map<fourdst::atomic::Species, double>, engine::EngineStatus>
getSpeciesDestructionTimescales(
[[nodiscard]] std::expected<std::unordered_map<fourdst::atomic::Species, double>, EngineStatus> getSpeciesDestructionTimescales(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho
@@ -208,29 +190,19 @@ namespace gridfire::engine {
*
* @post If the view was stale, it is rebuilt and is no longer stale.
*/
fourdst::composition::Composition update(const NetIn &netIn) override;
/**
* @brief Checks if the engine view is stale.
*
* @param netIn The current network input (unused).
* @return True if the view is stale and needs to be updated; false otherwise.
*/
[[deprecated]] bool isStale(const NetIn& netIn) override;
/**
* @brief Sets the screening model for the base engine.
*
* @param model The screening model to set.
*/
void setScreeningModel(screening::ScreeningType model) override;
fourdst::composition::Composition project(
scratch::StateBlob& ctx,
const NetIn &netIn
) const override;
/**
* @brief Gets the screening model from the base engine.
*
* @return The current screening model type.
*/
[[nodiscard]] screening::ScreeningType getScreeningModel() const override;
[[nodiscard]] screening::ScreeningType getScreeningModel(
scratch::StateBlob& ctx
) const override;
/** @brief Maps a species from the full network to its index in the defined active network.
*
@@ -239,21 +211,20 @@ namespace gridfire::engine {
*
* @throws std::runtime_error If the species is not in the active set.
*/
[[nodiscard]] size_t getSpeciesIndex(const fourdst::atomic::Species &species) const override;
/**
* @brief Map from a NetIn object to a vector of molar abundances for the active species.
* @param netIn The NetIn object containing the full network abundances.
* @return A vector of molar abundances for the active species.
*/
[[nodiscard]] std::vector<double> mapNetInToMolarAbundanceVector(const NetIn &netIn) const override;
[[nodiscard]] size_t getSpeciesIndex(
scratch::StateBlob& ctx,
const fourdst::atomic::Species &species
) const override;
/**
* @brief Prime the engine view for calculations. This will delegate to the base engine.
* @param netIn The current network input.
* @return The PrimingReport from the base engine.
*/
[[nodiscard]] PrimingReport primeEngine(const NetIn &netIn) override;
[[nodiscard]] PrimingReport primeEngine(
scratch::StateBlob& ctx,
const NetIn &netIn
) const override;
/**
* @brief Collects a Composition object from the base engine.
@@ -263,7 +234,12 @@ namespace gridfire::engine {
* @param rho The density in g/cm^3.
* @return A composition object representing the state of the engine stack and the current view.
*/
[[nodiscard]] fourdst::composition::Composition collectComposition(const fourdst::composition::CompositionAbstract &comp, double T9, double rho) const override;
[[nodiscard]] fourdst::composition::Composition collectComposition(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho
) const override;
/**
* @brief Gets the status of a species in the active network.
@@ -271,25 +247,19 @@ namespace gridfire::engine {
* @param species The species for which to get the status.
* @return The SpeciesStatus indicating if the species is active, inactive, or not present.
*/
[[nodiscard]] SpeciesStatus getSpeciesStatus(const fourdst::atomic::Species &species) const override;
[[nodiscard]] SpeciesStatus getSpeciesStatus(
scratch::StateBlob& ctx,
const fourdst::atomic::Species &species
) const override;
[[nodiscard]] std::optional<StepDerivatives<double>>getMostRecentRHSCalculation(
scratch::StateBlob &ctx
) const override;
protected:
bool m_isStale = true;
GraphEngine& m_baseEngine;
private:
quill::Logger* m_logger = fourdst::logging::LogManager::getInstance().getLogger("log"); ///< Logger instance for trace and debug information.
///< Active species in the defined engine.
std::set<fourdst::atomic::Species> m_activeSpecies;
///< Cache for the active species vector to avoid dangling references.
mutable std::optional<std::vector<fourdst::atomic::Species>> m_activeSpeciesVectorCache = std::nullopt;
///< Active reactions in the defined engine.
reaction::ReactionSet m_activeReactions;
///< Maps indices of active species to indices in the full network.
std::vector<size_t> m_speciesIndexMap;
///< Maps indices of active reactions to indices in the full network.
std::vector<size_t> m_reactionIndexMap;
quill::Logger* m_logger = LogManager::getInstance().getLogger("log"); ///< Logger instance for trace and debug information.
private:
/**
* @brief Constructs the species index map.
@@ -301,7 +271,9 @@ namespace gridfire::engine {
*
* @throws std::runtime_error If an active species is not found in the base engine's species list.
*/
[[nodiscard]] std::vector<size_t> constructSpeciesIndexMap() const;
[[nodiscard]] std::vector<size_t> constructSpeciesIndexMap(
scratch::StateBlob& ctx
) const;
/**
* @brief Constructs the reaction index map.
@@ -313,7 +285,9 @@ namespace gridfire::engine {
*
* @throws std::runtime_error If an active reaction is not found in the base engine's reaction list.
*/
[[nodiscard]] std::vector<size_t> constructReactionIndexMap() const;
[[nodiscard]] std::vector<size_t> constructReactionIndexMap(
scratch::StateBlob& ctx
) const;
/**
* @brief Maps a vector of culled abundances to a vector of full abundances.
@@ -322,7 +296,10 @@ namespace gridfire::engine {
* @return A vector of abundances for the full network, with the abundances of the active
* species copied from the defined vector.
*/
[[nodiscard]] std::vector<double> mapViewToFull(const std::vector<double>& defined) const;
[[nodiscard]] std::vector<double> mapViewToFull(
scratch::StateBlob& ctx,
const std::vector<double>& defined
) const;
/**
* @brief Maps a vector of full abundances to a vector of culled abundances.
@@ -331,7 +308,10 @@ namespace gridfire::engine {
* @return A vector of abundances for the active species, with the abundances of the active
* species copied from the full vector.
*/
[[nodiscard]] std::vector<double> mapFullToView(const std::vector<double>& full) const;
[[nodiscard]] static std::vector<double> mapFullToView(
scratch::StateBlob& ctx,
const std::vector<double>& full
);
/**
* @brief Maps a culled species index to a full species index.
@@ -341,7 +321,10 @@ namespace gridfire::engine {
*
* @throws std::out_of_range If the defined index is out of bounds for the species index map.
*/
[[nodiscard]] size_t mapViewToFullSpeciesIndex(size_t definedSpeciesIndex) const;
[[nodiscard]] size_t mapViewToFullSpeciesIndex(
scratch::StateBlob& ctx,
size_t definedSpeciesIndex
) const;
/**
* @brief Maps a culled reaction index to a full reaction index.
@@ -351,11 +334,15 @@ namespace gridfire::engine {
*
* @throws std::out_of_range If the defined index is out of bounds for the reaction index map.
*/
[[nodiscard]] size_t mapViewToFullReactionIndex(size_t definedReactionIndex) const;
[[nodiscard]] size_t mapViewToFullReactionIndex(
scratch::StateBlob& ctx,
size_t definedReactionIndex
) const;
void validateNetworkState() const;
void collect(const std::vector<std::string>& peNames);
void collect(
scratch::StateBlob& ctx,
const std::vector<std::string>& peNames
) const;
};
@@ -368,6 +355,9 @@ namespace gridfire::engine {
);
[[nodiscard]] std::string getNetworkFile() const { return m_fileName; }
[[nodiscard]] const io::NetworkFileParser& getParser() const { return m_parser; }
private:
using LogManager = LogManager;
Config<config::GridFireConfig> m_config;

View File

@@ -4,6 +4,8 @@
#include "gridfire/engine/views/engine_view_abstract.h"
#include "gridfire/engine/engine_graph.h"
#include "gridfire/engine/scratchpads/blob.h"
#include "sundials/sundials_linearsolver.h"
#include "sundials/sundials_matrix.h"
#include "sundials/sundials_nvector.h"
@@ -81,19 +83,20 @@ namespace gridfire::engine {
*/
explicit MultiscalePartitioningEngineView(DynamicEngine& baseEngine);
~MultiscalePartitioningEngineView() override;
/**
* @brief Gets the list of species in the network.
* @return A const reference to the vector of `Species` objects representing all species
* in the underlying base engine. This view does not alter the species list itself,
* only how their abundances are evolved.
*/
[[nodiscard]] const std::vector<fourdst::atomic::Species> & getNetworkSpecies() const override;
[[nodiscard]] const std::vector<fourdst::atomic::Species> & getNetworkSpecies(
scratch::StateBlob& ctx
) const override;
/**
* @brief Calculates the right-hand side (dY/dt) and energy generation.
*
* @param ctx The scratch data for thread-local storage.
* @param comp The current composition.
* @param T9 Temperature in units of 10^9 K.
* @param rho Density in g/cm^3.
@@ -120,6 +123,7 @@ namespace gridfire::engine {
* (T9, rho, Y_full). This indicates `update()` was not called recently enough.
*/
[[nodiscard]] std::expected<StepDerivatives<double>, engine::EngineStatus> calculateRHSAndEnergy(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho,
@@ -128,12 +132,14 @@ namespace gridfire::engine {
/**
* @brief Calculates the energy generation rate derivatives with respect to abundances.
* @param ctx The scratch data for thread-local storage.
* @param comp The current composition.
* @param T9 The temperature in units of 10^9 K.
* @param rho The density in g/cm^3.
* @return The energy generation rate derivatives (dEps/dT and dEps/drho).
*/
[[nodiscard]] EnergyDerivatives calculateEpsDerivatives(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho
@@ -142,6 +148,7 @@ namespace gridfire::engine {
/**
* @brief Generates the Jacobian matrix for the current state.
*
* @param ctx The scratch data for thread-local storage.
* @param comp The current composition.
* @param T9 Temperature in units of 10^9 K.
* @param rho Density in g/cm^3.
@@ -163,6 +170,7 @@ namespace gridfire::engine {
* without a valid partition.
*/
[[nodiscard]] NetworkJacobian generateJacobianMatrix(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho
@@ -171,6 +179,7 @@ namespace gridfire::engine {
/**
* @brief Generates the Jacobian matrix for a subset of active species.
*
* @param ctx The scratch data for thread-local storage.
* @param comp The current composition.
* @param T9 Temperature in units of 10^9 K.
* @param rho Density in g/cm^3.
@@ -192,6 +201,7 @@ namespace gridfire::engine {
* @throws exceptions::StaleEngineError If the QSE cache misses.
*/
[[nodiscard]] NetworkJacobian generateJacobianMatrix(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho,
@@ -201,6 +211,7 @@ namespace gridfire::engine {
/**
* @brief Generates the Jacobian matrix using a sparsity pattern.
*
* @param ctx The scratch data for thread-local storage.
* @param comp The current composition.
* @param T9 Temperature in units of 10^9 K.
* @param rho Density in g/cm^3.
@@ -220,47 +231,17 @@ namespace gridfire::engine {
* @throws exceptions::StaleEngineError If the QSE cache misses.
*/
[[nodiscard]] NetworkJacobian generateJacobianMatrix(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho,
const SparsityPattern &sparsityPattern
) const override;
/**
* @brief Generates the stoichiometry matrix for the network.
*
* @par Purpose
* To prepare the stoichiometry matrix for later queries.
*
* @par How
* This method delegates directly to the base engine's `generateStoichiometryMatrix()`.
* The stoichiometry is based on the full, unpartitioned network.
*/
void generateStoichiometryMatrix() override;
/**
* @brief Gets an entry from the stoichiometry matrix.
*
* @param species Species to look up stoichiometry for.
* @param reaction Reaction to find.
* @return Stoichiometric coefficient for the species in the reaction.
*
* @par Purpose
* To query the stoichiometric relationship between a species and a reaction.
*
* @par How
* This method delegates directly to the base engine's `getStoichiometryMatrixEntry()`.
*
* @pre `generateStoichiometryMatrix()` must have been called.
*/
[[nodiscard]] int getStoichiometryMatrixEntry(
const fourdst::atomic::Species& species,
const reaction::Reaction& reaction
) const override;
/**
* @brief Calculates the molar reaction flow for a given reaction.
*
* @param ctx The scratch data for thread-local storage.
* @param reaction The reaction for which to calculate the flow.
* @param comp The current composition.
* @param T9 Temperature in units of 10^9 K.
@@ -281,6 +262,7 @@ namespace gridfire::engine {
* @throws StaleEngineError If the QSE cache misses.
*/
[[nodiscard]] double calculateMolarReactionFlow(
scratch::StateBlob& ctx,
const reaction::Reaction &reaction,
const fourdst::composition::CompositionAbstract &comp,
double T9,
@@ -293,31 +275,14 @@ namespace gridfire::engine {
* @return A const reference to the `LogicalReactionSet` from the base engine,
* containing all reactions in the full network.
*/
[[nodiscard]] const reaction::ReactionSet & getNetworkReactions() const override;
/**
* @brief Sets the set of logical reactions in the network.
*
* @param reactions The set of logical reactions to use.
*
* @par Purpose
* To modify the reaction network.
*
* @par How
* This operation is not supported by the `MultiscalePartitioningEngineView` as it
* would invalidate the partitioning logic. It logs a critical error and throws an
* exception. Network modifications should be done on the base engine before it is
* wrapped by this view.
*
* @throws exceptions::UnableToSetNetworkReactionsError Always.
*/
void setNetworkReactions(
const reaction::ReactionSet &reactions
) override;
[[nodiscard]] const reaction::ReactionSet & getNetworkReactions(
scratch::StateBlob& ctx
) const override;
/**
* @brief Computes timescales for all species in the network.
*
* @param ctx The scratch data for thread-local storage.
* @param comp The current composition.
* @param T9 Temperature in units of 10^9 K.
* @param rho Density in g/cm^3.
@@ -335,8 +300,8 @@ namespace gridfire::engine {
* @pre The engine must have a valid QSE cache entry for the given state.
* @throws StaleEngineError If the QSE cache misses.
*/
[[nodiscard]] std::expected<std::unordered_map<fourdst::atomic::Species, double>, engine::EngineStatus>
getSpeciesTimescales(
[[nodiscard]] std::expected<std::unordered_map<fourdst::atomic::Species, double>, EngineStatus> getSpeciesTimescales(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho
@@ -345,6 +310,7 @@ namespace gridfire::engine {
/**
* @brief Computes destruction timescales for all species in the network.
*
* @param ctx The scratch data for thread-local storage.
* @param comp The current composition.
* @param T9 Temperature in units of 10^9 K.
* @param rho Density in g/cm^3.
@@ -362,8 +328,8 @@ namespace gridfire::engine {
* @pre The engine must have a valid QSE cache entry for the given state.
* @throws StaleEngineError If the QSE cache misses.
*/
[[nodiscard]] std::expected<std::unordered_map<fourdst::atomic::Species, double>, engine::EngineStatus>
getSpeciesDestructionTimescales(
[[nodiscard]] std::expected<std::unordered_map<fourdst::atomic::Species, double>, EngineStatus> getSpeciesDestructionTimescales(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho
@@ -372,6 +338,7 @@ namespace gridfire::engine {
/**
* @brief Updates the internal state of the engine, performing partitioning and QSE equilibration.
*
* @param ctx The scratch data for thread-local storage.
* @param netIn A struct containing the current network input: temperature, density, and composition.
* @return The new composition after QSE species have been brought to equilibrium.
*
@@ -396,38 +363,11 @@ namespace gridfire::engine {
* The `m_qse_abundance_cache` is populated with the QSE solution for the given state.
* The returned composition reflects the new equilibrium.
*/
fourdst::composition::Composition update(
fourdst::composition::Composition project(
scratch::StateBlob& ctx,
const NetIn &netIn
) override;
) const override;
/**
* @brief Checks if the engine's internal state is stale relative to the provided conditions.
*
* @param netIn A struct containing the current network input.
* @return `true` if the engine is stale, `false` otherwise.
*
* @par Purpose
* To determine if `update()` needs to be called.
*
* @par How
* It creates a `QSECacheKey` from the `netIn` data and checks for its
* existence in the `m_qse_abundance_cache`. A cache miss indicates the engine is
* stale because it does not have a valid QSE partition for the current conditions.
* It also queries the base engine's `isStale()` method.
*/
bool isStale(const NetIn& netIn) override;
/**
* @brief Sets the electron screening model.
*
* @param model The type of screening model to use for reaction rate calculations.
*
* @par How
* This method delegates directly to the base engine's `setScreeningModel()`.
*/
void setScreeningModel(
screening::ScreeningType model
) override;
/**
* @brief Gets the current electron screening model.
@@ -437,20 +377,23 @@ namespace gridfire::engine {
* @par How
* This method delegates directly to the base engine's `getScreeningModel()`.
*/
[[nodiscard]] screening::ScreeningType getScreeningModel() const override;
[[nodiscard]] screening::ScreeningType getScreeningModel(
scratch::StateBlob& ctx
) const override;
/**
* @brief Gets the base engine.
*
* @return A const reference to the base engine.
*/
const DynamicEngine & getBaseEngine() const override;
[[nodiscard]] const DynamicEngine & getBaseEngine() const override;
/**
* @brief Partitions the network based on timescales from a `NetIn` struct.
*
* @param ctx The scratch data for thread-local storage.
* @param netIn A struct containing the current network input.
*
* @par Purpose
@@ -461,12 +404,14 @@ namespace gridfire::engine {
* primary `partitionNetwork` method.
*/
fourdst::composition::Composition partitionNetwork(
scratch::StateBlob& ctx,
const NetIn &netIn
);
) const;
/**
* @brief Exports the network to a DOT file for visualization.
*
* @param ctx The scratch data for thread-local storage.
* @param filename The name of the DOT file to create.
* @param comp Composition object
* @param T9 Temperature in units of 10^9 K.
@@ -480,6 +425,7 @@ namespace gridfire::engine {
* currently add any partitioning information to the output graph.
*/
void exportToDot(
scratch::StateBlob &ctx,
const std::string& filename,
const fourdst::composition::Composition &comp,
double T9,
@@ -489,24 +435,17 @@ namespace gridfire::engine {
/**
* @brief Gets the index of a species in the full network.
*
* @param ctx The scratch data for thread-local storage.
* @param species The species to get the index of.
* @return The index of the species in the base engine's network.
*
* @par How
* This method delegates directly to the base engine's `getSpeciesIndex()`.
*/
[[nodiscard]] size_t getSpeciesIndex(const fourdst::atomic::Species &species) const override;
/**
* @brief Maps a `NetIn` struct to a molar abundance vector for the full network.
*
* @param netIn A struct containing the current network input.
* @return A vector of molar abundances corresponding to the species order in the base engine.
*
* @par How
* This method delegates directly to the base engine's `mapNetInToMolarAbundanceVector()`.
*/
[[nodiscard]] std::vector<double> mapNetInToMolarAbundanceVector(const NetIn &netIn) const override;
[[nodiscard]] size_t getSpeciesIndex(
scratch::StateBlob& ctx,
const fourdst::atomic::Species &species
) const override;
/**
* @brief Primes the engine with a specific species.
@@ -521,7 +460,10 @@ namespace gridfire::engine {
* This method delegates directly to the base engine's `primeEngine()`. The
* multiscale view does not currently interact with the priming process.
*/
[[nodiscard]] PrimingReport primeEngine(const NetIn &netIn) override;
[[nodiscard]] PrimingReport primeEngine(
scratch::StateBlob& ctx,
const NetIn &netIn
) const override;
/**
* @brief Gets the fast species in the network.
@@ -536,7 +478,10 @@ namespace gridfire::engine {
*
* @pre `partitionNetwork()` must have been called.
*/
[[nodiscard]] std::vector<fourdst::atomic::Species> getFastSpecies() const;
[[nodiscard]] std::vector<fourdst::atomic::Species> getFastSpecies(
scratch::StateBlob& ctx
) const;
/**
* @brief Gets the dynamic species in the network.
*
@@ -550,11 +495,14 @@ namespace gridfire::engine {
*
* @pre `partitionNetwork()` must have been called.
*/
[[nodiscard]] const std::vector<fourdst::atomic::Species>& getDynamicSpecies() const;
[[nodiscard]] static const std::vector<fourdst::atomic::Species>& getDynamicSpecies(
scratch::StateBlob& ctx
);
/**
* @brief Checks if a species is involved in the partitioned network.
*
* @param ctx The scratch data for thread-local storage.
* @param species The species to check.
* @return `true` if the species is in either the dynamic or algebraic sets, `false` otherwise.
*
@@ -566,25 +514,37 @@ namespace gridfire::engine {
*
* @pre `partitionNetwork()` must have been called.
*/
bool involvesSpecies(const fourdst::atomic::Species &species) const;
static bool involvesSpecies(
scratch::StateBlob& ctx,
const fourdst::atomic::Species &species
);
/**
* @brief Check if a species is involved in the QSE (algebraic) set.
* @param ctx The scratch data for thread-local storage.
* @param species The species to check.
* @return Boolean indicating if the species is in the algebraic set.
*/
bool involvesSpeciesInQSE(const fourdst::atomic::Species &species) const;
static bool involvesSpeciesInQSE(
scratch::StateBlob& ctx,
const fourdst::atomic::Species &species
);
/**
* @brief Check if a species is involved in the dynamic set.
* @param ctx The scratch data for thread-local storage.
* @param species The species to check.
* @return Boolean indicating if the species is in the dynamic set.
*/
bool involvesSpeciesInDynamic(const fourdst::atomic::Species &species) const;
static bool involvesSpeciesInDynamic(
scratch::StateBlob& ctx,
const fourdst::atomic::Species &species
);
/**
* @brief Gets a normalized composition with QSE species equilibrated.
*
* @param ctx The scratch data for thread-local storage.
* @param comp The input composition.
* @param T9 Temperature in units of 10^9 K.
* @param rho Density in g/cm^3.
@@ -601,23 +561,36 @@ namespace gridfire::engine {
* @pre The engine must have a valid QSE partition for the given state.
* @throws StaleEngineError If the QSE cache misses.
*/
fourdst::composition::Composition getNormalizedEquilibratedComposition(const fourdst::composition::CompositionAbstract& comp, double T9, double rho, bool trust) const;
fourdst::composition::Composition getNormalizedEquilibratedComposition(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract& comp,
double T9,
double rho,
bool trust
) const;
/**
* @brief Collect the composition from this and sub engines.
* @details This method operates by injecting the current equilibrium abundances for algebraic species into
* the composition object so that they can be bubbled up to the caller.
* @param ctx The scratch data for thread-local storage.
* @param comp Input Composition
* @param T9
* @param rho
* @return New composition which is comp + any edits from lower levels + the equilibrium abundances of all algebraic species.
* @throws BadCollectionError: if there is a species in the algebraic species set which does not show up in the reported composition from the base engine.:w
*/
fourdst::composition::Composition collectComposition(const fourdst::composition::CompositionAbstract &comp, double T9, double rho) const override;
fourdst::composition::Composition collectComposition(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho
) const override;
/**
* @brief Gets the status of a species in the network.
*
* @param ctx The scratch data for thread-local storage.
* @param species The species to query.
* @return The `SpeciesStatus` indicating if the species is dynamic, algebraic, or not involved.
*
@@ -630,9 +603,15 @@ namespace gridfire::engine {
*
* @pre `partitionNetwork()` must have been called.
*/
SpeciesStatus getSpeciesStatus(const fourdst::atomic::Species &species) const override;
SpeciesStatus getSpeciesStatus(
scratch::StateBlob& ctx,
const fourdst::atomic::Species &species
) const override;
private:
[[nodiscard]] std::optional<StepDerivatives<double>>getMostRecentRHSCalculation(
scratch::StateBlob &
) const override;
public:
/**
* @brief Struct representing a QSE group.
*
@@ -678,9 +657,9 @@ namespace gridfire::engine {
return os;
}
bool contains(const fourdst::atomic::Species& species) const;
bool containsAlgebraic(const fourdst::atomic::Species &species) const;
bool containsSeed(const fourdst::atomic::Species &species) const;
[[nodiscard]] bool contains(const fourdst::atomic::Species& species) const;
[[nodiscard]] bool containsAlgebraic(const fourdst::atomic::Species &species) const;
[[nodiscard]] bool containsSeed(const fourdst::atomic::Species &species) const;
};
class QSESolver {
@@ -701,6 +680,7 @@ namespace gridfire::engine {
~QSESolver();
fourdst::composition::Composition solve(
scratch::StateBlob& ctx,
const fourdst::composition::Composition& comp,
double T9,
double rho
@@ -709,6 +689,9 @@ namespace gridfire::engine {
size_t solves() const;
void log_diagnostics(const QSEGroup &group, const fourdst::composition::Composition &comp) const;
std::unique_ptr<QSESolver> clone() const;
std::unique_ptr<QSESolver> clone(SUNContext sun_ctx) const;
private:
static int sys_func(
@@ -733,8 +716,7 @@ namespace gridfire::engine {
const std::unordered_map<fourdst::atomic::Species, size_t>& qse_solve_species_index_map;
const std::vector<fourdst::atomic::Species>& qse_solve_species;
const QSESolver& instance;
std::vector<double> row_scaling_factors;
const double initial_group_mass;
scratch::StateBlob& ctx;
};
private:
@@ -783,46 +765,12 @@ namespace gridfire::engine {
* @brief The base engine to which this view delegates calculations.
*/
DynamicEngine& m_baseEngine;
/**
* @brief The list of identified equilibrium groups.
*/
std::vector<QSEGroup> m_qse_groups;
/**
* @brief A set of solvers, one for each QSE group
*/
std::vector<std::unique_ptr<QSESolver>> m_qse_solvers;
/**
* @brief The simplified set of species presented to the solver (the "slow" species).
*/
std::vector<fourdst::atomic::Species> m_dynamic_species;
/**
* @brief Species that are treated as algebraic (in QSE) in the QSE groups.
*/
std::vector<fourdst::atomic::Species> m_algebraic_species;
/**
* @brief Map from species to their calculated abundances in the QSE state.
*/
std::unordered_map<fourdst::atomic::Species, double> m_algebraic_abundances;
/**
* @brief Indices of all species considered active in the current partition (dynamic + algebraic).
*/
std::vector<size_t> m_activeSpeciesIndices;
/**
* @brief Indices of all reactions involving only active species.
*/
std::vector<size_t> m_activeReactionIndices;
mutable std::unordered_map<uint64_t, fourdst::composition::Composition> m_composition_cache;
SUNContext m_sun_ctx = nullptr;
private:
/**
* @brief Partitions the network by timescale.
*
* @param ctx The scratch data for thread-local storage.
* @param comp Vector of current molar abundances for all species.
* @param T9 Temperature in units of 10^9 K.
* @param rho Density in g/cm^3.
@@ -839,12 +787,14 @@ namespace gridfire::engine {
* (e.g., a factor of 100).
*/
std::vector<std::vector<fourdst::atomic::Species>> partitionByTimescale(
scratch::StateBlob& ctx,
const fourdst::composition::Composition &comp,
double T9,
double rho
) const;
std::pair<bool, reaction::ReactionSet> group_is_a_qse_cluster(
scratch::StateBlob& ctx,
const fourdst::composition::Composition &comp,
double T9,
double rho,
@@ -852,6 +802,7 @@ namespace gridfire::engine {
) const;
bool group_is_a_qse_pipeline(
scratch::StateBlob& ctx,
const fourdst::composition::Composition &comp,
double T9,
double rho,
@@ -861,6 +812,7 @@ namespace gridfire::engine {
/**
* @brief Validates candidate QSE groups using flux analysis.
*
* @param ctx The scratch data for thread-local storage.
* @param candidate_groups A vector of candidate QSE groups.
* @param comp Vector of current molar abundances for the full network.
* @param T9 Temperature in units of 10^9 K.
@@ -879,6 +831,7 @@ namespace gridfire::engine {
* to the returned vector.
*/
FluxValidationResult validateGroupsWithFluxAnalysis(
scratch::StateBlob& ctx,
const std::vector<QSEGroup> &candidate_groups,
const fourdst::composition::Composition &comp,
double T9,
@@ -888,6 +841,7 @@ namespace gridfire::engine {
/**
* @brief Solves for the QSE abundances of the algebraic species in a given state.
*
* @param ctx The scratch data for thread-local storage.
* @param comp Vector of current molar abundances for all species in the base engine.
* @param T9 Temperature in units of 10^9 K.
* @param rho Density in g/cm^3.
@@ -905,15 +859,17 @@ namespace gridfire::engine {
* @pre The input state (Y_full, T9, rho) must be a valid physical state.
* @post The algebraic species in the QSE cache are updated with the new equilibrium abundances.
*/
fourdst::composition::Composition solveQSEAbundances(
auto solveQSEAbundances(
scratch::StateBlob &ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho
) const;
) const -> fourdst::composition::Composition;
/**
* @brief Identifies the pool with the slowest mean timescale.
*
* @param ctx The scratch data for thread-local storage.
* @param pools A vector of vectors of species indices, where each inner vector represents a
* timescale pool.
* @param comp Vector of current molar abundances for the full network.
@@ -929,6 +885,7 @@ namespace gridfire::engine {
* pool and returns the index of the pool with the maximum mean timescale.
*/
size_t identifyMeanSlowestPool(
scratch::StateBlob& ctx,
const std::vector<std::vector<fourdst::atomic::Species>>& pools,
const fourdst::composition::Composition &comp,
double T9,
@@ -938,6 +895,7 @@ namespace gridfire::engine {
/**
* @brief Builds a connectivity graph from a species pool.
*
* @param ctx The scratch data for thread-local storage.
* @param species_pool A vector of species indices representing a species pool.
* @param comp
* @param T9
@@ -954,13 +912,17 @@ namespace gridfire::engine {
* that reaction that are also in the pool.
*/
std::unordered_map<fourdst::atomic::Species, std::vector<fourdst::atomic::Species>> buildConnectivityGraph(
const std::vector<fourdst::atomic::Species>& species_pool, const fourdst::composition::Composition &comp, double T9, double
rho
scratch::StateBlob& ctx,
const std::vector<fourdst::atomic::Species>& species_pool,
const fourdst::composition::Composition &comp,
double T9,
double rho
) const;
/**
* @brief Constructs candidate QSE groups from connected timescale pools.
*
* @param ctx The scratch data for thread-local storage.
* @param candidate_pools A vector of vectors of species indices, where each inner vector
* represents a connected pool of species with similar fast timescales.
* @param comp Vector of current molar abundances.
@@ -978,6 +940,7 @@ namespace gridfire::engine {
* @post A list of candidate `QSEGroup` objects is returned.
*/
std::vector<QSEGroup> constructCandidateGroups(
scratch::StateBlob& ctx,
const std::vector<std::vector<fourdst::atomic::Species>>& candidate_pools,
const fourdst::composition::Composition &comp,
double T9,
@@ -987,6 +950,7 @@ namespace gridfire::engine {
/**
* @brief Analyzes the connectivity of timescale pools.
*
* @param ctx The scratch data for thread-local storage.
* @param timescale_pools A vector of vectors of species indices, where each inner vector
* represents a timescale pool.
* @param comp
@@ -1005,11 +969,15 @@ namespace gridfire::engine {
* The resulting components from all pools are collected and returned.
*/
std::vector<std::vector<fourdst::atomic::Species>> analyzeTimescalePoolConnectivity(
const std::vector<std::vector<fourdst::atomic::Species>> &timescale_pools, const fourdst::composition::Composition &
comp, double T9, double rho
scratch::StateBlob& ctx,
const std::vector<std::vector<fourdst::atomic::Species>> &timescale_pools,
const fourdst::composition::Composition &comp,
double T9,
double rho
) const;
std::vector<QSEGroup> pruneValidatedGroups(
scratch::StateBlob& ctx,
const std::vector<QSEGroup> &groups,
const std::vector<reaction::ReactionSet> &groupReactions,
const fourdst::composition::Composition &comp,
@@ -1018,9 +986,13 @@ namespace gridfire::engine {
) const;
static std::vector<QSEGroup> merge_coupled_groups(
scratch::StateBlob& ctx,
const std::vector<QSEGroup> &groups,
const std::vector<reaction::ReactionSet> &groupReactions
);
public:
};
}

View File

@@ -1,10 +1,13 @@
#pragma once
#include "gridfire/engine/views/engine_defined.h"
#include "gridfire/engine/scratchpads/blob.h"
#include "fourdst/logging/logging.h"
#include "fourdst/atomic/atomicSpecies.h"
#include "quill/Logger.h"
#include <vector>
@@ -35,7 +38,11 @@ namespace gridfire::engine {
* @throws std::out_of_range If primingSymbol is not found in the species registry.
* @throws std::runtime_error If no reactions contain the priming species.
*/
NetworkPrimingEngineView(const std::string& primingSymbol, GraphEngine& baseEngine);
NetworkPrimingEngineView(
scratch::StateBlob& ctx,
const std::string& primingSymbol,
GraphEngine& baseEngine
);
/**
* @brief Constructs the view using an existing Species object.
*
@@ -45,7 +52,11 @@ namespace gridfire::engine {
* @post The view will contain only reactions that involve the priming species.
* @throws std::runtime_error If no reactions contain the priming species.
*/
NetworkPrimingEngineView(const fourdst::atomic::Species& primingSpecies, GraphEngine& baseEngine);
NetworkPrimingEngineView(
scratch::StateBlob& ctx,
const fourdst::atomic::Species& primingSpecies,
GraphEngine& baseEngine
);
private:
@@ -63,6 +74,7 @@ namespace gridfire::engine {
* @throws std::runtime_error If no reactions involve the priming species.
*/
[[nodiscard]] std::vector<std::string> constructPrimingReactionSet(
scratch::StateBlob& ctx,
const fourdst::atomic::Species& primingSpecies,
const GraphEngine& baseEngine
) const;

View File

@@ -0,0 +1,14 @@
#pragma once
#include <string>
#include <stdexcept>
#include "gridfire/exceptions/error_gridfire.h"
namespace gridfire::exceptions {
class ScratchPadError : public GridFireError {
public:
explicit ScratchPadError(const std::string& msg)
: GridFireError("ScratchPadError: " + msg) {}
};
}

View File

@@ -5,6 +5,7 @@
#include "gridfire/reaction/reaction.h"
#include "gridfire/engine/engine_abstract.h"
#include "gridfire/engine/scratchpads/blob.h"
/**
* @brief Namespace for generative input/output functionalities.
@@ -49,16 +50,17 @@ namespace gridfire::io::gen {
* This function converts the given dynamic engine into a Python script
* that can be used to recreate the engine's functionality in Python.
*/
std::string exportEngineToPy(const engine::DynamicEngine& engine);
std::string exportEngineToPy(engine::scratch::StateBlob& ctx, engine::DynamicEngine& engine);
/**
* @brief Exports a dynamic engine to a Python file.
*
* @param ctx
* @param engine The dynamic engine to export.
* @param fileName The name of the file to write the Python script to.
*
* This function writes the Python script representation of the given
* dynamic engine to the specified file.
*/
void exportEngineToPy(const engine::DynamicEngine& engine, const std::string& fileName);
void exportEngineToPy(engine::scratch::StateBlob &ctx, const engine::DynamicEngine& engine, const std::string& fileName);
}

View File

@@ -19,6 +19,7 @@
#include "gridfire/reaction/reaction.h"
#include "gridfire/engine/engine_abstract.h"
#include "gridfire/partition/partition.h"
#include "gridfire/utils/logging.h"
#include <string>
#include <set>
@@ -43,6 +44,28 @@ namespace gridfire::policy {
INITIALIZED_VERIFIED
};
inline std::string NetworkPolicyStatusToString(NetworkPolicyStatus status) {
switch (status) {
case NetworkPolicyStatus::UNINITIALIZED:
return "UNINITIALIZED";
case NetworkPolicyStatus::INITIALIZED_UNVERIFIED:
return "INITIALIZED_UNVERIFIED";
case NetworkPolicyStatus::MISSING_KEY_REACTION:
return "MISSING_KEY_REACTION";
case NetworkPolicyStatus::MISSING_KEY_SPECIES:
return "MISSING_KEY_SPECIES";
case NetworkPolicyStatus::INITIALIZED_VERIFIED:
return "INITIALIZED_VERIFIED";
default:
return "UNKNOWN_STATUS";
}
}
struct ConstructionResults {
const engine::DynamicEngine& engine;
std::unique_ptr<engine::scratch::StateBlob> scratch_blob;
};
/**
* @class NetworkPolicy
* @brief Abstract interface for policies that construct DynamicEngine networks from a seed composition.
@@ -139,7 +162,7 @@ namespace gridfire::policy {
* NetOut out = solver.evaluate(netIn, true);
* @endcode
*/
[[nodiscard]] virtual engine::DynamicEngine& construct() = 0;
[[nodiscard]] virtual ConstructionResults construct() = 0;
/**
* @brief Returns the current verification/construction status of the policy.
@@ -160,6 +183,8 @@ namespace gridfire::policy {
[[nodiscard]] virtual std::vector<engine::EngineTypes> get_engine_types_stack() const = 0;
[[nodiscard]] virtual const std::unique_ptr<partition::PartitionFunction>& get_partition_function() const = 0;
[[nodiscard]] virtual std::unique_ptr<engine::scratch::StateBlob> get_stack_scratch_blob() const = 0;
};
/**
@@ -217,3 +242,27 @@ namespace gridfire::policy {
};
}
// 1. Define the BASE specialization first
template<>
struct std::formatter<gridfire::policy::NetworkPolicy> {
static constexpr auto parse(const format_parse_context& ctx) { return ctx.begin(); }
template <typename FormatContext>
auto format(const gridfire::policy::NetworkPolicy& policy, FormatContext& ctx) const {
std::vector<gridfire::engine::EngineTypes> engine_types = policy.get_engine_types_stack();
std::ranges::reverse(engine_types);
return format_to(
ctx.out(),
"{}(Status: {}, Engine Stack Size: {}, Engine Stack: <(TOP) [{}] (BOTTOM)>)",
policy.name(),
gridfire::policy::NetworkPolicyStatusToString(policy.get_status()),
policy.get_engine_stack().size(),
gridfire::utils::iterable_to_delimited_string(
engine_types,
" -> ",
[](const auto& type) { return gridfire::engine::engine_type_to_string(type); }
)
);
}
};

View File

@@ -26,7 +26,6 @@
#include "fourdst/composition/composition.h"
#include "fourdst/atomic/atomicSpecies.h"
#include "gridfire/partition/composite/partition_composite.h"
#include "gridfire/policy/chains.h"
@@ -135,7 +134,7 @@ namespace gridfire::policy {
* // ... run solver ...
* @endcode
*/
engine::DynamicEngine& construct() override;
ConstructionResults construct() override;
/**
* @brief Gets the current status of the policy.
@@ -148,6 +147,8 @@ namespace gridfire::policy {
[[nodiscard]] std::vector<engine::EngineTypes> get_engine_types_stack() const override;
[[nodiscard]] const std::unique_ptr<partition::PartitionFunction>& get_partition_function() const override;
[[nodiscard]] std::unique_ptr<engine::scratch::StateBlob> get_stack_scratch_blob() const override;
private:
std::set<fourdst::atomic::Species> m_seed_species; ///< The set of seed species required by this policy. These are H-1, He-3, He-4, C-12, N-14, O-16, Ne-20, Mg-24.
@@ -159,12 +160,12 @@ namespace gridfire::policy {
NetworkPolicyStatus m_status = NetworkPolicyStatus::UNINITIALIZED; ///< The current status of the policy.
private:
static std::unique_ptr<partition::PartitionFunction> build_partition_function();
[[nodiscard]] NetworkPolicyStatus check_status() const;
[[nodiscard]] NetworkPolicyStatus check_status(engine::scratch::StateBlob& ctx) const;
public:
};
}
}
template<>
struct std::formatter<gridfire::policy::MainSequencePolicy> : std::formatter<gridfire::policy::NetworkPolicy> {};

View File

@@ -78,14 +78,17 @@ namespace gridfire::solver {
* std::cout << "Final energy: " << out.energy << " erg/g\n";
* @endcode
*/
class CVODESolverStrategy final : public SingleZoneDynamicNetworkSolverStrategy {
class CVODESolverStrategy final : public SingleZoneDynamicNetworkSolver {
public:
/**
* @brief Construct the CVODE strategy and create a SUNDIALS context.
* @param engine DynamicEngine used for RHS/Jacobian evaluation and network access.
* @throws std::runtime_error If SUNContext_Create fails.
*/
explicit CVODESolverStrategy(engine::DynamicEngine& engine);
explicit CVODESolverStrategy(
const engine::DynamicEngine& engine,
const engine::scratch::StateBlob& ctx
);
/**
* @brief Destructor: cleans CVODE/SUNDIALS resources and frees SUNContext.
*/
@@ -185,6 +188,7 @@ namespace gridfire::solver {
const size_t currentConvergenceFailures; ///< Total number of convergence failures
const size_t currentNonlinearIterations; ///< Total number of non-linear iterations
const std::map<fourdst::atomic::Species, std::unordered_map<std::string, double>>& reactionContributionMap; ///< Map of reaction contributions for the current step
engine::scratch::StateBlob& state_ctx; ///< Reference to the engine scratch state blob
/**
* @brief Construct a context snapshot.
@@ -201,7 +205,8 @@ namespace gridfire::solver {
const std::vector<fourdst::atomic::Species>& networkSpecies,
size_t currentConvergenceFailure,
size_t currentNonlinearIterations,
const std::map<fourdst::atomic::Species, std::unordered_map<std::string, double>> &reactionContributionMap
const std::map<fourdst::atomic::Species, std::unordered_map<std::string, double>> &reactionContributionMap,
engine::scratch::StateBlob& state_ctx
);
/**
@@ -226,7 +231,8 @@ namespace gridfire::solver {
*/
struct CVODEUserData {
CVODESolverStrategy* solver_instance{}; // Pointer back to the class instance
engine::DynamicEngine* engine{};
engine::scratch::StateBlob& ctx;
const engine::DynamicEngine* engine{};
double T9{};
double rho{};
double energy{};
@@ -302,7 +308,7 @@ namespace gridfire::solver {
* sorted table of species with the highest error ratios; then invokes diagnostic routines to
* inspect Jacobian stiffness and species balance.
*/
void log_step_diagnostics(const CVODEUserData& user_data, bool displayJacobianStiffness, bool
void log_step_diagnostics(engine::scratch::StateBlob &ctx, const CVODEUserData& user_data, bool displayJacobianStiffness, bool
displaySpeciesBalance, bool to_file, std::optional<std::string> filename) const;
private:
SUNContext m_sun_ctx = nullptr; ///< SUNDIALS context (lifetime of the solver).

View File

@@ -37,7 +37,7 @@ namespace gridfire::solver {
[[nodiscard]] virtual std::vector<std::tuple<std::string, std::string>> describe() const = 0;
};
/**
* @class SingleZoneNetworkSolverStrategy
* @class SingleZoneNetworkSolver
* @brief Abstract base class for network solver strategies.
*
* This class defines the interface for network solver strategies, which are responsible
@@ -47,18 +47,23 @@ namespace gridfire::solver {
* @tparam EngineT The type of engine to use with this solver strategy. Must inherit from Engine.
*/
template <IsEngine EngineT>
class SingleZoneNetworkSolverStrategy {
class SingleZoneNetworkSolver {
public:
/**
* @brief Constructor for the NetworkSolverStrategy.
* @param engine The engine to use for evaluating the network.
*/
explicit SingleZoneNetworkSolverStrategy(EngineT& engine) : m_engine(engine) {};
explicit SingleZoneNetworkSolver(
const EngineT& engine,
const engine::scratch::StateBlob& ctx
) :
m_engine(engine),
m_scratch_blob(ctx.clone_structure()) {};
/**
* @brief Virtual destructor.
*/
virtual ~SingleZoneNetworkSolverStrategy() = default;
virtual ~SingleZoneNetworkSolver() = default;
/**
* @brief Evaluates the network for a given timestep.
@@ -92,14 +97,21 @@ namespace gridfire::solver {
*/
[[nodiscard]] virtual std::vector<std::tuple<std::string, std::string>> describe_callback_context() const = 0;
protected:
EngineT& m_engine; ///< The engine used by this solver strategy.
const EngineT& m_engine; ///< The engine used by this solver strategy.
std::unique_ptr<engine::scratch::StateBlob> m_scratch_blob;
};
template <IsEngine EngineT>
class MultiZoneNetworkSolverStrategy {
class MultiZoneNetworkSolver {
public:
explicit MultiZoneNetworkSolverStrategy(EngineT& engine) : m_engine(engine) {};
virtual ~MultiZoneNetworkSolverStrategy() = default;
explicit MultiZoneNetworkSolver(
const EngineT& engine,
const engine::scratch::StateBlob& ctx
) :
m_engine(engine),
m_scratch_blob_structure(ctx.clone_structure()){};
virtual ~MultiZoneNetworkSolver() = default;
virtual std::vector<NetOut> evaluate(
const std::vector<NetIn>& netIns,
@@ -108,12 +120,13 @@ namespace gridfire::solver {
virtual void set_callback(const std::any& callback) = 0;
[[nodiscard]] virtual std::vector<std::tuple<std::string, std::string>> describe_callback_context() const = 0;
protected:
EngineT& m_engine; ///< The engine used by this solver strategy.
const EngineT& m_engine; ///< The engine used by this solver strategy.
std::unique_ptr<engine::scratch::StateBlob> m_scratch_blob_structure;
};
/**
* @brief Type alias for a network solver strategy that uses a DynamicEngine.
*/
using SingleZoneDynamicNetworkSolverStrategy = SingleZoneNetworkSolverStrategy<engine::DynamicEngine>;
using MultiZoneDynamicNetworkSolverStrategy = MultiZoneNetworkSolverStrategy<engine::DynamicEngine>;
using SingleZoneDynamicNetworkSolver = SingleZoneNetworkSolver<engine::DynamicEngine>;
using MultiZoneDynamicNetworkSolver = MultiZoneNetworkSolver<engine::DynamicEngine>;
}

View File

@@ -2,6 +2,7 @@
#include "gridfire/engine/engine_abstract.h"
#include "fourdst/composition/composition.h"
#include "gridfire/engine/scratchpads/blob.h"
#include <string>
#include <functional>
@@ -15,6 +16,7 @@ namespace gridfire::utils {
* It then formats this information into a neatly aligned ASCII table, which
* is suitable for logging or printing to the console.
*
* @param ctx
* @param engine A constant reference to a `DynamicEngine` object, used to
* calculate the species timescales.
* @param composition The current composition of the plasma
@@ -58,10 +60,10 @@ namespace gridfire::utils {
* @endcode
*/
std::string formatNuclearTimescaleLogString(
engine::scratch::StateBlob &ctx,
const engine::DynamicEngine& engine,
const fourdst::composition::Composition& composition,
double T9,
double rho
double T9, double rho
);
template <typename T>

View File

@@ -3,12 +3,15 @@
#include "gridfire/utils/table_format.h"
#include "fourdst/atomic/species.h"
#include "gridfire/engine/scratchpads/blob.h"
#include <vector>
#include <string>
#include <algorithm>
namespace gridfire::engine::diagnostics {
std::optional<nlohmann::json> report_limiting_species(
scratch::StateBlob& ctx,
const DynamicEngine &engine,
const std::vector<double> &Y_full,
const std::vector<double> &E_full,
@@ -24,7 +27,7 @@ namespace gridfire::engine::diagnostics {
double abundance;
};
const auto& species_list = engine.getNetworkSpecies();
const auto& species_list = engine.getNetworkSpecies(ctx);
std::vector<SpeciesError> errors;
for (size_t i = 0; i < species_list.size(); ++i) {
@@ -75,6 +78,7 @@ namespace gridfire::engine::diagnostics {
}
std::optional<nlohmann::json> inspect_species_balance(
scratch::StateBlob& ctx,
const DynamicEngine& engine,
const std::string& species_name,
const fourdst::composition::Composition &comp,
@@ -90,11 +94,11 @@ namespace gridfire::engine::diagnostics {
double total_creation_flow = 0.0;
double total_destruction_flow = 0.0;
for (const auto& reaction : engine.getNetworkReactions()) {
for (const auto& reaction : engine.getNetworkReactions(ctx)) {
const int stoichiometry = reaction->stoichiometry(species_obj);
if (stoichiometry == 0) continue;
const double flow = engine.calculateMolarReactionFlow(*reaction, comp, T9, rho);
const double flow = engine.calculateMolarReactionFlow(ctx, *reaction, comp, T9, rho);
if (stoichiometry > 0) {
creation_ids.emplace_back(reaction->id());
@@ -157,17 +161,18 @@ namespace gridfire::engine::diagnostics {
}
std::optional<nlohmann::json> inspect_jacobian_stiffness(
scratch::StateBlob& ctx,
const DynamicEngine &engine,
const fourdst::composition::Composition &comp,
const double T9,
const double rho,
const bool json
) {
NetworkJacobian jac = engine.generateJacobianMatrix(comp, T9, rho);
NetworkJacobian jac = engine.generateJacobianMatrix(ctx, comp, T9, rho);
jac = regularize_jacobian(jac, comp);
const auto& species_list = engine.getNetworkSpecies();
const auto& species_list = engine.getNetworkSpecies(ctx);
double max_diag = 0.0;
double max_off_diag = 0.0;

View File

@@ -8,11 +8,16 @@
#include "gridfire/utils/hashing.h"
#include "gridfire/utils/table_format.h"
#include "gridfire/engine/scratchpads/engine_graph_scratchpad.h"
#include "gridfire/engine/scratchpads/blob.h"
#include "gridfire/engine/scratchpads/utils.h"
#include "fourdst/atomic/species.h"
#include "fourdst/atomic/atomicSpecies.h"
#include "quill/LogMacros.h"
// ReSharper disable once CppUnusedIncludeDirective
#include <cstdint>
#include <set>
#include <stdexcept>
@@ -28,9 +33,6 @@
#include "cppad/utility/sparse_rc.hpp"
#include "cppad/utility/sparse_rcv.hpp"
#ifdef GRIDFIRE_USE_OPENMP
#include <omp.h>
#endif
namespace {
@@ -115,8 +117,9 @@ namespace gridfire::engine {
const NetworkConstructionFlags reactionTypes ) :
m_weakRateInterpolator(rates::weak::UNIFIED_WEAK_DATA),
m_reactions(build_nuclear_network(composition, m_weakRateInterpolator, buildDepth, reactionTypes)),
m_partitionFunction(partitionFunction.clone()),
m_depth(buildDepth),
m_partitionFunction(partitionFunction.clone())
m_state_blob_offset(0) // For a base engine the offset is always 0
{
syncInternalMaps();
}
@@ -125,33 +128,37 @@ namespace gridfire::engine {
const reaction::ReactionSet &reactions
) :
m_weakRateInterpolator(rates::weak::UNIFIED_WEAK_DATA),
m_reactions(reactions)
m_reactions(reactions),
m_state_blob_offset(0)
{
syncInternalMaps();
}
std::expected<StepDerivatives<double>, EngineStatus> GraphEngine::calculateRHSAndEnergy(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
const double rho,
bool trust
) const {
return calculateRHSAndEnergy(comp, T9, rho, m_reactions);
return calculateRHSAndEnergy(ctx, comp, T9, rho, m_reactions);
}
std::expected<StepDerivatives<double>, EngineStatus> GraphEngine::calculateRHSAndEnergy(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
const double rho,
const reaction::ReactionSet &activeReactions
) const {
auto* state = scratch::get_state<scratch::GraphEngineScratchPad, true>(ctx);
LOG_TRACE_L3(m_logger, "Calculating RHS and Energy in GraphEngine at T9 = {}, rho = {}.", T9, rho);
const double Ye = comp.getElectronAbundance();
const std::vector<double> molarAbundances = comp.getMolarAbundanceVector();
if (m_usePrecomputation) {
const std::size_t state_hash = utils::hash_state(comp, T9, rho, activeReactions);
if (m_stepDerivativesCache.contains(state_hash)) {
return m_stepDerivativesCache.at(state_hash);
if (state->stepDerivativesCache.contains(state_hash)) {
return state->stepDerivativesCache.at(state_hash);
}
LOG_TRACE_L3(m_logger, "Using precomputation for reaction rates in GraphEngine calculateRHSAndEnergy.");
std::vector<double> bare_rates;
@@ -171,9 +178,9 @@ namespace gridfire::engine {
LOG_TRACE_L3(m_logger, "Precomputed {} forward and {} reverse reaction rates for active reactions.", bare_rates.size(), bare_reverse_rates.size());
// --- The public facing interface can always use the precomputed version since taping is done internally ---
StepDerivatives<double> result = calculateAllDerivativesUsingPrecomputation(comp, bare_rates, bare_reverse_rates, T9, rho, activeReactions);
m_stepDerivativesCache.insert(std::make_pair(state_hash, result));
m_most_recent_rhs_calculation = result;
StepDerivatives<double> result = calculateAllDerivativesUsingPrecomputation(ctx, comp, bare_rates, bare_reverse_rates, T9, rho, activeReactions);
state->stepDerivativesCache.insert(std::make_pair(state_hash, result));
state->most_recent_rhs_calculation = result;
return result;
} else {
LOG_TRACE_L2(m_logger, "Not using precomputation for reaction rates in GraphEngine calculateRHSAndEnergy.");
@@ -194,25 +201,28 @@ namespace gridfire::engine {
return false;
}
);
m_most_recent_rhs_calculation = result;
state->most_recent_rhs_calculation = result;
return result;
}
}
EnergyDerivatives GraphEngine::calculateEpsDerivatives(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
const double rho
) const {
return calculateEpsDerivatives(comp, T9, rho, m_reactions);
return calculateEpsDerivatives(ctx, comp, T9, rho, m_reactions);
}
EnergyDerivatives GraphEngine::calculateEpsDerivatives(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
const double rho,
const reaction::ReactionSet &activeReactions
) const {
auto* state = scratch::get_state<scratch::GraphEngineScratchPad, true>(ctx);
const size_t numSpecies = m_networkSpecies.size();
const size_t numADInputs = numSpecies + 2; // +2 for T9 and rho
@@ -236,10 +246,11 @@ namespace gridfire::engine {
w[numSpecies] = 1.0; // We want the derivative of the energy generation rate
// Sweep the tape forward to record the function value at x
m_rhsADFun.Forward(0, x);
assert(state->rhsADFun.has_value() && "AD tape for energy derivatives has not been recorded.");
state->rhsADFun.value().Forward(0, x);
// Extract the gradient at the previously evaluated point x using reverse mode
const std::vector<double> eps_derivatives = m_rhsADFun.Reverse(1, w);
const std::vector<double> eps_derivatives = state->rhsADFun.value().Reverse(1, w);
const double dEps_dT9 = eps_derivatives[numSpecies];
const double dEps_dRho = eps_derivatives[numSpecies + 1];
@@ -252,19 +263,8 @@ namespace gridfire::engine {
return {dEps_dT, dEps_dRho};
}
void GraphEngine::syncInternalMaps() {
LOG_INFO(m_logger, "Synchronizing internal maps for REACLIB graph network (serif::network::GraphNetwork)...");
collectNetworkSpecies();
populateReactionIDMap();
populateSpeciesToIndexMap();
collectAtomicReverseRateAtomicBases();
generateStoichiometryMatrix();
recordADTape(); // Record the AD tape for the RHS of the ODE (dY/di and dEps/di) for all independent variables i
[[maybe_unused]] const size_t inputSize = m_rhsADFun.Domain();
const size_t outputSize = m_rhsADFun.Range();
void GraphEngine::generate_jacobian_sparsity_pattern() {
const size_t outputSize = m_authoritativeADFun.Range();
// Create a range x range identity pattern
CppAD::sparse_rc<std::vector<size_t>> patternIn(outputSize, outputSize, outputSize);
@@ -272,9 +272,8 @@ namespace gridfire::engine {
patternIn.set(i, i, i);
}
m_rhsADFun.rev_jac_sparsity(patternIn, false, false, false, m_full_jacobian_sparsity_pattern);
m_authoritativeADFun.rev_jac_sparsity(patternIn, false, false, false, m_full_jacobian_sparsity_pattern);
m_jac_work.clear();
m_full_sparsity_set.clear();
const auto& rows = m_full_jacobian_sparsity_pattern.row();
const auto& cols = m_full_jacobian_sparsity_pattern.col();
@@ -285,6 +284,17 @@ namespace gridfire::engine {
m_full_sparsity_set.insert(std::make_pair(rows[k], cols[k]));
}
}
}
void GraphEngine::syncInternalMaps() {
LOG_INFO(m_logger, "Synchronizing internal maps for REACLIB graph network (serif::network::GraphNetwork)...");
collectNetworkSpecies();
populateReactionIDMap();
populateSpeciesToIndexMap();
collectAtomicReverseRateAtomicBases();
recordADTape(); // Record the AD tape for the RHS of the ODE (dY/di and dEps/di) for all independent variables i
generate_jacobian_sparsity_pattern();
precomputeNetwork();
LOG_INFO(m_logger, "Internal maps synchronized. Network contains {} species and {} reactions.",
@@ -344,81 +354,26 @@ namespace gridfire::engine {
}
// --- Basic Accessors and Queries ---
const std::vector<fourdst::atomic::Species>& GraphEngine::getNetworkSpecies() const {
const std::vector<fourdst::atomic::Species>& GraphEngine::getNetworkSpecies(scratch::StateBlob &ctx) const {
return m_networkSpecies;
}
const reaction::ReactionSet& GraphEngine::getNetworkReactions() const {
const reaction::ReactionSet& GraphEngine::getNetworkReactions(
scratch::StateBlob& ctx
) const {
return m_reactions;
}
void GraphEngine::setNetworkReactions(const reaction::ReactionSet &reactions) {
m_reactions = reactions;
syncInternalMaps();
}
bool GraphEngine::involvesSpecies(const fourdst::atomic::Species& species) const {
bool GraphEngine::involvesSpecies(
scratch::StateBlob& ctx,
const fourdst::atomic::Species& species
) const {
const bool found = m_networkSpeciesMap.contains(species.name());
return found;
}
// --- Validation Methods ---
bool GraphEngine::validateConservation() const {
LOG_TRACE_L1(m_logger, "Validating mass (A) and charge (Z) conservation across all reactions in the network.");
for (const auto& reaction : m_reactions) {
uint64_t totalReactantA = 0;
uint64_t totalReactantZ = 0;
uint64_t totalProductA = 0;
uint64_t totalProductZ = 0;
// Calculate total A and Z for reactants
for (const auto& reactant : reaction->reactants()) {
auto it = m_networkSpeciesMap.find(reactant.name());
if (it != m_networkSpeciesMap.end()) {
totalReactantA += it->second.a();
totalReactantZ += it->second.z();
} else {
// This scenario indicates a severe data integrity issue:
// a reactant is part of a reaction but not in the network's species map.
LOG_ERROR(m_logger, "CRITICAL ERROR: Reactant species '{}' in reaction '{}' not found in network species map during conservation validation.",
reactant.name(), reaction->id());
return false;
}
}
// Calculate total A and Z for products
for (const auto& product : reaction->products()) {
auto it = m_networkSpeciesMap.find(product.name());
if (it != m_networkSpeciesMap.end()) {
totalProductA += it->second.a();
totalProductZ += it->second.z();
} else {
// Similar critical error for product species
LOG_ERROR(m_logger, "CRITICAL ERROR: Product species '{}' in reaction '{}' not found in network species map during conservation validation.",
product.name(), reaction->id());
return false;
}
}
// Compare totals for conservation
if (totalReactantA != totalProductA) {
LOG_ERROR(m_logger, "Mass number (A) not conserved for reaction '{}': Reactants A={} vs Products A={}.",
reaction->id(), totalReactantA, totalProductA);
return false;
}
if (totalReactantZ != totalProductZ) {
LOG_ERROR(m_logger, "Atomic number (Z) not conserved for reaction '{}': Reactants Z={} vs Products Z={}.",
reaction->id(), totalReactantZ, totalProductZ);
return false;
}
}
LOG_TRACE_L1(m_logger, "Mass (A) and charge (Z) conservation validated successfully for all reactions.");
return true; // All reactions passed the conservation check
}
double GraphEngine::compute_reaction_flow(
scratch::StateBlob& ctx,
const std::vector<double> &local_abundances,
const std::vector<double> &screening_factors,
const std::vector<double> &bare_rates,
@@ -488,6 +443,7 @@ namespace gridfire::engine {
}
std::pair<double, double> GraphEngine::compute_neutrino_fluxes(
scratch::StateBlob& ctx,
const double netFlow,
const reaction::Reaction &reaction
) const {
@@ -518,6 +474,7 @@ namespace gridfire::engine {
}
GraphEngine::PrecomputationKernelResults GraphEngine::accumulate_flows_serial(
scratch::StateBlob& ctx,
const std::vector<double> &local_abundances,
const std::vector<double> &screening_factors,
const std::vector<double> &bare_rates,
@@ -529,19 +486,20 @@ namespace gridfire::engine {
results.dydt_vector.resize(m_networkSpecies.size(), 0.0);
std::vector<double> molarReactionFlows;
molarReactionFlows.reserve(m_precomputedReactions.size());
molarReactionFlows.reserve(m_precomputed_reactions.size());
size_t reactionCounter = 0;
std::vector<size_t> reactionIndices;
reactionIndices.reserve(m_precomputedReactions.size());
reactionIndices.reserve(m_precomputed_reactions.size());
for (const auto& reaction : activeReactions) {
uint64_t reactionHash = reaction->hash(0);
const size_t reactionIndex = m_precomputedReactionIndexMap.at(reactionHash);
const size_t reactionIndex = m_precomputed_reaction_index_map.at(reactionHash);
reactionIndices.push_back(reactionIndex);
const PrecomputedReaction& precomputedReaction = m_precomputedReactions[reactionIndex];
const PrecomputedReaction& precomputedReaction = m_precomputed_reactions[reactionIndex];
double netFlow = compute_reaction_flow(
ctx,
local_abundances,
screening_factors,
bare_rates,
@@ -554,7 +512,7 @@ namespace gridfire::engine {
molarReactionFlows.push_back(netFlow);
auto [local_neutrino_loss, local_neutrino_flux] = compute_neutrino_fluxes(netFlow, *reaction);
auto [local_neutrino_loss, local_neutrino_flux] = compute_neutrino_fluxes(ctx, netFlow, *reaction);
results.total_neutrino_energy_loss_rate += local_neutrino_loss;
results.total_neutrino_flux += local_neutrino_flux;
@@ -565,7 +523,7 @@ namespace gridfire::engine {
reactionCounter = 0;
for (const auto& [reaction, j]: std::views::zip(activeReactions, reactionIndices)) {
const auto& precomp = m_precomputedReactions[j];
const auto& precomp = m_precomputed_reactions[j];
const double R_j = molarReactionFlows[reactionCounter];
for (size_t i = 0; i < precomp.affected_species_indices.size(); ++i) {
@@ -746,28 +704,24 @@ namespace gridfire::engine {
}
bool GraphEngine::isUsingReverseReactions() const {
bool GraphEngine::isUsingReverseReactions(
scratch::StateBlob& ctx
) const {
return m_useReverseReactions;
}
void GraphEngine::setUseReverseReactions(const bool useReverse) {
m_useReverseReactions = useReverse;
syncInternalMaps();
}
size_t GraphEngine::getSpeciesIndex(const fourdst::atomic::Species &species) const {
size_t GraphEngine::getSpeciesIndex(
scratch::StateBlob& ctx,
const fourdst::atomic::Species &species
) const {
return m_speciesToIndexMap.at(species); // Returns the index of the species in the stoichiometry matrix
}
std::vector<double> GraphEngine::mapNetInToMolarAbundanceVector(const NetIn &netIn) const {
std::vector<double> Y(m_networkSpecies.size(), 0.0); // Initialize with zeros
for (const auto& [sp, y] : netIn.composition) {
Y[getSpeciesIndex(sp)] = y; // Map species to their molar abundance
}
return Y; // Return the vector of molar abundances
}
PrimingReport GraphEngine::primeEngine(
scratch::StateBlob& ctx,
const NetIn &netIn
) const {
PrimingReport GraphEngine::primeEngine(const NetIn &netIn) {
NetIn fullNetIn;
fourdst::composition::Composition composition;
@@ -787,27 +741,13 @@ namespace gridfire::engine {
reactionTypesToIgnore = {reaction::ReactionType::WEAK};
}
auto primingReport = primeNetwork(fullNetIn, *this, reactionTypesToIgnore);
auto primingReport = primeNetwork(ctx, fullNetIn, *this, reactionTypesToIgnore);
m_has_been_primed = true;
return primingReport;
}
BuildDepthType GraphEngine::getDepth() const {
return m_depth;
}
void GraphEngine::rebuild(const fourdst::composition::CompositionAbstract &comp, const BuildDepthType depth) {
if (depth != m_depth) {
m_depth = depth;
m_reactions = build_nuclear_network(comp, m_weakRateInterpolator, m_depth);
syncInternalMaps(); // Resync internal maps after changing the depth
} else {
LOG_DEBUG(m_logger, "Rebuild requested with the same depth. No changes made to the network.");
}
}
fourdst::composition::Composition GraphEngine::collectComposition(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho
@@ -827,7 +767,10 @@ namespace gridfire::engine {
return result;
}
SpeciesStatus GraphEngine::getSpeciesStatus(const fourdst::atomic::Species &species) const {
SpeciesStatus GraphEngine::getSpeciesStatus(
scratch::StateBlob& ctx,
const fourdst::atomic::Species &species
) const {
if (m_networkSpeciesMap.contains(species.name())) {
return SpeciesStatus::ACTIVE;
}
@@ -835,15 +778,18 @@ namespace gridfire::engine {
}
std::optional<StepDerivatives<double>> GraphEngine::getMostRecentRHSCalculation() const {
if (!m_most_recent_rhs_calculation.has_value()) {
std::optional<StepDerivatives<double>> GraphEngine::getMostRecentRHSCalculation(
scratch::StateBlob& ctx
) const {
const auto *state = scratch::get_state<scratch::GraphEngineScratchPad, true>(ctx);
if (!state->most_recent_rhs_calculation.has_value()) {
return std::nullopt;
}
return m_most_recent_rhs_calculation.value();
return state->most_recent_rhs_calculation.value();
}
StepDerivatives<double> GraphEngine::calculateAllDerivativesUsingPrecomputation(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
const std::vector<double> &bare_rates,
const std::vector<double> &bare_reverse_rates,
@@ -851,6 +797,7 @@ namespace gridfire::engine {
const double rho,
const reaction::ReactionSet &activeReactions
) const {
auto *state = scratch::get_state<scratch::GraphEngineScratchPad, true>(ctx);
LOG_TRACE_L3(m_logger, "Computing screening factors for {} active reactions.", activeReactions.size());
// --- Calculate screening factors ---
const std::vector<double> screeningFactors = m_screeningModel->calculateScreeningFactors(
@@ -860,17 +807,17 @@ namespace gridfire::engine {
T9,
rho
);
m_local_abundance_cache.clear();
state->local_abundance_cache.clear();
for (const auto& species: m_networkSpecies) {
m_local_abundance_cache.push_back(comp.contains(species) ? comp.getMolarAbundance(species) : 0.0);
state->local_abundance_cache.push_back(comp.contains(species) ? comp.getMolarAbundance(species) : 0.0);
}
StepDerivatives<double> result;
std::vector<double> dydt_scratch(m_networkSpecies.size(), 0.0);
#ifndef GRIDFIRE_USE_OPENMP
const auto [dydt_vector, total_neutrino_energy_loss_rate, total_neutrino_flux] = accumulate_flows_serial(
m_local_abundance_cache,
ctx,
state->local_abundance_cache,
screeningFactors,
bare_rates,
bare_reverse_rates,
@@ -880,19 +827,6 @@ namespace gridfire::engine {
dydt_scratch = dydt_vector;
result.neutrinoEnergyLossRate = total_neutrino_energy_loss_rate;
result.totalNeutrinoFlux = total_neutrino_flux;
#else
const auto [dydt_vector, total_neutrino_energy_loss_rate, total_neutrino_flux] = accumulate_flows_parallel(
m_local_abundance_cache,
screeningFactors,
bare_rates,
bare_reverse_rates,
rho,
activeReactions
);
dydt_scratch = dydt_vector;
result.neutrinoEnergyLossRate = total_neutrino_energy_loss_rate;
result.totalNeutrinoFlux = total_neutrino_flux;
#endif
// load scratch into result.dydt
for (size_t i = 0; i < m_networkSpecies.size(); ++i) {
@@ -910,33 +844,26 @@ namespace gridfire::engine {
}
// --- Generate Stoichiometry Matrix ---
void GraphEngine::generateStoichiometryMatrix() {
return; // Deprecated
}
void GraphEngine::setScreeningModel(const screening::ScreeningType model) {
m_screeningModel = screening::selectScreeningModel(model);
m_screeningType = model;
}
screening::ScreeningType GraphEngine::getScreeningModel() const {
screening::ScreeningType GraphEngine::getScreeningModel(
scratch::StateBlob& ctx
) const {
return m_screeningType;
}
void GraphEngine::setPrecomputation(const bool precompute) {
m_usePrecomputation = precompute;
}
bool GraphEngine::isPrecomputationEnabled() const {
bool GraphEngine::isPrecomputationEnabled(
scratch::StateBlob& ctx
) const {
return m_usePrecomputation;
}
const partition::PartitionFunction & GraphEngine::getPartitionFunction() const {
const partition::PartitionFunction & GraphEngine::getPartitionFunction(
scratch::StateBlob& ctx
) const {
return *m_partitionFunction;
}
double GraphEngine::calculateMolarReactionFlow(
scratch::StateBlob& ctx,
const reaction::Reaction &reaction,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
@@ -962,10 +889,12 @@ namespace gridfire::engine {
}
NetworkJacobian GraphEngine::generateJacobianMatrix(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
const double rho
) const {
auto *state = scratch::get_state<scratch::GraphEngineScratchPad, true>(ctx);
fourdst::composition::Composition mutableComp;
for (const auto& species : m_networkSpecies) {
mutableComp.registerSpecies(species);
@@ -986,10 +915,11 @@ namespace gridfire::engine {
adInput[numSpecies + 1] = rho; // rho
// 2. Calculate the full jacobian
const std::vector<double> dotY = m_rhsADFun.Jacobian(adInput);
assert(state->rhsADFun.has_value() && "RHS ADFun not recorded before Jacobian generation.");
const std::vector<double> dotY = state->rhsADFun.value().Jacobian(adInput);
// 3. Pack jacobian vector into sparse matrix
Eigen::SparseMatrix<double> jacobianMatrix(numSpecies, numSpecies);
Eigen::SparseMatrix<double> jacobianMatrix(static_cast<long>(numSpecies), static_cast<long>(numSpecies));
std::vector<Eigen::Triplet<double> > triplets;
for (size_t i = 0; i < numSpecies; ++i) {
for (size_t j = 0; j < numSpecies; ++j) {
@@ -1013,12 +943,15 @@ namespace gridfire::engine {
}
NetworkJacobian GraphEngine::generateJacobianMatrix(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
const double rho,
const std::vector<fourdst::atomic::Species> &activeSpecies
) const {
// PERF: For small k it may make sense to implement a purley forward mode AD computation, some heuristic could be used to switch between the two methods based on k and total network species
// PERF: For small k it may make sense to implement a purley forward mode AD computation,
// some heuristic could be used to switch between the two methods based on k and
// total network species
const size_t k_active = activeSpecies.size();
// --- 1. Get the list of global indices ---
@@ -1026,8 +959,8 @@ namespace gridfire::engine {
active_indices.reserve(k_active);
for (const auto& species : activeSpecies) {
assert(involvesSpecies(species));
active_indices.push_back(getSpeciesIndex(species));
assert(involvesSpecies(ctx, species));
active_indices.push_back(getSpeciesIndex(ctx, species));
}
// --- 2. Build the k x k sparsity pattern ---
@@ -1041,15 +974,17 @@ namespace gridfire::engine {
}
// --- 3. Call the sparse reverse-mode implementation ---
return generateJacobianMatrix(comp, T9, rho, sparsityPattern);
return generateJacobianMatrix(ctx, comp, T9, rho, sparsityPattern);
}
NetworkJacobian GraphEngine::generateJacobianMatrix(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
const double rho,
const SparsityPattern &sparsityPattern
) const {
auto *state = scratch::get_state<scratch::GraphEngineScratchPad, true>(ctx);
// --- Compute the intersection of the requested sparsity pattern with the full sparsity pattern ---
SparsityPattern intersectionSparsityPattern;
for (const auto& entry : sparsityPattern) {
@@ -1097,30 +1032,32 @@ namespace gridfire::engine {
}
// --- Check cache for existing subset ---
if (!m_jacobianSubsetCache.contains(sparsity_hash)) {
m_jacobianSubsetCache.emplace(sparsity_hash, CppAD_sparsity_pattern);
m_jac_work.clear();
if (!state->jacobianSubsetCache.contains(sparsity_hash)) {
state->jacobianSubsetCache.emplace(sparsity_hash, CppAD_sparsity_pattern);
state->jac_work.clear();
} else {
if (m_jacWorkCache.contains(sparsity_hash)) {
m_jac_work.clear();
m_jac_work = m_jacWorkCache.at(sparsity_hash);
if (state->jacWorkCache.contains(sparsity_hash)) {
state->jac_work.clear();
state->jac_work = state->jacWorkCache.at(sparsity_hash);
}
}
auto& jac_subset = m_jacobianSubsetCache.at(sparsity_hash);
m_rhsADFun.sparse_jac_rev(
auto& jac_subset = state->jacobianSubsetCache.at(sparsity_hash);
assert(state->rhsADFun.has_value() && "RHS ADFun not recorded before Jacobian generation.");
state->rhsADFun.value().sparse_jac_rev(
x,
jac_subset, // Sparse Jacobian output
m_full_jacobian_sparsity_pattern,
"cppad",
m_jac_work // Work vector for CppAD
state->jac_work // Work vector for CppAD
);
// --- Stash the now populated work vector in the cache if not already present ---
if (!m_jacWorkCache.contains(sparsity_hash)) {
m_jacWorkCache.emplace(sparsity_hash, m_jac_work);
if (!state->jacWorkCache.contains(sparsity_hash)) {
state->jacWorkCache.emplace(sparsity_hash, state->jac_work);
}
Eigen::SparseMatrix<double> jacobianMatrix(numSpecies, numSpecies);
Eigen::SparseMatrix<double> jacobianMatrix(static_cast<long>(numSpecies), static_cast<long>(numSpecies));
std::vector<Eigen::Triplet<double> > triplets;
for (size_t k = 0; k < nnz; ++k) {
const size_t row = jac_subset.row()[k];
@@ -1142,20 +1079,10 @@ namespace gridfire::engine {
return jac;
}
std::unordered_map<fourdst::atomic::Species, int> GraphEngine::getNetReactionStoichiometry(
const reaction::Reaction &reaction
) {
return reaction.stoichiometry();
}
int GraphEngine::getStoichiometryMatrixEntry(
const fourdst::atomic::Species& species,
const reaction::Reaction &reaction
void GraphEngine::exportToDot(
scratch::StateBlob& ctx,
const std::string &filename
) const {
return reaction.stoichiometry(species);
}
void GraphEngine::exportToDot(const std::string &filename) const {
LOG_TRACE_L1(m_logger, "Exporting network graph to DOT file: {}", filename);
std::ofstream dotFile(filename);
@@ -1203,7 +1130,10 @@ namespace gridfire::engine {
LOG_TRACE_L1(m_logger, "Successfully exported network to {}", filename);
}
void GraphEngine::exportToCSV(const std::string &filename) const {
void GraphEngine::exportToCSV(
scratch::StateBlob& ctx,
const std::string &filename
) const {
LOG_TRACE_L1(m_logger, "Exporting network graph to CSV file: {}", filename);
std::ofstream csvFile(filename, std::ios::out | std::ios::trunc);
@@ -1241,14 +1171,16 @@ namespace gridfire::engine {
}
std::expected<std::unordered_map<fourdst::atomic::Species, double>, EngineStatus> GraphEngine::getSpeciesTimescales(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
const double rho
) const {
return getSpeciesTimescales(comp, T9, rho, m_reactions);
return getSpeciesTimescales(ctx, comp, T9, rho, m_reactions);
}
std::expected<std::unordered_map<fourdst::atomic::Species, double>, EngineStatus> GraphEngine::getSpeciesTimescales(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
const double rho,
@@ -1287,14 +1219,16 @@ namespace gridfire::engine {
}
std::expected<std::unordered_map<fourdst::atomic::Species, double>, EngineStatus> GraphEngine::getSpeciesDestructionTimescales(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
const double rho
) const {
return getSpeciesDestructionTimescales(comp, T9, rho, m_reactions);
return getSpeciesDestructionTimescales(ctx, comp, T9, rho, m_reactions);
}
std::expected<std::unordered_map<fourdst::atomic::Species, double>, EngineStatus> GraphEngine::getSpeciesDestructionTimescales(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
const double rho,
@@ -1337,7 +1271,10 @@ namespace gridfire::engine {
return speciesDestructionTimescales;
}
fourdst::composition::Composition GraphEngine::update(const NetIn &netIn) {
fourdst::composition::Composition GraphEngine::project(
scratch::StateBlob& ctx,
const NetIn &netIn
) const {
fourdst::composition::Composition baseUpdatedComposition = netIn.composition;
for (const auto& species : m_networkSpecies) {
if (!netIn.composition.contains(species)) {
@@ -1347,11 +1284,8 @@ namespace gridfire::engine {
return baseUpdatedComposition;
}
bool GraphEngine::isStale(const NetIn &netIn) {
return false;
}
void GraphEngine::recordADTape() {
void GraphEngine::recordADTape() const {
LOG_TRACE_L1(m_logger, "Recording AD tape for the RHS calculation...");
// Task 1: Set dimensions and initialize the matrix
@@ -1415,13 +1349,14 @@ namespace gridfire::engine {
);
dependentVector.push_back(result.nuclearEnergyGenerationRate);
m_rhsADFun.Dependent(adInput, dependentVector);
m_rhsADFun.optimize();
m_authoritativeADFun.Dependent(adInput, dependentVector);
m_authoritativeADFun.optimize();
LOG_TRACE_L1(m_logger, "AD tape recorded successfully for the RHS and Eps calculation. Number of independent variables: {}.", adInput.size());
}
void GraphEngine::collectAtomicReverseRateAtomicBases() {
void GraphEngine::collectAtomicReverseRateAtomicBases(
) {
m_atomicReverseRates.clear();
m_atomicReverseRates.reserve(m_reactions.size());
@@ -1434,7 +1369,7 @@ namespace gridfire::engine {
}
}
void GraphEngine::precomputeNetwork() {
void GraphEngine::precomputeNetwork() {
LOG_TRACE_L1(m_logger, "Pre-computing constant components of GraphNetwork state...");
// --- Reverse map for fast species lookups ---
@@ -1443,10 +1378,10 @@ namespace gridfire::engine {
speciesIndexMap[m_networkSpecies[i]] = i;
}
m_precomputedReactions.clear();
m_precomputedReactions.reserve(m_reactions.size());
m_precomputedReactionIndexMap.clear();
m_precomputedReactionIndexMap.reserve(m_reactions.size());
m_precomputed_reactions.clear();
m_precomputed_reactions.reserve(m_reactions.size());
m_precomputed_reaction_index_map.clear();
m_precomputed_reaction_index_map.reserve(m_reactions.size());
for (size_t i = 0; i < m_reactions.size(); ++i) {
const auto& reaction = m_reactions[i];
@@ -1456,7 +1391,7 @@ namespace gridfire::engine {
uint64_t reactionHash = reaction.hash(0);
precomp.reaction_hash = reactionHash;
m_precomputedReactionIndexMap[reactionHash] = i;
m_precomputed_reaction_index_map[reactionHash] = i;
// --- Precompute forward reaction information ---
// Count occurrences for each reactant to determine powers and symmetry
@@ -1506,7 +1441,7 @@ namespace gridfire::engine {
precomp.stoichiometric_coefficients.push_back(coeff);
}
m_precomputedReactions.push_back(std::move(precomp));
m_precomputed_reactions.push_back(std::move(precomp));
}
LOG_TRACE_L1(m_logger, "Pre-computation complete. Precomputed data for {} reactions.", m_precomputedReactions.size());
}
@@ -1523,6 +1458,7 @@ namespace gridfire::engine {
if ( p != 0) { return false; }
const double T9 = tx[0];
// We can pass a dummy comp and rho because reverse rates should only be calculated for strong reactions whose
// rates of progression do not depend on composition or density.
const fourdst::composition::Composition dummyComp;
@@ -1612,68 +1548,4 @@ namespace gridfire::engine {
return true;
}
#ifdef GRIDFIRE_USE_OPENMP
GraphEngine::PrecomputationKernelResults GraphEngine::accumulate_flows_parallel(
const std::vector<double> &local_abundances,
const std::vector<double> &screening_factors,
const std::vector<double> &bare_rates,
const std::vector<double> &bare_reverse_rates,
const double rho,
const reaction::ReactionSet &activeReactions
) const {
int n_threads = omp_get_max_threads();
std::vector<std::vector<double>> thread_local_dydt(n_threads, std::vector<double>(m_networkSpecies.size(), 0.0));
double total_neutrino_energy_loss_rate = 0.0;
double total_neutrino_flux = 0.0;
#pragma omp parallel for schedule(static) reduction(+:total_neutrino_energy_loss_rate, total_neutrino_flux)
for (size_t k = 0; k < activeReactions.size(); ++k) {
int t_id = omp_get_thread_num();
const auto& reaction = activeReactions[k];
const size_t reactionIndex = m_precomputedReactionIndexMap.at(reaction.hash(0));
const PrecomputedReaction& precomputedReaction = m_precomputedReactions[reactionIndex];
double netFlow = compute_reaction_flow(
local_abundances,
screening_factors,
bare_rates,
bare_reverse_rates,
rho,
reactionIndex,
reaction,
reactionIndex,
precomputedReaction
);
auto [neutrinoEnergyLossRate, neutrinoFlux] = compute_neutrino_fluxes(
netFlow,
reaction
);
total_neutrino_energy_loss_rate += neutrinoEnergyLossRate;
total_neutrino_flux += neutrinoFlux;
for (size_t i = 0; i < precomputedReaction.affected_species_indices.size(); ++i) {
thread_local_dydt[t_id][precomputedReaction.affected_species_indices[i]] +=
netFlow * precomputedReaction.stoichiometric_coefficients[i];
}
}
PrecomputationKernelResults results;
results.total_neutrino_energy_loss_rate = total_neutrino_energy_loss_rate;
results.total_neutrino_flux = total_neutrino_flux;
results.dydt_vector.resize(m_networkSpecies.size(), 0.0);
#pragma omp parallel for schedule(static)
for (size_t i = 0; i < m_networkSpecies.size(); ++i) {
double sum = 0.0;
for (int t = 0; t < n_threads; ++t) sum += thread_local_dydt[t][i];
results.dydt_vector[i] = sum;
}
return results;
}
#endif
}

View File

@@ -9,6 +9,9 @@
#include "gridfire/types/types.h"
#include "gridfire/exceptions/error_solver.h"
#include "gridfire/engine/scratchpads/blob.h"
#include "gridfire/engine/scratchpads/engine_graph_scratchpad.h"
#include "fourdst/logging/logging.h"
#include "gridfire/solver/strategies/CVODE_solver_strategy.h"
#include "quill/Logger.h"
@@ -20,12 +23,12 @@ namespace gridfire::engine {
using fourdst::atomic::Species;
PrimingReport primeNetwork(
const NetIn& netIn,
GraphEngine& engine,
const std::optional<std::vector<reaction::ReactionType>>& ignoredReactionTypes
scratch::StateBlob &ctx,
const NetIn& netIn,
const GraphEngine& engine, const std::optional<std::vector<reaction::ReactionType>>& ignoredReactionTypes
) {
const auto logger = LogManager::getInstance().getLogger("log");
solver::CVODESolverStrategy integrator(engine);
solver::CVODESolverStrategy integrator(engine, ctx);
// Do not need high precision for priming
integrator.set_absTol(1e-3);
@@ -70,7 +73,7 @@ namespace gridfire::engine {
minAbundance = y;
}
}
double abundanceForUnprimedSpecies = minAbundance / 1e10;
const double abundanceForUnprimedSpecies = minAbundance / 1e10;
for (const auto& sp : unprimedSpecies) {
LOG_TRACE_L1(logger, "Clamping Species {}: initial abundance {}, primed abundance {} to {}", sp.name(), netIn.composition.getMolarAbundance(sp), report.primedComposition.getMolarAbundance(sp), abundanceForUnprimedSpecies);
report.primedComposition.setMolarAbundance(sp, abundanceForUnprimedSpecies);

View File

@@ -9,6 +9,10 @@
#include "gridfire/exceptions/error_engine.h"
#include "gridfire/utils/hashing.h"
#include "gridfire/engine/scratchpads/blob.h"
#include "gridfire/engine/scratchpads/utils.h"
#include "gridfire/engine/scratchpads/engine_adaptive_scratchpad.h"
#include "quill/LogMacros.h"
#include "quill/Logger.h"
@@ -17,23 +21,24 @@ namespace gridfire::engine {
AdaptiveEngineView::AdaptiveEngineView(
DynamicEngine &baseEngine
) :
m_baseEngine(baseEngine),
m_activeSpecies(baseEngine.getNetworkSpecies()),
m_activeReactions(baseEngine.getNetworkReactions())
{}
m_baseEngine(baseEngine) {}
fourdst::composition::Composition AdaptiveEngineView::update(const NetIn &netIn) {
m_activeReactions.clear();
m_activeSpecies.clear();
fourdst::composition::Composition AdaptiveEngineView::project(
scratch::StateBlob& ctx,
const NetIn &netIn
) const {
auto *state = scratch::get_state<scratch::AdaptiveEngineViewScratchPad, true>(ctx);
state->active_reactions.clear();
state->active_species.clear();
fourdst::composition::Composition baseUpdatedComposition = m_baseEngine.update(netIn);
fourdst::composition::Composition baseUpdatedComposition = m_baseEngine.project(ctx, netIn);
NetIn updatedNetIn = netIn;
updatedNetIn.composition = baseUpdatedComposition;
LOG_TRACE_L1(m_logger, "Updating AdaptiveEngineView with new network input...");
auto [allFlows, composition] = calculateAllReactionFlows(updatedNetIn);
auto [allFlows, composition] = calculateAllReactionFlows(ctx, updatedNetIn);
double maxFlow = 0.0;
@@ -44,51 +49,50 @@ namespace gridfire::engine {
}
LOG_DEBUG(m_logger, "Maximum reaction flow rate in adaptive engine view: {:0.3E} [mol/s]", maxFlow);
const std::unordered_set<Species> reachableSpecies = findReachableSpecies(updatedNetIn);
const std::unordered_set<Species> reachableSpecies = findReachableSpecies(ctx, updatedNetIn);
LOG_DEBUG(m_logger, "Found {} reachable species in adaptive engine view.", reachableSpecies.size());
const std::vector<const reaction::Reaction*> finalReactions = cullReactionsByFlow(allFlows, reachableSpecies, composition, maxFlow);
const std::vector<const reaction::Reaction*> finalReactions = cullReactionsByFlow(ctx, allFlows, reachableSpecies, composition, maxFlow);
finalizeActiveSet(finalReactions);
finalizeActiveSet(ctx, finalReactions);
auto [rescuedReactions, rescuedSpecies] = rescueEdgeSpeciesDestructionChannel(composition, netIn.temperature/1e9, netIn.density, m_activeSpecies, m_activeReactions);
auto [rescuedReactions, rescuedSpecies] = rescueEdgeSpeciesDestructionChannel(
ctx,
composition,
netIn.temperature/1e9,
netIn.density
);
for (const auto& reactionPtr : rescuedReactions) {
m_activeReactions.add_reaction(*reactionPtr);
state->active_reactions.add_reaction(*reactionPtr);
}
for (const auto& species : rescuedSpecies) {
if (!std::ranges::contains(m_activeSpecies, species) && m_baseEngine.getSpeciesStatus(species) == SpeciesStatus::ACTIVE) {
m_activeSpecies.push_back(species);
if (!std::ranges::contains(state->active_species, species) && m_baseEngine.getSpeciesStatus(ctx, species) == SpeciesStatus::ACTIVE) {
state->active_species.push_back(species);
}
}
m_isStale = false;
LOG_INFO(m_logger, "AdaptiveEngineView updated successfully with {} active species and {} active reactions.", m_activeSpecies.size(), m_activeReactions.size());
LOG_INFO(m_logger, "AdaptiveEngineView updated successfully with {} active species and {} active reactions.", state->active_species.size(), state->active_reactions.size());
return updatedNetIn.composition;
}
bool AdaptiveEngineView::isStale(const NetIn &netIn) {
return m_isStale || m_baseEngine.isStale(netIn);
}
const std::vector<Species> & AdaptiveEngineView::getNetworkSpecies() const {
return m_activeSpecies;
const std::vector<Species> & AdaptiveEngineView::getNetworkSpecies(scratch::StateBlob& ctx) const {
return scratch::get_state<scratch::AdaptiveEngineViewScratchPad, true>(ctx)->active_species;
}
std::expected<StepDerivatives<double>, EngineStatus> AdaptiveEngineView::calculateRHSAndEnergy(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
const double rho, bool trust
) const {
LOG_TRACE_L2(m_logger, "Calculating RHS and Energy in AdaptiveEngineView at T9 = {}, rho = {}.", T9, rho);
validateState();
const fourdst::composition::Composition collectedComp = collectComposition(comp, T9, rho);
const fourdst::composition::Composition collectedComp = collectComposition(ctx, comp, T9, rho);
auto result = m_baseEngine.calculateRHSAndEnergy(collectedComp, T9, rho, true);
auto result = m_baseEngine.calculateRHSAndEnergy(ctx, collectedComp, T9, rho, true);
LOG_TRACE_L2(m_logger, "Base engine calculation of RHS and Energy complete.");
if (!result) {
@@ -100,99 +104,89 @@ namespace gridfire::engine {
}
EnergyDerivatives AdaptiveEngineView::calculateEpsDerivatives(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
const double rho
) const {
validateState();
return m_baseEngine.calculateEpsDerivatives(comp, T9, rho);
return m_baseEngine.calculateEpsDerivatives(ctx, comp, T9, rho);
}
NetworkJacobian AdaptiveEngineView::generateJacobianMatrix(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
const double rho
) const {
return generateJacobianMatrix(comp, T9, rho, m_activeSpecies);
const auto *state = scratch::get_state<scratch::AdaptiveEngineViewScratchPad, true>(ctx);
return generateJacobianMatrix(ctx, comp, T9, rho, state->active_species);
}
NetworkJacobian AdaptiveEngineView::generateJacobianMatrix(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
const double rho,
const std::vector<Species> &activeSpecies
) const {
validateState();
return m_baseEngine.generateJacobianMatrix(comp, T9, rho, activeSpecies);
const auto *state = scratch::get_state<scratch::AdaptiveEngineViewScratchPad, true>(ctx);
return m_baseEngine.generateJacobianMatrix(ctx, comp, T9, rho, state->active_species);
}
NetworkJacobian AdaptiveEngineView::generateJacobianMatrix(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
const double rho,
const SparsityPattern &sparsityPattern
) const {
validateState();
return m_baseEngine.generateJacobianMatrix(comp, T9, rho, sparsityPattern);
}
void AdaptiveEngineView::generateStoichiometryMatrix() {
validateState();
m_baseEngine.generateStoichiometryMatrix();
}
int AdaptiveEngineView::getStoichiometryMatrixEntry(
const Species &species,
const reaction::Reaction& reaction
) const {
validateState();
return m_baseEngine.getStoichiometryMatrixEntry(species, reaction);
return m_baseEngine.generateJacobianMatrix(ctx, comp, T9, rho, sparsityPattern);
}
double AdaptiveEngineView::calculateMolarReactionFlow(
scratch::StateBlob& ctx,
const reaction::Reaction &reaction,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
const double rho
) const {
validateState();
if (!m_activeReactions.contains(reaction)) {
const auto *state = scratch::get_state<scratch::AdaptiveEngineViewScratchPad, true>(ctx);
if (!state->active_reactions.contains(reaction)) {
LOG_ERROR(m_logger, "Reaction '{}' is not part of the active reactions in the adaptive engine view.", reaction.id());
m_logger -> flush_log();
throw std::runtime_error("Reaction not found in active reactions: " + std::string(reaction.id()));
}
return m_baseEngine.calculateMolarReactionFlow(reaction, comp, T9, rho);
return m_baseEngine.calculateMolarReactionFlow(ctx, reaction, comp, T9, rho);
}
const reaction::ReactionSet & AdaptiveEngineView::getNetworkReactions() const {
return m_activeReactions;
}
void AdaptiveEngineView::setNetworkReactions(const reaction::ReactionSet &reactions) {
LOG_CRITICAL(m_logger, "AdaptiveEngineView does not support setting network reactions directly. Use update() with NetIn instead. Perhaps you meant to call this on the base engine?");
throw exceptions::UnableToSetNetworkReactionsError("AdaptiveEngineView does not support setting network reactions directly. Use update() with NetIn instead. Perhaps you meant to call this on the base engine?");
const reaction::ReactionSet & AdaptiveEngineView::getNetworkReactions(
scratch::StateBlob& ctx
) const {
return scratch::get_state<scratch::AdaptiveEngineViewScratchPad, true>(ctx) -> active_reactions;
}
std::expected<std::unordered_map<Species, double>, EngineStatus> AdaptiveEngineView::getSpeciesTimescales(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
const double rho
) const {
validateState();
const auto result = m_baseEngine.getSpeciesTimescales(comp, T9, rho);
const auto result = m_baseEngine.getSpeciesTimescales(ctx, comp, T9, rho);
if (!result) {
return std::unexpected{result.error()};
}
const auto* state = scratch::get_state<scratch::AdaptiveEngineViewScratchPad, true>(ctx);
const std::unordered_map<Species, double>& fullTimescales = result.value();
std::unordered_map<Species, double> culledTimescales;
culledTimescales.reserve(m_activeSpecies.size());
for (const auto& active_species : m_activeSpecies) {
culledTimescales.reserve(state->active_species.size());
for (const auto& active_species : state->active_species) {
if (fullTimescales.contains(active_species)) {
culledTimescales[active_species] = fullTimescales.at(active_species);
}
@@ -202,21 +196,23 @@ namespace gridfire::engine {
}
std::expected<std::unordered_map<Species, double>, EngineStatus> AdaptiveEngineView::getSpeciesDestructionTimescales(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
const double rho
) const {
validateState();
const auto result = m_baseEngine.getSpeciesDestructionTimescales(comp, T9, rho);
const auto result = m_baseEngine.getSpeciesDestructionTimescales(ctx, comp, T9, rho);
if (!result) {
return std::unexpected{result.error()};
}
const auto* state = scratch::get_state<scratch::AdaptiveEngineViewScratchPad, true>(ctx);
const std::unordered_map<Species, double>& destructionTimescales = result.value();
std::unordered_map<Species, double> culledTimescales;
culledTimescales.reserve(m_activeSpecies.size());
for (const auto& active_species : m_activeSpecies) {
culledTimescales.reserve(state->active_species.size());
for (const auto& active_species : state->active_species) {
if (destructionTimescales.contains(active_species)) {
culledTimescales[active_species] = destructionTimescales.at(active_species);
}
@@ -224,34 +220,29 @@ namespace gridfire::engine {
return culledTimescales;
}
void AdaptiveEngineView::setScreeningModel(const screening::ScreeningType model) {
m_baseEngine.setScreeningModel(model);
screening::ScreeningType AdaptiveEngineView::getScreeningModel(
scratch::StateBlob& ctx
) const {
return m_baseEngine.getScreeningModel(ctx);
}
screening::ScreeningType AdaptiveEngineView::getScreeningModel() const {
return m_baseEngine.getScreeningModel();
}
std::vector<double> AdaptiveEngineView::mapNetInToMolarAbundanceVector(const NetIn &netIn) const {
std::vector<double> Y(m_activeSpecies.size(), 0.0); // Initialize with zeros
for (const auto& [species, y] : netIn.composition) {
Y[getSpeciesIndex(species)] = y; // Map species to their molar abundance
}
return Y; // Return the vector of molar abundances
}
PrimingReport AdaptiveEngineView::primeEngine(const NetIn &netIn) {
return m_baseEngine.primeEngine(netIn);
PrimingReport AdaptiveEngineView::primeEngine(
scratch::StateBlob& ctx,
const NetIn &netIn
) const {
return m_baseEngine.primeEngine(ctx, netIn);
}
fourdst::composition::Composition AdaptiveEngineView::collectComposition(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
const double rho
) const {
fourdst::composition::Composition result = m_baseEngine.collectComposition(comp, T9, rho);
const auto* state = scratch::get_state<scratch::AdaptiveEngineViewScratchPad, true>(ctx);
fourdst::composition::Composition result = m_baseEngine.collectComposition(ctx, comp, T9, rho);
for (const auto& species : m_activeSpecies) {
for (const auto& species : state->active_species) {
if (!result.contains(species)) {
result.registerSpecies(species);
}
@@ -260,18 +251,32 @@ namespace gridfire::engine {
return result;
}
SpeciesStatus AdaptiveEngineView::getSpeciesStatus(const fourdst::atomic::Species &species) const {
const SpeciesStatus status = m_baseEngine.getSpeciesStatus(species);
if (status == SpeciesStatus::ACTIVE && std::ranges::find(m_activeSpecies, species) == m_activeSpecies.end()) {
SpeciesStatus AdaptiveEngineView::getSpeciesStatus(
scratch::StateBlob& ctx,
const Species &species
) const {
const auto* state = scratch::get_state<scratch::AdaptiveEngineViewScratchPad, true>(ctx);
const SpeciesStatus status = m_baseEngine.getSpeciesStatus(ctx, species);
if (status == SpeciesStatus::ACTIVE && std::ranges::find(state->active_species, species) == state->active_species.end()) {
return SpeciesStatus::INACTIVE_FLOW;
}
return status;
}
size_t AdaptiveEngineView::getSpeciesIndex(const fourdst::atomic::Species &species) const {
const auto it = std::ranges::find(m_activeSpecies, species);
if (it != m_activeSpecies.end()) {
return static_cast<int>(std::distance(m_activeSpecies.begin(), it));
std::optional<StepDerivatives<double>> AdaptiveEngineView::getMostRecentRHSCalculation(
scratch::StateBlob &ctx
) const {
return m_baseEngine.getMostRecentRHSCalculation(ctx);
}
size_t AdaptiveEngineView::getSpeciesIndex(
scratch::StateBlob& ctx,
const Species &species
) const {
const auto *state = scratch::get_state<scratch::AdaptiveEngineViewScratchPad, true>(ctx);
const auto it = std::ranges::find(state->active_species, species);
if (it != state->active_species.end()) {
return static_cast<int>(std::distance(state->active_species.begin(), it));
} else {
LOG_ERROR(m_logger, "Species '{}' not found in active species list.", species.name());
m_logger->flush_log();
@@ -279,18 +284,11 @@ namespace gridfire::engine {
}
}
void AdaptiveEngineView::validateState() const {
if (m_isStale) {
LOG_ERROR(m_logger, "AdaptiveEngineView is stale. Please call update() before calculating RHS and energy.");
m_logger->flush_log();
throw std::runtime_error("AdaptiveEngineView is stale. Please call update() before calculating RHS and energy.");
}
}
std::pair<std::vector<AdaptiveEngineView::ReactionFlow>, fourdst::composition::Composition> AdaptiveEngineView::calculateAllReactionFlows(
scratch::StateBlob& ctx,
const NetIn &netIn
) const {
const auto& fullSpeciesList = m_baseEngine.getNetworkSpecies();
const auto& fullSpeciesList = m_baseEngine.getNetworkSpecies(ctx);
fourdst::composition::Composition composition = netIn.composition;
for (const auto& species: fullSpeciesList) {
@@ -304,10 +302,10 @@ namespace gridfire::engine {
const double rho = netIn.density; // Density in g/cm^3
std::vector<ReactionFlow> reactionFlows;
const auto& fullReactionSet = m_baseEngine.getNetworkReactions();
const auto& fullReactionSet = m_baseEngine.getNetworkReactions(ctx);
reactionFlows.reserve(fullReactionSet.size());
for (const auto& reaction : fullReactionSet) {
const double flow = m_baseEngine.calculateMolarReactionFlow(*reaction, composition, T9, rho);
const double flow = m_baseEngine.calculateMolarReactionFlow(ctx, *reaction, composition, T9, rho);
reactionFlows.push_back({reaction.get(), flow});
LOG_TRACE_L3(m_logger, "Reaction '{}' has flow rate: {:0.3E} [mol/s/g]", reaction->id(), flow);
}
@@ -315,13 +313,14 @@ namespace gridfire::engine {
}
std::unordered_set<Species> AdaptiveEngineView::findReachableSpecies(
scratch::StateBlob& ctx,
const NetIn &netIn
) const {
std::unordered_set<Species> reachable;
std::queue<Species> to_vist;
constexpr double ABUNDANCE_FLOOR = 1e-12; // Abundance floor for a species to be considered part of the initial fuel
for (const auto& species: m_baseEngine.getNetworkSpecies()) {
for (const auto& species: m_baseEngine.getNetworkSpecies(ctx)) {
if (netIn.composition.contains(species) && netIn.composition.getMassFraction(std::string(species.name())) > ABUNDANCE_FLOOR) {
if (!reachable.contains(species)) {
to_vist.push(species);
@@ -334,7 +333,7 @@ namespace gridfire::engine {
bool new_species_found_in_pass = true;
while (new_species_found_in_pass) {
new_species_found_in_pass = false;
for (const auto& reaction: m_baseEngine.getNetworkReactions()) {
for (const auto& reaction: m_baseEngine.getNetworkReactions(ctx)) {
bool all_reactants_reachable = true;
for (const auto& reactant: reaction->reactants()) {
if (!reachable.contains(reactant)) {
@@ -358,6 +357,7 @@ namespace gridfire::engine {
}
std::vector<const reaction::Reaction *> AdaptiveEngineView::cullReactionsByFlow(
scratch::StateBlob& ctx,
const std::vector<ReactionFlow> &allFlows,
const std::unordered_set<fourdst::atomic::Species> &reachableSpecies,
const fourdst::composition::Composition &comp,
@@ -401,21 +401,22 @@ namespace gridfire::engine {
}
AdaptiveEngineView::RescueSet AdaptiveEngineView::rescueEdgeSpeciesDestructionChannel(
scratch::StateBlob& ctx,
const fourdst::composition::Composition &comp,
const double T9,
const double rho,
const std::vector<Species> &activeSpecies,
const reaction::ReactionSet &activeReactions
const double rho
) const {
const auto result = m_baseEngine.getSpeciesTimescales(comp, T9, rho);
const auto result = m_baseEngine.getSpeciesTimescales(ctx, comp, T9, rho);
if (!result) {
LOG_CRITICAL(m_logger, "Failed to get species timescales due to base engine failure");
m_logger->flush_log();
throw exceptions::EngineError("Failed to get species timescales due base engine failure");
}
const auto* state = scratch::get_state<scratch::AdaptiveEngineViewScratchPad, true>(ctx);
std::unordered_map<Species, double> timescales = result.value();
std::set<Species> onlyProducedSpecies;
for (const auto& reaction : activeReactions) {
for (const auto& reaction : state->active_reactions) {
const std::vector<Species>& products = reaction->products();
onlyProducedSpecies.insert(products.begin(), products.end());
}
@@ -424,7 +425,7 @@ namespace gridfire::engine {
std::erase_if(
onlyProducedSpecies,
[&](const Species &species) {
for (const auto& reaction : activeReactions) {
for (const auto& reaction : state->active_reactions) {
if (reaction->contains_reactant(species)) {
return true; // If any active reaction consumes the species then erase it from the set.
}
@@ -444,14 +445,14 @@ namespace gridfire::engine {
std::unordered_map<Species, const reaction::Reaction*> reactionsToRescue;
for (const auto& species : onlyProducedSpecies) {
double maxSpeciesConsumptionRate = 0.0;
for (const auto& reaction : m_baseEngine.getNetworkReactions()) {
for (const auto& reaction : m_baseEngine.getNetworkReactions(ctx)) {
const bool speciesToCheckIsConsumed = reaction->contains_reactant(species);
if (!speciesToCheckIsConsumed) {
continue; // If the species is not consumed by this reaction, skip it.
}
bool allOtherReactantsAreAvailable = true;
for (const auto& reactant : reaction->reactants()) {
const bool reactantIsAvailable = std::ranges::contains(activeSpecies, reactant);
const bool reactantIsAvailable = std::ranges::contains(state->active_species, reactant);
if (!reactantIsAvailable && reactant != species) {
allOtherReactantsAreAvailable = false;
}
@@ -547,31 +548,33 @@ namespace gridfire::engine {
}
void AdaptiveEngineView::finalizeActiveSet(
scratch::StateBlob& ctx,
const std::vector<const reaction::Reaction *> &finalReactions
) {
) const {
auto* state = scratch::get_state<scratch::AdaptiveEngineViewScratchPad, true>(ctx);
std::unordered_set<Species>finalSpeciesSet;
m_activeReactions.clear();
state->active_reactions.clear();
for (const auto* reactionPtr: finalReactions) {
m_activeReactions.add_reaction(*reactionPtr);
state->active_reactions.add_reaction(*reactionPtr);
for (const auto& reactant : reactionPtr->reactants()) {
const SpeciesStatus reactantStatus = m_baseEngine.getSpeciesStatus(reactant);
const SpeciesStatus reactantStatus = m_baseEngine.getSpeciesStatus(ctx, reactant);
if (!finalSpeciesSet.contains(reactant) && (reactantStatus == SpeciesStatus::ACTIVE || reactantStatus == SpeciesStatus::EQUILIBRIUM)) {
LOG_TRACE_L3(m_logger, "Adding reactant '{}' to active species set through reaction {}.", reactant.name(), reactionPtr->id());
finalSpeciesSet.insert(reactant);
}
}
for (const auto& product : reactionPtr->products()) {
const SpeciesStatus productStatus = m_baseEngine.getSpeciesStatus(product);
const SpeciesStatus productStatus = m_baseEngine.getSpeciesStatus(ctx, product);
if (!finalSpeciesSet.contains(product) && (productStatus == SpeciesStatus::ACTIVE || productStatus == SpeciesStatus::EQUILIBRIUM)) {
LOG_TRACE_L3(m_logger, "Adding product '{}' to active species set through reaction {}.", product.name(), reactionPtr->id());
finalSpeciesSet.insert(product);
}
}
}
m_activeSpecies.clear();
m_activeSpecies = std::vector<Species>(finalSpeciesSet.begin(), finalSpeciesSet.end());
state->active_species.clear();
state->active_species = std::vector<Species>(finalSpeciesSet.begin(), finalSpeciesSet.end());
std::ranges::sort(
m_activeSpecies,
state->active_species,
[](const Species &a, const Species &b) { return a.mass() < b.mass(); }
);
}

View File

@@ -5,6 +5,10 @@
#include "fourdst/atomic/atomicSpecies.h"
#include "fourdst/composition/decorators/composition_masked.h"
#include "gridfire/engine/scratchpads/blob.h"
#include "gridfire/engine/scratchpads/engine_defined_scratchpad.h"
#include "gridfire/engine/scratchpads/utils.h"
#include "quill/LogMacros.h"
#include <string>
@@ -23,30 +27,34 @@ namespace gridfire::engine {
GraphEngine& baseEngine
) :
m_baseEngine(baseEngine) {
collect(peNames);
// collect(peNames);
}
const DynamicEngine & DefinedEngineView::getBaseEngine() const {
return m_baseEngine;
}
const std::vector<Species> & DefinedEngineView::getNetworkSpecies() const {
if (m_activeSpeciesVectorCache.has_value()) {
return m_activeSpeciesVectorCache.value();
const std::vector<Species> & DefinedEngineView::getNetworkSpecies(
scratch::StateBlob& ctx
) const {
auto* state = scratch::get_state<scratch::DefinedEngineViewScratchPad, true>(ctx);
if (state->active_species_vector_cache.has_value()) {
return state->active_species_vector_cache.value();
}
m_activeSpeciesVectorCache = std::vector<Species>(m_activeSpecies.begin(), m_activeSpecies.end());
return m_activeSpeciesVectorCache.value();
state->active_species_vector_cache = std::vector<Species>(state->active_species.begin(), state->active_species.end());
return state->active_species_vector_cache.value();
}
std::expected<StepDerivatives<double>, EngineStatus> DefinedEngineView::calculateRHSAndEnergy(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
const double rho, bool trust
) const {
validateNetworkState();
auto *state = scratch::get_state<scratch::DefinedEngineViewScratchPad, true>(ctx);
const fourdst::composition::MaskedComposition masked(comp, m_activeSpecies | std::ranges::to<std::vector>());
const auto result = m_baseEngine.calculateRHSAndEnergy(masked, T9, rho, m_activeReactions);
const fourdst::composition::MaskedComposition masked(comp, state->active_species | std::ranges::to<std::vector>());
const auto result = m_baseEngine.calculateRHSAndEnergy(ctx, masked, T9, rho, state->active_reactions);
if (!result) {
return std::unexpected{result.error()};
@@ -56,37 +64,39 @@ namespace gridfire::engine {
}
EnergyDerivatives DefinedEngineView::calculateEpsDerivatives(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
const double rho
) const {
validateNetworkState();
auto* state = scratch::get_state<scratch::DefinedEngineViewScratchPad, true>(ctx);
const fourdst::composition::MaskedComposition masked(comp, state->active_species | std::ranges::to<std::vector>());
const fourdst::composition::MaskedComposition masked(comp, m_activeSpecies | std::ranges::to<std::vector>());
return m_baseEngine.calculateEpsDerivatives(masked, T9, rho, m_activeReactions);
return m_baseEngine.calculateEpsDerivatives(ctx, masked, T9, rho, state->active_reactions);
}
NetworkJacobian DefinedEngineView::generateJacobianMatrix(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
const double rho
) const {
validateNetworkState();
if (!m_activeSpeciesVectorCache.has_value()) {
m_activeSpeciesVectorCache = std::vector<Species>(m_activeSpecies.begin(), m_activeSpecies.end());
auto* state = scratch::get_state<scratch::DefinedEngineViewScratchPad, true>(ctx);
if (!state->active_species_vector_cache.has_value()) {
state->active_species_vector_cache = std::vector<Species>(state->active_species.begin(), state->active_species.end());
}
const fourdst::composition::MaskedComposition masked(comp, m_activeSpecies | std::ranges::to<std::vector>());
return m_baseEngine.generateJacobianMatrix(masked, T9, rho, m_activeSpeciesVectorCache.value());
const fourdst::composition::MaskedComposition masked(comp, state->active_species | std::ranges::to<std::vector>());
return m_baseEngine.generateJacobianMatrix(ctx, masked, T9, rho, state->active_species_vector_cache.value());
}
NetworkJacobian DefinedEngineView::generateJacobianMatrix(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
const double rho,
const std::vector<fourdst::atomic::Species> &activeSpecies
const std::vector<Species> &activeSpecies
) const {
validateNetworkState();
const std::set<fourdst::atomic::Species> activeSpeciesSet(
activeSpecies.begin(),
@@ -94,96 +104,65 @@ namespace gridfire::engine {
);
const fourdst::composition::MaskedComposition masked(comp, activeSpeciesSet | std::ranges::to<std::vector>());
return m_baseEngine.generateJacobianMatrix(masked, T9, rho, activeSpecies);
return m_baseEngine.generateJacobianMatrix(ctx, masked, T9, rho, activeSpecies);
}
NetworkJacobian DefinedEngineView::generateJacobianMatrix(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
const double rho,
const SparsityPattern &sparsityPattern
) const {
validateNetworkState();
const fourdst::composition::MaskedComposition masked(comp, m_activeSpecies | std::ranges::to<std::vector>());
return m_baseEngine.generateJacobianMatrix(masked, T9, rho, sparsityPattern);
}
void DefinedEngineView::generateStoichiometryMatrix() {
validateNetworkState();
m_baseEngine.generateStoichiometryMatrix();
}
int DefinedEngineView::getStoichiometryMatrixEntry(
const Species& species,
const reaction::Reaction& reaction
) const {
validateNetworkState();
if (!m_activeSpecies.contains(species)) {
LOG_ERROR(m_logger, "Species '{}' is not part of the active species in the DefinedEngineView.", species.name());
m_logger -> flush_log();
throw std::runtime_error("Species not found in active species: " + std::string(species.name()));
}
if (!m_activeReactions.contains(reaction)) {
LOG_ERROR(m_logger, "Reaction '{}' is not part of the active reactions in the DefinedEngineView.", reaction.id());
m_logger -> flush_log();
throw std::runtime_error("Reaction not found in active reactions: " + std::string(reaction.id()));
}
return m_baseEngine.getStoichiometryMatrixEntry(species, reaction);
auto* state = scratch::get_state<scratch::DefinedEngineViewScratchPad, true>(ctx);
const fourdst::composition::MaskedComposition masked(comp, state->active_species | std::ranges::to<std::vector>());
return m_baseEngine.generateJacobianMatrix(ctx, masked, T9, rho, sparsityPattern);
}
double DefinedEngineView::calculateMolarReactionFlow(
scratch::StateBlob& ctx,
const reaction::Reaction &reaction,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
const double rho
) const {
validateNetworkState();
auto* state = scratch::get_state<scratch::DefinedEngineViewScratchPad, true>(ctx);
if (!m_activeReactions.contains(reaction)) {
if (!state->active_reactions.contains(reaction)) {
LOG_ERROR(m_logger, "Reaction '{}' is not part of the active reactions in the DefinedEngineView.", reaction.id());
m_logger -> flush_log();
throw std::runtime_error("Reaction not found in active reactions: " + std::string(reaction.id()));
}
const fourdst::composition::MaskedComposition masked(comp, m_activeSpecies | std::ranges::to<std::vector>());
return m_baseEngine.calculateMolarReactionFlow(reaction, masked, T9, rho);
const fourdst::composition::MaskedComposition masked(comp, state->active_species | std::ranges::to<std::vector>());
return m_baseEngine.calculateMolarReactionFlow(ctx, reaction, masked, T9, rho);
}
const reaction::ReactionSet & DefinedEngineView::getNetworkReactions() const {
validateNetworkState();
const reaction::ReactionSet & DefinedEngineView::getNetworkReactions(
scratch::StateBlob& ctx
) const {
return m_activeReactions;
}
void DefinedEngineView::setNetworkReactions(const reaction::ReactionSet &reactions) {
std::vector<std::string> peNames;
for (const auto& reaction : reactions) {
peNames.emplace_back(reaction->id());
}
collect(peNames);
m_activeSpeciesVectorCache = std::nullopt; // Invalidate species vector cache
auto* state = scratch::get_state<scratch::DefinedEngineViewScratchPad, true>(ctx);
return state->active_reactions;
}
std::expected<std::unordered_map<Species, double>, EngineStatus> DefinedEngineView::getSpeciesTimescales(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
const double rho
) const {
validateNetworkState();
const fourdst::composition::MaskedComposition masked(comp, m_activeSpecies | std::ranges::to<std::vector>());
auto* state = scratch::get_state<scratch::DefinedEngineViewScratchPad, true>(ctx);
const fourdst::composition::MaskedComposition masked(comp, state->active_species | std::ranges::to<std::vector>());
const auto result = m_baseEngine.getSpeciesTimescales(masked, T9, rho, m_activeReactions);
const auto result = m_baseEngine.getSpeciesTimescales(ctx, masked, T9, rho, state->active_reactions);
if (!result) {
return std::unexpected{result.error()};
}
const auto& fullTimescales = result.value();
std::unordered_map<Species, double> definedTimescales;
for (const auto& active_species : m_activeSpecies) {
for (const auto& active_species : state->active_species) {
if (fullTimescales.contains(active_species)) {
definedTimescales[active_species] = fullTimescales.at(active_species);
}
@@ -192,14 +171,15 @@ namespace gridfire::engine {
}
std::expected<std::unordered_map<Species, double>, EngineStatus> DefinedEngineView::getSpeciesDestructionTimescales(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
const double rho
) const {
validateNetworkState();
const fourdst::composition::MaskedComposition masked(comp, m_activeSpecies | std::ranges::to<std::vector>());
auto* state = scratch::get_state<scratch::DefinedEngineViewScratchPad, true>(ctx);
const fourdst::composition::MaskedComposition masked(comp, state->active_species| std::ranges::to<std::vector>());
const auto result = m_baseEngine.getSpeciesDestructionTimescales(masked, T9, rho, m_activeReactions);
const auto result = m_baseEngine.getSpeciesDestructionTimescales(ctx, masked, T9, rho, state->active_reactions);
if (!result) {
return std::unexpected{result.error()};
@@ -208,7 +188,7 @@ namespace gridfire::engine {
const auto& destructionTimescales = result.value();
std::unordered_map<Species, double> definedTimescales;
for (const auto& active_species : m_activeSpecies) {
for (const auto& active_species : state->active_species){
if (destructionTimescales.contains(active_species)) {
definedTimescales[active_species] = destructionTimescales.at(active_species);
}
@@ -216,29 +196,28 @@ namespace gridfire::engine {
return definedTimescales;
}
fourdst::composition::Composition DefinedEngineView::update(const NetIn &netIn) {
return m_baseEngine.update(netIn);
fourdst::composition::Composition DefinedEngineView::project(
scratch::StateBlob& ctx,
const NetIn &netIn
) const {
return m_baseEngine.project(ctx, netIn);
}
bool DefinedEngineView::isStale(const NetIn &netIn) {
return m_baseEngine.isStale(netIn);
screening::ScreeningType DefinedEngineView::getScreeningModel(
scratch::StateBlob& ctx
) const {
return m_baseEngine.getScreeningModel(ctx);
}
void DefinedEngineView::setScreeningModel(const screening::ScreeningType model) {
m_baseEngine.setScreeningModel(model);
}
size_t DefinedEngineView::getSpeciesIndex(
scratch::StateBlob& ctx,
const Species &species
) const {
auto* state = scratch::get_state<scratch::DefinedEngineViewScratchPad, true>(ctx);
screening::ScreeningType DefinedEngineView::getScreeningModel() const {
return m_baseEngine.getScreeningModel();
}
size_t DefinedEngineView::getSpeciesIndex(const Species &species) const {
// TODO: We are working to phase out all of these methods, its probably broken but it also should no longer be used and will be removed soon
validateNetworkState();
const auto it = std::ranges::find(m_activeSpecies, species);
if (it != m_activeSpecies.end()) {
return static_cast<int>(std::distance(m_activeSpecies.begin(), it));
const auto it = std::ranges::find(state->active_species, species);
if (it != state->active_species.end()) {
return static_cast<int>(std::distance(state->active_species.begin(), it));
} else {
LOG_ERROR(m_logger, "Species '{}' not found in active species list.", species.name());
m_logger->flush_log();
@@ -246,29 +225,23 @@ namespace gridfire::engine {
}
}
std::vector<double> DefinedEngineView::mapNetInToMolarAbundanceVector(const NetIn &netIn) const {
std::vector<double> Y(m_activeSpecies.size(), 0.0); // Initialize with zeros
for (const auto& [sp, y] : netIn.composition) {
auto it = std::ranges::find(m_activeSpecies, sp);
if (it != m_activeSpecies.end()) {
Y[getSpeciesIndex(sp)] = y; // Map species to their molar abundance
}
}
return Y; // Return the vector of molar abundances
}
PrimingReport DefinedEngineView::primeEngine(const NetIn &netIn) {
return m_baseEngine.primeEngine(netIn);
PrimingReport DefinedEngineView::primeEngine(
scratch::StateBlob& ctx,
const NetIn &netIn
) const {
return m_baseEngine.primeEngine(ctx, netIn);
}
fourdst::composition::Composition DefinedEngineView::collectComposition(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
const double rho
) const {
fourdst::composition::Composition result = m_baseEngine.collectComposition(comp, T9, rho);
fourdst::composition::Composition result = m_baseEngine.collectComposition(ctx, comp, T9, rho);
auto* state = scratch::get_state<scratch::DefinedEngineViewScratchPad, true>(ctx);
for (const auto& species : m_activeSpecies) {
for (const auto& species : state->active_species) {
if (!result.contains(species)) {
result.registerSpecies(species);
}
@@ -276,18 +249,30 @@ namespace gridfire::engine {
return result;
}
SpeciesStatus DefinedEngineView::getSpeciesStatus(const Species &species) const {
const SpeciesStatus status = m_baseEngine.getSpeciesStatus(species);
if (status == SpeciesStatus::ACTIVE && !m_activeSpecies.contains(species)) {
SpeciesStatus DefinedEngineView::getSpeciesStatus(
scratch::StateBlob& ctx,
const Species &species
) const {
const auto *state = scratch::get_state<scratch::DefinedEngineViewScratchPad, true>(ctx);
const SpeciesStatus status = m_baseEngine.getSpeciesStatus(ctx, species);
if (status == SpeciesStatus::ACTIVE && !state->active_species.contains(species)) {
return SpeciesStatus::INACTIVE_FLOW;
}
return status;
}
std::vector<size_t> DefinedEngineView::constructSpeciesIndexMap() const {
std::optional<StepDerivatives<double>> DefinedEngineView::getMostRecentRHSCalculation(
scratch::StateBlob &ctx
) const {
return m_baseEngine.getMostRecentRHSCalculation(ctx);
}
std::vector<size_t> DefinedEngineView::constructSpeciesIndexMap(
scratch::StateBlob& ctx
) const {
LOG_TRACE_L3(m_logger, "Constructing species index map for DefinedEngineView...");
std::unordered_map<Species, size_t> fullSpeciesReverseMap;
const auto& fullSpeciesList = m_baseEngine.getNetworkSpecies();
const auto& fullSpeciesList = m_baseEngine.getNetworkSpecies(ctx);
fullSpeciesReverseMap.reserve(fullSpeciesList.size());
@@ -296,9 +281,10 @@ namespace gridfire::engine {
}
std::vector<size_t> speciesIndexMap;
speciesIndexMap.reserve(m_activeSpecies.size());
auto* state = scratch::get_state<scratch::DefinedEngineViewScratchPad, true>(ctx);
speciesIndexMap.reserve(state->active_species.size());
for (const auto& active_species : m_activeSpecies) {
for (const auto& active_species : state->active_species) {
auto it = fullSpeciesReverseMap.find(active_species);
if (it != fullSpeciesReverseMap.end()) {
speciesIndexMap.push_back(it->second);
@@ -313,12 +299,15 @@ namespace gridfire::engine {
}
std::vector<size_t> DefinedEngineView::constructReactionIndexMap() const {
std::vector<size_t> DefinedEngineView::constructReactionIndexMap(
scratch::StateBlob& ctx
) const {
auto* state = scratch::get_state<scratch::DefinedEngineViewScratchPad, true>(ctx);
LOG_TRACE_L3(m_logger, "Constructing reaction index map for DefinedEngineView...");
// --- Step 1: Create a reverse map using the reaction's unique ID as the key. ---
std::unordered_map<std::string_view, size_t> fullReactionReverseMap;
const auto& fullReactionSet = m_baseEngine.getNetworkReactions();
const auto& fullReactionSet = m_baseEngine.getNetworkReactions(ctx);
fullReactionReverseMap.reserve(fullReactionSet.size());
for (size_t i_full = 0; i_full < fullReactionSet.size(); ++i_full) {
@@ -327,9 +316,9 @@ namespace gridfire::engine {
// --- Step 2: Build the final index map using the active reaction set. ---
std::vector<size_t> reactionIndexMap;
reactionIndexMap.reserve(m_activeReactions.size());
reactionIndexMap.reserve(state->active_reactions.size());
for (const auto& active_reaction_ptr : m_activeReactions) {
for (const auto& active_reaction_ptr : state->active_reactions) {
auto it = fullReactionReverseMap.find(active_reaction_ptr->id());
if (it != fullReactionReverseMap.end()) {
@@ -345,54 +334,66 @@ namespace gridfire::engine {
return reactionIndexMap;
}
std::vector<double> DefinedEngineView::mapViewToFull(const std::vector<double>& culled) const {
std::vector<double> full(m_baseEngine.getNetworkSpecies().size(), 0.0);
std::vector<double> DefinedEngineView::mapViewToFull(
scratch::StateBlob& ctx,
const std::vector<double>& culled
) const {
auto* state = scratch::get_state<scratch::DefinedEngineViewScratchPad, true>(ctx);
std::vector<double> full(m_baseEngine.getNetworkSpecies(ctx).size(), 0.0);
for (size_t i_culled = 0; i_culled < culled.size(); ++i_culled) {
const size_t i_full = m_speciesIndexMap[i_culled];
const size_t i_full = state->species_index_map[i_culled];
full[i_full] += culled[i_culled];
}
return full;
}
std::vector<double> DefinedEngineView::mapFullToView(const std::vector<double>& full) const {
std::vector<double> culled(m_activeSpecies.size(), 0.0);
for (size_t i_culled = 0; i_culled < m_activeSpecies.size(); ++i_culled) {
const size_t i_full = m_speciesIndexMap[i_culled];
std::vector<double> DefinedEngineView::mapFullToView(
scratch::StateBlob& ctx,
const std::vector<double>& full
) {
auto* state = scratch::get_state<scratch::DefinedEngineViewScratchPad, true>(ctx);
std::vector<double> culled(state->active_species.size(), 0.0);
for (size_t i_culled = 0; i_culled < state->active_species.size(); ++i_culled) {
const size_t i_full = state->species_index_map[i_culled];
culled[i_culled] = full[i_full];
}
return culled;
}
size_t DefinedEngineView::mapViewToFullSpeciesIndex(size_t culledSpeciesIndex) const {
if (culledSpeciesIndex >= m_speciesIndexMap.size()) {
LOG_ERROR(m_logger, "Defined index {} is out of bounds for species index map of size {}.", culledSpeciesIndex, m_speciesIndexMap.size());
size_t DefinedEngineView::mapViewToFullSpeciesIndex(
scratch::StateBlob& ctx,
size_t culledSpeciesIndex
) const {
auto* state = scratch::get_state<scratch::DefinedEngineViewScratchPad, true>(ctx);
if (culledSpeciesIndex >= state->species_index_map.size()) {
LOG_ERROR(m_logger, "Defined index {} is out of bounds for species index map of size {}.", culledSpeciesIndex, state->species_index_map.size());
m_logger->flush_log();
throw std::out_of_range("Defined index " + std::to_string(culledSpeciesIndex) + " is out of bounds for species index map of size " + std::to_string(m_speciesIndexMap.size()) + ".");
throw std::out_of_range("Defined index " + std::to_string(culledSpeciesIndex) + " is out of bounds for species index map of size " + std::to_string(state->species_index_map.size()) + ".");
}
return m_speciesIndexMap[culledSpeciesIndex];
return state->species_index_map[culledSpeciesIndex];
}
size_t DefinedEngineView::mapViewToFullReactionIndex(size_t culledReactionIndex) const {
if (culledReactionIndex >= m_reactionIndexMap.size()) {
LOG_ERROR(m_logger, "Defined index {} is out of bounds for reaction index map of size {}.", culledReactionIndex, m_reactionIndexMap.size());
size_t DefinedEngineView::mapViewToFullReactionIndex(
scratch::StateBlob& ctx,
size_t culledReactionIndex
) const {
auto* state = scratch::get_state<scratch::DefinedEngineViewScratchPad, true>(ctx);
if (culledReactionIndex >= state->reaction_index_map.size()) {
LOG_ERROR(m_logger, "Defined index {} is out of bounds for reaction index map of size {}.", culledReactionIndex, state->reaction_index_map.size());
m_logger->flush_log();
throw std::out_of_range("Defined index " + std::to_string(culledReactionIndex) + " is out of bounds for reaction index map of size " + std::to_string(m_reactionIndexMap.size()) + ".");
throw std::out_of_range("Defined index " + std::to_string(culledReactionIndex) + " is out of bounds for reaction index map of size " + std::to_string(state->reaction_index_map.size()) + ".");
}
return m_reactionIndexMap[culledReactionIndex];
return state->reaction_index_map[culledReactionIndex];
}
void DefinedEngineView::validateNetworkState() const {
if (m_isStale) {
LOG_ERROR(m_logger, "DefinedEngineView is stale. Please call update() with a valid NetIn object.");
m_logger->flush_log();
throw std::runtime_error("DefinedEngineView is stale. Please call update() with a valid NetIn object.");
}
}
void DefinedEngineView::collect(const std::vector<std::string> &peNames) {
void DefinedEngineView::collect(
scratch::StateBlob& ctx,
const std::vector<std::string> &peNames
) const {
auto* state = scratch::get_state<scratch::DefinedEngineViewScratchPad, true>(ctx);
std::unordered_set<Species> seenSpecies;
const auto& fullNetworkReactionSet = m_baseEngine.getNetworkReactions();
const auto& fullNetworkReactionSet = m_baseEngine.getNetworkReactions(ctx);
for (const auto& peName : peNames) {
if (!fullNetworkReactionSet.contains(peName)) {
LOG_ERROR(m_logger, "Reaction with name '{}' not found in the base engine's network reactions. Aborting...", peName);
@@ -403,16 +404,16 @@ namespace gridfire::engine {
for (const auto& reactant : reaction->reactants()) {
if (!seenSpecies.contains(reactant)) {
seenSpecies.insert(reactant);
m_activeSpecies.emplace(reactant);
state->active_species.emplace(reactant);
}
}
for (const auto& product : reaction->products()) {
if (!seenSpecies.contains(product)) {
seenSpecies.insert(product);
m_activeSpecies.emplace(product);
state->active_species.emplace(product);
}
}
m_activeReactions.add_reaction(*reaction);
state->active_reactions.add_reaction(*reaction);
}
LOG_TRACE_L3(m_logger, "DefinedEngineView built with {} active species and {} active reactions.", m_activeSpecies.size(), m_activeReactions.size());
LOG_TRACE_L3(m_logger, "Active species: {}", [this]() -> std::string {
@@ -437,9 +438,8 @@ namespace gridfire::engine {
}
return result;
}());
m_speciesIndexMap = constructSpeciesIndexMap();
m_reactionIndexMap = constructReactionIndexMap();
m_isStale = false;
state->species_index_map = constructSpeciesIndexMap(ctx);
state->reaction_index_map = constructReactionIndexMap(ctx);
}

View File

@@ -4,6 +4,10 @@
#include "gridfire/utils/sundials.h"
#include "gridfire/utils/logging.h"
#include "gridfire/engine/scratchpads/blob.h"
#include "gridfire/engine/scratchpads/utils.h"
#include "gridfire/engine/scratchpads/engine_multiscale_scratchpad.h"
#include <stdexcept>
#include <vector>
#include <ranges>
@@ -151,18 +155,6 @@ namespace {
return reactantSample != productSample;
}
void QuietErrorRouter(int line, const char *func, const char *file, const char *msg,
SUNErrCode err_code, void *err_user_data, SUNContext sunctx) {
// LIST OF ERRORS TO IGNORE
if (err_code == KIN_LINESEARCH_NONCONV) {
return;
}
// For everything else, use the default SUNDIALS logger (or your own)
SUNLogErrHandlerFn(line, func, file, msg, err_code, err_user_data, sunctx);
}
struct DisjointSet {
std::vector<size_t> parent;
explicit DisjointSet(const size_t n) {
@@ -170,7 +162,7 @@ namespace {
std::iota(parent.begin(), parent.end(), 0);
}
size_t find(const size_t i) {
size_t find(const size_t i) { // NOLINT(*-no-recursion)
if (parent.at(i) == i) return i;
return parent.at(i) = find(parent.at(i)); // Path compression
}
@@ -192,28 +184,16 @@ namespace gridfire::engine {
MultiscalePartitioningEngineView::MultiscalePartitioningEngineView(
DynamicEngine& baseEngine
) : m_baseEngine(baseEngine) {
const int flag = SUNContext_Create(SUN_COMM_NULL, &m_sun_ctx);
if (flag != 0) {
LOG_CRITICAL(m_logger, "Error while creating SUNContext in MultiscalePartitioningEngineView");
throw std::runtime_error("Error creating SUNContext in MultiscalePartitioningEngineView");
}
SUNContext_PushErrHandler(m_sun_ctx, QuietErrorRouter, nullptr);
}
MultiscalePartitioningEngineView::~MultiscalePartitioningEngineView() {
LOG_TRACE_L1(m_logger, "Cleaning up MultiscalePartitioningEngineView...");
m_qse_solvers.clear();
if (m_sun_ctx) {
SUNContext_Free(&m_sun_ctx);
m_sun_ctx = nullptr;
}
}
const std::vector<Species> & MultiscalePartitioningEngineView::getNetworkSpecies() const {
return m_baseEngine.getNetworkSpecies();
const std::vector<Species> & MultiscalePartitioningEngineView::getNetworkSpecies(
scratch::StateBlob& ctx
) const {
return m_baseEngine.getNetworkSpecies(ctx);
}
std::expected<StepDerivatives<double>, EngineStatus> MultiscalePartitioningEngineView::calculateRHSAndEnergy(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
const double rho,
@@ -232,7 +212,7 @@ namespace gridfire::engine {
}
return ss.str();
}());
const fourdst::composition::Composition qseComposition = getNormalizedEquilibratedComposition(comp, T9, rho, trust);
const fourdst::composition::Composition qseComposition = getNormalizedEquilibratedComposition(ctx, comp, T9, rho, trust);
LOG_TRACE_L2(m_logger, "Equilibrated composition prior to calling base engine is {}", [&qseComposition, &comp]() -> std::string {
std::stringstream ss;
size_t i = 0;
@@ -249,7 +229,7 @@ namespace gridfire::engine {
return ss.str();
}());
const auto result = m_baseEngine.calculateRHSAndEnergy(qseComposition, T9, rho, false);
const auto result = m_baseEngine.calculateRHSAndEnergy(ctx, qseComposition, T9, rho, false);
LOG_TRACE_L2(m_logger, "Base engine calculation of RHS and Energy complete.");
if (!result) {
@@ -258,9 +238,10 @@ namespace gridfire::engine {
}
auto deriv = result.value();
const auto* state = scratch::get_state<scratch::MultiscalePartitioningEngineViewScratchPad, true>(ctx);
LOG_TRACE_L2(m_logger, "Zeroing out algebraic species derivatives.");
for (const auto& species : m_algebraic_species) {
for (const auto& species : state->algebraic_species) {
deriv.dydt[species] = 0.0; // Fix the algebraic species to the equilibrium abundances we calculate.
}
LOG_TRACE_L2(m_logger, "Done Zeroing out algebraic species derivatives.");
@@ -268,24 +249,28 @@ namespace gridfire::engine {
}
EnergyDerivatives MultiscalePartitioningEngineView::calculateEpsDerivatives(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
const double rho
) const {
const fourdst::composition::Composition qseComposition = getNormalizedEquilibratedComposition(comp, T9, rho, false);
return m_baseEngine.calculateEpsDerivatives(qseComposition, T9, rho);
const fourdst::composition::Composition qseComposition = getNormalizedEquilibratedComposition(ctx, comp, T9, rho, false);
return m_baseEngine.calculateEpsDerivatives(ctx, qseComposition, T9, rho);
}
NetworkJacobian MultiscalePartitioningEngineView::generateJacobianMatrix(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
const double rho
) const {
const fourdst::composition::Composition qseComposition = getNormalizedEquilibratedComposition(comp, T9, rho, false);
return m_baseEngine.generateJacobianMatrix(qseComposition, T9, rho, m_dynamic_species);
const auto* state = scratch::get_state<scratch::MultiscalePartitioningEngineViewScratchPad, true>(ctx);
const fourdst::composition::Composition qseComposition = getNormalizedEquilibratedComposition(ctx, comp, T9, rho, false);
return m_baseEngine.generateJacobianMatrix(ctx, qseComposition, T9, rho, state->dynamic_species);
}
NetworkJacobian MultiscalePartitioningEngineView::generateJacobianMatrix(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
const double rho,
@@ -293,7 +278,7 @@ namespace gridfire::engine {
) const {
bool activeSpeciesIsSubset = true;
for (const auto& species : activeSpecies) {
if (!involvesSpecies(species)) activeSpeciesIsSubset = false;
if (!involvesSpecies(ctx, species)) activeSpeciesIsSubset = false;
}
if (!activeSpeciesIsSubset) {
std::string msg = std::format(
@@ -301,7 +286,7 @@ namespace gridfire::engine {
[&]() -> std::string {
std::stringstream ss;
for (const auto& species : activeSpecies) {
if (!this->involvesSpecies(species)) {
if (!involvesSpecies(ctx, species)) {
ss << species << " ";
}
}
@@ -314,114 +299,104 @@ namespace gridfire::engine {
std::vector<Species> dynamicActiveSpeciesIntersection;
for (const auto& species : activeSpecies) {
if (involvesSpeciesInDynamic(species)) {
if (involvesSpeciesInDynamic(ctx, species)) {
dynamicActiveSpeciesIntersection.push_back(species);
}
}
const fourdst::composition::Composition qseComposition = getNormalizedEquilibratedComposition(comp, T9, rho, false);
const fourdst::composition::Composition qseComposition = getNormalizedEquilibratedComposition(ctx, comp, T9, rho, false);
return m_baseEngine.generateJacobianMatrix(qseComposition, T9, rho, dynamicActiveSpeciesIntersection);
return m_baseEngine.generateJacobianMatrix(ctx, qseComposition, T9, rho, dynamicActiveSpeciesIntersection);
}
NetworkJacobian MultiscalePartitioningEngineView::generateJacobianMatrix(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
const double rho,
const SparsityPattern &sparsityPattern
) const {
const fourdst::composition::Composition qseComposition = getNormalizedEquilibratedComposition(comp, T9, rho, false);
return m_baseEngine.generateJacobianMatrix(qseComposition, T9, rho, sparsityPattern);
}
void MultiscalePartitioningEngineView::generateStoichiometryMatrix() {
m_baseEngine.generateStoichiometryMatrix();
}
int MultiscalePartitioningEngineView::getStoichiometryMatrixEntry(
const Species& species,
const reaction::Reaction& reaction
) const {
return m_baseEngine.getStoichiometryMatrixEntry(species, reaction);
const fourdst::composition::Composition qseComposition = getNormalizedEquilibratedComposition(ctx, comp, T9, rho, false);
return m_baseEngine.generateJacobianMatrix(ctx, qseComposition, T9, rho, sparsityPattern);
}
double MultiscalePartitioningEngineView::calculateMolarReactionFlow(
scratch::StateBlob& ctx,
const reaction::Reaction &reaction,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
const double rho
) const {
const fourdst::composition::Composition qseComposition = getNormalizedEquilibratedComposition(comp, T9, rho, false);
const fourdst::composition::Composition qseComposition = getNormalizedEquilibratedComposition(ctx, comp, T9, rho, false);
return m_baseEngine.calculateMolarReactionFlow(reaction, qseComposition, T9, rho);
return m_baseEngine.calculateMolarReactionFlow(ctx, reaction, qseComposition, T9, rho);
}
const reaction::ReactionSet & MultiscalePartitioningEngineView::getNetworkReactions() const {
return m_baseEngine.getNetworkReactions();
}
void MultiscalePartitioningEngineView::setNetworkReactions(const reaction::ReactionSet &reactions) {
LOG_CRITICAL(m_logger, "setNetworkReactions is not supported in MultiscalePartitioningEngineView. Did you mean to call this on the base engine?");
throw exceptions::UnableToSetNetworkReactionsError("setNetworkReactions is not supported in MultiscalePartitioningEngineView. Did you mean to call this on the base engine?");
const reaction::ReactionSet & MultiscalePartitioningEngineView::getNetworkReactions(
scratch::StateBlob& ctx
) const {
return m_baseEngine.getNetworkReactions(ctx);
}
std::expected<std::unordered_map<Species, double>, EngineStatus> MultiscalePartitioningEngineView::getSpeciesTimescales(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
const double rho
) const {
const fourdst::composition::Composition qseComposition = getNormalizedEquilibratedComposition(comp, T9, rho, false);
const auto result = m_baseEngine.getSpeciesTimescales(qseComposition, T9, rho);
const fourdst::composition::Composition qseComposition = getNormalizedEquilibratedComposition(ctx, comp, T9, rho, false);
const auto result = m_baseEngine.getSpeciesTimescales(ctx, qseComposition, T9, rho);
if (!result) {
return std::unexpected{result.error()};
}
const auto* state = scratch::get_state<scratch::MultiscalePartitioningEngineViewScratchPad, true>(ctx);
std::unordered_map<Species, double> speciesTimescales = result.value();
for (const auto& algebraicSpecies : m_algebraic_species) {
for (const auto& algebraicSpecies : state->algebraic_species) {
speciesTimescales[algebraicSpecies] = std::numeric_limits<double>::infinity(); // Algebraic species have infinite timescales.
}
return speciesTimescales;
}
std::expected<std::unordered_map<Species, double>, EngineStatus> MultiscalePartitioningEngineView::getSpeciesDestructionTimescales(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
const double rho
) const {
const fourdst::composition::Composition qseComposition = getNormalizedEquilibratedComposition(comp, T9, rho, false);
const auto result = m_baseEngine.getSpeciesDestructionTimescales(qseComposition, T9, rho);
const fourdst::composition::Composition qseComposition = getNormalizedEquilibratedComposition(ctx, comp, T9, rho, false);
const auto result = m_baseEngine.getSpeciesDestructionTimescales(ctx, qseComposition, T9, rho);
if (!result) {
return std::unexpected{result.error()};
}
std::unordered_map<Species, double> speciesDestructionTimescales = result.value();
for (const auto& algebraicSpecies : m_algebraic_species) {
const auto* state = scratch::get_state<scratch::MultiscalePartitioningEngineViewScratchPad, true>(ctx);
for (const auto& algebraicSpecies : state->algebraic_species) {
speciesDestructionTimescales[algebraicSpecies] = std::numeric_limits<double>::infinity(); // Algebraic species have infinite destruction timescales.
}
return speciesDestructionTimescales;
}
fourdst::composition::Composition MultiscalePartitioningEngineView::update(const NetIn &netIn) {
const fourdst::composition::Composition baseUpdatedComposition = m_baseEngine.update(netIn);
fourdst::composition::Composition MultiscalePartitioningEngineView::project(
scratch::StateBlob& ctx,
const NetIn &netIn
) const {
const fourdst::composition::Composition baseUpdatedComposition = m_baseEngine.project(ctx, netIn);
NetIn baseUpdatedNetIn = netIn;
baseUpdatedNetIn.composition = baseUpdatedComposition;
fourdst::composition::Composition equilibratedComposition = partitionNetwork(baseUpdatedNetIn);
m_composition_cache.clear();
fourdst::composition::Composition equilibratedComposition = partitionNetwork(ctx, baseUpdatedNetIn);
auto* state = scratch::get_state<scratch::MultiscalePartitioningEngineViewScratchPad, true>(ctx);
state->composition_cache.clear();
return equilibratedComposition;
}
bool MultiscalePartitioningEngineView::isStale(const NetIn &netIn) {
return m_baseEngine.isStale(netIn);
}
void MultiscalePartitioningEngineView::setScreeningModel(
const screening::ScreeningType model
) {
m_baseEngine.setScreeningModel(model);
}
screening::ScreeningType MultiscalePartitioningEngineView::getScreeningModel() const {
return m_baseEngine.getScreeningModel();
screening::ScreeningType MultiscalePartitioningEngineView::getScreeningModel(
scratch::StateBlob& ctx
) const {
return m_baseEngine.getScreeningModel(ctx);
}
const DynamicEngine & MultiscalePartitioningEngineView::getBaseEngine() const {
@@ -429,8 +404,11 @@ namespace gridfire::engine {
}
std::vector<std::vector<Species>> MultiscalePartitioningEngineView::analyzeTimescalePoolConnectivity(
const std::vector<std::vector<Species>> &timescale_pools, const fourdst::composition::Composition &comp, double T9, double
rho
scratch::StateBlob& ctx,
const std::vector<std::vector<Species>> &timescale_pools,
const fourdst::composition::Composition &comp,
double T9,
double rho
) const {
std::vector<std::vector<Species>> final_connected_pools;
@@ -440,7 +418,7 @@ namespace gridfire::engine {
}
// For each timescale pool, we need to analyze connectivity.
auto connectivity_graph = buildConnectivityGraph(pool, comp, T9, rho);
auto connectivity_graph = buildConnectivityGraph(ctx, pool, comp, T9, rho);
auto components = findConnectedComponentsBFS(connectivity_graph, pool);
final_connected_pools.insert(final_connected_pools.end(), components.begin(), components.end());
}
@@ -449,20 +427,21 @@ namespace gridfire::engine {
}
std::vector<MultiscalePartitioningEngineView::QSEGroup> MultiscalePartitioningEngineView::pruneValidatedGroups(
scratch::StateBlob& ctx,
const std::vector<QSEGroup> &groups,
const std::vector<reaction::ReactionSet> &groupReactions,
const fourdst::composition::Composition &comp,
const double T9,
const double rho
) const {
const auto result = m_baseEngine.getSpeciesTimescales(comp, T9, rho);
const auto result = m_baseEngine.getSpeciesTimescales(ctx, comp, T9, rho);
if (!result) {
throw std::runtime_error("Base engine returned stale error during pruneValidatedGroups timescale retrieval.");
}
std::unordered_map<Species, double> speciesTimescales = result.value();
const std::unordered_map<Species, double>& speciesTimescales = result.value();
std::vector<QSEGroup> newGroups;
for (const auto &[group, reactions] : std::views::zip(groups, groupReactions)) {
if (reactions.size() == 0) { // If a QSE group has gotten here it should have reactions associated with it. If it doesn't that is a serious error.
if (reactions.empty()) { // If a QSE group has gotten here it should have reactions associated with it. If it doesn't that is a serious error.
LOG_CRITICAL(m_logger, "No reactions specified for QSE group {} during pruning analysis.", group.toString(false));
throw std::runtime_error("No reactions specified for QSE group " + group.toString(false) + " during pruneValidatedGroups flux analysis.");
}
@@ -475,7 +454,7 @@ namespace gridfire::engine {
for (const auto& species : group.algebraic_species) {
mean_molar_abundance += comp.getMolarAbundance(species);
}
mean_molar_abundance /= group.algebraic_species.size();
mean_molar_abundance /= static_cast<double>(group.algebraic_species.size());
{ // Safety Valve to ensure valid log scaling
if (mean_molar_abundance <= 0) {
LOG_CRITICAL(m_logger, "Non-positive mean molar abundance {} calculated for QSE group during pruning analysis.", mean_molar_abundance);
@@ -484,7 +463,7 @@ namespace gridfire::engine {
}
for (const auto& reaction : reactions) {
const double flux = m_baseEngine.calculateMolarReactionFlow(*reaction, comp, T9, rho);
const double flux = m_baseEngine.calculateMolarReactionFlow(ctx, *reaction, comp, T9, rho);
size_t hash = reaction->hash(0);
if (reactionFluxes.contains(hash)) {
throw std::runtime_error("Duplicate reaction hash found during pruneValidatedGroups flux analysis.");
@@ -624,7 +603,7 @@ namespace gridfire::engine {
for (const auto &species : g.algebraic_species) {
meanTimescale += speciesTimescales.at(species);
}
meanTimescale /= g.algebraic_species.size();
meanTimescale /= static_cast<double>(g.algebraic_species.size());
g.mean_timescale = meanTimescale;
newGroups.push_back(g);
}
@@ -634,6 +613,7 @@ namespace gridfire::engine {
}
std::vector<MultiscalePartitioningEngineView::QSEGroup> MultiscalePartitioningEngineView::merge_coupled_groups(
scratch::StateBlob& ctx,
const std::vector<QSEGroup> &groups,
const std::vector<reaction::ReactionSet> &groupReactions
) {
@@ -688,10 +668,12 @@ namespace gridfire::engine {
}
fourdst::composition::Composition MultiscalePartitioningEngineView::partitionNetwork(
scratch::StateBlob& ctx,
const NetIn &netIn
) {
) const {
auto* state = scratch::get_state<scratch::MultiscalePartitioningEngineViewScratchPad, true>(ctx);
// --- Step 0. Prime the network ---
const PrimingReport primingReport = m_baseEngine.primeEngine(netIn);
const PrimingReport primingReport = m_baseEngine.primeEngine(ctx, netIn);
const fourdst::composition::Composition& comp = primingReport.primedComposition;
const double T9 = netIn.temperature / 1e9;
const double rho = netIn.density;
@@ -699,25 +681,25 @@ namespace gridfire::engine {
// --- Step 0.5 Clear previous state ---
LOG_TRACE_L1(m_logger, "Partitioning network...");
LOG_TRACE_L1(m_logger, "Clearing previous state...");
m_qse_groups.clear();
m_qse_solvers.clear();
m_dynamic_species.clear();
m_algebraic_species.clear();
m_composition_cache.clear(); // We need to clear the cache now cause the same comp, temp, and density may result in a different value
state->qse_groups.clear();
state->qse_solvers.clear();
state->dynamic_species.clear();
state->algebraic_species.clear();
state->composition_cache.clear(); // We need to clear the cache now cause the same comp, temp, and density may result in a different value
// --- Step 1. Identify distinct timescale regions ---
LOG_TRACE_L1(m_logger, "Identifying fast reactions...");
const std::vector<std::vector<Species>> timescale_pools = partitionByTimescale(comp, T9, rho);
const std::vector<std::vector<Species>> timescale_pools = partitionByTimescale(ctx, comp, T9, rho);
LOG_TRACE_L1(m_logger, "Found {} timescale pools.", timescale_pools.size());
// --- Step 2. Select the mean slowest pool as the base dynamical group ---
LOG_TRACE_L1(m_logger, "Identifying mean slowest pool...");
const size_t mean_slowest_pool_index = identifyMeanSlowestPool(timescale_pools, comp, T9, rho);
const size_t mean_slowest_pool_index = identifyMeanSlowestPool(ctx, timescale_pools, comp, T9, rho);
LOG_TRACE_L1(m_logger, "Mean slowest pool index: {}", mean_slowest_pool_index);
// --- Step 3. Push the slowest pool into the dynamic species list ---
for (const auto& slowSpecies : timescale_pools[mean_slowest_pool_index]) {
m_dynamic_species.push_back(slowSpecies);
state->dynamic_species.push_back(slowSpecies);
}
// --- Step 4. Pack Candidate QSE Groups ---
@@ -729,40 +711,40 @@ namespace gridfire::engine {
}
LOG_TRACE_L1(m_logger, "Preforming connectivity analysis on timescale pools...");
const std::vector<std::vector<Species>> connected_pools = analyzeTimescalePoolConnectivity(candidate_pools, comp, T9, rho);
const std::vector<std::vector<Species>> connected_pools = analyzeTimescalePoolConnectivity(ctx, candidate_pools, comp, T9, rho);
LOG_TRACE_L1(m_logger, "Found {} connected pools (compared to {} timescale pools) for QSE analysis.", connected_pools.size(), timescale_pools.size());
// --- Step 5. Identify potential seed species for each candidate pool ---
LOG_TRACE_L1(m_logger, "Identifying potential seed species for candidate pools...");
const std::vector<QSEGroup> candidate_groups = constructCandidateGroups(connected_pools, comp, T9, rho);
const std::vector<QSEGroup> candidate_groups = constructCandidateGroups(ctx, connected_pools, comp, T9, rho);
LOG_TRACE_L1(m_logger, "Found {} candidate QSE groups for further analysis ({})", candidate_groups.size(), utils::iterable_to_delimited_string(candidate_groups));
LOG_TRACE_L1(m_logger, "Validating candidate groups with flux analysis...");
const auto [validated_groups, invalidate_groups, validated_group_reactions] = validateGroupsWithFluxAnalysis(candidate_groups, comp, T9, rho);
const auto [validated_groups, invalidate_groups, validated_group_reactions] = validateGroupsWithFluxAnalysis(ctx, candidate_groups, comp, T9, rho);
LOG_TRACE_L1(m_logger, "Validated {} group(s) QSE groups. {}", validated_groups.size(), utils::iterable_to_delimited_string(validated_groups));
LOG_TRACE_L1(m_logger, "Pruning groups based on log abundance-normalized flux analysis...");
const std::vector<QSEGroup> prunedGroups = pruneValidatedGroups(validated_groups, validated_group_reactions, comp, T9, rho);
const std::vector<QSEGroup> prunedGroups = pruneValidatedGroups(ctx, validated_groups, validated_group_reactions, comp, T9, rho);
LOG_TRACE_L1(m_logger, "After Pruning remaining groups are: {}", utils::iterable_to_delimited_string(prunedGroups));
LOG_TRACE_L1(m_logger, "Re-validating pruned groups with flux analysis...");
auto [pruned_validated_groups, _, pruned_validated_reactions] = validateGroupsWithFluxAnalysis(prunedGroups, comp, T9, rho);
auto [pruned_validated_groups, _, pruned_validated_reactions] = validateGroupsWithFluxAnalysis(ctx, prunedGroups, comp, T9, rho);
LOG_TRACE_L1(m_logger, "After re-validation, {} QSE groups remain. ({})",pruned_validated_groups.size(), utils::iterable_to_delimited_string(pruned_validated_groups));
LOG_TRACE_L1(m_logger, "Merging coupled QSE groups...");
const std::vector<QSEGroup> merged_groups = merge_coupled_groups(pruned_validated_groups, pruned_validated_reactions);
const std::vector<QSEGroup> merged_groups = merge_coupled_groups(ctx, pruned_validated_groups, pruned_validated_reactions);
LOG_TRACE_L1(m_logger, "After merging, {} QSE groups remain. ({})", merged_groups.size(), utils::iterable_to_delimited_string(merged_groups));
m_qse_groups = pruned_validated_groups;
state->qse_groups = pruned_validated_groups;
LOG_TRACE_L1(m_logger, "Pushing all identified algebraic species into algebraic set...");
for (const auto& group : m_qse_groups) {
for (const auto& group : state->qse_groups) {
// Add algebraic species to the algebraic set
for (const auto& species : group.algebraic_species) {
if (std::ranges::find(m_algebraic_species, species) == m_algebraic_species.end()) {
m_algebraic_species.push_back(species);
if (std::ranges::find(state->algebraic_species, species) == state->algebraic_species.end()) {
state->algebraic_species.push_back(species);
}
}
}
@@ -771,46 +753,47 @@ namespace gridfire::engine {
LOG_INFO(
m_logger,
"Partitioning complete. Found {} dynamic species, {} algebraic (QSE) species ({}) spread over {} QSE group{}.",
m_dynamic_species.size(),
m_algebraic_species.size(),
utils::iterable_to_delimited_string(m_algebraic_species),
m_qse_groups.size(),
m_qse_groups.size() == 1 ? "" : "s"
state->dynamic_species.size(),
state->algebraic_species.size(),
utils::iterable_to_delimited_string(state->algebraic_species),
state->qse_groups.size(),
state->qse_groups.size() == 1 ? "" : "s"
);
// Sort the QSE groups by mean timescale so that fastest groups get equilibrated first (as these may feed slower groups)
LOG_TRACE_L1(m_logger, "Sorting algebraic set by mean timescale...");
std::ranges::sort(m_qse_groups, [](const QSEGroup& a, const QSEGroup& b) {
std::ranges::sort(state->qse_groups, [](const QSEGroup& a, const QSEGroup& b) {
return a.mean_timescale < b.mean_timescale;
});
LOG_TRACE_L1(m_logger, "Finalizing dynamic species list...");
for (const auto& species : m_baseEngine.getNetworkSpecies()) {
const bool involvesAlgebraic = involvesSpeciesInQSE(species);
if (std::ranges::find(m_dynamic_species, species) == m_dynamic_species.end() && !involvesAlgebraic) {
m_dynamic_species.push_back(species);
for (const auto& species : m_baseEngine.getNetworkSpecies(ctx)) {
const bool involvesAlgebraic = involvesSpeciesInQSE(ctx, species);
if (std::ranges::find(state->dynamic_species, species) == state->dynamic_species.end() && !involvesAlgebraic) {
state->dynamic_species.push_back(species);
}
}
LOG_TRACE_L1(m_logger, "Final dynamic species set: {}", utils::iterable_to_delimited_string(m_dynamic_species));
LOG_TRACE_L1(m_logger, "Creating QSE solvers for each identified QSE group...");
for (const auto& group : m_qse_groups) {
for (const auto& group : state->qse_groups) {
std::vector<Species> groupAlgebraicSpecies;
for (const auto& species : group.algebraic_species) {
groupAlgebraicSpecies.push_back(species);
}
m_qse_solvers.push_back(std::make_unique<QSESolver>(groupAlgebraicSpecies, m_baseEngine, m_sun_ctx));
state->qse_solvers.push_back(std::make_unique<QSESolver>(groupAlgebraicSpecies, m_baseEngine, state->sun_ctx));
}
LOG_TRACE_L1(m_logger, "{} QSE solvers created.", m_qse_solvers.size());
LOG_TRACE_L1(m_logger, "Calculating final equilibrated composition...");
fourdst::composition::Composition result = getNormalizedEquilibratedComposition(comp, T9, rho, false);
fourdst::composition::Composition result = getNormalizedEquilibratedComposition(ctx, comp, T9, rho, false);
LOG_TRACE_L1(m_logger, "Final equilibrated composition calculated...");
return result;
}
void MultiscalePartitioningEngineView::exportToDot(
scratch::StateBlob &ctx,
const std::string &filename,
const fourdst::composition::Composition &comp,
const double T9,
@@ -822,22 +805,24 @@ namespace gridfire::engine {
throw std::runtime_error("Failed to open file for writing: " + filename);
}
const auto& all_species = m_baseEngine.getNetworkSpecies();
const auto& all_reactions = m_baseEngine.getNetworkReactions();
const auto* state = scratch::get_state<scratch::MultiscalePartitioningEngineViewScratchPad, true>(ctx);
const auto& all_species = m_baseEngine.getNetworkSpecies(ctx);
const auto& all_reactions = m_baseEngine.getNetworkReactions(ctx);
// --- 1. Pre-computation and Categorization ---
// Categorize species into algebraic, seed, and core dynamic
std::unordered_set<Species> algebraic_species;
std::unordered_set<Species> seed_species;
for (const auto& group : m_qse_groups) {
for (const auto& group : state->qse_groups) {
if (group.is_in_equilibrium) {
algebraic_species.insert(group.algebraic_species.begin(), group.algebraic_species.end());
seed_species.insert(group.seed_species.begin(), group.seed_species.end());
}
}
const fourdst::composition::Composition qseComposition = getNormalizedEquilibratedComposition(comp, T9, rho, false);
const fourdst::composition::Composition qseComposition = getNormalizedEquilibratedComposition(ctx, comp, T9, rho, false);
// Calculate reaction flows and find min/max for logarithmic scaling of transparency
std::vector<double> reaction_flows;
reaction_flows.reserve(all_reactions.size());
@@ -845,7 +830,7 @@ namespace gridfire::engine {
double max_log_flow = std::numeric_limits<double>::lowest();
for (const auto& reaction : all_reactions) {
double flow = std::abs(m_baseEngine.calculateMolarReactionFlow(*reaction, qseComposition, T9, rho));
double flow = std::abs(m_baseEngine.calculateMolarReactionFlow(ctx, *reaction, qseComposition, T9, rho));
reaction_flows.push_back(flow);
if (flow > 1e-99) { // Avoid log(0)
double log_flow = std::log10(flow);
@@ -875,7 +860,7 @@ namespace gridfire::engine {
fillcolor = "#e0f2fe"; // Light Blue: Algebraic (in QSE)
} else if (seed_species.contains(species)) {
fillcolor = "#a7f3d0"; // Light Green: Seed (Dynamic, feeds a QSE group)
} else if (std::ranges::contains(m_dynamic_species, species)) {
} else if (std::ranges::contains(state->dynamic_species, species)) {
fillcolor = "#dcfce7"; // Pale Green: Core Dynamic
}
dotFile << " \"" << species.name() << "\" [label=\"" << species.name() << "\", fillcolor=\"" << fillcolor << "\"];\n";
@@ -918,7 +903,7 @@ namespace gridfire::engine {
// Draw a prominent box around the algebraic species of each valid QSE group.
dotFile << " // --- QSE Group Clusters ---\n";
int group_counter = 0;
for (const auto& group : m_qse_groups) {
for (const auto& group : state->qse_groups) {
if (!group.is_in_equilibrium || group.algebraic_species.empty()) {
continue;
}
@@ -1019,58 +1004,64 @@ namespace gridfire::engine {
dotFile.close();
}
std::vector<double> MultiscalePartitioningEngineView::mapNetInToMolarAbundanceVector(const NetIn &netIn) const {
std::vector<double> Y(netIn.composition.size(), 0.0); // Initialize with zeros
for (const auto& [sp, y] : netIn.composition) {
Y[getSpeciesIndex(sp)] = y; // Map species to their molar abundance
}
return Y; // Return the vector of molar abundances
}
std::vector<Species> MultiscalePartitioningEngineView::getFastSpecies() const {
const auto& all_species = m_baseEngine.getNetworkSpecies();
std::vector<Species> MultiscalePartitioningEngineView::getFastSpecies(
scratch::StateBlob& ctx
) const {
const auto& all_species = m_baseEngine.getNetworkSpecies(ctx);
const auto* state = scratch::get_state<scratch::MultiscalePartitioningEngineViewScratchPad, true>(ctx);
std::vector<Species> fast_species;
fast_species.reserve(all_species.size() - m_dynamic_species.size());
fast_species.reserve(all_species.size() - state->dynamic_species.size());
for (const auto& species : all_species) {
auto it = std::ranges::find(m_dynamic_species, species);
if (it == m_dynamic_species.end()) {
auto it = std::ranges::find(state->dynamic_species, species);
if (it == state->dynamic_species.end()) {
fast_species.push_back(species);
}
}
return fast_species;
}
const std::vector<Species> & MultiscalePartitioningEngineView::getDynamicSpecies() const {
return m_dynamic_species;
const std::vector<Species> & MultiscalePartitioningEngineView::getDynamicSpecies(
scratch::StateBlob& ctx
) {
const auto* state = scratch::get_state<scratch::MultiscalePartitioningEngineViewScratchPad, true>(ctx);
return state->dynamic_species;
}
PrimingReport MultiscalePartitioningEngineView::primeEngine(const NetIn &netIn) {
return m_baseEngine.primeEngine(netIn);
PrimingReport MultiscalePartitioningEngineView::primeEngine(
scratch::StateBlob& ctx,
const NetIn &netIn
) const {
return m_baseEngine.primeEngine(ctx, netIn);
}
bool MultiscalePartitioningEngineView::involvesSpecies(
scratch::StateBlob& ctx,
const Species &species
) const {
if (involvesSpeciesInQSE(species)) return true; // check this first since the vector is likely to be smaller so short circuit cost is less on average
if (involvesSpeciesInDynamic(species)) return true;
) {
if (involvesSpeciesInQSE(ctx, species)) return true; // check this first since the vector is likely to be smaller so short circuit cost is less on average
if (involvesSpeciesInDynamic(ctx, species)) return true;
return false;
}
bool MultiscalePartitioningEngineView::involvesSpeciesInQSE(
scratch::StateBlob& ctx,
const Species &species
) const {
return std::ranges::find(m_algebraic_species, species) != m_algebraic_species.end();
) {
const auto* state = scratch::get_state<scratch::MultiscalePartitioningEngineViewScratchPad, true>(ctx);
return std::ranges::find(state->algebraic_species, species) != state->algebraic_species.end();
}
bool MultiscalePartitioningEngineView::involvesSpeciesInDynamic(
scratch::StateBlob& ctx,
const Species &species
) const {
return std::ranges::find(m_dynamic_species, species) != m_dynamic_species.end();
) {
const auto* state = scratch::get_state<scratch::MultiscalePartitioningEngineViewScratchPad, true>(ctx);
return std::ranges::find(state->dynamic_species, species) != state->dynamic_species.end();
}
fourdst::composition::Composition MultiscalePartitioningEngineView::getNormalizedEquilibratedComposition(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract& comp,
const double T9,
const double rho,
@@ -1086,54 +1077,69 @@ namespace gridfire::engine {
std::hash<double>()(rho)
};
auto* state = scratch::get_state<scratch::MultiscalePartitioningEngineViewScratchPad, true>(ctx);
const uint64_t composite_hash = XXHash64::hash(hashes.begin(), sizeof(uint64_t) * 3, 0);
if (m_composition_cache.contains(composite_hash)) {
if (state->composition_cache.contains(composite_hash)) {
LOG_TRACE_L3(m_logger, "Cache Hit in Multiscale Partitioning Engine View for composition at T9 = {}, rho = {}.", T9, rho);
return m_composition_cache.at(composite_hash);
return state->composition_cache.at(composite_hash);
}
LOG_TRACE_L3(m_logger, "Cache Miss in Multiscale Partitioning Engine View for composition at T9 = {}, rho = {}. Solving QSE abundances...", T9, rho);
// Only solve if the composition and thermodynamic conditions have not been cached yet
fourdst::composition::Composition qseComposition(solveQSEAbundances(comp, T9, rho));
fourdst::composition::Composition qseComposition(solveQSEAbundances(ctx, comp, T9, rho));
m_composition_cache[composite_hash] = qseComposition;
state->composition_cache[composite_hash] = qseComposition;
return qseComposition;
}
fourdst::composition::Composition MultiscalePartitioningEngineView::collectComposition(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
const double rho
) const {
const fourdst::composition::Composition result = m_baseEngine.collectComposition(comp, T9, rho);
const fourdst::composition::Composition result = m_baseEngine.collectComposition(ctx, comp, T9, rho);
fourdst::composition::Composition qseComposition = getNormalizedEquilibratedComposition(result, T9, rho, false);
fourdst::composition::Composition qseComposition = getNormalizedEquilibratedComposition(ctx, result, T9, rho, false);
return qseComposition;
}
SpeciesStatus MultiscalePartitioningEngineView::getSpeciesStatus(const Species &species) const {
const SpeciesStatus status = m_baseEngine.getSpeciesStatus(species);
if (status == SpeciesStatus::ACTIVE && involvesSpeciesInQSE(species)) {
SpeciesStatus MultiscalePartitioningEngineView::getSpeciesStatus(
scratch::StateBlob& ctx,
const Species &species
) const {
const SpeciesStatus status = m_baseEngine.getSpeciesStatus(ctx, species);
if (status == SpeciesStatus::ACTIVE && involvesSpeciesInQSE(ctx, species)) {
return SpeciesStatus::EQUILIBRIUM;
}
return status;
}
size_t MultiscalePartitioningEngineView::getSpeciesIndex(const Species &species) const {
return m_baseEngine.getSpeciesIndex(species);
std::optional<StepDerivatives<double>> MultiscalePartitioningEngineView::getMostRecentRHSCalculation(
scratch::StateBlob& ctx
) const {
return m_baseEngine.getMostRecentRHSCalculation(ctx);
}
size_t MultiscalePartitioningEngineView::getSpeciesIndex(
scratch::StateBlob& ctx,
const Species &species
) const {
return m_baseEngine.getSpeciesIndex(ctx, species);
}
std::vector<std::vector<Species>> MultiscalePartitioningEngineView::partitionByTimescale(
scratch::StateBlob& ctx,
const fourdst::composition::Composition &comp,
const double T9,
const double rho
) const {
LOG_TRACE_L1(m_logger, "Partitioning by timescale...");
const auto destructionTimescale= m_baseEngine.getSpeciesDestructionTimescales(comp, T9, rho);
const auto netTimescale = m_baseEngine.getSpeciesTimescales(comp, T9, rho);
const auto destructionTimescale= m_baseEngine.getSpeciesDestructionTimescales(ctx, comp, T9, rho);
const auto netTimescale = m_baseEngine.getSpeciesTimescales(ctx, comp, T9, rho);
if (!destructionTimescale || !netTimescale) {
LOG_CRITICAL(m_logger, "Failed to compute species timescales for partitioning due to base engine error.");
@@ -1155,7 +1161,7 @@ namespace gridfire::engine {
}()
);
const auto& all_species = m_baseEngine.getNetworkSpecies();
const auto& all_species = m_baseEngine.getNetworkSpecies(ctx);
std::vector<std::pair<double, Species>> sorted_destruction_timescales;
for (const auto & species : all_species) {
@@ -1311,6 +1317,7 @@ namespace gridfire::engine {
}
std::pair<bool, reaction::ReactionSet> MultiscalePartitioningEngineView::group_is_a_qse_cluster(
scratch::StateBlob& ctx,
const fourdst::composition::Composition &comp,
const double T9,
const double rho,
@@ -1332,8 +1339,8 @@ namespace gridfire::engine {
double coupling_flux = 0.0;
double leakage_flux = 0.0;
for (const auto& reaction: m_baseEngine.getNetworkReactions()) {
const double flow = std::abs(m_baseEngine.calculateMolarReactionFlow(*reaction, comp, T9, rho));
for (const auto& reaction: m_baseEngine.getNetworkReactions(ctx)) {
const double flow = std::abs(m_baseEngine.calculateMolarReactionFlow(ctx, *reaction, comp, T9, rho));
if (flow == 0.0) {
continue; // Skip reactions with zero flow
}
@@ -1422,10 +1429,11 @@ namespace gridfire::engine {
}
bool MultiscalePartitioningEngineView::group_is_a_qse_pipeline(
const fourdst::composition::Composition &comp,
const double T9,
const double rho,
const QSEGroup &group
scratch::StateBlob& ctx,
const fourdst::composition::Composition &comp,
const double T9,
const double rho,
const QSEGroup &group
) const {
// Total fluxes (Standard check)
double total_prod = 0.0;
@@ -1435,8 +1443,8 @@ namespace gridfire::engine {
double charged_prod = 0.0;
double charged_dest = 0.0;
for (const auto& reaction : m_baseEngine.getNetworkReactions()) {
const double flow = m_baseEngine.calculateMolarReactionFlow(*reaction, comp, T9, rho);
for (const auto& reaction : m_baseEngine.getNetworkReactions(ctx)) {
const double flow = m_baseEngine.calculateMolarReactionFlow(ctx, *reaction, comp, T9, rho);
if (std::abs(flow) < 1.0e-99) continue;
int groupNetStoichiometry = 0;
@@ -1476,6 +1484,7 @@ namespace gridfire::engine {
MultiscalePartitioningEngineView::FluxValidationResult MultiscalePartitioningEngineView::validateGroupsWithFluxAnalysis(
scratch::StateBlob& ctx,
const std::vector<QSEGroup> &candidate_groups,
const fourdst::composition::Composition &comp,
const double T9, const double rho
@@ -1487,10 +1496,10 @@ namespace gridfire::engine {
group_reactions.reserve(candidate_groups.size());
for (auto& group : candidate_groups) {
// Values for measuring the flux coupling vs leakage
auto [leakage_coupled, group_reaction_set] = group_is_a_qse_cluster(comp, T9, rho, group);
auto [leakage_coupled, group_reaction_set] = group_is_a_qse_cluster(ctx, comp, T9, rho, group);
bool is_flow_balanced = group_is_a_qse_pipeline(comp, T9, rho, group);
bool is_flow_balanced = group_is_a_qse_pipeline(ctx, comp, T9, rho, group);
if (leakage_coupled) {
LOG_TRACE_L1(m_logger, "{} is in equilibrium due to high coupling flux", group.toString(false));
@@ -1516,21 +1525,23 @@ namespace gridfire::engine {
}
fourdst::composition::Composition MultiscalePartitioningEngineView::solveQSEAbundances(
scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
const double rho
) const {
LOG_TRACE_L2(m_logger, "Solving for QSE abundances...");
auto* state = scratch::get_state<scratch::MultiscalePartitioningEngineViewScratchPad, true>(ctx);
fourdst::composition::Composition outputComposition(comp);
std::vector<Species> species;
std::vector<double> abundances;
species.reserve(m_algebraic_species.size());
abundances.reserve(m_algebraic_species.size());
species.reserve(state->algebraic_species.size());
abundances.reserve(state->algebraic_species.size());
for (const auto& [group, solver]: std::views::zip(m_qse_groups, m_qse_solvers)) {
const fourdst::composition::Composition& groupResult = solver->solve(outputComposition, T9, rho);
for (const auto& [group, solver]: std::views::zip(state->qse_groups, state->qse_solvers)) {
const fourdst::composition::Composition& groupResult = solver->solve(ctx, outputComposition, T9, rho);
for (const auto& [sp, y] : groupResult) {
if (!std::isfinite(y)) {
LOG_CRITICAL(m_logger, "Non-finite abundance {} computed for species {} in QSE group solve at T9 = {}, rho = {}.",
@@ -1553,12 +1564,13 @@ namespace gridfire::engine {
}
size_t MultiscalePartitioningEngineView::identifyMeanSlowestPool(
scratch::StateBlob& ctx,
const std::vector<std::vector<Species>> &pools,
const fourdst::composition::Composition &comp,
const double T9,
const double rho
) const {
const auto& result = m_baseEngine.getSpeciesDestructionTimescales(comp, T9, rho);
const auto& result = m_baseEngine.getSpeciesDestructionTimescales(ctx, comp, T9, rho);
if (!result) {
LOG_CRITICAL(m_logger, "Failed to get species destruction timescales due base engine failure");
m_logger->flush_log();
@@ -1603,6 +1615,7 @@ namespace gridfire::engine {
}
std::unordered_map<Species, std::vector<Species>> MultiscalePartitioningEngineView::buildConnectivityGraph(
scratch::StateBlob& ctx,
const std::vector<Species> &species_pool,
const fourdst::composition::Composition &comp,
double T9,
@@ -1622,7 +1635,7 @@ namespace gridfire::engine {
std::map<size_t, std::vector<reaction::LogicalReaclibReaction*>> speciesReactionMap;
std::vector<const reaction::LogicalReaclibReaction*> candidate_reactions;
for (const auto& reaction : m_baseEngine.getNetworkReactions()) {
for (const auto& reaction : m_baseEngine.getNetworkReactions(ctx)) {
const std::vector<Species> &reactants = reaction->reactants();
const std::vector<Species> &products = reaction->products();
@@ -1660,13 +1673,14 @@ namespace gridfire::engine {
}
std::vector<MultiscalePartitioningEngineView::QSEGroup> MultiscalePartitioningEngineView::constructCandidateGroups(
scratch::StateBlob& ctx,
const std::vector<std::vector<Species>> &candidate_pools,
const fourdst::composition::Composition &comp,
const double T9,
const double rho
) const {
const auto& all_reactions = m_baseEngine.getNetworkReactions();
const auto& result = m_baseEngine.getSpeciesDestructionTimescales(comp, T9, rho);
const auto& all_reactions = m_baseEngine.getNetworkReactions(ctx);
const auto& result = m_baseEngine.getSpeciesDestructionTimescales(ctx, comp, T9, rho);
if (!result) {
LOG_ERROR(m_logger, "Failed to get species destruction timescales due base engine failure");
m_logger->flush_log();
@@ -1694,7 +1708,7 @@ namespace gridfire::engine {
}
}
if (has_external_reactant) {
double flow = std::abs(m_baseEngine.calculateMolarReactionFlow(*reaction, comp, T9, rho));
double flow = std::abs(m_baseEngine.calculateMolarReactionFlow(ctx, *reaction, comp, T9, rho));
LOG_TRACE_L3(m_logger, "Found bridge reaction {} with flow {} for species {}.", reaction->id(), flow, ash.name());
bridge_reactions.emplace_back(reaction.get(), flow);
}
@@ -1872,6 +1886,7 @@ namespace gridfire::engine {
}
fourdst::composition::Composition MultiscalePartitioningEngineView::QSESolver::solve(
scratch::StateBlob& ctx,
const fourdst::composition::Composition &comp,
const double T9,
const double rho
@@ -1885,7 +1900,8 @@ namespace gridfire::engine {
result,
m_speciesMap,
m_species,
*this
*this,
ctx
};
utils::check_sundials_flag(KINSetUserData(m_kinsol_mem, &data), "KINSetUserData", utils::SUNDIALS_RET_CODE_TYPES::KINSOL);
@@ -1905,9 +1921,9 @@ namespace gridfire::engine {
}
StepDerivatives<double> rhsGuess;
auto cached_rhs = m_engine.getMostRecentRHSCalculation();
auto cached_rhs = m_engine.getMostRecentRHSCalculation(ctx);
if (!cached_rhs) {
const auto initial_rhs = m_engine.calculateRHSAndEnergy(result, T9, rho, false);
const auto initial_rhs = m_engine.calculateRHSAndEnergy(ctx, result, T9, rho, false);
if (!initial_rhs) {
throw std::runtime_error("In QSE solver failed to calculate initial RHS for caching");
}
@@ -2063,6 +2079,16 @@ namespace gridfire::engine {
getLogger()->flush_log(true);
}
std::unique_ptr<MultiscalePartitioningEngineView::QSESolver> MultiscalePartitioningEngineView::QSESolver::clone() const {
auto new_solver = std::make_unique<QSESolver>(m_species, m_engine, m_sun_ctx);
return new_solver;
}
std::unique_ptr<MultiscalePartitioningEngineView::QSESolver> MultiscalePartitioningEngineView::QSESolver::clone(SUNContext sun_ctx) const {
auto new_solver = std::make_unique<QSESolver>(m_species, m_engine, sun_ctx);
return new_solver;
}
int MultiscalePartitioningEngineView::QSESolver::sys_func(
const N_Vector y,
@@ -2086,7 +2112,7 @@ namespace gridfire::engine {
data->comp.setMolarAbundance(species, y_data[index]);
}
const auto result = data->engine.calculateRHSAndEnergy(data->comp, data->T9, data->rho, false);
const auto result = data->engine.calculateRHSAndEnergy(data->ctx, data->comp, data->T9, data->rho, false);
if (!result) {
return 1; // Potentially recoverable error
@@ -2102,7 +2128,7 @@ namespace gridfire::engine {
for (const auto &s: map | std::views::keys) {
const double v = dydt.at(s);
if (!std::isfinite(v)) {
invalid_species.push_back(std::make_pair(s, v));
invalid_species.emplace_back(s, v);
}
}
std::string msg = std::format("Non-finite dydt values encountered for species: {}",
@@ -2150,6 +2176,7 @@ namespace gridfire::engine {
}
const NetworkJacobian jac = data->engine.generateJacobianMatrix(
data->ctx,
data->comp,
data->T9,
data->rho,
@@ -2159,9 +2186,6 @@ namespace gridfire::engine {
sunrealtype* J_data = SUNDenseMatrix_Data(J);
const sunindextype N = SUNDenseMatrix_Columns(J);
if (data->row_scaling_factors.size() != static_cast<size_t>(N)) {
data->row_scaling_factors.resize(N, 0.0);
}
for (const auto& [row_species, row_idx]: map) {
double max_value = std::numeric_limits<double>::lowest();

View File

@@ -1,5 +1,5 @@
#include "gridfire/engine/views/engine_priming.h"
#include "gridfire/solver/solver.h"
#include "gridfire/engine/scratchpads/blob.h"
#include "fourdst/atomic/species.h"
@@ -11,16 +11,17 @@
#include <unordered_set>
#include <unordered_map>
namespace gridfire::engine {
using fourdst::atomic::species;
NetworkPrimingEngineView::NetworkPrimingEngineView(
scratch::StateBlob& ctx,
const std::string &primingSymbol,
GraphEngine &baseEngine
) :
DefinedEngineView(
constructPrimingReactionSet(
ctx,
species.at(primingSymbol),
baseEngine
),
@@ -29,26 +30,27 @@ namespace gridfire::engine {
m_primingSpecies(species.at(primingSymbol)) {}
NetworkPrimingEngineView::NetworkPrimingEngineView(
scratch::StateBlob& ctx,
const fourdst::atomic::Species &primingSpecies,
GraphEngine &baseEngine
) :
DefinedEngineView(
constructPrimingReactionSet(
ctx,
primingSpecies,
baseEngine
),
baseEngine
),
m_primingSpecies(primingSpecies) {
}
m_primingSpecies(primingSpecies) {}
std::vector<std::string> NetworkPrimingEngineView::constructPrimingReactionSet(
scratch::StateBlob& ctx,
const fourdst::atomic::Species &primingSpecies,
const GraphEngine &baseEngine
) const {
std::unordered_set<std::string> primeReactions;
for (const auto &reaction : baseEngine.getNetworkReactions()) {
for (const auto &reaction : baseEngine.getNetworkReactions(ctx)) {
if (reaction->contains(primingSpecies)) {
primeReactions.insert(std::string(reaction->id()));
}

View File

@@ -9,6 +9,7 @@
#include <optional>
#include "gridfire/engine/engine_abstract.h"
#include "gridfire/engine/scratchpads/blob.h"
namespace {
template <typename T>
@@ -137,8 +138,11 @@ namespace gridfire::io::gen {
}
std::string exportEngineToPy(const engine::DynamicEngine& engine) {
auto reactions = engine.getNetworkReactions();
std::string exportEngineToPy(
engine::scratch::StateBlob& ctx,
const engine::DynamicEngine& engine
) {
auto reactions = engine.getNetworkReactions(ctx);
std::vector<std::string> functions;
functions.emplace_back(R"(import numpy as np
from typing import Dict, List, Tuple, Callable)");
@@ -150,8 +154,8 @@ from typing import Dict, List, Tuple, Callable)");
return join<std::string>(functions, "\n\n");
}
void exportEngineToPy(const engine::DynamicEngine &engine, const std::string &fileName) {
const std::string funcCode = exportEngineToPy(engine);
void exportEngineToPy(engine::scratch::StateBlob &ctx, const engine::DynamicEngine &engine, const std::string &fileName) {
const std::string funcCode = exportEngineToPy(ctx, engine);
std::ofstream outFile(fileName);
outFile << funcCode;
outFile.close();

View File

@@ -5,6 +5,13 @@
#include "gridfire/engine/engine_abstract.h"
#include "gridfire/engine/engine_graph.h"
#include "gridfire/engine/views/engine_views.h"
#include "gridfire/utils/logging.h"
#include "gridfire/engine/scratchpads/blob.h"
#include "gridfire/engine/scratchpads/utils.h"
#include "gridfire/engine/scratchpads/engine_graph_scratchpad.h"
#include "gridfire/engine/scratchpads/engine_adaptive_scratchpad.h"
#include "gridfire/engine/scratchpads/engine_multiscale_scratchpad.h"
#include "fourdst/atomic/species.h"
#include "fourdst/composition/utils.h"
@@ -47,17 +54,13 @@ namespace gridfire::policy {
m_partition_function = build_partition_function();
}
engine::DynamicEngine& MainSequencePolicy::construct() {
ConstructionResults MainSequencePolicy::construct() {
m_network_stack.clear();
m_network_stack.emplace_back(
std::make_unique<engine::GraphEngine>(m_initializing_composition, *m_partition_function, engine::NetworkBuildDepth::ThirdOrder, engine::NetworkConstructionFlags::DEFAULT)
);
auto& graphRepr = dynamic_cast<engine::GraphEngine&>(*m_network_stack.back().get());
graphRepr.setUseReverseReactions(false);
m_network_stack.emplace_back(
std::make_unique<engine::MultiscalePartitioningEngineView>(*m_network_stack.back().get())
);
@@ -65,8 +68,9 @@ namespace gridfire::policy {
std::make_unique<engine::AdaptiveEngineView>(*m_network_stack.back().get())
);
std::unique_ptr<engine::scratch::StateBlob> scratch_blob = get_stack_scratch_blob();
m_status = NetworkPolicyStatus::INITIALIZED_UNVERIFIED;
m_status = check_status();
m_status = check_status(*scratch_blob);
switch (m_status) {
case NetworkPolicyStatus::MISSING_KEY_REACTION:
@@ -80,7 +84,7 @@ namespace gridfire::policy {
case NetworkPolicyStatus::INITIALIZED_VERIFIED:
break;
}
return *m_network_stack.back();
return {.engine = *m_network_stack.back(), .scratch_blob = std::move(scratch_blob)};
}
inline std::unique_ptr<partition::PartitionFunction> MainSequencePolicy::build_partition_function() {
@@ -115,13 +119,49 @@ namespace gridfire::policy {
return m_partition_function;
}
inline NetworkPolicyStatus MainSequencePolicy::check_status() const {
std::unique_ptr<engine::scratch::StateBlob> MainSequencePolicy::get_stack_scratch_blob() const {
if (m_network_stack.empty()) {
throw exceptions::PolicyError("Cannot get stack scratch blob from MainSequencePolicy: Engine stack is empty. Call construct() first.");
}
auto blob = std::make_unique<engine::scratch::StateBlob>();
blob->enroll<engine::scratch::GraphEngineScratchPad>();
blob->enroll<engine::scratch::AdaptiveEngineViewScratchPad>();
blob->enroll<engine::scratch::MultiscalePartitioningEngineViewScratchPad>();
const engine::GraphEngine* graph_engine = dynamic_cast<engine::GraphEngine*>(m_network_stack.front().get());
if (!graph_engine) {
throw exceptions::PolicyError("Cannot get stack scratch blob from MainSequencePolicy: The base engine is not a GraphEngine. This indicates a serious internal inconsistency and should be reported to the GridFire developers, thank you.");
}
const engine::MultiscalePartitioningEngineView* multiscale_engine = dynamic_cast<engine::MultiscalePartitioningEngineView*>(m_network_stack[1].get());
if (!multiscale_engine) {
throw exceptions::PolicyError("Cannot get stack scratch blob from MainSequencePolicy: The middle engine is not a MultiscalePartitioningEngineView. This indicates a serious internal inconsistency and should be reported to the GridFire developers, thank you.");
}
const engine::AdaptiveEngineView* adaptive_engine = dynamic_cast<engine::AdaptiveEngineView*>(m_network_stack.back().get());
if (!adaptive_engine) {
throw exceptions::PolicyError("Cannot get stack scratch blob from MainSequencePolicy: The top engine is not an AdaptiveEngineView. This indicates a serious internal inconsistency and should be reported to the GridFire developers, thank you.");
}
auto* graph_engine_state = engine::scratch::get_state<engine::scratch::GraphEngineScratchPad, false>(*blob);
graph_engine_state->initialize(*graph_engine);
auto* multiscale_engine_state = engine::scratch::get_state<engine::scratch::MultiscalePartitioningEngineViewScratchPad, false>(*blob);
multiscale_engine_state->initialize();
auto* adaptive_engine_state = engine::scratch::get_state<engine::scratch::AdaptiveEngineViewScratchPad, false>(*blob);
adaptive_engine_state->initialize(*adaptive_engine);
return blob;
}
inline NetworkPolicyStatus MainSequencePolicy::check_status(engine::scratch::StateBlob& ctx) const {
for (const auto& species : m_seed_species) {
if (!m_initializing_composition.contains(species)) {
return NetworkPolicyStatus::MISSING_KEY_SPECIES;
}
}
const reaction::ReactionSet& baseReactions = m_network_stack.front()->getNetworkReactions();
const reaction::ReactionSet& baseReactions = m_network_stack.front()->getNetworkReactions(ctx);
for (const auto& reaction : m_reaction_policy->get_reactions()) {
const bool result = baseReactions.contains(*reaction);
if (!result) {
@@ -130,4 +170,4 @@ namespace gridfire::policy {
}
return NetworkPolicyStatus::INITIALIZED_VERIFIED;
}
}
}

View File

@@ -19,7 +19,6 @@
#include "fourdst/atomic/species.h"
#include "fourdst/composition/exceptions/exceptions_composition.h"
#include "gridfire/engine/engine_graph.h"
#include "gridfire/engine/types/engine_types.h"
#include "gridfire/solver/strategies/triggers/engine_partitioning_trigger.h"
#include "gridfire/trigger/procedures/trigger_pprint.h"
#include "gridfire/exceptions/error_solver.h"
@@ -41,7 +40,8 @@ namespace gridfire::solver {
const std::vector<fourdst::atomic::Species> &networkSpecies,
const size_t currentConvergenceFailure,
const size_t currentNonlinearIterations,
const std::map<fourdst::atomic::Species, std::unordered_map<std::string, double>> &reactionContributionMap
const std::map<fourdst::atomic::Species, std::unordered_map<std::string, double>> &reactionContributionMap,
scratch::StateBlob& ctx
) :
t(t),
state(state),
@@ -54,7 +54,8 @@ namespace gridfire::solver {
networkSpecies(networkSpecies),
currentConvergenceFailures(currentConvergenceFailure),
currentNonlinearIterations(currentNonlinearIterations),
reactionContributionMap(reactionContributionMap)
reactionContributionMap(reactionContributionMap),
state_ctx(ctx)
{}
std::vector<std::tuple<std::string, std::string>> CVODESolverStrategy::TimestepContext::describe() const {
@@ -74,8 +75,11 @@ namespace gridfire::solver {
}
CVODESolverStrategy::CVODESolverStrategy(DynamicEngine &engine): SingleZoneNetworkSolverStrategy<DynamicEngine>(engine) {
// TODO: In order to support MPI this function must be changed
CVODESolverStrategy::CVODESolverStrategy(
const DynamicEngine &engine,
const scratch::StateBlob& ctx
): SingleZoneNetworkSolver<DynamicEngine>(engine, ctx) {
// PERF: In order to support MPI this function must be changed
const int flag = SUNContext_Create(SUN_COMM_NULL, &m_sun_ctx);
if (flag < 0) {
throw std::runtime_error("Failed to create SUNDIALS context (SUNDIALS Errno: " + std::to_string(flag) + ")");
@@ -137,10 +141,10 @@ namespace gridfire::solver {
(!resourcesExist ? "CVODE resources do not exist" :
"Input composition inconsistent with previous state"));
LOG_TRACE_L1(m_logger, "Starting engine update chain...");
equilibratedComposition = m_engine.update(netIn);
equilibratedComposition = m_engine.project(*m_scratch_blob, netIn);
LOG_TRACE_L1(m_logger, "Engine updated and equilibrated composition found!");
size_t numSpecies = m_engine.getNetworkSpecies().size();
size_t numSpecies = m_engine.getNetworkSpecies(*m_scratch_blob).size();
uint64_t N = numSpecies + 1;
LOG_TRACE_L1(m_logger, "Number of species: {} ({} independent variables)", numSpecies, N);
@@ -153,10 +157,10 @@ namespace gridfire::solver {
} else {
LOG_INFO(m_logger, "Reusing existing CVODE resources (size: {})", m_last_size);
const size_t numSpecies = m_engine.getNetworkSpecies().size();
const size_t numSpecies = m_engine.getNetworkSpecies(*m_scratch_blob).size();
sunrealtype *y_data = N_VGetArrayPointer(m_Y);
for (size_t i = 0; i < numSpecies; i++) {
const auto& species = m_engine.getNetworkSpecies()[i];
const auto& species = m_engine.getNetworkSpecies(*m_scratch_blob)[i];
if (netIn.composition.contains(species)) {
y_data[i] = netIn.composition.getMolarAbundance(species);
} else {
@@ -170,10 +174,12 @@ namespace gridfire::solver {
equilibratedComposition = netIn.composition; // Use the provided composition as-is if we already have validated CVODE resources and that the composition is consistent with the previous state
}
size_t numSpecies = m_engine.getNetworkSpecies().size();
CVODEUserData user_data;
user_data.solver_instance = this;
user_data.engine = &m_engine;
size_t numSpecies = m_engine.getNetworkSpecies(*m_scratch_blob).size();
CVODEUserData user_data {
.solver_instance = this,
.ctx = *m_scratch_blob,
.engine = &m_engine,
};
LOG_TRACE_L1(m_logger, "CVODE resources successfully initialized!");
double current_time = 0;
@@ -199,7 +205,7 @@ namespace gridfire::solver {
while (current_time < netIn.tMax) {
user_data.T9 = T9;
user_data.rho = netIn.density;
user_data.networkSpecies = &m_engine.getNetworkSpecies();
user_data.networkSpecies = &m_engine.getNetworkSpecies(*m_scratch_blob);
user_data.captured_exception.reset();
utils::check_cvode_flag(CVodeSetUserData(m_cvode_mem, &user_data), "CVodeSetUserData");
@@ -247,7 +253,7 @@ namespace gridfire::solver {
);
}
for (size_t i = 0; i < numSpecies; ++i) {
const auto& species = m_engine.getNetworkSpecies()[i];
const auto& species = m_engine.getNetworkSpecies(*m_scratch_blob)[i];
if (y_data[i] > 0.0) {
postStep.setMolarAbundance(species, y_data[i]);
}
@@ -260,7 +266,7 @@ namespace gridfire::solver {
LOG_DEBUG(m_logger, "Current composition (molar abundance): {}", [&]() -> std::string {
std::stringstream ss;
for (size_t i = 0; i < numSpecies; ++i) {
const auto& species = m_engine.getNetworkSpecies()[i];
const auto& species = m_engine.getNetworkSpecies(*m_scratch_blob)[i];
ss << species.name() << ": (y_data = " << y_data[i] << ", collected = " << postStep.getMolarAbundance(species) << ")";
if (i < numSpecies - 1) {
ss << ", ";
@@ -285,10 +291,11 @@ namespace gridfire::solver {
netIn.density,
n_steps,
m_engine,
m_engine.getNetworkSpecies(),
m_engine.getNetworkSpecies(*m_scratch_blob),
convFail_diff,
iter_diff,
rcMap
rcMap,
*m_scratch_blob
);
prev_nonlinear_iterations = nliters + total_nonlinear_iterations;
@@ -300,7 +307,7 @@ namespace gridfire::solver {
trigger->step(ctx);
if (m_detailed_step_logging) {
log_step_diagnostics(user_data, true, true, true, "step_" + std::to_string(total_steps + n_steps) + ".json");
log_step_diagnostics(*m_scratch_blob, user_data, true, true, true, "step_" + std::to_string(total_steps + n_steps) + ".json");
}
if (trigger->check(ctx)) {
@@ -326,7 +333,7 @@ namespace gridfire::solver {
fourdst::composition::Composition temp_comp;
std::vector<double> mass_fractions;
auto num_species_at_stop = static_cast<long int>(m_engine.getNetworkSpecies().size());
auto num_species_at_stop = static_cast<long int>(m_engine.getNetworkSpecies(*m_scratch_blob).size());
if (num_species_at_stop > m_Y->ops->nvgetlength(m_Y) - 1) {
LOG_ERROR(
@@ -338,8 +345,8 @@ namespace gridfire::solver {
throw std::runtime_error("Number of species at engine update exceeds the number of species in the CVODE solver. This should never happen.");
}
for (const auto& species: m_engine.getNetworkSpecies()) {
const size_t sid = m_engine.getSpeciesIndex(species);
for (const auto& species: m_engine.getNetworkSpecies(*m_scratch_blob)) {
const size_t sid = m_engine.getSpeciesIndex(*m_scratch_blob, species);
temp_comp.registerSpecies(species);
double y = end_of_step_abundances[sid];
if (y > 0.0) {
@@ -349,7 +356,7 @@ namespace gridfire::solver {
#ifndef NDEBUG
for (long int i = 0; i < num_species_at_stop; ++i) {
const auto& species = m_engine.getNetworkSpecies()[i];
const auto& species = m_engine.getNetworkSpecies(*m_scratch_blob)[i];
if (std::abs(temp_comp.getMolarAbundance(species) - y_data[i]) > 1e-12) {
throw exceptions::UtilityError("Conversion from solver state to composition molar abundance failed verification.");
}
@@ -384,7 +391,7 @@ namespace gridfire::solver {
"Prior to Engine Update active reactions are: {}",
[&]() -> std::string {
std::stringstream ss;
const gridfire::reaction::ReactionSet& reactions = m_engine.getNetworkReactions();
const gridfire::reaction::ReactionSet& reactions = m_engine.getNetworkReactions(*m_scratch_blob);
size_t count = 0;
for (const auto& reaction : reactions) {
ss << reaction -> id();
@@ -396,7 +403,7 @@ namespace gridfire::solver {
return ss.str();
}()
);
fourdst::composition::Composition currentComposition = m_engine.update(netInTemp);
fourdst::composition::Composition currentComposition = m_engine.project(*m_scratch_blob, netInTemp);
LOG_DEBUG(
m_logger,
"After to Engine update composition is (molar abundance) {}",
@@ -443,7 +450,7 @@ namespace gridfire::solver {
"After Engine Update active reactions are: {}",
[&]() -> std::string {
std::stringstream ss;
const gridfire::reaction::ReactionSet& reactions = m_engine.getNetworkReactions();
const gridfire::reaction::ReactionSet& reactions = m_engine.getNetworkReactions(*m_scratch_blob);
size_t count = 0;
for (const auto& reaction : reactions) {
ss << reaction -> id();
@@ -459,10 +466,10 @@ namespace gridfire::solver {
m_logger,
"Due to a triggered engine update the composition was updated from size {} to {} species.",
num_species_at_stop,
m_engine.getNetworkSpecies().size()
m_engine.getNetworkSpecies(*m_scratch_blob).size()
);
numSpecies = m_engine.getNetworkSpecies().size();
numSpecies = m_engine.getNetworkSpecies(*m_scratch_blob).size();
size_t N = numSpecies + 1;
LOG_INFO(m_logger, "Starting CVODE reinitialization after engine update...");
@@ -490,15 +497,15 @@ namespace gridfire::solver {
accumulated_energy += y_data[numSpecies];
std::vector<double> y_vec(y_data, y_data + numSpecies);
for (size_t i = 0; i < y_vec.size(); ++i) {
if (y_vec[i] < 0 && std::abs(y_vec[i]) < 1e-16) {
y_vec[i] = 0.0; // Regularize tiny negative abundances to zero
for (double & i : y_vec) {
if (i < 0 && std::abs(i) < 1e-16) {
i = 0.0; // Regularize tiny negative abundances to zero
}
}
LOG_INFO(m_logger, "Constructing final composition= with {} species", numSpecies);
fourdst::composition::Composition topLevelComposition(m_engine.getNetworkSpecies(), y_vec);
fourdst::composition::Composition topLevelComposition(m_engine.getNetworkSpecies(*m_scratch_blob), y_vec);
LOG_INFO(m_logger, "Final composition constructed from solver state successfully! ({})", [&topLevelComposition]() -> std::string {
std::ostringstream ss;
size_t i = 0;
@@ -513,7 +520,7 @@ namespace gridfire::solver {
}());
LOG_INFO(m_logger, "Collecting final composition...");
fourdst::composition::Composition outputComposition = m_engine.collectComposition(topLevelComposition, netIn.temperature/1e9, netIn.density);
fourdst::composition::Composition outputComposition = m_engine.collectComposition(*m_scratch_blob, topLevelComposition, netIn.temperature/1e9, netIn.density);
assert(outputComposition.getRegisteredSymbols().size() == equilibratedComposition.getRegisteredSymbols().size());
@@ -538,6 +545,7 @@ namespace gridfire::solver {
LOG_TRACE_L2(m_logger, "generating final nuclear energy generation rate derivatives...");
auto [dEps_dT, dEps_dRho] = m_engine.calculateEpsDerivatives(
*m_scratch_blob,
outputComposition,
T9,
netIn.density
@@ -640,7 +648,7 @@ namespace gridfire::solver {
const auto* solver_instance = data->solver_instance;
LOG_TRACE_L2(solver_instance->m_logger, "CVODE Jacobian wrapper starting");
const size_t numSpecies = engine->getNetworkSpecies().size();
const size_t numSpecies = engine->getNetworkSpecies(data->ctx).size();
sunrealtype* y_data = N_VGetArrayPointer(y);
@@ -653,7 +661,7 @@ namespace gridfire::solver {
}
}
std::vector<double> y_vec(y_data, y_data + numSpecies);
fourdst::composition::Composition composition(engine->getNetworkSpecies(), y_vec);
fourdst::composition::Composition composition(engine->getNetworkSpecies(data->ctx), y_vec);
LOG_TRACE_L2(solver_instance->m_logger, "Generating Jacobian matrix at time {} with {} species in composition (mean molecular mass: {})", t, composition.size(), composition.getMeanParticleMass());
LOG_TRACE_L2(solver_instance->m_logger, "Composition is {}", [&composition]() -> std::string {
std::stringstream ss;
@@ -669,11 +677,11 @@ namespace gridfire::solver {
}());
LOG_TRACE_L2(solver_instance->m_logger, "Generating Jacobian matrix at time {}", t);
NetworkJacobian jac = engine->generateJacobianMatrix(composition, data->T9, data->rho);
NetworkJacobian jac = engine->generateJacobianMatrix(data->ctx, composition, data->T9, data->rho);
LOG_TRACE_L2(solver_instance->m_logger, "Regularizing Jacobian matrix at time {}", t);
jac = regularize_jacobian(jac, composition, solver_instance->m_logger);
LOG_TRACE_L2(solver_instance->m_logger, "Done regularizing Jacobian matrix at time {}", t);
if (jac.infs().size() != 0 || jac.nans().size() != 0) {
if (!jac.infs().empty() || !jac.nans().empty()) {
auto infString = [&jac]() -> std::string {
std::stringstream ss;
size_t i = 0;
@@ -685,7 +693,7 @@ namespace gridfire::solver {
}
i++;
}
if (entries.size() == 0) {
if (entries.empty()) {
ss << "None";
}
return ss.str();
@@ -701,7 +709,7 @@ namespace gridfire::solver {
}
i++;
}
if (entries.size() == 0) {
if (entries.empty()) {
ss << "None";
}
return ss.str();
@@ -724,9 +732,9 @@ namespace gridfire::solver {
LOG_TRACE_L2(solver_instance->m_logger, "Transferring Jacobian matrix data to SUNDenseMatrix format at time {}", t);
for (size_t j = 0; j < numSpecies; ++j) {
const fourdst::atomic::Species& species_j = engine->getNetworkSpecies()[j];
const fourdst::atomic::Species& species_j = engine->getNetworkSpecies(data->ctx)[j];
for (size_t i = 0; i < numSpecies; ++i) {
const fourdst::atomic::Species& species_i = engine->getNetworkSpecies()[i];
const fourdst::atomic::Species& species_i = engine->getNetworkSpecies(data->ctx)[i];
// J(i,j) = d(f_i)/d(y_j)
// Column-major order format for SUNDenseMatrix: J_data[j*N + i] indexes J(i,j)
const double dYi_dt = jac(species_i, species_j);
@@ -752,7 +760,7 @@ namespace gridfire::solver {
N_Vector ydot,
const CVODEUserData *data
) const {
const size_t numSpecies = m_engine.getNetworkSpecies().size();
const size_t numSpecies = m_engine.getNetworkSpecies(data->ctx).size();
sunrealtype* y_data = N_VGetArrayPointer(y);
// Solver constraints should keep these values very close to 0 but floating point noise can still result in very
@@ -764,10 +772,10 @@ namespace gridfire::solver {
}
}
std::vector<double> y_vec(y_data, y_data + numSpecies);
fourdst::composition::Composition composition(m_engine.getNetworkSpecies(), y_vec);
fourdst::composition::Composition composition(m_engine.getNetworkSpecies(*m_scratch_blob), y_vec);
LOG_TRACE_L2(m_logger, "Calculating RHS at time {} with {} species in composition", t, composition.size());
const auto result = m_engine.calculateRHSAndEnergy(composition, data->T9, data->rho, false);
const auto result = m_engine.calculateRHSAndEnergy(*m_scratch_blob, composition, data->T9, data->rho, false);
if (!result) {
LOG_CRITICAL(m_logger, "Failed to calculate RHS at time {}: {}", t, EngineStatus_to_string(result.error()));
throw exceptions::BadRHSEngineError(std::format("Failed to calculate RHS at time {}: {}", t, EngineStatus_to_string(result.error())));
@@ -797,7 +805,7 @@ namespace gridfire::solver {
}());
for (size_t i = 0; i < numSpecies; ++i) {
fourdst::atomic::Species species = m_engine.getNetworkSpecies()[i];
fourdst::atomic::Species species = m_engine.getNetworkSpecies(*m_scratch_blob)[i];
ydot_data[i] = dydt.at(species);
}
ydot_data[numSpecies] = nuclearEnergyGenerationRate; // Set the last element to the specific energy rate
@@ -822,7 +830,7 @@ namespace gridfire::solver {
sunrealtype *y_data = N_VGetArrayPointer(m_Y);
for (size_t i = 0; i < numSpecies; i++) {
const auto& species = m_engine.getNetworkSpecies()[i];
const auto& species = m_engine.getNetworkSpecies(*m_scratch_blob)[i];
if (composition.contains(species)) {
y_data[i] = composition.getMolarAbundance(species);
} else {
@@ -893,11 +901,11 @@ namespace gridfire::solver {
}
void CVODESolverStrategy::log_step_diagnostics(
scratch::StateBlob &ctx,
const CVODEUserData &user_data,
bool displayJacobianStiffness,
bool displaySpeciesBalance,
bool to_file,
std::optional<std::string> filename
bool to_file, std::optional<std::string> filename
) const {
if (to_file && !filename.has_value()) {
LOG_ERROR(m_logger, "Filename must be provided when logging diagnostics to file.");
@@ -982,7 +990,7 @@ namespace gridfire::solver {
std::vector<double> Y_full(y_data, y_data + num_components - 1);
std::vector<double> E_full(y_err_data, y_err_data + num_components - 1);
auto result = diagnostics::report_limiting_species(*user_data.engine, Y_full, E_full, relTol, absTol, 10, to_file);
auto result = diagnostics::report_limiting_species(ctx, *user_data.engine, Y_full, E_full, relTol, absTol, 10, to_file);
if (to_file && result.has_value()) {
j["Limiting_Species"] = result.value();
}
@@ -1005,11 +1013,11 @@ namespace gridfire::solver {
err_ratios[i] = err_ratio;
}
fourdst::composition::Composition composition(user_data.engine->getNetworkSpecies(), Y_full);
fourdst::composition::Composition collectedComposition = user_data.engine->collectComposition(composition, user_data.T9, user_data.rho);
fourdst::composition::Composition composition(user_data.engine->getNetworkSpecies(*m_scratch_blob), Y_full);
fourdst::composition::Composition collectedComposition = user_data.engine->collectComposition(*m_scratch_blob, composition, user_data.T9, user_data.rho);
auto destructionTimescales = user_data.engine->getSpeciesDestructionTimescales(collectedComposition, user_data.T9, user_data.rho);
auto netTimescales = user_data.engine->getSpeciesTimescales(collectedComposition, user_data.T9, user_data.rho);
auto destructionTimescales = user_data.engine->getSpeciesDestructionTimescales(*m_scratch_blob, collectedComposition, user_data.T9, user_data.rho);
auto netTimescales = user_data.engine->getSpeciesTimescales(*m_scratch_blob, collectedComposition, user_data.T9, user_data.rho);
bool timescaleOkay = false;
if (destructionTimescales && netTimescales) timescaleOkay = true;
@@ -1029,7 +1037,7 @@ namespace gridfire::solver {
if (destructionTimescales.value().contains(sp)) destructionTimescales_list.emplace_back(destructionTimescales.value().at(sp));
else destructionTimescales_list.emplace_back(std::numeric_limits<double>::infinity());
speciesStatus_list.push_back(SpeciesStatus_to_string(user_data.engine->getSpeciesStatus(sp)));
speciesStatus_list.push_back(SpeciesStatus_to_string(user_data.engine->getSpeciesStatus(*m_scratch_blob, sp)));
}
utils::Column<fourdst::atomic::Species> speciesColumn("Species", species_list);
@@ -1093,7 +1101,7 @@ namespace gridfire::solver {
// --- 4. Call Your Jacobian and Balance Diagnostics ---
if (displayJacobianStiffness) {
auto jStiff = diagnostics::inspect_jacobian_stiffness(*user_data.engine, composition, user_data.T9, user_data.rho, to_file);
auto jStiff = diagnostics::inspect_jacobian_stiffness(ctx, *user_data.engine, composition, user_data.T9, user_data.rho, to_file);
if (to_file && jStiff.has_value()) {
j["Jacobian_Stiffness_Diagnostics"] = jStiff.value();
}
@@ -1103,7 +1111,7 @@ namespace gridfire::solver {
const size_t num_species_to_inspect = std::min(sorted_species.size(), static_cast<size_t>(5));
for (size_t i = 0; i < num_species_to_inspect; ++i) {
const auto& species = sorted_species[i];
auto sbr = diagnostics::inspect_species_balance(*user_data.engine, std::string(species.name()), composition, user_data.T9, user_data.rho, to_file);
auto sbr = diagnostics::inspect_species_balance(ctx, *user_data.engine, std::string(species.name()), composition, user_data.T9, user_data.rho, to_file);
if (to_file && sbr.has_value()) {
j[std::string("Species_Balance_Diagnostics_") + species.name().data()] = sbr.value();
}

View File

@@ -1,5 +1,6 @@
#include "gridfire/utils/logging.h"
#include "gridfire/engine/engine_abstract.h"
#include "gridfire/engine/scratchpads/blob.h"
#include <sstream>
#include <iomanip>
@@ -9,12 +10,12 @@
#include <string>
std::string gridfire::utils::formatNuclearTimescaleLogString(
engine::scratch::StateBlob &ctx,
const engine::DynamicEngine& engine,
const fourdst::composition::Composition& composition,
const double T9,
const double rho
const double T9, const double rho
) {
auto const& result = engine.getSpeciesTimescales(composition, T9, rho);
auto const& result = engine.getSpeciesTimescales(ctx, composition, T9, rho);
if (!result) {
std::ostringstream ss;
ss << "Failed to get species timescales: " << engine::EngineStatus_to_string(result.error());

View File

@@ -16,7 +16,7 @@ gridfire_sources = files(
'lib/io/network_file.cpp',
'lib/io/generative/python.cpp',
'lib/solver/strategies/CVODE_solver_strategy.cpp',
'lib/solver/strategies/SpectralSolverStrategy.cpp',
# 'lib/solver/strategies/SpectralSolverStrategy.cpp',
'lib/solver/strategies/triggers/engine_partitioning_trigger.cpp',
'lib/screening/screening_types.cpp',
'lib/screening/screening_weak.cpp',

View File

@@ -1,9 +1,12 @@
// ReSharper disable CppUnusedIncludeDirective
#include <iostream>
#include <fstream>
#include <chrono>
#include <thread>
#include <format>
#include "gridfire/gridfire.h"
#include <cppad/utility/thread_alloc.hpp> // Required for parallel_setup
#include "fourdst/composition/composition.h"
#include "fourdst/logging/logging.h"
@@ -17,7 +20,15 @@
#include <clocale>
#include "gridfire/reaction/reaclib.h"
#include <omp.h>
unsigned long get_thread_id() {
return static_cast<unsigned long>(omp_get_thread_num());
}
bool in_parallel() {
return omp_in_parallel() != 0;
}
static std::terminate_handler g_previousHandler = nullptr;
static std::vector<std::pair<double, std::unordered_map<std::string, std::pair<double, double>>>> g_callbackHistory;
@@ -110,14 +121,14 @@ void log_results(const gridfire::NetOut& netOut, const gridfire::NetIn& netIn) {
std::vector<std::string> rowLabels = [&]() -> std::vector<std::string> {
std::vector<std::string> labels;
for (const auto& species : logSpecies) {
labels.push_back(std::string(species.name()));
labels.emplace_back(species.name());
}
labels.push_back("ε");
labels.push_back("dε/dT");
labels.push_back("dε/dρ");
labels.push_back("Eν");
labels.push_back("Fν");
labels.push_back("<μ>");
labels.emplace_back("ε");
labels.emplace_back("dε/dT");
labels.emplace_back("dε/dρ");
labels.emplace_back("Eν");
labels.emplace_back("Fν");
labels.emplace_back("<μ>");
return labels;
}();
@@ -145,13 +156,13 @@ void record_abundance_history_callback(const gridfire::solver::CVODESolverStrate
const auto& engine = ctx.engine;
// std::unordered_map<std::string, std::pair<double, double>> abundances;
std::vector<double> Y;
for (const auto& species : engine.getNetworkSpecies()) {
const size_t sid = engine.getSpeciesIndex(species);
for (const auto& species : engine.getNetworkSpecies(ctx.state_ctx)) {
const size_t sid = engine.getSpeciesIndex(ctx.state_ctx, species);
double y = N_VGetArrayPointer(ctx.state)[sid];
Y.push_back(y > 0.0 ? y : 0.0); // Regularize tiny negative abundances to zero
}
fourdst::composition::Composition comp(engine.getNetworkSpecies(), Y);
fourdst::composition::Composition comp(engine.getNetworkSpecies(ctx.state_ctx), Y);
std::unordered_map<std::string, std::pair<double, double>> abundances;
@@ -225,45 +236,116 @@ void callback_main(const gridfire::solver::CVODESolverStrategy::TimestepContext&
record_abundance_history_callback(ctx);
}
int main(int argc, char** argv) {
int main() {
using namespace gridfire;
CLI::App app{"GridFire Sandbox Application."};
constexpr size_t breaks = 100;
constexpr size_t breaks = 1;
double temp = 1.5e7;
double rho = 1.5e2;
double tMax = 3.1536e+17/breaks;
double tMax = 3.1536e+16/breaks;
app.add_option("-t,--temp", temp, "Temperature in K (Default 1.5e7K)");
app.add_option("-r,--rho", rho, "Density in g/cm^3 (Default 1.5e2g/cm^3)");
app.add_option("--tmax", tMax, "Maximum simulation time in s (Default 3.1536e17s)");
CLI11_PARSE(app, argc, argv);
NetIn netIn = init(temp, rho, tMax);
const NetIn netIn = init(temp, rho, tMax);
policy::MainSequencePolicy stellarPolicy(netIn.composition);
stellarPolicy.construct();
engine::DynamicEngine& engine = stellarPolicy.construct();
policy::ConstructionResults construct = stellarPolicy.construct();
std::println("Sandbox Engine Stack: {}", stellarPolicy);
std::println("Scratch Blob State: {}", *construct.scratch_blob);
solver::CVODESolverStrategy solver(engine);
solver.set_stdout_logging_enabled(false);
// solver.set_callback(solver::CVODESolverStrategy::TimestepCallback(callback_main));
fourdst::composition::Composition reinputComp = netIn.composition;
NetOut netOut;
const auto timer = std::chrono::high_resolution_clock::now();
for (int i = 0; i < breaks; ++i) {
NetIn in({.composition = reinputComp, .temperature = temp, .density = rho, .tMax = tMax, .dt0 = 1e-12});
netOut = solver.evaluate(in, false, false);
reinputComp = netOut.composition;
constexpr size_t runs = 1000;
auto startTime = std::chrono::high_resolution_clock::now();
// arrays to store timings
std::array<std::chrono::duration<double>, runs> setup_times;
std::array<std::chrono::duration<double>, runs> eval_times;
std::array<NetOut, runs> serial_results;
for (size_t i = 0; i < runs; ++i) {
auto start_setup_time = std::chrono::high_resolution_clock::now();
std::print("Run {}/{}\r", i + 1, runs);
solver::CVODESolverStrategy solver(construct.engine, *construct.scratch_blob);
// solver.set_callback(solver::CVODESolverStrategy::TimestepCallback(callback_main));
solver.set_stdout_logging_enabled(false);
auto end_setup_time = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> setup_elapsed = end_setup_time - start_setup_time;
setup_times[i] = setup_elapsed;
auto start_eval_time = std::chrono::high_resolution_clock::now();
const NetOut netOut = solver.evaluate(netIn);
auto end_eval_time = std::chrono::high_resolution_clock::now();
serial_results[i] = netOut;
std::chrono::duration<double> eval_elapsed = end_eval_time - start_eval_time;
eval_times[i] = eval_elapsed;
// log_results(netOut, netIn);
}
auto endTime = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> elapsed = endTime - startTime;
std::println("");
// Summarize serial timings
double total_setup_time = 0.0;
double total_eval_time = 0.0;
for (size_t i = 0; i < runs; ++i) {
total_setup_time += setup_times[i].count();
total_eval_time += eval_times[i].count();
}
std::println("Average Setup Time over {} runs: {:.6f} seconds", runs, total_setup_time / runs);
std::println("Average Evaluation Time over {} runs: {:.6f} seconds", runs, total_eval_time / runs);
std::println("Total Time for {} runs: {:.6f} seconds", runs, elapsed.count());
std::println("Final H-1 Abundances Serial: {}", serial_results[0].composition.getMolarAbundance(fourdst::atomic::H_1));
CppAD::thread_alloc::parallel_setup(
static_cast<size_t>(omp_get_max_threads()), // Max threads
[]() -> bool { return in_parallel(); }, // Function to get thread ID
[]() -> size_t { return get_thread_id(); } // Function to check parallel state
);
// OPTIONAL: Prevent CppAD from returning memory to the system
// during execution to reduce overhead (can speed up tight loops)
CppAD::thread_alloc::hold_memory(true);
std::array<NetOut, runs> parallelResults;
std::array<std::chrono::duration<double>, runs> setupTimes;
std::array<std::chrono::duration<double>, runs> evalTimes;
std::array<std::unique_ptr<gridfire::engine::scratch::StateBlob>, runs> workspaces;
for (size_t i = 0; i < runs; ++i) {
workspaces[i] = construct.scratch_blob->clone_structure();
}
const auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - timer).count();
std::cout << "Average execution time over run: " << duration/breaks << " ms" << std::endl;
std::cout << "Total execution time over " << breaks << " runs: " << duration << " ms" << std::endl;
log_results(netOut, netIn);
// log_callback_data(temp);
}
// Parallel runs
startTime = std::chrono::high_resolution_clock::now();
#pragma omp parallel for
for (size_t i = 0; i < runs; ++i) {
auto start_setup_time = std::chrono::high_resolution_clock::now();
solver::CVODESolverStrategy solver(construct.engine, *workspaces[i]);
solver.set_stdout_logging_enabled(false);
auto end_setup_time = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> setup_elapsed = end_setup_time - start_setup_time;
setupTimes[i] = setup_elapsed;
auto start_eval_time = std::chrono::high_resolution_clock::now();
parallelResults[i] = solver.evaluate(netIn);
auto end_eval_time = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> eval_elapsed = end_eval_time - start_eval_time;
evalTimes[i] = eval_elapsed;
}
endTime = std::chrono::high_resolution_clock::now();
elapsed = endTime - startTime;
std::println("");
// Summarize parallel timings
total_setup_time = 0.0;
total_eval_time = 0.0;
for (size_t i = 0; i < runs; ++i) {
total_setup_time += setupTimes[i].count();
total_eval_time += evalTimes[i].count();
}
std::println("Average Parallel Setup Time over {} runs: {:.6f} seconds", runs, total_setup_time / runs);
std::println("Average Parallel Evaluation Time over {} runs: {:.6f} seconds", runs, total_eval_time / runs);
std::println("Total Parallel Time for {} runs: {:.6f} seconds", runs, elapsed.count());
std::println("Final H-1 Abundances Parallel: {}", utils::iterable_to_delimited_string(parallelResults, ",", [](const auto& result) {
return result.composition.getMolarAbundance(fourdst::atomic::H_1);
}));
}

View File

@@ -4,8 +4,8 @@ executable(
dependencies: [gridfire_dep, cli11_dep],
)
executable(
'spectral_sandbox',
'spectral_main.cpp',
dependencies: [gridfire_dep, cli11_dep]
)
#executable(
# 'spectral_sandbox',
# 'spectral_main.cpp',
# dependencies: [gridfire_dep, cli11_dep]
#)

3
utils/cloc/ignore.txt Normal file
View File

@@ -0,0 +1,3 @@
include/gridfire/partition/rauscher_thielemann_partition_data.h
include/gridfire/reaction/reactions_data.h
include/gridfire/reaction/weak/weak_rate_library.h