From 13e2ea9ffad822d50f59b9c33b769e41554f33b2 Mon Sep 17 00:00:00 2001 From: Emily Boudreaux Date: Wed, 8 Oct 2025 11:17:35 -0400 Subject: [PATCH] feat(weak-reactions): brought weak reaction code up to a point where it will compile (NOT YET TESTED) --- .../gridfire/engine/procedures/construction.h | 2 +- src/include/gridfire/reaction/reaction.h | 231 +++++++++++- src/include/gridfire/reaction/weak/weak.h | 343 +++++++++++++++++- .../reaction/weak/weak_interpolator.h | 91 +++++ .../gridfire/reaction/weak/weak_types.h | 172 ++++++--- .../solver/strategies/CVODE_solver_strategy.h | 176 +++++++-- .../triggers/engine_partitioning_trigger.h | 208 ++++++++++- .../trigger/procedures/trigger_pprint.h | 18 + .../gridfire/trigger/trigger_abstract.h | 49 +++ .../gridfire/trigger/trigger_logical.h | 74 ++++ src/include/gridfire/trigger/trigger_result.h | 24 +- src/lib/engine/engine_graph.cpp | 6 +- src/lib/engine/views/engine_multiscale.cpp | 3 +- src/lib/reaction/reaction.cpp | 4 +- src/lib/reaction/weak/weak.cpp | 204 ++++++++--- 15 files changed, 1452 insertions(+), 153 deletions(-) diff --git a/src/include/gridfire/engine/procedures/construction.h b/src/include/gridfire/engine/procedures/construction.h index 2a942912..654f1197 100644 --- a/src/include/gridfire/engine/procedures/construction.h +++ b/src/include/gridfire/engine/procedures/construction.h @@ -33,7 +33,7 @@ namespace gridfire { * @return A LogicalReactionSet encapsulating the collected reactions for graph-based engines. * @throws std::logic_error If the resolved network depth is zero (no reactions can be collected). */ - reaction::ReactionSet build_reaclib_nuclear_network( + reaction::ReactionSet build_nuclear_network( const fourdst::composition::Composition &composition, const rates::weak::WeakRateInterpolator &weakInterpolator, BuildDepthType maxLayers = NetworkBuildDepth::Full, bool reverse = false diff --git a/src/include/gridfire/reaction/reaction.h b/src/include/gridfire/reaction/reaction.h index ab9a71a8..18305ea1 100644 --- a/src/include/gridfire/reaction/reaction.h +++ b/src/include/gridfire/reaction/reaction.h @@ -77,66 +77,253 @@ namespace gridfire::reaction { class Reaction { public: + /** + * @brief Virtual destructor for correct polymorphic cleanup. + */ virtual ~Reaction() = default; + /** + * @brief Compute the temperature- and composition-dependent reaction rate. + * + * This is the primary interface used by the network to obtain the rate of a single + * reaction at the given thermodynamic state and composition. The exact units and + * normalization are defined by the concrete implementation (e.g., REACLIB typically + * provides NA with units depending on the reaction order). Implementations may + * use density/electron properties for weak processes or screening, and the composition + * vector for multi-body reactions. + * + * @param T9 Temperature in GK (10^9 K). + * @param rho Mass density (g cm^-3). May be unused for some reaction types. + * @param Ye Electron fraction. May be unused depending on the reaction type. + * @param mue Electron chemical potential. May be unused depending on the reaction type. + * @param Y Composition vector (molar abundances or number fractions) indexed consistently + * with index_to_species_map. + * @param index_to_species_map Mapping from state-vector index to Species, used to interpret Y. + * @return The reaction rate for the forward direction, with units/normalization defined by the + * specific model (implementation must document its convention). + */ [[nodiscard]] virtual double calculate_rate( double T9, double rho, double Ye, - double mue, const std::vector &Y, const std::unordered_map& index_to_species_map + double mue, + const std::vector &Y, + const std::unordered_map& index_to_species_map ) const = 0; + + /** + * @brief AD-enabled reaction rate for algorithmic differentiation. + * + * This overload mirrors calculate_rate(double, ...) but operates on CppAD types to + * enable derivative calculations w.r.t. its inputs. + * + * @param T9 Temperature in GK as CppAD::AD. + * @param rho Mass density as CppAD::AD. + * @param Ye Electron fraction as CppAD::AD. + * @param mue Electron chemical potential as CppAD::AD. + * @param Y Composition vector as CppAD::AD values. + * @param index_to_species_map Mapping from state-vector index to Species, used to interpret Y. + * @return The reaction rate as a CppAD::AD value. + */ [[nodiscard]] virtual CppAD::AD calculate_rate( CppAD::AD T9, CppAD::AD rho, CppAD::AD Ye, - CppAD::AD mue, const std::vector>& Y, const std::unordered_map& index_to_species_map + CppAD::AD mue, + const std::vector>& Y, + const std::unordered_map& index_to_species_map ) const = 0; + /** + * @brief A stable, unique identifier for this reaction instance. + * @return String view of the reaction ID (stable across runs and suitable for lookups). + */ [[nodiscard]] virtual std::string_view id() const = 0; + + /** + * @brief Ordered list of reactant species. + * + * Multiplicity is represented by duplicates, e.g., (p, p) would list H1 twice. + * @return Const reference to the vector of reactants. + */ [[nodiscard]] virtual const std::vector& reactants() const = 0; + + /** + * @brief Ordered list of product species. + * + * Multiplicity is represented by duplicates if applicable. + * @return Const reference to the vector of products. + */ [[nodiscard]] virtual const std::vector& products() const = 0; + /** + * @brief True if the species appears as a reactant or a product. + * @param species Species to test. + * @return Whether the species participates in the reaction (either side). + */ [[nodiscard]] virtual bool contains(const fourdst::atomic::Species& species) const = 0; + + /** + * @brief True if the species appears among the reactants. + * @param species Species to test. + * @return Whether the species is a reactant. + */ [[nodiscard]] virtual bool contains_reactant(const fourdst::atomic::Species& species) const = 0; + + /** + * @brief True if the species appears among the products. + * @param species Species to test. + * @return Whether the species is a product. + */ [[nodiscard]] virtual bool contains_product(const fourdst::atomic::Species& species) const = 0; + /** + * @brief Whether this object represents a reverse (backward) rate. + * + * Implementations may pair forward/reverse rates for detailed balance. This flag indicates + * that the parameterization corresponds to the reverse direction. + * @return True for a reverse rate, false for a forward rate. + */ [[nodiscard]] virtual bool is_reverse() const = 0; + /** + * @brief Set of all unique species appearing in the reaction. + * @return Unordered set of all reactants and products (no duplicates). + */ [[nodiscard]] virtual std::unordered_set all_species() const = 0; + + /** + * @brief Set of unique reactant species. + * @return Unordered set of reactant species (no duplicates). + */ [[nodiscard]] virtual std::unordered_set reactant_species() const = 0; + + /** + * @brief Set of unique product species. + * @return Unordered set of product species (no duplicates). + */ [[nodiscard]] virtual std::unordered_set product_species() const = 0; + /** + * @brief Number of unique species involved in the reaction. + * @return Count of distinct species across reactants and products. + */ [[nodiscard]] virtual size_t num_species() const = 0; + /** + * @brief Full stoichiometry map for this reaction. + * + * Coefficients are negative for reactants and positive for products; multiplicity is reflected + * in the magnitude (e.g., 2H -> He gives H: -2, He: +1). + * @return Map from Species to integer stoichiometric coefficient. + */ [[nodiscard]] virtual std::unordered_map stoichiometry() const = 0; + + /** + * @brief Stoichiometric coefficient for a particular species. + * @param species Species for which to query the coefficient. + * @return Negative for reactants, positive for products, zero if absent. + */ [[nodiscard]] virtual int stoichiometry(const fourdst::atomic::Species& species) const = 0; + /** + * @brief Stable content-based hash for this reaction. + * + * Intended for use in caches, sets, and order-independent hashing of Reaction collections. + * Implementations should produce the same value across processes for the same content and seed. + * @param seed Seed value to initialize/mix into the hash. + * @return 64-bit hash value. + */ [[nodiscard]] virtual uint64_t hash(uint64_t seed) const = 0; + /** + * @brief Q-value of the reaction (typically MeV), positive if exothermic. + * @return Reaction Q-value used for energy accounting. + */ [[nodiscard]] virtual double qValue() const = 0; + /** + * @brief Convenience: energy generation rate from this reaction (double version). + * + * Default implementation multiplies the scalar rate by the reaction Q-value. Electron + * quantities (Ye, mue) are ignored in this default, so override in derived classes if + * needed. Sign convention follows qValue(). + * + * @param T9 Temperature in GK (10^9 K). + * @param rho Mass density (g cm^-3). + * @param Ye Electron fraction (ignored by default implementation). + * @param mue Electron chemical potential (ignored by default implementation). + * @param Y Composition vector. + * @param index_to_species_map Mapping from state-vector index to Species. + * @return Energy generation rate, typically rate * qValue(). + */ [[nodiscard]] virtual double calculate_energy_generation_rate( - double T9, - double rho, - double Ye, - double mue, const std::vector& Y, const std::unordered_map& index_to_species_map + const double T9, + const double rho, + const double Ye, + double mue, + const std::vector& Y, + const std::unordered_map& index_to_species_map ) const { return calculate_rate(T9, rho, 0, 0, Y, index_to_species_map) * qValue(); } + /** + * @brief Convenience: AD-enabled energy generation rate (AD version). + * + * Default implementation multiplies the AD rate by the reaction Q-value. Electron + * quantities (Ye, mue) are ignored in this default, so override if they contribute. + * + * @param T9 Temperature in GK as CppAD::AD. + * @param rho Mass density as CppAD::AD. + * @param Ye Electron fraction as CppAD::AD (ignored by default). + * @param mue Electron chemical potential as CppAD::AD (ignored by default). + * @param Y Composition vector as CppAD::AD values. + * @param index_to_species_map Mapping from state-vector index to Species. + * @return Energy generation rate as CppAD::AD. + */ [[nodiscard]] virtual CppAD::AD calculate_energy_generation_rate( const CppAD::AD& T9, const CppAD::AD& rho, const CppAD::AD &Ye, - const CppAD::AD &mue, const std::vector>& Y, const std::unordered_map& index_to_species_map + const CppAD::AD &mue, + const std::vector>& Y, + const std::unordered_map& index_to_species_map ) const { return calculate_rate(T9, rho, {}, {}, Y, index_to_species_map) * qValue(); } - [[nodiscard]] virtual double calculate_forward_rate_log_derivative(double T9, double rho, double Ye, double mue, const fourdst::composition::Composition& comp) const = 0; + /** + * @brief Logarithmic partial derivative of the rate with respect to temperature. + * + * Implementations return d(ln rate)/d(ln T9) or an equivalent measure (as documented by the + * concrete class), evaluated at the provided state. + * + * @param T9 Temperature in GK (10^9 K). + * @param rho Mass density (g cm^-3). + * @param Ye Electron fraction. + * @param mue Electron chemical potential. + * @param comp Composition object providing composition in a convenient form. + * @return The logarithmic temperature derivative of the rate. + */ + [[nodiscard]] virtual double calculate_log_rate_partial_deriv_wrt_T9( + double T9, + double rho, + double Ye, + double mue, + const fourdst::composition::Composition& comp + ) const = 0; + /** + * @brief Category of this reaction (e.g., REACLIB, WEAK, LOGICAL_REACLIB). + * @return Enumerated reaction type for runtime dispatch and filtering. + */ [[nodiscard]] virtual ReactionType type() const = 0; + /** + * @brief Polymorphic deep copy. + * @return A std::unique_ptr owning a new Reaction equal to this one. + */ [[nodiscard]] virtual std::unique_ptr clone() const = 0; friend std::ostream& operator<<(std::ostream& os, const Reaction& r) { @@ -161,15 +348,15 @@ namespace gridfire::reaction { * @param reverse True if this is a reverse reaction rate. */ ReaclibReaction( - const std::string_view id, - const std::string_view peName, - const int chapter, + std::string_view id, + std::string_view peName, + int chapter, const std::vector &reactants, const std::vector &products, - const double qValue, - const std::string_view label, + double qValue, + std::string_view label, const RateCoefficientSet &sets, - const bool reverse = false); + bool reverse = false); /** * @brief Calculates the reaction rate for a given temperature. @@ -185,7 +372,10 @@ namespace gridfire::reaction { double T9, double rho, double Ye, - double mue, const std::vector &Y, const std::unordered_map& index_to_species_map + double mue, + const std::vector &Y, + const std::unordered_map& index_to_species_map ) const override; /** @@ -205,7 +395,13 @@ namespace gridfire::reaction { CppAD::AD mue, const std::vector>& Y, const std::unordered_map& index_to_species_map ) const override; - [[nodiscard]] double calculate_forward_rate_log_derivative(double T9, double rho, double Ye, double mue, const fourdst::composition::Composition& comp) const override; + [[nodiscard]] double calculate_log_rate_partial_deriv_wrt_T9( + double T9, + double rho, + double Ye, + double mue, + const fourdst::composition::Composition& comp + ) const override; /** * @brief Gets the reaction name in (projectile, ejectile) notation. @@ -448,7 +644,7 @@ namespace gridfire::reaction { double mue, const std::vector &Y, const std::unordered_map& index_to_species_map ) const override; - [[nodiscard]] double calculate_forward_rate_log_derivative( + [[nodiscard]] double calculate_log_rate_partial_deriv_wrt_T9( double T9, double rho, double Ye, double mue, const fourdst::composition::Composition& comp @@ -686,4 +882,3 @@ namespace gridfire::reaction { ReactionSet packReactionSet(const ReactionSet& reactionSet); } - diff --git a/src/include/gridfire/reaction/weak/weak.h b/src/include/gridfire/reaction/weak/weak.h index fbffffd4..582bd196 100644 --- a/src/include/gridfire/reaction/weak/weak.h +++ b/src/include/gridfire/reaction/weak/weak.h @@ -24,29 +24,135 @@ namespace gridfire::rates::weak { + /** + * @class WeakReactionMap + * @brief Index of available weak reactions keyed by species. + * + * Builds an in-memory map from the compiled weak-rate tables and provides + * simple query helpers to retrieve all weak reactions or those that involve + * a particular nuclide. + * + * Implementation summary: the constructor iterates over UNIFIED_WEAK_DATA and + * inserts entries keyed by the parent Species. For each channel (β−, β+, e−-capture, + * e+-capture), if the tabulated log10(rate) is above the sentinel (-60), a + * WeakReactionEntry is pushed containing the grids t9, log10(rho*Ye), mu_e, the log10(rate), + * and the corresponding log10(neutrino loss) column. + */ class WeakReactionMap { public: + /** + * @brief Construct the map by loading all weak reaction entries. + * @post All valid reactions from the compiled data are available via + * get_all_reactions() and get_species_reactions(). + * Implementation: iterates UNIFIED_WEAK_DATA, filters any log(rate) <= -60, + * and groups entries by parent Species. + */ WeakReactionMap(); ~WeakReactionMap() = default; + /** + * @brief Return a flat list of all weak reaction entries. + * @return Vector of WeakReactionEntry records. + * @par Example + * @code + * WeakReactionMap map; + * auto all = map.get_all_reactions(); + * // iterate or group as needed + * @endcode + */ [[nodiscard]] std::vector get_all_reactions() const; + /** + * @brief Get all weak reaction entries for a given species. + * @param species Nuclide to query (A,Z). + * @return expected, WeakMapError> + * containing reactions on success or SPECIES_NOT_FOUND on failure. + * @par Example + * @code + * using fourdst::atomic::Species; + * WeakReactionMap map; + * Species fe52 = fourdst::atomic::az_to_species(52, 26); + * if (auto res = map.get_species_reactions(fe52); res) { + * for (const auto& e : *res) { } // use e + * } else { + * // handle WeakMapError::SPECIES_NOT_FOUND + * } + * @endcode + */ [[nodiscard]] std::expected, WeakMapError> get_species_reactions( - const fourdst::atomic::Species &species) const; + const fourdst::atomic::Species &species + ) const; + /** + * @brief Get all weak reaction entries for a given species by name. + * @param species_name Symbolic name (e.g., "Fe52"). + * @return expected, WeakMapError> + * containing reactions on success or SPECIES_NOT_FOUND on failure. + * @par Example + * @code + * WeakReactionMap map; + * if (auto res = map.get_species_reactions("Fe52"); res) { + * // use *res + * } + * @endcode + */ [[nodiscard]] std::expected, WeakMapError> get_species_reactions( - const std::string &species_name) const; + const std::string &species_name + ) const; private: std::unordered_map> m_weak_network; }; + /** + * @class WeakReaction + * @brief Concrete Reaction representing a single weak process (beta±, e−/e+ capture). + * + * Wraps interpolation logic for tabulated weak rates and provides both scalar and AD + * interfaces for rate and energy generation. The reactants/products are the parent/daughter + * nuclei of the weak process. + * + * @details the product nucleus is resolved from (A,Z) and channel via + * simple charge-changing rules (β−: Z+1; β+: Z−1; e− capture: Z−1; e+ capture: Z+1). + * The reaction ID is formatted like "Parent(channel)Product" with ν/ν̄ decorations, and + * an internal CppAD atomic (AtomicWeakRate) is prepared for AD energy calculations. + */ class WeakReaction final : public reaction::Reaction { public: + /** + * @brief Construct a WeakReaction for a specific weak channel and parent species. + * @param species Parent nuclide undergoing the weak process. + * @param type The weak reaction channel (beta−, beta+, e− capture, e+ capture). + * @param interpolator Reference to a WeakRateInterpolator providing tabulated data. + * @pre The product nuclide must be resolvable for the given (species, type). + * @post Object is ready to compute rates using the provided interpolator. + * @throws std::runtime_error If the product species cannot be resolved for the channel + * (product resolution uses the charge-changing rules described above). + */ explicit WeakReaction( const fourdst::atomic::Species &species, WeakReactionType type, const WeakRateInterpolator& interpolator ); + /** + * @brief Scalar weak reaction rate λ(T9, rho, Ye, μe) in 1/s. + * + * @details Performs a single interpolation of the weak-rate tables at (T9, log10(rho*Ye), μe). + * If the selected log10(rate) is ≤ sentinel (-60), returns 0; otherwise returns 10^{log10(rate)}. + * On interpolation failure, throws with a message including (A,Z) and the state point. + * + * @param T9 Temperature in GK (1e9 K). + * @param rho Mass density (g cm^-3). + * @param Ye Electron fraction. + * @param mue Electron chemical potential (MeV). + * @param Y Composition vector (unused for weak channels). + * @param index_to_species_map Index-to-species map (unused for weak channels). + * @return Reaction rate (1/s). + * @throws std::runtime_error On interpolation failure. + * @par Example + * @code + * double lambda = rxn.calculate_rate(2.0, 1e8, 0.4, 1.5, {}, {}); + * @endcode + */ [[nodiscard]] double calculate_rate( double T9, double rho, @@ -55,6 +161,26 @@ namespace gridfire::rates::weak { const std::vector &Y, const std::unordered_map& index_to_species_map ) const override; + /** + * @brief AD-enabled weak reaction rate λ(T9, rho, Ye, μe) in 1/s. + * + * @details Current implementation returns 0.0. AD support is provided for the energy-generation + * overload below using an internal CppAD atomic that evaluates both the rate and neutrino + * loss consistently. A future implementation may mirror that atomic here and return the AD rate. + * + * @param T9 Temperature in GK (AD type). + * @param rho Mass density (g cm^-3, AD type). + * @param Ye Electron fraction (AD type). + * @param mue Electron chemical potential (MeV, AD type). + * @param Y Composition vector (unused for weak channels). + * @param index_to_species_map Index-to-species map (unused for weak channels). + * @return Reaction rate (1/s) as CppAD::AD (currently 0.0). + * @par Example + * @code + * using AD = CppAD::AD; + * AD lambda_ad = rxn.calculate_rate(AD(3.0), AD(1e7), AD(0.5), AD(2.0), {}, {}); + * @endcode + */ [[nodiscard]] CppAD::AD calculate_rate( CppAD::AD T9, CppAD::AD rho, @@ -63,20 +189,108 @@ namespace gridfire::rates::weak { const std::vector> &Y, const std::unordered_map& index_to_species_map ) const override; - [[nodiscard]] std::string_view id() const override { return m_id; } - [[nodiscard]] const std::vector &reactants() const override { return {m_reactant}; } - [[nodiscard]] const std::vector &products() const override { return {m_product}; } + + /** + * @brief Unique identifier string for the weak channel. + * @return A stable string view (e.g., "Fe52(e-,ν)Mn52"). + */ + [[nodiscard]] std::string_view id() const override; + + /** + * @brief Reactants list (parent nuclide only). + * @return Vector with the parent species. + */ + [[nodiscard]] const std::vector &reactants() const override; + + /** + * @brief Products list (daughter nuclide only). + * @return Vector with the daughter species. + */ + [[nodiscard]] const std::vector &products() const override; + + /** + * @brief Check if a species participates in this weak reaction. + */ [[nodiscard]] bool contains(const fourdst::atomic::Species &species) const override; + + /** + * @brief Check if a species is the reactant (parent). + */ [[nodiscard]] bool contains_reactant(const fourdst::atomic::Species &species) const override; + + /** + * @brief Check if a species is the product (daughter). + */ [[nodiscard]] bool contains_product(const fourdst::atomic::Species &species) const override; + + /** + * @brief Set of both parent and daughter species. + */ [[nodiscard]] std::unordered_set all_species() const override; + + /** + * @brief Singleton set containing only the parent species. + */ [[nodiscard]] std::unordered_set reactant_species() const override; + + /** + * @brief Singleton set containing only the daughter species. + */ [[nodiscard]] std::unordered_set product_species() const override; - [[nodiscard]] size_t num_species() const override { return 2; } // Always 2 for weak reactions + + /** + * @brief Number of unique species involved (always 2 for weak reactions). + */ + [[nodiscard]] size_t num_species() const override; + + /** + * @brief Full stoichiometry map: parent -1, daughter +1. + */ [[nodiscard]] std::unordered_map stoichiometry() const override; + + /** + * @brief Stoichiometric coefficient for a species: -1 (parent), +1 (daughter), 0 otherwise. + */ [[nodiscard]] int stoichiometry(const fourdst::atomic::Species &species) const override; + + /** + * @brief Content-based 64-bit hash for this reaction. + */ [[nodiscard]] uint64_t hash(uint64_t seed) const override; + + /** + * @brief Q-value (MeV) based on nuclear mass differences and channel. + * + * Computes Q = (M_parent − M_daughter) c^2 converted to MeV. For β+ decay subtract 2 m_e c^2, + * for e+ capture add 2 m_e c^2; for β− and e− capture it is just the nuclear mass difference. + * Neutrino rest mass is ignored. + */ [[nodiscard]] double qValue() const override; + + /** + * @brief Net energy generation rate (MeV/s) for this weak reaction. + * + * Interpolates once to obtain both the log10(rate) and the appropriate log10(neutrino-loss) + * for the channel, converts to linear values, computes E_deposited = Q − ν_loss, and returns + * λ · E_deposited. Throws on interpolation failure. + * + * Channel mapping for neutrino-loss column: + * - β− decay and e+ capture: use log_antineutrino_loss_bd + * - β+ decay and e− capture: use log_neutrino_loss_ec + * + * @param T9 Temperature in GK. + * @param rho Density in g cm^-3. + * @param Ye Electron fraction. + * @param mue Electron chemical potential (MeV). + * @param Y Composition vector (unused for weak channels). + * @param index_to_species_map Index-to-species map (unused for weak channels). + * @return Energy generation rate in MeV/s. + * @throws std::runtime_error On interpolation failure. + * @par Example + * @code + * double eps = rxn.calculate_energy_generation_rate(3.0, 1e7, 0.5, 2.0, {}, {}); + * @endcode + */ [[nodiscard]] double calculate_energy_generation_rate( double T9, double rho, @@ -85,6 +299,16 @@ namespace gridfire::rates::weak { const std::vector& Y, const std::unordered_map& index_to_species_map ) const override; + + /** + * @brief AD-enabled net energy generation rate (MeV/s). + * + * Uses an internal CppAD atomic to compute two outputs at once: the rate λ and the neutrino + * loss ν_loss at (T9, log10(rho*Ye), μe). Returns λ · (Q − ν_loss). The atomic throws on + * interpolation failure. + * + * @throws std::runtime_error If the atomic rate evaluation fails to interpolate. + */ [[nodiscard]] CppAD::AD calculate_energy_generation_rate( const CppAD::AD &T9, const CppAD::AD &rho, @@ -93,18 +317,64 @@ namespace gridfire::rates::weak { const std::vector> &Y, const std::unordered_map &index_to_species_map ) const override; - [[nodiscard]] double calculate_forward_rate_log_derivative( + + /** + * @brief Logarithmic temperature sensitivity of the rate at the given state. + * + * Implementation status: requests derivative tables from the interpolator and throws on + * failure; otherwise the function is not yet implemented and does not return a value. + * Avoid calling until implemented. + * + * @param T9 Temperature in GK. + * @param rho Density in g cm^-3. + * @param Ye Electron fraction. + * @param mue Electron chemical potential (MeV). + * @param composition Composition context (not used by weak channels presently). + * @return d ln λ / d ln T9. + * @throws std::runtime_error On interpolation failure. + */ + [[nodiscard]] double calculate_log_rate_partial_deriv_wrt_T9( double T9, double rho, double Ye, double mue, const fourdst::composition::Composition& composition ) const override; - [[nodiscard]] reaction::ReactionType type() const override { return reaction::ReactionType::WEAK; } + + /** + * @brief Reaction type tag for runtime dispatch. + */ + [[nodiscard]] reaction::ReactionType type() const override; + + /** + * @brief Polymorphic deep copy. + */ [[nodiscard]] std::unique_ptr clone() const override; - [[nodiscard]] bool is_reverse() const override { return false; }; + + /** + * @brief Weak reactions are parameterized in the forward sense (never reverse). + */ + [[nodiscard]] bool is_reverse() const override; + + /** + * @brief Access the underlying rate interpolator used by this reaction. + */ + [[nodiscard]] const WeakRateInterpolator& getWeakRateInterpolator() const; private: + /** + * @brief Internal unified implementation for scalar/AD rate evaluation. + * @tparam T double or CppAD::AD. + * @param T9, rho, Ye, mue Thermodynamic state. + * @param Y Composition vector (unused for weak channels). + * @param index_to_species_map Index-to-species map (unused for weak channels). + * @return Reaction rate (1/s) as T. For double, performs table interpolation and returns + * 0 when the tabulated log10(rate) ≤ sentinel; for AD, calls the atomic and returns + * the first output. + * @pre T9 > 0, rho > 0, 0 < Ye <= 1. + * @post No persistent state is modified. + * @throws std::runtime_error If interpolation fails (double path) or the atomic fails (AD path). + */ template T calculate_rate( T T9, @@ -115,11 +385,37 @@ namespace gridfire::rates::weak { const std::unordered_map& index_to_species_map ) const; + /** + * @brief Extract the channel-specific log10(rate) from an interpolated payload. + * Mapping: β−→log_beta_minus, β+→log_beta_plus, e− capture→log_electron_capture, + * e+ capture→log_positron_capture. + */ double get_log_rate_from_payload(const WeakRatePayload& payload) const; + /** + * @brief Extract the channel-specific log10(neutrino loss) from a payload. + * Mapping: β−/e+ capture→log_antineutrino_loss_bd; β+/e− capture→log_neutrino_loss_ec. + */ + double get_log_neutrino_loss_from_payload(const WeakRatePayload& payload) const; + private: + /** + * @brief CppAD atomic that wraps weak-rate interpolation for AD evaluation. + * + * Forward pass computes two outputs (λ, ν_loss) by interpolating the tables at the + * provided state; reverse pass uses derivative tables to backpropagate adjoints for + * all three inputs (T9, log10(rho*Ye), μe). Sparsity routines declare full dependence + * of both outputs on all inputs. + */ class AtomicWeakRate final : public CppAD::atomic_base { public: + /** + * @brief Construct the atomic operation for a specific (A,Z) and channel. + * @param interpolator Rate source. + * @param a Mass number A of the parent. + * @param z Proton number Z of the parent. + * @param type Weak channel. + */ AtomicWeakRate( const WeakRateInterpolator& interpolator, const size_t a, @@ -132,6 +428,11 @@ namespace gridfire::rates::weak { m_z(z) , m_type(type) {} + /** + * @brief Forward pass: compute rate and neutrino-loss values for AD. + * On failure to interpolate, throws a std::runtime_error with details; sets output + * sparsity such that both outputs depend on all inputs when any input is variable. + */ bool forward( size_t p, size_t q, @@ -140,6 +441,12 @@ namespace gridfire::rates::weak { const CppAD::vector& tx, CppAD::vector& ty ) override; + + /** + * @brief Reverse pass: propagate adjoints using tabulated derivatives. + * Uses d log10 columns, converting to linear-scale derivatives via ln(10) scaling and + * chain rule with the forward-pass outputs. + */ bool reverse( size_t q, const CppAD::vector& tx, @@ -147,11 +454,19 @@ namespace gridfire::rates::weak { CppAD::vector& px, const CppAD::vector& py ) override; + + /** + * @brief Forward-mode sparsity for Jacobian. + */ bool for_sparse_jac( size_t q, const CppAD::vector>&r, CppAD::vector>& s ) override; + + /** + * @brief Reverse-mode sparsity for Jacobian. + */ bool rev_sparse_jac( size_t q, const CppAD::vector>&rt, @@ -190,6 +505,9 @@ namespace gridfire::rates::weak { fourdst::atomic::Species m_reactant; fourdst::atomic::Species m_product; + std::vector m_reactants; + std::vector m_products; + size_t m_reactant_a; size_t m_reactant_z; size_t m_product_a; @@ -200,10 +518,11 @@ namespace gridfire::rates::weak { const WeakRateInterpolator& m_interpolator; - AtomicWeakRate m_atomic; + mutable AtomicWeakRate m_atomic; }; + // template implementation lives at the end of the header for AD instantiation template T WeakReaction::calculate_rate( T T9, @@ -218,7 +537,7 @@ namespace gridfire::rates::weak { T rateConstant = static_cast(0.0); if constexpr (std::is_same_v>) { // Case where T is an AD type std::vector ax = {T9, log_rhoYe, mue}; - std::vector ay(1); + std::vector ay(2); m_atomic(ax, ay); rateConstant = static_cast(ay[0]); } else { // The case where T is of type double @@ -251,5 +570,3 @@ namespace gridfire::rates::weak { } } - - diff --git a/src/include/gridfire/reaction/weak/weak_interpolator.h b/src/include/gridfire/reaction/weak/weak_interpolator.h index 92e9c797..c5fdc63a 100644 --- a/src/include/gridfire/reaction/weak/weak_interpolator.h +++ b/src/include/gridfire/reaction/weak/weak_interpolator.h @@ -7,18 +7,79 @@ #include #include #include +#include namespace gridfire::rates::weak { + /** + * @class WeakRateInterpolator + * @brief 3D table interpolator for tabulated weak reaction data by isotope. + * + * Builds per-isotope 3D grids over (T9, log10(rho*Ye), mu_e) and provides: + * - Trilinear interpolation of the tabulated log10(rate) and neutrino-loss fields + * into a WeakRatePayload via get_rates(). + * - Finite-difference estimates of partial derivatives via get_rate_derivatives(). + * + * Implementation summary (constructor): rows are grouped by (A,Z), then each group's unique + * axis values are collected and sorted to form the three axes; the 3D payload array is + * populated at each lattice point with the 6 log10() fields from the raw table. + */ class WeakRateInterpolator { public: + /** + * @brief Raw weak-rate table type expected by the constructor. + * + * The size must match the number of rows compiled into the weak-rate library. + */ using RowDataTable = std::array; // Total number of entries in the weak rate table NOTE: THIS MUST EQUAL THE VALUE IN weak_rate_library.h + /** + * @brief Construct the interpolator from raw weak-rate rows. + * + * Groups rows by isotope (A,Z), extracts unique sorted axes for T9, log10(rho*Ye), and mu_e, + * and fills an internal regular grid with the log10(rate) and neutrino-loss payloads at each node. + * No interpolation occurs at construction time. + */ explicit WeakRateInterpolator(const RowDataTable& raw_data); + /** + * @brief List isotopes for which tables are available. + * @return Vector of available Species (A,Z) derived from internal tables. + * @throws std::runtime_error If any packed (A,Z) cannot be converted to Species. + * @par Example + * @code + * WeakRateInterpolator interp(rows); + * auto isotopes = interp.available_isotopes(); + * @endcode + */ [[nodiscard]] std::vector available_isotopes() const; + /** + * @brief Trilinear interpolation of weak-rate payload at a state. + * + * Interpolates the 6 log10() fields (rates and neutrino losses) at the given state + * for the requested isotope. If the isotope is unknown or the state lies outside + * the tabulated ranges, returns an error via std::expected with detailed bounds info. + * + * @param A Mass number of the isotope. + * @param Z Proton number of the isotope. + * @param t9 Temperature in GK (10^9 K). + * @param log_rhoYe Log10 of rho*Ye (cgs density times electron fraction). + * @param mu_e Electron chemical potential (MeV). + * @return expected: payload on success; + * InterpolationError::UNKNOWN_SPECIES_ERROR if (A,Z) not present; or + * InterpolationError::BOUNDS_ERROR if any coordinate is outside the table + * (with per-axis bounds included). + * @par Example + * @code + * if (auto res = interp.get_rates(52, 26, 3.0, 6.0, 2.0); res) { + * const WeakRatePayload& p = *res; + * } else { + * // inspect res.error().type and optional bounds info + * } + * @endcode + */ [[nodiscard]] std::expected get_rates( uint16_t A, uint8_t Z, @@ -27,6 +88,28 @@ namespace gridfire::rates::weak { double mu_e ) const; + /** + * @brief Finite-difference partial derivatives of the log10() fields. + * + * Uses central differences with small fixed (1e-6) perturbations in each variable + * (T9, log10(rho*Ye), mu_e) and returns arrays of d(log10(field))/d(var) for all fields. + * If any perturbed state falls outside the table, returns a BOUNDS_ERROR with per-axis + * bounds; if the isotope is unknown, returns UNKNOWN_SPECIES_ERROR. + * + * @param A Mass number of the isotope. + * @param Z Proton number of the isotope. + * @param t9 Temperature in GK (10^9 K). + * @param log_rhoYe Log10 of rho*Ye (cgs density times electron fraction). + * @param mu_e Electron chemical potential (MeV). + * @return expected: derivative payload on success; + * otherwise an InterpolationError as described above. + * @par Example + * @code + * if (auto d = interp.get_rate_derivatives(52, 26, 3.0, 6.0, 2.0); d) { + * // use d->d_log_beta_minus[0..2], etc. + * } + * @endcode + */ [[nodiscard]] std::expected get_rate_derivatives( uint16_t A, uint8_t Z, @@ -35,8 +118,16 @@ namespace gridfire::rates::weak { double mu_e ) const; private: + /** + * @brief Pack (A,Z) into a 32-bit key used for the internal map. + * + * Layout: (A << 8) | Z. To unpack, use (key >> 8) for A and (key & 0xFF) for Z. + */ static uint32_t pack_isotope_id(uint16_t A, uint8_t Z); + /** + * @brief Per-isotope grids over (T9, log10(rho*Ye), mu_e) with payloads at lattice nodes. + */ std::unordered_map m_rate_table; }; diff --git a/src/include/gridfire/reaction/weak/weak_types.h b/src/include/gridfire/reaction/weak/weak_types.h index 7f0ccd9e..793d4585 100644 --- a/src/include/gridfire/reaction/weak/weak_types.h +++ b/src/include/gridfire/reaction/weak/weak_types.h @@ -1,5 +1,14 @@ #pragma once +/** + * @file weak_types.h + * @brief Plain data structures and enums for weak reaction tables, interpolation payloads, and errors. + * + * This header defines the raw row format loaded from the unified weak-rate library, simple + * enumerations for channels and axes, compact payloads for interpolated values and derivatives, + * and error-reporting structures used by the interpolator. + */ + #include #include #include @@ -8,27 +17,48 @@ #include namespace gridfire::rates::weak { + /** + * @brief One row of the unified weak-rate data table for a specific isotope and state. + * + * Units and meanings: + * - t9: temperature in GK (10^9 K). + * - log_rhoye: base-10 logarithm of rho*Ye where rho is g cm^-3 and Ye is electron fraction. + * - mu_e: electron chemical potential in MeV. + * - log_*: base-10 logarithm of the tabulated rate or neutrino-energy loss term. + * + * Channel mappings: + * - beta-plus (β+): log_beta_plus, neutrino-loss column log_neutrino_loss_ec. + * - electron capture (e− cap): log_electron_capture, neutrino-loss column log_neutrino_loss_ec. + * - beta-minus (β−): log_beta_minus, neutrino-loss column log_antineutrino_loss_bd. + * - positron capture (e+ cap): log_positron_capture, neutrino-loss column log_antineutrino_loss_bd. + */ struct RateDataRow { - uint16_t A; - uint8_t Z; - float t9; - float log_rhoye; - float mu_e; - float log_beta_plus; - float log_electron_capture; - float log_neutrino_loss_ec; - float log_beta_minus; - float log_positron_capture; - float log_antineutrino_loss_bd; + uint16_t A; ///< Mass number. + uint8_t Z; ///< Proton number. + float t9; ///< Temperature in GK. + float log_rhoye; ///< log10(rho*Ye) (cgs density times electron fraction). + float mu_e; ///< Electron chemical potential (MeV). + float log_beta_plus; ///< log10(β+ decay rate). + float log_electron_capture; ///< log10(e− capture rate). + float log_neutrino_loss_ec; ///< log10(neutrino loss for β+ and e− capture). + float log_beta_minus; ///< log10(β− decay rate). + float log_positron_capture; ///< log10(e+ capture rate). + float log_antineutrino_loss_bd; ///< log10(antineutrino loss for β− and e+ capture). }; + /** + * @brief Weak reaction channel identifiers. + */ enum class WeakReactionType { - BETA_PLUS_DECAY, - BETA_MINUS_DECAY, - ELECTRON_CAPTURE, - POSITRON_CAPTURE, + BETA_PLUS_DECAY, ///< β+ decay: Z -> Z-1 + e+ + ν_e + BETA_MINUS_DECAY, ///< β− decay: Z -> Z+1 + e− + ν̄_e + ELECTRON_CAPTURE, ///< e− capture: (Z, e−) -> Z-1 + ν_e + POSITRON_CAPTURE, ///< e+ capture: (Z, e+) -> Z+1 + ν̄_e }; + /** + * @brief Enumeration of neutrino flavors (for potential extensions and tagging). + */ enum class NeutrinoTypes { ELECTRON_NEUTRINO, ELECTRON_ANTINEUTRINO, @@ -38,20 +68,35 @@ namespace gridfire::rates::weak { TAU_ANTINEUTRINO }; + /** + * @brief Lookup errors for WeakReactionMap queries. + */ enum class WeakMapError { - SPECIES_NOT_FOUND, + SPECIES_NOT_FOUND, ///< No entries for the requested Species. UNKNOWN_ERROR }; + /** + * @brief Interpolated weak-rate payload at a single state. + * + * All values are base-10 logarithms of the corresponding rates or neutrino-loss terms. + * Consumers typically convert with pow(10, log_value) and may apply sentinel thresholds + * at the usage site. + */ struct WeakRatePayload { - double log_beta_plus; - double log_electron_capture; - double log_neutrino_loss_ec; - double log_beta_minus; - double log_positron_capture; - double log_antineutrino_loss_bd; + double log_beta_plus; ///< log10(β+ decay rate). + double log_electron_capture; ///< log10(e− capture rate). + double log_neutrino_loss_ec; ///< log10(neutrino loss for β+ and e− capture). + double log_beta_minus; ///< log10(β− decay rate). + double log_positron_capture; ///< log10(e+ capture rate). + double log_antineutrino_loss_bd; ///< log10(antineutrino loss for β− and e+ capture). }; + /** + * @brief Partial derivatives of the log10() fields w.r.t. (T9, log10(rho*Ye), mu_e). + * + * Array ordering is [d/dT9, d/dlogRhoYe, d/dMuE] for each corresponding field. + */ struct WeakRateDerivatives { // Each array holds [d/dT9, d/dlogRhoYe, d/dMuE] std::array d_log_beta_plus; @@ -62,45 +107,74 @@ namespace gridfire::rates::weak { std::array d_log_antineutrino_loss_bd; }; + /** + * @brief Error categories for interpolation attempts. + */ enum class InterpolationErrorType { - BOUNDS_ERROR, - UNKNOWN_SPECIES_ERROR, + BOUNDS_ERROR, ///< Query outside the per-axis min/max of the table. + UNKNOWN_SPECIES_ERROR, ///< Requested (A,Z) not present in the tables. UNKNOWN_ERROR }; + /** + * @brief Human-readable names for InterpolationErrorType. + */ inline std::unordered_map InterpolationErrorTypeMap = { {InterpolationErrorType::BOUNDS_ERROR, "Bounds Error"}, {InterpolationErrorType::UNKNOWN_SPECIES_ERROR, "Unknown Species Error"}, {InterpolationErrorType::UNKNOWN_ERROR, "Unknown Error"} }; + /** + * @brief Axes of the interpolation table. + */ enum class TableAxes { - T9, - LOG_RHOYE, - MUE + T9, ///< Temperature in GK. + LOG_RHOYE, ///< log10(rho*Ye). + MUE ///< Electron chemical potential (MeV). }; + /** + * @brief Detailed bounds information for a BOUNDS_ERROR. + */ struct BoundsErrorInfo { - TableAxes axis; - double axisMinValue; - double axisMaxValue; - double queryValue; + TableAxes axis; ///< Axis on which the error occurred. + double axisMinValue; ///< Minimum tabulated value on the axis. + double axisMaxValue; ///< Maximum tabulated value on the axis. + double queryValue; ///< Requested value. }; + /** + * @brief Interpolation error with optional per-axis bounds details. + * + * For BOUNDS_ERROR, boundsErrorInfo may contain an entry per offending axis. + */ struct InterpolationError { - InterpolationErrorType type; + InterpolationErrorType type; ///< Error category. std::optional> boundsErrorInfo = std::nullopt; }; + /** + * @brief Regular 3D grid and payloads for a single isotope (A,Z). + * + * Axes are monotonically increasing per dimension. Data vector is laid out in + * row-major order with index computed as: + * index = ((i_t9 * rhoYe_axis.size() + j_rhoYe) * mue_axis.size()) + k_mue + */ struct IsotopeGrid { - std::vector t9_axis; - std::vector rhoYe_axis; - std::vector mue_axis; + std::vector t9_axis; ///< Unique sorted T9 grid. + std::vector rhoYe_axis;///< Unique sorted log10(rho*Ye) grid. + std::vector mue_axis; ///< Unique sorted mu_e grid. - // index = (i_t9 * logRhoYe_axis.size() + j_rhoYe) + mue_axis.size() + k_mue - std::vector data; + // index = ((i_t9 * rhoYe_axis.size() + j_rhoYe) * mue_axis.size()) + k_mue + std::vector data; ///< Payloads at each grid node. }; + /** + * @brief Abbreviated channel name used in printing and IDs. + * @param t Channel enum. + * @return Short name: bp, bm, ec, or pc. + */ constexpr std::string_view weak_reaction_type_name(const WeakReactionType t) noexcept { switch (t) { case WeakReactionType::BETA_PLUS_DECAY: return "bp"; @@ -111,13 +185,25 @@ namespace gridfire::rates::weak { return "Unknown"; } + /** + * @brief A single weak-reaction data point (type, state, and log values). + * + * All rates and losses are base-10 logarithms. Useful for listing and filtering + * weak entries for a Species. + * + * @par Example + * @code + * WeakReactionEntry e{WeakReactionType::ELECTRON_CAPTURE, 3.0f, 6.0f, 2.0f, -2.3f, -1.7f}; + * std::cout << e << "\n"; // prints a compact summary + * @endcode + */ struct WeakReactionEntry { - WeakReactionType type; - float T9; - float log_rhoYe; - float mu_e; - float log_rate; - float log_neutrino_loss; + WeakReactionType type; ///< Channel. + float T9; ///< Temperature in GK. + float log_rhoYe; ///< log10(rho*Ye). + float mu_e; ///< Electron chemical potential (MeV). + float log_rate; ///< Channel-specific log10(rate). + float log_neutrino_loss; ///< Corresponding log10(neutrino or antineutrino energy loss). friend std::ostream& operator<<(std::ostream& os, const WeakReactionEntry& reaction) { os << "WeakReactionEntry(type=" << weak_reaction_type_name(reaction.type) diff --git a/src/include/gridfire/solver/strategies/CVODE_solver_strategy.h b/src/include/gridfire/solver/strategies/CVODE_solver_strategy.h index 7df1b7ca..be986cda 100644 --- a/src/include/gridfire/solver/strategies/CVODE_solver_strategy.h +++ b/src/include/gridfire/solver/strategies/CVODE_solver_strategy.h @@ -40,9 +40,51 @@ #endif namespace gridfire::solver { + /** + * @class CVODESolverStrategy + * @brief Stiff ODE integrator backed by SUNDIALS CVODE (BDF) for network + energy. + * + * Integrates the nuclear network abundances along with a final accumulator entry for specific + * energy using CVODE's BDF method and a dense linear solver. The state vector layout is: + * [y_0, y_1, ..., y_{N-1}, eps], where eps is the accumulated specific energy (erg/g). + * + * Implementation summary: + * - Creates a SUNContext and CVODE memory; initializes the state from a Composition. + * - Enforces non-negativity on species via CVodeSetConstraints (>= 0 for all species slots). + * - Uses a user-provided DynamicEngine to compute RHS and to fill the dense Jacobian. + * - The Jacobian is assembled column-major into a SUNDenseMatrix; the energy row/column is + * currently set to zero (decoupled from abundances in the linearization). + * - An internal trigger can rebuild the engine/network; when triggered, CVODE resources are + * torn down and recreated with the new network size, preserving the energy accumulator. + * - The CVODE RHS wrapper captures exceptions::StaleEngineTrigger from the engine evaluation + * path as recoverable (return code 1) and stores a copy in user-data for the driver loop. + * + * @par Example + * @code + * using gridfire::solver::CVODESolverStrategy; + * using gridfire::solver::NetIn; + * + * CVODESolverStrategy solver(engine); + * NetIn in; + * in.temperature = 1.0e9; // K + * in.density = 1.0e6; // g/cm^3 + * in.tMax = 1.0; // s + * in.composition = initialComposition; + * auto out = solver.evaluate(in); + * std::cout << "Final energy: " << out.energy << " erg/g\n"; + * @endcode + */ class CVODESolverStrategy final : public DynamicNetworkSolverStrategy { public: + /** + * @brief Construct the CVODE strategy and create a SUNDIALS context. + * @param engine DynamicEngine used for RHS/Jacobian evaluation and network access. + * @throws std::runtime_error If SUNContext_Create fails. + */ explicit CVODESolverStrategy(DynamicEngine& engine); + /** + * @brief Destructor: cleans CVODE/SUNDIALS resources and frees SUNContext. + */ ~CVODESolverStrategy() override; // Make the class non-copyable and non-movable to prevent shallow copies of CVODE pointers @@ -51,47 +93,99 @@ namespace gridfire::solver { CVODESolverStrategy(CVODESolverStrategy&&) = delete; CVODESolverStrategy& operator=(CVODESolverStrategy&&) = delete; + /** + * @brief Integrate from t=0 to netIn.tMax and return final composition and energy. + * + * Implementation summary: + * - Converts temperature to T9, initializes CVODE memory and state (size = numSpecies + 1). + * - Repeatedly calls CVode in single-step or normal mode depending on stdout logging. + * - Wraps RHS to capture exceptions::StaleEngineTrigger as a recoverable step failure; + * if present after a step, it is rethrown for upstream handling. + * - Prints/collects diagnostics per step (step size, energy, solver iterations). + * - On trigger activation, rebuilds CVODE resources to reflect a changed network and + * reinitializes the state using the latest engine composition, preserving energy. + * - At the end, converts molar abundances to mass fractions and assembles NetOut, + * including derivatives of energy w.r.t. T and rho from the engine. + * + * @param netIn Inputs: temperature [K], density [g cm^-3], tMax [s], composition. + * @return NetOut containing final Composition, accumulated energy [erg/g], step count, + * and dEps/dT, dEps/dRho. + * @throws std::runtime_error If any CVODE or SUNDIALS call fails (negative return codes), + * or if internal consistency checks fail during engine updates. + * @throws exceptions::StaleEngineTrigger Propagated if the engine signals a stale state + * during RHS evaluation (captured in the wrapper then rethrown here). + */ NetOut evaluate(const NetIn& netIn) override; + /** + * @brief Install a timestep callback. + * @param callback std::any containing TimestepCallback (std::function). + * @throws std::bad_any_cast If callback is not of the expected type. + */ void set_callback(const std::any &callback) override; + /** + * @brief Whether per-step logs are printed to stdout and CVode is stepped with CV_ONE_STEP. + */ [[nodiscard]] bool get_stdout_logging_enabled() const; + /** + * @brief Enable/disable per-step stdout logging. + */ void set_stdout_logging_enabled(const bool value); + /** + * @brief Schema of fields exposed to the timestep callback context. + */ [[nodiscard]] std::vector> describe_callback_context() const override; + /** + * @struct TimestepContext + * @brief Immutable view of the current integration state passed to callbacks. + * + * Fields capture CVODE time/state, step size, thermodynamic state, the engine reference, + * and the list of network species used to interpret the state vector layout. + */ struct TimestepContext final : public SolverContextBase { // This struct can be identical to the one in DirectNetworkSolver - const double t; - const N_Vector& state; // Note: state is now an N_Vector - const double dt; - const double last_step_time; - const double T9; - const double rho; - const size_t num_steps; - const DynamicEngine& engine; - const std::vector& networkSpecies; + const double t; ///< Current integration time [s]. + const N_Vector& state; ///< Current CVODE state vector (N_Vector). + const double dt; ///< Last step size [s]. + const double last_step_time; ///< Time at last callback [s]. + const double T9; ///< Temperature in GK. + const double rho; ///< Density [g cm^-3]. + const size_t num_steps; ///< Number of CVODE steps taken so far. + const DynamicEngine& engine; ///< Reference to the engine. + const std::vector& networkSpecies; ///< Species layout. - // Constructor + /** + * @brief Construct a context snapshot. + */ TimestepContext( double t, const N_Vector& state, double dt, double last_step_time, double t9, double rho, size_t num_steps, const DynamicEngine& engine, const std::vector& networkSpecies ); + /** + * @brief Human-readable description of the context fields. + */ [[nodiscard]] std::vector> describe() const override; }; + /** + * @brief Type alias for a timestep callback. + */ using TimestepCallback = std::function; ///< Type alias for a timestep callback function. private: /** * @struct CVODEUserData * @brief A helper struct to pass C++ context to C-style CVODE callbacks. * - * CVODE callbacks are C functions and use a void* pointer to pass user data. - * This struct bundles all the necessary C++ objects (like 'this', engine references, etc.) - * to be accessed safely within those static C wrappers. + * Carries pointers back to the solver instance and engine, the current thermodynamic + * state, energy accumulator, and a slot to capture a copy of exceptions::StaleEngineTrigger + * from RHS evaluation. The RHS wrapper treats this as a recoverable failure and returns 1 + * to CVODE, then the driver loop inspects and rethrows. */ struct CVODEUserData { CVODESolverStrategy* solver_instance; // Pointer back to the class instance @@ -106,11 +200,36 @@ namespace gridfire::solver { private: fourdst::config::Config& m_config = fourdst::config::Config::getInstance(); quill::Logger* m_logger = fourdst::logging::LogManager::getInstance().getLogger("log"); + /** + * @brief CVODE RHS C-wrapper that delegates to calculate_rhs and captures exceptions. + * @return 0 on success; 1 on recoverable StaleEngineTrigger; -1 on other failures. + */ static int cvode_rhs_wrapper(sunrealtype t, N_Vector y, N_Vector ydot, void *user_data); + /** + * @brief CVODE dense Jacobian C-wrapper that fills SUNDenseMatrix using the engine. + * + * Assembles J(i,j) = d(f_i)/d(y_j) for all species using engine->getJacobianMatrixEntry, + * then zeros the last row and column corresponding to the energy variable. + */ static int cvode_jac_wrapper(sunrealtype t, N_Vector y, N_Vector ydot, SUNMatrix J, void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3); + /** + * @brief Compute RHS into ydot at time t from the engine and current state y. + * + * Converts the CVODE state to a Composition (mass fractions) and calls + * engine.calculateRHSAndEnergy(T9, rho). Negative small abundances are clamped to zero + * before constructing Composition. On stale engine, throws exceptions::StaleEngineTrigger. + */ void calculate_rhs(sunrealtype t, N_Vector y, N_Vector ydot, const CVODEUserData* data) const; + /** + * @brief Allocate and initialize CVODE vectors, linear algebra, tolerances, and constraints. + * + * State vector m_Y is sized to N (numSpecies + 1). Species slots are initialized from Composition + * molar abundances when present, otherwise a tiny positive value; the last slot is set to + * accumulatedEnergy. Sets scalar tolerances, non-negativity constraints for species, maximum + * step size, creates a dense matrix and dense linear solver, and registers the Jacobian. + */ void initialize_cvode_integration_resources( uint64_t N, size_t numSpecies, @@ -121,23 +240,34 @@ namespace gridfire::solver { double accumulatedEnergy ); + /** + * @brief Destroy CVODE vectors/linear algebra and optionally the CVODE memory block. + * @param memFree If true, also calls CVodeFree on m_cvode_mem. + */ void cleanup_cvode_resources(bool memFree); + /** + * @brief Compute and print per-component error ratios; run diagnostic helpers. + * + * Gathers CVODE's estimated local errors, converts the state to a Composition, and prints a + * sorted table of species with highest error ratios; then invokes diagnostic routines to + * inspect Jacobian stiffness and species balance. + */ void log_step_diagnostics(const CVODEUserData& user_data) const; private: - SUNContext m_sun_ctx = nullptr; - void* m_cvode_mem = nullptr; - N_Vector m_Y = nullptr; - N_Vector m_YErr = nullptr; - SUNMatrix m_J = nullptr; - SUNLinearSolver m_LS = nullptr; + SUNContext m_sun_ctx = nullptr; ///< SUNDIALS context (lifetime of the solver). + void* m_cvode_mem = nullptr; ///< CVODE memory block. + N_Vector m_Y = nullptr; ///< CVODE state vector (species + energy accumulator). + N_Vector m_YErr = nullptr; ///< Estimated local errors. + SUNMatrix m_J = nullptr; ///< Dense Jacobian matrix. + SUNLinearSolver m_LS = nullptr; ///< Dense linear solver. - TimestepCallback m_callback; - int m_num_steps = 0; + TimestepCallback m_callback; ///< Optional per-step callback. + int m_num_steps = 0; ///< CVODE step counter (used for diagnostics and triggers). - bool m_stdout_logging_enabled = true; + bool m_stdout_logging_enabled = true; ///< If true, print per-step logs and use CV_ONE_STEP. - N_Vector m_constraints = nullptr; + N_Vector m_constraints = nullptr; ///< CVODE constraints vector (>= 0 for species entries). }; } \ No newline at end of file diff --git a/src/include/gridfire/solver/strategies/triggers/engine_partitioning_trigger.h b/src/include/gridfire/solver/strategies/triggers/engine_partitioning_trigger.h index c900605a..64507542 100644 --- a/src/include/gridfire/solver/strategies/triggers/engine_partitioning_trigger.h +++ b/src/include/gridfire/solver/strategies/triggers/engine_partitioning_trigger.h @@ -9,81 +9,287 @@ #include #include +/** + * @file engine_partitioning_trigger.h + * @brief CVODE-specific triggers that decide when to (re)partition the reaction network engine. + * + * @details + * This header provides three concrete Trigger implementations: + * - SimulationTimeTrigger: fires when the simulated time advances by a fixed interval. + * - OffDiagonalTrigger: fires when any off-diagonal Jacobian entry exceeds a threshold. + * - TimestepCollapseTrigger: fires when the timestep changes sharply relative to a moving average. + * + * It also provides a convenience factory (makeEnginePartitioningTrigger) composing these triggers + * with logical combinators defined in trigger_logical.h. + * + * See the implementation for details: src/lib/solver/strategies/triggers/engine_partitioning_trigger.cpp + * + * @par Related headers: + * - trigger_abstract.h: base Trigger interface and lifecycle semantics + * - trigger_logical.h: AND/OR/NOT/EveryNth composition utilities + */ namespace gridfire::trigger::solver::CVODE { + /** + * @class SimulationTimeTrigger + * @brief Triggers when the current simulation time advances by at least a fixed interval. + * + * @details + * - check(ctx) returns true when (ctx.t - last_trigger_time) >= interval. + * - update(ctx) will, if check(ctx) is true, record ctx.t as the new last_trigger_time and + * store a small delta relative to the configured interval (for diagnostics/logging). + * - Counters (hits/misses/updates/resets) are maintained for diagnostics; they are + * mutable to allow updates from const check(). + * + * @par Constraints/Errors: + * - Constructing with a non-positive interval throws std::invalid_argument. + * + * @note Thread-safety: not thread-safe; intended for single-threaded trigger evaluation. + * + * See also: engine_partitioning_trigger.cpp for the concrete logic and logging. + */ class SimulationTimeTrigger final : public Trigger { public: + /** + * @brief Construct with a positive time interval between firings. + * @param interval Strictly positive time interval (simulation units) between triggers. + * @throws std::invalid_argument if interval <= 0. + */ explicit SimulationTimeTrigger(double interval); + /** + * @brief Evaluate whether enough simulated time has elapsed since the last trigger. + * @param ctx CVODE timestep context providing the current simulation time (ctx.t). + * @return true if (ctx.t - last_trigger_time) >= interval; false otherwise. + * + * @post increments hit/miss counters and may emit trace logs. + */ bool check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override; + /** + * @brief Update internal state; if check(ctx) is true, advance last_trigger_time. + * @param ctx CVODE timestep context. + * + * @note update() calls check(ctx) and, on success, records the overshoot delta + * (ctx.t - last_trigger_time) - interval for diagnostics. + */ void update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override; + /** + * @brief Reset counters and last trigger bookkeeping (time and delta) to zero. + */ void reset() override; + /** @brief Stable human-readable name. */ std::string name() const override; + /** + * @brief Structured explanation of the evaluation outcome. + * @param ctx CVODE timestep context. + * @return TriggerResult including name, value, and description. + */ TriggerResult why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override; + /** @brief Textual description including configured interval. */ std::string describe() const override; + /** @brief Number of true evaluations since last reset. */ size_t numTriggers() const override; + /** @brief Number of false evaluations since last reset. */ size_t numMisses() const override; private: + /** @brief Logger used for trace/error diagnostics. */ quill::Logger* m_logger = fourdst::logging::LogManager::getInstance().getLogger("log"); + /** @name Diagnostics counters */ + ///@{ mutable size_t m_hits = 0; mutable size_t m_misses = 0; mutable size_t m_updates = 0; mutable size_t m_resets = 0; + ///@} + /** @brief Required time interval between successive triggers. */ double m_interval; + /** @brief Time at which the trigger last fired; initialized to 0. */ mutable double m_last_trigger_time = 0.0; + /** @brief Overshoot relative to interval at the last firing; for diagnostics. */ mutable double m_last_trigger_time_delta = 0.0; }; + /** + * @class OffDiagonalTrigger + * @brief Triggers when any off-diagonal Jacobian entry magnitude exceeds a threshold. + * + * Semantics: + * - Iterates over all species pairs (row != col) and queries the engine's Jacobian + * via ctx.engine.getJacobianMatrixEntry(row, col). If any |entry| > threshold, + * check(ctx) returns true (short-circuits on first exceedance). + * - update(ctx) only records an update counter; it does not cache Jacobian values. + * + * @note Complexity: O(S^2) per check for S species (due to dense scan). + * + * @par Constraints/Errors: + * - Constructing with threshold < 0 throws std::invalid_argument. + * + * @par See also + * - engine_partitioning_trigger.cpp for concrete logic and trace logging. + */ class OffDiagonalTrigger final : public Trigger { public: + /** + * @brief Construct with a non-negative magnitude threshold. + * @param threshold Off-diagonal Jacobian magnitude threshold (>= 0). + * @throws std::invalid_argument if threshold < 0. + */ explicit OffDiagonalTrigger(double threshold); + /** + * @brief Check if any off-diagonal Jacobian entry exceeds the threshold. + * @param ctx CVODE timestep context providing access to engine species and Jacobian. + * @return true if max_{i!=j} |J(i,j)| > threshold; false otherwise. + * + * @post increments hit/miss counters and may emit trace logs. + */ bool check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override; + /** + * @brief Record an update; does not mutate any Jacobian-related state. + * @param ctx CVODE timestep context (unused except for symmetry with interface). + */ void update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override; + /** @brief Reset counters to zero. */ void reset() override; + /** @brief Stable human-readable name. */ std::string name() const override; + /** + * @brief Structured explanation of the evaluation outcome. + * @param ctx CVODE timestep context. + */ TriggerResult why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override; + /** @brief Textual description including configured threshold. */ std::string describe() const override; + /** @brief Number of true evaluations since last reset. */ size_t numTriggers() const override; + /** @brief Number of false evaluations since last reset. */ size_t numMisses() const override; private: + /** @brief Logger used for trace/error diagnostics. */ quill::Logger* m_logger = fourdst::logging::LogManager::getInstance().getLogger("log"); + /** @name Diagnostics counters */ + ///@{ mutable size_t m_hits = 0; mutable size_t m_misses = 0; mutable size_t m_updates = 0; mutable size_t m_resets = 0; + ///@} + /** @brief Off-diagonal magnitude threshold (>= 0). */ double m_threshold; }; + /** + * @class TimestepCollapseTrigger + * @brief Triggers when the timestep deviates from its recent average beyond a threshold. + * + * @details + * - Maintains a sliding window of recent dt values (size = windowSize). + * - check(ctx): + * - If the window is empty, returns false. + * - Computes the arithmetic mean of values in the window and compares either: + * - relative: |dt - mean| / mean >= threshold + * - absolute: |dt - mean| >= threshold + * - update(ctx): pushes ctx.dt into the fixed-size window (dropping oldest when full). + * + * @par Constraints/Errors: + * - threshold must be >= 0. + * - If relative==true, threshold must be in [0, 1]. Violations throw std::invalid_argument. + * + * @note + * - With windowSize==1, the mean is the most recent prior dt. + * - Counter fields are mutable to allow updates during const check(). + * + * See also: engine_partitioning_trigger.cpp for exact logic and logging. + */ class TimestepCollapseTrigger final : public Trigger { public: + /** + * @brief Construct with threshold and relative/absolute mode; window size defaults to 1. + * @param threshold Non-negative threshold; if relative, must be in [0, 1]. + * @param relative If true, use relative deviation; otherwise use absolute deviation. + * @throws std::invalid_argument on invalid threshold constraints. + */ explicit TimestepCollapseTrigger(double threshold, bool relative); + /** + * @brief Construct with threshold, mode, and window size. + * @param threshold Non-negative threshold; if relative, must be in [0, 1]. + * @param relative If true, use relative deviation; otherwise use absolute deviation. + * @param windowSize Number of dt samples to average over (>= 1 recommended). + * @throws std::invalid_argument on invalid threshold constraints. + */ explicit TimestepCollapseTrigger(double threshold, bool relative, size_t windowSize); + /** + * @brief Evaluate whether the current dt deviates sufficiently from recent average. + * @param ctx CVODE timestep context providing current dt. + * @return true if deviation exceeds the configured threshold; false otherwise. + * + * @post increments hit/miss counters and may emit trace logs. + */ bool check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override; + /** + * @brief Update sliding window with the most recent dt and increment update counter. + * @param ctx CVODE timestep context. + */ void update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override; + /** @brief Reset counters and clear the dt window. */ void reset() override; + /** @brief Stable human-readable name. */ std::string name() const override; + /** @brief Structured explanation of the evaluation outcome. */ TriggerResult why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override; + /** @brief Textual description including threshold, mode, and window size. */ std::string describe() const override; + /** @brief Number of true evaluations since last reset. */ size_t numTriggers() const override; + /** @brief Number of false evaluations since last reset. */ size_t numMisses() const override; private: + /** @brief Logger used for trace/error diagnostics. */ quill::Logger* m_logger = fourdst::logging::LogManager::getInstance().getLogger("log"); + /** @name Diagnostics counters */ + ///@{ mutable size_t m_hits = 0; mutable size_t m_misses = 0; mutable size_t m_updates = 0; mutable size_t m_resets = 0; + ///@} + /** @brief Threshold for absolute or relative deviation. */ double m_threshold; + /** @brief When true, use relative deviation; otherwise absolute deviation. */ bool m_relative; + /** @brief Number of dt samples to maintain in the moving window. */ size_t m_windowSize; + /** @brief Sliding window of recent timesteps (most recent at back). */ std::deque m_timestep_window; }; + /** + * @brief Compose a trigger suitable for deciding engine re-partitioning during CVODE solves. + * + * Policy (as of implementation): + * - OR of the following three conditions: + * 1) Every 1000th firing of SimulationTimeTrigger(simulationTimeInterval) + * 2) OffDiagonalTrigger(offDiagonalThreshold) + * 3) Every 10th firing of TimestepCollapseTrigger(timestepGrowthThreshold, + * timestepGrowthRelative, timestepGrowthWindowSize) + * + * See engine_partitioning_trigger.cpp for construction details using OrTrigger and + * EveryNthTrigger from trigger_logical.h. + * + * @param simulationTimeInterval Interval used by SimulationTimeTrigger (> 0). + * @param offDiagonalThreshold Off-diagonal Jacobian magnitude threshold (>= 0). + * @param timestepGrowthThreshold Threshold for timestep deviation (>= 0, and <= 1 when relative). + * @param timestepGrowthRelative Whether deviation is measured relatively. + * @param timestepGrowthWindowSize Window size for timestep averaging (>= 1 recommended). + * @return A unique_ptr to a composed Trigger implementing the policy above. + * + * @note The exact policy is subject to change; this function centralizes that decision. + */ std::unique_ptr> makeEnginePartitioningTrigger( const double simulationTimeInterval, const double offDiagonalThreshold, @@ -91,4 +297,4 @@ namespace gridfire::trigger::solver::CVODE { const bool timestepGrowthRelative, const size_t timestepGrowthWindowSize ); -} \ No newline at end of file +} diff --git a/src/include/gridfire/trigger/procedures/trigger_pprint.h b/src/include/gridfire/trigger/procedures/trigger_pprint.h index f6721f89..031b8d9d 100644 --- a/src/include/gridfire/trigger/procedures/trigger_pprint.h +++ b/src/include/gridfire/trigger/procedures/trigger_pprint.h @@ -5,6 +5,24 @@ #include namespace gridfire::trigger { + /** + * @brief Pretty-print a TriggerResult explanation tree to std::cout. + * + * Prints one line per node prefixed with Unicode bullets and indentation to visualize + * the explanation hierarchy. Each line shows [TRUE|FALSE], the node name, and description. + * + * @param result Root TriggerResult to print. + * @param indent Current indentation level (number of two-space indents); callers typically + * omit this parameter and let recursion handle it. + * + * @par Example + * @code + * using gridfire::trigger::TriggerResult; + * TriggerResult leaf{"A>5", "Threshold passed", true, {}}; + * TriggerResult root{"AND", "Both conditions true", true, {leaf}}; + * gridfire::trigger::printWhy(root); + * @endcode + */ inline void printWhy(const TriggerResult& result, const int indent = 0) { // NOLINT(*-no-recursion) const std::string prefix(indent * 2, ' '); std::cout << prefix << "• [" << (result.value ? "TRUE" : "FALSE") diff --git a/src/include/gridfire/trigger/trigger_abstract.h b/src/include/gridfire/trigger/trigger_abstract.h index 20465de8..6d366777 100644 --- a/src/include/gridfire/trigger/trigger_abstract.h +++ b/src/include/gridfire/trigger/trigger_abstract.h @@ -5,21 +5,70 @@ #include namespace gridfire::trigger { + /** + * @brief Generic trigger interface for signaling events/conditions during integration. + * + * A Trigger encapsulates a condition evaluated against a user-defined context struct + * (TriggerContextStruct). The interface supports: + * - check(): evaluate the current condition without mutating external state + * - update(): feed the context to update internal state (counters, sliding windows, etc.) + * - reset(): clear internal state and counters + * - name()/describe(): human-readable identification and textual description + * - why(): structured explanation (tree of TriggerResult) of the last evaluation + * - numTriggers()/numMisses(): simple counters for diagnostics + * + * Logical compositions (AND/OR/NOT/EveryNth) are provided in trigger_logical.h and implement + * this interface for any TriggerContextStruct. + * + * @tparam TriggerContextStruct User-defined context passed to triggers (e.g., timestep info). + */ template class Trigger { public: + /** + * @brief Virtual destructor for polymorphic use. + */ virtual ~Trigger() = default; + /** + * @brief Evaluate the trigger condition against the provided context. + * @param ctx Context snapshot (immutable view) used to evaluate the condition. + * @return true if the condition is satisfied; false otherwise. + */ virtual bool check(const TriggerContextStruct& ctx) const = 0; + /** + * @brief Update any internal state with the given context (e.g., counters, windows). + * @param ctx Context snapshot used to update state. + */ virtual void update(const TriggerContextStruct& ctx) = 0; + /** + * @brief Reset internal state and diagnostics counters. + */ virtual void reset() = 0; + /** + * @brief Short, stable name for this trigger (suitable for logs/UI). + */ [[nodiscard]] virtual std::string name() const = 0; + /** + * @brief Human-readable description of this trigger's logic. + */ [[nodiscard]] virtual std::string describe() const = 0; + /** + * @brief Explain why the last evaluation would be true/false in a structured way. + * @param ctx Context snapshot for the explanation. + * @return A TriggerResult tree with a boolean value and per-cause details. + */ [[nodiscard]] virtual TriggerResult why(const TriggerContextStruct& ctx) const = 0; + /** + * @brief Total number of times this trigger evaluated to true since last reset. + */ [[nodiscard]] virtual size_t numTriggers() const = 0; + /** + * @brief Total number of times this trigger evaluated to false since last reset. + */ [[nodiscard]] virtual size_t numMisses() const = 0; }; } \ No newline at end of file diff --git a/src/include/gridfire/trigger/trigger_logical.h b/src/include/gridfire/trigger/trigger_logical.h index b86c90ab..78fffbee 100644 --- a/src/include/gridfire/trigger/trigger_logical.h +++ b/src/include/gridfire/trigger/trigger_logical.h @@ -6,24 +6,75 @@ #include #include #include +#include namespace gridfire::trigger { + /** + * @file trigger_logical.h + * @brief Combinators for composing triggers with boolean logic (AND/OR/NOT/EveryNth). + * + * These templates wrap any Trigger and provide convenient composition. They also + * maintain simple hit/miss counters and implement short-circuit logic in check() and why(). + */ template class LogicalTrigger : public Trigger {}; + /** + * @class AndTrigger + * @brief Logical conjunction of two triggers with short-circuit evaluation. + * + * check(ctx) returns A.check(ctx) && B.check(ctx). The why(ctx) explanation short-circuits + * if A is false, avoiding evaluation of B. update(ctx) calls update() on both A and B. + * + * Counters (mutable) are incremented inside const check(): m_hits on true; m_misses on false; + * m_updates on each update(); m_resets on reset(). + * + * @par Example + * @code + * auto t = AndTrigger(ctxA, ctxB); + * if (t.check(ctx)) { (void)ctx; } + * @endcode + */ template class AndTrigger final : public LogicalTrigger { public: + /** + * @brief Construct AND from two triggers (ownership transferred). + */ AndTrigger(std::unique_ptr> A, std::unique_ptr> B); ~AndTrigger() override = default; + /** + * @brief Evaluate A && B; increments hit/miss counters. + */ bool check(const TriggerContextStruct& ctx) const override; + /** + * @brief Update both sub-triggers and increment update counter. + */ void update(const TriggerContextStruct& ctx) override; + /** + * @brief Reset both sub-triggers and local counters. + */ void reset() override; + /** + * @brief Human-readable name. + */ std::string name() const override; + /** + * @brief Structured explanation; short-circuits on A=false. + */ TriggerResult why(const TriggerContextStruct& ctx) const override; + /** + * @brief Description expression e.g. "(A) AND (B)". + */ std::string describe() const override; + /** + * @brief Number of true evaluations since last reset. + */ size_t numTriggers() const override; + /** + * @brief Number of false evaluations since last reset. + */ size_t numMisses() const override; private: std::unique_ptr> m_A; @@ -35,6 +86,13 @@ namespace gridfire::trigger { mutable size_t m_resets = 0; }; + /** + * @class OrTrigger + * @brief Logical disjunction of two triggers with short-circuit evaluation. + * + * check(ctx) returns A.check(ctx) || B.check(ctx). why(ctx) returns early when A is true. + * update(ctx) calls update() on both A and B. Counters behave as in AndTrigger. + */ template class OrTrigger final : public LogicalTrigger { public: @@ -60,6 +118,13 @@ namespace gridfire::trigger { mutable size_t m_resets = 0; }; + /** + * @class NotTrigger + * @brief Logical negation of a trigger. + * + * check(ctx) returns !A.check(ctx). why(ctx) explains the inverted condition. Counter + * semantics match the other logical triggers. + */ template class NotTrigger final : public LogicalTrigger { public: @@ -86,6 +151,15 @@ namespace gridfire::trigger { }; + /** + * @class EveryNthTrigger + * @brief Pass-through trigger that fires every Nth time its child trigger is true. + * + * On update(ctx), increments an internal counter when A.check(ctx) is true. check(ctx) + * returns true only when A.check(ctx) is true and the internal counter is a multiple of N. + * + * @throws std::invalid_argument When constructed with N==0. + */ template class EveryNthTrigger final : public LogicalTrigger { public: diff --git a/src/include/gridfire/trigger/trigger_result.h b/src/include/gridfire/trigger/trigger_result.h index 9b865a91..409b809e 100644 --- a/src/include/gridfire/trigger/trigger_result.h +++ b/src/include/gridfire/trigger/trigger_result.h @@ -4,10 +4,26 @@ #include namespace gridfire::trigger { + /** + * @file trigger_result.h + * @brief Structured explanation node for trigger evaluations. + * + * TriggerResult represents a tree describing why a trigger evaluated to true/false. + * Each node contains a boolean value, a short name, a human-readable description, + * and optional nested causes for composite triggers (e.g., AND/OR/NOT). + * + * @par Example + * @code + * // Produce a result and pretty-print it + * gridfire::trigger::TriggerResult r{"A>5", "Threshold passed", true, {}}; + * // See procedures/trigger_pprint.h for printWhy() + * // gridfire::trigger::printWhy(r); + * @endcode + */ struct TriggerResult { - std::string name; - std::string description; - bool value; - std::vector causes; + std::string name; ///< Short identifier for the condition (e.g., "Temperature Rise"). + std::string description; ///< Human-readable reason summarizing the outcome at this node. + bool value; ///< Evaluation result for this node (true/false). + std::vector causes; ///< Sub-reasons for composite triggers. }; } \ No newline at end of file diff --git a/src/lib/engine/engine_graph.cpp b/src/lib/engine/engine_graph.cpp index c84f0c1a..39027f66 100644 --- a/src/lib/engine/engine_graph.cpp +++ b/src/lib/engine/engine_graph.cpp @@ -40,7 +40,7 @@ namespace gridfire { const partition::PartitionFunction& partitionFunction, const BuildDepthType buildDepth) : m_weakRateInterpolator(rates::weak::UNIFIED_WEAK_DATA), - m_reactions(build_reaclib_nuclear_network(composition, m_weakRateInterpolator, buildDepth, false)), + m_reactions(build_nuclear_network(composition, m_weakRateInterpolator, buildDepth, false)), m_depth(buildDepth), m_partitionFunction(partitionFunction.clone()) { @@ -419,7 +419,7 @@ namespace gridfire { double Ye = comp.getElectronAbundance(); // TODO: This is a dummy value for the electron chemical potential. We eventually need to replace this with an EOS call. double mue = 5.0e-3 * std::pow(rho * Ye, 1.0 / 3.0) + 0.5 * T9; - const double d_log_kFwd = reaction.calculate_forward_rate_log_derivative(T9, rho, Ye, mue, comp); + const double d_log_kFwd = reaction.calculate_log_rate_partial_deriv_wrt_T9(T9, rho, Ye, mue, comp); auto log_deriv_pf_op = [&](double acc, const auto& species) { const double g = m_partitionFunction->evaluate(species.z(), species.a(), T9); @@ -505,7 +505,7 @@ namespace gridfire { void GraphEngine::rebuild(const fourdst::composition::Composition& comp, const BuildDepthType depth) { if (depth != m_depth) { m_depth = depth; - m_reactions = build_reaclib_nuclear_network(comp, m_weakRateInterpolator, m_depth, false); + m_reactions = build_nuclear_network(comp, m_weakRateInterpolator, m_depth, false); syncInternalMaps(); // Resync internal maps after changing the depth } else { LOG_DEBUG(m_logger, "Rebuild requested with the same depth. No changes made to the network."); diff --git a/src/lib/engine/views/engine_multiscale.cpp b/src/lib/engine/views/engine_multiscale.cpp index 5dac52f2..81fc596a 100644 --- a/src/lib/engine/views/engine_multiscale.cpp +++ b/src/lib/engine/views/engine_multiscale.cpp @@ -170,7 +170,6 @@ namespace gridfire { } auto deriv = result.value(); - //TODO: Sort out how to deal with this. Need to return something like a step derivative but with the index consistent... for (const auto& species : m_algebraic_species) { deriv.dydt[species] = 0.0; // Fix the algebraic species to the equilibrium abundances we calculate. } @@ -1115,7 +1114,7 @@ namespace gridfire { normalized_composition.getMolarAbundance(species), Y_final_qse(i) ); - //TODO: CHeck this conversion + //TODO: Check this conversion double Xi = Y_final_qse(i) * species.mass(); // Convert from molar abundance to mass fraction if (!outputComposition.contains(species)) { outputComposition.registerSpecies(species); diff --git a/src/lib/reaction/reaction.cpp b/src/lib/reaction/reaction.cpp index d03a29db..33dde775 100644 --- a/src/lib/reaction/reaction.cpp +++ b/src/lib/reaction/reaction.cpp @@ -67,7 +67,7 @@ namespace gridfire::reaction { return calculate_rate>(T9); } - double ReaclibReaction::calculate_forward_rate_log_derivative( + double ReaclibReaction::calculate_log_rate_partial_deriv_wrt_T9( const double T9, const double rho, double Ye, @@ -243,7 +243,7 @@ namespace gridfire::reaction { return calculate_rate(T9); } - double LogicalReaclibReaction::calculate_forward_rate_log_derivative( + double LogicalReaclibReaction::calculate_log_rate_partial_deriv_wrt_T9( const double T9, const double rho, double Ye, double mue, const fourdst::composition::Composition& comp diff --git a/src/lib/reaction/weak/weak.cpp b/src/lib/reaction/weak/weak.cpp index 7f5c1085..428f9008 100644 --- a/src/lib/reaction/weak/weak.cpp +++ b/src/lib/reaction/weak/weak.cpp @@ -162,6 +162,8 @@ namespace gridfire::rates::weak { ) : m_reactant(species), m_product(resolve_weak_product(type, species)), + m_reactants({m_reactant}), + m_products({m_product}), m_reactant_a(species.a()), m_reactant_z(species.z()), m_product_a(m_product.a()), @@ -183,12 +185,26 @@ namespace gridfire::rates::weak { } CppAD::AD WeakReaction::calculate_rate( - CppAD::AD T9, - CppAD::AD rho, - CppAD::AD Ye, - CppAD::AD mue, const std::vector> &Y, const std::unordered_map& index_to_species_map + const CppAD::AD T9, + const CppAD::AD rho, + const CppAD::AD Ye, + const CppAD::AD mue, + const std::vector> &Y, + const std::unordered_map& index_to_species_map ) const { - return static_cast>(0.0); + return calculate_rate>(T9, rho, Ye, mue, Y, index_to_species_map); + } + + std::string_view WeakReaction::id() const { + return m_id; + } + + const std::vector & WeakReaction::reactants() const { + return m_reactants; + } + + const std::vector & WeakReaction::products() const { + return m_products; } bool WeakReaction::contains(const fourdst::atomic::Species &species) const { @@ -221,6 +237,10 @@ namespace gridfire::rates::weak { return {m_product}; } + size_t WeakReaction::num_species() const { + return 2; + } + int WeakReaction::stoichiometry(const fourdst::atomic::Species &species) const { if (species == m_reactant) { return -1; @@ -330,14 +350,49 @@ namespace gridfire::rates::weak { const std::unordered_map &index_to_species_map ) const { const CppAD::AD log_rhoYe = CppAD::log10(rho * Ye); - std::vector> ax = {T9, log_rhoYe, mue}; - std::vector> ay(1); - m_atomic(ax, ay); // TODO: Sort out why this isn't working and checkline 222 in weak.h where a similar line is - //TODO: think about how to get out neutrino loss in a autodiff safe way. This may mean I need to add an extra output to the atomic base - // so that I can get out both the rate and the neutrino loss rate. This will also mean that the sparsity pattern will need to - // be updated to account for the extra output. - CppAD::AD rateConstant = ay[0]; + const std::vector> ax = {T9, log_rhoYe, mue}; + std::vector> ay(2); // 2 outputs are the reaction rate (1/s) and the neutrino loss (MeV) + + m_atomic(ax, ay); // Note: We needed to make m_atomic mutable to allow this call in a const method. + + const CppAD::AD rateConstant = ay[0]; + const CppAD::AD NuLoss = ay[1]; + + return rateConstant * (qValue() - NuLoss); // returns in MeV / s + } + + double WeakReaction::calculate_log_rate_partial_deriv_wrt_T9( + const double T9, + const double rho, + const double Ye, + const double mue, + const fourdst::composition::Composition &composition + ) const { + const double log_rhoYe = std::log10(rho * Ye); + std::expected rates = m_interpolator.get_rate_derivatives( + static_cast(m_reactant_a), + static_cast(m_reactant_z), + T9, + log_rhoYe, + mue + ); + if (!rates.has_value()) { + const InterpolationErrorType type = rates.error().type; + const std::string msg = std::format( + "Failed to interpolate weak rate for (A={}, Z={}) at T9={}, log10(rho*Ye)={}, mu_e={} with error: {}", + m_reactant.name(), m_reactant_a, m_reactant_z, T9, log_rhoYe, mue, InterpolationErrorTypeMap.at(type) + ); + throw std::runtime_error(msg); + } + + // TODO: Finish implementing this (just need a switch statement) + return 0.0; + + } + + reaction::ReactionType WeakReaction::type() const { + return reaction::ReactionType::WEAK; } std::unique_ptr WeakReaction::clone() const { @@ -349,6 +404,14 @@ namespace gridfire::rates::weak { return reaction_ptr; } + bool WeakReaction::is_reverse() const { + return false; + } + + const WeakRateInterpolator & WeakReaction::getWeakRateInterpolator() const { + return m_interpolator; + } + double WeakReaction::get_log_rate_from_payload(const WeakRatePayload &payload) const { double logRate = 0.0; switch (m_type) { @@ -368,6 +431,25 @@ namespace gridfire::rates::weak { return logRate; } + double WeakReaction::get_log_neutrino_loss_from_payload(const WeakRatePayload &payload) const { + double logNeutrinoLoss = 0.0; + switch (m_type) { + case WeakReactionType::BETA_MINUS_DECAY: + logNeutrinoLoss = payload.log_antineutrino_loss_bd; + break; + case WeakReactionType::BETA_PLUS_DECAY: + logNeutrinoLoss = payload.log_neutrino_loss_ec; + break; + case WeakReactionType::ELECTRON_CAPTURE: + logNeutrinoLoss = payload.log_neutrino_loss_ec; + break; + case WeakReactionType::POSITRON_CAPTURE: + logNeutrinoLoss = payload.log_antineutrino_loss_bd; + break; + } + return logNeutrinoLoss; + } + bool WeakReaction::AtomicWeakRate::forward ( const size_t p, const size_t q, @@ -393,28 +475,35 @@ namespace gridfire::rates::weak { ); if (!result.has_value()) { const InterpolationErrorType type = result.error().type; - std::string msg = std::format( + const std::string msg = std::format( "Failed to interpolate weak rate for (A={}, Z={}) at T9={}, log10(rho*Ye)={}, mu_e={} with error: {}", m_a, m_z, T9, log10_rhoye, mu_e, InterpolationErrorTypeMap.at(type) ); + throw std::runtime_error(msg); } switch (m_type) { case WeakReactionType::BETA_MINUS_DECAY: ty[0] = std::pow(10, result.value().log_beta_minus); + ty[1] = std::pow(10, result.value().log_antineutrino_loss_bd); break; case WeakReactionType::BETA_PLUS_DECAY: ty[0] = std::pow(10, result.value().log_beta_plus); + ty[1] = std::pow(10, result.value().log_neutrino_loss_ec); break; case WeakReactionType::ELECTRON_CAPTURE: ty[0] = std::pow(10, result.value().log_electron_capture); + ty[1] = std::pow(10, result.value().log_neutrino_loss_ec); break; case WeakReactionType::POSITRON_CAPTURE: ty[0] = std::pow(10, result.value().log_positron_capture); + ty[1] = std::pow(10, result.value().log_antineutrino_loss_bd); break; } - if (vx.size() > 0) { - vy[0] = vx[0] || vx[1] || vx[2]; // Sets the output sparsity pattern + if (vx.size() > 0) { // Set up the sparsity pattern. This is saying that all input variables affect the output variable. + const bool any_input_varies = vx[0] || vx[1] || vx[2]; + vy[0] = any_input_varies; + vy[1] = any_input_varies; } return true; } @@ -430,6 +519,9 @@ namespace gridfire::rates::weak { const double log10_rhoye = tx[1]; const double mu_e = tx[2]; + const double forwardPassRate = ty[0]; // This is the rate from the forward pass. + const double forwardPassNeutrinoLossRate = ty[1]; // This is the neutrino loss rate from the forward pass. + const std::expected result = m_interpolator.get_rate_derivatives( static_cast(m_a), static_cast(m_z), @@ -447,37 +539,56 @@ namespace gridfire::rates::weak { throw std::runtime_error(msg); } - WeakRateDerivatives derivatives = result.value(); + const WeakRateDerivatives derivatives = result.value(); - double dT9 = 0.0; - double dRho = 0.0; - double dMuE = 0.0; + std::array dLogRate; // d(rate)/dT9, d(rate)/dlogRhoYe, d(rate)/dMuE + std::array dLogNuLoss; // d(nu loss)/dT9, d(nu loss)/dlogRhoYe, d(nu loss)/dMuE switch (m_type) { case WeakReactionType::BETA_MINUS_DECAY: - dT9 = py[0] * derivatives.d_log_beta_minus[0]; - dRho = py[0] * derivatives.d_log_beta_minus[1]; - dMuE = py[0] * derivatives.d_log_beta_minus[2]; + dLogRate[0] = derivatives.d_log_beta_minus[0]; + dLogRate[1] = derivatives.d_log_beta_minus[1]; + dLogRate[2] = derivatives.d_log_beta_minus[2]; + dLogNuLoss[0] = derivatives.d_log_antineutrino_loss_bd[0]; + dLogNuLoss[1] = derivatives.d_log_antineutrino_loss_bd[1]; + dLogNuLoss[2] = derivatives.d_log_antineutrino_loss_bd[2]; break; case WeakReactionType::BETA_PLUS_DECAY: - dT9 = py[0] * derivatives.d_log_beta_plus[0]; - dRho = py[0] * derivatives.d_log_beta_plus[1]; - dMuE = py[0] * derivatives.d_log_beta_plus[2]; + dLogRate[0] = derivatives.d_log_beta_plus[0]; + dLogRate[1] = derivatives.d_log_beta_plus[1]; + dLogRate[2] = derivatives.d_log_beta_plus[2]; + dLogNuLoss[0] = derivatives.d_log_neutrino_loss_ec[0]; + dLogNuLoss[1] = derivatives.d_log_neutrino_loss_ec[1]; + dLogNuLoss[2] = derivatives.d_log_neutrino_loss_ec[2]; break; case WeakReactionType::ELECTRON_CAPTURE: - dT9 = py[0] * derivatives.d_log_electron_capture[0]; - dRho = py[0] * derivatives.d_log_electron_capture[1]; - dMuE = py[0] * derivatives.d_log_electron_capture[2]; + dLogRate[0] = derivatives.d_log_electron_capture[0]; + dLogRate[1] = derivatives.d_log_electron_capture[1]; + dLogRate[2] = derivatives.d_log_electron_capture[2]; + dLogNuLoss[0] = derivatives.d_log_neutrino_loss_ec[0]; + dLogNuLoss[1] = derivatives.d_log_neutrino_loss_ec[1]; + dLogNuLoss[2] = derivatives.d_log_neutrino_loss_ec[2]; break; case WeakReactionType::POSITRON_CAPTURE: - dT9 = py[0] * derivatives.d_log_positron_capture[0]; - dRho = py[0] * derivatives.d_log_positron_capture[1]; - dMuE = py[0] * derivatives.d_log_positron_capture[2]; + dLogRate[0] = derivatives.d_log_positron_capture[0]; + dLogRate[1] = derivatives.d_log_positron_capture[1]; + dLogRate[2] = derivatives.d_log_positron_capture[2]; + dLogNuLoss[0] = derivatives.d_log_antineutrino_loss_bd[0]; + dLogNuLoss[1] = derivatives.d_log_antineutrino_loss_bd[1]; + dLogNuLoss[2] = derivatives.d_log_antineutrino_loss_bd[2]; break; } - px[0] = py[0] * dT9; // d(rate)/dT9 - px[1] = py[0] * dRho; // d(rate)/dlogRhoYe - px[2] = py[0] * dMuE; // d(rate)/dMuE + const double ln10 = std::log(10.0); + + // Contributions from the reaction rate (output 0) + px[0] = py[0] * forwardPassRate * ln10 * dLogRate[0]; + px[1] = py[0] * forwardPassRate * ln10 * dLogRate[1]; + px[2] = py[0] * forwardPassRate * ln10 * dLogRate[2]; + + // Contributions from the neutrino loss rate (output 1) + px[0] += py[1] * forwardPassNeutrinoLossRate * ln10 * dLogNuLoss[0]; + px[1] += py[1] * forwardPassNeutrinoLossRate * ln10 * dLogNuLoss[1]; + px[2] += py[1] * forwardPassNeutrinoLossRate * ln10 * dLogNuLoss[2]; return true; @@ -488,9 +599,14 @@ namespace gridfire::rates::weak { const CppAD::vector > &r, CppAD::vector > &s ) { - s[0] = r[0]; - s[0].insert(r[1].begin(), r[1].end()); - s[0].insert(r[2].begin(), r[2].end()); + std::set all_input_deps; + all_input_deps.insert(r[0].begin(), r[0].end()); + all_input_deps.insert(r[1].begin(), r[1].end()); + all_input_deps.insert(r[2].begin(), r[2].end()); + + // What this is saying is that both output variables depend on all input variables. + s[0] = all_input_deps; + s[1] = all_input_deps; return true; } @@ -500,12 +616,14 @@ namespace gridfire::rates::weak { const CppAD::vector > &rt, CppAD::vector > &st ) { - // What this is saying is that each of the three input variables (T9, rho, Ye) - // all only affect the output variable (the rate) since there is only - // one output variable. - st[0] = rt[0]; - st[1] = rt[0]; - st[2] = rt[0]; + // What this is saying is that all input variables may affect both output variables. + std::set all_output_deps; + all_output_deps.insert(rt[0].begin(), rt[0].end()); + all_output_deps.insert(rt[1].begin(), rt[1].end()); + + st[0] = all_output_deps; + st[1] = all_output_deps; + st[2] = all_output_deps; return true; }