refactor(reaction): refactored to an abstract reaction class in prep for weak reactions

This commit is contained in:
2025-08-14 13:33:46 -04:00
parent d920a55ba6
commit 0b77f2e269
81 changed files with 1050041 additions and 913 deletions

View File

@@ -48,7 +48,7 @@ namespace gridfire {
}
GraphEngine::GraphEngine(
const reaction::LogicalReactionSet &reactions
const reaction::ReactionSet &reactions
) :
m_reactions(reactions) {
syncInternalMaps();
@@ -67,8 +67,8 @@ namespace gridfire {
// TODO: Add cache to this
for (const auto& reaction: m_reactions) {
bare_rates.push_back(reaction.calculate_rate(T9));
bare_reverse_rates.push_back(calculateReverseRate(reaction, T9));
bare_rates.push_back(reaction->calculate_rate(T9, rho, Y));
bare_reverse_rates.push_back(calculateReverseRate(*reaction, T9, rho, Y));
}
// --- The public facing interface can always use the precomputed version since taping is done internally ---
@@ -110,10 +110,10 @@ namespace gridfire {
std::set<std::string_view> uniqueSpeciesNames;
for (const auto& reaction: m_reactions) {
for (const auto& reactant: reaction.reactants()) {
for (const auto& reactant: reaction->reactants()) {
uniqueSpeciesNames.insert(reactant.name());
}
for (const auto& product: reaction.products()) {
for (const auto& product: reaction->products()) {
uniqueSpeciesNames.insert(product.name());
}
}
@@ -136,7 +136,7 @@ namespace gridfire {
LOG_TRACE_L1(m_logger, "Populating reaction ID map for REACLIB graph network (serif::network::GraphNetwork)...");
m_reactionIDMap.clear();
for (auto& reaction: m_reactions) {
m_reactionIDMap.emplace(reaction.id(), &reaction);
m_reactionIDMap.emplace(reaction->id(), reaction.get());
}
LOG_TRACE_L1(m_logger, "Populated {} reactions in the reaction ID map.", m_reactionIDMap.size());
}
@@ -165,13 +165,13 @@ namespace gridfire {
return m_networkSpecies;
}
const reaction::LogicalReactionSet& GraphEngine::getNetworkReactions() const {
const reaction::ReactionSet& GraphEngine::getNetworkReactions() const {
// Returns a constant reference to the set of reactions in the network.
LOG_TRACE_L3(m_logger, "Providing access to network reactions set. Size: {}.", m_reactions.size());
return m_reactions;
}
void GraphEngine::setNetworkReactions(const reaction::LogicalReactionSet &reactions) {
void GraphEngine::setNetworkReactions(const reaction::ReactionSet &reactions) {
m_reactions = reactions;
syncInternalMaps();
}
@@ -194,7 +194,7 @@ namespace gridfire {
uint64_t totalProductZ = 0;
// Calculate total A and Z for reactants
for (const auto& reactant : reaction.reactants()) {
for (const auto& reactant : reaction->reactants()) {
auto it = m_networkSpeciesMap.find(reactant.name());
if (it != m_networkSpeciesMap.end()) {
totalReactantA += it->second.a();
@@ -203,13 +203,13 @@ namespace gridfire {
// This scenario indicates a severe data integrity issue:
// a reactant is part of a reaction but not in the network's species map.
LOG_ERROR(m_logger, "CRITICAL ERROR: Reactant species '{}' in reaction '{}' not found in network species map during conservation validation.",
reactant.name(), reaction.id());
reactant.name(), reaction->id());
return false;
}
}
// Calculate total A and Z for products
for (const auto& product : reaction.products()) {
for (const auto& product : reaction->products()) {
auto it = m_networkSpeciesMap.find(product.name());
if (it != m_networkSpeciesMap.end()) {
totalProductA += it->second.a();
@@ -217,7 +217,7 @@ namespace gridfire {
} else {
// Similar critical error for product species
LOG_ERROR(m_logger, "CRITICAL ERROR: Product species '{}' in reaction '{}' not found in network species map during conservation validation.",
product.name(), reaction.id());
product.name(), reaction->id());
return false;
}
}
@@ -225,12 +225,12 @@ namespace gridfire {
// Compare totals for conservation
if (totalReactantA != totalProductA) {
LOG_ERROR(m_logger, "Mass number (A) not conserved for reaction '{}': Reactants A={} vs Products A={}.",
reaction.id(), totalReactantA, totalProductA);
reaction->id(), totalReactantA, totalProductA);
return false;
}
if (totalReactantZ != totalProductZ) {
LOG_ERROR(m_logger, "Atomic number (Z) not conserved for reaction '{}': Reactants Z={} vs Products Z={}.",
reaction.id(), totalReactantZ, totalProductZ);
reaction->id(), totalReactantZ, totalProductZ);
return false;
}
}
@@ -241,7 +241,9 @@ namespace gridfire {
double GraphEngine::calculateReverseRate(
const reaction::Reaction &reaction,
const double T9
const double T9,
const double rho,
const std::vector<double> &Y
) const {
if (!m_useReverseReactions) {
LOG_TRACE_L3_LIMIT_EVERY_N(std::numeric_limits<int>::max(), m_logger, "Reverse reactions are disabled. Returning 0.0 for reverse rate of reaction '{}'.", reaction.id());
@@ -255,12 +257,12 @@ namespace gridfire {
const double kBMeV = m_constants.kB * 624151; // Convert kB to MeV/K NOTE: This relies on the fact that m_constants.kB is in erg/K!
const double expFactor = std::exp(-reaction.qValue() / (kBMeV * temp));
double reverseRate = 0.0;
const double forwardRate = reaction.calculate_rate(T9);
const double forwardRate = reaction.calculate_rate(T9, rho, Y);
if (reaction.reactants().size() == 2 && reaction.products().size() == 2) {
reverseRate = calculateReverseRateTwoBody(reaction, T9, forwardRate, expFactor);
} else {
LOG_WARNING_LIMIT_EVERY_N(1000000, m_logger, "Reverse rate calculation for reactions with more than two reactants or products is not implemented (reaction id {}).", reaction.peName());
LOG_WARNING_LIMIT_EVERY_N(1000000, m_logger, "Reverse rate calculation for reactions with more than two reactants or products is not implemented (reaction id {}).", reaction.id());
}
LOG_TRACE_L2_LIMIT_EVERY_N(1000, m_logger, "Calculated reverse rate for reaction '{}': {:.3E} at T9={:.3E}.", reaction.id(), reverseRate, T9);
return reverseRate;
@@ -346,13 +348,15 @@ namespace gridfire {
double GraphEngine::calculateReverseRateTwoBodyDerivative(
const reaction::Reaction &reaction,
const double T9,
const double rho,
const std::vector<double> &Y,
const double reverseRate
) const {
if (!m_useReverseReactions) {
LOG_TRACE_L3_LIMIT_EVERY_N(std::numeric_limits<int>::max(), m_logger, "Reverse reactions are disabled. Returning 0.0 for reverse rate of reaction '{}'.", reaction.id());
return 0.0; // If reverse reactions are not used, return 0.0
}
const double d_log_kFwd = reaction.calculate_forward_rate_log_derivative(T9);
const double d_log_kFwd = reaction.calculate_forward_rate_log_derivative(T9, rho, Y);
auto log_deriv_pf_op = [&](double acc, const auto& species) {
const double g = m_partitionFunction->evaluate(species.z(), species.a(), T9);
@@ -392,7 +396,7 @@ namespace gridfire {
m_useReverseReactions = useReverse;
}
int GraphEngine::getSpeciesIndex(const fourdst::atomic::Species &species) const {
size_t GraphEngine::getSpeciesIndex(const fourdst::atomic::Species &species) const {
return m_speciesToIndexMap.at(species); // Returns the index of the species in the stoichiometry matrix
}
@@ -494,7 +498,7 @@ namespace gridfire {
bare_rate *
precomp.symmetry_factor *
forwardAbundanceProduct *
std::pow(rho, numReactants > 1 ? numReactants - 1 : 0.0);
std::pow(rho, numReactants > 1 ? static_cast<double>(numReactants) - 1 : 0.0);
double reverseMolarReactionFlow = 0.0;
if (precomp.reverse_symmetry_factor != 0.0 and m_useReverseReactions) {
@@ -507,7 +511,7 @@ namespace gridfire {
bare_reverse_rate *
precomp.reverse_symmetry_factor *
reverseAbundanceProduct *
std::pow(rho, numProducts > 1 ? numProducts - 1 : 0.0);
std::pow(rho, numProducts > 1 ? static_cast<double>(numProducts) - 1 : 0.0);
}
molarReactionFlows.push_back(forwardMolarReactionFlow - reverseMolarReactionFlow);
@@ -558,7 +562,7 @@ namespace gridfire {
size_t reactionColumnIndex = 0;
for (const auto& reaction : m_reactions) {
// Get the net stoichiometry for the current reaction
std::unordered_map<fourdst::atomic::Species, int> netStoichiometry = reaction.stoichiometry();
std::unordered_map<fourdst::atomic::Species, int> netStoichiometry = reaction->stoichiometry();
// Iterate through the species and their coefficients in the stoichiometry map
for (const auto& [species, coefficient] : netStoichiometry) {
@@ -571,7 +575,7 @@ namespace gridfire {
} else {
// This scenario should ideally not happen if m_networkSpeciesMap and m_speciesToIndexMap are correctly synced
LOG_ERROR(m_logger, "CRITICAL ERROR: Species '{}' from reaction '{}' stoichiometry not found in species to index map.",
species.name(), reaction.id());
species.name(), reaction->id());
m_logger -> flush_log();
throw std::runtime_error("Species not found in species to index map: " + std::string(species.name()));
}
@@ -763,19 +767,19 @@ namespace gridfire {
dotFile << " // --- Reaction Edges ---\n";
for (const auto& reaction : m_reactions) {
// Create a unique ID for the reaction node
std::string reactionNodeId = "reaction_" + std::string(reaction.id());
std::string reactionNodeId = "reaction_" + std::string(reaction->id());
// Define the reaction node (small, black dot)
dotFile << " \"" << reactionNodeId << "\" [shape=point, fillcolor=black, width=0.1, height=0.1, label=\"\"];\n";
// Draw edges from reactants to the reaction node
for (const auto& reactant : reaction.reactants()) {
for (const auto& reactant : reaction->reactants()) {
dotFile << " \"" << reactant.name() << "\" -> \"" << reactionNodeId << "\";\n";
}
// Draw edges from the reaction node to products
for (const auto& product : reaction.products()) {
dotFile << " \"" << reactionNodeId << "\" -> \"" << product.name() << "\" [label=\"" << reaction.qValue() << " MeV\"];\n";
for (const auto& product : reaction->products()) {
dotFile << " \"" << reactionNodeId << "\" -> \"" << product.name() << "\" [label=\"" << reaction->qValue() << " MeV\"];\n";
}
dotFile << "\n";
}
@@ -797,42 +801,25 @@ namespace gridfire {
csvFile << "Reaction;Reactants;Products;Q-value;sources;rates\n";
for (const auto& reaction : m_reactions) {
// Dynamic cast to REACLIBReaction to access specific properties
csvFile << reaction.id() << ";";
csvFile << reaction->id() << ";";
// Reactants
size_t count = 0;
for (const auto& reactant : reaction.reactants()) {
for (const auto& reactant : reaction->reactants()) {
csvFile << reactant.name();
if (++count < reaction.reactants().size()) {
if (++count < reaction->reactants().size()) {
csvFile << ",";
}
}
csvFile << ";";
count = 0;
for (const auto& product : reaction.products()) {
for (const auto& product : reaction->products()) {
csvFile << product.name();
if (++count < reaction.products().size()) {
if (++count < reaction->products().size()) {
csvFile << ",";
}
}
csvFile << ";" << reaction.qValue() << ";";
csvFile << ";" << reaction->qValue() << ";";
// Reaction coefficients
auto sources = reaction.sources();
count = 0;
for (const auto& source : sources) {
csvFile << source;
if (++count < sources.size()) {
csvFile << ",";
}
}
csvFile << ";";
// Reaction coefficients
count = 0;
for (const auto& rates : reaction) {
csvFile << rates;
if (++count < reaction.size()) {
csvFile << ",";
}
}
csvFile << "\n";
}
csvFile.close();
@@ -869,8 +856,8 @@ namespace gridfire {
for (const auto& species : m_networkSpecies) {
double netDestructionFlow = 0.0;
for (const auto& reaction : m_reactions) {
if (reaction.stoichiometry(species) < 0) {
const double flow = calculateMolarReactionFlow<double>(reaction, Y, T9, rho);
if (reaction->stoichiometry(species) < 0) {
const auto flow = calculateMolarReactionFlow<double>(*reaction, Y, T9, rho);
netDestructionFlow += flow;
}
}
@@ -899,7 +886,7 @@ namespace gridfire {
return false;
}
void GraphEngine::recordADTape() {
void GraphEngine::recordADTape() const {
LOG_TRACE_L1(m_logger, "Recording AD tape for the RHS calculation...");
// Task 1: Set dimensions and initialize the matrix
@@ -949,8 +936,8 @@ namespace gridfire {
m_atomicReverseRates.reserve(m_reactions.size());
for (const auto& reaction: m_reactions) {
if (reaction.qValue() != 0.0) {
m_atomicReverseRates.push_back(std::make_unique<AtomicReverseRate>(reaction, *this));
if (reaction->qValue() != 0.0) {
m_atomicReverseRates.push_back(std::make_unique<AtomicReverseRate>(*reaction, *this));
} else {
m_atomicReverseRates.push_back(nullptr);
}
@@ -1038,7 +1025,8 @@ namespace gridfire {
if ( p != 0) { return false; }
const double T9 = tx[0];
const double reverseRate = m_engine.calculateReverseRate(m_reaction, T9);
// TODO: Handle rho and Y
const double reverseRate = m_engine.calculateReverseRate(m_reaction, T9, 0, {});
// std::cout << m_reaction.peName() << " reverseRate: " << reverseRate << " at T9: " << T9 << "\n";
ty[0] = reverseRate; // Store the reverse rate in the output vector
@@ -1058,8 +1046,8 @@ namespace gridfire {
const double T9 = tx[0];
const double reverseRate = ty[0];
const double derivative = m_engine.calculateReverseRateTwoBodyDerivative(m_reaction, T9, reverseRate);
// std::cout << m_reaction.peName() << " reverseRate Derivative: " << derivative << "\n";
// TODO: Handle rho and Y
const double derivative = m_engine.calculateReverseRateTwoBodyDerivative(m_reaction, T9, 0, {}, reverseRate);
px[0] = py[0] * derivative; // Return the derivative of the reverse rate with respect to T9