refactor(reaction): refactored to an abstract reaction class in prep for weak reactions
This commit is contained in:
@@ -48,7 +48,7 @@ namespace gridfire {
|
||||
}
|
||||
|
||||
GraphEngine::GraphEngine(
|
||||
const reaction::LogicalReactionSet &reactions
|
||||
const reaction::ReactionSet &reactions
|
||||
) :
|
||||
m_reactions(reactions) {
|
||||
syncInternalMaps();
|
||||
@@ -67,8 +67,8 @@ namespace gridfire {
|
||||
|
||||
// TODO: Add cache to this
|
||||
for (const auto& reaction: m_reactions) {
|
||||
bare_rates.push_back(reaction.calculate_rate(T9));
|
||||
bare_reverse_rates.push_back(calculateReverseRate(reaction, T9));
|
||||
bare_rates.push_back(reaction->calculate_rate(T9, rho, Y));
|
||||
bare_reverse_rates.push_back(calculateReverseRate(*reaction, T9, rho, Y));
|
||||
}
|
||||
|
||||
// --- The public facing interface can always use the precomputed version since taping is done internally ---
|
||||
@@ -110,10 +110,10 @@ namespace gridfire {
|
||||
std::set<std::string_view> uniqueSpeciesNames;
|
||||
|
||||
for (const auto& reaction: m_reactions) {
|
||||
for (const auto& reactant: reaction.reactants()) {
|
||||
for (const auto& reactant: reaction->reactants()) {
|
||||
uniqueSpeciesNames.insert(reactant.name());
|
||||
}
|
||||
for (const auto& product: reaction.products()) {
|
||||
for (const auto& product: reaction->products()) {
|
||||
uniqueSpeciesNames.insert(product.name());
|
||||
}
|
||||
}
|
||||
@@ -136,7 +136,7 @@ namespace gridfire {
|
||||
LOG_TRACE_L1(m_logger, "Populating reaction ID map for REACLIB graph network (serif::network::GraphNetwork)...");
|
||||
m_reactionIDMap.clear();
|
||||
for (auto& reaction: m_reactions) {
|
||||
m_reactionIDMap.emplace(reaction.id(), &reaction);
|
||||
m_reactionIDMap.emplace(reaction->id(), reaction.get());
|
||||
}
|
||||
LOG_TRACE_L1(m_logger, "Populated {} reactions in the reaction ID map.", m_reactionIDMap.size());
|
||||
}
|
||||
@@ -165,13 +165,13 @@ namespace gridfire {
|
||||
return m_networkSpecies;
|
||||
}
|
||||
|
||||
const reaction::LogicalReactionSet& GraphEngine::getNetworkReactions() const {
|
||||
const reaction::ReactionSet& GraphEngine::getNetworkReactions() const {
|
||||
// Returns a constant reference to the set of reactions in the network.
|
||||
LOG_TRACE_L3(m_logger, "Providing access to network reactions set. Size: {}.", m_reactions.size());
|
||||
return m_reactions;
|
||||
}
|
||||
|
||||
void GraphEngine::setNetworkReactions(const reaction::LogicalReactionSet &reactions) {
|
||||
void GraphEngine::setNetworkReactions(const reaction::ReactionSet &reactions) {
|
||||
m_reactions = reactions;
|
||||
syncInternalMaps();
|
||||
}
|
||||
@@ -194,7 +194,7 @@ namespace gridfire {
|
||||
uint64_t totalProductZ = 0;
|
||||
|
||||
// Calculate total A and Z for reactants
|
||||
for (const auto& reactant : reaction.reactants()) {
|
||||
for (const auto& reactant : reaction->reactants()) {
|
||||
auto it = m_networkSpeciesMap.find(reactant.name());
|
||||
if (it != m_networkSpeciesMap.end()) {
|
||||
totalReactantA += it->second.a();
|
||||
@@ -203,13 +203,13 @@ namespace gridfire {
|
||||
// This scenario indicates a severe data integrity issue:
|
||||
// a reactant is part of a reaction but not in the network's species map.
|
||||
LOG_ERROR(m_logger, "CRITICAL ERROR: Reactant species '{}' in reaction '{}' not found in network species map during conservation validation.",
|
||||
reactant.name(), reaction.id());
|
||||
reactant.name(), reaction->id());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate total A and Z for products
|
||||
for (const auto& product : reaction.products()) {
|
||||
for (const auto& product : reaction->products()) {
|
||||
auto it = m_networkSpeciesMap.find(product.name());
|
||||
if (it != m_networkSpeciesMap.end()) {
|
||||
totalProductA += it->second.a();
|
||||
@@ -217,7 +217,7 @@ namespace gridfire {
|
||||
} else {
|
||||
// Similar critical error for product species
|
||||
LOG_ERROR(m_logger, "CRITICAL ERROR: Product species '{}' in reaction '{}' not found in network species map during conservation validation.",
|
||||
product.name(), reaction.id());
|
||||
product.name(), reaction->id());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -225,12 +225,12 @@ namespace gridfire {
|
||||
// Compare totals for conservation
|
||||
if (totalReactantA != totalProductA) {
|
||||
LOG_ERROR(m_logger, "Mass number (A) not conserved for reaction '{}': Reactants A={} vs Products A={}.",
|
||||
reaction.id(), totalReactantA, totalProductA);
|
||||
reaction->id(), totalReactantA, totalProductA);
|
||||
return false;
|
||||
}
|
||||
if (totalReactantZ != totalProductZ) {
|
||||
LOG_ERROR(m_logger, "Atomic number (Z) not conserved for reaction '{}': Reactants Z={} vs Products Z={}.",
|
||||
reaction.id(), totalReactantZ, totalProductZ);
|
||||
reaction->id(), totalReactantZ, totalProductZ);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -241,7 +241,9 @@ namespace gridfire {
|
||||
|
||||
double GraphEngine::calculateReverseRate(
|
||||
const reaction::Reaction &reaction,
|
||||
const double T9
|
||||
const double T9,
|
||||
const double rho,
|
||||
const std::vector<double> &Y
|
||||
) const {
|
||||
if (!m_useReverseReactions) {
|
||||
LOG_TRACE_L3_LIMIT_EVERY_N(std::numeric_limits<int>::max(), m_logger, "Reverse reactions are disabled. Returning 0.0 for reverse rate of reaction '{}'.", reaction.id());
|
||||
@@ -255,12 +257,12 @@ namespace gridfire {
|
||||
const double kBMeV = m_constants.kB * 624151; // Convert kB to MeV/K NOTE: This relies on the fact that m_constants.kB is in erg/K!
|
||||
const double expFactor = std::exp(-reaction.qValue() / (kBMeV * temp));
|
||||
double reverseRate = 0.0;
|
||||
const double forwardRate = reaction.calculate_rate(T9);
|
||||
const double forwardRate = reaction.calculate_rate(T9, rho, Y);
|
||||
|
||||
if (reaction.reactants().size() == 2 && reaction.products().size() == 2) {
|
||||
reverseRate = calculateReverseRateTwoBody(reaction, T9, forwardRate, expFactor);
|
||||
} else {
|
||||
LOG_WARNING_LIMIT_EVERY_N(1000000, m_logger, "Reverse rate calculation for reactions with more than two reactants or products is not implemented (reaction id {}).", reaction.peName());
|
||||
LOG_WARNING_LIMIT_EVERY_N(1000000, m_logger, "Reverse rate calculation for reactions with more than two reactants or products is not implemented (reaction id {}).", reaction.id());
|
||||
}
|
||||
LOG_TRACE_L2_LIMIT_EVERY_N(1000, m_logger, "Calculated reverse rate for reaction '{}': {:.3E} at T9={:.3E}.", reaction.id(), reverseRate, T9);
|
||||
return reverseRate;
|
||||
@@ -346,13 +348,15 @@ namespace gridfire {
|
||||
double GraphEngine::calculateReverseRateTwoBodyDerivative(
|
||||
const reaction::Reaction &reaction,
|
||||
const double T9,
|
||||
const double rho,
|
||||
const std::vector<double> &Y,
|
||||
const double reverseRate
|
||||
) const {
|
||||
if (!m_useReverseReactions) {
|
||||
LOG_TRACE_L3_LIMIT_EVERY_N(std::numeric_limits<int>::max(), m_logger, "Reverse reactions are disabled. Returning 0.0 for reverse rate of reaction '{}'.", reaction.id());
|
||||
return 0.0; // If reverse reactions are not used, return 0.0
|
||||
}
|
||||
const double d_log_kFwd = reaction.calculate_forward_rate_log_derivative(T9);
|
||||
const double d_log_kFwd = reaction.calculate_forward_rate_log_derivative(T9, rho, Y);
|
||||
|
||||
auto log_deriv_pf_op = [&](double acc, const auto& species) {
|
||||
const double g = m_partitionFunction->evaluate(species.z(), species.a(), T9);
|
||||
@@ -392,7 +396,7 @@ namespace gridfire {
|
||||
m_useReverseReactions = useReverse;
|
||||
}
|
||||
|
||||
int GraphEngine::getSpeciesIndex(const fourdst::atomic::Species &species) const {
|
||||
size_t GraphEngine::getSpeciesIndex(const fourdst::atomic::Species &species) const {
|
||||
return m_speciesToIndexMap.at(species); // Returns the index of the species in the stoichiometry matrix
|
||||
}
|
||||
|
||||
@@ -494,7 +498,7 @@ namespace gridfire {
|
||||
bare_rate *
|
||||
precomp.symmetry_factor *
|
||||
forwardAbundanceProduct *
|
||||
std::pow(rho, numReactants > 1 ? numReactants - 1 : 0.0);
|
||||
std::pow(rho, numReactants > 1 ? static_cast<double>(numReactants) - 1 : 0.0);
|
||||
|
||||
double reverseMolarReactionFlow = 0.0;
|
||||
if (precomp.reverse_symmetry_factor != 0.0 and m_useReverseReactions) {
|
||||
@@ -507,7 +511,7 @@ namespace gridfire {
|
||||
bare_reverse_rate *
|
||||
precomp.reverse_symmetry_factor *
|
||||
reverseAbundanceProduct *
|
||||
std::pow(rho, numProducts > 1 ? numProducts - 1 : 0.0);
|
||||
std::pow(rho, numProducts > 1 ? static_cast<double>(numProducts) - 1 : 0.0);
|
||||
}
|
||||
|
||||
molarReactionFlows.push_back(forwardMolarReactionFlow - reverseMolarReactionFlow);
|
||||
@@ -558,7 +562,7 @@ namespace gridfire {
|
||||
size_t reactionColumnIndex = 0;
|
||||
for (const auto& reaction : m_reactions) {
|
||||
// Get the net stoichiometry for the current reaction
|
||||
std::unordered_map<fourdst::atomic::Species, int> netStoichiometry = reaction.stoichiometry();
|
||||
std::unordered_map<fourdst::atomic::Species, int> netStoichiometry = reaction->stoichiometry();
|
||||
|
||||
// Iterate through the species and their coefficients in the stoichiometry map
|
||||
for (const auto& [species, coefficient] : netStoichiometry) {
|
||||
@@ -571,7 +575,7 @@ namespace gridfire {
|
||||
} else {
|
||||
// This scenario should ideally not happen if m_networkSpeciesMap and m_speciesToIndexMap are correctly synced
|
||||
LOG_ERROR(m_logger, "CRITICAL ERROR: Species '{}' from reaction '{}' stoichiometry not found in species to index map.",
|
||||
species.name(), reaction.id());
|
||||
species.name(), reaction->id());
|
||||
m_logger -> flush_log();
|
||||
throw std::runtime_error("Species not found in species to index map: " + std::string(species.name()));
|
||||
}
|
||||
@@ -763,19 +767,19 @@ namespace gridfire {
|
||||
dotFile << " // --- Reaction Edges ---\n";
|
||||
for (const auto& reaction : m_reactions) {
|
||||
// Create a unique ID for the reaction node
|
||||
std::string reactionNodeId = "reaction_" + std::string(reaction.id());
|
||||
std::string reactionNodeId = "reaction_" + std::string(reaction->id());
|
||||
|
||||
// Define the reaction node (small, black dot)
|
||||
dotFile << " \"" << reactionNodeId << "\" [shape=point, fillcolor=black, width=0.1, height=0.1, label=\"\"];\n";
|
||||
|
||||
// Draw edges from reactants to the reaction node
|
||||
for (const auto& reactant : reaction.reactants()) {
|
||||
for (const auto& reactant : reaction->reactants()) {
|
||||
dotFile << " \"" << reactant.name() << "\" -> \"" << reactionNodeId << "\";\n";
|
||||
}
|
||||
|
||||
// Draw edges from the reaction node to products
|
||||
for (const auto& product : reaction.products()) {
|
||||
dotFile << " \"" << reactionNodeId << "\" -> \"" << product.name() << "\" [label=\"" << reaction.qValue() << " MeV\"];\n";
|
||||
for (const auto& product : reaction->products()) {
|
||||
dotFile << " \"" << reactionNodeId << "\" -> \"" << product.name() << "\" [label=\"" << reaction->qValue() << " MeV\"];\n";
|
||||
}
|
||||
dotFile << "\n";
|
||||
}
|
||||
@@ -797,42 +801,25 @@ namespace gridfire {
|
||||
csvFile << "Reaction;Reactants;Products;Q-value;sources;rates\n";
|
||||
for (const auto& reaction : m_reactions) {
|
||||
// Dynamic cast to REACLIBReaction to access specific properties
|
||||
csvFile << reaction.id() << ";";
|
||||
csvFile << reaction->id() << ";";
|
||||
// Reactants
|
||||
size_t count = 0;
|
||||
for (const auto& reactant : reaction.reactants()) {
|
||||
for (const auto& reactant : reaction->reactants()) {
|
||||
csvFile << reactant.name();
|
||||
if (++count < reaction.reactants().size()) {
|
||||
if (++count < reaction->reactants().size()) {
|
||||
csvFile << ",";
|
||||
}
|
||||
}
|
||||
csvFile << ";";
|
||||
count = 0;
|
||||
for (const auto& product : reaction.products()) {
|
||||
for (const auto& product : reaction->products()) {
|
||||
csvFile << product.name();
|
||||
if (++count < reaction.products().size()) {
|
||||
if (++count < reaction->products().size()) {
|
||||
csvFile << ",";
|
||||
}
|
||||
}
|
||||
csvFile << ";" << reaction.qValue() << ";";
|
||||
csvFile << ";" << reaction->qValue() << ";";
|
||||
// Reaction coefficients
|
||||
auto sources = reaction.sources();
|
||||
count = 0;
|
||||
for (const auto& source : sources) {
|
||||
csvFile << source;
|
||||
if (++count < sources.size()) {
|
||||
csvFile << ",";
|
||||
}
|
||||
}
|
||||
csvFile << ";";
|
||||
// Reaction coefficients
|
||||
count = 0;
|
||||
for (const auto& rates : reaction) {
|
||||
csvFile << rates;
|
||||
if (++count < reaction.size()) {
|
||||
csvFile << ",";
|
||||
}
|
||||
}
|
||||
csvFile << "\n";
|
||||
}
|
||||
csvFile.close();
|
||||
@@ -869,8 +856,8 @@ namespace gridfire {
|
||||
for (const auto& species : m_networkSpecies) {
|
||||
double netDestructionFlow = 0.0;
|
||||
for (const auto& reaction : m_reactions) {
|
||||
if (reaction.stoichiometry(species) < 0) {
|
||||
const double flow = calculateMolarReactionFlow<double>(reaction, Y, T9, rho);
|
||||
if (reaction->stoichiometry(species) < 0) {
|
||||
const auto flow = calculateMolarReactionFlow<double>(*reaction, Y, T9, rho);
|
||||
netDestructionFlow += flow;
|
||||
}
|
||||
}
|
||||
@@ -899,7 +886,7 @@ namespace gridfire {
|
||||
return false;
|
||||
}
|
||||
|
||||
void GraphEngine::recordADTape() {
|
||||
void GraphEngine::recordADTape() const {
|
||||
LOG_TRACE_L1(m_logger, "Recording AD tape for the RHS calculation...");
|
||||
|
||||
// Task 1: Set dimensions and initialize the matrix
|
||||
@@ -949,8 +936,8 @@ namespace gridfire {
|
||||
m_atomicReverseRates.reserve(m_reactions.size());
|
||||
|
||||
for (const auto& reaction: m_reactions) {
|
||||
if (reaction.qValue() != 0.0) {
|
||||
m_atomicReverseRates.push_back(std::make_unique<AtomicReverseRate>(reaction, *this));
|
||||
if (reaction->qValue() != 0.0) {
|
||||
m_atomicReverseRates.push_back(std::make_unique<AtomicReverseRate>(*reaction, *this));
|
||||
} else {
|
||||
m_atomicReverseRates.push_back(nullptr);
|
||||
}
|
||||
@@ -1038,7 +1025,8 @@ namespace gridfire {
|
||||
if ( p != 0) { return false; }
|
||||
const double T9 = tx[0];
|
||||
|
||||
const double reverseRate = m_engine.calculateReverseRate(m_reaction, T9);
|
||||
// TODO: Handle rho and Y
|
||||
const double reverseRate = m_engine.calculateReverseRate(m_reaction, T9, 0, {});
|
||||
// std::cout << m_reaction.peName() << " reverseRate: " << reverseRate << " at T9: " << T9 << "\n";
|
||||
ty[0] = reverseRate; // Store the reverse rate in the output vector
|
||||
|
||||
@@ -1058,8 +1046,8 @@ namespace gridfire {
|
||||
const double T9 = tx[0];
|
||||
const double reverseRate = ty[0];
|
||||
|
||||
const double derivative = m_engine.calculateReverseRateTwoBodyDerivative(m_reaction, T9, reverseRate);
|
||||
// std::cout << m_reaction.peName() << " reverseRate Derivative: " << derivative << "\n";
|
||||
// TODO: Handle rho and Y
|
||||
const double derivative = m_engine.calculateReverseRateTwoBodyDerivative(m_reaction, T9, 0, {}, reverseRate);
|
||||
|
||||
px[0] = py[0] * derivative; // Return the derivative of the reverse rate with respect to T9
|
||||
|
||||
|
||||
Reference in New Issue
Block a user