feat(python): added robust python bindings covering the entire codebase

This commit is contained in:
2025-07-23 16:26:30 -04:00
parent 6a22cb65b8
commit f20bffc411
134 changed files with 2202 additions and 170 deletions

View File

@@ -0,0 +1,8 @@
#pragma once
#include "gridfire/engine/engine_abstract.h"
#include "gridfire/engine/engine_graph.h"
#include "gridfire/engine/views/engine_views.h"
#include "gridfire/engine/procedures/engine_procedures.h"
#include "gridfire/engine/types/engine_types.h"

View File

@@ -0,0 +1,4 @@
#pragma once
#include "gridfire/engine/procedures/construction.h"
#include "gridfire/engine/procedures/priming.h"

View File

@@ -0,0 +1,4 @@
#pragma once
#include "gridfire/engine/types/building.h"
#include "gridfire/engine/types/reporting.h"

View File

@@ -0,0 +1,7 @@
#pragma once
#include "gridfire/engine/views/engine_adaptive.h"
#include "gridfire/engine/views/engine_defined.h"
#include "gridfire/engine/views/engine_multiscale.h"
#include "gridfire/engine/views/engine_priming.h"
#include "gridfire/engine/views/engine_view_abstract.h"

View File

@@ -0,0 +1,3 @@
#pragma once
#include "gridfire/exceptions/error_engine.h"

View File

@@ -0,0 +1,3 @@
#pragma once
#include "gridfire/expectations/expected_engine.h"

View File

