From 9fdbb579964e8b43a4092412ffb20169a5f8a1dd Mon Sep 17 00:00:00 2001 From: Emily Boudreaux Date: Wed, 5 Nov 2025 12:48:08 -0500 Subject: [PATCH] feat(policy): began addition of robust policy system The policy system provides a way for users to ensure that they get a network with certain traits. For example being sure that the network they get has all of the proton proton reactions in its base reaction set. This is an extensible system which is intended to be used by researchers to build various determanistic network policies to make network results more reproducable --- .../gridfire/exceptions/error_policy.h | 32 ++++ src/include/gridfire/policy/chains.h | 169 ++++++++++++++++++ src/include/gridfire/policy/policy_abstract.h | 45 +++++ src/include/gridfire/policy/stellar_policy.h | 151 ++++++++++++++++ src/include/gridfire/reaction/reaction.h | 2 + src/lib/engine/engine_graph.cpp | 1 - src/lib/reaction/reaction.cpp | 7 + 7 files changed, 406 insertions(+), 1 deletion(-) create mode 100644 src/include/gridfire/exceptions/error_policy.h create mode 100644 src/include/gridfire/policy/chains.h create mode 100644 src/include/gridfire/policy/policy_abstract.h create mode 100644 src/include/gridfire/policy/stellar_policy.h diff --git a/src/include/gridfire/exceptions/error_policy.h b/src/include/gridfire/exceptions/error_policy.h new file mode 100644 index 00000000..0f124017 --- /dev/null +++ b/src/include/gridfire/exceptions/error_policy.h @@ -0,0 +1,32 @@ +#pragma once + +#include +#include + +namespace gridfire::exceptions { + class PolicyError : std::exception { + public: + explicit PolicyError(const std::string& msg) : m_message(msg) {}; + + [[nodiscard]] const char* what() const noexcept override { + return m_message.c_str(); + } + private: + std::string m_message; + }; + + class MissingBaseReactionError final : public PolicyError { + public: + explicit MissingBaseReactionError(const std::string& msg) : PolicyError(msg) {}; + }; + + class MissingSeedSpeciesError final : public PolicyError { + public: + explicit MissingSeedSpeciesError(const std::string& msg) : PolicyError(msg) {}; + }; + + class MissingKeyReactionError final : public PolicyError { + public: + explicit MissingKeyReactionError(const std::string& msg) : PolicyError(msg) {} + }; +} \ No newline at end of file diff --git a/src/include/gridfire/policy/chains.h b/src/include/gridfire/policy/chains.h new file mode 100644 index 00000000..288f9858 --- /dev/null +++ b/src/include/gridfire/policy/chains.h @@ -0,0 +1,169 @@ +#pragma once + +#include "gridfire/policy/policy_abstract.h" +#include "gridfire/reaction/reaction.h" + +#include "gridfire/reaction/reaclib.h" + +#include "gridfire/exceptions/error_policy.h" + +#include + +namespace gridfire::policy { + + class ProtonProtonChainPolicy final: public ReactionChainPolicy { + public: + ProtonProtonChainPolicy(); + + const reaction::ReactionSet& get_reactions() const override { return m_reactions; } + private: + std::vector m_reactionIDs = { + "p(p,e+)d", + "d(p,g)he3", + "he3(he3,2p)he4", + "he4(he3,g)be7", + "be7(e-,)li7", + "li7(p,a)he4", + "be7(p,g)b8", + "b8(,e+)be8", + "be8(,a)he4" + }; + reaction::ReactionSet m_reactions; + }; + + class CNOChainPolicy final: public ReactionChainPolicy { + public: + CNOChainPolicy(); + const reaction::ReactionSet& get_reactions() const override { return m_reactions; } + private: + std::set m_reactionIDs = { + "c12(p,g)n13", + "n13(,e+)c13", + "c13(p,g)n14", + "n14(p,g)o15", + "o15(,e+)n15", + "n15(p,a)c12", + + "n15(p,g)o16", + "o16(p,g)f17", + "f17(,e+)o17", + "o17(p,a)n14", + "n14(p,g)o15", + "o15(,e+)n15", + + "o17(p,g)f18", + "f18(,e+)o18", + "o18(p,a)n15", + "n15(p,g)o16", + "o16(p,g)f17", + "f17(,e+)o17", + + "o18(p,g)f19", + "f19(p,a)o16", + "o16(p,g)f17", + "f17(,e+)o17", + "o17(p,g)f18", + "f18(,e+)o18" + }; + reaction::ReactionSet m_reactions; + }; + + class HotCNOChainPolicy final : public ReactionChainPolicy { + public: + HotCNOChainPolicy(); + const reaction::ReactionSet& get_reactions() const override { return m_reactions; } + private: + std::set m_reactionIDs = { + "c12(p,g)n13", + "n13(p,g)o14", + "o14(,e+)n14", + "n14(p,g)o15", + "o15(,e+)n15", + "n15(p,a)c12", + + "n15(p,g)o16", + "o16(p,g)f17", + "f17(p,g)ne18", + "ne18(,e+)f18", + "f18(p,a)o15", + "o15(,e+)n15", + + "f18(p,g)ne19", + "ne19(,e+)f19", + "f19(p,a)o16", + "o16(p,g)f17", + "f17(p,g)ne18", + "ne18(,e+)f18" + }; + + reaction::ReactionSet m_reactions; + }; + + class LowMassMainSequenceReactionChainPolicy final : public MultiReactionChainPolicy { + public: + LowMassMainSequenceReactionChainPolicy(); + + const reaction::ReactionSet & get_reactions() const override; + + const std::vector>& get_chain_policies() const override; + + private: + std::vector> m_chain_policies; + reaction::ReactionSet m_reactions; + }; + + inline ProtonProtonChainPolicy::ProtonProtonChainPolicy() { + const auto& all_reaclib_reactions = reaclib::get_all_reaclib_reactions(); + + for (const auto& reactionID : m_reactionIDs) { + auto reaction = all_reaclib_reactions.get(reactionID); + if (!reaction) { + throw exceptions::MissingBaseReactionError("The Underlying REACLIB reaction set is missing the reaction " + std::string(reactionID) + " needed for the proton-proton chain. This indicates that there is an issue with the GridFire binary you are using. Please try to recompile and if that fails please report this issue to the developers."); + } + m_reactions.add_reaction(reaction.value()->clone()); + } + } + + inline CNOChainPolicy::CNOChainPolicy() { + const auto& all_reaclib_reactions = reaclib::get_all_reaclib_reactions(); + for (const auto& reactionID : m_reactionIDs) { + auto reaction = all_reaclib_reactions.get(reactionID); + if (!reaction) { + throw exceptions::MissingBaseReactionError("The Underlying REACLIB reaction set is missing the reaction " + std::string(reactionID) + " needed for the CNO cycle. This indicates that there is an issue with the GridFire binary you are using. Please try to recompile and if that fails please report this issue to the developers."); + } + m_reactions.add_reaction(reaction.value()->clone()); + } + } + + inline HotCNOChainPolicy::HotCNOChainPolicy() { + const auto& all_reaclib_reactions = reaclib::get_all_reaclib_reactions(); + for (const auto& reactionID : m_reactionIDs) { + auto reaction = all_reaclib_reactions.get(reactionID); + if (!reaction) { + throw exceptions::MissingBaseReactionError("The Underlying REACLIB reaction set is missing the reaction " + std::string(reactionID) + " needed for the Hot CNO cycle. This indicates that there is an issue with the GridFire binary you are using. Please try to recompile and if that fails please report this issue to the developers."); + } + m_reactions.add_reaction(reaction.value()->clone()); + } + } + + + inline LowMassMainSequenceReactionChainPolicy::LowMassMainSequenceReactionChainPolicy() { + m_chain_policies.emplace_back(std::make_unique()); + m_chain_policies.emplace_back(std::make_unique()); + for (const auto& policy_ptr : m_chain_policies) { + m_reactions.extend(policy_ptr->get_reactions()); + } + } + + inline const reaction::ReactionSet & LowMassMainSequenceReactionChainPolicy::get_reactions() const { + return m_reactions; + } + + inline const std::vector>& LowMassMainSequenceReactionChainPolicy::get_chain_policies() const { + return m_chain_policies; + } + + + +} + diff --git a/src/include/gridfire/policy/policy_abstract.h b/src/include/gridfire/policy/policy_abstract.h new file mode 100644 index 00000000..bf325f52 --- /dev/null +++ b/src/include/gridfire/policy/policy_abstract.h @@ -0,0 +1,45 @@ +#pragma once + +#include "fourdst/composition/atomicSpecies.h" +#include "gridfire/engine/types/building.h" +#include "gridfire/reaction/reaction.h" +#include "gridfire/engine/engine_abstract.h" + +#include +#include + + +namespace gridfire::policy { + enum class NetworkPolicyStatus { + UNINITIALIZED, + INITIALIZED_UNVERIFIED, + MISSING_KEY_REACTION, + MISSING_KEY_SPECIES, + INITIALIZED_VERIFIED + }; + + class NetworkPolicy { + public: + virtual ~NetworkPolicy() = default; + virtual std::string name() const = 0; + + virtual const std::set get_seed_species() const = 0; + virtual const reaction::ReactionSet& get_seed_reactions() const = 0; + + virtual DynamicEngine& construct() = 0; + + virtual NetworkPolicyStatus getStatus() const = 0; + }; + + class ReactionChainPolicy { + public: + virtual ~ReactionChainPolicy() = default; + virtual const reaction::ReactionSet& get_reactions() const = 0; + }; + + class MultiReactionChainPolicy : public ReactionChainPolicy { + public: + virtual const std::vector>& get_chain_policies() const = 0; + }; + +} diff --git a/src/include/gridfire/policy/stellar_policy.h b/src/include/gridfire/policy/stellar_policy.h new file mode 100644 index 00000000..93cef85d --- /dev/null +++ b/src/include/gridfire/policy/stellar_policy.h @@ -0,0 +1,151 @@ +#pragma once + +#include +#include + +#include "gridfire/policy/policy_abstract.h" +#include "gridfire/engine/engine_abstract.h" +#include "gridfire/reaction/reaction.h" + +#include "gridfire/exceptions/error_policy.h" + +#include "fourdst/composition/composition.h" +#include "fourdst/composition/atomicSpecies.h" +#include "fourdst/composition/exceptions/exceptions_composition.h" +#include "gridfire/engine/engine_graph.h" +#include "gridfire/engine/views/engine_adaptive.h" +#include "gridfire/engine/views/engine_multiscale.h" +#include "gridfire/partition/composite/partition_composite.h" + +#include "gridfire/policy/chains.h" + +namespace gridfire::policy { + class LowMassMainSequencePolicy final: public NetworkPolicy { + public: + explicit LowMassMainSequencePolicy(const fourdst::composition::Composition& composition); + explicit LowMassMainSequencePolicy(std::vector seed_species, std::vector mass_fractions); + + std::string name() const override { return "LowMassMainSequencePolicy"; } + + const std::set get_seed_species() const override { return m_seed_species; } + const reaction::ReactionSet& get_seed_reactions() const override { return m_reaction_policy->get_reactions(); } + + DynamicEngine& construct() override; + + NetworkPolicyStatus getStatus() const override; + private: + std::set m_seed_species = { + fourdst::atomic::H_1, + fourdst::atomic::He_3, + fourdst::atomic::He_4, + fourdst::atomic::C_12, + fourdst::atomic::N_14, + fourdst::atomic::O_16, + fourdst::atomic::Ne_20, + fourdst::atomic::Mg_24 + }; + + std::unique_ptr m_reaction_policy = std::make_unique(); + fourdst::composition::Composition m_initializing_composition; + std::unique_ptr m_partition_function; + std::vector> m_network_stack; + + NetworkPolicyStatus m_status = NetworkPolicyStatus::UNINITIALIZED; + private: + static std::unique_ptr build_partition_function(); + NetworkPolicyStatus check_status() const; + + }; + + inline LowMassMainSequencePolicy::LowMassMainSequencePolicy(const fourdst::composition::Composition& composition) { + for (const auto& species : m_seed_species) { + if (!composition.hasSpecies(species)) { + throw exceptions::MissingSeedSpeciesError("Cannot initialize LowMassMainSequencePolicy: Required Seed species " + std::string(species.name()) + " is missing from the provided composition."); + } + } + m_initializing_composition = composition; + m_partition_function = build_partition_function(); + } + + inline LowMassMainSequencePolicy::LowMassMainSequencePolicy(std::vector seed_species, std::vector mass_fractions) { + for (const auto& species : m_seed_species) { + if (std::ranges::find(seed_species, species) == seed_species.end()) { + throw exceptions::MissingSeedSpeciesError("Cannot initialize LowMassMainSequencePolicy: Required Seed species " + std::string(species.name()) + " is missing from the provided composition."); + } + } + + for (const auto& [species, x] : std::views::zip(seed_species, mass_fractions)) { + m_initializing_composition.registerSpecies(species); + m_initializing_composition.setMassFraction(species, x); + } + + const bool didFinalize = m_initializing_composition.finalize(true); + if (!didFinalize) { + throw fourdst::composition::exceptions::CompositionNotFinalizedError("Failed to finalize initial composition for LowMassMainSequencePolicy."); + } + + m_partition_function = build_partition_function(); + } + + inline DynamicEngine& LowMassMainSequencePolicy::construct() { + m_network_stack.clear(); + + m_network_stack.emplace_back( + std::make_unique(m_initializing_composition, *m_partition_function, NetworkBuildDepth::ThirdOrder, NetworkConstructionFlags::DEFAULT) + ); + m_network_stack.emplace_back( + std::make_unique(*m_network_stack.back().get()) + ); + m_network_stack.emplace_back( + std::make_unique(*m_network_stack.back().get()) + ); + + m_status = NetworkPolicyStatus::INITIALIZED_UNVERIFIED; + m_status = check_status(); + + switch (m_status) { + case NetworkPolicyStatus::MISSING_KEY_REACTION: + throw exceptions::MissingKeyReactionError("LowMassMainSequencePolicy construction failed: The constructed network is missing key reactions required by the policy."); + case NetworkPolicyStatus::MISSING_KEY_SPECIES: + throw exceptions::MissingSeedSpeciesError("LowMassMainSequencePolicy construction failed: The constructed network is missing key seed species required by the policy."); + case NetworkPolicyStatus::UNINITIALIZED: + throw exceptions::PolicyError("LowMassMainSequencePolicy construction failed: The network policy is uninitialized."); + case NetworkPolicyStatus::INITIALIZED_UNVERIFIED: + throw exceptions::PolicyError("LowMassMainSequencePolicy construction failed: The network policy status could not be verified."); + case NetworkPolicyStatus::INITIALIZED_VERIFIED: + break; + } + return *m_network_stack.back(); + } + + inline std::unique_ptr LowMassMainSequencePolicy::build_partition_function() { + using partition::BasePartitionType; + const auto partitionFunction = partition::CompositePartitionFunction({ + BasePartitionType::RauscherThielemann, + BasePartitionType::GroundState + }); + return std::make_unique(partitionFunction); + } + + inline NetworkPolicyStatus LowMassMainSequencePolicy::getStatus() const { + return m_status; + } + + inline NetworkPolicyStatus LowMassMainSequencePolicy::check_status() const { + for (const auto& species : m_seed_species) { + if (!m_initializing_composition.hasSpecies(species)) { + return NetworkPolicyStatus::MISSING_KEY_SPECIES; + } + } + const reaction::ReactionSet& baseReactions = m_network_stack.front()->getNetworkReactions(); + for (const auto& reaction : m_reaction_policy->get_reactions()) { + const bool result = baseReactions.contains(*reaction); + if (!result) { + return NetworkPolicyStatus::MISSING_KEY_REACTION; + } + } + return NetworkPolicyStatus::INITIALIZED_VERIFIED; + } + + +} \ No newline at end of file diff --git a/src/include/gridfire/reaction/reaction.h b/src/include/gridfire/reaction/reaction.h index 1f28e45e..a6e46e94 100644 --- a/src/include/gridfire/reaction/reaction.h +++ b/src/include/gridfire/reaction/reaction.h @@ -788,6 +788,8 @@ namespace gridfire::reaction { void extend(const ReactionSet& other); + [[nodiscard]] std::optional> get(const std::string_view& id) const; + /** * @brief Removes a reaction from the set. * @param reaction The Reaction to remove. diff --git a/src/lib/engine/engine_graph.cpp b/src/lib/engine/engine_graph.cpp index 3f0d894a..42e54793 100644 --- a/src/lib/engine/engine_graph.cpp +++ b/src/lib/engine/engine_graph.cpp @@ -900,7 +900,6 @@ namespace gridfire { const double rho, const SparsityPattern &sparsityPattern ) const { - //TODO: The issue now seems to be that the jacobian is returning all zeros. I need to sort out why this is SparsityPattern intersectionSparsityPattern; for (const auto& entry : sparsityPattern) { if (m_full_sparsity_set.contains(entry)) { diff --git a/src/lib/reaction/reaction.cpp b/src/lib/reaction/reaction.cpp index 67e172ef..c040b5a0 100644 --- a/src/lib/reaction/reaction.cpp +++ b/src/lib/reaction/reaction.cpp @@ -417,6 +417,13 @@ namespace gridfire::reaction { } } + std::optional> ReactionSet::get(const std::string_view &id) const { + if (!contains(id)) { + return std::nullopt; + } + return std::make_optional(m_reactions[m_reactionNameMap.at(std::string(id))]->clone()); + } + void ReactionSet::remove_reaction(const Reaction& reaction) { const auto reaction_id = std::string(reaction.id()); if (!m_reactionNameMap.contains(reaction_id)) {