refactor(reaction): refactored to an abstract reaction class in prep for weak reactions

This commit is contained in:
2025-08-14 13:33:46 -04:00
parent d920a55ba6
commit 0b77f2e269
81 changed files with 1050041 additions and 913 deletions

View File

@@ -1,8 +1,6 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <string>
#include "types/bindings.h"
#include "partition/bindings.h"
#include "expectations/bindings.h"

View File

@@ -1,8 +1,6 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc if needed directly
#include <iostream>
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc. if needed directly
#include "bindings.h"
@@ -152,7 +150,7 @@ void register_engine_bindings(py::module &m) {
);
}
void register_base_engine_bindings(pybind11::module &m) {
void register_base_engine_bindings(const pybind11::module &m) {
py::class_<gridfire::StepDerivatives<double>>(m, "StepDerivatives")
.def_readonly("dYdt", &gridfire::StepDerivatives<double>::dydt, "The right-hand side (dY/dt) of the ODE system.")
@@ -165,15 +163,15 @@ void register_base_engine_bindings(pybind11::module &m) {
con_stype_register_graph_engine_bindings(m);
}
void abs_stype_register_engine_bindings(pybind11::module &m) {
void abs_stype_register_engine_bindings(const pybind11::module &m) {
py::class_<gridfire::Engine, PyEngine>(m, "Engine");
}
void abs_stype_register_dynamic_engine_bindings(pybind11::module &m) {
void abs_stype_register_dynamic_engine_bindings(const pybind11::module &m) {
const auto a = py::class_<gridfire::DynamicEngine, PyDynamicEngine>(m, "DynamicEngine");
}
void con_stype_register_graph_engine_bindings(pybind11::module &m) {
void con_stype_register_graph_engine_bindings(const pybind11::module &m) {
py::enum_<gridfire::NetworkBuildDepth>(m, "NetworkBuildDepth")
.value("Full", gridfire::NetworkBuildDepth::Full, "Full network build depth")
.value("Shallow", gridfire::NetworkBuildDepth::Shallow, "Shallow network build depth")
@@ -199,7 +197,7 @@ void con_stype_register_graph_engine_bindings(pybind11::module &m) {
py::arg("depth") = gridfire::NetworkBuildDepth::Full,
"Initialize GraphEngine with a composition, partition function and build depth."
);
py_dynamic_engine_bindings.def(py::init<const gridfire::reaction::LogicalReactionSet &>(),
py_dynamic_engine_bindings.def(py::init<const gridfire::reaction::ReactionSet &>(),
py::arg("reactions"),
"Initialize GraphEngine with a set of reactions."
);
@@ -267,7 +265,7 @@ void con_stype_register_graph_engine_bindings(pybind11::module &m) {
registerDynamicEngineDefs<gridfire::GraphEngine, gridfire::DynamicEngine>(py_dynamic_engine_bindings);
}
void register_engine_view_bindings(pybind11::module &m) {
void register_engine_view_bindings(const pybind11::module &m) {
auto py_defined_engine_view_bindings = py::class_<gridfire::DefinedEngineView, gridfire::DynamicEngine>(m, "DefinedEngineView");
py_defined_engine_view_bindings.def(py::init<std::vector<std::string>, gridfire::DynamicEngine&>(),

View File

@@ -4,13 +4,13 @@
void register_engine_bindings(pybind11::module &m);
void register_base_engine_bindings(pybind11::module &m);
void register_base_engine_bindings(const pybind11::module &m);
void register_engine_view_bindings(pybind11::module &m);
void register_engine_view_bindings(const pybind11::module &m);
void abs_stype_register_engine_bindings(pybind11::module &m);
void abs_stype_register_dynamic_engine_bindings(pybind11::module &m);
void abs_stype_register_engine_bindings(const pybind11::module &m);
void abs_stype_register_dynamic_engine_bindings(const pybind11::module &m);
void con_stype_register_graph_engine_bindings(pybind11::module &m);
void con_stype_register_graph_engine_bindings(const pybind11::module &m);

View File

@@ -23,10 +23,9 @@ const std::vector<fourdst::atomic::Species>& PyEngine::getNetworkSpecies() const
/*
* get_override() looks for a Python method that overrides this C++ one.
*/
py::function override = py::get_override(this, "getNetworkSpecies");
if (override) {
py::object result = override();
if (const py::function override = py::get_override(this, "getNetworkSpecies")) {
const py::object result = override();
m_species_cache = result.cast<std::vector<fourdst::atomic::Species>>();
return m_species_cache;
}
@@ -57,10 +56,9 @@ const std::vector<fourdst::atomic::Species>& PyDynamicEngine::getNetworkSpecies(
/*
* get_override() looks for a Python method that overrides this C++ one.
*/
py::function override = py::get_override(this, "getNetworkSpecies");
if (override) {
py::object result = override();
if (const py::function override = py::get_override(this, "getNetworkSpecies")) {
const py::object result = override();
m_species_cache = result.cast<std::vector<fourdst::atomic::Species>>();
return m_species_cache;
}
@@ -129,15 +127,15 @@ double PyDynamicEngine::calculateMolarReactionFlow(const gridfire::reaction::Rea
);
}
const gridfire::reaction::LogicalReactionSet& PyDynamicEngine::getNetworkReactions() const {
const gridfire::reaction::ReactionSet& PyDynamicEngine::getNetworkReactions() const {
PYBIND11_OVERRIDE_PURE(
const gridfire::reaction::LogicalReactionSet&,
const gridfire::reaction::ReactionSet&,
gridfire::DynamicEngine,
getNetworkReactions
);
}
void PyDynamicEngine::setNetworkReactions(const gridfire::reaction::LogicalReactionSet& reactions) {
void PyDynamicEngine::setNetworkReactions(const gridfire::reaction::ReactionSet& reactions) {
PYBIND11_OVERRIDE_PURE(
void,
gridfire::DynamicEngine,
@@ -199,7 +197,7 @@ gridfire::screening::ScreeningType PyDynamicEngine::getScreeningModel() const {
);
}
int PyDynamicEngine::getSpeciesIndex(const fourdst::atomic::Species &species) const {
size_t PyDynamicEngine::getSpeciesIndex(const fourdst::atomic::Species &species) const {
PYBIND11_OVERRIDE_PURE(
int,
gridfire::DynamicEngine,

View File

@@ -27,15 +27,16 @@ public:
void generateStoichiometryMatrix() override;
int getStoichiometryMatrixEntry(int speciesIndex, int reactionIndex) const override;
double calculateMolarReactionFlow(const gridfire::reaction::Reaction &reaction, const std::vector<double> &Y, double T9, double rho) const override;
const gridfire::reaction::LogicalReactionSet& getNetworkReactions() const override;
void setNetworkReactions(const gridfire::reaction::LogicalReactionSet& reactions) override;
const gridfire::reaction::ReactionSet& getNetworkReactions() const override;
void setNetworkReactions(const gridfire::reaction::ReactionSet& reactions) override;
std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::expectations::StaleEngineError> getSpeciesTimescales(const std::vector<double> &Y, double T9, double rho) const override;
std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::expectations::StaleEngineError> getSpeciesDestructionTimescales(const std::vector<double> &Y, double T9, double rho) const override;
fourdst::composition::Composition update(const gridfire::NetIn &netIn) override;
bool isStale(const gridfire::NetIn &netIn) override;
void setScreeningModel(gridfire::screening::ScreeningType model) override;
gridfire::screening::ScreeningType getScreeningModel() const override;
int getSpeciesIndex(const fourdst::atomic::Species &species) const override;
size_t getSpeciesIndex(const fourdst::atomic::Species &species) const override;
std::vector<double> mapNetInToMolarAbundanceVector(const gridfire::NetIn &netIn) const override;
gridfire::PrimingReport primeEngine(const gridfire::NetIn &netIn) override;
gridfire::BuildDepthType getDepth() const override {

View File

@@ -1,8 +1,4 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc if needed directly
#include <iostream>
#include "bindings.h"
@@ -10,7 +6,7 @@ namespace py = pybind11;
#include "gridfire/exceptions/exceptions.h"
void register_exception_bindings(py::module &m) {
void register_exception_bindings(const py::module &m) {
py::register_exception<gridfire::exceptions::EngineError>(m, "GridFireEngineError");
// TODO: Make it so that we can grab the stale state in python

View File

@@ -2,4 +2,4 @@
#include <pybind11/pybind11.h>
void register_exception_bindings(pybind11::module &m);
void register_exception_bindings(const pybind11::module &m);

View File

@@ -8,7 +8,7 @@ namespace py = pybind11;
#include "gridfire/expectations/expectations.h"
void register_expectation_bindings(py::module &m) {
void register_expectation_bindings(const py::module &m) {
py::enum_<gridfire::expectations::EngineErrorTypes>(m, "EngineErrorTypes")
.value("FAILURE", gridfire::expectations::EngineErrorTypes::FAILURE)
.value("INDEX", gridfire::expectations::EngineErrorTypes::INDEX)

View File

@@ -2,4 +2,4 @@
#include <pybind11/pybind11.h>
void register_expectation_bindings(pybind11::module &m);
void register_expectation_bindings(const pybind11::module &m);

View File

@@ -1,8 +1,6 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc if needed directly
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc. if needed directly
#include <string_view>
#include <vector>
#include "bindings.h"
@@ -12,12 +10,12 @@
namespace py = pybind11;
void register_io_bindings(py::module &m) {
py::class_<gridfire::io::ParsedNetworkData>(m, "ParsedNetworkData");
auto register_io_bindings(const py::module &m) -> void {
auto ParsedNetworkData = py::class_<gridfire::io::ParsedNetworkData>(m, "ParsedNetworkData");
py::class_<gridfire::io::NetworkFileParser, PyNetworkFileParser>(m, "NetworkFileParser");
auto NetworkFileParser = py::class_<gridfire::io::NetworkFileParser, PyNetworkFileParser>(m, "NetworkFileParser");
py::class_<gridfire::io::SimpleReactionListFileParser, gridfire::io::NetworkFileParser>(m, "SimpleReactionListFileParser")
auto SimpleReactionListFileParser = py::class_<gridfire::io::SimpleReactionListFileParser, gridfire::io::NetworkFileParser>(m, "SimpleReactionListFileParser")
.def("parse", &gridfire::io::SimpleReactionListFileParser::parse,
py::arg("filename"),
"Parse a simple reaction list file and return a ParsedNetworkData object.");

View File

@@ -2,4 +2,4 @@
#include <pybind11/pybind11.h>
void register_io_bindings(pybind11::module &m);
void register_io_bindings(const pybind11::module &m);

View File

@@ -1,6 +1,6 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc if needed directly
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc. if needed directly
#include <iostream>
#include <memory>
@@ -19,7 +19,7 @@ namespace py = pybind11;
void register_partition_bindings(pybind11::module &m) {
using PF = gridfire::partition::PartitionFunction;
py::class_<PF, PyPartitionFunction>(m, "PartitionFunction");
auto TrampPartitionFunction = py::class_<PF, PyPartitionFunction>(m, "PartitionFunction");
register_partition_types_bindings(m);
register_ground_state_partition_bindings(m);
@@ -44,7 +44,7 @@ void register_partition_types_bindings(pybind11::module &m) {
}, py::arg("typeStr"), "Convert string to BasePartitionType.");
}
void register_ground_state_partition_bindings(pybind11::module &m) {
void register_ground_state_partition_bindings(const pybind11::module &m) {
using GSPF = gridfire::partition::GroundStatePartitionFunction;
using PF = gridfire::partition::PartitionFunction;
py::class_<GSPF, PF>(m, "GroundStatePartitionFunction")
@@ -62,7 +62,7 @@ void register_ground_state_partition_bindings(pybind11::module &m) {
"Get the type of the partition function (should return 'GroundState').");
}
void register_rauscher_thielemann_partition_data_record_bindings(pybind11::module &m) {
void register_rauscher_thielemann_partition_data_record_bindings(const pybind11::module &m) {
py::class_<gridfire::partition::record::RauscherThielemannPartitionDataRecord>(m, "RauscherThielemannPartitionDataRecord")
.def_readonly("z", &gridfire::partition::record::RauscherThielemannPartitionDataRecord::z, "Atomic number")
.def_readonly("a", &gridfire::partition::record::RauscherThielemannPartitionDataRecord::a, "Mass number")
@@ -71,7 +71,7 @@ void register_rauscher_thielemann_partition_data_record_bindings(pybind11::modul
}
void register_rauscher_thielemann_partition_bindings(pybind11::module &m) {
void register_rauscher_thielemann_partition_bindings(const pybind11::module &m) {
using RTPF = gridfire::partition::RauscherThielemannPartitionFunction;
using PF = gridfire::partition::PartitionFunction;
py::class_<RTPF, PF>(m, "RauscherThielemannPartitionFunction")
@@ -89,7 +89,7 @@ void register_rauscher_thielemann_partition_bindings(pybind11::module &m) {
"Get the type of the partition function (should return 'RauscherThielemann').");
}
void register_composite_partition_bindings(pybind11::module &m) {
void register_composite_partition_bindings(const pybind11::module &m) {
py::class_<gridfire::partition::CompositePartitionFunction>(m, "CompositePartitionFunction")
.def(py::init<const std::vector<gridfire::partition::BasePartitionType>&>(),
py::arg("partitionFunctions"),

View File

@@ -6,11 +6,11 @@ void register_partition_bindings(pybind11::module &m);
void register_partition_types_bindings(pybind11::module &m);
void register_ground_state_partition_bindings(pybind11::module &m);
void register_ground_state_partition_bindings(const pybind11::module &m);
void register_rauscher_thielemann_partition_data_record_bindings(pybind11::module &m);
void register_rauscher_thielemann_partition_data_record_bindings(const pybind11::module &m);
void register_rauscher_thielemann_partition_bindings(pybind11::module &m);
void register_rauscher_thielemann_partition_bindings(const pybind11::module &m);
void register_composite_partition_bindings(pybind11::module &m);
void register_composite_partition_bindings(const pybind11::module &m);

View File

@@ -7,9 +7,9 @@
class PyPartitionFunction final : public gridfire::partition::PartitionFunction {
double evaluate(int z, int a, double T9) const override;
double evaluateDerivative(int z, int a, double T9) const override;
bool supports(int z, int a) const override;
std::string type() const override;
std::unique_ptr<gridfire::partition::PartitionFunction> clone() const override;
[[nodiscard]] double evaluate(int z, int a, double T9) const override;
[[nodiscard]] double evaluateDerivative(int z, int a, double T9) const override;
[[nodiscard]] bool supports(int z, int a) const override;
[[nodiscard]] std::string type() const override;
[[nodiscard]] std::unique_ptr<PartitionFunction> clone() const override;
};

View File

@@ -1,6 +1,6 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc if needed directly
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc. if needed directly
#include <string_view>
#include <vector>
@@ -22,150 +22,332 @@ void register_reaction_bindings(py::module &m) {
);
using fourdst::atomic::Species;
py::class_<gridfire::reaction::Reaction>(m, "Reaction")
.def(py::init<const std::string_view, const std::string_view, int, const std::vector<Species>&, const std::vector<Species>&, double, std::string_view, gridfire::reaction::RateCoefficientSet, bool>(),
py::arg("id"), py::arg("peName"), py::arg("chapter"),
py::arg("reactants"), py::arg("products"), py::arg("qValue"),
py::arg("label"), py::arg("sets"), py::arg("reverse") = false,
"Construct a Reaction with the given parameters.")
.def("calculate_rate", static_cast<double (gridfire::reaction::Reaction::*)(double) const>(&gridfire::reaction::Reaction::calculate_rate),
py::arg("T9"), "Calculate the reaction rate at a given temperature T9 (in units of 10^9 K).")
.def("peName", &gridfire::reaction::Reaction::peName,
"Get the reaction name in (projectile, ejectile) notation (e.g., 'p(p,g)d').")
.def("chapter", &gridfire::reaction::Reaction::chapter,
"Get the REACLIB chapter number defining the reaction structure.")
.def("sourceLabel", &gridfire::reaction::Reaction::sourceLabel,
"Get the source label for the rate data (e.g., 'wc12w', 'st08').")
.def("rateCoefficients", &gridfire::reaction::Reaction::rateCoefficients,
"get the set of rate coefficients.")
.def("contains", &gridfire::reaction::Reaction::contains,
py::arg("species"), "Check if the reaction contains a specific species.")
.def("contains_reactant", &gridfire::reaction::Reaction::contains_reactant,
"Check if the reaction contains a specific reactant species.")
.def("contains_product", &gridfire::reaction::Reaction::contains_product,
"Check if the reaction contains a specific product species.")
.def("all_species", &gridfire::reaction::Reaction::all_species,
"Get all species involved in the reaction (both reactants and products) as a set.")
.def("reactant_species", &gridfire::reaction::Reaction::reactant_species,
"Get the reactant species of the reaction as a set.")
.def("product_species", &gridfire::reaction::Reaction::product_species,
"Get the product species of the reaction as a set.")
.def("num_species", &gridfire::reaction::Reaction::num_species,
"Count the number of species in the reaction.")
.def("stoichiometry", static_cast<int (gridfire::reaction::Reaction::*)(const Species&) const>(&gridfire::reaction::Reaction::stoichiometry),
py::arg("species"),
"Get the stoichiometry of the reaction as a map from species to their coefficients.")
.def("stoichiometry", static_cast<std::unordered_map<Species, int> (gridfire::reaction::Reaction::*)() const>(&gridfire::reaction::Reaction::stoichiometry),
"Get the stoichiometry of the reaction as a map from species to their coefficients.")
.def("id", &gridfire::reaction::Reaction::id,
"Get the unique identifier of the reaction.")
.def("qValue", &gridfire::reaction::Reaction::qValue,
"Get the Q-value of the reaction in MeV.")
.def("reactants", &gridfire::reaction::Reaction::reactants,
"Get a list of reactant species in the reaction.")
.def("products", &gridfire::reaction::Reaction::products,
"Get a list of product species in the reaction.")
.def("is_reverse", &gridfire::reaction::Reaction::is_reverse,
"Check if this is a reverse reaction rate.")
.def("excess_energy", &gridfire::reaction::Reaction::excess_energy,
"Calculate the excess energy from the mass difference of reactants and products.")
.def("__eq__", &gridfire::reaction::Reaction::operator==,
"Equality operator for reactions based on their IDs.")
.def("__neq__", &gridfire::reaction::Reaction::operator!=,
"Inequality operator for reactions based on their IDs.")
.def("hash", &gridfire::reaction::Reaction::hash,
py::arg("seed") = 0,
"Compute a hash for the reaction based on its ID.")
.def("__repr__", [](const gridfire::reaction::Reaction& self) {
std::stringstream ss;
ss << self; // Use the existing operator<< for Reaction
return ss.str();
});
py::class_<gridfire::reaction::LogicalReaction, gridfire::reaction::Reaction>(m, "LogicalReaction")
.def(py::init<const std::vector<gridfire::reaction::Reaction>>(),
py::arg("reactions"),
"Construct a LogicalReaction from a vector of Reaction objects.")
.def("add_reaction", &gridfire::reaction::LogicalReaction::add_reaction,
py::arg("reaction"),
"Add another Reaction source to this logical reaction.")
.def("size", &gridfire::reaction::LogicalReaction::size,
"Get the number of source rates contributing to this logical reaction.")
.def("__len__", &gridfire::reaction::LogicalReaction::size,
"Overload len() to return the number of source rates.")
.def("sources", &gridfire::reaction::LogicalReaction::sources,
"Get the list of source labels for the aggregated rates.")
.def("calculate_rate", static_cast<double (gridfire::reaction::LogicalReaction::*)(double) const>(&gridfire::reaction::LogicalReaction::calculate_rate),
py::arg("T9"), "Calculate the reaction rate at a given temperature T9 (in units of 10^9 K).")
.def("calculate_forward_rate_log_derivative", &gridfire::reaction::LogicalReaction::calculate_forward_rate_log_derivative,
py::arg("T9"), "Calculate the forward rate log derivative at a given temperature T9 (in units of 10^9 K).");
py::class_<gridfire::reaction::LogicalReactionSet>(m, "LogicalReactionSet")
.def(py::init<const std::vector<gridfire::reaction::LogicalReaction>>(),
py::arg("reactions"),
"Construct a LogicalReactionSet from a vector of LogicalReaction objects.")
.def(py::init<>(),
"Default constructor for an empty LogicalReactionSet.")
.def(py::init<const gridfire::reaction::LogicalReactionSet&>(),
py::arg("other"),
"Copy constructor for LogicalReactionSet.")
.def("add_reaction", &gridfire::reaction::LogicalReactionSet::add_reaction,
py::arg("reaction"),
"Add a LogicalReaction to the set.")
.def("remove_reaction", &gridfire::reaction::LogicalReactionSet::remove_reaction,
py::arg("reaction"),
"Remove a LogicalReaction from the set.")
.def("contains", py::overload_cast<const std::string_view&>(&gridfire::reaction::LogicalReactionSet::contains, py::const_),
py::class_<gridfire::reaction::ReaclibReaction>(m, "ReaclibReaction")
.def(
py::init<
const std::string_view,
const std::string_view,
int,
const std::vector<Species>&,
const std::vector<Species>&,
double, std::string_view,
gridfire::reaction::RateCoefficientSet,
bool
>(),
py::arg("id"),
"Check if the set contains a specific LogicalReaction.")
.def("contains", py::overload_cast<const gridfire::reaction::Reaction&>(&gridfire::reaction::LogicalReactionSet::contains, py::const_),
py::arg("reaction"),
"Check if the set contains a specific Reaction.")
.def("size", &gridfire::reaction::LogicalReactionSet::size,
"Get the number of LogicalReactions in the set.")
.def("__len__", &gridfire::reaction::LogicalReactionSet::size,
"Overload len() to return the number of LogicalReactions.")
.def("clear", &gridfire::reaction::LogicalReactionSet::clear,
"Remove all LogicalReactions from the set.")
.def("containes_species", &gridfire::reaction::LogicalReactionSet::contains_species,
py::arg("species"),
"Check if any reaction in the set involves the given species.")
.def("contains_reactant", &gridfire::reaction::LogicalReactionSet::contains_reactant,
py::arg("species"),
"Check if any reaction in the set has the species as a reactant.")
.def("contains_product", &gridfire::reaction::LogicalReactionSet::contains_product,
py::arg("species"),
"Check if any reaction in the set has the species as a product.")
.def("__getitem__", py::overload_cast<size_t>(&gridfire::reaction::LogicalReactionSet::operator[], py::const_),
py::arg("index"),
"Get a LogicalReaction by index.")
.def("__getitem___", py::overload_cast<const std::string_view&>(&gridfire::reaction::LogicalReactionSet::operator[], py::const_),
py::arg("id"),
"Get a LogicalReaction by its ID.")
.def("__eq__", &gridfire::reaction::LogicalReactionSet::operator==,
py::arg("LogicalReactionSet"),
"Equality operator for LogicalReactionSets based on their contents.")
.def("__ne__", &gridfire::reaction::LogicalReactionSet::operator!=,
py::arg("LogicalReactionSet"),
"Inequality operator for LogicalReactionSets based on their contents.")
.def("hash", &gridfire::reaction::LogicalReactionSet::hash,
py::arg("seed") = 0,
"Compute a hash for the LogicalReactionSet based on its contents."
py::arg("peName"),
py::arg("chapter"),
py::arg("reactants"),
py::arg("products"),
py::arg("qValue"),
py::arg("label"),
py::arg("sets"),
py::arg("reverse") = false,
"Construct a Reaction with the given parameters."
)
.def("__repr__", [](const gridfire::reaction::LogicalReactionSet& self) {
std::stringstream ss;
ss << self;
return ss.str();
})
.def("getReactionSetSpecies", &gridfire::reaction::LogicalReactionSet::getReactionSetSpecies,
"Get all species involved in the reactions of the set as a set of Species objects.");
.def(
"calculate_rate",
[](const gridfire::reaction::ReaclibReaction& self, const double T9, const double rho, const std::vector<double>& Y) -> double {
return self.calculate_rate(T9, rho, Y);
},
py::arg("T9"),
py::arg("rho"),
py::arg("Y"),
"Calculate the reaction rate at a given temperature T9 (in units of 10^9 K)."
)
.def(
"peName",
&gridfire::reaction::ReaclibReaction::peName,
"Get the reaction name in (projectile, ejectile) notation (e.g., 'p(p,g)d')."
)
.def(
"chapter",
&gridfire::reaction::ReaclibReaction::chapter,
"Get the REACLIB chapter number defining the reaction structure."
)
.def(
"sourceLabel",
&gridfire::reaction::ReaclibReaction::sourceLabel,
"Get the source label for the rate data (e.g., 'wc12w', 'st08')."
)
.def(
"rateCoefficients",
&gridfire::reaction::ReaclibReaction::rateCoefficients,
"get the set of rate coefficients."
)
.def(
"contains",
&gridfire::reaction::ReaclibReaction::contains,
py::arg("species"),
"Check if the reaction contains a specific species."
)
.def(
"contains_reactant",
&gridfire::reaction::ReaclibReaction::contains_reactant,
"Check if the reaction contains a specific reactant species."
)
.def(
"contains_product",
&gridfire::reaction::ReaclibReaction::contains_product,
"Check if the reaction contains a specific product species."
)
.def(
"all_species",
&gridfire::reaction::ReaclibReaction::all_species,
"Get all species involved in the reaction (both reactants and products) as a set."
)
.def(
"reactant_species",
&gridfire::reaction::ReaclibReaction::reactant_species,
"Get the reactant species of the reaction as a set."
)
.def(
"product_species",
&gridfire::reaction::ReaclibReaction::product_species,
"Get the product species of the reaction as a set."
)
.def(
"num_species",
&gridfire::reaction::ReaclibReaction::num_species,
"Count the number of species in the reaction."
)
.def(
"stoichiometry",
[](const gridfire::reaction::ReaclibReaction& self, const Species& species) -> int {
return self.stoichiometry(species);
},
py::arg("species"),
"Get the stoichiometry of the reaction as a map from species to their coefficients."
)
.def(
"stoichiometry",
[](const gridfire::reaction::ReaclibReaction& self) -> std::unordered_map<Species, int> {
return self.stoichiometry();
},
"Get the stoichiometry of the reaction as a map from species to their coefficients."
)
.def(
"id",
&gridfire::reaction::ReaclibReaction::id,
"Get the unique identifier of the reaction."
)
.def(
"qValue",
&gridfire::reaction::ReaclibReaction::qValue,
"Get the Q-value of the reaction in MeV."
)
.def(
"reactants",
&gridfire::reaction::ReaclibReaction::reactants,
"Get a list of reactant species in the reaction."
)
.def(
"products",
&gridfire::reaction::ReaclibReaction::products,
"Get a list of product species in the reaction."
)
.def(
"is_reverse",
&gridfire::reaction::ReaclibReaction::is_reverse,
"Check if this is a reverse reaction rate."
)
.def(
"excess_energy",
&gridfire::reaction::ReaclibReaction::excess_energy,
"Calculate the excess energy from the mass difference of reactants and products."
)
.def(
"__eq__",
&gridfire::reaction::ReaclibReaction::operator==,
"Equality operator for reactions based on their IDs."
)
.def(
"__neq__",
&gridfire::reaction::ReaclibReaction::operator!=,
"Inequality operator for reactions based on their IDs."
)
.def(
"hash",
&gridfire::reaction::ReaclibReaction::hash,
py::arg("seed") = 0,
"Compute a hash for the reaction based on its ID."
)
.def(
"__repr__",
[](const gridfire::reaction::ReaclibReaction& self) {
std::stringstream ss;
ss << self; // Use the existing operator<< for Reaction
return ss.str();
}
);
m.def("packReactionSetToLogicalReactionSet",
&gridfire::reaction::packReactionSetToLogicalReactionSet,
py::class_<gridfire::reaction::LogicalReaclibReaction, gridfire::reaction::ReaclibReaction>(m, "LogicalReaclibReaction")
.def(
py::init<const std::vector<gridfire::reaction::Reaction>>(),
py::arg("reactions"),
"Construct a LogicalReaclibReaction from a vector of Reaction objects."
)
.def(
"add_reaction",
&gridfire::reaction::LogicalReaclibReaction::add_reaction,
py::arg("reaction"),
"Add another Reaction source to this logical reaction."
)
.def(
"size",
&gridfire::reaction::LogicalReaclibReaction::size,
"Get the number of source rates contributing to this logical reaction."
)
.def(
"__len__",
&gridfire::reaction::LogicalReaclibReaction::size,
"Overload len() to return the number of source rates."
)
.def(
"sources",
&gridfire::reaction::LogicalReaclibReaction::sources,
"Get the list of source labels for the aggregated rates."
)
.def(
"calculate_rate",
[](const gridfire::reaction::LogicalReaclibReaction& self, const double T9, const double rho, const std::vector<double>& Y) -> double {
return self.calculate_rate(T9, rho, Y);
},
py::arg("T9"),
"Calculate the reaction rate at a given temperature T9 (in units of 10^9 K)."
)
.def(
"calculate_forward_rate_log_derivative",
&gridfire::reaction::LogicalReaclibReaction::calculate_forward_rate_log_derivative,
py::arg("T9"),
"Calculate the forward rate log derivative at a given temperature T9 (in units of 10^9 K)."
);
py::class_<gridfire::reaction::ReactionSet>(m, "ReactionSet")
// TODO: Fix the constructor to accept a vector of unique ptrs to Reaclib Reactions
.def(
py::init<const std::vector<gridfire::reaction::Reaction>>(),
py::arg("reactions"),
"Construct a LogicalReactionSet from a vector of LogicalReaclibReaction objects."
)
.def(
py::init<>(),
"Default constructor for an empty LogicalReactionSet."
)
.def(
py::init<const gridfire::reaction::ReactionSet&>(),
py::arg("other"),
"Copy constructor for LogicalReactionSet."
)
.def(
"add_reaction",
py::overload_cast<const gridfire::reaction::Reaction&>(&gridfire::reaction::ReactionSet::add_reaction),
py::arg("reaction"),
"Add a LogicalReaclibReaction to the set."
)
.def(
"remove_reaction",
&gridfire::reaction::ReactionSet::remove_reaction,
py::arg("reaction"),
"Remove a LogicalReaclibReaction from the set."
)
.def(
"contains",
py::overload_cast<const std::string_view&>(&gridfire::reaction::ReactionSet::contains, py::const_),
py::arg("id"),
"Check if the set contains a specific LogicalReaclibReaction."
)
.def(
"contains",
py::overload_cast<const gridfire::reaction::Reaction&>(&gridfire::reaction::ReactionSet::contains, py::const_),
py::arg("reaction"),
"Check if the set contains a specific Reaction."
)
.def(
"size",
&gridfire::reaction::ReactionSet::size,
"Get the number of LogicalReactions in the set."
)
.def(
"__len__", &gridfire::reaction::ReactionSet::size,
"Overload len() to return the number of LogicalReactions."
)
.def(
"clear",
&gridfire::reaction::ReactionSet::clear,
"Remove all LogicalReactions from the set."
)
.def("contains_species",
&gridfire::reaction::ReactionSet::contains_species,
py::arg("species"),
"Check if any reaction in the set involves the given species."
)
.def(
"contains_reactant",
&gridfire::reaction::ReactionSet::contains_reactant,
py::arg("species"),
"Check if any reaction in the set has the species as a reactant."
)
.def(
"contains_product",
&gridfire::reaction::ReactionSet::contains_product,
py::arg("species"),
"Check if any reaction in the set has the species as a product."
)
.def(
"__getitem__",
py::overload_cast<size_t>(&gridfire::reaction::ReactionSet::operator[], py::const_),
py::arg("index"),
"Get a LogicalReaclibReaction by index."
)
.def(
"__getitem___",
py::overload_cast<const std::string_view&>(&gridfire::reaction::ReactionSet::operator[], py::const_),
py::arg("id"),
"Get a LogicalReaclibReaction by its ID."
)
.def(
"__eq__",
&gridfire::reaction::ReactionSet::operator==,
py::arg("LogicalReactionSet"),
"Equality operator for LogicalReactionSets based on their contents."
)
.def(
"__ne__",
&gridfire::reaction::ReactionSet::operator!=,
py::arg("LogicalReactionSet"),
"Inequality operator for LogicalReactionSets based on their contents."
)
.def(
"hash",
&gridfire::reaction::ReactionSet::hash,
py::arg("seed") = 0,
"Compute a hash for the LogicalReactionSet based on its contents."
)
.def(
"__repr__",
[](const gridfire::reaction::ReactionSet& self) {
std::stringstream ss;
ss << self;
return ss.str();
}
)
.def(
"getReactionSetSpecies",
&gridfire::reaction::ReactionSet::getReactionSetSpecies,
"Get all species involved in the reactions of the set as a set of Species objects."
);
m.def(
"packReactionSet",
&gridfire::reaction::packReactionSet,
py::arg("reactionSet"),
"Convert a ReactionSet to a LogicalReactionSet by aggregating reactions with the same peName."
);
m.def("get_all_reactions", &gridfire::reaclib::get_all_reactions,
"Get all reactions from the REACLIB database.");
m.def(
"get_all_reactions",
&gridfire::reaclib::get_all_reaclib_reactions,
"Get all reactions from the REACLIB database."
);
}

View File

@@ -1,8 +1,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc if needed directly
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc. if needed directly
#include <string_view>
#include <vector>
#include "bindings.h"
@@ -13,30 +12,53 @@
namespace py = pybind11;
void register_screening_bindings(py::module &m) {
py::class_<gridfire::screening::ScreeningModel, PyScreening>(m, "ScreeningModel");
auto screening_model = py::class_<gridfire::screening::ScreeningModel, PyScreening>(m, "ScreeningModel");
py::enum_<gridfire::screening::ScreeningType>(m, "ScreeningType")
.value("BARE", gridfire::screening::ScreeningType::BARE)
.value("WEAK", gridfire::screening::ScreeningType::WEAK)
.export_values();
m.def("selectScreeningModel", &gridfire::screening::selectScreeningModel,
py::arg("type"),
"Select a screening model based on the specified type. Returns a pointer to the selected model.");
m.def(
"selectScreeningModel",
&gridfire::screening::selectScreeningModel,
py::arg("type"),
"Select a screening model based on the specified type. Returns a pointer to the selected model."
);
py::class_<gridfire::screening::BareScreeningModel>(m, "BareScreeningModel")
.def(py::init<>())
.def("calculateScreeningFactors",
py::overload_cast<const gridfire::reaction::LogicalReactionSet&, const std::vector<fourdst::atomic::Species>&, const std::vector<double>&, double, double>(&gridfire::screening::BareScreeningModel::calculateScreeningFactors, py::const_),
py::arg("reactions"), py::arg("species"), py::arg("Y"), py::arg("T9"), py::arg("rho"),
"Calculate the bare plasma screening factors. This always returns 1.0 (bare)"
);
py::overload_cast<
const gridfire::reaction::ReactionSet&,
const std::vector<fourdst::atomic::Species>&,
const std::vector<double>&,
double,
double
>(&gridfire::screening::BareScreeningModel::calculateScreeningFactors, py::const_),
py::arg("reactions"),
py::arg("species"),
py::arg("Y"),
py::arg("T9"),
py::arg("rho"),
"Calculate the bare plasma screening factors. This always returns 1.0 (bare)"
);
py::class_<gridfire::screening::WeakScreeningModel>(m, "WeakScreeningModel")
.def(py::init<>())
.def("calculateScreeningFactors",
py::overload_cast<const gridfire::reaction::LogicalReactionSet&, const std::vector<fourdst::atomic::Species>&, const std::vector<double>&, double, double>(&gridfire::screening::WeakScreeningModel::calculateScreeningFactors, py::const_),
py::arg("reactions"), py::arg("species"), py::arg("Y"), py::arg("T9"), py::arg("rho"),
"Calculate the weak plasma screening factors using the Salpeter (1954) model."
);
py::overload_cast<
const gridfire::reaction::ReactionSet&,
const std::vector<fourdst::atomic::Species>&,
const std::vector<double>&,
double,
double
>(&gridfire::screening::WeakScreeningModel::calculateScreeningFactors, py::const_),
py::arg("reactions"),
py::arg("species"),
py::arg("Y"),
py::arg("T9"),
py::arg("rho"),
"Calculate the weak plasma screening factors using the Salpeter (1954) model."
);
}

View File

@@ -13,7 +13,13 @@
namespace py = pybind11;
std::vector<double> PyScreening::calculateScreeningFactors(const gridfire::reaction::LogicalReactionSet &reactions, const std::vector<fourdst::atomic::Species> &species, const std::vector<double> &Y, const double T9, const double rho) const {
std::vector<double> PyScreening::calculateScreeningFactors(
const gridfire::reaction::ReactionSet &reactions,
const std::vector<fourdst::atomic::Species> &species,
const std::vector<double> &Y,
const double T9,
const double rho
) const {
PYBIND11_OVERLOAD_PURE(
std::vector<double>, // Return type
gridfire::screening::ScreeningModel,
@@ -22,7 +28,13 @@ std::vector<double> PyScreening::calculateScreeningFactors(const gridfire::react
}
using ADDouble = gridfire::screening::ScreeningModel::ADDouble;
std::vector<ADDouble> PyScreening::calculateScreeningFactors(const gridfire::reaction::LogicalReactionSet &reactions, const std::vector<fourdst::atomic::Species> &species, const std::vector<ADDouble> &Y, const ADDouble T9, const ADDouble rho) const {
std::vector<ADDouble> PyScreening::calculateScreeningFactors(
const gridfire::reaction::ReactionSet &reactions,
const std::vector<fourdst::atomic::Species> &species,
const std::vector<ADDouble> &Y,
const ADDouble T9,
const ADDouble rho
) const {
PYBIND11_OVERLOAD_PURE(
std::vector<ADDouble>, // Return type
gridfire::screening::ScreeningModel,

View File

@@ -2,15 +2,24 @@
#include "gridfire/screening/screening.h"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/functional.h> // Needed for std::function
#include <vector>
#include "cppad/cppad.hpp"
class PyScreening final : public gridfire::screening::ScreeningModel {
std::vector<double> calculateScreeningFactors(const gridfire::reaction::LogicalReactionSet &reactions, const std::vector<fourdst::atomic::Species> &species, const std::vector<double> &Y, const double T9, const double rho) const override;
std::vector<ADDouble> calculateScreeningFactors(const gridfire::reaction::LogicalReactionSet &reactions, const std::vector<fourdst::atomic::Species> &species, const std::vector<ADDouble> &Y, const ADDouble T9, const ADDouble rho) const override;
[[nodiscard]] std::vector<double> calculateScreeningFactors(
const gridfire::reaction::ReactionSet &reactions,
const std::vector<fourdst::atomic::Species> &species,
const std::vector<double> &Y,
double T9,
double rho
) const override;
[[nodiscard]] std::vector<ADDouble> calculateScreeningFactors(
const gridfire::reaction::ReactionSet &reactions,
const std::vector<fourdst::atomic::Species> &species,
const std::vector<ADDouble> &Y,
ADDouble T9,
ADDouble rho
) const override;
};

View File

@@ -2,7 +2,6 @@
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc. if needed directly
#include <pybind11/numpy.h>
#include <pybind11/functional.h> // Needed for std::function
#include <boost/numeric/ublas/vector.hpp>
@@ -30,7 +29,7 @@ void register_solver_bindings(const py::module &m) {
);
py_direct_network_solver.def("set_callback",
[](gridfire::solver::DirectNetworkSolver &self, gridfire::solver::DirectNetworkSolver::TimestepCallback cb) {
[](gridfire::solver::DirectNetworkSolver &self, const gridfire::solver::DirectNetworkSolver::TimestepCallback& cb) {
self.set_callback(cb);
},
py::arg("callback"),

View File

@@ -2,7 +2,6 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/functional.h> // Needed for std::function
#include <vector>
#include <tuple>

View File

@@ -11,5 +11,5 @@ class PyDynamicNetworkSolverStrategy final : public gridfire::solver::DynamicNet
explicit PyDynamicNetworkSolverStrategy(gridfire::DynamicEngine &engine) : gridfire::solver::DynamicNetworkSolverStrategy(engine) {}
gridfire::NetOut evaluate(const gridfire::NetIn &netIn) override;
void set_callback(const std::any &callback) override;
std::vector<std::tuple<std::string, std::string>> describe_callback_context() const override;
[[nodiscard]] std::vector<std::tuple<std::string, std::string>> describe_callback_context() const override;
};

View File

@@ -8,7 +8,7 @@ namespace py = pybind11;
#include "gridfire/network.h"
void register_type_bindings(pybind11::module &m) {
void register_type_bindings(const pybind11::module &m) {
py::class_<gridfire::NetIn>(m, "NetIn")
.def(py::init<>())
.def_readwrite("composition", &gridfire::NetIn::composition)

View File

@@ -2,4 +2,4 @@
#include <pybind11/pybind11.h>
void register_type_bindings(pybind11::module &m);
void register_type_bindings(const pybind11::module &m);

View File

@@ -1,5 +1,4 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc. if needed directly
#include "bindings.h"