@@ -14,9 +14,16 @@ namespace gridfire::expectations {
SYSTEM_RESIZED
};
// TODO: rename this to EngineExpectation or something similar
struct EngineError {
std::string m_message;
EngineErrorTypes type = EngineErrorTypes::FAILURE;
const EngineErrorTypes type = EngineErrorTypes::FAILURE;
explicit EngineError(const std::string &message, const EngineErrorTypes type)
: m_message(message), type(type) {}
virtual ~EngineError() = default;
friend std::ostream& operator<<(std::ostream& os, const EngineError& e) {
os << "EngineError: " << e.m_message;
return os;
@@ -25,7 +32,9 @@ namespace gridfire::expectations {
struct EngineIndexError : EngineError {
int m_index;
EngineErrorTypes type = EngineErrorTypes::INDEX;
explicit EngineIndexError(const int index)
: EngineError("Index error occurred", EngineErrorTypes::INDEX), m_index(index) {}
friend std::ostream& operator<<(std::ostream& os, const EngineIndexError& e) {
os << "EngineIndexError: " << e.m_message << " at index " << e.m_index;
return os;
@@ -33,10 +42,10 @@ namespace gridfire::expectations {
};
struct StaleEngineError : EngineError {
EngineErrorTypes type = EngineErrorTypes::STALE;
StaleEngineErrorTypes staleType;
explicit StaleEngineError(StaleEngineErrorTypes staleType) : staleType(staleType) {}
explicit StaleEngineError(const StaleEngineErrorTypes sType)
: EngineError("Stale engine error occurred", EngineErrorTypes::STALE), staleType(sType) {}
explicit operator std::string() const {
switch (staleType) {

View File

@@ -0,0 +1,3 @@
#pragma once
#include "gridfire/io/network_file.h"

View File

@@ -0,0 +1,8 @@
#pragma once
#include "gridfire/partition/partition_types.h"
#include "gridfire/partition/partition_abstract.h"
#include "gridfire/partition/partition_ground.h"
#include "gridfire/partition/partition_rauscher_thielemann.h"
#include "gridfire/partition/rauscher_thielemann_partition_data_record.h"
#include "gridfire/partition/composite/partition_composite.h"

View File

@@ -0,0 +1,6 @@
#pragma once
#include "gridfire/screening/screening_types.h"
#include "gridfire/screening/screening_abstract.h"
#include "gridfire/screening/screening_bare.h"
#include "gridfire/screening/screening_weak.h"

View File

@@ -103,6 +103,6 @@ namespace gridfire::screening {
const std::vector<ADDouble>& Y,
const ADDouble T9,
const ADDouble rho
) const = 0;
) const = 0;
};
}

View File

@@ -1 +1,80 @@
subdir('network')
# Define the library
gridfire_sources = files(
'lib/network.cpp',
'lib/engine/engine_approx8.cpp',
'lib/engine/engine_graph.cpp',
'lib/engine/views/engine_adaptive.cpp',
'lib/engine/views/engine_defined.cpp',
'lib/engine/views/engine_multiscale.cpp',
'lib/engine/views/engine_priming.cpp',
'lib/engine/procedures/priming.cpp',
'lib/engine/procedures/construction.cpp',
'lib/reaction/reaction.cpp',
'lib/reaction/reaclib.cpp',
'lib/io/network_file.cpp',
'lib/solver/solver.cpp',
'lib/screening/screening_types.cpp',
'lib/screening/screening_weak.cpp',
'lib/screening/screening_bare.cpp',
'lib/partition/partition_rauscher_thielemann.cpp',
'lib/partition/partition_ground.cpp',
'lib/partition/composite/partition_composite.cpp',
'lib/utils/logging.cpp',
)
gridfire_build_dependencies = [
boost_dep,
const_dep,
config_dep,
composition_dep,
cppad_dep,
log_dep,
xxhash_dep,
eigen_dep,
]
# Define the libnetwork library so it can be linked against by other parts of the build system
libgridfire = library('gridfire',
gridfire_sources,
include_directories: include_directories('include'),
dependencies: gridfire_build_dependencies,
install : true)
gridfire_dep = declare_dependency(
include_directories: include_directories('include'),
link_with: libgridfire,
sources: gridfire_sources,
dependencies: gridfire_build_dependencies,
)
# Make headers accessible
gridfire_headers = files(
'include/gridfire/network.h',
'include/gridfire/engine/engine_abstract.h',
'include/gridfire/engine/views/engine_view_abstract.h',
'include/gridfire/engine/engine_approx8.h',
'include/gridfire/engine/engine_graph.h',
'include/gridfire/engine/views/engine_adaptive.h',
'include/gridfire/engine/views/engine_defined.h',
'include/gridfire/engine/views/engine_multiscale.h',
'include/gridfire/engine/views/engine_priming.h',
'include/gridfire/engine/procedures/priming.h',
'include/gridfire/engine/procedures/construction.h',
'include/gridfire/reaction/reaction.h',
'include/gridfire/reaction/reaclib.h',
'include/gridfire/io/network_file.h',
'include/gridfire/solver/solver.h',
'include/gridfire/screening/screening_abstract.h',
'include/gridfire/screening/screening_bare.h',
'include/gridfire/screening/screening_weak.h',
'include/gridfire/screening/screening_types.h',
'include/gridfire/partition/partition_abstract.h',
'include/gridfire/partition/partition_rauscher_thielemann.h',
'include/gridfire/partition/partition_ground.h',
'include/gridfire/partition/composite/partition_composite.h',
'include/gridfire/utils/logging.h',
)
install_headers(gridfire_headers, subdir : 'gridfire')
subdir('python')

View File

@@ -1,78 +0,0 @@
# Define the library
network_sources = files(
'lib/network.cpp',
'lib/engine/engine_approx8.cpp',
'lib/engine/engine_graph.cpp',
'lib/engine/views/engine_adaptive.cpp',
'lib/engine/views/engine_defined.cpp',
'lib/engine/views/engine_multiscale.cpp',
'lib/engine/views/engine_priming.cpp',
'lib/engine/procedures/priming.cpp',
'lib/engine/procedures/construction.cpp',
'lib/reaction/reaction.cpp',
'lib/reaction/reaclib.cpp',
'lib/io/network_file.cpp',
'lib/solver/solver.cpp',
'lib/screening/screening_types.cpp',
'lib/screening/screening_weak.cpp',
'lib/screening/screening_bare.cpp',
'lib/partition/partition_rauscher_thielemann.cpp',
'lib/partition/partition_ground.cpp',
'lib/partition/composite/partition_composite.cpp',
'lib/utils/logging.cpp',
)
dependencies = [
boost_dep,
const_dep,
config_dep,
composition_dep,
cppad_dep,
log_dep,
xxhash_dep,
eigen_dep,
]
# Define the libnetwork library so it can be linked against by other parts of the build system
libnetwork = library('network',
network_sources,
include_directories: include_directories('include'),
dependencies: dependencies,
install : true)
network_dep = declare_dependency(
include_directories: include_directories('include'),
link_with: libnetwork,
sources: network_sources,
dependencies: dependencies,
)
# Make headers accessible
network_headers = files(
'include/gridfire/network.h',
'include/gridfire/engine/engine_abstract.h',
'include/gridfire/engine/views/engine_view_abstract.h',
'include/gridfire/engine/engine_approx8.h',
'include/gridfire/engine/engine_graph.h',
'include/gridfire/engine/views/engine_adaptive.h',
'include/gridfire/engine/views/engine_defined.h',
'include/gridfire/engine/views/engine_multiscale.h',
'include/gridfire/engine/views/engine_priming.h',
'include/gridfire/engine/procedures/priming.h',
'include/gridfire/engine/procedures/construction.h',
'include/gridfire/reaction/reaction.h',
'include/gridfire/reaction/reaclib.h',
'include/gridfire/io/network_file.h',
'include/gridfire/solver/solver.h',
'include/gridfire/screening/screening_abstract.h',
'include/gridfire/screening/screening_bare.h',
'include/gridfire/screening/screening_weak.h',
'include/gridfire/screening/screening_types.h',
'include/gridfire/partition/partition_abstract.h',
'include/gridfire/partition/partition_rauscher_thielemann.h',
'include/gridfire/partition/partition_ground.h',
'include/gridfire/partition/composite/partition_composite.h',
'include/gridfire/utils/logging.h',
)
install_headers(network_headers, subdir : 'gridfire')

54
src/python/bindings.cpp Normal file
View File

@@ -0,0 +1,54 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <string>
#include "types/bindings.h"
#include "partition/bindings.h"
#include "expectations/bindings.h"
#include "engine/bindings.h"
#include "exceptions/bindings.h"
#include "io/bindings.h"
#include "reaction/bindings.h"
#include "screening/bindings.h"
#include "solver/bindings.h"
#include "utils/bindings.h"
PYBIND11_MODULE(gridfire, m) {
m.doc() = "Python bindings for the fourdst utility modules which are a part of the 4D-STAR project.";
pybind11::module::import("fourdst.constants");
pybind11::module::import("fourdst.composition");
pybind11::module::import("fourdst.config");
pybind11::module::import("fourdst.atomic");
auto typeMod = m.def_submodule("type", "GridFire type bindings");
register_type_bindings(typeMod);
auto partitionMod = m.def_submodule("partition", "GridFire partition function bindings");
register_partition_bindings(partitionMod);
auto expectationMod = m.def_submodule("expectations", "GridFire expectations bindings");
register_expectation_bindings(expectationMod);
auto reactionMod = m.def_submodule("reaction", "GridFire reaction bindings");
register_reaction_bindings(reactionMod);
auto screeningMod = m.def_submodule("screening", "GridFire plasma screening bindings");
register_screening_bindings(screeningMod);
auto ioMod = m.def_submodule("io", "GridFire io bindings");
register_io_bindings(ioMod);
auto exceptionMod = m.def_submodule("exceptions", "GridFire exceptions bindings");
register_exception_bindings(exceptionMod);
auto engineMod = m.def_submodule("engine", "Engine and Engine View bindings");
register_engine_bindings(engineMod);
auto solverMod = m.def_submodule("solver", "GridFire numerical solver bindings");
register_solver_bindings(solverMod);
auto utilsMod = m.def_submodule("utils", "GridFire utility method bindings");
register_utils_bindings(utilsMod);
}

View File

@@ -0,0 +1,391 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc if needed directly
#include <iostream>
#include "bindings.h"
#include "gridfire/engine/engine.h"
#include "trampoline/py_engine.h"
namespace py = pybind11;
namespace {
template <typename T>
concept IsDynamicEngine = std::is_base_of_v<gridfire::DynamicEngine, T>;
template <IsDynamicEngine T, IsDynamicEngine BaseT>
void registerDynamicEngineDefs(py::class_<T, BaseT> pyClass) {
pyClass.def("calculateRHSAndEnergy", &T::calculateRHSAndEnergy,
py::arg("Y"),
py::arg("T9"),
py::arg("rho"),
"Calculate the right-hand side (dY/dt) and energy generation rate."
)
.def("generateJacobianMatrix", py::overload_cast<const std::vector<double>&, double, double>(&T::generateJacobianMatrix, py::const_),
py::arg("Y_dynamic"),
py::arg("T9"),
py::arg("rho"),
"Generate the Jacobian matrix for the current state."
)
.def("generateStoichiometryMatrix", &T::generateStoichiometryMatrix)
.def("calculateMolarReactionFlow",
static_cast<double (T::*)(const gridfire::reaction::Reaction&, const std::vector<double>&, const double, const double) const>(&T::calculateMolarReactionFlow),
py::arg("reaction"),
py::arg("Y"),
py::arg("T9"),
py::arg("rho"),
"Calculate the molar reaction flow for a given reaction."
)
.def("getNetworkSpecies", &T::getNetworkSpecies,
"Get the list of species in the network."
)
.def("getNetworkReactions", &T::getNetworkReactions,
"Get the set of logical reactions in the network."
)
.def ("setNetworkReactions", &T::setNetworkReactions,
py::arg("reactions"),
"Set the network reactions to a new set of reactions."
)
.def("getJacobianMatrixEntry", &T::getJacobianMatrixEntry,
py::arg("i"),
py::arg("j"),
"Get an entry from the previously generated Jacobian matrix."
)
.def("getStoichiometryMatrixEntry", &T::getStoichiometryMatrixEntry,
py::arg("speciesIndex"),
py::arg("reactionIndex"),
"Get an entry from the stoichiometry matrix."
)
.def("getSpeciesTimescales", &T::getSpeciesTimescales,
py::arg("Y"),
py::arg("T9"),
py::arg("rho"),
"Get the timescales for each species in the network."
)
.def("getSpeciesDestructionTimescales", &T::getSpeciesDestructionTimescales,
py::arg("Y"),
py::arg("T9"),
py::arg("rho"),
"Get the destruction timescales for each species in the network."
)
.def("update", &T::update,
py::arg("netIn"),
"Update the engine state based on the provided NetIn object."
)
.def("setScreeningModel", &T::setScreeningModel,
py::arg("screeningModel"),
"Set the screening model for the engine."
)
.def("getScreeningModel", &T::getScreeningModel,
"Get the current screening model of the engine."
)
.def("getSpeciesIndex", &T::getSpeciesIndex,
py::arg("species"),
"Get the index of a species in the network."
)
.def("mapNetInToMolarAbundanceVector", &T::mapNetInToMolarAbundanceVector,
py::arg("netIn"),
"Map a NetIn object to a vector of molar abundances."
)
.def("primeEngine", &T::primeEngine,
py::arg("netIn"),
"Prime the engine with a NetIn object to prepare for calculations."
)
.def("getDepth", &T::getDepth,
"Get the current build depth of the engine."
)
.def("rebuild", &T::rebuild,
py::arg("composition"),
py::arg("depth") = gridfire::NetworkBuildDepth::Full,
"Rebuild the engine with a new composition and build depth."
)
.def("isStale", &T::isStale,
py::arg("netIn"),
"Check if the engine is stale based on the provided NetIn object."
);
}
}
void register_engine_bindings(py::module &m) {
register_base_engine_bindings(m);
register_engine_view_bindings(m);
m.def("build_reaclib_nuclear_network", &gridfire::build_reaclib_nuclear_network,
py::arg("composition"),
py::arg("maxLayers") = gridfire::NetworkBuildDepth::Full,
py::arg("reverse") = false,
"Build a nuclear network from a composition using ReacLib data."
);
py::enum_<gridfire::PrimingReportStatus>(m, "PrimingReportStatus")
.value("FULL_SUCCESS", gridfire::PrimingReportStatus::FULL_SUCCESS, "Priming was full successful.")
.value("NO_SPECIES_TO_PRIME", gridfire::PrimingReportStatus::NO_SPECIES_TO_PRIME, "No species to prime.")
.value("MAX_ITERATIONS_REACHED", gridfire::PrimingReportStatus::MAX_ITERATIONS_REACHED, "Maximum iterations reached during priming.")
.value("FAILED_TO_FINALIZE_COMPOSITION", gridfire::PrimingReportStatus::FAILED_TO_FINALIZE_COMPOSITION, "Failed to finalize the composition after priming.")
.value("FAILED_TO_FIND_CREATION_CHANNEL", gridfire::PrimingReportStatus::FAILED_TO_FIND_CREATION_CHANNEL, "Failed to find a creation channel for the priming species.")
.value("FAILED_TO_FIND_PRIMING_REACTIONS", gridfire::PrimingReportStatus::FAILED_TO_FIND_PRIMING_REACTIONS, "Failed to find priming reactions for the species.")
.value("BASE_NETWORK_TOO_SHALLOW", gridfire::PrimingReportStatus::BASE_NETWORK_TOO_SHALLOW, "The base network is too shallow for priming.")
.export_values()
.def("__repr__", [](const gridfire::PrimingReportStatus& status) {
std::stringstream ss;
ss << gridfire::PrimingReportStatusStrings.at(status) << "\n";
return ss.str();
},
"String representation of the PrimingReport."
);
py::class_<gridfire::PrimingReport>(m, "PrimingReport")
.def_readonly("success", &gridfire::PrimingReport::success, "Indicates if the priming was successful.")
.def_readonly("massFractionChanges", &gridfire::PrimingReport::massFractionChanges, "Map of species to their mass fraction changes after priming.")
.def_readonly("primedComposition", &gridfire::PrimingReport::primedComposition, "The composition after priming.")
.def_readonly("status", &gridfire::PrimingReport::status, "Status message from the priming process.")
.def("__repr__", [](const gridfire::PrimingReport& report) {
std::stringstream ss;
ss << report;
return ss.str();
}
);
}
void register_base_engine_bindings(pybind11::module &m) {
py::class_<gridfire::StepDerivatives<double>>(m, "StepDerivatives")
.def_readonly("dYdt", &gridfire::StepDerivatives<double>::dydt, "The right-hand side (dY/dt) of the ODE system.")
.def_readonly("energy", &gridfire::StepDerivatives<double>::nuclearEnergyGenerationRate, "The energy generation rate.");
py::class_<gridfire::SparsityPattern>(m, "SparsityPattern");
abs_stype_register_engine_bindings(m);
abs_stype_register_dynamic_engine_bindings(m);
con_stype_register_graph_engine_bindings(m);
}
void abs_stype_register_engine_bindings(pybind11::module &m) {
py::class_<gridfire::Engine, PyEngine>(m, "Engine");
}
void abs_stype_register_dynamic_engine_bindings(pybind11::module &m) {
const auto a = py::class_<gridfire::DynamicEngine, PyDynamicEngine>(m, "DynamicEngine");
}
void con_stype_register_graph_engine_bindings(pybind11::module &m) {
py::enum_<gridfire::NetworkBuildDepth>(m, "NetworkBuildDepth")
.value("Full", gridfire::NetworkBuildDepth::Full, "Full network build depth")
.value("Shallow", gridfire::NetworkBuildDepth::Shallow, "Shallow network build depth")
.value("SecondOrder", gridfire::NetworkBuildDepth::SecondOrder, "Second order network build depth")
.value("ThirdOrder", gridfire::NetworkBuildDepth::ThirdOrder, "Third order network build depth")
.value("FourthOrder", gridfire::NetworkBuildDepth::FourthOrder, "Fourth order network build depth")
.value("FifthOrder", gridfire::NetworkBuildDepth::FifthOrder, "Fifth order network build depth")
.export_values();
py::class_<gridfire::BuildDepthType>(m, "BuildDepthType");
auto py_dynamic_engine_bindings = py::class_<gridfire::GraphEngine, gridfire::DynamicEngine>(m, "GraphEngine");
// Register the Graph Engine Specific Bindings
py_dynamic_engine_bindings.def(py::init<const fourdst::composition::Composition &, const gridfire::BuildDepthType>(),
py::arg("composition"),
py::arg("depth") = gridfire::NetworkBuildDepth::Full,
"Initialize GraphEngine with a composition and build depth."
);
py_dynamic_engine_bindings.def(py::init<const fourdst::composition::Composition &, const gridfire::partition::PartitionFunction &, const gridfire::BuildDepthType>(),
py::arg("composition"),
py::arg("partitionFunction"),
py::arg("depth") = gridfire::NetworkBuildDepth::Full,
"Initialize GraphEngine with a composition, partition function and build depth."
);
py_dynamic_engine_bindings.def(py::init<const gridfire::reaction::LogicalReactionSet &>(),
py::arg("reactions"),
"Initialize GraphEngine with a set of reactions."
);
py_dynamic_engine_bindings.def("generateJacobianMatrix", py::overload_cast<const std::vector<double>&, double, double, const gridfire::SparsityPattern&>(&gridfire::GraphEngine::generateJacobianMatrix, py::const_),
py::arg("Y_dynamic"),
py::arg("T9"),
py::arg("rho"),
py::arg("sparsityPattern"),
"Generate the Jacobian matrix for the current state with a specified sparsity pattern."
);
py_dynamic_engine_bindings.def_static("getNetReactionStoichiometry", &gridfire::GraphEngine::getNetReactionStoichiometry,
py::arg("reaction"),
"Get the net stoichiometry for a given reaction."
);
py_dynamic_engine_bindings.def("involvesSpecies", &gridfire::GraphEngine::involvesSpecies,
py::arg("species"),
"Check if a given species is involved in the network."
);
py_dynamic_engine_bindings.def("exportToDot", &gridfire::GraphEngine::exportToDot,
py::arg("filename"),
"Export the network to a DOT file for visualization."
);
py_dynamic_engine_bindings.def("exportToCSV", &gridfire::GraphEngine::exportToCSV,
py::arg("filename"),
"Export the network to a CSV file for analysis."
);
py_dynamic_engine_bindings.def("setPrecomputation", &gridfire::GraphEngine::setPrecomputation,
py::arg("precompute"),
"Enable or disable precomputation for the engine."
);
py_dynamic_engine_bindings.def("isPrecomputationEnabled", &gridfire::GraphEngine::isPrecomputationEnabled,
"Check if precomputation is enabled for the engine."
);
py_dynamic_engine_bindings.def("getPartitionFunction", &gridfire::GraphEngine::getPartitionFunction,
"Get the partition function used by the engine."
);
py_dynamic_engine_bindings.def("calculateReverseRate", &gridfire::GraphEngine::calculateReverseRate,
py::arg("reaction"),
py::arg("T9"),
"Calculate the reverse rate for a given reaction at a specific temperature."
);
py_dynamic_engine_bindings.def("calculateReverseRateTwoBody", &gridfire::GraphEngine::calculateReverseRateTwoBody,
py::arg("reaction"),
py::arg("T9"),
py::arg("forwardRate"),
py::arg("expFactor"),
"Calculate the reverse rate for a two-body reaction at a specific temperature."
);
py_dynamic_engine_bindings.def("calculateReverseRateTwoBodyDerivative", &gridfire::GraphEngine::calculateReverseRateTwoBodyDerivative,
py::arg("reaction"),
py::arg("T9"),
py::arg("reverseRate"),
"Calculate the derivative of the reverse rate for a two-body reaction at a specific temperature."
);
py_dynamic_engine_bindings.def("isUsingReverseReactions", &gridfire::GraphEngine::isUsingReverseReactions,
"Check if the engine is using reverse reactions."
);
py_dynamic_engine_bindings.def("setUseReverseReactions", &gridfire::GraphEngine::setUseReverseReactions,
py::arg("useReverse"),
"Enable or disable the use of reverse reactions in the engine."
);
// Register the general dynamic engine bindings
registerDynamicEngineDefs<gridfire::GraphEngine, gridfire::DynamicEngine>(py_dynamic_engine_bindings);
}
void register_engine_view_bindings(pybind11::module &m) {
auto py_defined_engine_view_bindings = py::class_<gridfire::DefinedEngineView, gridfire::DynamicEngine>(m, "DefinedEngineView");
py_defined_engine_view_bindings.def(py::init<std::vector<std::string>, gridfire::DynamicEngine&>(),
py::arg("peNames"),
py::arg("baseEngine"),
"Construct a defined engine view with a list of tracked reactions and a base engine.");
py_defined_engine_view_bindings.def("getBaseEngine", &gridfire::DefinedEngineView::getBaseEngine,
"Get the base engine associated with this defined engine view.");
registerDynamicEngineDefs<gridfire::DefinedEngineView, gridfire::DynamicEngine>(py_defined_engine_view_bindings);
auto py_file_defined_engine_view_bindings = py::class_<gridfire::FileDefinedEngineView, gridfire::DefinedEngineView>(m, "FileDefinedEngineView");
py_file_defined_engine_view_bindings.def(py::init<gridfire::DynamicEngine&, const std::string&, const gridfire::io::NetworkFileParser&>(),
py::arg("baseEngine"),
py::arg("fileName"),
py::arg("parser"),
"Construct a defined engine view from a file and a base engine."
);
py_file_defined_engine_view_bindings.def("getNetworkFile", &gridfire::FileDefinedEngineView::getNetworkFile,
"Get the network file associated with this defined engine view."
);
py_file_defined_engine_view_bindings.def("getParser", &gridfire::FileDefinedEngineView::getParser,
"Get the parser used for this defined engine view."
);
py_file_defined_engine_view_bindings.def("getBaseEngine", &gridfire::FileDefinedEngineView::getBaseEngine,
"Get the base engine associated with this file defined engine view.");
registerDynamicEngineDefs<gridfire::FileDefinedEngineView, gridfire::DefinedEngineView>(py_file_defined_engine_view_bindings);
auto py_priming_engine_view_bindings = py::class_<gridfire::NetworkPrimingEngineView, gridfire::DefinedEngineView>(m, "NetworkPrimingEngineView");
py_priming_engine_view_bindings.def(py::init<const std::string&, gridfire::DynamicEngine&>(),
py::arg("primingSymbol"),
py::arg("baseEngine"),
"Construct a priming engine view with a priming symbol and a base engine.");
py_priming_engine_view_bindings.def(py::init<const fourdst::atomic::Species&, gridfire::DynamicEngine&>(),
py::arg("primingSpecies"),
py::arg("baseEngine"),
"Construct a priming engine view with a priming species and a base engine.");
py_priming_engine_view_bindings.def("getBaseEngine", &gridfire::NetworkPrimingEngineView::getBaseEngine,
"Get the base engine associated with this priming engine view.");
registerDynamicEngineDefs<gridfire::NetworkPrimingEngineView, gridfire::DefinedEngineView>(py_priming_engine_view_bindings);
auto py_adaptive_engine_view_bindings = py::class_<gridfire::AdaptiveEngineView, gridfire::DynamicEngine>(m, "AdaptiveEngineView");
py_adaptive_engine_view_bindings.def(py::init<gridfire::DynamicEngine&>(),
py::arg("baseEngine"),
"Construct an adaptive engine view with a base engine.");
py_adaptive_engine_view_bindings.def("getBaseEngine", &gridfire::AdaptiveEngineView::getBaseEngine,
"Get the base engine associated with this adaptive engine view.");
registerDynamicEngineDefs<gridfire::AdaptiveEngineView, gridfire::DynamicEngine>(py_adaptive_engine_view_bindings);
auto py_qse_cache_config = py::class_<gridfire::QSECacheConfig>(m, "QSECacheConfig");
auto py_qse_cache_key = py::class_<gridfire::QSECacheKey>(m, "QSECacheKey");
py_qse_cache_key.def(py::init<double, double, const std::vector<double>&>(),
py::arg("T9"),
py::arg("rho"),
py::arg("Y")
);
py_qse_cache_key.def("hash", &gridfire::QSECacheKey::hash,
"Get the pre-computed hash value of the key");
py_qse_cache_key.def_static("bin", &gridfire::QSECacheKey::bin,
py::arg("value"),
py::arg("tol"),
"bin a value based on a tolerance");
py_qse_cache_key.def("__eq__", &gridfire::QSECacheKey::operator==,
py::arg("other"),
"Check if two QSECacheKeys are equal");
auto py_multiscale_engine_view_bindings = py::class_<gridfire::MultiscalePartitioningEngineView, gridfire::DynamicEngine>(m, "MultiscalePartitioningEngineView");
py_multiscale_engine_view_bindings.def(py::init<gridfire::GraphEngine&>(),
py::arg("baseEngine"),
"Construct a multiscale partitioning engine view with a base engine.");
py_multiscale_engine_view_bindings.def("getBaseEngine", &gridfire::MultiscalePartitioningEngineView::getBaseEngine,
"Get the base engine associated with this multiscale partitioning engine view.");
py_multiscale_engine_view_bindings.def("analyzeTimescalePoolConnectivity", &gridfire::MultiscalePartitioningEngineView::analyzeTimescalePoolConnectivity,
py::arg("timescale_pools"),
py::arg("Y"),
py::arg("T9"),
py::arg("rho"),
"Analyze the connectivity of timescale pools in the network.");
py_multiscale_engine_view_bindings.def("partitionNetwork", py::overload_cast<const std::vector<double>&, double, double>(&gridfire::MultiscalePartitioningEngineView::partitionNetwork),
py::arg("Y"),
py::arg("T9"),
py::arg("rho"),
"Partition the network based on species timescales and connectivity.");
py_multiscale_engine_view_bindings.def("partitionNetwork", py::overload_cast<const gridfire::NetIn&>(&gridfire::MultiscalePartitioningEngineView::partitionNetwork),
py::arg("netIn"),
"Partition the network based on a NetIn object.");
py_multiscale_engine_view_bindings.def("exportToDot", &gridfire::MultiscalePartitioningEngineView::exportToDot,
py::arg("filename"),
py::arg("Y"),
py::arg("T9"),
py::arg("rho"),
"Export the network to a DOT file for visualization.");
py_multiscale_engine_view_bindings.def("getFastSpecies", &gridfire::MultiscalePartitioningEngineView::getFastSpecies,
"Get the list of fast species in the network.");
py_multiscale_engine_view_bindings.def("getDynamicSpecies", &gridfire::MultiscalePartitioningEngineView::getDynamicSpecies,
"Get the list of dynamic species in the network.");
py_multiscale_engine_view_bindings.def("equilibrateNetwork", py::overload_cast<const std::vector<double>&, double, double>(&gridfire::MultiscalePartitioningEngineView::equilibrateNetwork),
py::arg("Y"),
py::arg("T9"),
py::arg("rho"),
"Equilibrate the network based on species abundances and conditions.");
py_multiscale_engine_view_bindings.def("equilibrateNetwork", py::overload_cast<const gridfire::NetIn&>(&gridfire::MultiscalePartitioningEngineView::equilibrateNetwork),
py::arg("netIn"),
"Equilibrate the network based on a NetIn object.");
registerDynamicEngineDefs<gridfire::MultiscalePartitioningEngineView, gridfire::DynamicEngine>(py_multiscale_engine_view_bindings);
}

View File

@@ -0,0 +1,16 @@
#pragma once
#include <pybind11/pybind11.h>
void register_engine_bindings(pybind11::module &m);
void register_base_engine_bindings(pybind11::module &m);
void register_engine_view_bindings(pybind11::module &m);
void abs_stype_register_engine_bindings(pybind11::module &m);
void abs_stype_register_dynamic_engine_bindings(pybind11::module &m);
void con_stype_register_graph_engine_bindings(pybind11::module &m);

View File

@@ -0,0 +1,21 @@
subdir('trampoline')
# Define the library
bindings_sources = files('bindings.cpp')
bindings_headers = files('bindings.h')
dependencies = [
gridfire_dep,
python3_dep,
pybind11_dep,
]
message('⏳ Python bindings for GridFire Engine are being registered...')
shared_module('py_gf_engine',
bindings_sources,
cpp_args: ['-fvisibility=default'],
install : true,
dependencies: dependencies,
include_directories: include_directories('.')
)
message('✅ Python bindings for GridFire Engine registered successfully!')

View File

@@ -0,0 +1,21 @@
gf_engine_trampoline_sources = files('py_engine.cpp')
gf_engine_trapoline_dependencies = [
gridfire_dep,
pybind11_dep,
python3_dep,
]
gf_engine_trampoline_lib = static_library(
'engine_trampolines',
gf_engine_trampoline_sources,
include_directories: include_directories('.'),
dependencies: gf_engine_trapoline_dependencies,
install: false,
)
gr_engine_trampoline_dep = declare_dependency(
link_with: gf_engine_trampoline_lib,
include_directories: ('.'),
dependencies: gf_engine_trapoline_dependencies,
)

View File

@@ -0,0 +1,219 @@
#include "py_engine.h"
#include "gridfire/engine/engine.h"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/functional.h> // Needed for std::function
#include <expected>
#include <unordered_map>
#include <vector>
namespace py = pybind11;
const std::vector<fourdst::atomic::Species>& PyEngine::getNetworkSpecies() const {
PYBIND11_OVERRIDE_PURE(
std::vector<fourdst::atomic::Species>,
gridfire::Engine, /* Base class */
getNetworkSpecies
);
}
std::expected<gridfire::StepDerivatives<double>, gridfire::expectations::StaleEngineError> PyEngine::calculateRHSAndEnergy(const std::vector<double> &Y, double T9, double rho) const {
PYBIND11_OVERRIDE_PURE(
PYBIND11_TYPE(std::expected<gridfire::StepDerivatives<double>, gridfire::expectations::StaleEngineError>),
gridfire::Engine,
calculateRHSAndEnergy,
Y, T9, rho
);
}
///////////////////////////////////////
/// PyDynamicEngine Implementation ///
/////////////////////////////////////
const std::vector<fourdst::atomic::Species>& PyDynamicEngine::getNetworkSpecies() const {
PYBIND11_OVERRIDE_PURE(
std::vector<fourdst::atomic::Species>,
gridfire::DynamicEngine, /* Base class */
getNetworkSpecies
);
}
std::expected<gridfire::StepDerivatives<double>, gridfire::expectations::StaleEngineError> PyDynamicEngine::calculateRHSAndEnergy(const std::vector<double> &Y, double T9, double rho) const {
PYBIND11_OVERRIDE_PURE(
PYBIND11_TYPE(std::expected<gridfire::StepDerivatives<double>, gridfire::expectations::StaleEngineError>),
gridfire::Engine,
calculateRHSAndEnergy,
Y, T9, rho
);
}
void PyDynamicEngine::generateJacobianMatrix(const std::vector<double> &Y_dynamic, double T9, double rho) const {
PYBIND11_OVERRIDE_PURE(
void,
gridfire::DynamicEngine,
generateJacobianMatrix,
Y_dynamic, T9, rho
);
}
void PyDynamicEngine::generateJacobianMatrix(const std::vector<double> &Y_dynamic, double T9, double rho, const gridfire::SparsityPattern &sparsityPattern) const {
PYBIND11_OVERRIDE_PURE(
void,
gridfire::DynamicEngine,
generateJacobianMatrix,
Y_dynamic, T9, rho, sparsityPattern
);
}
double PyDynamicEngine::getJacobianMatrixEntry(int i, int j) const {
PYBIND11_OVERRIDE_PURE(
double,
gridfire::DynamicEngine,
getJacobianMatrixEntry,
i, j
);
}
void PyDynamicEngine::generateStoichiometryMatrix() {
PYBIND11_OVERRIDE_PURE(
void,
gridfire::DynamicEngine,
generateStoichiometryMatrix
);
}
int PyDynamicEngine::getStoichiometryMatrixEntry(int speciesIndex, int reactionIndex) const {
PYBIND11_OVERRIDE_PURE(
int,
gridfire::DynamicEngine,
getStoichiometryMatrixEntry,
speciesIndex, reactionIndex
);
}
double PyDynamicEngine::calculateMolarReactionFlow(const gridfire::reaction::Reaction &reaction, const std::vector<double> &Y, double T9, double rho) const {
PYBIND11_OVERRIDE_PURE(
double,
gridfire::DynamicEngine,
calculateMolarReactionFlow,
reaction, Y, T9, rho
);
}
const gridfire::reaction::LogicalReactionSet& PyDynamicEngine::getNetworkReactions() const {
PYBIND11_OVERRIDE_PURE(
const gridfire::reaction::LogicalReactionSet&,
gridfire::DynamicEngine,
getNetworkReactions
);
}
void PyDynamicEngine::setNetworkReactions(const gridfire::reaction::LogicalReactionSet& reactions) {
PYBIND11_OVERRIDE_PURE(
void,
gridfire::DynamicEngine,
setNetworkReactions,
reactions
);
}
std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::expectations::StaleEngineError> PyDynamicEngine::getSpeciesTimescales(const std::vector<double> &Y, double T9, double rho) const {
PYBIND11_OVERRIDE_PURE(
PYBIND11_TYPE(std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::expectations::StaleEngineError>),
gridfire::DynamicEngine,
getSpeciesTimescales,
Y, T9, rho
);
}
std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::expectations::StaleEngineError> PyDynamicEngine::getSpeciesDestructionTimescales(const std::vector<double> &Y, double T9, double rho) const {
PYBIND11_OVERRIDE_PURE(
PYBIND11_TYPE(std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::expectations::StaleEngineError>),
gridfire::DynamicEngine,
getSpeciesDestructionTimescales,
Y, T9, rho
);
}
fourdst::composition::Composition PyDynamicEngine::update(const gridfire::NetIn &netIn) {
PYBIND11_OVERRIDE_PURE(
fourdst::composition::Composition,
gridfire::DynamicEngine,
update,
netIn
);
}
bool PyDynamicEngine::isStale(const gridfire::NetIn &netIn) {
PYBIND11_OVERRIDE_PURE(
bool,
gridfire::DynamicEngine,
isStale,
netIn
);
}
void PyDynamicEngine::setScreeningModel(gridfire::screening::ScreeningType model) {
PYBIND11_OVERRIDE_PURE(
void,
gridfire::DynamicEngine,
setScreeningModel,
model
);
}
gridfire::screening::ScreeningType PyDynamicEngine::getScreeningModel() const {
PYBIND11_OVERRIDE_PURE(
gridfire::screening::ScreeningType,
gridfire::DynamicEngine,
getScreeningModel
);
}
int PyDynamicEngine::getSpeciesIndex(const fourdst::atomic::Species &species) const {
PYBIND11_OVERRIDE_PURE(
int,
gridfire::DynamicEngine,
getSpeciesIndex,
species
);
}
std::vector<double> PyDynamicEngine::mapNetInToMolarAbundanceVector(const gridfire::NetIn &netIn) const {
PYBIND11_OVERRIDE_PURE(
std::vector<double>,
gridfire::DynamicEngine,
mapNetInToMolarAbundanceVector,
netIn
);
}
gridfire::PrimingReport PyDynamicEngine::primeEngine(const gridfire::NetIn &netIn) {
PYBIND11_OVERRIDE_PURE(
gridfire::PrimingReport,
gridfire::DynamicEngine,
primeEngine,
netIn
);
}
const gridfire::Engine& PyEngineView::getBaseEngine() const {
PYBIND11_OVERRIDE_PURE(
const gridfire::Engine&,
gridfire::EngineView<gridfire::Engine>,
getBaseEngine
);
}
const gridfire::DynamicEngine& PyDynamicEngineView::getBaseEngine() const {
PYBIND11_OVERRIDE_PURE(
const gridfire::DynamicEngine&,
gridfire::EngineView<gridfire::DynamicEngine>,
getBaseEngine
);
}

View File

@@ -0,0 +1,52 @@
#pragma once
#include "gridfire/engine/engine.h"
#include "gridfire/expectations/expected_engine.h"
#include "fourdst/composition/atomicSpecies.h"
#include <vector>
#include <expected>
class PyEngine final : public gridfire::Engine {
public:
const std::vector<fourdst::atomic::Species>& getNetworkSpecies() const override;
std::expected<gridfire::StepDerivatives<double>,gridfire::expectations::StaleEngineError> calculateRHSAndEnergy(const std::vector<double> &Y, double T9, double rho) const override;
};
class PyDynamicEngine final : public gridfire::DynamicEngine {
const std::vector<fourdst::atomic::Species>& getNetworkSpecies() const override;
std::expected<gridfire::StepDerivatives<double>,gridfire::expectations::StaleEngineError> calculateRHSAndEnergy(const std::vector<double> &Y, double T9, double rho) const override;
void generateJacobianMatrix(const std::vector<double> &Y_dynamic, double T9, double rho) const override;
void generateJacobianMatrix(const std::vector<double> &Y_dynamic, double T9, double rho, const gridfire::SparsityPattern &sparsityPattern) const override;
double getJacobianMatrixEntry(int i, int j) const override;
void generateStoichiometryMatrix() override;
int getStoichiometryMatrixEntry(int speciesIndex, int reactionIndex) const override;
double calculateMolarReactionFlow(const gridfire::reaction::Reaction &reaction, const std::vector<double> &Y, double T9, double rho) const override;
const gridfire::reaction::LogicalReactionSet& getNetworkReactions() const override;
void setNetworkReactions(const gridfire::reaction::LogicalReactionSet& reactions) override;
std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::expectations::StaleEngineError> getSpeciesTimescales(const std::vector<double> &Y, double T9, double rho) const override;
std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::expectations::StaleEngineError> getSpeciesDestructionTimescales(const std::vector<double> &Y, double T9, double rho) const override;
fourdst::composition::Composition update(const gridfire::NetIn &netIn) override;
bool isStale(const gridfire::NetIn &netIn) override;
void setScreeningModel(gridfire::screening::ScreeningType model) override;
gridfire::screening::ScreeningType getScreeningModel() const override;
int getSpeciesIndex(const fourdst::atomic::Species &species) const override;
std::vector<double> mapNetInToMolarAbundanceVector(const gridfire::NetIn &netIn) const override;
gridfire::PrimingReport primeEngine(const gridfire::NetIn &netIn) override;
gridfire::BuildDepthType getDepth() const override {
throw std::logic_error("Network depth not supported by this engine.");
}
void rebuild(const fourdst::composition::Composition& comp, gridfire::BuildDepthType depth) override {
throw std::logic_error("Setting network depth not supported by this engine.");
}
};
class PyEngineView final : public gridfire::EngineView<gridfire::Engine> {
const gridfire::Engine& getBaseEngine() const override;
};
class PyDynamicEngineView final : public gridfire::EngineView<gridfire::DynamicEngine> {
const gridfire::DynamicEngine& getBaseEngine() const override;
};

View File

@@ -0,0 +1,45 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc if needed directly
#include <iostream>
#include "bindings.h"
namespace py = pybind11;
#include "gridfire/exceptions/exceptions.h"
void register_exception_bindings(py::module &m) {
py::register_exception<gridfire::exceptions::EngineError>(m, "GridFireEngineError");
// TODO: Make it so that we can grab the stale state in python
// m.attr("StaleEngineTrigger") = py::register_exception<gridfire::exceptions::StaleEngineTrigger>(m, "StaleEngineTrigger", m.attr("GridFireEngineError"));
m.attr("StaleEngineError") = py::register_exception<gridfire::exceptions::StaleEngineError>(m, "StaleEngineError", m.attr("GridFireEngineError"));
m.attr("FailedToPartitionEngineError") = py::register_exception<gridfire::exceptions::FailedToPartitionEngineError>(m, "FailedToPartitionEngineError", m.attr("GridFireEngineError"));
m.attr("NetworkResizedError") = py::register_exception<gridfire::exceptions::NetworkResizedError>(m, "NetworkResizedError", m.attr("GridFireEngineError"));
m.attr("UnableToSetNetworkReactionsError") = py::register_exception<gridfire::exceptions::UnableToSetNetworkReactionsError>(m, "UnableToSetNetworkReactionsError", m.attr("GridFireEngineError"));
py::class_<gridfire::exceptions::StaleEngineTrigger::state>(m, "StaleEngineState")
.def(py::init<>())
.def_readwrite("T9", &gridfire::exceptions::StaleEngineTrigger::state::m_T9)
.def_readwrite("rho", &gridfire::exceptions::StaleEngineTrigger::state::m_rho)
.def_readwrite("Y", &gridfire::exceptions::StaleEngineTrigger::state::m_Y)
.def_readwrite("t", &gridfire::exceptions::StaleEngineTrigger::state::m_t)
.def_readwrite("total_steps", &gridfire::exceptions::StaleEngineTrigger::state::m_total_steps)
.def_readwrite("eps_nuc", &gridfire::exceptions::StaleEngineTrigger::state::m_eps_nuc);
py::class_<gridfire::exceptions::StaleEngineTrigger>(m, "StaleEngineTrigger")
.def(py::init<const gridfire::exceptions::StaleEngineTrigger::state &>())
.def("getState", &gridfire::exceptions::StaleEngineTrigger::getState)
.def("numSpecies", &gridfire::exceptions::StaleEngineTrigger::numSpecies)
.def("totalSteps", &gridfire::exceptions::StaleEngineTrigger::totalSteps)
.def("energy", &gridfire::exceptions::StaleEngineTrigger::energy)
.def("getMolarAbundance", &gridfire::exceptions::StaleEngineTrigger::getMolarAbundance)
.def("temperature", &gridfire::exceptions::StaleEngineTrigger::temperature)
.def("density", &gridfire::exceptions::StaleEngineTrigger::density)
.def("__repr__", [&](const gridfire::exceptions::StaleEngineTrigger& self) {
return self.what();
});
}

View File

@@ -0,0 +1,5 @@
#pragma once
#include <pybind11/pybind11.h>
void register_exception_bindings(pybind11::module &m);

View File

@@ -0,0 +1,17 @@
# Define the library
bindings_sources = files('bindings.cpp')
bindings_headers = files('bindings.h')
dependencies = [
gridfire_dep,
python3_dep,
pybind11_dep,
]
shared_module('py_gf_exceptions',
bindings_sources,
cpp_args: ['-fvisibility=default'],
install : true,
dependencies: dependencies,
include_directories: include_directories('.')
)

View File

@@ -0,0 +1,43 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc. if needed directly
#include "bindings.h"
namespace py = pybind11;
#include "gridfire/expectations/expectations.h"
void register_expectation_bindings(py::module &m) {
py::enum_<gridfire::expectations::EngineErrorTypes>(m, "EngineErrorTypes")
.value("FAILURE", gridfire::expectations::EngineErrorTypes::FAILURE)
.value("INDEX", gridfire::expectations::EngineErrorTypes::INDEX)
.value("STALE", gridfire::expectations::EngineErrorTypes::STALE)
.export_values();
py::enum_<gridfire::expectations::StaleEngineErrorTypes>(m, "StaleEngineErrorTypes")
.value("SYSTEM_RESIZED", gridfire::expectations::StaleEngineErrorTypes::SYSTEM_RESIZED)
.export_values();
// Bind the base class
py::class_<gridfire::expectations::EngineError>(m, "EngineError")
.def_readonly("message", &gridfire::expectations::EngineError::m_message)
.def_readonly("type", &gridfire::expectations::EngineError::type)
.def("__str__", [](const gridfire::expectations::EngineError &e) {return e.m_message;});
// Bind the EngineIndexError, specifying EngineError as the base
py::class_<gridfire::expectations::EngineIndexError, gridfire::expectations::EngineError>(m, "EngineIndexError")
.def(py::init<int>(), py::arg("index"))
.def_readonly("index", &gridfire::expectations::EngineIndexError::m_index)
.def("__str__", [](const gridfire::expectations::EngineIndexError &e) {
return e.m_message + " at index " + std::to_string(e.m_index);
});
// Bind the StaleEngineError, specifying EngineError as the base
py::class_<gridfire::expectations::StaleEngineError, gridfire::expectations::EngineError>(m, "StaleEngineError")
.def(py::init<gridfire::expectations::StaleEngineErrorTypes>(), py::arg("stale_type"))
.def_readonly("stale_type", &gridfire::expectations::StaleEngineError::staleType)
.def("__str__", [](const gridfire::expectations::StaleEngineError &e) {
return static_cast<std::string>(e);
});
}

View File

@@ -0,0 +1,5 @@
#pragma once
#include <pybind11/pybind11.h>
void register_expectation_bindings(pybind11::module &m);

View File

@@ -0,0 +1,17 @@
# Define the library
bindings_sources = files('bindings.cpp')
bindings_headers = files('bindings.h')
dependencies = [
gridfire_dep,
python3_dep,
pybind11_dep,
]
shared_module('py_gf_expectations',
bindings_sources,
cpp_args: ['-fvisibility=default'],
install : true,
dependencies: dependencies,
include_directories: include_directories('.')
)

View File

@@ -0,0 +1,29 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc if needed directly
#include <string_view>
#include <vector>
#include "bindings.h"
#include "gridfire/io/io.h"
#include "trampoline/py_io.h"
namespace py = pybind11;
void register_io_bindings(py::module &m) {
py::class_<gridfire::io::ParsedNetworkData>(m, "ParsedNetworkData");
py::class_<gridfire::io::NetworkFileParser, PyNetworkFileParser>(m, "NetworkFileParser");
py::class_<gridfire::io::SimpleReactionListFileParser, gridfire::io::NetworkFileParser>(m, "SimpleReactionListFileParser")
.def("parse", &gridfire::io::SimpleReactionListFileParser::parse,
py::arg("filename"),
"Parse a simple reaction list file and return a ParsedNetworkData object.");
// py::class_<gridfire::io::MESANetworkFileParser, gridfire::io::NetworkFileParser>(m, "MESANetworkFileParser")
// .def("parse", &gridfire::io::MESANetworkFileParser::parse,
// py::arg("filename"),
// "Parse a MESA network file and return a ParsedNetworkData object.");
}

5
src/python/io/bindings.h Normal file
View File

@@ -0,0 +1,5 @@
#pragma once
#include <pybind11/pybind11.h>
void register_io_bindings(pybind11::module &m);

17
src/python/io/meson.build Normal file
View File

@@ -0,0 +1,17 @@
# Define the library
bindings_sources = files('bindings.cpp')
bindings_headers = files('bindings.h')
dependencies = [
gridfire_dep,
python3_dep,
pybind11_dep,
]
shared_module('py_gf_io',
bindings_sources,
cpp_args: ['-fvisibility=default'],
install : true,
dependencies: dependencies,
include_directories: include_directories('.')
)

View File

@@ -0,0 +1,21 @@
gf_io_trampoline_sources = files('py_io.cpp')
gf_io_trapoline_dependencies = [
gridfire_dep,
pybind11_dep,
python3_dep,
]
gf_io_trampoline_lib = static_library(
'io_trampolines',
gf_io_trampoline_sources,
include_directories: include_directories('.'),
dependencies: gf_io_trapoline_dependencies,
install: false,
)
gr_io_trampoline_dep = declare_dependency(
link_with: gf_io_trampoline_lib,
include_directories: ('.'),
dependencies: gf_io_trapoline_dependencies,
)

View File

@@ -0,0 +1,14 @@
#include "gridfire/io/io.h"
#include "py_io.h"
#include <pybind11/pybind11.h>
namespace py = pybind11;
gridfire::io::ParsedNetworkData PyNetworkFileParser::parse(const std::string &filename) const {
PYBIND11_OVERLOAD_PURE(
gridfire::io::ParsedNetworkData,
gridfire::io::NetworkFileParser,
parse // Method name
);
}

View File

@@ -0,0 +1,7 @@
#pragma once
#include "gridfire/io/io.h"
class PyNetworkFileParser final : public gridfire::io::NetworkFileParser {
gridfire::io::ParsedNetworkData parse(const std::string &filename) const override;
};

10
src/python/meson.build Normal file
View File

@@ -0,0 +1,10 @@
subdir('types')
subdir('utils')
subdir('expectations')
subdir('exceptions')
subdir('io')
subdir('partition')
subdir('reaction')
subdir('screening')
subdir('engine')
subdir('solver')

View File

@@ -0,0 +1,113 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc if needed directly
#include <iostream>
#include <memory>
#include "bindings.h"
#include "gridfire/partition/partition.h"
PYBIND11_DECLARE_HOLDER_TYPE(T, std::unique_ptr<T>, true) // Declare unique_ptr as a holder type for pybind11
#include "trampoline/py_partition.h"
namespace py = pybind11;
void register_partition_bindings(pybind11::module &m) {
using PF = gridfire::partition::PartitionFunction;
py::class_<PF, PyPartitionFunction>(m, "PartitionFunction");
register_partition_types_bindings(m);
register_ground_state_partition_bindings(m);
register_rauscher_thielemann_partition_data_record_bindings(m);
register_rauscher_thielemann_partition_bindings(m);
register_composite_partition_bindings(m);
}
void register_partition_types_bindings(pybind11::module &m) {
py::enum_<gridfire::partition::BasePartitionType>(m, "BasePartitionType")
.value("RauscherThielemann", gridfire::partition::BasePartitionType::RauscherThielemann)
.value("GroundState", gridfire::partition::BasePartitionType::GroundState)
.export_values();
m.def("basePartitionTypeToString", [](gridfire::partition::BasePartitionType type) {
return gridfire::partition::basePartitionTypeToString[type];
}, py::arg("type"), "Convert BasePartitionType to string.");
m.def("stringToBasePartitionType", [](const std::string &typeStr) {
return gridfire::partition::stringToBasePartitionType[typeStr];
}, py::arg("typeStr"), "Convert string to BasePartitionType.");
}
void register_ground_state_partition_bindings(pybind11::module &m) {
using GSPF = gridfire::partition::GroundStatePartitionFunction;
using PF = gridfire::partition::PartitionFunction;
py::class_<GSPF, PF>(m, "GroundStatePartitionFunction")
.def(py::init<>())
.def("evaluate", &gridfire::partition::GroundStatePartitionFunction::evaluate,
py::arg("z"), py::arg("a"), py::arg("T9"),
"Evaluate the ground state partition function for given Z, A, and T9.")
.def("evaluateDerivative", &gridfire::partition::GroundStatePartitionFunction::evaluateDerivative,
py::arg("z"), py::arg("a"), py::arg("T9"),
"Evaluate the derivative of the ground state partition function for given Z, A, and T9.")
.def("supports", &gridfire::partition::GroundStatePartitionFunction::supports,
py::arg("z"), py::arg("a"),
"Check if the ground state partition function supports given Z and A.")
.def("get_type", &gridfire::partition::GroundStatePartitionFunction::type,
"Get the type of the partition function (should return 'GroundState').");
}
void register_rauscher_thielemann_partition_data_record_bindings(pybind11::module &m) {
py::class_<gridfire::partition::record::RauscherThielemannPartitionDataRecord>(m, "RauscherThielemannPartitionDataRecord")
.def_readonly("z", &gridfire::partition::record::RauscherThielemannPartitionDataRecord::z, "Atomic number")
.def_readonly("a", &gridfire::partition::record::RauscherThielemannPartitionDataRecord::a, "Mass number")
.def_readonly("ground_state_spin", &gridfire::partition::record::RauscherThielemannPartitionDataRecord::ground_state_spin, "Ground state spin")
.def_readonly("normalized_g_values", &gridfire::partition::record::RauscherThielemannPartitionDataRecord::normalized_g_values, "Normalized g-values for the first 24 energy levels");
}
void register_rauscher_thielemann_partition_bindings(pybind11::module &m) {
using RTPF = gridfire::partition::RauscherThielemannPartitionFunction;
using PF = gridfire::partition::PartitionFunction;
py::class_<RTPF, PF>(m, "RauscherThielemannPartitionFunction")
.def(py::init<>())
.def("evaluate", &gridfire::partition::RauscherThielemannPartitionFunction::evaluate,
py::arg("z"), py::arg("a"), py::arg("T9"),
"Evaluate the Rauscher-Thielemann partition function for given Z, A, and T9.")
.def("evaluateDerivative", &gridfire::partition::RauscherThielemannPartitionFunction::evaluateDerivative,
py::arg("z"), py::arg("a"), py::arg("T9"),
"Evaluate the derivative of the Rauscher-Thielemann partition function for given Z, A, and T9.")
.def("supports", &gridfire::partition::RauscherThielemannPartitionFunction::supports,
py::arg("z"), py::arg("a"),
"Check if the Rauscher-Thielemann partition function supports given Z and A.")
.def("get_type", &gridfire::partition::RauscherThielemannPartitionFunction::type,
"Get the type of the partition function (should return 'RauscherThielemann').");
}
void register_composite_partition_bindings(pybind11::module &m) {
py::class_<gridfire::partition::CompositePartitionFunction>(m, "CompositePartitionFunction")
.def(py::init<const std::vector<gridfire::partition::BasePartitionType>&>(),
py::arg("partitionFunctions"),
"Create a composite partition function from a list of base partition types.")
.def(py::init<const gridfire::partition::CompositePartitionFunction&>(),
"Copy constructor for CompositePartitionFunction.")
.def("evaluate", &gridfire::partition::CompositePartitionFunction::evaluate,
py::arg("z"), py::arg("a"), py::arg("T9"),
"Evaluate the composite partition function for given Z, A, and T9.")
.def("evaluateDerivative", &gridfire::partition::CompositePartitionFunction::evaluateDerivative,
py::arg("z"), py::arg("a"), py::arg("T9"),
"Evaluate the derivative of the composite partition function for given Z, A, and T9.")
.def("supports", &gridfire::partition::CompositePartitionFunction::supports,
py::arg("z"), py::arg("a"),
"Check if the composite partition function supports given Z and A.")
.def("get_type", &gridfire::partition::CompositePartitionFunction::type,
"Get the type of the partition function (should return 'Composite').");
}

View File

@@ -0,0 +1,16 @@
#pragma once
#include <pybind11/pybind11.h>
void register_partition_bindings(pybind11::module &m);
void register_partition_types_bindings(pybind11::module &m);
void register_ground_state_partition_bindings(pybind11::module &m);
void register_rauscher_thielemann_partition_data_record_bindings(pybind11::module &m);
void register_rauscher_thielemann_partition_bindings(pybind11::module &m);
void register_composite_partition_bindings(pybind11::module &m);

View File

@@ -0,0 +1,19 @@
subdir('trampoline')
# Define the library
bindings_sources = files('bindings.cpp')
bindings_headers = files('bindings.h')
dependencies = [
gridfire_dep,
python3_dep,
pybind11_dep,
]
shared_module('py_gf_partition',
bindings_sources,
cpp_args: ['-fvisibility=default'],
install : true,
dependencies: dependencies,
include_directories: include_directories('.')
)

View File

@@ -0,0 +1,21 @@
gf_partition_trampoline_sources = files('py_partition.cpp')
gf_partition_trapoline_dependencies = [
gridfire_dep,
pybind11_dep,
python3_dep,
]
gf_partition_trampoline_lib = static_library(
'partition_trampolines',
gf_partition_trampoline_sources,
include_directories: include_directories('.'),
dependencies: gf_partition_trapoline_dependencies,
install: false,
)
gr_partition_trampoline_dep = declare_dependency(
link_with: gf_partition_trampoline_lib,
include_directories: ('.'),
dependencies: gf_partition_trapoline_dependencies,
)

View File

@@ -0,0 +1,57 @@
#include "py_partition.h"
#include "gridfire/partition/partition.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include <string>
#include <memory>
namespace py = pybind11;
double PyPartitionFunction::evaluate(int z, int a, double T9) const {
PYBIND11_OVERRIDE_PURE(
double,
gridfire::partition::PartitionFunction,
evaluate,
z, a, T9
);
}
double PyPartitionFunction::evaluateDerivative(int z, int a, double T9) const {
PYBIND11_OVERRIDE_PURE(
double,
gridfire::partition::PartitionFunction,
evaluateDerivative,
z, a, T9
);
}
bool PyPartitionFunction::supports(int z, int a) const {
PYBIND11_OVERRIDE_PURE(
bool,
gridfire::partition::PartitionFunction,
supports,
z, a
);
}
std::string PyPartitionFunction::type() const {
PYBIND11_OVERRIDE_PURE(
std::string,
gridfire::partition::PartitionFunction,
type
);
}
std::unique_ptr<gridfire::partition::PartitionFunction> PyPartitionFunction::clone() const {
PYBIND11_OVERRIDE_PURE(
std::unique_ptr<gridfire::partition::PartitionFunction>,
gridfire::partition::PartitionFunction,
clone
);
}

View File

@@ -0,0 +1,15 @@
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "gridfire/partition/partition.h"
class PyPartitionFunction final : public gridfire::partition::PartitionFunction {
double evaluate(int z, int a, double T9) const override;
double evaluateDerivative(int z, int a, double T9) const override;
bool supports(int z, int a) const override;
std::string type() const override;
std::unique_ptr<gridfire::partition::PartitionFunction> clone() const override;
};

View File

@@ -0,0 +1,171 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc if needed directly
#include <string_view>
#include <vector>
#include "bindings.h"
#include "gridfire/reaction/reaction.h"
#include "gridfire/reaction/reaclib.h"
namespace py = pybind11;
void register_reaction_bindings(py::module &m) {
py::class_<gridfire::reaction::RateCoefficientSet>(m, "RateCoefficientSet")
.def(py::init<double, double, double, double, double, double, double>(),
py::arg("a0"), py::arg("a1"), py::arg("a2"), py::arg("a3"),
py::arg("a4"), py::arg("a5"), py::arg("a6"),
"Construct a RateCoefficientSet with the given parameters."
);
using fourdst::atomic::Species;
py::class_<gridfire::reaction::Reaction>(m, "Reaction")
.def(py::init<const std::string_view, const std::string_view, int, const std::vector<Species>&, const std::vector<Species>&, double, std::string_view, gridfire::reaction::RateCoefficientSet, bool>(),
py::arg("id"), py::arg("peName"), py::arg("chapter"),
py::arg("reactants"), py::arg("products"), py::arg("qValue"),
py::arg("label"), py::arg("sets"), py::arg("reverse") = false,
"Construct a Reaction with the given parameters.")
.def("calculate_rate", static_cast<double (gridfire::reaction::Reaction::*)(double) const>(&gridfire::reaction::Reaction::calculate_rate),
py::arg("T9"), "Calculate the reaction rate at a given temperature T9 (in units of 10^9 K).")
.def("peName", &gridfire::reaction::Reaction::peName,
"Get the reaction name in (projectile, ejectile) notation (e.g., 'p(p,g)d').")
.def("chapter", &gridfire::reaction::Reaction::chapter,
"Get the REACLIB chapter number defining the reaction structure.")
.def("sourceLabel", &gridfire::reaction::Reaction::sourceLabel,
"Get the source label for the rate data (e.g., 'wc12w', 'st08').")
.def("rateCoefficients", &gridfire::reaction::Reaction::rateCoefficients,
"get the set of rate coefficients.")
.def("contains", &gridfire::reaction::Reaction::contains,
py::arg("species"), "Check if the reaction contains a specific species.")
.def("contains_reactant", &gridfire::reaction::Reaction::contains_reactant,
"Check if the reaction contains a specific reactant species.")
.def("contains_product", &gridfire::reaction::Reaction::contains_product,
"Check if the reaction contains a specific product species.")
.def("all_species", &gridfire::reaction::Reaction::all_species,
"Get all species involved in the reaction (both reactants and products) as a set.")
.def("reactant_species", &gridfire::reaction::Reaction::reactant_species,
"Get the reactant species of the reaction as a set.")
.def("product_species", &gridfire::reaction::Reaction::product_species,
"Get the product species of the reaction as a set.")
.def("num_species", &gridfire::reaction::Reaction::num_species,
"Count the number of species in the reaction.")
.def("stoichiometry", static_cast<int (gridfire::reaction::Reaction::*)(const Species&) const>(&gridfire::reaction::Reaction::stoichiometry),
py::arg("species"),
"Get the stoichiometry of the reaction as a map from species to their coefficients.")
.def("stoichiometry", static_cast<std::unordered_map<Species, int> (gridfire::reaction::Reaction::*)() const>(&gridfire::reaction::Reaction::stoichiometry),
"Get the stoichiometry of the reaction as a map from species to their coefficients.")
.def("id", &gridfire::reaction::Reaction::id,
"Get the unique identifier of the reaction.")
.def("qValue", &gridfire::reaction::Reaction::qValue,
"Get the Q-value of the reaction in MeV.")
.def("reactants", &gridfire::reaction::Reaction::reactants,
"Get a list of reactant species in the reaction.")
.def("products", &gridfire::reaction::Reaction::products,
"Get a list of product species in the reaction.")
.def("is_reverse", &gridfire::reaction::Reaction::is_reverse,
"Check if this is a reverse reaction rate.")
.def("excess_energy", &gridfire::reaction::Reaction::excess_energy,
"Calculate the excess energy from the mass difference of reactants and products.")
.def("__eq__", &gridfire::reaction::Reaction::operator==,
"Equality operator for reactions based on their IDs.")
.def("__neq__", &gridfire::reaction::Reaction::operator!=,
"Inequality operator for reactions based on their IDs.")
.def("hash", &gridfire::reaction::Reaction::hash,
py::arg("seed") = 0,
"Compute a hash for the reaction based on its ID.")
.def("__repr__", [](const gridfire::reaction::Reaction& self) {
std::stringstream ss;
ss << self; // Use the existing operator<< for Reaction
return ss.str();
});
py::class_<gridfire::reaction::LogicalReaction, gridfire::reaction::Reaction>(m, "LogicalReaction")
.def(py::init<const std::vector<gridfire::reaction::Reaction>>(),
py::arg("reactions"),
"Construct a LogicalReaction from a vector of Reaction objects.")
.def("add_reaction", &gridfire::reaction::LogicalReaction::add_reaction,
py::arg("reaction"),
"Add another Reaction source to this logical reaction.")
.def("size", &gridfire::reaction::LogicalReaction::size,
"Get the number of source rates contributing to this logical reaction.")
.def("__len__", &gridfire::reaction::LogicalReaction::size,
"Overload len() to return the number of source rates.")
.def("sources", &gridfire::reaction::LogicalReaction::sources,
"Get the list of source labels for the aggregated rates.")
.def("calculate_rate", static_cast<double (gridfire::reaction::LogicalReaction::*)(double) const>(&gridfire::reaction::LogicalReaction::calculate_rate),
py::arg("T9"), "Calculate the reaction rate at a given temperature T9 (in units of 10^9 K).")
.def("calculate_forward_rate_log_derivative", &gridfire::reaction::LogicalReaction::calculate_forward_rate_log_derivative,
py::arg("T9"), "Calculate the forward rate log derivative at a given temperature T9 (in units of 10^9 K).");
py::class_<gridfire::reaction::LogicalReactionSet>(m, "LogicalReactionSet")
.def(py::init<const std::vector<gridfire::reaction::LogicalReaction>>(),
py::arg("reactions"),
"Construct a LogicalReactionSet from a vector of LogicalReaction objects.")
.def(py::init<>(),
"Default constructor for an empty LogicalReactionSet.")
.def(py::init<const gridfire::reaction::LogicalReactionSet&>(),
py::arg("other"),
"Copy constructor for LogicalReactionSet.")
.def("add_reaction", &gridfire::reaction::LogicalReactionSet::add_reaction,
py::arg("reaction"),
"Add a LogicalReaction to the set.")
.def("remove_reaction", &gridfire::reaction::LogicalReactionSet::remove_reaction,
py::arg("reaction"),
"Remove a LogicalReaction from the set.")
.def("contains", py::overload_cast<const std::string_view&>(&gridfire::reaction::LogicalReactionSet::contains, py::const_),
py::arg("id"),
"Check if the set contains a specific LogicalReaction.")
.def("contains", py::overload_cast<const gridfire::reaction::Reaction&>(&gridfire::reaction::LogicalReactionSet::contains, py::const_),
py::arg("reaction"),
"Check if the set contains a specific Reaction.")
.def("size", &gridfire::reaction::LogicalReactionSet::size,
"Get the number of LogicalReactions in the set.")
.def("__len__", &gridfire::reaction::LogicalReactionSet::size,
"Overload len() to return the number of LogicalReactions.")
.def("clear", &gridfire::reaction::LogicalReactionSet::clear,
"Remove all LogicalReactions from the set.")
.def("containes_species", &gridfire::reaction::LogicalReactionSet::contains_species,
py::arg("species"),
"Check if any reaction in the set involves the given species.")
.def("contains_reactant", &gridfire::reaction::LogicalReactionSet::contains_reactant,
py::arg("species"),
"Check if any reaction in the set has the species as a reactant.")
.def("contains_product", &gridfire::reaction::LogicalReactionSet::contains_product,
py::arg("species"),
"Check if any reaction in the set has the species as a product.")
.def("__getitem__", py::overload_cast<size_t>(&gridfire::reaction::LogicalReactionSet::operator[], py::const_),
py::arg("index"),
"Get a LogicalReaction by index.")
.def("__getitem___", py::overload_cast<const std::string_view&>(&gridfire::reaction::LogicalReactionSet::operator[], py::const_),
py::arg("id"),
"Get a LogicalReaction by its ID.")
.def("__eq__", &gridfire::reaction::LogicalReactionSet::operator==,
py::arg("LogicalReactionSet"),
"Equality operator for LogicalReactionSets based on their contents.")
.def("__ne__", &gridfire::reaction::LogicalReactionSet::operator!=,
py::arg("LogicalReactionSet"),
"Inequality operator for LogicalReactionSets based on their contents.")
.def("hash", &gridfire::reaction::LogicalReactionSet::hash,
py::arg("seed") = 0,
"Compute a hash for the LogicalReactionSet based on its contents."
)
.def("__repr__", [](const gridfire::reaction::LogicalReactionSet& self) {
std::stringstream ss;
ss << self;
return ss.str();
})
.def("getReactionSetSpecies", &gridfire::reaction::LogicalReactionSet::getReactionSetSpecies,
"Get all species involved in the reactions of the set as a set of Species objects.");
m.def("packReactionSetToLogicalReactionSet",
&gridfire::reaction::packReactionSetToLogicalReactionSet,
py::arg("reactionSet"),
"Convert a ReactionSet to a LogicalReactionSet by aggregating reactions with the same peName."
);
m.def("get_all_reactions", &gridfire::reaclib::get_all_reactions,
"Get all reactions from the REACLIB database.");
}

View File

@@ -0,0 +1,5 @@
#pragma once
#include <pybind11/pybind11.h>
void register_reaction_bindings(pybind11::module &m);

View File

@@ -0,0 +1,17 @@
# Define the library
bindings_sources = files('bindings.cpp')
bindings_headers = files('bindings.h')
dependencies = [
gridfire_dep,
python3_dep,
pybind11_dep,
]
shared_module('py_gf_reaction',
bindings_sources,
cpp_args: ['-fvisibility=default'],
install : true,
dependencies: dependencies,
include_directories: include_directories('.')
)

View File

@@ -0,0 +1,42 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc if needed directly
#include <string_view>
#include <vector>
#include "bindings.h"
#include "gridfire/screening/screening.h"
#include "trampoline/py_screening.h"
namespace py = pybind11;
void register_screening_bindings(py::module &m) {
py::class_<gridfire::screening::ScreeningModel, PyScreening>(m, "ScreeningModel");
py::enum_<gridfire::screening::ScreeningType>(m, "ScreeningType")
.value("BARE", gridfire::screening::ScreeningType::BARE)
.value("WEAK", gridfire::screening::ScreeningType::WEAK)
.export_values();
m.def("selectScreeningModel", &gridfire::screening::selectScreeningModel,
py::arg("type"),
"Select a screening model based on the specified type. Returns a pointer to the selected model.");
py::class_<gridfire::screening::BareScreeningModel>(m, "BareScreeningModel")
.def(py::init<>())
.def("calculateScreeningFactors",
py::overload_cast<const gridfire::reaction::LogicalReactionSet&, const std::vector<fourdst::atomic::Species>&, const std::vector<double>&, double, double>(&gridfire::screening::BareScreeningModel::calculateScreeningFactors, py::const_),
py::arg("reactions"), py::arg("species"), py::arg("Y"), py::arg("T9"), py::arg("rho"),
"Calculate the bare plasma screening factors. This always returns 1.0 (bare)"
);
py::class_<gridfire::screening::WeakScreeningModel>(m, "WeakScreeningModel")
.def(py::init<>())
.def("calculateScreeningFactors",
py::overload_cast<const gridfire::reaction::LogicalReactionSet&, const std::vector<fourdst::atomic::Species>&, const std::vector<double>&, double, double>(&gridfire::screening::WeakScreeningModel::calculateScreeningFactors, py::const_),
py::arg("reactions"), py::arg("species"), py::arg("Y"), py::arg("T9"), py::arg("rho"),
"Calculate the weak plasma screening factors using the Salpeter (1954) model."
);
}

View File

@@ -0,0 +1,5 @@
#pragma once
#include <pybind11/pybind11.h>
void register_screening_bindings(pybind11::module &m);

View File

@@ -0,0 +1,19 @@
subdir('trampoline')
# Define the library
bindings_sources = files('bindings.cpp')
bindings_headers = files('bindings.h')
dependencies = [
gridfire_dep,
python3_dep,
pybind11_dep,
]
shared_module('py_gf_screening',
bindings_sources,
cpp_args: ['-fvisibility=default'],
install : true,
dependencies: dependencies,
include_directories: include_directories('.')
)

View File

@@ -0,0 +1,21 @@
gf_screening_trampoline_sources = files('py_screening.cpp')
gf_screening_trapoline_dependencies = [
gridfire_dep,
pybind11_dep,
python3_dep,
]
gf_screening_trampoline_lib = static_library(
'screening_trampolines',
gf_screening_trampoline_sources,
include_directories: include_directories('.'),
dependencies: gf_screening_trapoline_dependencies,
install: false,
)
gr_screening_trampoline_dep = declare_dependency(
link_with: gf_screening_trampoline_lib,
include_directories: ('.'),
dependencies: gf_screening_trapoline_dependencies,
)

View File

@@ -0,0 +1,31 @@
#include "gridfire/screening/screening.h"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/functional.h> // Needed for std::function
#include <vector>
#include "py_screening.h"
#include "cppad/cppad.hpp"
namespace py = pybind11;
std::vector<double> PyScreening::calculateScreeningFactors(const gridfire::reaction::LogicalReactionSet &reactions, const std::vector<fourdst::atomic::Species> &species, const std::vector<double> &Y, const double T9, const double rho) const {
PYBIND11_OVERLOAD_PURE(
std::vector<double>, // Return type
gridfire::screening::ScreeningModel,
calculateScreeningFactors // Method name
);
}
using ADDouble = gridfire::screening::ScreeningModel::ADDouble;
std::vector<ADDouble> PyScreening::calculateScreeningFactors(const gridfire::reaction::LogicalReactionSet &reactions, const std::vector<fourdst::atomic::Species> &species, const std::vector<ADDouble> &Y, const ADDouble T9, const ADDouble rho) const {
PYBIND11_OVERLOAD_PURE(
std::vector<ADDouble>, // Return type
gridfire::screening::ScreeningModel,
calculateScreeningFactors // Method name
);
}

View File

@@ -0,0 +1,16 @@
#pragma once
#include "gridfire/screening/screening.h"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/functional.h> // Needed for std::function
#include <vector>
#include "cppad/cppad.hpp"
class PyScreening final : public gridfire::screening::ScreeningModel {
std::vector<double> calculateScreeningFactors(const gridfire::reaction::LogicalReactionSet &reactions, const std::vector<fourdst::atomic::Species> &species, const std::vector<double> &Y, const double T9, const double rho) const override;
std::vector<ADDouble> calculateScreeningFactors(const gridfire::reaction::LogicalReactionSet &reactions, const std::vector<fourdst::atomic::Species> &species, const std::vector<ADDouble> &Y, const ADDouble T9, const ADDouble rho) const override;
};

View File

@@ -0,0 +1,26 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc if needed directly
#include "bindings.h"
#include "gridfire/solver/solver.h"
#include "trampoline/py_solver.h"
namespace py = pybind11;
void register_solver_bindings(py::module &m) {
auto py_dynamic_network_solving_strategy = py::class_<gridfire::solver::DynamicNetworkSolverStrategy, PyDynamicNetworkSolverStrategy>(m, "DynamicNetworkSolverStrategy");
auto py_direct_network_solver = py::class_<gridfire::solver::DirectNetworkSolver, gridfire::solver::DynamicNetworkSolverStrategy>(m, "DirectNetworkSolver");
py_direct_network_solver.def(py::init<gridfire::DynamicEngine&>(),
py::arg("engine"),
"Constructor for the DirectNetworkSolver. Takes a DynamicEngine instance to use for evaluating the network.");
py_direct_network_solver.def("evaluate",
&gridfire::solver::DirectNetworkSolver::evaluate,
py::arg("netIn"),
"Evaluate the network for a given timestep. Returns the output conditions after the timestep.");
}

View File

@@ -0,0 +1,5 @@
#pragma once
#include <pybind11/pybind11.h>
void register_solver_bindings(pybind11::module &m);

Some files were not shown because too many files have changed in this diff Show More