feat(python): added robust python bindings covering the entire codebase
This commit is contained in:
8
src/include/gridfire/engine/engine.h
Normal file
8
src/include/gridfire/engine/engine.h
Normal file
@@ -0,0 +1,8 @@
|
||||
#pragma once
|
||||
|
||||
#include "gridfire/engine/engine_abstract.h"
|
||||
#include "gridfire/engine/engine_graph.h"
|
||||
|
||||
#include "gridfire/engine/views/engine_views.h"
|
||||
#include "gridfire/engine/procedures/engine_procedures.h"
|
||||
#include "gridfire/engine/types/engine_types.h"
|
||||
@@ -0,0 +1,4 @@
|
||||
#pragma once
|
||||
|
||||
#include "gridfire/engine/procedures/construction.h"
|
||||
#include "gridfire/engine/procedures/priming.h"
|
||||
4
src/include/gridfire/engine/types/engine_types.h
Normal file
4
src/include/gridfire/engine/types/engine_types.h
Normal file
@@ -0,0 +1,4 @@
|
||||
#pragma once
|
||||
|
||||
#include "gridfire/engine/types/building.h"
|
||||
#include "gridfire/engine/types/reporting.h"
|
||||
7
src/include/gridfire/engine/views/engine_views.h
Normal file
7
src/include/gridfire/engine/views/engine_views.h
Normal file
@@ -0,0 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "gridfire/engine/views/engine_adaptive.h"
|
||||
#include "gridfire/engine/views/engine_defined.h"
|
||||
#include "gridfire/engine/views/engine_multiscale.h"
|
||||
#include "gridfire/engine/views/engine_priming.h"
|
||||
#include "gridfire/engine/views/engine_view_abstract.h"
|
||||
3
src/include/gridfire/exceptions/exceptions.h
Normal file
3
src/include/gridfire/exceptions/exceptions.h
Normal file
@@ -0,0 +1,3 @@
|
||||
#pragma once
|
||||
|
||||
#include "gridfire/exceptions/error_engine.h"
|
||||
3
src/include/gridfire/expectations/expectations.h
Normal file
3
src/include/gridfire/expectations/expectations.h
Normal file
@@ -0,0 +1,3 @@
|
||||
#pragma once
|
||||
|
||||
#include "gridfire/expectations/expected_engine.h"
|
||||
@@ -14,9 +14,16 @@ namespace gridfire::expectations {
|
||||
SYSTEM_RESIZED
|
||||
};
|
||||
|
||||
// TODO: rename this to EngineExpectation or something similar
|
||||
struct EngineError {
|
||||
std::string m_message;
|
||||
EngineErrorTypes type = EngineErrorTypes::FAILURE;
|
||||
const EngineErrorTypes type = EngineErrorTypes::FAILURE;
|
||||
|
||||
explicit EngineError(const std::string &message, const EngineErrorTypes type)
|
||||
: m_message(message), type(type) {}
|
||||
|
||||
virtual ~EngineError() = default;
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const EngineError& e) {
|
||||
os << "EngineError: " << e.m_message;
|
||||
return os;
|
||||
@@ -25,7 +32,9 @@ namespace gridfire::expectations {
|
||||
|
||||
struct EngineIndexError : EngineError {
|
||||
int m_index;
|
||||
EngineErrorTypes type = EngineErrorTypes::INDEX;
|
||||
|
||||
explicit EngineIndexError(const int index)
|
||||
: EngineError("Index error occurred", EngineErrorTypes::INDEX), m_index(index) {}
|
||||
friend std::ostream& operator<<(std::ostream& os, const EngineIndexError& e) {
|
||||
os << "EngineIndexError: " << e.m_message << " at index " << e.m_index;
|
||||
return os;
|
||||
@@ -33,10 +42,10 @@ namespace gridfire::expectations {
|
||||
};
|
||||
|
||||
struct StaleEngineError : EngineError {
|
||||
EngineErrorTypes type = EngineErrorTypes::STALE;
|
||||
StaleEngineErrorTypes staleType;
|
||||
|
||||
explicit StaleEngineError(StaleEngineErrorTypes staleType) : staleType(staleType) {}
|
||||
explicit StaleEngineError(const StaleEngineErrorTypes sType)
|
||||
: EngineError("Stale engine error occurred", EngineErrorTypes::STALE), staleType(sType) {}
|
||||
|
||||
explicit operator std::string() const {
|
||||
switch (staleType) {
|
||||
3
src/include/gridfire/io/io.h
Normal file
3
src/include/gridfire/io/io.h
Normal file
@@ -0,0 +1,3 @@
|
||||
#pragma once
|
||||
|
||||
#include "gridfire/io/network_file.h"
|
||||
8
src/include/gridfire/partition/partition.h
Normal file
8
src/include/gridfire/partition/partition.h
Normal file
@@ -0,0 +1,8 @@
|
||||
#pragma once
|
||||
|
||||
#include "gridfire/partition/partition_types.h"
|
||||
#include "gridfire/partition/partition_abstract.h"
|
||||
#include "gridfire/partition/partition_ground.h"
|
||||
#include "gridfire/partition/partition_rauscher_thielemann.h"
|
||||
#include "gridfire/partition/rauscher_thielemann_partition_data_record.h"
|
||||
#include "gridfire/partition/composite/partition_composite.h"
|
||||
6
src/include/gridfire/screening/screening.h
Normal file
6
src/include/gridfire/screening/screening.h
Normal file
@@ -0,0 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "gridfire/screening/screening_types.h"
|
||||
#include "gridfire/screening/screening_abstract.h"
|
||||
#include "gridfire/screening/screening_bare.h"
|
||||
#include "gridfire/screening/screening_weak.h"
|
||||
@@ -103,6 +103,6 @@ namespace gridfire::screening {
|
||||
const std::vector<ADDouble>& Y,
|
||||
const ADDouble T9,
|
||||
const ADDouble rho
|
||||
) const = 0;
|
||||
) const = 0;
|
||||
};
|
||||
}
|
||||
@@ -1 +1,80 @@
|
||||
subdir('network')
|
||||
# Define the library
|
||||
gridfire_sources = files(
|
||||
'lib/network.cpp',
|
||||
'lib/engine/engine_approx8.cpp',
|
||||
'lib/engine/engine_graph.cpp',
|
||||
'lib/engine/views/engine_adaptive.cpp',
|
||||
'lib/engine/views/engine_defined.cpp',
|
||||
'lib/engine/views/engine_multiscale.cpp',
|
||||
'lib/engine/views/engine_priming.cpp',
|
||||
'lib/engine/procedures/priming.cpp',
|
||||
'lib/engine/procedures/construction.cpp',
|
||||
'lib/reaction/reaction.cpp',
|
||||
'lib/reaction/reaclib.cpp',
|
||||
'lib/io/network_file.cpp',
|
||||
'lib/solver/solver.cpp',
|
||||
'lib/screening/screening_types.cpp',
|
||||
'lib/screening/screening_weak.cpp',
|
||||
'lib/screening/screening_bare.cpp',
|
||||
'lib/partition/partition_rauscher_thielemann.cpp',
|
||||
'lib/partition/partition_ground.cpp',
|
||||
'lib/partition/composite/partition_composite.cpp',
|
||||
'lib/utils/logging.cpp',
|
||||
)
|
||||
|
||||
|
||||
gridfire_build_dependencies = [
|
||||
boost_dep,
|
||||
const_dep,
|
||||
config_dep,
|
||||
composition_dep,
|
||||
cppad_dep,
|
||||
log_dep,
|
||||
xxhash_dep,
|
||||
eigen_dep,
|
||||
]
|
||||
|
||||
# Define the libnetwork library so it can be linked against by other parts of the build system
|
||||
libgridfire = library('gridfire',
|
||||
gridfire_sources,
|
||||
include_directories: include_directories('include'),
|
||||
dependencies: gridfire_build_dependencies,
|
||||
install : true)
|
||||
|
||||
gridfire_dep = declare_dependency(
|
||||
include_directories: include_directories('include'),
|
||||
link_with: libgridfire,
|
||||
sources: gridfire_sources,
|
||||
dependencies: gridfire_build_dependencies,
|
||||
)
|
||||
|
||||
# Make headers accessible
|
||||
gridfire_headers = files(
|
||||
'include/gridfire/network.h',
|
||||
'include/gridfire/engine/engine_abstract.h',
|
||||
'include/gridfire/engine/views/engine_view_abstract.h',
|
||||
'include/gridfire/engine/engine_approx8.h',
|
||||
'include/gridfire/engine/engine_graph.h',
|
||||
'include/gridfire/engine/views/engine_adaptive.h',
|
||||
'include/gridfire/engine/views/engine_defined.h',
|
||||
'include/gridfire/engine/views/engine_multiscale.h',
|
||||
'include/gridfire/engine/views/engine_priming.h',
|
||||
'include/gridfire/engine/procedures/priming.h',
|
||||
'include/gridfire/engine/procedures/construction.h',
|
||||
'include/gridfire/reaction/reaction.h',
|
||||
'include/gridfire/reaction/reaclib.h',
|
||||
'include/gridfire/io/network_file.h',
|
||||
'include/gridfire/solver/solver.h',
|
||||
'include/gridfire/screening/screening_abstract.h',
|
||||
'include/gridfire/screening/screening_bare.h',
|
||||
'include/gridfire/screening/screening_weak.h',
|
||||
'include/gridfire/screening/screening_types.h',
|
||||
'include/gridfire/partition/partition_abstract.h',
|
||||
'include/gridfire/partition/partition_rauscher_thielemann.h',
|
||||
'include/gridfire/partition/partition_ground.h',
|
||||
'include/gridfire/partition/composite/partition_composite.h',
|
||||
'include/gridfire/utils/logging.h',
|
||||
)
|
||||
install_headers(gridfire_headers, subdir : 'gridfire')
|
||||
|
||||
subdir('python')
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
# Define the library
|
||||
network_sources = files(
|
||||
'lib/network.cpp',
|
||||
'lib/engine/engine_approx8.cpp',
|
||||
'lib/engine/engine_graph.cpp',
|
||||
'lib/engine/views/engine_adaptive.cpp',
|
||||
'lib/engine/views/engine_defined.cpp',
|
||||
'lib/engine/views/engine_multiscale.cpp',
|
||||
'lib/engine/views/engine_priming.cpp',
|
||||
'lib/engine/procedures/priming.cpp',
|
||||
'lib/engine/procedures/construction.cpp',
|
||||
'lib/reaction/reaction.cpp',
|
||||
'lib/reaction/reaclib.cpp',
|
||||
'lib/io/network_file.cpp',
|
||||
'lib/solver/solver.cpp',
|
||||
'lib/screening/screening_types.cpp',
|
||||
'lib/screening/screening_weak.cpp',
|
||||
'lib/screening/screening_bare.cpp',
|
||||
'lib/partition/partition_rauscher_thielemann.cpp',
|
||||
'lib/partition/partition_ground.cpp',
|
||||
'lib/partition/composite/partition_composite.cpp',
|
||||
'lib/utils/logging.cpp',
|
||||
)
|
||||
|
||||
|
||||
dependencies = [
|
||||
boost_dep,
|
||||
const_dep,
|
||||
config_dep,
|
||||
composition_dep,
|
||||
cppad_dep,
|
||||
log_dep,
|
||||
xxhash_dep,
|
||||
eigen_dep,
|
||||
]
|
||||
|
||||
# Define the libnetwork library so it can be linked against by other parts of the build system
|
||||
libnetwork = library('network',
|
||||
network_sources,
|
||||
include_directories: include_directories('include'),
|
||||
dependencies: dependencies,
|
||||
install : true)
|
||||
|
||||
network_dep = declare_dependency(
|
||||
include_directories: include_directories('include'),
|
||||
link_with: libnetwork,
|
||||
sources: network_sources,
|
||||
dependencies: dependencies,
|
||||
)
|
||||
|
||||
# Make headers accessible
|
||||
network_headers = files(
|
||||
'include/gridfire/network.h',
|
||||
'include/gridfire/engine/engine_abstract.h',
|
||||
'include/gridfire/engine/views/engine_view_abstract.h',
|
||||
'include/gridfire/engine/engine_approx8.h',
|
||||
'include/gridfire/engine/engine_graph.h',
|
||||
'include/gridfire/engine/views/engine_adaptive.h',
|
||||
'include/gridfire/engine/views/engine_defined.h',
|
||||
'include/gridfire/engine/views/engine_multiscale.h',
|
||||
'include/gridfire/engine/views/engine_priming.h',
|
||||
'include/gridfire/engine/procedures/priming.h',
|
||||
'include/gridfire/engine/procedures/construction.h',
|
||||
'include/gridfire/reaction/reaction.h',
|
||||
'include/gridfire/reaction/reaclib.h',
|
||||
'include/gridfire/io/network_file.h',
|
||||
'include/gridfire/solver/solver.h',
|
||||
'include/gridfire/screening/screening_abstract.h',
|
||||
'include/gridfire/screening/screening_bare.h',
|
||||
'include/gridfire/screening/screening_weak.h',
|
||||
'include/gridfire/screening/screening_types.h',
|
||||
'include/gridfire/partition/partition_abstract.h',
|
||||
'include/gridfire/partition/partition_rauscher_thielemann.h',
|
||||
'include/gridfire/partition/partition_ground.h',
|
||||
'include/gridfire/partition/composite/partition_composite.h',
|
||||
'include/gridfire/utils/logging.h',
|
||||
)
|
||||
install_headers(network_headers, subdir : 'gridfire')
|
||||
54
src/python/bindings.cpp
Normal file
54
src/python/bindings.cpp
Normal file
@@ -0,0 +1,54 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "types/bindings.h"
|
||||
#include "partition/bindings.h"
|
||||
#include "expectations/bindings.h"
|
||||
#include "engine/bindings.h"
|
||||
#include "exceptions/bindings.h"
|
||||
#include "io/bindings.h"
|
||||
#include "reaction/bindings.h"
|
||||
#include "screening/bindings.h"
|
||||
#include "solver/bindings.h"
|
||||
#include "utils/bindings.h"
|
||||
|
||||
PYBIND11_MODULE(gridfire, m) {
|
||||
m.doc() = "Python bindings for the fourdst utility modules which are a part of the 4D-STAR project.";
|
||||
|
||||
pybind11::module::import("fourdst.constants");
|
||||
pybind11::module::import("fourdst.composition");
|
||||
pybind11::module::import("fourdst.config");
|
||||
pybind11::module::import("fourdst.atomic");
|
||||
|
||||
auto typeMod = m.def_submodule("type", "GridFire type bindings");
|
||||
register_type_bindings(typeMod);
|
||||
|
||||
auto partitionMod = m.def_submodule("partition", "GridFire partition function bindings");
|
||||
register_partition_bindings(partitionMod);
|
||||
|
||||
auto expectationMod = m.def_submodule("expectations", "GridFire expectations bindings");
|
||||
register_expectation_bindings(expectationMod);
|
||||
|
||||
auto reactionMod = m.def_submodule("reaction", "GridFire reaction bindings");
|
||||
register_reaction_bindings(reactionMod);
|
||||
|
||||
auto screeningMod = m.def_submodule("screening", "GridFire plasma screening bindings");
|
||||
register_screening_bindings(screeningMod);
|
||||
|
||||
auto ioMod = m.def_submodule("io", "GridFire io bindings");
|
||||
register_io_bindings(ioMod);
|
||||
|
||||
auto exceptionMod = m.def_submodule("exceptions", "GridFire exceptions bindings");
|
||||
register_exception_bindings(exceptionMod);
|
||||
|
||||
auto engineMod = m.def_submodule("engine", "Engine and Engine View bindings");
|
||||
register_engine_bindings(engineMod);
|
||||
|
||||
auto solverMod = m.def_submodule("solver", "GridFire numerical solver bindings");
|
||||
register_solver_bindings(solverMod);
|
||||
|
||||
auto utilsMod = m.def_submodule("utils", "GridFire utility method bindings");
|
||||
register_utils_bindings(utilsMod);
|
||||
}
|
||||
391
src/python/engine/bindings.cpp
Normal file
391
src/python/engine/bindings.cpp
Normal file
@@ -0,0 +1,391 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
|
||||
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc if needed directly
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "bindings.h"
|
||||
|
||||
#include "gridfire/engine/engine.h"
|
||||
#include "trampoline/py_engine.h"
|
||||
|
||||
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
concept IsDynamicEngine = std::is_base_of_v<gridfire::DynamicEngine, T>;
|
||||
|
||||
template <IsDynamicEngine T, IsDynamicEngine BaseT>
|
||||
void registerDynamicEngineDefs(py::class_<T, BaseT> pyClass) {
|
||||
pyClass.def("calculateRHSAndEnergy", &T::calculateRHSAndEnergy,
|
||||
py::arg("Y"),
|
||||
py::arg("T9"),
|
||||
py::arg("rho"),
|
||||
"Calculate the right-hand side (dY/dt) and energy generation rate."
|
||||
)
|
||||
.def("generateJacobianMatrix", py::overload_cast<const std::vector<double>&, double, double>(&T::generateJacobianMatrix, py::const_),
|
||||
py::arg("Y_dynamic"),
|
||||
py::arg("T9"),
|
||||
py::arg("rho"),
|
||||
"Generate the Jacobian matrix for the current state."
|
||||
)
|
||||
.def("generateStoichiometryMatrix", &T::generateStoichiometryMatrix)
|
||||
.def("calculateMolarReactionFlow",
|
||||
static_cast<double (T::*)(const gridfire::reaction::Reaction&, const std::vector<double>&, const double, const double) const>(&T::calculateMolarReactionFlow),
|
||||
py::arg("reaction"),
|
||||
py::arg("Y"),
|
||||
py::arg("T9"),
|
||||
py::arg("rho"),
|
||||
"Calculate the molar reaction flow for a given reaction."
|
||||
)
|
||||
.def("getNetworkSpecies", &T::getNetworkSpecies,
|
||||
"Get the list of species in the network."
|
||||
)
|
||||
.def("getNetworkReactions", &T::getNetworkReactions,
|
||||
"Get the set of logical reactions in the network."
|
||||
)
|
||||
.def ("setNetworkReactions", &T::setNetworkReactions,
|
||||
py::arg("reactions"),
|
||||
"Set the network reactions to a new set of reactions."
|
||||
)
|
||||
.def("getJacobianMatrixEntry", &T::getJacobianMatrixEntry,
|
||||
py::arg("i"),
|
||||
py::arg("j"),
|
||||
"Get an entry from the previously generated Jacobian matrix."
|
||||
)
|
||||
.def("getStoichiometryMatrixEntry", &T::getStoichiometryMatrixEntry,
|
||||
py::arg("speciesIndex"),
|
||||
py::arg("reactionIndex"),
|
||||
"Get an entry from the stoichiometry matrix."
|
||||
)
|
||||
.def("getSpeciesTimescales", &T::getSpeciesTimescales,
|
||||
py::arg("Y"),
|
||||
py::arg("T9"),
|
||||
py::arg("rho"),
|
||||
"Get the timescales for each species in the network."
|
||||
)
|
||||
.def("getSpeciesDestructionTimescales", &T::getSpeciesDestructionTimescales,
|
||||
py::arg("Y"),
|
||||
py::arg("T9"),
|
||||
py::arg("rho"),
|
||||
"Get the destruction timescales for each species in the network."
|
||||
)
|
||||
.def("update", &T::update,
|
||||
py::arg("netIn"),
|
||||
"Update the engine state based on the provided NetIn object."
|
||||
)
|
||||
.def("setScreeningModel", &T::setScreeningModel,
|
||||
py::arg("screeningModel"),
|
||||
"Set the screening model for the engine."
|
||||
)
|
||||
.def("getScreeningModel", &T::getScreeningModel,
|
||||
"Get the current screening model of the engine."
|
||||
)
|
||||
.def("getSpeciesIndex", &T::getSpeciesIndex,
|
||||
py::arg("species"),
|
||||
"Get the index of a species in the network."
|
||||
)
|
||||
.def("mapNetInToMolarAbundanceVector", &T::mapNetInToMolarAbundanceVector,
|
||||
py::arg("netIn"),
|
||||
"Map a NetIn object to a vector of molar abundances."
|
||||
)
|
||||
.def("primeEngine", &T::primeEngine,
|
||||
py::arg("netIn"),
|
||||
"Prime the engine with a NetIn object to prepare for calculations."
|
||||
)
|
||||
.def("getDepth", &T::getDepth,
|
||||
"Get the current build depth of the engine."
|
||||
)
|
||||
.def("rebuild", &T::rebuild,
|
||||
py::arg("composition"),
|
||||
py::arg("depth") = gridfire::NetworkBuildDepth::Full,
|
||||
"Rebuild the engine with a new composition and build depth."
|
||||
)
|
||||
.def("isStale", &T::isStale,
|
||||
py::arg("netIn"),
|
||||
"Check if the engine is stale based on the provided NetIn object."
|
||||
);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
void register_engine_bindings(py::module &m) {
|
||||
register_base_engine_bindings(m);
|
||||
register_engine_view_bindings(m);
|
||||
|
||||
m.def("build_reaclib_nuclear_network", &gridfire::build_reaclib_nuclear_network,
|
||||
py::arg("composition"),
|
||||
py::arg("maxLayers") = gridfire::NetworkBuildDepth::Full,
|
||||
py::arg("reverse") = false,
|
||||
"Build a nuclear network from a composition using ReacLib data."
|
||||
);
|
||||
|
||||
py::enum_<gridfire::PrimingReportStatus>(m, "PrimingReportStatus")
|
||||
.value("FULL_SUCCESS", gridfire::PrimingReportStatus::FULL_SUCCESS, "Priming was full successful.")
|
||||
.value("NO_SPECIES_TO_PRIME", gridfire::PrimingReportStatus::NO_SPECIES_TO_PRIME, "No species to prime.")
|
||||
.value("MAX_ITERATIONS_REACHED", gridfire::PrimingReportStatus::MAX_ITERATIONS_REACHED, "Maximum iterations reached during priming.")
|
||||
.value("FAILED_TO_FINALIZE_COMPOSITION", gridfire::PrimingReportStatus::FAILED_TO_FINALIZE_COMPOSITION, "Failed to finalize the composition after priming.")
|
||||
.value("FAILED_TO_FIND_CREATION_CHANNEL", gridfire::PrimingReportStatus::FAILED_TO_FIND_CREATION_CHANNEL, "Failed to find a creation channel for the priming species.")
|
||||
.value("FAILED_TO_FIND_PRIMING_REACTIONS", gridfire::PrimingReportStatus::FAILED_TO_FIND_PRIMING_REACTIONS, "Failed to find priming reactions for the species.")
|
||||
.value("BASE_NETWORK_TOO_SHALLOW", gridfire::PrimingReportStatus::BASE_NETWORK_TOO_SHALLOW, "The base network is too shallow for priming.")
|
||||
.export_values()
|
||||
.def("__repr__", [](const gridfire::PrimingReportStatus& status) {
|
||||
std::stringstream ss;
|
||||
ss << gridfire::PrimingReportStatusStrings.at(status) << "\n";
|
||||
return ss.str();
|
||||
},
|
||||
"String representation of the PrimingReport."
|
||||
);
|
||||
|
||||
py::class_<gridfire::PrimingReport>(m, "PrimingReport")
|
||||
.def_readonly("success", &gridfire::PrimingReport::success, "Indicates if the priming was successful.")
|
||||
.def_readonly("massFractionChanges", &gridfire::PrimingReport::massFractionChanges, "Map of species to their mass fraction changes after priming.")
|
||||
.def_readonly("primedComposition", &gridfire::PrimingReport::primedComposition, "The composition after priming.")
|
||||
.def_readonly("status", &gridfire::PrimingReport::status, "Status message from the priming process.")
|
||||
.def("__repr__", [](const gridfire::PrimingReport& report) {
|
||||
std::stringstream ss;
|
||||
ss << report;
|
||||
return ss.str();
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
void register_base_engine_bindings(pybind11::module &m) {
|
||||
|
||||
py::class_<gridfire::StepDerivatives<double>>(m, "StepDerivatives")
|
||||
.def_readonly("dYdt", &gridfire::StepDerivatives<double>::dydt, "The right-hand side (dY/dt) of the ODE system.")
|
||||
.def_readonly("energy", &gridfire::StepDerivatives<double>::nuclearEnergyGenerationRate, "The energy generation rate.");
|
||||
|
||||
py::class_<gridfire::SparsityPattern>(m, "SparsityPattern");
|
||||
|
||||
abs_stype_register_engine_bindings(m);
|
||||
abs_stype_register_dynamic_engine_bindings(m);
|
||||
con_stype_register_graph_engine_bindings(m);
|
||||
}
|
||||
|
||||
void abs_stype_register_engine_bindings(pybind11::module &m) {
|
||||
py::class_<gridfire::Engine, PyEngine>(m, "Engine");
|
||||
}
|
||||
|
||||
void abs_stype_register_dynamic_engine_bindings(pybind11::module &m) {
|
||||
const auto a = py::class_<gridfire::DynamicEngine, PyDynamicEngine>(m, "DynamicEngine");
|
||||
}
|
||||
|
||||
void con_stype_register_graph_engine_bindings(pybind11::module &m) {
|
||||
py::enum_<gridfire::NetworkBuildDepth>(m, "NetworkBuildDepth")
|
||||
.value("Full", gridfire::NetworkBuildDepth::Full, "Full network build depth")
|
||||
.value("Shallow", gridfire::NetworkBuildDepth::Shallow, "Shallow network build depth")
|
||||
.value("SecondOrder", gridfire::NetworkBuildDepth::SecondOrder, "Second order network build depth")
|
||||
.value("ThirdOrder", gridfire::NetworkBuildDepth::ThirdOrder, "Third order network build depth")
|
||||
.value("FourthOrder", gridfire::NetworkBuildDepth::FourthOrder, "Fourth order network build depth")
|
||||
.value("FifthOrder", gridfire::NetworkBuildDepth::FifthOrder, "Fifth order network build depth")
|
||||
.export_values();
|
||||
|
||||
py::class_<gridfire::BuildDepthType>(m, "BuildDepthType");
|
||||
|
||||
auto py_dynamic_engine_bindings = py::class_<gridfire::GraphEngine, gridfire::DynamicEngine>(m, "GraphEngine");
|
||||
|
||||
// Register the Graph Engine Specific Bindings
|
||||
py_dynamic_engine_bindings.def(py::init<const fourdst::composition::Composition &, const gridfire::BuildDepthType>(),
|
||||
py::arg("composition"),
|
||||
py::arg("depth") = gridfire::NetworkBuildDepth::Full,
|
||||
"Initialize GraphEngine with a composition and build depth."
|
||||
);
|
||||
py_dynamic_engine_bindings.def(py::init<const fourdst::composition::Composition &, const gridfire::partition::PartitionFunction &, const gridfire::BuildDepthType>(),
|
||||
py::arg("composition"),
|
||||
py::arg("partitionFunction"),
|
||||
py::arg("depth") = gridfire::NetworkBuildDepth::Full,
|
||||
"Initialize GraphEngine with a composition, partition function and build depth."
|
||||
);
|
||||
py_dynamic_engine_bindings.def(py::init<const gridfire::reaction::LogicalReactionSet &>(),
|
||||
py::arg("reactions"),
|
||||
"Initialize GraphEngine with a set of reactions."
|
||||
);
|
||||
py_dynamic_engine_bindings.def("generateJacobianMatrix", py::overload_cast<const std::vector<double>&, double, double, const gridfire::SparsityPattern&>(&gridfire::GraphEngine::generateJacobianMatrix, py::const_),
|
||||
py::arg("Y_dynamic"),
|
||||
py::arg("T9"),
|
||||
py::arg("rho"),
|
||||
py::arg("sparsityPattern"),
|
||||
"Generate the Jacobian matrix for the current state with a specified sparsity pattern."
|
||||
);
|
||||
py_dynamic_engine_bindings.def_static("getNetReactionStoichiometry", &gridfire::GraphEngine::getNetReactionStoichiometry,
|
||||
py::arg("reaction"),
|
||||
"Get the net stoichiometry for a given reaction."
|
||||
);
|
||||
py_dynamic_engine_bindings.def("involvesSpecies", &gridfire::GraphEngine::involvesSpecies,
|
||||
py::arg("species"),
|
||||
"Check if a given species is involved in the network."
|
||||
);
|
||||
py_dynamic_engine_bindings.def("exportToDot", &gridfire::GraphEngine::exportToDot,
|
||||
py::arg("filename"),
|
||||
"Export the network to a DOT file for visualization."
|
||||
);
|
||||
py_dynamic_engine_bindings.def("exportToCSV", &gridfire::GraphEngine::exportToCSV,
|
||||
py::arg("filename"),
|
||||
"Export the network to a CSV file for analysis."
|
||||
);
|
||||
py_dynamic_engine_bindings.def("setPrecomputation", &gridfire::GraphEngine::setPrecomputation,
|
||||
py::arg("precompute"),
|
||||
"Enable or disable precomputation for the engine."
|
||||
);
|
||||
py_dynamic_engine_bindings.def("isPrecomputationEnabled", &gridfire::GraphEngine::isPrecomputationEnabled,
|
||||
"Check if precomputation is enabled for the engine."
|
||||
);
|
||||
py_dynamic_engine_bindings.def("getPartitionFunction", &gridfire::GraphEngine::getPartitionFunction,
|
||||
"Get the partition function used by the engine."
|
||||
);
|
||||
py_dynamic_engine_bindings.def("calculateReverseRate", &gridfire::GraphEngine::calculateReverseRate,
|
||||
py::arg("reaction"),
|
||||
py::arg("T9"),
|
||||
"Calculate the reverse rate for a given reaction at a specific temperature."
|
||||
);
|
||||
py_dynamic_engine_bindings.def("calculateReverseRateTwoBody", &gridfire::GraphEngine::calculateReverseRateTwoBody,
|
||||
py::arg("reaction"),
|
||||
py::arg("T9"),
|
||||
py::arg("forwardRate"),
|
||||
py::arg("expFactor"),
|
||||
"Calculate the reverse rate for a two-body reaction at a specific temperature."
|
||||
);
|
||||
py_dynamic_engine_bindings.def("calculateReverseRateTwoBodyDerivative", &gridfire::GraphEngine::calculateReverseRateTwoBodyDerivative,
|
||||
py::arg("reaction"),
|
||||
py::arg("T9"),
|
||||
py::arg("reverseRate"),
|
||||
"Calculate the derivative of the reverse rate for a two-body reaction at a specific temperature."
|
||||
);
|
||||
py_dynamic_engine_bindings.def("isUsingReverseReactions", &gridfire::GraphEngine::isUsingReverseReactions,
|
||||
"Check if the engine is using reverse reactions."
|
||||
);
|
||||
py_dynamic_engine_bindings.def("setUseReverseReactions", &gridfire::GraphEngine::setUseReverseReactions,
|
||||
py::arg("useReverse"),
|
||||
"Enable or disable the use of reverse reactions in the engine."
|
||||
);
|
||||
|
||||
|
||||
// Register the general dynamic engine bindings
|
||||
registerDynamicEngineDefs<gridfire::GraphEngine, gridfire::DynamicEngine>(py_dynamic_engine_bindings);
|
||||
}
|
||||
|
||||
void register_engine_view_bindings(pybind11::module &m) {
|
||||
auto py_defined_engine_view_bindings = py::class_<gridfire::DefinedEngineView, gridfire::DynamicEngine>(m, "DefinedEngineView");
|
||||
|
||||
py_defined_engine_view_bindings.def(py::init<std::vector<std::string>, gridfire::DynamicEngine&>(),
|
||||
py::arg("peNames"),
|
||||
py::arg("baseEngine"),
|
||||
"Construct a defined engine view with a list of tracked reactions and a base engine.");
|
||||
py_defined_engine_view_bindings.def("getBaseEngine", &gridfire::DefinedEngineView::getBaseEngine,
|
||||
"Get the base engine associated with this defined engine view.");
|
||||
|
||||
registerDynamicEngineDefs<gridfire::DefinedEngineView, gridfire::DynamicEngine>(py_defined_engine_view_bindings);
|
||||
|
||||
auto py_file_defined_engine_view_bindings = py::class_<gridfire::FileDefinedEngineView, gridfire::DefinedEngineView>(m, "FileDefinedEngineView");
|
||||
py_file_defined_engine_view_bindings.def(py::init<gridfire::DynamicEngine&, const std::string&, const gridfire::io::NetworkFileParser&>(),
|
||||
py::arg("baseEngine"),
|
||||
py::arg("fileName"),
|
||||
py::arg("parser"),
|
||||
"Construct a defined engine view from a file and a base engine."
|
||||
);
|
||||
py_file_defined_engine_view_bindings.def("getNetworkFile", &gridfire::FileDefinedEngineView::getNetworkFile,
|
||||
"Get the network file associated with this defined engine view."
|
||||
);
|
||||
py_file_defined_engine_view_bindings.def("getParser", &gridfire::FileDefinedEngineView::getParser,
|
||||
"Get the parser used for this defined engine view."
|
||||
);
|
||||
py_file_defined_engine_view_bindings.def("getBaseEngine", &gridfire::FileDefinedEngineView::getBaseEngine,
|
||||
"Get the base engine associated with this file defined engine view.");
|
||||
|
||||
registerDynamicEngineDefs<gridfire::FileDefinedEngineView, gridfire::DefinedEngineView>(py_file_defined_engine_view_bindings);
|
||||
|
||||
auto py_priming_engine_view_bindings = py::class_<gridfire::NetworkPrimingEngineView, gridfire::DefinedEngineView>(m, "NetworkPrimingEngineView");
|
||||
py_priming_engine_view_bindings.def(py::init<const std::string&, gridfire::DynamicEngine&>(),
|
||||
py::arg("primingSymbol"),
|
||||
py::arg("baseEngine"),
|
||||
"Construct a priming engine view with a priming symbol and a base engine.");
|
||||
py_priming_engine_view_bindings.def(py::init<const fourdst::atomic::Species&, gridfire::DynamicEngine&>(),
|
||||
py::arg("primingSpecies"),
|
||||
py::arg("baseEngine"),
|
||||
"Construct a priming engine view with a priming species and a base engine.");
|
||||
py_priming_engine_view_bindings.def("getBaseEngine", &gridfire::NetworkPrimingEngineView::getBaseEngine,
|
||||
"Get the base engine associated with this priming engine view.");
|
||||
|
||||
registerDynamicEngineDefs<gridfire::NetworkPrimingEngineView, gridfire::DefinedEngineView>(py_priming_engine_view_bindings);
|
||||
|
||||
auto py_adaptive_engine_view_bindings = py::class_<gridfire::AdaptiveEngineView, gridfire::DynamicEngine>(m, "AdaptiveEngineView");
|
||||
py_adaptive_engine_view_bindings.def(py::init<gridfire::DynamicEngine&>(),
|
||||
py::arg("baseEngine"),
|
||||
"Construct an adaptive engine view with a base engine.");
|
||||
py_adaptive_engine_view_bindings.def("getBaseEngine", &gridfire::AdaptiveEngineView::getBaseEngine,
|
||||
"Get the base engine associated with this adaptive engine view.");
|
||||
|
||||
registerDynamicEngineDefs<gridfire::AdaptiveEngineView, gridfire::DynamicEngine>(py_adaptive_engine_view_bindings);
|
||||
|
||||
auto py_qse_cache_config = py::class_<gridfire::QSECacheConfig>(m, "QSECacheConfig");
|
||||
auto py_qse_cache_key = py::class_<gridfire::QSECacheKey>(m, "QSECacheKey");
|
||||
|
||||
py_qse_cache_key.def(py::init<double, double, const std::vector<double>&>(),
|
||||
py::arg("T9"),
|
||||
py::arg("rho"),
|
||||
py::arg("Y")
|
||||
);
|
||||
|
||||
py_qse_cache_key.def("hash", &gridfire::QSECacheKey::hash,
|
||||
"Get the pre-computed hash value of the key");
|
||||
|
||||
py_qse_cache_key.def_static("bin", &gridfire::QSECacheKey::bin,
|
||||
py::arg("value"),
|
||||
py::arg("tol"),
|
||||
"bin a value based on a tolerance");
|
||||
py_qse_cache_key.def("__eq__", &gridfire::QSECacheKey::operator==,
|
||||
py::arg("other"),
|
||||
"Check if two QSECacheKeys are equal");
|
||||
|
||||
auto py_multiscale_engine_view_bindings = py::class_<gridfire::MultiscalePartitioningEngineView, gridfire::DynamicEngine>(m, "MultiscalePartitioningEngineView");
|
||||
py_multiscale_engine_view_bindings.def(py::init<gridfire::GraphEngine&>(),
|
||||
py::arg("baseEngine"),
|
||||
"Construct a multiscale partitioning engine view with a base engine.");
|
||||
py_multiscale_engine_view_bindings.def("getBaseEngine", &gridfire::MultiscalePartitioningEngineView::getBaseEngine,
|
||||
"Get the base engine associated with this multiscale partitioning engine view.");
|
||||
py_multiscale_engine_view_bindings.def("analyzeTimescalePoolConnectivity", &gridfire::MultiscalePartitioningEngineView::analyzeTimescalePoolConnectivity,
|
||||
py::arg("timescale_pools"),
|
||||
py::arg("Y"),
|
||||
py::arg("T9"),
|
||||
py::arg("rho"),
|
||||
"Analyze the connectivity of timescale pools in the network.");
|
||||
py_multiscale_engine_view_bindings.def("partitionNetwork", py::overload_cast<const std::vector<double>&, double, double>(&gridfire::MultiscalePartitioningEngineView::partitionNetwork),
|
||||
py::arg("Y"),
|
||||
py::arg("T9"),
|
||||
py::arg("rho"),
|
||||
"Partition the network based on species timescales and connectivity.");
|
||||
py_multiscale_engine_view_bindings.def("partitionNetwork", py::overload_cast<const gridfire::NetIn&>(&gridfire::MultiscalePartitioningEngineView::partitionNetwork),
|
||||
py::arg("netIn"),
|
||||
"Partition the network based on a NetIn object.");
|
||||
py_multiscale_engine_view_bindings.def("exportToDot", &gridfire::MultiscalePartitioningEngineView::exportToDot,
|
||||
py::arg("filename"),
|
||||
py::arg("Y"),
|
||||
py::arg("T9"),
|
||||
py::arg("rho"),
|
||||
"Export the network to a DOT file for visualization.");
|
||||
py_multiscale_engine_view_bindings.def("getFastSpecies", &gridfire::MultiscalePartitioningEngineView::getFastSpecies,
|
||||
"Get the list of fast species in the network.");
|
||||
py_multiscale_engine_view_bindings.def("getDynamicSpecies", &gridfire::MultiscalePartitioningEngineView::getDynamicSpecies,
|
||||
"Get the list of dynamic species in the network.");
|
||||
py_multiscale_engine_view_bindings.def("equilibrateNetwork", py::overload_cast<const std::vector<double>&, double, double>(&gridfire::MultiscalePartitioningEngineView::equilibrateNetwork),
|
||||
py::arg("Y"),
|
||||
py::arg("T9"),
|
||||
py::arg("rho"),
|
||||
"Equilibrate the network based on species abundances and conditions.");
|
||||
py_multiscale_engine_view_bindings.def("equilibrateNetwork", py::overload_cast<const gridfire::NetIn&>(&gridfire::MultiscalePartitioningEngineView::equilibrateNetwork),
|
||||
py::arg("netIn"),
|
||||
"Equilibrate the network based on a NetIn object.");
|
||||
|
||||
registerDynamicEngineDefs<gridfire::MultiscalePartitioningEngineView, gridfire::DynamicEngine>(py_multiscale_engine_view_bindings);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
16
src/python/engine/bindings.h
Normal file
16
src/python/engine/bindings.h
Normal file
@@ -0,0 +1,16 @@
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
void register_engine_bindings(pybind11::module &m);
|
||||
|
||||
void register_base_engine_bindings(pybind11::module &m);
|
||||
|
||||
void register_engine_view_bindings(pybind11::module &m);
|
||||
|
||||
void abs_stype_register_engine_bindings(pybind11::module &m);
|
||||
void abs_stype_register_dynamic_engine_bindings(pybind11::module &m);
|
||||
|
||||
void con_stype_register_graph_engine_bindings(pybind11::module &m);
|
||||
|
||||
|
||||
21
src/python/engine/meson.build
Normal file
21
src/python/engine/meson.build
Normal file
@@ -0,0 +1,21 @@
|
||||
subdir('trampoline')
|
||||
|
||||
# Define the library
|
||||
bindings_sources = files('bindings.cpp')
|
||||
bindings_headers = files('bindings.h')
|
||||
|
||||
dependencies = [
|
||||
gridfire_dep,
|
||||
python3_dep,
|
||||
pybind11_dep,
|
||||
]
|
||||
|
||||
message('⏳ Python bindings for GridFire Engine are being registered...')
|
||||
shared_module('py_gf_engine',
|
||||
bindings_sources,
|
||||
cpp_args: ['-fvisibility=default'],
|
||||
install : true,
|
||||
dependencies: dependencies,
|
||||
include_directories: include_directories('.')
|
||||
)
|
||||
message('✅ Python bindings for GridFire Engine registered successfully!')
|
||||
21
src/python/engine/trampoline/meson.build
Normal file
21
src/python/engine/trampoline/meson.build
Normal file
@@ -0,0 +1,21 @@
|
||||
gf_engine_trampoline_sources = files('py_engine.cpp')
|
||||
|
||||
gf_engine_trapoline_dependencies = [
|
||||
gridfire_dep,
|
||||
pybind11_dep,
|
||||
python3_dep,
|
||||
]
|
||||
|
||||
gf_engine_trampoline_lib = static_library(
|
||||
'engine_trampolines',
|
||||
gf_engine_trampoline_sources,
|
||||
include_directories: include_directories('.'),
|
||||
dependencies: gf_engine_trapoline_dependencies,
|
||||
install: false,
|
||||
)
|
||||
|
||||
gr_engine_trampoline_dep = declare_dependency(
|
||||
link_with: gf_engine_trampoline_lib,
|
||||
include_directories: ('.'),
|
||||
dependencies: gf_engine_trapoline_dependencies,
|
||||
)
|
||||
219
src/python/engine/trampoline/py_engine.cpp
Normal file
219
src/python/engine/trampoline/py_engine.cpp
Normal file
@@ -0,0 +1,219 @@
|
||||
#include "py_engine.h"
|
||||
|
||||
#include "gridfire/engine/engine.h"
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <pybind11/functional.h> // Needed for std::function
|
||||
|
||||
#include <expected>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
const std::vector<fourdst::atomic::Species>& PyEngine::getNetworkSpecies() const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
std::vector<fourdst::atomic::Species>,
|
||||
gridfire::Engine, /* Base class */
|
||||
getNetworkSpecies
|
||||
);
|
||||
}
|
||||
|
||||
std::expected<gridfire::StepDerivatives<double>, gridfire::expectations::StaleEngineError> PyEngine::calculateRHSAndEnergy(const std::vector<double> &Y, double T9, double rho) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
PYBIND11_TYPE(std::expected<gridfire::StepDerivatives<double>, gridfire::expectations::StaleEngineError>),
|
||||
gridfire::Engine,
|
||||
calculateRHSAndEnergy,
|
||||
Y, T9, rho
|
||||
);
|
||||
}
|
||||
|
||||
///////////////////////////////////////
|
||||
/// PyDynamicEngine Implementation ///
|
||||
/////////////////////////////////////
|
||||
|
||||
const std::vector<fourdst::atomic::Species>& PyDynamicEngine::getNetworkSpecies() const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
std::vector<fourdst::atomic::Species>,
|
||||
gridfire::DynamicEngine, /* Base class */
|
||||
getNetworkSpecies
|
||||
);
|
||||
}
|
||||
std::expected<gridfire::StepDerivatives<double>, gridfire::expectations::StaleEngineError> PyDynamicEngine::calculateRHSAndEnergy(const std::vector<double> &Y, double T9, double rho) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
PYBIND11_TYPE(std::expected<gridfire::StepDerivatives<double>, gridfire::expectations::StaleEngineError>),
|
||||
gridfire::Engine,
|
||||
calculateRHSAndEnergy,
|
||||
Y, T9, rho
|
||||
);
|
||||
}
|
||||
|
||||
void PyDynamicEngine::generateJacobianMatrix(const std::vector<double> &Y_dynamic, double T9, double rho) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
void,
|
||||
gridfire::DynamicEngine,
|
||||
generateJacobianMatrix,
|
||||
Y_dynamic, T9, rho
|
||||
);
|
||||
}
|
||||
|
||||
void PyDynamicEngine::generateJacobianMatrix(const std::vector<double> &Y_dynamic, double T9, double rho, const gridfire::SparsityPattern &sparsityPattern) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
void,
|
||||
gridfire::DynamicEngine,
|
||||
generateJacobianMatrix,
|
||||
Y_dynamic, T9, rho, sparsityPattern
|
||||
);
|
||||
}
|
||||
|
||||
double PyDynamicEngine::getJacobianMatrixEntry(int i, int j) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
double,
|
||||
gridfire::DynamicEngine,
|
||||
getJacobianMatrixEntry,
|
||||
i, j
|
||||
);
|
||||
}
|
||||
|
||||
void PyDynamicEngine::generateStoichiometryMatrix() {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
void,
|
||||
gridfire::DynamicEngine,
|
||||
generateStoichiometryMatrix
|
||||
);
|
||||
}
|
||||
|
||||
int PyDynamicEngine::getStoichiometryMatrixEntry(int speciesIndex, int reactionIndex) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
int,
|
||||
gridfire::DynamicEngine,
|
||||
getStoichiometryMatrixEntry,
|
||||
speciesIndex, reactionIndex
|
||||
);
|
||||
}
|
||||
|
||||
double PyDynamicEngine::calculateMolarReactionFlow(const gridfire::reaction::Reaction &reaction, const std::vector<double> &Y, double T9, double rho) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
double,
|
||||
gridfire::DynamicEngine,
|
||||
calculateMolarReactionFlow,
|
||||
reaction, Y, T9, rho
|
||||
);
|
||||
}
|
||||
|
||||
const gridfire::reaction::LogicalReactionSet& PyDynamicEngine::getNetworkReactions() const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
const gridfire::reaction::LogicalReactionSet&,
|
||||
gridfire::DynamicEngine,
|
||||
getNetworkReactions
|
||||
);
|
||||
}
|
||||
|
||||
void PyDynamicEngine::setNetworkReactions(const gridfire::reaction::LogicalReactionSet& reactions) {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
void,
|
||||
gridfire::DynamicEngine,
|
||||
setNetworkReactions,
|
||||
reactions
|
||||
);
|
||||
}
|
||||
|
||||
std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::expectations::StaleEngineError> PyDynamicEngine::getSpeciesTimescales(const std::vector<double> &Y, double T9, double rho) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
PYBIND11_TYPE(std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::expectations::StaleEngineError>),
|
||||
gridfire::DynamicEngine,
|
||||
getSpeciesTimescales,
|
||||
Y, T9, rho
|
||||
);
|
||||
}
|
||||
|
||||
std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::expectations::StaleEngineError> PyDynamicEngine::getSpeciesDestructionTimescales(const std::vector<double> &Y, double T9, double rho) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
PYBIND11_TYPE(std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::expectations::StaleEngineError>),
|
||||
gridfire::DynamicEngine,
|
||||
getSpeciesDestructionTimescales,
|
||||
Y, T9, rho
|
||||
);
|
||||
}
|
||||
|
||||
fourdst::composition::Composition PyDynamicEngine::update(const gridfire::NetIn &netIn) {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
fourdst::composition::Composition,
|
||||
gridfire::DynamicEngine,
|
||||
update,
|
||||
netIn
|
||||
);
|
||||
}
|
||||
|
||||
bool PyDynamicEngine::isStale(const gridfire::NetIn &netIn) {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
bool,
|
||||
gridfire::DynamicEngine,
|
||||
isStale,
|
||||
netIn
|
||||
);
|
||||
}
|
||||
|
||||
void PyDynamicEngine::setScreeningModel(gridfire::screening::ScreeningType model) {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
void,
|
||||
gridfire::DynamicEngine,
|
||||
setScreeningModel,
|
||||
model
|
||||
);
|
||||
}
|
||||
|
||||
gridfire::screening::ScreeningType PyDynamicEngine::getScreeningModel() const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
gridfire::screening::ScreeningType,
|
||||
gridfire::DynamicEngine,
|
||||
getScreeningModel
|
||||
);
|
||||
}
|
||||
|
||||
int PyDynamicEngine::getSpeciesIndex(const fourdst::atomic::Species &species) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
int,
|
||||
gridfire::DynamicEngine,
|
||||
getSpeciesIndex,
|
||||
species
|
||||
);
|
||||
}
|
||||
|
||||
std::vector<double> PyDynamicEngine::mapNetInToMolarAbundanceVector(const gridfire::NetIn &netIn) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
std::vector<double>,
|
||||
gridfire::DynamicEngine,
|
||||
mapNetInToMolarAbundanceVector,
|
||||
netIn
|
||||
);
|
||||
}
|
||||
|
||||
gridfire::PrimingReport PyDynamicEngine::primeEngine(const gridfire::NetIn &netIn) {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
gridfire::PrimingReport,
|
||||
gridfire::DynamicEngine,
|
||||
primeEngine,
|
||||
netIn
|
||||
);
|
||||
}
|
||||
|
||||
const gridfire::Engine& PyEngineView::getBaseEngine() const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
const gridfire::Engine&,
|
||||
gridfire::EngineView<gridfire::Engine>,
|
||||
getBaseEngine
|
||||
);
|
||||
}
|
||||
|
||||
const gridfire::DynamicEngine& PyDynamicEngineView::getBaseEngine() const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
const gridfire::DynamicEngine&,
|
||||
gridfire::EngineView<gridfire::DynamicEngine>,
|
||||
getBaseEngine
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
52
src/python/engine/trampoline/py_engine.h
Normal file
52
src/python/engine/trampoline/py_engine.h
Normal file
@@ -0,0 +1,52 @@
|
||||
#pragma once
|
||||
|
||||
#include "gridfire/engine/engine.h"
|
||||
#include "gridfire/expectations/expected_engine.h"
|
||||
|
||||
#include "fourdst/composition/atomicSpecies.h"
|
||||
|
||||
#include <vector>
|
||||
#include <expected>
|
||||
|
||||
|
||||
class PyEngine final : public gridfire::Engine {
|
||||
public:
|
||||
const std::vector<fourdst::atomic::Species>& getNetworkSpecies() const override;
|
||||
std::expected<gridfire::StepDerivatives<double>,gridfire::expectations::StaleEngineError> calculateRHSAndEnergy(const std::vector<double> &Y, double T9, double rho) const override;
|
||||
};
|
||||
|
||||
class PyDynamicEngine final : public gridfire::DynamicEngine {
|
||||
const std::vector<fourdst::atomic::Species>& getNetworkSpecies() const override;
|
||||
std::expected<gridfire::StepDerivatives<double>,gridfire::expectations::StaleEngineError> calculateRHSAndEnergy(const std::vector<double> &Y, double T9, double rho) const override;
|
||||
void generateJacobianMatrix(const std::vector<double> &Y_dynamic, double T9, double rho) const override;
|
||||
void generateJacobianMatrix(const std::vector<double> &Y_dynamic, double T9, double rho, const gridfire::SparsityPattern &sparsityPattern) const override;
|
||||
double getJacobianMatrixEntry(int i, int j) const override;
|
||||
void generateStoichiometryMatrix() override;
|
||||
int getStoichiometryMatrixEntry(int speciesIndex, int reactionIndex) const override;
|
||||
double calculateMolarReactionFlow(const gridfire::reaction::Reaction &reaction, const std::vector<double> &Y, double T9, double rho) const override;
|
||||
const gridfire::reaction::LogicalReactionSet& getNetworkReactions() const override;
|
||||
void setNetworkReactions(const gridfire::reaction::LogicalReactionSet& reactions) override;
|
||||
std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::expectations::StaleEngineError> getSpeciesTimescales(const std::vector<double> &Y, double T9, double rho) const override;
|
||||
std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::expectations::StaleEngineError> getSpeciesDestructionTimescales(const std::vector<double> &Y, double T9, double rho) const override;
|
||||
fourdst::composition::Composition update(const gridfire::NetIn &netIn) override;
|
||||
bool isStale(const gridfire::NetIn &netIn) override;
|
||||
void setScreeningModel(gridfire::screening::ScreeningType model) override;
|
||||
gridfire::screening::ScreeningType getScreeningModel() const override;
|
||||
int getSpeciesIndex(const fourdst::atomic::Species &species) const override;
|
||||
std::vector<double> mapNetInToMolarAbundanceVector(const gridfire::NetIn &netIn) const override;
|
||||
gridfire::PrimingReport primeEngine(const gridfire::NetIn &netIn) override;
|
||||
gridfire::BuildDepthType getDepth() const override {
|
||||
throw std::logic_error("Network depth not supported by this engine.");
|
||||
}
|
||||
void rebuild(const fourdst::composition::Composition& comp, gridfire::BuildDepthType depth) override {
|
||||
throw std::logic_error("Setting network depth not supported by this engine.");
|
||||
}
|
||||
};
|
||||
|
||||
class PyEngineView final : public gridfire::EngineView<gridfire::Engine> {
|
||||
const gridfire::Engine& getBaseEngine() const override;
|
||||
};
|
||||
|
||||
class PyDynamicEngineView final : public gridfire::EngineView<gridfire::DynamicEngine> {
|
||||
const gridfire::DynamicEngine& getBaseEngine() const override;
|
||||
};
|
||||
45
src/python/exceptions/bindings.cpp
Normal file
45
src/python/exceptions/bindings.cpp
Normal file
@@ -0,0 +1,45 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
|
||||
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc if needed directly
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "bindings.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
#include "gridfire/exceptions/exceptions.h"
|
||||
|
||||
void register_exception_bindings(py::module &m) {
|
||||
py::register_exception<gridfire::exceptions::EngineError>(m, "GridFireEngineError");
|
||||
|
||||
// TODO: Make it so that we can grab the stale state in python
|
||||
// m.attr("StaleEngineTrigger") = py::register_exception<gridfire::exceptions::StaleEngineTrigger>(m, "StaleEngineTrigger", m.attr("GridFireEngineError"));
|
||||
m.attr("StaleEngineError") = py::register_exception<gridfire::exceptions::StaleEngineError>(m, "StaleEngineError", m.attr("GridFireEngineError"));
|
||||
m.attr("FailedToPartitionEngineError") = py::register_exception<gridfire::exceptions::FailedToPartitionEngineError>(m, "FailedToPartitionEngineError", m.attr("GridFireEngineError"));
|
||||
m.attr("NetworkResizedError") = py::register_exception<gridfire::exceptions::NetworkResizedError>(m, "NetworkResizedError", m.attr("GridFireEngineError"));
|
||||
m.attr("UnableToSetNetworkReactionsError") = py::register_exception<gridfire::exceptions::UnableToSetNetworkReactionsError>(m, "UnableToSetNetworkReactionsError", m.attr("GridFireEngineError"));
|
||||
|
||||
py::class_<gridfire::exceptions::StaleEngineTrigger::state>(m, "StaleEngineState")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("T9", &gridfire::exceptions::StaleEngineTrigger::state::m_T9)
|
||||
.def_readwrite("rho", &gridfire::exceptions::StaleEngineTrigger::state::m_rho)
|
||||
.def_readwrite("Y", &gridfire::exceptions::StaleEngineTrigger::state::m_Y)
|
||||
.def_readwrite("t", &gridfire::exceptions::StaleEngineTrigger::state::m_t)
|
||||
.def_readwrite("total_steps", &gridfire::exceptions::StaleEngineTrigger::state::m_total_steps)
|
||||
.def_readwrite("eps_nuc", &gridfire::exceptions::StaleEngineTrigger::state::m_eps_nuc);
|
||||
|
||||
py::class_<gridfire::exceptions::StaleEngineTrigger>(m, "StaleEngineTrigger")
|
||||
.def(py::init<const gridfire::exceptions::StaleEngineTrigger::state &>())
|
||||
.def("getState", &gridfire::exceptions::StaleEngineTrigger::getState)
|
||||
.def("numSpecies", &gridfire::exceptions::StaleEngineTrigger::numSpecies)
|
||||
.def("totalSteps", &gridfire::exceptions::StaleEngineTrigger::totalSteps)
|
||||
.def("energy", &gridfire::exceptions::StaleEngineTrigger::energy)
|
||||
.def("getMolarAbundance", &gridfire::exceptions::StaleEngineTrigger::getMolarAbundance)
|
||||
.def("temperature", &gridfire::exceptions::StaleEngineTrigger::temperature)
|
||||
.def("density", &gridfire::exceptions::StaleEngineTrigger::density)
|
||||
.def("__repr__", [&](const gridfire::exceptions::StaleEngineTrigger& self) {
|
||||
return self.what();
|
||||
});
|
||||
|
||||
}
|
||||
5
src/python/exceptions/bindings.h
Normal file
5
src/python/exceptions/bindings.h
Normal file
@@ -0,0 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
void register_exception_bindings(pybind11::module &m);
|
||||
17
src/python/exceptions/meson.build
Normal file
17
src/python/exceptions/meson.build
Normal file
@@ -0,0 +1,17 @@
|
||||
# Define the library
|
||||
bindings_sources = files('bindings.cpp')
|
||||
bindings_headers = files('bindings.h')
|
||||
|
||||
dependencies = [
|
||||
gridfire_dep,
|
||||
python3_dep,
|
||||
pybind11_dep,
|
||||
]
|
||||
|
||||
shared_module('py_gf_exceptions',
|
||||
bindings_sources,
|
||||
cpp_args: ['-fvisibility=default'],
|
||||
install : true,
|
||||
dependencies: dependencies,
|
||||
include_directories: include_directories('.')
|
||||
)
|
||||
43
src/python/expectations/bindings.cpp
Normal file
43
src/python/expectations/bindings.cpp
Normal file
@@ -0,0 +1,43 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
|
||||
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc. if needed directly
|
||||
|
||||
#include "bindings.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
#include "gridfire/expectations/expectations.h"
|
||||
|
||||
void register_expectation_bindings(py::module &m) {
|
||||
py::enum_<gridfire::expectations::EngineErrorTypes>(m, "EngineErrorTypes")
|
||||
.value("FAILURE", gridfire::expectations::EngineErrorTypes::FAILURE)
|
||||
.value("INDEX", gridfire::expectations::EngineErrorTypes::INDEX)
|
||||
.value("STALE", gridfire::expectations::EngineErrorTypes::STALE)
|
||||
.export_values();
|
||||
|
||||
py::enum_<gridfire::expectations::StaleEngineErrorTypes>(m, "StaleEngineErrorTypes")
|
||||
.value("SYSTEM_RESIZED", gridfire::expectations::StaleEngineErrorTypes::SYSTEM_RESIZED)
|
||||
.export_values();
|
||||
|
||||
// Bind the base class
|
||||
py::class_<gridfire::expectations::EngineError>(m, "EngineError")
|
||||
.def_readonly("message", &gridfire::expectations::EngineError::m_message)
|
||||
.def_readonly("type", &gridfire::expectations::EngineError::type)
|
||||
.def("__str__", [](const gridfire::expectations::EngineError &e) {return e.m_message;});
|
||||
|
||||
// Bind the EngineIndexError, specifying EngineError as the base
|
||||
py::class_<gridfire::expectations::EngineIndexError, gridfire::expectations::EngineError>(m, "EngineIndexError")
|
||||
.def(py::init<int>(), py::arg("index"))
|
||||
.def_readonly("index", &gridfire::expectations::EngineIndexError::m_index)
|
||||
.def("__str__", [](const gridfire::expectations::EngineIndexError &e) {
|
||||
return e.m_message + " at index " + std::to_string(e.m_index);
|
||||
});
|
||||
|
||||
// Bind the StaleEngineError, specifying EngineError as the base
|
||||
py::class_<gridfire::expectations::StaleEngineError, gridfire::expectations::EngineError>(m, "StaleEngineError")
|
||||
.def(py::init<gridfire::expectations::StaleEngineErrorTypes>(), py::arg("stale_type"))
|
||||
.def_readonly("stale_type", &gridfire::expectations::StaleEngineError::staleType)
|
||||
.def("__str__", [](const gridfire::expectations::StaleEngineError &e) {
|
||||
return static_cast<std::string>(e);
|
||||
});
|
||||
}
|
||||
5
src/python/expectations/bindings.h
Normal file
5
src/python/expectations/bindings.h
Normal file
@@ -0,0 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
void register_expectation_bindings(pybind11::module &m);
|
||||
17
src/python/expectations/meson.build
Normal file
17
src/python/expectations/meson.build
Normal file
@@ -0,0 +1,17 @@
|
||||
# Define the library
|
||||
bindings_sources = files('bindings.cpp')
|
||||
bindings_headers = files('bindings.h')
|
||||
|
||||
dependencies = [
|
||||
gridfire_dep,
|
||||
python3_dep,
|
||||
pybind11_dep,
|
||||
]
|
||||
|
||||
shared_module('py_gf_expectations',
|
||||
bindings_sources,
|
||||
cpp_args: ['-fvisibility=default'],
|
||||
install : true,
|
||||
dependencies: dependencies,
|
||||
include_directories: include_directories('.')
|
||||
)
|
||||
29
src/python/io/bindings.cpp
Normal file
29
src/python/io/bindings.cpp
Normal file
@@ -0,0 +1,29 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
|
||||
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc if needed directly
|
||||
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
|
||||
#include "bindings.h"
|
||||
|
||||
#include "gridfire/io/io.h"
|
||||
#include "trampoline/py_io.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void register_io_bindings(py::module &m) {
|
||||
py::class_<gridfire::io::ParsedNetworkData>(m, "ParsedNetworkData");
|
||||
|
||||
py::class_<gridfire::io::NetworkFileParser, PyNetworkFileParser>(m, "NetworkFileParser");
|
||||
|
||||
py::class_<gridfire::io::SimpleReactionListFileParser, gridfire::io::NetworkFileParser>(m, "SimpleReactionListFileParser")
|
||||
.def("parse", &gridfire::io::SimpleReactionListFileParser::parse,
|
||||
py::arg("filename"),
|
||||
"Parse a simple reaction list file and return a ParsedNetworkData object.");
|
||||
|
||||
// py::class_<gridfire::io::MESANetworkFileParser, gridfire::io::NetworkFileParser>(m, "MESANetworkFileParser")
|
||||
// .def("parse", &gridfire::io::MESANetworkFileParser::parse,
|
||||
// py::arg("filename"),
|
||||
// "Parse a MESA network file and return a ParsedNetworkData object.");
|
||||
}
|
||||
5
src/python/io/bindings.h
Normal file
5
src/python/io/bindings.h
Normal file
@@ -0,0 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
void register_io_bindings(pybind11::module &m);
|
||||
17
src/python/io/meson.build
Normal file
17
src/python/io/meson.build
Normal file
@@ -0,0 +1,17 @@
|
||||
# Define the library
|
||||
bindings_sources = files('bindings.cpp')
|
||||
bindings_headers = files('bindings.h')
|
||||
|
||||
dependencies = [
|
||||
gridfire_dep,
|
||||
python3_dep,
|
||||
pybind11_dep,
|
||||
]
|
||||
|
||||
shared_module('py_gf_io',
|
||||
bindings_sources,
|
||||
cpp_args: ['-fvisibility=default'],
|
||||
install : true,
|
||||
dependencies: dependencies,
|
||||
include_directories: include_directories('.')
|
||||
)
|
||||
21
src/python/io/trampoline/meson.build
Normal file
21
src/python/io/trampoline/meson.build
Normal file
@@ -0,0 +1,21 @@
|
||||
gf_io_trampoline_sources = files('py_io.cpp')
|
||||
|
||||
gf_io_trapoline_dependencies = [
|
||||
gridfire_dep,
|
||||
pybind11_dep,
|
||||
python3_dep,
|
||||
]
|
||||
|
||||
gf_io_trampoline_lib = static_library(
|
||||
'io_trampolines',
|
||||
gf_io_trampoline_sources,
|
||||
include_directories: include_directories('.'),
|
||||
dependencies: gf_io_trapoline_dependencies,
|
||||
install: false,
|
||||
)
|
||||
|
||||
gr_io_trampoline_dep = declare_dependency(
|
||||
link_with: gf_io_trampoline_lib,
|
||||
include_directories: ('.'),
|
||||
dependencies: gf_io_trapoline_dependencies,
|
||||
)
|
||||
14
src/python/io/trampoline/py_io.cpp
Normal file
14
src/python/io/trampoline/py_io.cpp
Normal file
@@ -0,0 +1,14 @@
|
||||
#include "gridfire/io/io.h"
|
||||
#include "py_io.h"
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
gridfire::io::ParsedNetworkData PyNetworkFileParser::parse(const std::string &filename) const {
|
||||
PYBIND11_OVERLOAD_PURE(
|
||||
gridfire::io::ParsedNetworkData,
|
||||
gridfire::io::NetworkFileParser,
|
||||
parse // Method name
|
||||
);
|
||||
}
|
||||
7
src/python/io/trampoline/py_io.h
Normal file
7
src/python/io/trampoline/py_io.h
Normal file
@@ -0,0 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "gridfire/io/io.h"
|
||||
|
||||
class PyNetworkFileParser final : public gridfire::io::NetworkFileParser {
|
||||
gridfire::io::ParsedNetworkData parse(const std::string &filename) const override;
|
||||
};
|
||||
10
src/python/meson.build
Normal file
10
src/python/meson.build
Normal file
@@ -0,0 +1,10 @@
|
||||
subdir('types')
|
||||
subdir('utils')
|
||||
subdir('expectations')
|
||||
subdir('exceptions')
|
||||
subdir('io')
|
||||
subdir('partition')
|
||||
subdir('reaction')
|
||||
subdir('screening')
|
||||
subdir('engine')
|
||||
subdir('solver')
|
||||
113
src/python/partition/bindings.cpp
Normal file
113
src/python/partition/bindings.cpp
Normal file
@@ -0,0 +1,113 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
|
||||
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc if needed directly
|
||||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
|
||||
#include "bindings.h"
|
||||
|
||||
#include "gridfire/partition/partition.h"
|
||||
|
||||
PYBIND11_DECLARE_HOLDER_TYPE(T, std::unique_ptr<T>, true) // Declare unique_ptr as a holder type for pybind11
|
||||
|
||||
#include "trampoline/py_partition.h"
|
||||
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
|
||||
void register_partition_bindings(pybind11::module &m) {
|
||||
using PF = gridfire::partition::PartitionFunction;
|
||||
py::class_<PF, PyPartitionFunction>(m, "PartitionFunction");
|
||||
|
||||
register_partition_types_bindings(m);
|
||||
register_ground_state_partition_bindings(m);
|
||||
register_rauscher_thielemann_partition_data_record_bindings(m);
|
||||
register_rauscher_thielemann_partition_bindings(m);
|
||||
|
||||
register_composite_partition_bindings(m);
|
||||
}
|
||||
|
||||
void register_partition_types_bindings(pybind11::module &m) {
|
||||
py::enum_<gridfire::partition::BasePartitionType>(m, "BasePartitionType")
|
||||
.value("RauscherThielemann", gridfire::partition::BasePartitionType::RauscherThielemann)
|
||||
.value("GroundState", gridfire::partition::BasePartitionType::GroundState)
|
||||
.export_values();
|
||||
|
||||
m.def("basePartitionTypeToString", [](gridfire::partition::BasePartitionType type) {
|
||||
return gridfire::partition::basePartitionTypeToString[type];
|
||||
}, py::arg("type"), "Convert BasePartitionType to string.");
|
||||
|
||||
m.def("stringToBasePartitionType", [](const std::string &typeStr) {
|
||||
return gridfire::partition::stringToBasePartitionType[typeStr];
|
||||
}, py::arg("typeStr"), "Convert string to BasePartitionType.");
|
||||
}
|
||||
|
||||
void register_ground_state_partition_bindings(pybind11::module &m) {
|
||||
using GSPF = gridfire::partition::GroundStatePartitionFunction;
|
||||
using PF = gridfire::partition::PartitionFunction;
|
||||
py::class_<GSPF, PF>(m, "GroundStatePartitionFunction")
|
||||
.def(py::init<>())
|
||||
.def("evaluate", &gridfire::partition::GroundStatePartitionFunction::evaluate,
|
||||
py::arg("z"), py::arg("a"), py::arg("T9"),
|
||||
"Evaluate the ground state partition function for given Z, A, and T9.")
|
||||
.def("evaluateDerivative", &gridfire::partition::GroundStatePartitionFunction::evaluateDerivative,
|
||||
py::arg("z"), py::arg("a"), py::arg("T9"),
|
||||
"Evaluate the derivative of the ground state partition function for given Z, A, and T9.")
|
||||
.def("supports", &gridfire::partition::GroundStatePartitionFunction::supports,
|
||||
py::arg("z"), py::arg("a"),
|
||||
"Check if the ground state partition function supports given Z and A.")
|
||||
.def("get_type", &gridfire::partition::GroundStatePartitionFunction::type,
|
||||
"Get the type of the partition function (should return 'GroundState').");
|
||||
}
|
||||
|
||||
void register_rauscher_thielemann_partition_data_record_bindings(pybind11::module &m) {
|
||||
py::class_<gridfire::partition::record::RauscherThielemannPartitionDataRecord>(m, "RauscherThielemannPartitionDataRecord")
|
||||
.def_readonly("z", &gridfire::partition::record::RauscherThielemannPartitionDataRecord::z, "Atomic number")
|
||||
.def_readonly("a", &gridfire::partition::record::RauscherThielemannPartitionDataRecord::a, "Mass number")
|
||||
.def_readonly("ground_state_spin", &gridfire::partition::record::RauscherThielemannPartitionDataRecord::ground_state_spin, "Ground state spin")
|
||||
.def_readonly("normalized_g_values", &gridfire::partition::record::RauscherThielemannPartitionDataRecord::normalized_g_values, "Normalized g-values for the first 24 energy levels");
|
||||
}
|
||||
|
||||
|
||||
void register_rauscher_thielemann_partition_bindings(pybind11::module &m) {
|
||||
using RTPF = gridfire::partition::RauscherThielemannPartitionFunction;
|
||||
using PF = gridfire::partition::PartitionFunction;
|
||||
py::class_<RTPF, PF>(m, "RauscherThielemannPartitionFunction")
|
||||
.def(py::init<>())
|
||||
.def("evaluate", &gridfire::partition::RauscherThielemannPartitionFunction::evaluate,
|
||||
py::arg("z"), py::arg("a"), py::arg("T9"),
|
||||
"Evaluate the Rauscher-Thielemann partition function for given Z, A, and T9.")
|
||||
.def("evaluateDerivative", &gridfire::partition::RauscherThielemannPartitionFunction::evaluateDerivative,
|
||||
py::arg("z"), py::arg("a"), py::arg("T9"),
|
||||
"Evaluate the derivative of the Rauscher-Thielemann partition function for given Z, A, and T9.")
|
||||
.def("supports", &gridfire::partition::RauscherThielemannPartitionFunction::supports,
|
||||
py::arg("z"), py::arg("a"),
|
||||
"Check if the Rauscher-Thielemann partition function supports given Z and A.")
|
||||
.def("get_type", &gridfire::partition::RauscherThielemannPartitionFunction::type,
|
||||
"Get the type of the partition function (should return 'RauscherThielemann').");
|
||||
}
|
||||
|
||||
void register_composite_partition_bindings(pybind11::module &m) {
|
||||
py::class_<gridfire::partition::CompositePartitionFunction>(m, "CompositePartitionFunction")
|
||||
.def(py::init<const std::vector<gridfire::partition::BasePartitionType>&>(),
|
||||
py::arg("partitionFunctions"),
|
||||
"Create a composite partition function from a list of base partition types.")
|
||||
.def(py::init<const gridfire::partition::CompositePartitionFunction&>(),
|
||||
"Copy constructor for CompositePartitionFunction.")
|
||||
.def("evaluate", &gridfire::partition::CompositePartitionFunction::evaluate,
|
||||
py::arg("z"), py::arg("a"), py::arg("T9"),
|
||||
"Evaluate the composite partition function for given Z, A, and T9.")
|
||||
.def("evaluateDerivative", &gridfire::partition::CompositePartitionFunction::evaluateDerivative,
|
||||
py::arg("z"), py::arg("a"), py::arg("T9"),
|
||||
"Evaluate the derivative of the composite partition function for given Z, A, and T9.")
|
||||
.def("supports", &gridfire::partition::CompositePartitionFunction::supports,
|
||||
py::arg("z"), py::arg("a"),
|
||||
"Check if the composite partition function supports given Z and A.")
|
||||
.def("get_type", &gridfire::partition::CompositePartitionFunction::type,
|
||||
"Get the type of the partition function (should return 'Composite').");
|
||||
}
|
||||
|
||||
|
||||
|
||||
16
src/python/partition/bindings.h
Normal file
16
src/python/partition/bindings.h
Normal file
@@ -0,0 +1,16 @@
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
void register_partition_bindings(pybind11::module &m);
|
||||
|
||||
void register_partition_types_bindings(pybind11::module &m);
|
||||
|
||||
void register_ground_state_partition_bindings(pybind11::module &m);
|
||||
|
||||
void register_rauscher_thielemann_partition_data_record_bindings(pybind11::module &m);
|
||||
|
||||
void register_rauscher_thielemann_partition_bindings(pybind11::module &m);
|
||||
|
||||
void register_composite_partition_bindings(pybind11::module &m);
|
||||
|
||||
19
src/python/partition/meson.build
Normal file
19
src/python/partition/meson.build
Normal file
@@ -0,0 +1,19 @@
|
||||
subdir('trampoline')
|
||||
|
||||
# Define the library
|
||||
bindings_sources = files('bindings.cpp')
|
||||
bindings_headers = files('bindings.h')
|
||||
|
||||
dependencies = [
|
||||
gridfire_dep,
|
||||
python3_dep,
|
||||
pybind11_dep,
|
||||
]
|
||||
|
||||
shared_module('py_gf_partition',
|
||||
bindings_sources,
|
||||
cpp_args: ['-fvisibility=default'],
|
||||
install : true,
|
||||
dependencies: dependencies,
|
||||
include_directories: include_directories('.')
|
||||
)
|
||||
21
src/python/partition/trampoline/meson.build
Normal file
21
src/python/partition/trampoline/meson.build
Normal file
@@ -0,0 +1,21 @@
|
||||
gf_partition_trampoline_sources = files('py_partition.cpp')
|
||||
|
||||
gf_partition_trapoline_dependencies = [
|
||||
gridfire_dep,
|
||||
pybind11_dep,
|
||||
python3_dep,
|
||||
]
|
||||
|
||||
gf_partition_trampoline_lib = static_library(
|
||||
'partition_trampolines',
|
||||
gf_partition_trampoline_sources,
|
||||
include_directories: include_directories('.'),
|
||||
dependencies: gf_partition_trapoline_dependencies,
|
||||
install: false,
|
||||
)
|
||||
|
||||
gr_partition_trampoline_dep = declare_dependency(
|
||||
link_with: gf_partition_trampoline_lib,
|
||||
include_directories: ('.'),
|
||||
dependencies: gf_partition_trapoline_dependencies,
|
||||
)
|
||||
57
src/python/partition/trampoline/py_partition.cpp
Normal file
57
src/python/partition/trampoline/py_partition.cpp
Normal file
@@ -0,0 +1,57 @@
|
||||
#include "py_partition.h"
|
||||
|
||||
#include "gridfire/partition/partition.h"
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
|
||||
double PyPartitionFunction::evaluate(int z, int a, double T9) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
double,
|
||||
gridfire::partition::PartitionFunction,
|
||||
evaluate,
|
||||
z, a, T9
|
||||
);
|
||||
}
|
||||
|
||||
double PyPartitionFunction::evaluateDerivative(int z, int a, double T9) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
double,
|
||||
gridfire::partition::PartitionFunction,
|
||||
evaluateDerivative,
|
||||
z, a, T9
|
||||
);
|
||||
}
|
||||
|
||||
bool PyPartitionFunction::supports(int z, int a) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
bool,
|
||||
gridfire::partition::PartitionFunction,
|
||||
supports,
|
||||
z, a
|
||||
);
|
||||
}
|
||||
|
||||
std::string PyPartitionFunction::type() const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
std::string,
|
||||
gridfire::partition::PartitionFunction,
|
||||
type
|
||||
);
|
||||
}
|
||||
|
||||
std::unique_ptr<gridfire::partition::PartitionFunction> PyPartitionFunction::clone() const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
std::unique_ptr<gridfire::partition::PartitionFunction>,
|
||||
gridfire::partition::PartitionFunction,
|
||||
clone
|
||||
);
|
||||
}
|
||||
|
||||
15
src/python/partition/trampoline/py_partition.h
Normal file
15
src/python/partition/trampoline/py_partition.h
Normal file
@@ -0,0 +1,15 @@
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include "gridfire/partition/partition.h"
|
||||
|
||||
|
||||
class PyPartitionFunction final : public gridfire::partition::PartitionFunction {
|
||||
double evaluate(int z, int a, double T9) const override;
|
||||
double evaluateDerivative(int z, int a, double T9) const override;
|
||||
bool supports(int z, int a) const override;
|
||||
std::string type() const override;
|
||||
std::unique_ptr<gridfire::partition::PartitionFunction> clone() const override;
|
||||
};
|
||||
171
src/python/reaction/bindings.cpp
Normal file
171
src/python/reaction/bindings.cpp
Normal file
@@ -0,0 +1,171 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
|
||||
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc if needed directly
|
||||
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
|
||||
#include "bindings.h"
|
||||
|
||||
#include "gridfire/reaction/reaction.h"
|
||||
#include "gridfire/reaction/reaclib.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
|
||||
void register_reaction_bindings(py::module &m) {
|
||||
py::class_<gridfire::reaction::RateCoefficientSet>(m, "RateCoefficientSet")
|
||||
.def(py::init<double, double, double, double, double, double, double>(),
|
||||
py::arg("a0"), py::arg("a1"), py::arg("a2"), py::arg("a3"),
|
||||
py::arg("a4"), py::arg("a5"), py::arg("a6"),
|
||||
"Construct a RateCoefficientSet with the given parameters."
|
||||
);
|
||||
|
||||
using fourdst::atomic::Species;
|
||||
py::class_<gridfire::reaction::Reaction>(m, "Reaction")
|
||||
.def(py::init<const std::string_view, const std::string_view, int, const std::vector<Species>&, const std::vector<Species>&, double, std::string_view, gridfire::reaction::RateCoefficientSet, bool>(),
|
||||
py::arg("id"), py::arg("peName"), py::arg("chapter"),
|
||||
py::arg("reactants"), py::arg("products"), py::arg("qValue"),
|
||||
py::arg("label"), py::arg("sets"), py::arg("reverse") = false,
|
||||
"Construct a Reaction with the given parameters.")
|
||||
.def("calculate_rate", static_cast<double (gridfire::reaction::Reaction::*)(double) const>(&gridfire::reaction::Reaction::calculate_rate),
|
||||
py::arg("T9"), "Calculate the reaction rate at a given temperature T9 (in units of 10^9 K).")
|
||||
.def("peName", &gridfire::reaction::Reaction::peName,
|
||||
"Get the reaction name in (projectile, ejectile) notation (e.g., 'p(p,g)d').")
|
||||
.def("chapter", &gridfire::reaction::Reaction::chapter,
|
||||
"Get the REACLIB chapter number defining the reaction structure.")
|
||||
.def("sourceLabel", &gridfire::reaction::Reaction::sourceLabel,
|
||||
"Get the source label for the rate data (e.g., 'wc12w', 'st08').")
|
||||
.def("rateCoefficients", &gridfire::reaction::Reaction::rateCoefficients,
|
||||
"get the set of rate coefficients.")
|
||||
.def("contains", &gridfire::reaction::Reaction::contains,
|
||||
py::arg("species"), "Check if the reaction contains a specific species.")
|
||||
.def("contains_reactant", &gridfire::reaction::Reaction::contains_reactant,
|
||||
"Check if the reaction contains a specific reactant species.")
|
||||
.def("contains_product", &gridfire::reaction::Reaction::contains_product,
|
||||
"Check if the reaction contains a specific product species.")
|
||||
.def("all_species", &gridfire::reaction::Reaction::all_species,
|
||||
"Get all species involved in the reaction (both reactants and products) as a set.")
|
||||
.def("reactant_species", &gridfire::reaction::Reaction::reactant_species,
|
||||
"Get the reactant species of the reaction as a set.")
|
||||
.def("product_species", &gridfire::reaction::Reaction::product_species,
|
||||
"Get the product species of the reaction as a set.")
|
||||
.def("num_species", &gridfire::reaction::Reaction::num_species,
|
||||
"Count the number of species in the reaction.")
|
||||
.def("stoichiometry", static_cast<int (gridfire::reaction::Reaction::*)(const Species&) const>(&gridfire::reaction::Reaction::stoichiometry),
|
||||
py::arg("species"),
|
||||
"Get the stoichiometry of the reaction as a map from species to their coefficients.")
|
||||
.def("stoichiometry", static_cast<std::unordered_map<Species, int> (gridfire::reaction::Reaction::*)() const>(&gridfire::reaction::Reaction::stoichiometry),
|
||||
"Get the stoichiometry of the reaction as a map from species to their coefficients.")
|
||||
.def("id", &gridfire::reaction::Reaction::id,
|
||||
"Get the unique identifier of the reaction.")
|
||||
.def("qValue", &gridfire::reaction::Reaction::qValue,
|
||||
"Get the Q-value of the reaction in MeV.")
|
||||
.def("reactants", &gridfire::reaction::Reaction::reactants,
|
||||
"Get a list of reactant species in the reaction.")
|
||||
.def("products", &gridfire::reaction::Reaction::products,
|
||||
"Get a list of product species in the reaction.")
|
||||
.def("is_reverse", &gridfire::reaction::Reaction::is_reverse,
|
||||
"Check if this is a reverse reaction rate.")
|
||||
.def("excess_energy", &gridfire::reaction::Reaction::excess_energy,
|
||||
"Calculate the excess energy from the mass difference of reactants and products.")
|
||||
.def("__eq__", &gridfire::reaction::Reaction::operator==,
|
||||
"Equality operator for reactions based on their IDs.")
|
||||
.def("__neq__", &gridfire::reaction::Reaction::operator!=,
|
||||
"Inequality operator for reactions based on their IDs.")
|
||||
.def("hash", &gridfire::reaction::Reaction::hash,
|
||||
py::arg("seed") = 0,
|
||||
"Compute a hash for the reaction based on its ID.")
|
||||
.def("__repr__", [](const gridfire::reaction::Reaction& self) {
|
||||
std::stringstream ss;
|
||||
ss << self; // Use the existing operator<< for Reaction
|
||||
return ss.str();
|
||||
});
|
||||
|
||||
py::class_<gridfire::reaction::LogicalReaction, gridfire::reaction::Reaction>(m, "LogicalReaction")
|
||||
.def(py::init<const std::vector<gridfire::reaction::Reaction>>(),
|
||||
py::arg("reactions"),
|
||||
"Construct a LogicalReaction from a vector of Reaction objects.")
|
||||
.def("add_reaction", &gridfire::reaction::LogicalReaction::add_reaction,
|
||||
py::arg("reaction"),
|
||||
"Add another Reaction source to this logical reaction.")
|
||||
.def("size", &gridfire::reaction::LogicalReaction::size,
|
||||
"Get the number of source rates contributing to this logical reaction.")
|
||||
.def("__len__", &gridfire::reaction::LogicalReaction::size,
|
||||
"Overload len() to return the number of source rates.")
|
||||
.def("sources", &gridfire::reaction::LogicalReaction::sources,
|
||||
"Get the list of source labels for the aggregated rates.")
|
||||
.def("calculate_rate", static_cast<double (gridfire::reaction::LogicalReaction::*)(double) const>(&gridfire::reaction::LogicalReaction::calculate_rate),
|
||||
py::arg("T9"), "Calculate the reaction rate at a given temperature T9 (in units of 10^9 K).")
|
||||
.def("calculate_forward_rate_log_derivative", &gridfire::reaction::LogicalReaction::calculate_forward_rate_log_derivative,
|
||||
py::arg("T9"), "Calculate the forward rate log derivative at a given temperature T9 (in units of 10^9 K).");
|
||||
|
||||
py::class_<gridfire::reaction::LogicalReactionSet>(m, "LogicalReactionSet")
|
||||
.def(py::init<const std::vector<gridfire::reaction::LogicalReaction>>(),
|
||||
py::arg("reactions"),
|
||||
"Construct a LogicalReactionSet from a vector of LogicalReaction objects.")
|
||||
.def(py::init<>(),
|
||||
"Default constructor for an empty LogicalReactionSet.")
|
||||
.def(py::init<const gridfire::reaction::LogicalReactionSet&>(),
|
||||
py::arg("other"),
|
||||
"Copy constructor for LogicalReactionSet.")
|
||||
.def("add_reaction", &gridfire::reaction::LogicalReactionSet::add_reaction,
|
||||
py::arg("reaction"),
|
||||
"Add a LogicalReaction to the set.")
|
||||
.def("remove_reaction", &gridfire::reaction::LogicalReactionSet::remove_reaction,
|
||||
py::arg("reaction"),
|
||||
"Remove a LogicalReaction from the set.")
|
||||
.def("contains", py::overload_cast<const std::string_view&>(&gridfire::reaction::LogicalReactionSet::contains, py::const_),
|
||||
py::arg("id"),
|
||||
"Check if the set contains a specific LogicalReaction.")
|
||||
.def("contains", py::overload_cast<const gridfire::reaction::Reaction&>(&gridfire::reaction::LogicalReactionSet::contains, py::const_),
|
||||
py::arg("reaction"),
|
||||
"Check if the set contains a specific Reaction.")
|
||||
.def("size", &gridfire::reaction::LogicalReactionSet::size,
|
||||
"Get the number of LogicalReactions in the set.")
|
||||
.def("__len__", &gridfire::reaction::LogicalReactionSet::size,
|
||||
"Overload len() to return the number of LogicalReactions.")
|
||||
.def("clear", &gridfire::reaction::LogicalReactionSet::clear,
|
||||
"Remove all LogicalReactions from the set.")
|
||||
.def("containes_species", &gridfire::reaction::LogicalReactionSet::contains_species,
|
||||
py::arg("species"),
|
||||
"Check if any reaction in the set involves the given species.")
|
||||
.def("contains_reactant", &gridfire::reaction::LogicalReactionSet::contains_reactant,
|
||||
py::arg("species"),
|
||||
"Check if any reaction in the set has the species as a reactant.")
|
||||
.def("contains_product", &gridfire::reaction::LogicalReactionSet::contains_product,
|
||||
py::arg("species"),
|
||||
"Check if any reaction in the set has the species as a product.")
|
||||
.def("__getitem__", py::overload_cast<size_t>(&gridfire::reaction::LogicalReactionSet::operator[], py::const_),
|
||||
py::arg("index"),
|
||||
"Get a LogicalReaction by index.")
|
||||
.def("__getitem___", py::overload_cast<const std::string_view&>(&gridfire::reaction::LogicalReactionSet::operator[], py::const_),
|
||||
py::arg("id"),
|
||||
"Get a LogicalReaction by its ID.")
|
||||
.def("__eq__", &gridfire::reaction::LogicalReactionSet::operator==,
|
||||
py::arg("LogicalReactionSet"),
|
||||
"Equality operator for LogicalReactionSets based on their contents.")
|
||||
.def("__ne__", &gridfire::reaction::LogicalReactionSet::operator!=,
|
||||
py::arg("LogicalReactionSet"),
|
||||
"Inequality operator for LogicalReactionSets based on their contents.")
|
||||
.def("hash", &gridfire::reaction::LogicalReactionSet::hash,
|
||||
py::arg("seed") = 0,
|
||||
"Compute a hash for the LogicalReactionSet based on its contents."
|
||||
)
|
||||
.def("__repr__", [](const gridfire::reaction::LogicalReactionSet& self) {
|
||||
std::stringstream ss;
|
||||
ss << self;
|
||||
return ss.str();
|
||||
})
|
||||
.def("getReactionSetSpecies", &gridfire::reaction::LogicalReactionSet::getReactionSetSpecies,
|
||||
"Get all species involved in the reactions of the set as a set of Species objects.");
|
||||
|
||||
m.def("packReactionSetToLogicalReactionSet",
|
||||
&gridfire::reaction::packReactionSetToLogicalReactionSet,
|
||||
py::arg("reactionSet"),
|
||||
"Convert a ReactionSet to a LogicalReactionSet by aggregating reactions with the same peName."
|
||||
);
|
||||
|
||||
m.def("get_all_reactions", &gridfire::reaclib::get_all_reactions,
|
||||
"Get all reactions from the REACLIB database.");
|
||||
}
|
||||
5
src/python/reaction/bindings.h
Normal file
5
src/python/reaction/bindings.h
Normal file
@@ -0,0 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
void register_reaction_bindings(pybind11::module &m);
|
||||
17
src/python/reaction/meson.build
Normal file
17
src/python/reaction/meson.build
Normal file
@@ -0,0 +1,17 @@
|
||||
# Define the library
|
||||
bindings_sources = files('bindings.cpp')
|
||||
bindings_headers = files('bindings.h')
|
||||
|
||||
dependencies = [
|
||||
gridfire_dep,
|
||||
python3_dep,
|
||||
pybind11_dep,
|
||||
]
|
||||
|
||||
shared_module('py_gf_reaction',
|
||||
bindings_sources,
|
||||
cpp_args: ['-fvisibility=default'],
|
||||
install : true,
|
||||
dependencies: dependencies,
|
||||
include_directories: include_directories('.')
|
||||
)
|
||||
42
src/python/screening/bindings.cpp
Normal file
42
src/python/screening/bindings.cpp
Normal file
@@ -0,0 +1,42 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
|
||||
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc if needed directly
|
||||
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
|
||||
#include "bindings.h"
|
||||
|
||||
#include "gridfire/screening/screening.h"
|
||||
#include "trampoline/py_screening.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void register_screening_bindings(py::module &m) {
|
||||
py::class_<gridfire::screening::ScreeningModel, PyScreening>(m, "ScreeningModel");
|
||||
|
||||
py::enum_<gridfire::screening::ScreeningType>(m, "ScreeningType")
|
||||
.value("BARE", gridfire::screening::ScreeningType::BARE)
|
||||
.value("WEAK", gridfire::screening::ScreeningType::WEAK)
|
||||
.export_values();
|
||||
|
||||
m.def("selectScreeningModel", &gridfire::screening::selectScreeningModel,
|
||||
py::arg("type"),
|
||||
"Select a screening model based on the specified type. Returns a pointer to the selected model.");
|
||||
|
||||
py::class_<gridfire::screening::BareScreeningModel>(m, "BareScreeningModel")
|
||||
.def(py::init<>())
|
||||
.def("calculateScreeningFactors",
|
||||
py::overload_cast<const gridfire::reaction::LogicalReactionSet&, const std::vector<fourdst::atomic::Species>&, const std::vector<double>&, double, double>(&gridfire::screening::BareScreeningModel::calculateScreeningFactors, py::const_),
|
||||
py::arg("reactions"), py::arg("species"), py::arg("Y"), py::arg("T9"), py::arg("rho"),
|
||||
"Calculate the bare plasma screening factors. This always returns 1.0 (bare)"
|
||||
);
|
||||
|
||||
py::class_<gridfire::screening::WeakScreeningModel>(m, "WeakScreeningModel")
|
||||
.def(py::init<>())
|
||||
.def("calculateScreeningFactors",
|
||||
py::overload_cast<const gridfire::reaction::LogicalReactionSet&, const std::vector<fourdst::atomic::Species>&, const std::vector<double>&, double, double>(&gridfire::screening::WeakScreeningModel::calculateScreeningFactors, py::const_),
|
||||
py::arg("reactions"), py::arg("species"), py::arg("Y"), py::arg("T9"), py::arg("rho"),
|
||||
"Calculate the weak plasma screening factors using the Salpeter (1954) model."
|
||||
);
|
||||
}
|
||||
5
src/python/screening/bindings.h
Normal file
5
src/python/screening/bindings.h
Normal file
@@ -0,0 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
void register_screening_bindings(pybind11::module &m);
|
||||
19
src/python/screening/meson.build
Normal file
19
src/python/screening/meson.build
Normal file
@@ -0,0 +1,19 @@
|
||||
subdir('trampoline')
|
||||
|
||||
# Define the library
|
||||
bindings_sources = files('bindings.cpp')
|
||||
bindings_headers = files('bindings.h')
|
||||
|
||||
dependencies = [
|
||||
gridfire_dep,
|
||||
python3_dep,
|
||||
pybind11_dep,
|
||||
]
|
||||
|
||||
shared_module('py_gf_screening',
|
||||
bindings_sources,
|
||||
cpp_args: ['-fvisibility=default'],
|
||||
install : true,
|
||||
dependencies: dependencies,
|
||||
include_directories: include_directories('.')
|
||||
)
|
||||
21
src/python/screening/trampoline/meson.build
Normal file
21
src/python/screening/trampoline/meson.build
Normal file
@@ -0,0 +1,21 @@
|
||||
gf_screening_trampoline_sources = files('py_screening.cpp')
|
||||
|
||||
gf_screening_trapoline_dependencies = [
|
||||
gridfire_dep,
|
||||
pybind11_dep,
|
||||
python3_dep,
|
||||
]
|
||||
|
||||
gf_screening_trampoline_lib = static_library(
|
||||
'screening_trampolines',
|
||||
gf_screening_trampoline_sources,
|
||||
include_directories: include_directories('.'),
|
||||
dependencies: gf_screening_trapoline_dependencies,
|
||||
install: false,
|
||||
)
|
||||
|
||||
gr_screening_trampoline_dep = declare_dependency(
|
||||
link_with: gf_screening_trampoline_lib,
|
||||
include_directories: ('.'),
|
||||
dependencies: gf_screening_trapoline_dependencies,
|
||||
)
|
||||
31
src/python/screening/trampoline/py_screening.cpp
Normal file
31
src/python/screening/trampoline/py_screening.cpp
Normal file
@@ -0,0 +1,31 @@
|
||||
#include "gridfire/screening/screening.h"
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <pybind11/functional.h> // Needed for std::function
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "py_screening.h"
|
||||
|
||||
#include "cppad/cppad.hpp"
|
||||
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
std::vector<double> PyScreening::calculateScreeningFactors(const gridfire::reaction::LogicalReactionSet &reactions, const std::vector<fourdst::atomic::Species> &species, const std::vector<double> &Y, const double T9, const double rho) const {
|
||||
PYBIND11_OVERLOAD_PURE(
|
||||
std::vector<double>, // Return type
|
||||
gridfire::screening::ScreeningModel,
|
||||
calculateScreeningFactors // Method name
|
||||
);
|
||||
}
|
||||
|
||||
using ADDouble = gridfire::screening::ScreeningModel::ADDouble;
|
||||
std::vector<ADDouble> PyScreening::calculateScreeningFactors(const gridfire::reaction::LogicalReactionSet &reactions, const std::vector<fourdst::atomic::Species> &species, const std::vector<ADDouble> &Y, const ADDouble T9, const ADDouble rho) const {
|
||||
PYBIND11_OVERLOAD_PURE(
|
||||
std::vector<ADDouble>, // Return type
|
||||
gridfire::screening::ScreeningModel,
|
||||
calculateScreeningFactors // Method name
|
||||
);
|
||||
}
|
||||
16
src/python/screening/trampoline/py_screening.h
Normal file
16
src/python/screening/trampoline/py_screening.h
Normal file
@@ -0,0 +1,16 @@
|
||||
#pragma once
|
||||
|
||||
#include "gridfire/screening/screening.h"
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <pybind11/functional.h> // Needed for std::function
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "cppad/cppad.hpp"
|
||||
|
||||
class PyScreening final : public gridfire::screening::ScreeningModel {
|
||||
std::vector<double> calculateScreeningFactors(const gridfire::reaction::LogicalReactionSet &reactions, const std::vector<fourdst::atomic::Species> &species, const std::vector<double> &Y, const double T9, const double rho) const override;
|
||||
std::vector<ADDouble> calculateScreeningFactors(const gridfire::reaction::LogicalReactionSet &reactions, const std::vector<fourdst::atomic::Species> &species, const std::vector<ADDouble> &Y, const ADDouble T9, const ADDouble rho) const override;
|
||||
};
|
||||
26
src/python/solver/bindings.cpp
Normal file
26
src/python/solver/bindings.cpp
Normal file
@@ -0,0 +1,26 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
|
||||
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc if needed directly
|
||||
|
||||
#include "bindings.h"
|
||||
|
||||
#include "gridfire/solver/solver.h"
|
||||
#include "trampoline/py_solver.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
|
||||
void register_solver_bindings(py::module &m) {
|
||||
auto py_dynamic_network_solving_strategy = py::class_<gridfire::solver::DynamicNetworkSolverStrategy, PyDynamicNetworkSolverStrategy>(m, "DynamicNetworkSolverStrategy");
|
||||
auto py_direct_network_solver = py::class_<gridfire::solver::DirectNetworkSolver, gridfire::solver::DynamicNetworkSolverStrategy>(m, "DirectNetworkSolver");
|
||||
|
||||
py_direct_network_solver.def(py::init<gridfire::DynamicEngine&>(),
|
||||
py::arg("engine"),
|
||||
"Constructor for the DirectNetworkSolver. Takes a DynamicEngine instance to use for evaluating the network.");
|
||||
|
||||
py_direct_network_solver.def("evaluate",
|
||||
&gridfire::solver::DirectNetworkSolver::evaluate,
|
||||
py::arg("netIn"),
|
||||
"Evaluate the network for a given timestep. Returns the output conditions after the timestep.");
|
||||
}
|
||||
|
||||
5
src/python/solver/bindings.h
Normal file
5
src/python/solver/bindings.h
Normal file
@@ -0,0 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
void register_solver_bindings(pybind11::module &m);
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user