feat(solver): added callback functions to solver in C++ and python

This commit is contained in:
2025-07-31 15:04:57 -04:00
parent 5b74155477
commit 24049b2658
482 changed files with 4318 additions and 1467 deletions

View File

@@ -296,16 +296,65 @@ namespace gridfire {
*/
[[nodiscard]] virtual screening::ScreeningType getScreeningModel() const = 0;
/**
* @brief Get the index of a species in the network.
*
* @param species The species to look up.
*
* This method allows querying the index of a specific species in the
* engine's internal representation. It is useful for accessing species
* data efficiently.
*/
[[nodiscard]] virtual int getSpeciesIndex(const fourdst::atomic::Species &species) const = 0;
/**
* @brief Map a NetIn object to a vector of molar abundances.
*
* @param netIn The input conditions for the network.
* @return A vector of molar abundances corresponding to the species in the network.
*
* This method converts the input conditions into a vector of molar abundances,
* which can be used for further calculations or diagnostics.
*/
[[nodiscard]] virtual std::vector<double> mapNetInToMolarAbundanceVector(const NetIn &netIn) const = 0;
/**
* @brief Prime the engine with initial conditions.
*
* @param netIn The input conditions for the network.
* @return PrimingReport containing information about the priming process.
*
* This method is used to prepare the engine for calculations by setting up
* initial conditions, reactions, and species. It may involve compiling reaction
* rates, initializing internal data structures, and performing any necessary
* pre-computation.
*/
[[nodiscard]] virtual PrimingReport primeEngine(const NetIn &netIn) = 0;
/**
* @brief Get the depth of the network.
*
* @return The depth of the network, which may indicate the level of detail or
* complexity in the reaction network.
*
* This method is intended to provide information about the network's structure,
* such as how many layers of reactions or species are present. It can be useful
* for diagnostics and understanding the network's complexity.
*/
[[nodiscard]] virtual BuildDepthType getDepth() const {
throw std::logic_error("Network depth not supported by this engine.");
}
/**
* @brief Rebuild the network with a specified depth.
*
* @param comp The composition to rebuild the network with.
* @param depth The desired depth of the network.
*
* This method is intended to allow dynamic adjustment of the network's depth,
* which may involve adding or removing species and reactions based on the
* specified depth. However, not all engines support this operation.
*/
virtual void rebuild(const fourdst::composition::Composition& comp, BuildDepthType depth) {
throw std::logic_error("Setting network depth not supported by this engine.");
}

View File

@@ -335,21 +335,87 @@ namespace gridfire {
const std::string& filename
) const;
void setScreeningModel(screening::ScreeningType) override;
/**
* @brief Sets the electron screening model for reaction rate calculations.
*
* @param model The type of screening model to use.
*
* This method allows changing the screening model at runtime. Screening corrections
* account for the electrostatic shielding of nuclei by electrons, which affects
* reaction rates in dense stellar plasmas.
*/
void setScreeningModel(screening::ScreeningType model) override;
/**
* @brief Gets the current electron screening model.
*
* @return The currently active screening model type.
*
* Example usage:
* @code
* screening::ScreeningType currentModel = engine.getScreeningModel();
* @endcode
*/
[[nodiscard]] screening::ScreeningType getScreeningModel() const override;
/**
* @brief Sets whether to precompute reaction rates.
*
* @param precompute True to enable precomputation, false to disable.
*
* This method allows enabling or disabling precomputation of reaction rates
* for performance optimization. When enabled, reaction rates are computed
* once and stored for later use.
*/
void setPrecomputation(bool precompute);
/**
* @brief Checks if precomputation of reaction rates is enabled.
*
* @return True if precomputation is enabled, false otherwise.
*
* This method allows checking the current state of precomputation for
* reaction rates in the engine.
*/
[[nodiscard]] bool isPrecomputationEnabled() const;
/**
* @brief Gets the partition function used for reaction rate calculations.
*
* @return Reference to the PartitionFunction object.
*
* This method provides access to the partition function used in the engine,
* which is essential for calculating thermodynamic properties and reaction rates.
*/
[[nodiscard]] const partition::PartitionFunction& getPartitionFunction() const;
/**
* @brief Calculates the reverse rate for a given reaction.
*
* @param reaction The reaction for which to calculate the reverse rate.
* @param T9 Temperature in units of 10^9 K.
* @return Reverse rate for the reaction (e.g., mol/g/s).
*
* This method computes the reverse rate based on the forward rate and
* thermodynamic properties of the reaction.
*/
[[nodiscard]] double calculateReverseRate(
const reaction::Reaction &reaction,
double T9
) const;
/**
* @brief Calculates the reverse rate for a two-body reaction.
*
* @param reaction The reaction for which to calculate the reverse rate.
* @param T9 Temperature in units of 10^9 K.
* @param forwardRate The forward rate of the reaction.
* @param expFactor Exponential factor for the reaction.
* @return Reverse rate for the two-body reaction (e.g., mol/g/s).
*
* This method computes the reverse rate using the forward rate and
* thermodynamic properties of the reaction.
*/
[[nodiscard]] double calculateReverseRateTwoBody(
const reaction::Reaction &reaction,
const double T9,
@@ -363,23 +429,82 @@ namespace gridfire {
const double reverseRate
) const;
/**
* @brief Checks if reverse reactions are enabled.
*
* @return True if reverse reactions are enabled, false otherwise.
*
* This method allows checking whether the engine is configured to use
* reverse reactions in its calculations.
*/
[[nodiscard]] bool isUsingReverseReactions() const;
/**
* @brief Sets whether to use reverse reactions in the engine.
*
* @param useReverse True to enable reverse reactions, false to disable.
*
* This method allows enabling or disabling reverse reactions in the engine.
* If disabled, only forward reactions will be considered in calculations.
*/
void setUseReverseReactions(bool useReverse);
/**
* @brief Gets the index of a species in the network.
*
* @param species The species for which to get the index.
* @return Index of the species in the network, or -1 if not found.
*
* This method returns the index of the given species in the network's
* species vector. If the species is not found, it returns -1.
*/
[[nodiscard]] int getSpeciesIndex(
const fourdst::atomic::Species& species
) const override;
/**
* @brief Maps the NetIn object to a vector of molar abundances.
*
* @param netIn The NetIn object containing the input conditions.
* @return Vector of molar abundances corresponding to the species in the network.
*
* This method converts the NetIn object into a vector of molar abundances
* for each species in the network, which can be used for further calculations.
*/
[[nodiscard]] std::vector<double> mapNetInToMolarAbundanceVector(const NetIn &netIn) const override;
/**
* @brief Prepares the engine for calculations with initial conditions.
*
* @param netIn The input conditions for the network.
* @return PrimingReport containing information about the priming process.
*
* This method initializes the engine with the provided input conditions,
* setting up reactions, species, and precomputing necessary data.
*/
[[nodiscard]] PrimingReport primeEngine(const NetIn &netIn) override;
/**
* @brief Gets the depth of the network.
*
* @return The build depth of the network.
*
* This method returns the current build depth of the reaction network,
* which indicates how many levels of reactions are included in the network.
*/
[[nodiscard]] BuildDepthType getDepth() const override;
/**
* @brief Rebuilds the reaction network based on a new composition.
*
* @param comp The new composition to use for rebuilding the network.
* @param depth The build depth to use for the network.
*
* This method rebuilds the reaction network using the provided composition
* and build depth. It updates all internal data structures accordingly.
*/
void rebuild(const fourdst::composition::Composition& comp, const BuildDepthType depth) override;
private:
struct PrecomputedReaction {
// Forward cacheing

View File

@@ -10,10 +10,12 @@ namespace gridfire {
/**
* @brief Configuration struct for the QSE cache.
*
* @purpose This struct defines the tolerances used to determine if a QSE cache key
* @par Purpose
* This struct defines the tolerances used to determine if a QSE cache key
* is considered a hit. It allows for tuning the sensitivity of the cache.
*
* @how It works by providing binning widths for temperature, density, and abundances.
* @par How
* It works by providing binning widths for temperature, density, and abundances.
* When a `QSECacheKey` is created, it uses these tolerances to discretize the
* continuous physical values into bins. If two sets of conditions fall into the
* same bins, they will produce the same hash and be considered a cache hit.
@@ -33,12 +35,14 @@ namespace gridfire {
/**
* @brief Key struct for the QSE abundance cache.
*
* @purpose This struct is used as the key for the QSE abundance cache (`m_qse_abundance_cache`)
* @par Purpose
* This struct is used as the key for the QSE abundance cache (`m_qse_abundance_cache`)
* within the `MultiscalePartitioningEngineView`. Its primary goal is to avoid
* expensive re-partitioning and QSE solves for thermodynamic conditions that are
* "close enough" to previously computed ones.
*
* @how It works by storing the temperature (`m_T9`), density (`m_rho`), and species
* @par How
* It works by storing the temperature (`m_T9`), density (`m_rho`), and species
* abundances (`m_Y`). A pre-computed hash is generated in the constructor by
* calling the `hash()` method. This method discretizes the continuous physical
* values into bins using the tolerances defined in `QSECacheConfig`. The `operator==`
@@ -78,7 +82,8 @@ namespace gridfire {
*
* @return The computed hash value.
*
* @how This method combines the hashes of the binned temperature, density, and
* @par How
* This method combines the hashes of the binned temperature, density, and
* each species abundance. The `bin()` static method is used for discretization.
*/
size_t hash() const;
@@ -89,7 +94,8 @@ namespace gridfire {
* @param tol The tolerance (bin width) to use for binning.
* @return The bin number as a long integer.
*
* @how The algorithm is `floor(value / tol)`.
* @par How
* The algorithm is `floor(value / tol)`.
*/
static long bin(double value, double tol);
@@ -124,14 +130,16 @@ namespace gridfire {
* @class MultiscalePartitioningEngineView
* @brief An engine view that partitions the reaction network into multiple groups based on timescales.
*
* @purpose This class is designed to accelerate the integration of stiff nuclear reaction networks.
* @par Purpose
* This class is designed to accelerate the integration of stiff nuclear reaction networks.
* It identifies species that react on very short timescales ("fast" species) and treats them
* as being in Quasi-Steady-State Equilibrium (QSE). Their abundances are solved for algebraically,
* removing their stiff differential equations from the system. The remaining "slow" or "dynamic"
* species are integrated normally. This significantly improves the stability and performance of
* the solver.
*
* @how The core logic resides in the `partitionNetwork()` and `equilibrateNetwork()` methods.
* @par How
* The core logic resides in the `partitionNetwork()` and `equilibrateNetwork()` methods.
* The partitioning process involves:
* 1. **Timescale Analysis:** Using `getSpeciesDestructionTimescales` from the base engine,
* all species are sorted by their characteristic timescales.
@@ -207,10 +215,12 @@ namespace gridfire {
* `StaleEngineError` if the engine's QSE cache does not contain a solution
* for the given state.
*
* @purpose To compute the time derivatives for the ODE solver. This implementation
* @par Purpose
* To compute the time derivatives for the ODE solver. This implementation
* modifies the derivatives from the base engine to enforce the QSE condition.
*
* @how It first performs a lookup in the QSE abundance cache (`m_qse_abundance_cache`).
* @par How
* It first performs a lookup in the QSE abundance cache (`m_qse_abundance_cache`).
* If a cache hit occurs, it calls the base engine's `calculateRHSAndEnergy`. It then
* manually sets the time derivatives (`dydt`) of all identified algebraic species to zero,
* effectively removing their differential equations from the system being solved.
@@ -235,9 +245,11 @@ namespace gridfire {
* @param T9 Temperature in units of 10^9 K.
* @param rho Density in g/cm^3.
*
* @purpose To compute the Jacobian matrix required by implicit ODE solvers.
* @par Purpose
* To compute the Jacobian matrix required by implicit ODE solvers.
*
* @how It first performs a QSE cache lookup. On a hit, it delegates the full Jacobian
* @par How
* It first performs a QSE cache lookup. On a hit, it delegates the full Jacobian
* calculation to the base engine. While this view could theoretically return a
* modified, sparser Jacobian reflecting the QSE constraints, the current implementation
* returns the full Jacobian from the base engine. The solver is expected to handle the
@@ -262,9 +274,11 @@ namespace gridfire {
* @param j_full Column index (species index) in the full network.
* @return Value of the Jacobian matrix at (i_full, j_full).
*
* @purpose To provide Jacobian entries to an implicit solver.
* @par Purpose
* To provide Jacobian entries to an implicit solver.
*
* @how This method directly delegates to the base engine's `getJacobianMatrixEntry`.
* @par How
* This method directly delegates to the base engine's `getJacobianMatrixEntry`.
* It does not currently modify the Jacobian to reflect the QSE algebraic constraints,
* as these are handled by setting `dY/dt = 0` in `calculateRHSAndEnergy`.
*
@@ -278,9 +292,11 @@ namespace gridfire {
/**
* @brief Generates the stoichiometry matrix for the network.
*
* @purpose To prepare the stoichiometry matrix for later queries.
* @par Purpose
* To prepare the stoichiometry matrix for later queries.
*
* @how This method delegates directly to the base engine's `generateStoichiometryMatrix()`.
* @par How
* This method delegates directly to the base engine's `generateStoichiometryMatrix()`.
* The stoichiometry is based on the full, unpartitioned network.
*/
void generateStoichiometryMatrix() override;
@@ -292,9 +308,11 @@ namespace gridfire {
* @param reactionIndex Index of the reaction in the full network.
* @return Stoichiometric coefficient for the species in the reaction.
*
* @purpose To query the stoichiometric relationship between a species and a reaction.
* @par Purpose
* To query the stoichiometric relationship between a species and a reaction.
*
* @how This method delegates directly to the base engine's `getStoichiometryMatrixEntry()`.
* @par How
* This method delegates directly to the base engine's `getStoichiometryMatrixEntry()`.
*
* @pre `generateStoichiometryMatrix()` must have been called.
*/
@@ -312,9 +330,11 @@ namespace gridfire {
* @param rho Density in g/cm^3.
* @return Molar flow rate for the reaction (e.g., mol/g/s).
*
* @purpose To compute the net rate of a single reaction.
* @par Purpose
* To compute the net rate of a single reaction.
*
* @how It first checks the QSE cache. On a hit, it retrieves the cached equilibrium
* @par How
* It first checks the QSE cache. On a hit, it retrieves the cached equilibrium
* abundances for the algebraic species. It creates a mutable copy of `Y_full`,
* overwrites the algebraic species abundances with the cached equilibrium values,
* and then calls the base engine's `calculateMolarReactionFlow` with this modified
@@ -343,9 +363,11 @@ namespace gridfire {
*
* @param reactions The set of logical reactions to use.
*
* @purpose To modify the reaction network.
* @par Purpose
* To modify the reaction network.
*
* @how This operation is not supported by the `MultiscalePartitioningEngineView` as it
* @par How
* This operation is not supported by the `MultiscalePartitioningEngineView` as it
* would invalidate the partitioning logic. It logs a critical error and throws an
* exception. Network modifications should be done on the base engine before it is
* wrapped by this view.
@@ -365,9 +387,11 @@ namespace gridfire {
* @return A `std::expected` containing a map from `Species` to their characteristic
* timescales (s) on success, or a `StaleEngineError` on failure.
*
* @purpose To get the characteristic timescale `Y / (dY/dt)` for each species.
* @par Purpose
* To get the characteristic timescale `Y / (dY/dt)` for each species.
*
* @how It delegates the calculation to the base engine. For any species identified
* @par How
* It delegates the calculation to the base engine. For any species identified
* as algebraic (in QSE), it manually sets their timescale to 0.0 to signify
* that they equilibrate instantaneously on the timescale of the solver.
*
@@ -389,10 +413,12 @@ namespace gridfire {
* @return A `std::expected` containing a map from `Species` to their characteristic
* destruction timescales (s) on success, or a `StaleEngineError` on failure.
*
* @purpose To get the timescale for species destruction, which is used as the primary
* @par Purpose
* To get the timescale for species destruction, which is used as the primary
* metric for network partitioning.
*
* @how It delegates the calculation to the base engine. For any species identified
* @par How
* It delegates the calculation to the base engine. For any species identified
* as algebraic (in QSE), it manually sets their timescale to 0.0.
*
* @pre The engine must have a valid QSE cache entry for the given state.
@@ -410,7 +436,8 @@ namespace gridfire {
* @param netIn A struct containing the current network input: temperature, density, and composition.
* @return The new composition after QSE species have been brought to equilibrium.
*
* @purpose This is the main entry point for preparing the multiscale engine for use. It
* @par Purpose
* This is the main entry point for preparing the multiscale engine for use. It
* triggers the network partitioning and solves for the initial QSE abundances, caching the result.
*
* @how
@@ -440,9 +467,11 @@ namespace gridfire {
* @param netIn A struct containing the current network input.
* @return `true` if the engine is stale, `false` otherwise.
*
* @purpose To determine if `update()` needs to be called.
* @par Purpose
* To determine if `update()` needs to be called.
*
* @how It creates a `QSECacheKey` from the `netIn` data and checks for its
* @par How
* It creates a `QSECacheKey` from the `netIn` data and checks for its
* existence in the `m_qse_abundance_cache`. A cache miss indicates the engine is
* stale because it does not have a valid QSE partition for the current conditions.
* It also queries the base engine's `isStale()` method.
@@ -454,7 +483,8 @@ namespace gridfire {
*
* @param model The type of screening model to use for reaction rate calculations.
*
* @how This method delegates directly to the base engine's `setScreeningModel()`.
* @par How
* This method delegates directly to the base engine's `setScreeningModel()`.
*/
void setScreeningModel(
screening::ScreeningType model
@@ -465,7 +495,8 @@ namespace gridfire {
*
* @return The currently active screening model type.
*
* @how This method delegates directly to the base engine's `getScreeningModel()`.
* @par How
* This method delegates directly to the base engine's `getScreeningModel()`.
*/
[[nodiscard]] screening::ScreeningType getScreeningModel() const override;
@@ -487,10 +518,12 @@ namespace gridfire {
* @return A vector of vectors of species indices, where each inner vector represents a
* single connected component.
*
* @purpose To merge timescale pools that are strongly connected by reactions, forming
* @par Purpose
* To merge timescale pools that are strongly connected by reactions, forming
* cohesive groups for QSE analysis.
*
* @how For each pool, it builds a reaction connectivity graph using `buildConnectivityGraph`.
* @par How
* For each pool, it builds a reaction connectivity graph using `buildConnectivityGraph`.
* It then finds the connected components within that graph using a Breadth-First Search (BFS).
* The resulting components from all pools are collected and returned.
*/
@@ -508,7 +541,8 @@ namespace gridfire {
* @param T9 Temperature in units of 10^9 K.
* @param rho Density in g/cm^3.
*
* @purpose To perform the core partitioning logic that identifies which species are "fast"
* @par Purpose
* To perform the core partitioning logic that identifies which species are "fast"
* (and can be treated algebraically) and which are "slow" (and must be integrated dynamically).
*
* @how
@@ -539,9 +573,11 @@ namespace gridfire {
*
* @param netIn A struct containing the current network input.
*
* @purpose A convenience overload for `partitionNetwork`.
* @par Purpose
* A convenience overload for `partitionNetwork`.
*
* @how It unpacks the `netIn` struct into `Y`, `T9`, and `rho` and then calls the
* @par How
* It unpacks the `netIn` struct into `Y`, `T9`, and `rho` and then calls the
* primary `partitionNetwork` method.
*/
void partitionNetwork(
@@ -556,9 +592,11 @@ namespace gridfire {
* @param T9 Temperature in units of 10^9 K.
* @param rho Density in g/cm^3.
*
* @purpose To visualize the partitioned network graph.
* @par Purpose
* To visualize the partitioned network graph.
*
* @how This method delegates the DOT file export to the base engine. It does not
* @par How
* This method delegates the DOT file export to the base engine. It does not
* currently add any partitioning information to the output graph.
*/
void exportToDot(
@@ -574,7 +612,8 @@ namespace gridfire {
* @param species The species to get the index of.
* @return The index of the species in the base engine's network.
*
* @how This method delegates directly to the base engine's `getSpeciesIndex()`.
* @par How
* This method delegates directly to the base engine's `getSpeciesIndex()`.
*/
[[nodiscard]] int getSpeciesIndex(const fourdst::atomic::Species &species) const override;
@@ -584,7 +623,8 @@ namespace gridfire {
* @param netIn A struct containing the current network input.
* @return A vector of molar abundances corresponding to the species order in the base engine.
*
* @how This method delegates directly to the base engine's `mapNetInToMolarAbundanceVector()`.
* @par How
* This method delegates directly to the base engine's `mapNetInToMolarAbundanceVector()`.
*/
[[nodiscard]] std::vector<double> mapNetInToMolarAbundanceVector(const NetIn &netIn) const override;
@@ -594,9 +634,11 @@ namespace gridfire {
* @param netIn A struct containing the current network input.
* @return A `PrimingReport` struct containing information about the priming process.
*
* @purpose To prepare the network for ignition or specific pathway studies.
* @par Purpose
* To prepare the network for ignition or specific pathway studies.
*
* @how This method delegates directly to the base engine's `primeEngine()`. The
* @par How
* This method delegates directly to the base engine's `primeEngine()`. The
* multiscale view does not currently interact with the priming process.
*/
[[nodiscard]] PrimingReport primeEngine(const NetIn &netIn) override;
@@ -606,9 +648,11 @@ namespace gridfire {
*
* @return A vector of species identified as "fast" or "algebraic" by the partitioning.
*
* @purpose To allow external queries of the partitioning results.
* @par Purpose
* To allow external queries of the partitioning results.
*
* @how It returns a copy of the `m_algebraic_species` member vector.
* @par How
* It returns a copy of the `m_algebraic_species` member vector.
*
* @pre `partitionNetwork()` must have been called.
*/
@@ -618,9 +662,11 @@ namespace gridfire {
*
* @return A const reference to the vector of species identified as "dynamic" or "slow".
*
* @purpose To allow external queries of the partitioning results.
* @par Purpose
* To allow external queries of the partitioning results.
*
* @how It returns a const reference to the `m_dynamic_species` member vector.
* @par How
* It returns a const reference to the `m_dynamic_species` member vector.
*
* @pre `partitionNetwork()` must have been called.
*/
@@ -634,10 +680,12 @@ namespace gridfire {
* @param rho Density in g/cm^3.
* @return A new composition object with the equilibrated abundances.
*
* @purpose A convenience method to run the full QSE analysis and get an equilibrated
* @par Purpose
* A convenience method to run the full QSE analysis and get an equilibrated
* composition object as a result.
*
* @how It first calls `partitionNetwork()` with the given state to define the QSE groups.
* @par How
* It first calls `partitionNetwork()` with the given state to define the QSE groups.
* Then, it calls `solveQSEAbundances()` to compute the new equilibrium abundances for the
* algebraic species. Finally, it packs the resulting full abundance vector into a new
* `fourdst::composition::Composition` object and returns it.
@@ -657,9 +705,11 @@ namespace gridfire {
* @param netIn A struct containing the current network input.
* @return The equilibrated composition.
*
* @purpose A convenience overload for `equilibrateNetwork`.
* @par Purpose
* A convenience overload for `equilibrateNetwork`.
*
* @how It unpacks the `netIn` struct into `Y`, `T9`, and `rho` and then calls the
* @par How
* It unpacks the `netIn` struct into `Y`, `T9`, and `rho` and then calls the
* primary `equilibrateNetwork` method.
*/
fourdst::composition::Composition equilibrateNetwork(
@@ -671,7 +721,8 @@ namespace gridfire {
/**
* @brief Struct representing a QSE group.
*
* @purpose A container to hold all information about a set of species that are potentially
* @par Purpose
* A container to hold all information about a set of species that are potentially
* in quasi-steady-state equilibrium with each other.
*/
struct QSEGroup {
@@ -710,7 +761,8 @@ namespace gridfire {
/**
* @brief Functor for solving QSE abundances using Eigen's nonlinear optimization.
*
* @purpose This struct provides the objective function (`operator()`) and its Jacobian
* @par Purpose
* This struct provides the objective function (`operator()`) and its Jacobian
* (`df`) to Eigen's Levenberg-Marquardt solver. The goal is to find the abundances
* of algebraic species that make their time derivatives (`dY/dt`) equal to zero.
*
@@ -816,7 +868,8 @@ namespace gridfire {
/**
* @brief Struct for tracking cache statistics.
*
* @purpose A simple utility to monitor the performance of the QSE cache by counting
* @par Purpose
* A simple utility to monitor the performance of the QSE cache by counting
* hits and misses for various engine operations.
*/
struct CacheStats {
@@ -946,7 +999,8 @@ namespace gridfire {
/**
* @brief Cache for QSE abundances based on T9, rho, and Y.
*
* @purpose This is the core of the caching mechanism. It stores the results of QSE solves
* @par Purpose
* This is the core of the caching mechanism. It stores the results of QSE solves
* to avoid re-computation. The key is a `QSECacheKey` which hashes the thermodynamic
* state, and the value is the vector of solved molar abundances for the algebraic species.
*/
@@ -969,9 +1023,11 @@ namespace gridfire {
* @return A vector of vectors of species indices, where each inner vector represents a
* timescale pool.
*
* @purpose To group species into "pools" based on their destruction timescales.
* @par Purpose
* To group species into "pools" based on their destruction timescales.
*
* @how It retrieves all species destruction timescales from the base engine, sorts them,
* @par How
* It retrieves all species destruction timescales from the base engine, sorts them,
* and then iterates through the sorted list, creating a new pool whenever it detects
* a gap between consecutive timescales that is larger than a predefined threshold
* (e.g., a factor of 100).
@@ -989,9 +1045,11 @@ namespace gridfire {
* @return An unordered map representing the adjacency list of the connectivity graph,
* where keys are species indices and values are vectors of connected species indices.
*
* @purpose To represent the reaction pathways among a subset of reactions.
* @par Purpose
* To represent the reaction pathways among a subset of reactions.
*
* @how It iterates through the specified fast reactions. For each reaction, it creates
* @par How
* It iterates through the specified fast reactions. For each reaction, it creates
* a two-way edge in the graph between every reactant and every product, signifying
* that mass can flow between them.
*/
@@ -1008,11 +1066,13 @@ namespace gridfire {
* @param rho Density in g/cm^3.
* @return A vector of validated QSE groups that meet the flux criteria.
*
* @purpose To ensure that a candidate QSE group is truly in equilibrium by checking that
* @par Purpose
* To ensure that a candidate QSE group is truly in equilibrium by checking that
* the reaction fluxes *within* the group are much larger than the fluxes
* *leaving* the group.
*
* @how For each candidate group, it calculates the sum of all internal reaction fluxes and
* @par How
* For each candidate group, it calculates the sum of all internal reaction fluxes and
* the sum of all external (bridge) reaction fluxes. If the ratio of internal to external
* flux exceeds a configurable threshold, the group is considered valid and is added
* to the returned vector.
@@ -1032,10 +1092,12 @@ namespace gridfire {
* @param rho Density in g/cm^3.
* @return A vector of molar abundances for the algebraic species.
*
* @purpose To find the equilibrium abundances of the algebraic species that satisfy
* @par Purpose
* To find the equilibrium abundances of the algebraic species that satisfy
* the QSE conditions.
*
* @how It uses the Levenberg-Marquardt algorithm via Eigen's `LevenbergMarquardt` class.
* @par How
* It uses the Levenberg-Marquardt algorithm via Eigen's `LevenbergMarquardt` class.
* The problem is defined by the `EigenFunctor` which computes the residuals and
* Jacobian for the QSE equations.
*
@@ -1058,9 +1120,11 @@ namespace gridfire {
* @param rho Density in g/cm^3.
* @return The index of the pool with the largest (slowest) mean destruction timescale.
*
* @purpose To identify the core set of dynamic species that will not be part of any QSE group.
* @par Purpose
* To identify the core set of dynamic species that will not be part of any QSE group.
*
* @how It calculates the geometric mean of the destruction timescales for all species in each
* @par How
* It calculates the geometric mean of the destruction timescales for all species in each
* pool and returns the index of the pool with the maximum mean timescale.
*/
size_t identifyMeanSlowestPool(
@@ -1076,9 +1140,11 @@ namespace gridfire {
* @param species_pool A vector of species indices representing a species pool.
* @return An unordered map representing the adjacency list of the connectivity graph.
*
* @purpose To find reaction connections within a specific group of species.
* @par Purpose
* To find reaction connections within a specific group of species.
*
* @how It iterates through all reactions in the base engine. If a reaction involves
* @par How
* It iterates through all reactions in the base engine. If a reaction involves
* at least two distinct species from the input `species_pool` (one as a reactant
* and one as a product), it adds edges between all reactants and products from
* that reaction that are also in the pool.
@@ -1097,7 +1163,8 @@ namespace gridfire {
* @param rho Density in g/cm^3.
* @return A vector of `QSEGroup` structs, ready for flux validation.
*
* @how For each input pool, it identifies "bridge" reactions that connect the pool to
* @par How
* For each input pool, it identifies "bridge" reactions that connect the pool to
* species outside the pool. The reactants of these bridge reactions that are *not* in the
* pool are identified as "seed" species. The original pool members are the "algebraic"
* species. It then bundles the seed and algebraic species into a `QSEGroup` struct.

View File

@@ -2,7 +2,6 @@
#include "gridfire/engine/engine_graph.h"
#include "gridfire/engine/engine_abstract.h"
#include "../engine/views/engine_adaptive.h"
#include "gridfire/network.h"
#include "fourdst/logging/logging.h"
@@ -10,9 +9,35 @@
#include "quill/Logger.h"
#include <functional>
#include <any>
#include <vector>
#include <tuple>
#include <string>
namespace gridfire::solver {
/**
* @struct SolverContextBase
* @brief Base class for solver callback contexts.
*
* This struct serves as a base class for contexts that can be passed to solver callbacks, it enforces
* that derived classes implement a `describe` method that returns a vector of tuples describing
* the context that a callback will receive when called.
*/
struct SolverContextBase {
virtual ~SolverContextBase() = default;
/**
* @brief Describe the context for callback functions.
* @return A vector of tuples, each containing a string for the parameters name and a string for its type.
*
* This method should be overridden by derived classes to provide a description of the context
* that will be passed to the callback function. The intent of this method is that an end user can investigate
* the context that will be passed to the callback function, and use this information to craft their own
* callback function.
*/
virtual std::vector<std::tuple<std::string, std::string>> describe() const = 0;
};
/**
* @class NetworkSolverStrategy
* @brief Abstract base class for network solver strategies.
@@ -43,6 +68,31 @@ namespace gridfire::solver {
* @return The output conditions after the timestep.
*/
virtual NetOut evaluate(const NetIn& netIn) = 0;
/**
* @brief set the callback function to be called at the end of each timestep.
*
* This function allows the user to set a callback function that will be called at the end of each timestep.
* The callback function will receive a gridfire::solver::<SOMESOLVER>::TimestepContext object. Note that
* depending on the solver, this context may contain different information. Further, the exact
* signature of the callback function is left up to each solver. Every solver should provide a type or type alias
* TimestepCallback that defines the signature of the callback function so that the user can easily
* get that type information.
*
* @param callback The callback function to be called at the end of each timestep.
*/
virtual void set_callback(const std::any& callback) = 0;
/**
* @brief Describe the context that will be passed to the callback function.
* @return A vector of tuples, each containing a string for the parameter's name and a string for its type.
*
* This method should be overridden by derived classes to provide a description of the context
* that will be passed to the callback function. The intent of this method is that an end user can investigate
* the context that will be passed to the callback function, and use this information to craft their own
* callback function.
*/
virtual std::vector<std::tuple<std::string, std::string>> describe_callback_context() const = 0;
protected:
EngineT& m_engine; ///< The engine used by this solver strategy.
};
@@ -70,15 +120,120 @@ namespace gridfire::solver {
*/
using DynamicNetworkSolverStrategy::DynamicNetworkSolverStrategy;
/**
* @struct TimestepContext
* @brief Context for the timestep callback function for the DirectNetworkSolver.
*
* This struct contains the context that will be passed to the callback function at the end of each timestep.
* It includes the current time, state, timestep size, cached results, and other relevant information.
*
* This type should be used when defining a callback function
*
* **Example:**
* @code
* #include "gridfire/solver/solver.h"
*
* #include <ofstream>
* #include <ranges>
*
* static std::ofstream consumptionFile("consumption.txt");
* void callback(const gridfire::solver::DirectNetworkSolver::TimestepContext& context) {
* int H1Index = context.engine.getSpeciesIndex(fourdst::atomic::H_1);
* int He4Index = context.engine.getSpeciesIndex(fourdst::atomic::He_4);
*
* consumptionFile << context.t << "," << context.state(H1Index) << "," << context.state(He4Index) << "\n";
* }
*
* int main() {
* ... // Code to set up engine and solvers...
* solver.set_callback(callback);
* solver.evaluate(netIn);
* consumptionFile.close();
* }
* @endcode
*/
struct TimestepContext final : public SolverContextBase {
const double t; ///< Current time.
const boost::numeric::ublas::vector<double>& state; ///< Current state of the system.
const double dt; ///< Time step size.
const double cached_time; ///< Cached time for the last observed state.
const double last_observed_time; ///< Last time the state was observed.
const double last_step_time; ///< Last step time.
const double T9; ///< Temperature in units of 10^9 K.
const double rho; ///< Density in g/cm^3.
const std::optional<StepDerivatives<double>>& cached_result; ///< Cached result of the step derivatives.
const int num_steps; ///< Total number of steps taken.
const DynamicEngine& engine; ///< Reference to the dynamic engine.
const std::vector<fourdst::atomic::Species>& networkSpecies;
TimestepContext(
const double t,
const boost::numeric::ublas::vector<double> &state,
const double dt,
const double cached_time,
const double last_observed_time,
const double last_step_time,
const double t9,
const double rho,
const std::optional<StepDerivatives<double>> &cached_result,
const int num_steps,
const DynamicEngine &engine,
const std::vector<fourdst::atomic::Species>& networkSpecies
);
/**
* @brief Describe the context for callback functions.
* @return A vector of tuples, each containing a string for the parameter's name and a string for its type.
*
* This method provides a description of the context that will be passed to the callback function.
* The intent is that an end user can investigate the context and use this information to craft their own
* callback function.
*
* @implements SolverContextBase::describe
*/
std::vector<std::tuple<std::string, std::string>> describe() const override;
};
/**
* @brief Type alias for a timestep callback function.
*
* @brief The type alias for the callback function that will be called at the end of each timestep.
*
*/
using TimestepCallback = std::function<void(const TimestepContext& context)>; ///< Type alias for a timestep callback function.
/**
* @brief Evaluates the network for a given timestep using direct integration.
* @param netIn The input conditions for the network.
* @return The output conditions after the timestep.
*/
NetOut evaluate(const NetIn& netIn) override;
/**
* @brief Sets the callback function to be called at the end of each timestep.
* @param callback The callback function to be called at the end of each timestep.
*
* This function allows the user to set a callback function that will be called at the end of each timestep.
* The callback function will receive a gridfire::solver::DirectNetworkSolver::TimestepContext object.
*/
void set_callback(const std::any &callback) override;
/**
* @brief Describe the context that will be passed to the callback function.
* @return A vector of tuples, each containing a string for the parameter's name and a string for its type.
*
* This method provides a description of the context that will be passed to the callback function.
* The intent is that an end user can investigate the context and use this information to craft their own
* callback function.
*
* @implements SolverContextBase::describe
*/
std::vector<std::tuple<std::string, std::string>> describe_callback_context() const override;
private:
/**
* @struct RHSFunctor
* @struct RHSManager
* @brief Functor for calculating the right-hand side of the ODEs.
*
* This functor is used by the ODE solver to calculate the time derivatives of the
@@ -100,21 +255,30 @@ namespace gridfire::solver {
mutable int m_num_steps = 0;
mutable double m_last_step_time = 1e-20;
TimestepCallback& m_callback;
const std::vector<fourdst::atomic::Species>& m_networkSpecies;
/**
* @brief Constructor for the RHSFunctor.
* @param engine The engine used to evaluate the network.
* @param T9 Temperature in units of 10^9 K.
* @param rho Density in g/cm^3.
* @param callback callback function to be called at the end of each timestep.
* @param networkSpecies vector of species in the network in the correct order.
*/
RHSManager(
DynamicEngine& engine,
const double T9,
const double rho
const double rho,
TimestepCallback& callback,
const std::vector<fourdst::atomic::Species>& networkSpecies
) :
m_engine(engine),
m_T9(T9),
m_rho(rho),
m_cached_time(0) {}
m_cached_time(0),
m_callback(callback),
m_networkSpecies(networkSpecies){}
/**
* @brief Calculates the time derivatives of the species abundances.
@@ -179,5 +343,7 @@ namespace gridfire::solver {
private:
quill::Logger* m_logger = LogManager::getInstance().getLogger("log"); ///< Logger instance.
Config& m_config = Config::getInstance(); ///< Configuration instance.
TimestepCallback m_callback;
};
}

View File

@@ -34,7 +34,7 @@ namespace gridfire::solver {
size_t numSpecies = m_engine.getNetworkSpecies().size();
ublas::vector<double> Y(numSpecies + 1);
RHSManager manager(m_engine, T9, netIn.density);
RHSManager manager(m_engine, T9, netIn.density, m_callback, m_engine.getNetworkSpecies());
JacobianFunctor jacobianFunctor(m_engine, T9, netIn.density);
auto populateY = [&](const Composition& comp) {
@@ -149,6 +149,44 @@ namespace gridfire::solver {
return netOut;
}
void DirectNetworkSolver::set_callback(const std::any& callback) {
if (!callback.has_value()) {
m_callback = {};
return;
}
using FunctionPtrType = void (*)(const TimestepContext&);
if (callback.type() == typeid(TimestepCallback)) {
m_callback = std::any_cast<TimestepCallback>(callback);
}
else if (callback.type() == typeid(FunctionPtrType)) {
auto func_ptr = std::any_cast<FunctionPtrType>(callback);
m_callback = func_ptr;
}
else {
throw std::invalid_argument("Unsupported type passed to set_callback. "
"Provide a std::function or a matching function pointer.");
}
}
std::vector<std::tuple<std::string, std::string>> DirectNetworkSolver::describe_callback_context() const {
const TimestepContext context(
0.0, // time
boost::numeric::ublas::vector<double>(), // state
0.0, // dt
0.0, // cached_time
0.0, // last_observed_time
0.0, // last_step_time
0.0, // T9
0.0, // rho
std::nullopt, // cached_result
0, // num_steps
m_engine, // engine,
{}
);
return context.describe();
}
void DirectNetworkSolver::RHSManager::operator()(
const boost::numeric::ublas::vector<double> &Y,
boost::numeric::ublas::vector<double> &dYdt,
@@ -181,6 +219,29 @@ namespace gridfire::solver {
oss << std::scientific << std::setprecision(3);
oss << "(Step: " << std::setw(10) << m_num_steps << ") t = " << t << " (dt = " << dt << ", eps_nuc: " << state(state.size() - 1) << " [erg])\n";
std::cout << oss.str();
// Callback logic
if (m_callback) {
LOG_TRACE_L1(m_logger, "Calling user callback function at t = {:0.3E} with dt = {:0.3E}", t, dt);
const TimestepContext context(
t,
state,
dt,
m_cached_time,
m_last_observed_time,
m_last_step_time,
m_T9,
m_rho,
m_cached_result,
m_num_steps,
m_engine,
m_networkSpecies
);
m_callback(context);
LOG_TRACE_L1(m_logger, "User callback function completed at t = {:0.3E} with dt = {:0.3E}", t, dt);
}
m_last_observed_time = t;
m_last_step_time = dt;
@@ -228,4 +289,49 @@ namespace gridfire::solver {
}
}
DirectNetworkSolver::TimestepContext::TimestepContext(
const double t,
const boost::numeric::ublas::vector<double> &state,
const double dt,
const double cached_time,
const double last_observed_time,
const double last_step_time,
const double t9,
const double rho,
const std::optional<StepDerivatives<double>> &cached_result,
const int num_steps,
const DynamicEngine &engine,
const std::vector<fourdst::atomic::Species> &networkSpecies
)
: t(t),
state(state),
dt(dt),
cached_time(cached_time),
last_observed_time(last_observed_time),
last_step_time(last_step_time),
T9(t9),
rho(rho),
cached_result(cached_result),
num_steps(num_steps),
engine(engine),
networkSpecies(networkSpecies) {}
std::vector<std::tuple<std::string, std::string>> DirectNetworkSolver::TimestepContext::describe() const {
return {
{"time", "double"},
{"state", "boost::numeric::ublas::vector<double>&"},
{"dt", "double"},
{"cached_time", "double"},
{"last_observed_time", "double"},
{"last_step_time", "double"},
{"T9", "double"},
{"rho", "double"},
{"cached_result", "std::optional<StepDerivatives<double>>&"},
{"num_steps", "int"},
{"engine", "DynamicEngine&"},
{"networkSpecies", "std::vector<fourdst::atomic::Species>&"}
};
}
}

View File

@@ -1,6 +1,10 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc if needed directly
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc. if needed directly
#include <pybind11/numpy.h>
#include <pybind11/functional.h> // Needed for std::function
#include <boost/numeric/ublas/vector.hpp>
#include "bindings.h"
@@ -10,17 +14,60 @@
namespace py = pybind11;
void register_solver_bindings(py::module &m) {
void register_solver_bindings(const py::module &m) {
auto py_dynamic_network_solving_strategy = py::class_<gridfire::solver::DynamicNetworkSolverStrategy, PyDynamicNetworkSolverStrategy>(m, "DynamicNetworkSolverStrategy");
auto py_direct_network_solver = py::class_<gridfire::solver::DirectNetworkSolver, gridfire::solver::DynamicNetworkSolverStrategy>(m, "DirectNetworkSolver");
py_direct_network_solver.def(py::init<gridfire::DynamicEngine&>(),
py::arg("engine"),
"Constructor for the DirectNetworkSolver. Takes a DynamicEngine instance to use for evaluating the network.");
"Constructor for the DirectNetworkSolver. Takes a DynamicEngine instance to use for evaluating the network."
);
py_direct_network_solver.def("evaluate",
&gridfire::solver::DirectNetworkSolver::evaluate,
py::arg("netIn"),
"Evaluate the network for a given timestep. Returns the output conditions after the timestep.");
"Evaluate the network for a given timestep. Returns the output conditions after the timestep."
);
py_direct_network_solver.def("set_callback",
[](gridfire::solver::DirectNetworkSolver &self, gridfire::solver::DirectNetworkSolver::TimestepCallback cb) {
self.set_callback(cb);
},
py::arg("callback"),
"Sets a callback function to be called at each timestep."
);
py::class_<gridfire::solver::DirectNetworkSolver::TimestepContext>(py_direct_network_solver, "TimestepContext")
.def_readonly("t", &gridfire::solver::DirectNetworkSolver::TimestepContext::t, "Current time in the simulation.")
.def_property_readonly(
"state", [](const gridfire::solver::DirectNetworkSolver::TimestepContext& ctx) {
std::vector<double> state(ctx.state.size());
std::ranges::copy(ctx.state, state.begin());
return py::array_t<double>(static_cast<ssize_t>(state.size()), state.data());
})
.def_readonly("dt", &gridfire::solver::DirectNetworkSolver::TimestepContext::dt, "Current timestep size.")
.def_readonly("cached_time", &gridfire::solver::DirectNetworkSolver::TimestepContext::cached_time, "Cached time for the last computed result.")
.def_readonly("last_observed_time", &gridfire::solver::DirectNetworkSolver::TimestepContext::last_observed_time, "Last time the state was observed.")
.def_readonly("last_step_time", &gridfire::solver::DirectNetworkSolver::TimestepContext::last_step_time, "Last step time taken for the integration.")
.def_readonly("T9", &gridfire::solver::DirectNetworkSolver::TimestepContext::T9, "Temperature in units of 10^9 K.")
.def_readonly("rho", &gridfire::solver::DirectNetworkSolver::TimestepContext::rho, "Temperature in units of 10^9 K.")
.def_property_readonly("cached_result", [](const gridfire::solver::DirectNetworkSolver::TimestepContext& ctx) -> py::object {
if (ctx.cached_result.has_value()) {
const auto&[dydt, nuclearEnergyGenerationRate] = ctx.cached_result.value();
return py::make_tuple(
py::array_t<double>(static_cast<ssize_t>(dydt.size()), dydt.data()),
nuclearEnergyGenerationRate
);
}
return py::none();
}, "Cached result of the step derivatives.")
.def_readonly("num_steps", &gridfire::solver::DirectNetworkSolver::TimestepContext::num_steps, "Total number of steps taken in the simulation.")
.def_property_readonly("engine", [](const gridfire::solver::DirectNetworkSolver::TimestepContext &ctx) -> const gridfire::DynamicEngine & {
return ctx.engine;
}, py::return_value_policy::reference)
.def_property_readonly("network_species", [](const gridfire::solver::DirectNetworkSolver::TimestepContext &ctx) -> const std::vector<fourdst::atomic::Species> & {
return ctx.networkSpecies;
}, py::return_value_policy::reference);
}

View File

@@ -2,4 +2,4 @@
#include <pybind11/pybind11.h>
void register_solver_bindings(pybind11::module &m);
void register_solver_bindings(const pybind11::module &m);

View File

@@ -5,6 +5,9 @@
#include <pybind11/functional.h> // Needed for std::function
#include <vector>
#include <tuple>
#include <string>
#include <any>
#include "py_solver.h"
@@ -19,3 +22,21 @@ gridfire::NetOut PyDynamicNetworkSolverStrategy::evaluate(const gridfire::NetIn
netIn // Arguments
);
}
void PyDynamicNetworkSolverStrategy::set_callback(const std::any &callback) {
PYBIND11_OVERRIDE_PURE(
void,
gridfire::solver::DynamicNetworkSolverStrategy, // Base class
set_callback, // Method name
callback // Arguments
);
}
std::vector<std::tuple<std::string, std::string>> PyDynamicNetworkSolverStrategy::describe_callback_context() const {
using DescriptionVector = std::vector<std::tuple<std::string, std::string>>;
PYBIND11_OVERRIDE_PURE(
DescriptionVector, // Return type
gridfire::solver::DynamicNetworkSolverStrategy, // Base class
describe_callback_context // Method name
);
}

View File

@@ -3,8 +3,13 @@
#include "gridfire/solver/solver.h"
#include <vector>
#include <tuple>
#include <string>
#include <any>
class PyDynamicNetworkSolverStrategy final : public gridfire::solver::DynamicNetworkSolverStrategy {
explicit PyDynamicNetworkSolverStrategy(gridfire::DynamicEngine &engine) : gridfire::solver::DynamicNetworkSolverStrategy(engine) {}
gridfire::NetOut evaluate(const gridfire::NetIn &netIn) override;
void set_callback(const std::any &callback) override;
std::vector<std::tuple<std::string, std::string>> describe_callback_context() const override;
};