fix(python-bindings): Updated python bindings to new interface

The python bindings now work with the polymorphic reaction class and the CVODE solver
This commit is contained in:
2025-10-30 15:05:08 -04:00
parent 23df87f915
commit 7fded59814
27 changed files with 962 additions and 255 deletions

View File

@@ -5,6 +5,8 @@
#include "bindings.h"
#include "gridfire/engine/engine.h"
#include "gridfire/engine/diagnostics/dynamic_engine_diagnostics.h"
#include "gridfire/exceptions/exceptions.h"
#include "trampoline/py_engine.h"
@@ -17,23 +19,70 @@ namespace {
template <IsDynamicEngine T, IsDynamicEngine BaseT>
void registerDynamicEngineDefs(py::class_<T, BaseT> pyClass) {
pyClass.def("calculateRHSAndEnergy", &T::calculateRHSAndEnergy,
py::arg("Y"),
pyClass.def(
"calculateRHSAndEnergy",
[](
const gridfire::DynamicEngine& self,
const fourdst::composition::Composition& comp,
const double T9,
const double rho
) {
auto result = self.calculateRHSAndEnergy(comp, T9, rho);
if (!result.has_value()) {
throw gridfire::exceptions::StaleEngineError("Engine reports stale state, call update().");
}
return result.value();
},
py::arg("comp"),
py::arg("T9"),
py::arg("rho"),
"Calculate the right-hand side (dY/dt) and energy generation rate."
)
.def("generateJacobianMatrix", py::overload_cast<const std::vector<double>&, double, double>(&T::generateJacobianMatrix, py::const_),
py::arg("Y_dynamic"),
.def("calculateEpsDerivatives",
&gridfire::DynamicEngine::calculateEpsDerivatives,
py::arg("comp"),
py::arg("T9"),
py::arg("rho"),
"Calculate deps/dT and deps/drho"
)
.def("generateJacobianMatrix",
py::overload_cast<const fourdst::composition::Composition&, double, double>(&T::generateJacobianMatrix, py::const_),
py::arg("comp"),
py::arg("T9"),
py::arg("rho"),
"Generate the Jacobian matrix for the current state."
)
.def("generateStoichiometryMatrix", &T::generateStoichiometryMatrix)
.def("generateJacobianMatrix",
py::overload_cast<const fourdst::composition::Composition&, double, double, const std::vector<fourdst::atomic::Species>&>(&T::generateJacobianMatrix, py::const_),
py::arg("comp"),
py::arg("T9"),
py::arg("rho"),
py::arg("activeSpecies"),
"Generate the jacobian matrix only for the subset of the matrix representing the active species."
)
.def("generateJacobianMatrix",
py::overload_cast<const fourdst::composition::Composition&, double, double, const gridfire::SparsityPattern&>(&T::generateJacobianMatrix, py::const_),
py::arg("comp"),
py::arg("T9"),
py::arg("rho"),
py::arg("sparsityPattern"),
"Generate the jacobian matrix for the given sparsity pattern"
)
.def("generateStoichiometryMatrix",
&T::generateStoichiometryMatrix
)
.def("calculateMolarReactionFlow",
static_cast<double (T::*)(const gridfire::reaction::Reaction&, const std::vector<double>&, const double, const double) const>(&T::calculateMolarReactionFlow),
[](
const gridfire::DynamicEngine& self,
const gridfire::reaction::Reaction& reaction,
const fourdst::composition::Composition& comp,
const double T9,
const double rho
) -> double {
return self.calculateMolarReactionFlow(reaction, comp, T9, rho);
},
py::arg("reaction"),
py::arg("Y"),
py::arg("comp"),
py::arg("T9"),
py::arg("rho"),
"Calculate the molar reaction flow for a given reaction."
@@ -49,61 +98,99 @@ namespace {
"Set the network reactions to a new set of reactions."
)
.def("getJacobianMatrixEntry", &T::getJacobianMatrixEntry,
py::arg("i"),
py::arg("j"),
py::arg("rowSpecies"),
py::arg("colSpecies"),
"Get an entry from the previously generated Jacobian matrix."
)
.def("getStoichiometryMatrixEntry", &T::getStoichiometryMatrixEntry,
py::arg("speciesIndex"),
py::arg("reactionIndex"),
py::arg("species"),
py::arg("reaction"),
"Get an entry from the stoichiometry matrix."
)
.def("getSpeciesTimescales", &T::getSpeciesTimescales,
py::arg("Y"),
.def("getSpeciesTimescales",
[](
const gridfire::DynamicEngine& self,
const fourdst::composition::Composition& comp,
const double T9,
const double rho
) -> std::unordered_map<fourdst::atomic::Species, double> {
const auto result = self.getSpeciesTimescales(comp, T9, rho);
if (!result.has_value()) {
throw gridfire::exceptions::StaleEngineError("Engine reports stale state, call update().");
}
return result.value();
},
py::arg("comp"),
py::arg("T9"),
py::arg("rho"),
"Get the timescales for each species in the network."
)
.def("getSpeciesDestructionTimescales", &T::getSpeciesDestructionTimescales,
py::arg("Y"),
.def("getSpeciesDestructionTimescales",
[](
const gridfire::DynamicEngine& self,
const fourdst::composition::Composition& comp,
const double T9,
const double rho
) -> std::unordered_map<fourdst::atomic::Species, double> {
const auto result = self.getSpeciesDestructionTimescales(comp, T9, rho);
if (!result.has_value()) {
throw gridfire::exceptions::StaleEngineError("Engine reports stale state, call update().");
}
return result.value();
},
py::arg("comp"),
py::arg("T9"),
py::arg("rho"),
"Get the destruction timescales for each species in the network."
)
.def("update", &T::update,
.def("update",
&T::update,
py::arg("netIn"),
"Update the engine state based on the provided NetIn object."
)
.def("setScreeningModel", &T::setScreeningModel,
.def("setScreeningModel",
&T::setScreeningModel,
py::arg("screeningModel"),
"Set the screening model for the engine."
)
.def("getScreeningModel", &T::getScreeningModel,
.def("getScreeningModel",
&T::getScreeningModel,
"Get the current screening model of the engine."
)
.def("getSpeciesIndex", &T::getSpeciesIndex,
.def("getSpeciesIndex",
&T::getSpeciesIndex,
py::arg("species"),
"Get the index of a species in the network."
)
.def("mapNetInToMolarAbundanceVector", &T::mapNetInToMolarAbundanceVector,
.def("mapNetInToMolarAbundanceVector",
&T::mapNetInToMolarAbundanceVector,
py::arg("netIn"),
"Map a NetIn object to a vector of molar abundances."
)
.def("primeEngine", &T::primeEngine,
.def("primeEngine",
&T::primeEngine,
py::arg("netIn"),
"Prime the engine with a NetIn object to prepare for calculations."
)
.def("getDepth", &T::getDepth,
.def("getDepth",
&T::getDepth,
"Get the current build depth of the engine."
)
.def("rebuild", &T::rebuild,
.def("rebuild",
&T::rebuild,
py::arg("composition"),
py::arg("depth") = gridfire::NetworkBuildDepth::Full,
"Rebuild the engine with a new composition and build depth."
)
.def("isStale", &T::isStale,
.def("isStale",
&T::isStale,
py::arg("netIn"),
"Check if the engine is stale based on the provided NetIn object."
)
.def("collectComposition",
&T::collectComposition,
py::arg("composition"),
"Recursively collect composition from current engine and any sub engines if they exist."
);
}
@@ -112,14 +199,123 @@ namespace {
void register_engine_bindings(py::module &m) {
register_base_engine_bindings(m);
register_engine_view_bindings(m);
register_engine_diagnostic_bindings(m);
register_engine_procedural_bindings(m);
register_engine_type_bindings(m);
}
m.def("build_reaclib_nuclear_network", &gridfire::build_reaclib_nuclear_network,
py::arg("composition"),
py::arg("maxLayers") = gridfire::NetworkBuildDepth::Full,
py::arg("reverse") = false,
"Build a nuclear network from a composition using ReacLib data."
void register_base_engine_bindings(const pybind11::module &m) {
py::class_<gridfire::StepDerivatives<double>>(m, "StepDerivatives")
.def_readonly("dYdt", &gridfire::StepDerivatives<double>::dydt, "The right-hand side (dY/dt) of the ODE system.")
.def_readonly("energy", &gridfire::StepDerivatives<double>::nuclearEnergyGenerationRate, "The energy generation rate.");
py::class_<gridfire::SparsityPattern> py_sparsity_pattern(m, "SparsityPattern");
abs_stype_register_engine_bindings(m);
abs_stype_register_dynamic_engine_bindings(m);
con_stype_register_graph_engine_bindings(m);
}
void abs_stype_register_engine_bindings(const pybind11::module &m) {
py::class_<gridfire::Engine, PyEngine>(m, "Engine");
}
void abs_stype_register_dynamic_engine_bindings(const pybind11::module &m) {
const auto a = py::class_<gridfire::DynamicEngine, PyDynamicEngine>(m, "DynamicEngine");
}
void register_engine_procedural_bindings(pybind11::module &m) {
auto procedures = m.def_submodule("procedures", "Procedural functions associated with engine module");
register_engine_construction_bindings(procedures);
register_engine_construction_bindings(procedures);
}
void register_engine_diagnostic_bindings(pybind11::module &m) {
auto diagnostics = m.def_submodule("diagnostics", "A submodule for engine diagnostics");
diagnostics.def("report_limiting_species",
&gridfire::diagnostics::report_limiting_species,
py::arg("engine"),
py::arg("Y_full"),
py::arg("E_full"),
py::arg("dydt_full"),
py::arg("relTol"),
py::arg("absTol"),
py::arg("top_n") = 10
);
diagnostics.def("inspect_species_balance",
&gridfire::diagnostics::inspect_species_balance,
py::arg("engine"),
py::arg("species_name"),
py::arg("comp"),
py::arg("T9"),
py::arg("rho")
);
diagnostics.def("inspect_jacobian_stiffness",
&gridfire::diagnostics::inspect_jacobian_stiffness,
py::arg("engine"),
py::arg("comp"),
py::arg("T9"),
py::arg("rho")
);
}
void register_engine_construction_bindings(pybind11::module &m) {
m.def("build_nuclear_network", &gridfire::build_nuclear_network,
py::arg("composition"),
py::arg("weakInterpolator"),
py::arg("maxLayers") = gridfire::NetworkBuildDepth::Full,
py::arg("reverse") = false,
"Build a nuclear network from a composition using all archived reaction data."
);
}
void register_engine_priming_bindings(pybind11::module &m) {
m.def("calculateDestructionRateConstant",
&gridfire::calculateDestructionRateConstant,
py::arg("engine"),
py::arg("species"),
py::arg("composition"),
py::arg("T9"),
py::arg("rho"),
py::arg("reactionTypesToIgnore")
);
m.def("calculateCreationRate",
&gridfire::calculateCreationRate,
py::arg("engine"),
py::arg("species"),
py::arg("composition"),
py::arg("T9"),
py::arg("rho"),
py::arg("reactionTypesToIgnore")
);
}
void register_engine_type_bindings(pybind11::module &m) {
auto types = m.def_submodule("types", "Types associated with engine module");
register_engine_building_type_bindings(types);
register_engine_reporting_type_bindings(types);
}
void register_engine_building_type_bindings(pybind11::module &m) {
py::enum_<gridfire::NetworkBuildDepth>(m, "NetworkBuildDepth")
.value("Full", gridfire::NetworkBuildDepth::Full, "Full network build depth")
.value("Shallow", gridfire::NetworkBuildDepth::Shallow, "Shallow network build depth")
.value("SecondOrder", gridfire::NetworkBuildDepth::SecondOrder, "Second order network build depth")
.value("ThirdOrder", gridfire::NetworkBuildDepth::ThirdOrder, "Third order network build depth")
.value("FourthOrder", gridfire::NetworkBuildDepth::FourthOrder, "Fourth order network build depth")
.value("FifthOrder", gridfire::NetworkBuildDepth::FifthOrder, "Fifth order network build depth")
.export_values();
py::class_<gridfire::BuildDepthType> py_build_depth_type(m, "BuildDepthType");
}
void register_engine_reporting_type_bindings(pybind11::module &m) {
py::enum_<gridfire::PrimingReportStatus>(m, "PrimingReportStatus")
.value("FULL_SUCCESS", gridfire::PrimingReportStatus::FULL_SUCCESS, "Priming was full successful.")
.value("NO_SPECIES_TO_PRIME", gridfire::PrimingReportStatus::NO_SPECIES_TO_PRIME, "No species to prime.")
@@ -150,135 +346,128 @@ void register_engine_bindings(py::module &m) {
);
}
void register_base_engine_bindings(const pybind11::module &m) {
py::class_<gridfire::StepDerivatives<double>>(m, "StepDerivatives")
.def_readonly("dYdt", &gridfire::StepDerivatives<double>::dydt, "The right-hand side (dY/dt) of the ODE system.")
.def_readonly("energy", &gridfire::StepDerivatives<double>::nuclearEnergyGenerationRate, "The energy generation rate.");
py::class_<gridfire::SparsityPattern>(m, "SparsityPattern");
abs_stype_register_engine_bindings(m);
abs_stype_register_dynamic_engine_bindings(m);
con_stype_register_graph_engine_bindings(m);
}
void abs_stype_register_engine_bindings(const pybind11::module &m) {
py::class_<gridfire::Engine, PyEngine>(m, "Engine");
}
void abs_stype_register_dynamic_engine_bindings(const pybind11::module &m) {
const auto a = py::class_<gridfire::DynamicEngine, PyDynamicEngine>(m, "DynamicEngine");
}
void con_stype_register_graph_engine_bindings(const pybind11::module &m) {
py::enum_<gridfire::NetworkBuildDepth>(m, "NetworkBuildDepth")
.value("Full", gridfire::NetworkBuildDepth::Full, "Full network build depth")
.value("Shallow", gridfire::NetworkBuildDepth::Shallow, "Shallow network build depth")
.value("SecondOrder", gridfire::NetworkBuildDepth::SecondOrder, "Second order network build depth")
.value("ThirdOrder", gridfire::NetworkBuildDepth::ThirdOrder, "Third order network build depth")
.value("FourthOrder", gridfire::NetworkBuildDepth::FourthOrder, "Fourth order network build depth")
.value("FifthOrder", gridfire::NetworkBuildDepth::FifthOrder, "Fifth order network build depth")
.export_values();
py::class_<gridfire::BuildDepthType>(m, "BuildDepthType");
auto py_dynamic_engine_bindings = py::class_<gridfire::GraphEngine, gridfire::DynamicEngine>(m, "GraphEngine");
auto py_graph_engine_bindings = py::class_<gridfire::GraphEngine, gridfire::DynamicEngine>(m, "GraphEngine");
// Register the Graph Engine Specific Bindings
py_dynamic_engine_bindings.def(py::init<const fourdst::composition::Composition &, const gridfire::BuildDepthType>(),
py_graph_engine_bindings.def(py::init<const fourdst::composition::Composition &, const gridfire::BuildDepthType>(),
py::arg("composition"),
py::arg("depth") = gridfire::NetworkBuildDepth::Full,
"Initialize GraphEngine with a composition and build depth."
);
py_dynamic_engine_bindings.def(py::init<const fourdst::composition::Composition &, const gridfire::partition::PartitionFunction &, const gridfire::BuildDepthType>(),
py_graph_engine_bindings.def(py::init<const fourdst::composition::Composition &,const gridfire::partition::PartitionFunction &, const gridfire::BuildDepthType>(),
py::arg("composition"),
py::arg("partitionFunction"),
py::arg("depth") = gridfire::NetworkBuildDepth::Full,
"Initialize GraphEngine with a composition, partition function and build depth."
);
py_dynamic_engine_bindings.def(py::init<const gridfire::reaction::ReactionSet &>(),
py_graph_engine_bindings.def(py::init<const gridfire::reaction::ReactionSet &>(),
py::arg("reactions"),
"Initialize GraphEngine with a set of reactions."
);
py_dynamic_engine_bindings.def("generateJacobianMatrix", py::overload_cast<const std::vector<double>&, double, double, const gridfire::SparsityPattern&>(&gridfire::GraphEngine::generateJacobianMatrix, py::const_),
py::arg("Y_dynamic"),
py::arg("T9"),
py::arg("rho"),
py::arg("sparsityPattern"),
"Generate the Jacobian matrix for the current state with a specified sparsity pattern."
);
py_dynamic_engine_bindings.def_static("getNetReactionStoichiometry", &gridfire::GraphEngine::getNetReactionStoichiometry,
py_graph_engine_bindings.def_static("getNetReactionStoichiometry",
&gridfire::GraphEngine::getNetReactionStoichiometry,
py::arg("reaction"),
"Get the net stoichiometry for a given reaction."
);
py_dynamic_engine_bindings.def("involvesSpecies", &gridfire::GraphEngine::involvesSpecies,
py_graph_engine_bindings.def("getSpeciesTimescales",
py::overload_cast<const fourdst::composition::Composition&, double, double, const gridfire::reaction::ReactionSet&>(&gridfire::GraphEngine::getSpeciesTimescales, py::const_),
py::arg("composition"),
py::arg("T9"),
py::arg("rho"),
py::arg("activeReactions")
);
py_graph_engine_bindings.def("getSpeciesDestructionTimescales",
py::overload_cast<const fourdst::composition::Composition&, double, double, const gridfire::reaction::ReactionSet&>(&gridfire::GraphEngine::getSpeciesDestructionTimescales, py::const_),
py::arg("composition"),
py::arg("T9"),
py::arg("rho"),
py::arg("activeReactions")
);
py_graph_engine_bindings.def("involvesSpecies",
&gridfire::GraphEngine::involvesSpecies,
py::arg("species"),
"Check if a given species is involved in the network."
);
py_dynamic_engine_bindings.def("exportToDot", &gridfire::GraphEngine::exportToDot,
py_graph_engine_bindings.def("exportToDot",
&gridfire::GraphEngine::exportToDot,
py::arg("filename"),
"Export the network to a DOT file for visualization."
);
py_dynamic_engine_bindings.def("exportToCSV", &gridfire::GraphEngine::exportToCSV,
py_graph_engine_bindings.def("exportToCSV",
&gridfire::GraphEngine::exportToCSV,
py::arg("filename"),
"Export the network to a CSV file for analysis."
);
py_dynamic_engine_bindings.def("setPrecomputation", &gridfire::GraphEngine::setPrecomputation,
py_graph_engine_bindings.def("setPrecomputation",
&gridfire::GraphEngine::setPrecomputation,
py::arg("precompute"),
"Enable or disable precomputation for the engine."
);
py_dynamic_engine_bindings.def("isPrecomputationEnabled", &gridfire::GraphEngine::isPrecomputationEnabled,
py_graph_engine_bindings.def("isPrecomputationEnabled",
&gridfire::GraphEngine::isPrecomputationEnabled,
"Check if precomputation is enabled for the engine."
);
py_dynamic_engine_bindings.def("getPartitionFunction", &gridfire::GraphEngine::getPartitionFunction,
py_graph_engine_bindings.def("getPartitionFunction",
&gridfire::GraphEngine::getPartitionFunction,
"Get the partition function used by the engine."
);
py_dynamic_engine_bindings.def("calculateReverseRate", &gridfire::GraphEngine::calculateReverseRate,
py_graph_engine_bindings.def("calculateReverseRate",
&gridfire::GraphEngine::calculateReverseRate,
py::arg("reaction"),
py::arg("T9"),
"Calculate the reverse rate for a given reaction at a specific temperature."
py::arg("rho"),
py::arg("composition"),
"Calculate the reverse rate for a given reaction at a specific temperature, density, and composition."
);
py_dynamic_engine_bindings.def("calculateReverseRateTwoBody", &gridfire::GraphEngine::calculateReverseRateTwoBody,
py_graph_engine_bindings.def("calculateReverseRateTwoBody",
&gridfire::GraphEngine::calculateReverseRateTwoBody,
py::arg("reaction"),
py::arg("T9"),
py::arg("forwardRate"),
py::arg("expFactor"),
"Calculate the reverse rate for a two-body reaction at a specific temperature."
);
py_dynamic_engine_bindings.def("calculateReverseRateTwoBodyDerivative", &gridfire::GraphEngine::calculateReverseRateTwoBodyDerivative,
py_graph_engine_bindings.def("calculateReverseRateTwoBodyDerivative",
&gridfire::GraphEngine::calculateReverseRateTwoBodyDerivative,
py::arg("reaction"),
py::arg("T9"),
py::arg("rho"),
py::arg("composition"),
py::arg("reverseRate"),
"Calculate the derivative of the reverse rate for a two-body reaction at a specific temperature."
);
py_dynamic_engine_bindings.def("isUsingReverseReactions", &gridfire::GraphEngine::isUsingReverseReactions,
py_graph_engine_bindings.def("isUsingReverseReactions",
&gridfire::GraphEngine::isUsingReverseReactions,
"Check if the engine is using reverse reactions."
);
py_dynamic_engine_bindings.def("setUseReverseReactions", &gridfire::GraphEngine::setUseReverseReactions,
py_graph_engine_bindings.def("setUseReverseReactions",
&gridfire::GraphEngine::setUseReverseReactions,
py::arg("useReverse"),
"Enable or disable the use of reverse reactions in the engine."
);
// Register the general dynamic engine bindings
registerDynamicEngineDefs<gridfire::GraphEngine, gridfire::DynamicEngine>(py_dynamic_engine_bindings);
registerDynamicEngineDefs<gridfire::GraphEngine, gridfire::DynamicEngine>(py_graph_engine_bindings);
}
void register_engine_view_bindings(const pybind11::module &m) {
auto py_defined_engine_view_bindings = py::class_<gridfire::DefinedEngineView, gridfire::DynamicEngine>(m, "DefinedEngineView");
py_defined_engine_view_bindings.def(py::init<std::vector<std::string>, gridfire::DynamicEngine&>(),
py_defined_engine_view_bindings.def(py::init<std::vector<std::string>, gridfire::GraphEngine&>(),
py::arg("peNames"),
py::arg("baseEngine"),
"Construct a defined engine view with a list of tracked reactions and a base engine.");
"Construct a defined engine view with a list of tracked reactions and a base engine."
);
py_defined_engine_view_bindings.def("getBaseEngine", &gridfire::DefinedEngineView::getBaseEngine,
"Get the base engine associated with this defined engine view.");
registerDynamicEngineDefs<gridfire::DefinedEngineView, gridfire::DynamicEngine>(py_defined_engine_view_bindings);
auto py_file_defined_engine_view_bindings = py::class_<gridfire::FileDefinedEngineView, gridfire::DefinedEngineView>(m, "FileDefinedEngineView");
py_file_defined_engine_view_bindings.def(py::init<gridfire::DynamicEngine&, const std::string&, const gridfire::io::NetworkFileParser&>(),
py_file_defined_engine_view_bindings.def(
py::init<gridfire::GraphEngine&, const std::string&, const gridfire::io::NetworkFileParser&>(),
py::arg("baseEngine"),
py::arg("fileName"),
py::arg("parser"),
@@ -296,11 +485,11 @@ void register_engine_view_bindings(const pybind11::module &m) {
registerDynamicEngineDefs<gridfire::FileDefinedEngineView, gridfire::DefinedEngineView>(py_file_defined_engine_view_bindings);
auto py_priming_engine_view_bindings = py::class_<gridfire::NetworkPrimingEngineView, gridfire::DefinedEngineView>(m, "NetworkPrimingEngineView");
py_priming_engine_view_bindings.def(py::init<const std::string&, gridfire::DynamicEngine&>(),
py_priming_engine_view_bindings.def(py::init<const std::string&, gridfire::GraphEngine&>(),
py::arg("primingSymbol"),
py::arg("baseEngine"),
"Construct a priming engine view with a priming symbol and a base engine.");
py_priming_engine_view_bindings.def(py::init<const fourdst::atomic::Species&, gridfire::DynamicEngine&>(),
py_priming_engine_view_bindings.def(py::init<const fourdst::atomic::Species&, gridfire::GraphEngine&>(),
py::arg("primingSpecies"),
py::arg("baseEngine"),
"Construct a priming engine view with a priming species and a base engine.");
@@ -313,8 +502,10 @@ void register_engine_view_bindings(const pybind11::module &m) {
py_adaptive_engine_view_bindings.def(py::init<gridfire::DynamicEngine&>(),
py::arg("baseEngine"),
"Construct an adaptive engine view with a base engine.");
py_adaptive_engine_view_bindings.def("getBaseEngine", &gridfire::AdaptiveEngineView::getBaseEngine,
"Get the base engine associated with this adaptive engine view.");
py_adaptive_engine_view_bindings.def("getBaseEngine",
&gridfire::AdaptiveEngineView::getBaseEngine,
"Get the base engine associated with this adaptive engine view."
);
registerDynamicEngineDefs<gridfire::AdaptiveEngineView, gridfire::DynamicEngine>(py_adaptive_engine_view_bindings);
@@ -341,43 +532,63 @@ void register_engine_view_bindings(const pybind11::module &m) {
auto py_multiscale_engine_view_bindings = py::class_<gridfire::MultiscalePartitioningEngineView, gridfire::DynamicEngine>(m, "MultiscalePartitioningEngineView");
py_multiscale_engine_view_bindings.def(py::init<gridfire::GraphEngine&>(),
py::arg("baseEngine"),
"Construct a multiscale partitioning engine view with a base engine.");
py_multiscale_engine_view_bindings.def("getBaseEngine", &gridfire::MultiscalePartitioningEngineView::getBaseEngine,
"Get the base engine associated with this multiscale partitioning engine view.");
py_multiscale_engine_view_bindings.def("analyzeTimescalePoolConnectivity", &gridfire::MultiscalePartitioningEngineView::analyzeTimescalePoolConnectivity,
"Construct a multiscale partitioning engine view with a base engine."
);
py_multiscale_engine_view_bindings.def("getBaseEngine",
&gridfire::MultiscalePartitioningEngineView::getBaseEngine,
"Get the base engine associated with this multiscale partitioning engine view."
);
py_multiscale_engine_view_bindings.def("analyzeTimescalePoolConnectivity",
&gridfire::MultiscalePartitioningEngineView::analyzeTimescalePoolConnectivity,
py::arg("timescale_pools"),
py::arg("Y"),
py::arg("comp"),
py::arg("T9"),
py::arg("rho"),
"Analyze the connectivity of timescale pools in the network.");
py_multiscale_engine_view_bindings.def("partitionNetwork", py::overload_cast<const std::vector<double>&, double, double>(&gridfire::MultiscalePartitioningEngineView::partitionNetwork),
py::arg("Y"),
"Analyze the connectivity of timescale pools in the network."
);
py_multiscale_engine_view_bindings.def("partitionNetwork",
py::overload_cast<const fourdst::composition::Composition&, double, double>(&gridfire::MultiscalePartitioningEngineView::partitionNetwork),
py::arg("comp"),
py::arg("T9"),
py::arg("rho"),
"Partition the network based on species timescales and connectivity.");
py_multiscale_engine_view_bindings.def("partitionNetwork", py::overload_cast<const gridfire::NetIn&>(&gridfire::MultiscalePartitioningEngineView::partitionNetwork),
py_multiscale_engine_view_bindings.def("partitionNetwork",
py::overload_cast<const gridfire::NetIn&>(&gridfire::MultiscalePartitioningEngineView::partitionNetwork),
py::arg("netIn"),
"Partition the network based on a NetIn object.");
py_multiscale_engine_view_bindings.def("exportToDot", &gridfire::MultiscalePartitioningEngineView::exportToDot,
"Partition the network based on a NetIn object."
);
py_multiscale_engine_view_bindings.def("exportToDot",
&gridfire::MultiscalePartitioningEngineView::exportToDot,
py::arg("filename"),
py::arg("Y"),
py::arg("comp"),
py::arg("T9"),
py::arg("rho"),
"Export the network to a DOT file for visualization.");
py_multiscale_engine_view_bindings.def("getFastSpecies", &gridfire::MultiscalePartitioningEngineView::getFastSpecies,
"Get the list of fast species in the network.");
py_multiscale_engine_view_bindings.def("getDynamicSpecies", &gridfire::MultiscalePartitioningEngineView::getDynamicSpecies,
"Get the list of dynamic species in the network.");
py_multiscale_engine_view_bindings.def("equilibrateNetwork", py::overload_cast<const std::vector<double>&, double, double>(&gridfire::MultiscalePartitioningEngineView::equilibrateNetwork),
py::arg("Y"),
"Export the network to a DOT file for visualization."
);
py_multiscale_engine_view_bindings.def("getFastSpecies",
&gridfire::MultiscalePartitioningEngineView::getFastSpecies,
"Get the list of fast species in the network."
);
py_multiscale_engine_view_bindings.def("getDynamicSpecies",
&gridfire::MultiscalePartitioningEngineView::getDynamicSpecies,
"Get the list of dynamic species in the network."
);
py_multiscale_engine_view_bindings.def("equilibrateNetwork",
py::overload_cast<const fourdst::composition::Composition&, double, double>(&gridfire::MultiscalePartitioningEngineView::equilibrateNetwork),
py::arg("comp"),
py::arg("T9"),
py::arg("rho"),
"Equilibrate the network based on species abundances and conditions.");
py_multiscale_engine_view_bindings.def("equilibrateNetwork", py::overload_cast<const gridfire::NetIn&>(&gridfire::MultiscalePartitioningEngineView::equilibrateNetwork),
py_multiscale_engine_view_bindings.def("equilibrateNetwork",
py::overload_cast<const gridfire::NetIn&>(&gridfire::MultiscalePartitioningEngineView::equilibrateNetwork),
py::arg("netIn"),
"Equilibrate the network based on a NetIn object.");
"Equilibrate the network based on a NetIn object."
);
registerDynamicEngineDefs<gridfire::MultiscalePartitioningEngineView, gridfire::DynamicEngine>(
py_multiscale_engine_view_bindings
);
registerDynamicEngineDefs<gridfire::MultiscalePartitioningEngineView, gridfire::DynamicEngine>(py_multiscale_engine_view_bindings);
}
@@ -387,3 +598,4 @@ void register_engine_view_bindings(const pybind11::module &m) {

View File

@@ -13,4 +13,14 @@ void abs_stype_register_dynamic_engine_bindings(const pybind11::module &m);
void con_stype_register_graph_engine_bindings(const pybind11::module &m);
void register_engine_diagnostic_bindings(pybind11::module &m);
void register_engine_procedural_bindings(pybind11::module &m);
void register_engine_construction_bindings(pybind11::module &m);
void register_engine_priming_bindings(pybind11::module &m);
void register_engine_type_bindings(pybind11::module &m);
void register_engine_building_type_bindings(pybind11::module &m);
void register_engine_reporting_type_bindings(pybind11::module &m);

View File

@@ -33,12 +33,12 @@ const std::vector<fourdst::atomic::Species>& PyEngine::getNetworkSpecies() const
py::pybind11_fail("Tried to call pure virtual function \"DynamicEngine::getNetworkSpecies\"");
}
std::expected<gridfire::StepDerivatives<double>, gridfire::expectations::StaleEngineError> PyEngine::calculateRHSAndEnergy(const std::vector<double> &Y, double T9, double rho) const {
std::expected<gridfire::StepDerivatives<double>, gridfire::expectations::StaleEngineError> PyEngine::calculateRHSAndEnergy(const fourdst::composition::Composition &comp, double T9, double rho) const {
PYBIND11_OVERRIDE_PURE(
PYBIND11_TYPE(std::expected<gridfire::StepDerivatives<double>, gridfire::expectations::StaleEngineError>),
gridfire::Engine,
calculateRHSAndEnergy,
Y, T9, rho
comp, T9, rho
);
}
@@ -65,39 +65,62 @@ const std::vector<fourdst::atomic::Species>& PyDynamicEngine::getNetworkSpecies(
py::pybind11_fail("Tried to call pure virtual function \"DynamicEngine::getNetworkSpecies\"");
}
std::expected<gridfire::StepDerivatives<double>, gridfire::expectations::StaleEngineError> PyDynamicEngine::calculateRHSAndEnergy(const std::vector<double> &Y, double T9, double rho) const {
std::expected<gridfire::StepDerivatives<double>, gridfire::expectations::StaleEngineError> PyDynamicEngine::calculateRHSAndEnergy(const fourdst::composition::Composition &comp, double T9, double rho) const {
PYBIND11_OVERRIDE_PURE(
PYBIND11_TYPE(std::expected<gridfire::StepDerivatives<double>, gridfire::expectations::StaleEngineError>),
gridfire::Engine,
calculateRHSAndEnergy,
Y, T9, rho
comp, T9, rho
);
}
void PyDynamicEngine::generateJacobianMatrix(const std::vector<double> &Y_dynamic, double T9, double rho) const {
void PyDynamicEngine::generateJacobianMatrix(const fourdst::composition::Composition& comp, double T9, double rho) const {
PYBIND11_OVERRIDE_PURE(
void,
gridfire::DynamicEngine,
generateJacobianMatrix,
Y_dynamic, T9, rho
comp,
T9,
rho
);
}
void PyDynamicEngine::generateJacobianMatrix(const std::vector<double> &Y_dynamic, double T9, double rho, const gridfire::SparsityPattern &sparsityPattern) const {
void PyDynamicEngine::generateJacobianMatrix(
const fourdst::composition::Composition &comp,
const double T9,
const double rho,
const std::vector<fourdst::atomic::Species> &activeSpecies
) const {
PYBIND11_OVERRIDE_PURE(
void,
gridfire::DynamicEngine,
generateJacobianMatrix,
Y_dynamic, T9, rho, sparsityPattern
comp,
T9,
rho,
activeSpecies
);
}
double PyDynamicEngine::getJacobianMatrixEntry(int i, int j) const {
void PyDynamicEngine::generateJacobianMatrix(const fourdst::composition::Composition &comp, double T9, double rho, const gridfire::SparsityPattern &sparsityPattern) const {
PYBIND11_OVERRIDE_PURE(
void,
gridfire::DynamicEngine,
generateJacobianMatrix,
comp,
T9,
rho,
sparsityPattern
);
}
double PyDynamicEngine::getJacobianMatrixEntry(const fourdst::atomic::Species& rowSpecies, const fourdst::atomic::Species& colSpecies) const {
PYBIND11_OVERRIDE_PURE(
double,
gridfire::DynamicEngine,
getJacobianMatrixEntry,
i, j
rowSpecies,
colSpecies
);
}
@@ -109,21 +132,25 @@ void PyDynamicEngine::generateStoichiometryMatrix() {
);
}
int PyDynamicEngine::getStoichiometryMatrixEntry(int speciesIndex, int reactionIndex) const {
int PyDynamicEngine::getStoichiometryMatrixEntry(const fourdst::atomic::Species& species, const gridfire::reaction::Reaction& reaction) const {
PYBIND11_OVERRIDE_PURE(
int,
gridfire::DynamicEngine,
getStoichiometryMatrixEntry,
speciesIndex, reactionIndex
species,
reaction
);
}
double PyDynamicEngine::calculateMolarReactionFlow(const gridfire::reaction::Reaction &reaction, const std::vector<double> &Y, double T9, double rho) const {
double PyDynamicEngine::calculateMolarReactionFlow(const gridfire::reaction::Reaction &reaction, const fourdst::composition::Composition &comp, double T9, double rho) const {
PYBIND11_OVERRIDE_PURE(
double,
gridfire::DynamicEngine,
calculateMolarReactionFlow,
reaction, Y, T9, rho
reaction,
comp,
T9,
rho
);
}
@@ -144,21 +171,23 @@ void PyDynamicEngine::setNetworkReactions(const gridfire::reaction::ReactionSet&
);
}
std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::expectations::StaleEngineError> PyDynamicEngine::getSpeciesTimescales(const std::vector<double> &Y, double T9, double rho) const {
std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::expectations::StaleEngineError> PyDynamicEngine::getSpeciesTimescales(const fourdst::composition::Composition &comp, double T9, double rho) const {
PYBIND11_OVERRIDE_PURE(
PYBIND11_TYPE(std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::expectations::StaleEngineError>),
gridfire::DynamicEngine,
getSpeciesTimescales,
Y, T9, rho
comp,
T9,
rho
);
}
std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::expectations::StaleEngineError> PyDynamicEngine::getSpeciesDestructionTimescales(const std::vector<double> &Y, double T9, double rho) const {
std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::expectations::StaleEngineError> PyDynamicEngine::getSpeciesDestructionTimescales(const fourdst::composition::Composition &comp, double T9, double rho) const {
PYBIND11_OVERRIDE_PURE(
PYBIND11_TYPE(std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::expectations::StaleEngineError>),
gridfire::DynamicEngine,
getSpeciesDestructionTimescales,
Y, T9, rho
comp, T9, rho
);
}
@@ -224,6 +253,31 @@ gridfire::PrimingReport PyDynamicEngine::primeEngine(const gridfire::NetIn &netI
);
}
gridfire::EnergyDerivatives PyDynamicEngine::calculateEpsDerivatives(
const fourdst::composition::Composition &comp,
const double T9,
const double rho) const {
PYBIND11_OVERRIDE_PURE(
gridfire::EnergyDerivatives,
gridfire::DynamicEngine,
calculateEpsDerivatives,
comp,
T9,
rho
);
}
fourdst::composition::Composition PyDynamicEngine::collectComposition(
fourdst::composition::Composition &comp
) const {
PYBIND11_OVERRIDE_PURE(
fourdst::composition::Composition,
gridfire::DynamicEngine,
collectComposition,
comp
);
}
const gridfire::Engine& PyEngineView::getBaseEngine() const {
PYBIND11_OVERRIDE_PURE(
const gridfire::Engine&,

View File

@@ -12,7 +12,12 @@
class PyEngine final : public gridfire::Engine {
public:
const std::vector<fourdst::atomic::Species>& getNetworkSpecies() const override;
std::expected<gridfire::StepDerivatives<double>,gridfire::expectations::StaleEngineError> calculateRHSAndEnergy(const std::vector<double> &Y, double T9, double rho) const override;
std::expected<gridfire::StepDerivatives<double>,gridfire::expectations::StaleEngineError> calculateRHSAndEnergy(
const fourdst::composition::Composition& comp,
double T9,
double rho
) const override;
private:
mutable std::vector<fourdst::atomic::Species> m_species_cache;
};
@@ -20,41 +25,124 @@ private:
class PyDynamicEngine final : public gridfire::DynamicEngine {
public:
const std::vector<fourdst::atomic::Species>& getNetworkSpecies() const override;
std::expected<gridfire::StepDerivatives<double>,gridfire::expectations::StaleEngineError> calculateRHSAndEnergy(const std::vector<double> &Y, double T9, double rho) const override;
void generateJacobianMatrix(const std::vector<double> &Y_dynamic, double T9, double rho) const override;
void generateJacobianMatrix(const std::vector<double> &Y_dynamic, double T9, double rho, const gridfire::SparsityPattern &sparsityPattern) const override;
double getJacobianMatrixEntry(int i, int j) const override;
std::expected<gridfire::StepDerivatives<double>, gridfire::expectations::StaleEngineError> calculateRHSAndEnergy(
const fourdst::composition::Composition& comp,
double T9,
double rho
) const override;
void generateJacobianMatrix(
const fourdst::composition::Composition& comp,
double T9,
double rho
) const override;
void generateJacobianMatrix(
const fourdst::composition::Composition &comp,
double T9,
double rho,
const std::vector<fourdst::atomic::Species> &activeSpecies
) const override;
void generateJacobianMatrix(
const fourdst::composition::Composition& comp,
double T9,
double rho,
const gridfire::SparsityPattern &sparsityPattern
) const override;
double getJacobianMatrixEntry(
const fourdst::atomic::Species& rowSpecies,
const fourdst::atomic::Species& colSpecies
) const override;
void generateStoichiometryMatrix() override;
int getStoichiometryMatrixEntry(int speciesIndex, int reactionIndex) const override;
double calculateMolarReactionFlow(const gridfire::reaction::Reaction &reaction, const std::vector<double> &Y, double T9, double rho) const override;
int getStoichiometryMatrixEntry(
const fourdst::atomic::Species& species,
const gridfire::reaction::Reaction& reaction
) const override;
double calculateMolarReactionFlow(
const gridfire::reaction::Reaction &reaction,
const fourdst::composition::Composition& comp,
double T9,
double rho
) const override;
const gridfire::reaction::ReactionSet& getNetworkReactions() const override;
void setNetworkReactions(const gridfire::reaction::ReactionSet& reactions) override;
std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::expectations::StaleEngineError> getSpeciesTimescales(const std::vector<double> &Y, double T9, double rho) const override;
std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::expectations::StaleEngineError> getSpeciesDestructionTimescales(const std::vector<double> &Y, double T9, double rho) const override;
fourdst::composition::Composition update(const gridfire::NetIn &netIn) override;
bool isStale(const gridfire::NetIn &netIn) override;
void setScreeningModel(gridfire::screening::ScreeningType model) override;
void setNetworkReactions(
const gridfire::reaction::ReactionSet& reactions
) override;
std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::expectations::StaleEngineError> getSpeciesTimescales(
const fourdst::composition::Composition& comp,
double T9,
double rho
) const override;
std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::expectations::StaleEngineError> getSpeciesDestructionTimescales(
const fourdst::composition::Composition &comp,
double T9,
double rho
) const override;
fourdst::composition::Composition update(
const gridfire::NetIn &netIn
) override;
bool isStale(
const gridfire::NetIn &netIn
) override;
void setScreeningModel(
gridfire::screening::ScreeningType model
) override;
gridfire::screening::ScreeningType getScreeningModel() const override;
size_t getSpeciesIndex(const fourdst::atomic::Species &species) const override;
std::vector<double> mapNetInToMolarAbundanceVector(const gridfire::NetIn &netIn) const override;
gridfire::PrimingReport primeEngine(const gridfire::NetIn &netIn) override;
size_t getSpeciesIndex(
const fourdst::atomic::Species &species
) const override;
std::vector<double> mapNetInToMolarAbundanceVector(
const gridfire::NetIn &netIn
) const override;
gridfire::PrimingReport primeEngine(
const gridfire::NetIn &netIn
) override;
gridfire::BuildDepthType getDepth() const override {
throw std::logic_error("Network depth not supported by this engine.");
}
void rebuild(const fourdst::composition::Composition& comp, gridfire::BuildDepthType depth) override {
void rebuild(
const fourdst::composition::Composition& comp,
gridfire::BuildDepthType depth
) override {
throw std::logic_error("Setting network depth not supported by this engine.");
}
[[nodiscard]] gridfire::EnergyDerivatives calculateEpsDerivatives(
const fourdst::composition::Composition &comp,
double T9,
double rho
) const override;
fourdst::composition::Composition collectComposition(
fourdst::composition::Composition &comp
) const override;
private:
mutable std::vector<fourdst::atomic::Species> m_species_cache;
};
class PyEngineView final : public gridfire::EngineView<gridfire::Engine> {
const gridfire::Engine& getBaseEngine() const override;
[[nodiscard]] const gridfire::Engine& getBaseEngine() const override;
};
class PyDynamicEngineView final : public gridfire::EngineView<gridfire::DynamicEngine> {
const gridfire::DynamicEngine& getBaseEngine() const override;
[[nodiscard]] const gridfire::DynamicEngine& getBaseEngine() const override;
};

View File

@@ -38,4 +38,21 @@ void register_exception_bindings(const py::module &m) {
return self.what();
});
py::register_exception<gridfire::exceptions::FailedToPartitionEngineError>(m, "FailedToPartitionEngineError", m.attr("GridFireEngineError"));
py::register_exception<gridfire::exceptions::NetworkResizedError>(m, "NetworkResizedError", m.attr("GridFireEngineError"));
py::register_exception<gridfire::exceptions::UnableToSetNetworkReactionsError>(m, "UnableToSetNetworkReactionsError", m.attr("GridFireEngineError"));
py::register_exception<gridfire::exceptions::BadCollectionError>(m, "BadCollectionError", m.attr("GridFireEngineError"));
py::register_exception<gridfire::exceptions::JacobianError>(m, "JacobianError", m.attr("GridFireEngineError"));
py::register_exception<gridfire::exceptions::StaleJacobianError>(m, "StaleJacobianError", m.attr("JacobianEngineError"));
py::register_exception<gridfire::exceptions::UninitializedJacobianError>(m, "UninitializedJacobianError", m.attr("JacobianEngineError"));
py::register_exception<gridfire::exceptions::UnknownJacobianError>(m, "UnknownJacobianError", m.attr("JacobianEngineError"));
py::register_exception<gridfire::exceptions::UtilityError>(m, "UtilityError");
py::register_exception<gridfire::exceptions::HashingError>(m, "HashingError", m.attr("UtilityError"));
}

View File

@@ -48,7 +48,7 @@ void register_reaction_bindings(py::module &m) {
.def(
"calculate_rate",
[](const gridfire::reaction::ReaclibReaction& self, const double T9, const double rho, const std::vector<double>& Y) -> double {
return self.calculate_rate(T9, rho, 0, TODO, Y, TODO);
return self.calculate_rate(T9, rho, 0, {}, Y, {});
},
py::arg("T9"),
py::arg("rho"),
@@ -183,9 +183,15 @@ void register_reaction_bindings(py::module &m) {
py::class_<gridfire::reaction::LogicalReaclibReaction, gridfire::reaction::ReaclibReaction>(m, "LogicalReaclibReaction")
.def(
py::init<const std::vector<gridfire::reaction::Reaction>>(),
py::init<const std::vector<gridfire::reaction::ReaclibReaction>>(),
py::arg("reactions"),
"Construct a LogicalReaclibReaction from a vector of Reaction objects."
"Construct a LogicalReaclibReaction from a vector of ReaclibReaction objects."
)
.def(
py::init<const std::vector<gridfire::reaction::ReaclibReaction>, bool>(),
py::arg("reactions"),
py::arg("is_reverse"),
"Construct a LogicalReaclibReaction from a vector of ReaclibReaction objects."
)
.def(
"add_reaction",
@@ -210,26 +216,56 @@ void register_reaction_bindings(py::module &m) {
)
.def(
"calculate_rate",
[](const gridfire::reaction::LogicalReaclibReaction& self, const double T9, const double rho, const std::vector<double>& Y) -> double {
return self.calculate_rate(T9, rho, 0, TODO, Y, TODO);
[](
const gridfire::reaction::LogicalReaclibReaction& self,
const double T9,
const double rho,
const double Ye,
const double mue,
const std::vector<double>& Y,
const std::unordered_map<size_t, Species>& index_to_species_map
) -> double {
return self.calculate_rate(T9, rho, Ye, mue, Y, index_to_species_map);
},
py::arg("T9"),
"Calculate the reaction rate at a given temperature T9 (in units of 10^9 K)."
py::arg("rho"),
py::arg("Ye"),
py::arg("mue"),
py::arg("Y"),
py::arg("index_to_species_map"),
"Calculate the reaction rate at a given temperature T9 (in units of 10^9 K). Note that for a reaclib reaction only T9 is actually used, all other parameters are there for interface compatibility."
)
.def(
"calculate_forward_rate_log_derivative",
&gridfire::reaction::LogicalReaclibReaction::calculate_forward_rate_log_derivative,
&gridfire::reaction::LogicalReaclibReaction::calculate_log_rate_partial_deriv_wrt_T9,
py::arg("T9"),
py::arg("rho"),
py::arg("Ye"),
py::arg("mue"),
py::arg("Composition"),
"Calculate the forward rate log derivative at a given temperature T9 (in units of 10^9 K)."
);
py::class_<gridfire::reaction::ReactionSet>(m, "ReactionSet")
// TODO: Fix the constructor to accept a vector of unique ptrs to Reaclib Reactions
.def(
py::init<const std::vector<gridfire::reaction::Reaction>>(),
py::init<const std::vector<gridfire::reaction::Reaction*>>(),
py::arg("reactions"),
py::keep_alive<1, 2>(), // Keep arg 2 (reactions) alive as long as arg 1 (self) is alive. This helps mitigate use-after-free errors
"Construct a LogicalReactionSet from a vector of LogicalReaclibReaction objects."
)
.def_static(
"from_clones",
[](const std::vector<gridfire::reaction::Reaction*>& py_reactions) {
std::vector<std::unique_ptr<gridfire::reaction::Reaction>> cpp_reactions;
cpp_reactions.reserve(py_reactions.size());
for (const auto& reaction : py_reactions) {
cpp_reactions.emplace_back(reaction->clone());
}
return std::make_unique<gridfire::reaction::ReactionSet>(std::move(cpp_reactions));
},
py::arg("reactions"),
"Create a ReactionSet that takes ownership of the reactions by cloning the input reactions."
)
.def(
py::init<>(),
"Default constructor for an empty LogicalReactionSet."

View File

@@ -2,71 +2,97 @@
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc. if needed directly
#include <pybind11/numpy.h>
#include <functional>
#include <boost/numeric/ublas/vector.hpp>
#include "bindings.h"
#include "gridfire/solver/solver.h"
#include "gridfire/solver/strategies/CVODE_solver_strategy.h"
#include "trampoline/py_solver.h"
namespace py = pybind11;
void register_solver_bindings(const py::module &m) {
auto py_dynamic_network_solving_strategy = py::class_<gridfire::solver::DynamicNetworkSolverStrategy, PyDynamicNetworkSolverStrategy>(m, "DynamicNetworkSolverStrategy");
auto py_direct_network_solver = py::class_<gridfire::solver::DirectNetworkSolver, gridfire::solver::DynamicNetworkSolverStrategy>(m, "DirectNetworkSolver");
py_direct_network_solver.def(py::init<gridfire::DynamicEngine&>(),
py::arg("engine"),
"Constructor for the DirectNetworkSolver. Takes a DynamicEngine instance to use for evaluating the network."
);
py_direct_network_solver.def("evaluate",
&gridfire::solver::DirectNetworkSolver::evaluate,
auto py_dynamic_network_solver_strategy = py::class_<gridfire::solver::DynamicNetworkSolverStrategy, PyDynamicNetworkSolverStrategy>(m, "DynamicNetworkSolverStrategy");
py_dynamic_network_solver_strategy.def(
"evaluate",
&gridfire::solver::DynamicNetworkSolverStrategy::evaluate,
py::arg("netIn"),
"Evaluate the network for a given timestep. Returns the output conditions after the timestep."
"evaluate the dynamic engine using the dynamic engine class"
);
py_direct_network_solver.def("set_callback",
[](gridfire::solver::DirectNetworkSolver &self, const gridfire::solver::DirectNetworkSolver::TimestepCallback& cb) {
py_dynamic_network_solver_strategy.def(
"set_callback",
[](gridfire::solver::DynamicNetworkSolverStrategy& self, std::function<void(const gridfire::solver::SolverContextBase&)> cb) {
self.set_callback(cb);
},
py::arg("callback"),
"Sets a callback function to be called at each timestep."
"Set a callback function which will run at the end of every successful timestep"
);
py::class_<gridfire::solver::DirectNetworkSolver::TimestepContext>(py_direct_network_solver, "TimestepContext")
.def_readonly("t", &gridfire::solver::DirectNetworkSolver::TimestepContext::t, "Current time in the simulation.")
.def_property_readonly(
"state", [](const gridfire::solver::DirectNetworkSolver::TimestepContext& ctx) {
std::vector<double> state(ctx.state.size());
std::ranges::copy(ctx.state, state.begin());
return py::array_t<double>(static_cast<ssize_t>(state.size()), state.data());
})
.def_readonly("dt", &gridfire::solver::DirectNetworkSolver::TimestepContext::dt, "Current timestep size.")
.def_readonly("cached_time", &gridfire::solver::DirectNetworkSolver::TimestepContext::cached_time, "Cached time for the last computed result.")
.def_readonly("last_observed_time", &gridfire::solver::DirectNetworkSolver::TimestepContext::last_observed_time, "Last time the state was observed.")
.def_readonly("last_step_time", &gridfire::solver::DirectNetworkSolver::TimestepContext::last_step_time, "Last step time taken for the integration.")
.def_readonly("T9", &gridfire::solver::DirectNetworkSolver::TimestepContext::T9, "Temperature in units of 10^9 K.")
.def_readonly("rho", &gridfire::solver::DirectNetworkSolver::TimestepContext::rho, "Temperature in units of 10^9 K.")
.def_property_readonly("cached_result", [](const gridfire::solver::DirectNetworkSolver::TimestepContext& ctx) -> py::object {
if (ctx.cached_result.has_value()) {
const auto&[dydt, nuclearEnergyGenerationRate] = ctx.cached_result.value();
return py::make_tuple(
py::array_t<double>(static_cast<ssize_t>(dydt.size()), dydt.data()),
nuclearEnergyGenerationRate
);
}
return py::none();
}, "Cached result of the step derivatives.")
.def_readonly("num_steps", &gridfire::solver::DirectNetworkSolver::TimestepContext::num_steps, "Total number of steps taken in the simulation.")
.def_property_readonly("engine", [](const gridfire::solver::DirectNetworkSolver::TimestepContext &ctx) -> const gridfire::DynamicEngine & {
return ctx.engine;
}, py::return_value_policy::reference)
py_dynamic_network_solver_strategy.def(
"describe_callback_context",
&gridfire::solver::DynamicNetworkSolverStrategy::describe_callback_context,
"Get a structure representing what data is in the callback context in a human readable format"
);
auto py_cvode_solver_strategy = py::class_<gridfire::solver::CVODESolverStrategy, gridfire::solver::DynamicNetworkSolverStrategy>(m, "CVODESolverStrategy");
py_cvode_solver_strategy.def(
py::init<gridfire::DynamicEngine&>(),
py::arg("engine"),
"Initialize the CVODESolverStrategy object."
);
py_cvode_solver_strategy.def(
"evaluate",
py::overload_cast<const gridfire::NetIn&, bool>(&gridfire::solver::CVODESolverStrategy::evaluate),
py::arg("netIn"),
py::arg("display_trigger"),
"evaluate the dynamic engine using the dynamic engine class"
);
py_cvode_solver_strategy.def(
"get_stdout_logging_enabled",
&gridfire::solver::CVODESolverStrategy::get_stdout_logging_enabled,
"Check if solver logging to standard output is enabled."
);
py_cvode_solver_strategy.def(
"set_stdout_logging_enabled",
&gridfire::solver::CVODESolverStrategy::set_stdout_logging_enabled,
py::arg("logging_enabled"),
"Enable logging to standard output."
);
auto py_cvode_timestep_context = py::class_<gridfire::solver::CVODESolverStrategy::TimestepContext>(m, "CVODETimestepContext");
py_cvode_timestep_context.def_readonly("t", &gridfire::solver::CVODESolverStrategy::TimestepContext::t);
py_cvode_timestep_context.def_property_readonly(
"state",
[](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> std::vector<double> {
const sunrealtype* nvec_data = N_VGetArrayPointer(self.state);
const sunindextype length = N_VGetLength(self.state);
return std::vector<double>(nvec_data, nvec_data + length);
}
);
py_cvode_timestep_context.def_readonly("dt", &gridfire::solver::CVODESolverStrategy::TimestepContext::dt);
py_cvode_timestep_context.def_readonly("last_step_time", &gridfire::solver::CVODESolverStrategy::TimestepContext::last_step_time);
py_cvode_timestep_context.def_readonly("T9", &gridfire::solver::CVODESolverStrategy::TimestepContext::T9);
py_cvode_timestep_context.def_readonly("rho", &gridfire::solver::CVODESolverStrategy::TimestepContext::rho);
py_cvode_timestep_context.def_readonly("num_steps", &gridfire::solver::CVODESolverStrategy::TimestepContext::num_steps);
py_cvode_timestep_context.def_property_readonly(
"engine",
[](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> const gridfire::DynamicEngine& {
return self.engine;
}
);
py_cvode_timestep_context.def_property_readonly(
"networkSpecies",
[](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> std::vector<fourdst::atomic::Species> {
return self.networkSpecies;
}
);
.def_property_readonly("network_species", [](const gridfire::solver::DirectNetworkSolver::TimestepContext &ctx) -> const std::vector<fourdst::atomic::Species> & {
return ctx.networkSpecies;
}, py::return_value_policy::reference);
}

View File

@@ -39,3 +39,12 @@ std::vector<std::tuple<std::string, std::string>> PyDynamicNetworkSolverStrategy
describe_callback_context // Method name
);
}
std::vector<std::tuple<std::string, std::string>> PySolverContextBase::describe() const {
using DescriptionVector = std::vector<std::tuple<std::string, std::string>>;
PYBIND11_OVERRIDE_PURE(
DescriptionVector,
gridfire::solver::SolverContextBase,
describe
);
}

View File

@@ -12,4 +12,9 @@ class PyDynamicNetworkSolverStrategy final : public gridfire::solver::DynamicNet
gridfire::NetOut evaluate(const gridfire::NetIn &netIn) override;
void set_callback(const std::any &callback) override;
[[nodiscard]] std::vector<std::tuple<std::string, std::string>> describe_callback_context() const override;
};
class PySolverContextBase final : public gridfire::solver::SolverContextBase {
public:
[[nodiscard]] std::vector<std::tuple<std::string, std::string>> describe() const override;
};

View File

@@ -1,6 +1,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc. if needed directly
#include <format>
#include "bindings.h"
@@ -32,12 +33,18 @@ void register_type_bindings(const pybind11::module &m) {
.def_readonly("composition", &gridfire::NetOut::composition)
.def_readonly("num_steps", &gridfire::NetOut::num_steps)
.def_readonly("energy", &gridfire::NetOut::energy)
.def_readonly("dEps_dT", &gridfire::NetOut::dEps_dT)
.def_readonly("dEps_dRho", &gridfire::NetOut::dEps_dRho)
.def("__repr__", [](const gridfire::NetOut &netOut) {
std::stringstream ss;
ss << "NetOut(composition=" << netOut.composition
<< ", num_steps=" << netOut.num_steps
<< ", energy=" << netOut.energy << ")";
return ss.str();
std::string repr = std::format(
"NetOut(<μ> = {} steps = {}, ε = {}, dε/dT = {}, dε/dρ = {})",
netOut.composition.getMeanParticleMass(),
netOut.num_steps,
netOut.energy,
netOut.dEps_dT,
netOut.dEps_dRho
);
return repr;
});
}

View File

@@ -3,6 +3,9 @@
#include "bindings.h"
#include "gridfire/utils/general_composition.h"
#include "gridfire/utils/hashing.h"
namespace py = pybind11;
#include "gridfire/utils/logging.h"
@@ -16,4 +19,65 @@ void register_utils_bindings(py::module &m) {
py::arg("rho"),
"Format a string for logging nuclear timescales based on temperature, density, and energy generation rate."
);
m.def(
"massFractionFromMolarAbundanceAndComposition",
&gridfire::utils::massFractionFromMolarAbundanceAndComposition,
py::arg("composition"),
py::arg("species"),
py::arg("Yi"),
"Convert a specific species molar abundance into its mass fraction if it were present in a given composition."
);
m.def(
"massFractionFromMolarAbundanceAndMolarMass",
&gridfire::utils::massFractionFromMolarAbundanceAndMolarMass,
py::arg("molarAbundances"),
py::arg("molarMasses"),
"Convert a vector of molar abundances and a parallel vector of molar masses into a vector of mass fractions"
);
m.def(
"molarMassVectorFromComposition",
&gridfire::utils::molarMassVectorFromComposition,
py::arg("composition"),
"Extract vector of molar masses from a composition object, this will be sorted by species mass so that the lightest species are at the front of the list."
);
m.def(
"hash_atomic",
&gridfire::utils::hash_atomic,
py::arg("a"),
py::arg("z")
);
auto hashing_module = m.def_submodule("hashing", "module for gridfire hashing functions");
auto reaction_hashing_module = hashing_module.def_submodule("reaction", "utility module for hashing gridfire reaction functions");
reaction_hashing_module.def(
"splitmix64",
&gridfire::utils::hashing::reaction::splitmix64,
py::arg("x")
);
reaction_hashing_module.def(
"mix_species",
&gridfire::utils::hashing::reaction::mix_species,
py::arg("a"),
py::arg("z")
);
reaction_hashing_module.def(
"multiset_combine",
&gridfire::utils::hashing::reaction::multiset_combine,
py::arg("acc"),
py::arg("x")
);
m.def(
"hash_reaction",
&gridfire::utils::hash_reaction,
py::arg("reaction")
);
}