feat(python): Python Bindings
Python Bindings are working again
This commit is contained in:
@@ -8,34 +8,8 @@ namespace gridfire::config {
|
||||
double relTol = 1.0e-5;
|
||||
};
|
||||
|
||||
|
||||
struct SpectralSolverConfig {
|
||||
struct Trigger {
|
||||
double timestepCollapseRatio = 0.5;
|
||||
size_t maxConvergenceFailures = 2;
|
||||
double relativeFailureRate = 0.5;
|
||||
size_t windowSize = 10;
|
||||
};
|
||||
struct MonitorFunctionConfig {
|
||||
double structure_weight = 1.0;
|
||||
double abundance_weight = 10.0;
|
||||
double alpha = 0.2;
|
||||
double beta = 0.8;
|
||||
};
|
||||
struct BasisConfig {
|
||||
size_t num_elements = 50;
|
||||
};
|
||||
double absTol = 1.0e-8;
|
||||
double relTol = 1.0e-5;
|
||||
size_t degree = 3;
|
||||
MonitorFunctionConfig monitorFunction;
|
||||
BasisConfig basis;
|
||||
Trigger trigger;
|
||||
};
|
||||
|
||||
struct SolverConfig {
|
||||
CVODESolverConfig cvode;
|
||||
SpectralSolverConfig spectral;
|
||||
};
|
||||
|
||||
struct AdaptiveEngineViewConfig {
|
||||
@@ -57,5 +31,4 @@ namespace gridfire::config {
|
||||
};
|
||||
|
||||
|
||||
|
||||
}
|
||||
@@ -25,6 +25,7 @@
|
||||
#include <set>
|
||||
|
||||
#include "gridfire/engine/types/engine_types.h"
|
||||
#include "gridfire/engine/scratchpads/blob.h"
|
||||
|
||||
|
||||
namespace gridfire::policy {
|
||||
|
||||
@@ -20,6 +20,9 @@ namespace gridfire::solver {
|
||||
void set_callback(const std::function<void(const TimestepContextBase&)> &callback);
|
||||
void set_callback(const std::function<void(const TimestepContextBase&)> &callback, size_t zone_idx);
|
||||
|
||||
void clear_callback();
|
||||
void clear_callback(size_t zone_idx);
|
||||
|
||||
void set_stdout_logging(bool enable) override;
|
||||
void set_detailed_logging(bool enable) override;
|
||||
|
||||
|
||||
@@ -28,6 +28,19 @@ namespace gridfire::solver {
|
||||
timestep_callbacks[zone_idx] = callback;
|
||||
}
|
||||
|
||||
void GridSolverContext::clear_callback() {
|
||||
for (auto &cb : timestep_callbacks) {
|
||||
cb = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void GridSolverContext::clear_callback(const size_t zone_idx) {
|
||||
if (zone_idx >= timestep_callbacks.size()) {
|
||||
throw exceptions::SolverError("GridSolverContext::clear_callback: zone_idx out of range.");
|
||||
}
|
||||
timestep_callbacks[zone_idx] = nullptr;
|
||||
}
|
||||
|
||||
void GridSolverContext::set_stdout_logging(const bool enable) {
|
||||
zone_stdout_logging = enable;
|
||||
}
|
||||
|
||||
@@ -58,12 +58,21 @@ if get_option('openmp_support')
|
||||
endif
|
||||
|
||||
# Define the libnetwork library so it can be linked against by other parts of the build system
|
||||
libgridfire = library('gridfire',
|
||||
gridfire_sources,
|
||||
include_directories: include_directories('include'),
|
||||
dependencies: gridfire_build_dependencies,
|
||||
objects: [cvode_objs, kinsol_objs],
|
||||
install : true)
|
||||
if get_option('build_python')
|
||||
libgridfire = static_library('gridfire',
|
||||
gridfire_sources,
|
||||
include_directories: include_directories('include'),
|
||||
dependencies: gridfire_build_dependencies,
|
||||
objects: [cvode_objs, kinsol_objs],
|
||||
install : false)
|
||||
else
|
||||
libgridfire = library('gridfire',
|
||||
gridfire_sources,
|
||||
include_directories: include_directories('include'),
|
||||
dependencies: gridfire_build_dependencies,
|
||||
objects: [cvode_objs, kinsol_objs],
|
||||
install : true)
|
||||
endif
|
||||
|
||||
gridfire_dep = declare_dependency(
|
||||
include_directories: include_directories('include'),
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "types/bindings.h"
|
||||
#include "partition/bindings.h"
|
||||
#include "engine/bindings.h"
|
||||
#include "engine/scratchpads/bindings.h"
|
||||
#include "exceptions/bindings.h"
|
||||
#include "io/bindings.h"
|
||||
#include "reaction/bindings.h"
|
||||
@@ -11,6 +12,7 @@
|
||||
#include "solver/bindings.h"
|
||||
#include "utils/bindings.h"
|
||||
#include "policy/bindings.h"
|
||||
#include "config/bindings.h"
|
||||
|
||||
PYBIND11_MODULE(_gridfire, m) {
|
||||
m.doc() = "Python bindings for the fourdst utility modules which are a part of the 4D-STAR project.";
|
||||
@@ -20,6 +22,9 @@ PYBIND11_MODULE(_gridfire, m) {
|
||||
pybind11::module::import("fourdst.config");
|
||||
pybind11::module::import("fourdst.atomic");
|
||||
|
||||
auto configMod = m.def_submodule("config", "GridFire configuration bindings");
|
||||
register_config_bindings(configMod);
|
||||
|
||||
auto typeMod = m.def_submodule("type", "GridFire type bindings");
|
||||
register_type_bindings(typeMod);
|
||||
|
||||
@@ -39,6 +44,12 @@ PYBIND11_MODULE(_gridfire, m) {
|
||||
register_exception_bindings(exceptionMod);
|
||||
|
||||
auto engineMod = m.def_submodule("engine", "Engine and Engine View bindings");
|
||||
auto scratchpadMod = engineMod.def_submodule("scratchpads", "Engine ScratchPad bindings");
|
||||
|
||||
register_scratchpad_types_bindings(scratchpadMod);
|
||||
register_scratchpad_bindings(scratchpadMod);
|
||||
register_state_blob_bindings(scratchpadMod);
|
||||
|
||||
register_engine_bindings(engineMod);
|
||||
|
||||
auto solverMod = m.def_submodule("solver", "GridFire numerical solver bindings");
|
||||
|
||||
34
src/python/config/bindings.cpp
Normal file
34
src/python/config/bindings.cpp
Normal file
@@ -0,0 +1,34 @@
|
||||
#include "bindings.h"
|
||||
|
||||
#include "gridfire/config/config.h"
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void register_config_bindings(pybind11::module &m) {
|
||||
py::class_<gridfire::config::CVODESolverConfig>(m, "CVODESolverConfig")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("absTol", &gridfire::config::CVODESolverConfig::absTol)
|
||||
.def_readwrite("relTol", &gridfire::config::CVODESolverConfig::relTol);
|
||||
|
||||
py::class_<gridfire::config::SolverConfig>(m, "SolverConfig")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("cvode", &gridfire::config::SolverConfig::cvode);
|
||||
|
||||
py::class_<gridfire::config::AdaptiveEngineViewConfig>(m, "AdaptiveEngineViewConfig")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("relativeCullingThreshold", &gridfire::config::AdaptiveEngineViewConfig::relativeCullingThreshold);
|
||||
|
||||
py::class_<gridfire::config::EngineViewConfig>(m, "EngineViewConfig")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("adaptiveEngineView", &gridfire::config::EngineViewConfig::adaptiveEngineView);
|
||||
|
||||
py::class_<gridfire::config::EngineConfig>(m, "EngineConfig")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("views", &gridfire::config::EngineConfig::views);
|
||||
|
||||
py::class_<gridfire::config::GridFireConfig>(m, "GridFireConfig")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("solver", &gridfire::config::GridFireConfig::solver)
|
||||
.def_readwrite("engine", &gridfire::config::GridFireConfig::engine);
|
||||
}
|
||||
5
src/python/config/bindings.h
Normal file
5
src/python/config/bindings.h
Normal file
@@ -0,0 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
void register_config_bindings(pybind11::module &m);
|
||||
@@ -12,6 +12,7 @@
|
||||
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace sp = gridfire::engine::scratch;
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
@@ -23,16 +24,18 @@ namespace {
|
||||
"calculateRHSAndEnergy",
|
||||
[](
|
||||
const gridfire::engine::DynamicEngine& self,
|
||||
sp::StateBlob& ctx,
|
||||
const fourdst::composition::Composition& comp,
|
||||
const double T9,
|
||||
const double rho
|
||||
) {
|
||||
auto result = self.calculateRHSAndEnergy(comp, T9, rho);
|
||||
auto result = self.calculateRHSAndEnergy(ctx, comp, T9, rho, false);
|
||||
if (!result.has_value()) {
|
||||
throw gridfire::exceptions::EngineError(std::format("calculateRHSAndEnergy returned a potentially recoverable error {}", gridfire::engine::EngineStatus_to_string(result.error())));
|
||||
}
|
||||
return result.value();
|
||||
},
|
||||
py::arg("ctx"),
|
||||
py::arg("comp"),
|
||||
py::arg("T9"),
|
||||
py::arg("rho"),
|
||||
@@ -40,6 +43,7 @@ namespace {
|
||||
)
|
||||
.def("calculateEpsDerivatives",
|
||||
&gridfire::engine::DynamicEngine::calculateEpsDerivatives,
|
||||
py::arg("ctx"),
|
||||
py::arg("comp"),
|
||||
py::arg("T9"),
|
||||
py::arg("rho"),
|
||||
@@ -47,11 +51,13 @@ namespace {
|
||||
)
|
||||
.def("generateJacobianMatrix",
|
||||
[](const gridfire::engine::DynamicEngine& self,
|
||||
sp::StateBlob& ctx,
|
||||
const fourdst::composition::Composition& comp,
|
||||
const double T9,
|
||||
const double rho) -> gridfire::engine::NetworkJacobian {
|
||||
return self.generateJacobianMatrix(comp, T9, rho);
|
||||
return self.generateJacobianMatrix(ctx, comp, T9, rho);
|
||||
},
|
||||
py::arg("ctx"),
|
||||
py::arg("comp"),
|
||||
py::arg("T9"),
|
||||
py::arg("rho"),
|
||||
@@ -59,12 +65,14 @@ namespace {
|
||||
)
|
||||
.def("generateJacobianMatrix",
|
||||
[](const gridfire::engine::DynamicEngine& self,
|
||||
sp::StateBlob& ctx,
|
||||
const fourdst::composition::Composition& comp,
|
||||
const double T9,
|
||||
const double rho,
|
||||
const std::vector<fourdst::atomic::Species>& activeSpecies) -> gridfire::engine::NetworkJacobian {
|
||||
return self.generateJacobianMatrix(comp, T9, rho, activeSpecies);
|
||||
return self.generateJacobianMatrix(ctx, comp, T9, rho, activeSpecies);
|
||||
},
|
||||
py::arg("ctx"),
|
||||
py::arg("comp"),
|
||||
py::arg("T9"),
|
||||
py::arg("rho"),
|
||||
@@ -73,31 +81,32 @@ namespace {
|
||||
)
|
||||
.def("generateJacobianMatrix",
|
||||
[](const gridfire::engine::DynamicEngine& self,
|
||||
sp::StateBlob& ctx,
|
||||
const fourdst::composition::Composition& comp,
|
||||
const double T9,
|
||||
const double rho,
|
||||
const gridfire::engine::SparsityPattern& sparsityPattern) -> gridfire::engine::NetworkJacobian {
|
||||
return self.generateJacobianMatrix(comp, T9, rho, sparsityPattern);
|
||||
return self.generateJacobianMatrix(ctx, comp, T9, rho, sparsityPattern);
|
||||
},
|
||||
py::arg("ctx"),
|
||||
py::arg("comp"),
|
||||
py::arg("T9"),
|
||||
py::arg("rho"),
|
||||
py::arg("sparsityPattern"),
|
||||
"Generate the jacobian matrix for the given sparsity pattern"
|
||||
)
|
||||
.def("generateStoichiometryMatrix",
|
||||
&T::generateStoichiometryMatrix
|
||||
)
|
||||
.def("calculateMolarReactionFlow",
|
||||
[](
|
||||
const gridfire::engine::DynamicEngine& self,
|
||||
sp::StateBlob& ctx,
|
||||
const gridfire::reaction::Reaction& reaction,
|
||||
const fourdst::composition::Composition& comp,
|
||||
const double T9,
|
||||
const double rho
|
||||
) -> double {
|
||||
return self.calculateMolarReactionFlow(reaction, comp, T9, rho);
|
||||
return self.calculateMolarReactionFlow(ctx, reaction, comp, T9, rho);
|
||||
},
|
||||
py::arg("ctx"),
|
||||
py::arg("reaction"),
|
||||
py::arg("comp"),
|
||||
py::arg("T9"),
|
||||
@@ -110,28 +119,21 @@ namespace {
|
||||
.def("getNetworkReactions", &T::getNetworkReactions,
|
||||
"Get the set of logical reactions in the network."
|
||||
)
|
||||
.def ("setNetworkReactions", &T::setNetworkReactions,
|
||||
py::arg("reactions"),
|
||||
"Set the network reactions to a new set of reactions."
|
||||
)
|
||||
.def("getStoichiometryMatrixEntry", &T::getStoichiometryMatrixEntry,
|
||||
py::arg("species"),
|
||||
py::arg("reaction"),
|
||||
"Get an entry from the stoichiometry matrix."
|
||||
)
|
||||
.def("getSpeciesTimescales",
|
||||
[](
|
||||
const gridfire::engine::DynamicEngine& self,
|
||||
sp::StateBlob& ctx,
|
||||
const fourdst::composition::Composition& comp,
|
||||
const double T9,
|
||||
const double rho
|
||||
) -> std::unordered_map<fourdst::atomic::Species, double> {
|
||||
const auto result = self.getSpeciesTimescales(comp, T9, rho);
|
||||
const auto result = self.getSpeciesTimescales(ctx, comp, T9, rho);
|
||||
if (!result.has_value()) {
|
||||
throw gridfire::exceptions::EngineError(std::format("getSpeciesTimescales has returned a potentially recoverable error {}", gridfire::engine::EngineStatus_to_string(result.error())));
|
||||
}
|
||||
return result.value();
|
||||
},
|
||||
py::arg("ctx"),
|
||||
py::arg("comp"),
|
||||
py::arg("T9"),
|
||||
py::arg("rho"),
|
||||
@@ -140,67 +142,48 @@ namespace {
|
||||
.def("getSpeciesDestructionTimescales",
|
||||
[](
|
||||
const gridfire::engine::DynamicEngine& self,
|
||||
sp::StateBlob& ctx,
|
||||
const fourdst::composition::Composition& comp,
|
||||
const double T9,
|
||||
const double rho
|
||||
) -> std::unordered_map<fourdst::atomic::Species, double> {
|
||||
const auto result = self.getSpeciesDestructionTimescales(comp, T9, rho);
|
||||
const auto result = self.getSpeciesDestructionTimescales(ctx, comp, T9, rho);
|
||||
if (!result.has_value()) {
|
||||
throw gridfire::exceptions::EngineError(std::format("getSpeciesDestructionTimescales has returned a potentially recoverable error {}", gridfire::engine::EngineStatus_to_string(result.error())));
|
||||
}
|
||||
return result.value();
|
||||
},
|
||||
py::arg("ctx"),
|
||||
py::arg("comp"),
|
||||
py::arg("T9"),
|
||||
py::arg("rho"),
|
||||
"Get the destruction timescales for each species in the network."
|
||||
)
|
||||
.def("update",
|
||||
&T::update,
|
||||
.def("project",
|
||||
&T::project,
|
||||
py::arg("ctx"),
|
||||
py::arg("netIn"),
|
||||
"Update the engine state based on the provided NetIn object."
|
||||
)
|
||||
.def("setScreeningModel",
|
||||
&T::setScreeningModel,
|
||||
py::arg("screeningModel"),
|
||||
"Set the screening model for the engine."
|
||||
)
|
||||
.def("getScreeningModel",
|
||||
&T::getScreeningModel,
|
||||
"Get the current screening model of the engine."
|
||||
)
|
||||
.def("getSpeciesIndex",
|
||||
&T::getSpeciesIndex,
|
||||
py::arg("ctx"),
|
||||
py::arg("species"),
|
||||
"Get the index of a species in the network."
|
||||
)
|
||||
.def("mapNetInToMolarAbundanceVector",
|
||||
&T::mapNetInToMolarAbundanceVector,
|
||||
py::arg("netIn"),
|
||||
"Map a NetIn object to a vector of molar abundances."
|
||||
)
|
||||
.def("primeEngine",
|
||||
&T::primeEngine,
|
||||
py::arg("ctx"),
|
||||
py::arg("netIn"),
|
||||
"Prime the engine with a NetIn object to prepare for calculations."
|
||||
)
|
||||
.def("getDepth",
|
||||
&T::getDepth,
|
||||
"Get the current build depth of the engine."
|
||||
)
|
||||
.def("rebuild",
|
||||
&T::rebuild,
|
||||
py::arg("composition"),
|
||||
py::arg("depth") = gridfire::engine::NetworkBuildDepth::Full,
|
||||
"Rebuild the engine with a new composition and build depth."
|
||||
)
|
||||
.def("isStale",
|
||||
&T::isStale,
|
||||
py::arg("netIn"),
|
||||
"Check if the engine is stale based on the provided NetIn object."
|
||||
)
|
||||
.def("collectComposition",
|
||||
&T::collectComposition,
|
||||
py::arg("ctx"),
|
||||
py::arg("composition"),
|
||||
py::arg("T9"),
|
||||
py::arg("rho"),
|
||||
@@ -208,6 +191,7 @@ namespace {
|
||||
)
|
||||
.def("getSpeciesStatus",
|
||||
&T::getSpeciesStatus,
|
||||
py::arg("ctx"),
|
||||
py::arg("species"),
|
||||
"Get the status of a species in the network."
|
||||
);
|
||||
@@ -253,6 +237,7 @@ void register_engine_diagnostic_bindings(pybind11::module &m) {
|
||||
auto diagnostics = m.def_submodule("diagnostics", "A submodule for engine diagnostics");
|
||||
diagnostics.def("report_limiting_species",
|
||||
&gridfire::engine::diagnostics::report_limiting_species,
|
||||
py::arg("ctx"),
|
||||
py::arg("engine"),
|
||||
py::arg("Y_full"),
|
||||
py::arg("E_full"),
|
||||
@@ -264,6 +249,7 @@ void register_engine_diagnostic_bindings(pybind11::module &m) {
|
||||
|
||||
diagnostics.def("inspect_species_balance",
|
||||
&gridfire::engine::diagnostics::inspect_species_balance,
|
||||
py::arg("ctx"),
|
||||
py::arg("engine"),
|
||||
py::arg("species_name"),
|
||||
py::arg("comp"),
|
||||
@@ -274,6 +260,7 @@ void register_engine_diagnostic_bindings(pybind11::module &m) {
|
||||
|
||||
diagnostics.def("inspect_jacobian_stiffness",
|
||||
&gridfire::engine::diagnostics::inspect_jacobian_stiffness,
|
||||
py::arg("ctx"),
|
||||
py::arg("engine"),
|
||||
py::arg("comp"),
|
||||
py::arg("T9"),
|
||||
@@ -311,6 +298,7 @@ void register_engine_construction_bindings(pybind11::module &m) {
|
||||
void register_engine_priming_bindings(pybind11::module &m) {
|
||||
m.def("primeNetwork",
|
||||
&gridfire::engine::primeNetwork,
|
||||
py::arg("ctx"),
|
||||
py::arg("netIn"),
|
||||
py::arg("engine"),
|
||||
py::arg("ignoredReactionTypes") = std::nullopt,
|
||||
@@ -456,19 +444,16 @@ void con_stype_register_graph_engine_bindings(const pybind11::module &m) {
|
||||
py::arg("reactions"),
|
||||
"Initialize GraphEngine with a set of reactions."
|
||||
);
|
||||
py_graph_engine_bindings.def_static("getNetReactionStoichiometry",
|
||||
&gridfire::engine::GraphEngine::getNetReactionStoichiometry,
|
||||
py::arg("reaction"),
|
||||
"Get the net stoichiometry for a given reaction."
|
||||
);
|
||||
py_graph_engine_bindings.def("getSpeciesTimescales",
|
||||
[](const gridfire::engine::GraphEngine& self,
|
||||
sp::StateBlob& ctx,
|
||||
const fourdst::composition::Composition& composition,
|
||||
const double T9,
|
||||
const double rho,
|
||||
const gridfire::reaction::ReactionSet& activeReactions) {
|
||||
return self.getSpeciesTimescales(composition, T9, rho, activeReactions);
|
||||
return self.getSpeciesTimescales(ctx, composition, T9, rho, activeReactions);
|
||||
},
|
||||
py::arg("ctx"),
|
||||
py::arg("composition"),
|
||||
py::arg("T9"),
|
||||
py::arg("rho"),
|
||||
@@ -476,12 +461,14 @@ void con_stype_register_graph_engine_bindings(const pybind11::module &m) {
|
||||
);
|
||||
py_graph_engine_bindings.def("getSpeciesDestructionTimescales",
|
||||
[](const gridfire::engine::GraphEngine& self,
|
||||
sp::StateBlob& ctx,
|
||||
const fourdst::composition::Composition& composition,
|
||||
const double T9,
|
||||
const double rho,
|
||||
const gridfire::reaction::ReactionSet& activeReactions) {
|
||||
return self.getSpeciesDestructionTimescales(composition, T9, rho, activeReactions);
|
||||
return self.getSpeciesDestructionTimescales(ctx, composition, T9, rho, activeReactions);
|
||||
},
|
||||
py::arg("ctx"),
|
||||
py::arg("composition"),
|
||||
py::arg("T9"),
|
||||
py::arg("rho"),
|
||||
@@ -489,24 +476,22 @@ void con_stype_register_graph_engine_bindings(const pybind11::module &m) {
|
||||
);
|
||||
py_graph_engine_bindings.def("involvesSpecies",
|
||||
&gridfire::engine::GraphEngine::involvesSpecies,
|
||||
py::arg("ctx"),
|
||||
py::arg("species"),
|
||||
"Check if a given species is involved in the network."
|
||||
);
|
||||
py_graph_engine_bindings.def("exportToDot",
|
||||
&gridfire::engine::GraphEngine::exportToDot,
|
||||
py::arg("ctx"),
|
||||
py::arg("filename"),
|
||||
"Export the network to a DOT file for visualization."
|
||||
);
|
||||
py_graph_engine_bindings.def("exportToCSV",
|
||||
&gridfire::engine::GraphEngine::exportToCSV,
|
||||
py::arg("ctx"),
|
||||
py::arg("filename"),
|
||||
"Export the network to a CSV file for analysis."
|
||||
);
|
||||
py_graph_engine_bindings.def("setPrecomputation",
|
||||
&gridfire::engine::GraphEngine::setPrecomputation,
|
||||
py::arg("precompute"),
|
||||
"Enable or disable precomputation for the engine."
|
||||
);
|
||||
py_graph_engine_bindings.def("isPrecomputationEnabled",
|
||||
&gridfire::engine::GraphEngine::isPrecomputationEnabled,
|
||||
"Check if precomputation is enabled for the engine."
|
||||
@@ -544,11 +529,6 @@ void con_stype_register_graph_engine_bindings(const pybind11::module &m) {
|
||||
&gridfire::engine::GraphEngine::isUsingReverseReactions,
|
||||
"Check if the engine is using reverse reactions."
|
||||
);
|
||||
py_graph_engine_bindings.def("setUseReverseReactions",
|
||||
&gridfire::engine::GraphEngine::setUseReverseReactions,
|
||||
py::arg("useReverse"),
|
||||
"Enable or disable the use of reverse reactions in the engine."
|
||||
);
|
||||
|
||||
// Register the general dynamic engine bindings
|
||||
registerDynamicEngineDefs<gridfire::engine::GraphEngine, gridfire::engine::DynamicEngine>(py_graph_engine_bindings);
|
||||
@@ -587,11 +567,13 @@ void register_engine_view_bindings(const pybind11::module &m) {
|
||||
registerDynamicEngineDefs<gridfire::engine::FileDefinedEngineView, gridfire::engine::DefinedEngineView>(py_file_defined_engine_view_bindings);
|
||||
|
||||
auto py_priming_engine_view_bindings = py::class_<gridfire::engine::NetworkPrimingEngineView, gridfire::engine::DefinedEngineView>(m, "NetworkPrimingEngineView");
|
||||
py_priming_engine_view_bindings.def(py::init<const std::string&, gridfire::engine::GraphEngine&>(),
|
||||
py_priming_engine_view_bindings.def(py::init<sp::StateBlob&, const std::string&, gridfire::engine::GraphEngine&>(),
|
||||
py::arg("ctx"),
|
||||
py::arg("primingSymbol"),
|
||||
py::arg("baseEngine"),
|
||||
"Construct a priming engine view with a priming symbol and a base engine.");
|
||||
py_priming_engine_view_bindings.def(py::init<const fourdst::atomic::Species&, gridfire::engine::GraphEngine&>(),
|
||||
py_priming_engine_view_bindings.def(py::init<sp::StateBlob&, const fourdst::atomic::Species&, gridfire::engine::GraphEngine&>(),
|
||||
py::arg("ctx"),
|
||||
py::arg("primingSpecies"),
|
||||
py::arg("baseEngine"),
|
||||
"Construct a priming engine view with a priming species and a base engine.");
|
||||
@@ -622,15 +604,12 @@ void register_engine_view_bindings(const pybind11::module &m) {
|
||||
);
|
||||
py_multiscale_engine_view_bindings.def("partitionNetwork",
|
||||
&gridfire::engine::MultiscalePartitioningEngineView::partitionNetwork,
|
||||
py::arg("ctx"),
|
||||
py::arg("netIn"),
|
||||
"Partition the network based on species timescales and connectivity.");
|
||||
py_multiscale_engine_view_bindings.def("partitionNetwork",
|
||||
py::overload_cast<const gridfire::NetIn&>(&gridfire::engine::MultiscalePartitioningEngineView::partitionNetwork),
|
||||
py::arg("netIn"),
|
||||
"Partition the network based on a NetIn object."
|
||||
);
|
||||
py_multiscale_engine_view_bindings.def("exportToDot",
|
||||
&gridfire::engine::MultiscalePartitioningEngineView::exportToDot,
|
||||
py::arg("ctx"),
|
||||
py::arg("filename"),
|
||||
py::arg("comp"),
|
||||
py::arg("T9"),
|
||||
@@ -661,7 +640,16 @@ void register_engine_view_bindings(const pybind11::module &m) {
|
||||
"Check if a given species is involved in the network's dynamic set."
|
||||
);
|
||||
py_multiscale_engine_view_bindings.def("getNormalizedEquilibratedComposition",
|
||||
&gridfire::engine::MultiscalePartitioningEngineView::getNormalizedEquilibratedComposition,
|
||||
[](
|
||||
const gridfire::engine::MultiscalePartitioningEngineView& self,
|
||||
sp::StateBlob& ctx,
|
||||
const fourdst::composition::Composition& comp,
|
||||
const double T9,
|
||||
const double rho
|
||||
) {
|
||||
return self.getNormalizedEquilibratedComposition(ctx, comp, T9, rho, false);
|
||||
},
|
||||
py::arg("ctx"),
|
||||
py::arg("comp"),
|
||||
py::arg("T9"),
|
||||
py::arg("rho"),
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
subdir('trampoline')
|
||||
|
||||
# Define the library
|
||||
bindings_sources = files('bindings.cpp')
|
||||
bindings_headers = files('bindings.h')
|
||||
|
||||
dependencies = [
|
||||
gridfire_dep,
|
||||
python3_dep,
|
||||
pybind11_dep,
|
||||
]
|
||||
|
||||
message('⏳ Python bindings for GridFire Engine are being registered...')
|
||||
shared_module('py_gf_engine',
|
||||
bindings_sources,
|
||||
cpp_args: ['-fvisibility=default'],
|
||||
install : true,
|
||||
dependencies: dependencies,
|
||||
include_directories: include_directories('.')
|
||||
)
|
||||
message('✅ Python bindings for GridFire Engine registered successfully!')
|
||||
152
src/python/engine/scratchpads/bindings.cpp
Normal file
152
src/python/engine/scratchpads/bindings.cpp
Normal file
@@ -0,0 +1,152 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
|
||||
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc. if needed directly
|
||||
|
||||
#include "gridfire/engine/scratchpads/scratchpads.h"
|
||||
|
||||
#include "bindings.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace sp = gridfire::engine::scratch;
|
||||
|
||||
template<typename... ScratchPadTypes>
|
||||
void build_state_getter(py::module& m) {
|
||||
|
||||
}
|
||||
|
||||
void register_scratchpad_types_bindings(pybind11::module &m) {
|
||||
py::enum_<sp::ScratchPadType>(m, "ScratchPadType")
|
||||
.value("GRAPH_ENGINE_SCRATCHPAD", sp::ScratchPadType::GRAPH_ENGINE_SCRATCHPAD)
|
||||
.value("MULTISCALE_PARTITIONING_ENGINE_VIEW_SCRATCHPAD", sp::ScratchPadType::MULTISCALE_PARTITIONING_ENGINE_VIEW_SCRATCHPAD)
|
||||
.value("ADAPTIVE_ENGINE_VIEW_SCRATCHPAD", sp::ScratchPadType::ADAPTIVE_ENGINE_VIEW_SCRATCHPAD)
|
||||
.value("DEFINED_ENGINE_VIEW_SCRATCHPAD", sp::ScratchPadType::DEFINED_ENGINE_VIEW_SCRATCHPAD)
|
||||
.export_values();
|
||||
}
|
||||
|
||||
void register_scratchpad_bindings(pybind11::module_ &m) {
|
||||
py::enum_<sp::GraphEngineScratchPad::ADFunRegistrationResult>(m, "ADFunRegistrationResult")
|
||||
.value("SUCCESS", sp::GraphEngineScratchPad::ADFunRegistrationResult::SUCCESS)
|
||||
.value("ALREADY_REGISTERED", sp::GraphEngineScratchPad::ADFunRegistrationResult::ALREADY_REGISTERED)
|
||||
.export_values();
|
||||
|
||||
py::class_<sp::GraphEngineScratchPad>(m, "GraphEngineScratchPad")
|
||||
.def(py::init<>())
|
||||
.def("initialize", &sp::GraphEngineScratchPad::initialize, py::arg("engine"))
|
||||
.def("clone", &sp::GraphEngineScratchPad::clone)
|
||||
.def("is_initialized", &sp::GraphEngineScratchPad::is_initialized)
|
||||
.def_readonly("most_recent_rhs_calculation", &sp::GraphEngineScratchPad::most_recent_rhs_calculation)
|
||||
.def_readonly("local_abundance_cache", &sp::GraphEngineScratchPad::local_abundance_cache)
|
||||
.def_readonly("has_initialized", &sp::GraphEngineScratchPad::has_initialized)
|
||||
.def_readonly("stepDerivativesCache", &sp::GraphEngineScratchPad::stepDerivativesCache)
|
||||
.def_readonly_static("ID", &sp::GraphEngineScratchPad::ID)
|
||||
.def("__repr__", [](const sp::GraphEngineScratchPad &self) {
|
||||
return std::format("{}", self);
|
||||
});
|
||||
|
||||
py::class_<sp::MultiscalePartitioningEngineViewScratchPad>(m, "MultiscalePartitioningEngineViewScratchPad")
|
||||
.def(py::init<>())
|
||||
.def("initialize", &sp::MultiscalePartitioningEngineViewScratchPad::initialize)
|
||||
.def("clone", &sp::MultiscalePartitioningEngineViewScratchPad::clone)
|
||||
.def("is_initialized", &sp::MultiscalePartitioningEngineViewScratchPad::is_initialized)
|
||||
.def_readonly("qse_groups", &sp::MultiscalePartitioningEngineViewScratchPad::qse_groups)
|
||||
.def_readonly("dynamic_species", &sp::MultiscalePartitioningEngineViewScratchPad::dynamic_species)
|
||||
.def_readonly("algebraic_species", &sp::MultiscalePartitioningEngineViewScratchPad::algebraic_species)
|
||||
.def_readonly("composition_cache", &sp::MultiscalePartitioningEngineViewScratchPad::composition_cache)
|
||||
.def_readonly("has_initialized", &sp::MultiscalePartitioningEngineViewScratchPad::has_initialized)
|
||||
.def_readonly_static("ID", &sp::MultiscalePartitioningEngineViewScratchPad::ID)
|
||||
.def("__repr__", [](const sp::MultiscalePartitioningEngineViewScratchPad &self) {
|
||||
return std::format("{}", self);
|
||||
});
|
||||
|
||||
py::class_<sp::AdaptiveEngineViewScratchPad>(m, "AdaptiveEngineViewScratchPad")
|
||||
.def(py::init<>())
|
||||
.def("initialize", &sp::AdaptiveEngineViewScratchPad::initialize)
|
||||
.def("clone", &sp::AdaptiveEngineViewScratchPad::clone)
|
||||
.def("is_initialized", &sp::AdaptiveEngineViewScratchPad::is_initialized)
|
||||
.def_readonly("active_species", &sp::AdaptiveEngineViewScratchPad::active_species)
|
||||
.def_readonly("active_reactions", &sp::AdaptiveEngineViewScratchPad::active_reactions)
|
||||
.def_readonly("has_initialized", &sp::AdaptiveEngineViewScratchPad::has_initialized)
|
||||
.def_readonly_static("ID", &sp::AdaptiveEngineViewScratchPad::ID)
|
||||
.def("__repr__", [](const sp::AdaptiveEngineViewScratchPad &self) {
|
||||
return std::format("{}", self);
|
||||
});
|
||||
|
||||
py::class_<sp::DefinedEngineViewScratchPad>(m, "DefinedEngineViewScratchPad")
|
||||
.def(py::init<>())
|
||||
.def("clone", &sp::DefinedEngineViewScratchPad::clone)
|
||||
.def("is_initialized", &sp::DefinedEngineViewScratchPad::is_initialized)
|
||||
.def_readonly("active_species", &sp::DefinedEngineViewScratchPad::active_species)
|
||||
.def_readonly("active_reactions", &sp::DefinedEngineViewScratchPad::active_reactions)
|
||||
.def_readonly("species_index_map", &sp::DefinedEngineViewScratchPad::species_index_map)
|
||||
.def_readonly("reaction_index_map", &sp::DefinedEngineViewScratchPad::reaction_index_map)
|
||||
.def_readonly("has_initialized", &sp::DefinedEngineViewScratchPad::has_initialized)
|
||||
.def_readonly_static("ID", &sp::DefinedEngineViewScratchPad::ID)
|
||||
.def("__repr__", [](const sp::DefinedEngineViewScratchPad &self) {
|
||||
return std::format("{}", self);
|
||||
});
|
||||
}
|
||||
|
||||
void register_state_blob_bindings(pybind11::module_ &m) {
|
||||
py::enum_<sp::StateBlob::Error>(m, "StateBlobError")
|
||||
.value("SCRATCHPAD_OUT_OF_BOUNDS", sp::StateBlob::Error::SCRATCHPAD_OUT_OF_BOUNDS)
|
||||
.value("SCRATCHPAD_NOT_FOUND", sp::StateBlob::Error::SCRATCHPAD_NOT_FOUND)
|
||||
.value("SCRATCHPAD_BAD_CAST", sp::StateBlob::Error::SCRATCHPAD_BAD_CAST)
|
||||
.value("SCRATCHPAD_NOT_INITIALIZED", sp::StateBlob::Error::SCRATCHPAD_NOT_INITIALIZED)
|
||||
.value("SCRATCHPAD_TYPE_COLLISION", sp::StateBlob::Error::SCRATCHPAD_TYPE_COLLISION)
|
||||
.value("SCRATCHPAD_UNKNOWN_ERROR", sp::StateBlob::Error::SCRATCHPAD_UNKNOWN_ERROR)
|
||||
.export_values();
|
||||
|
||||
py::class_<sp::StateBlob>(m, "StateBlob")
|
||||
.def(py::init<>())
|
||||
.def("enroll", [](sp::StateBlob &self, const sp::ScratchPadType type) {
|
||||
switch (type) {
|
||||
case sp::ScratchPadType::GRAPH_ENGINE_SCRATCHPAD:
|
||||
self.enroll<sp::GraphEngineScratchPad>();
|
||||
break;
|
||||
case sp::ScratchPadType::MULTISCALE_PARTITIONING_ENGINE_VIEW_SCRATCHPAD:
|
||||
self.enroll<sp::MultiscalePartitioningEngineViewScratchPad>();
|
||||
break;
|
||||
case sp::ScratchPadType::ADAPTIVE_ENGINE_VIEW_SCRATCHPAD:
|
||||
self.enroll<sp::AdaptiveEngineViewScratchPad>();
|
||||
break;
|
||||
case sp::ScratchPadType::DEFINED_ENGINE_VIEW_SCRATCHPAD:
|
||||
self.enroll<sp::DefinedEngineViewScratchPad>();
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument("Unknown ScratchPadType for enrollment.");
|
||||
}
|
||||
})
|
||||
.def("get", [](const sp::StateBlob &self, const sp::ScratchPadType type) {
|
||||
auto result = self.get(type);
|
||||
if (!result.has_value()) {
|
||||
throw std::runtime_error("Error retrieving scratchpad: " + sp::StateBlob::error_to_string(result.error()));
|
||||
}
|
||||
return result.value();
|
||||
},
|
||||
pybind11::return_value_policy::reference_internal
|
||||
)
|
||||
.def("clone_structure", &sp::StateBlob::clone_structure)
|
||||
.def("get_registered_scratchpads", &sp::StateBlob::get_registered_scratchpads)
|
||||
.def("get_status", [](const sp::StateBlob &self, const sp::ScratchPadType type) -> sp::StateBlob::ScratchPadStatus {
|
||||
switch (type) {
|
||||
case sp::ScratchPadType::GRAPH_ENGINE_SCRATCHPAD:
|
||||
return self.get_status<sp::GraphEngineScratchPad>();
|
||||
case sp::ScratchPadType::MULTISCALE_PARTITIONING_ENGINE_VIEW_SCRATCHPAD:
|
||||
return self.get_status<sp::MultiscalePartitioningEngineViewScratchPad>();
|
||||
case sp::ScratchPadType::ADAPTIVE_ENGINE_VIEW_SCRATCHPAD:
|
||||
return self.get_status<sp::AdaptiveEngineViewScratchPad>();
|
||||
case sp::ScratchPadType::DEFINED_ENGINE_VIEW_SCRATCHPAD:
|
||||
return self.get_status<sp::DefinedEngineViewScratchPad>();
|
||||
default:
|
||||
throw std::invalid_argument("Unknown ScratchPadType for status retrieval.");
|
||||
}
|
||||
})
|
||||
.def("get_status_map", &sp::StateBlob::get_status_map)
|
||||
.def_static("error_to_string", &sp::StateBlob::error_to_string)
|
||||
.def("__repr__", [](const sp::StateBlob &self) {
|
||||
return std::format("{}", self);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
|
||||
7
src/python/engine/scratchpads/bindings.h
Normal file
7
src/python/engine/scratchpads/bindings.h
Normal file
@@ -0,0 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
void register_scratchpad_types_bindings(pybind11::module_& m);
|
||||
void register_scratchpad_bindings(pybind11::module_& m);
|
||||
void register_state_blob_bindings(pybind11::module_& m);
|
||||
@@ -1,21 +0,0 @@
|
||||
gf_engine_trampoline_sources = files('py_engine.cpp')
|
||||
|
||||
gf_engine_trapoline_dependencies = [
|
||||
gridfire_dep,
|
||||
pybind11_dep,
|
||||
python3_dep,
|
||||
]
|
||||
|
||||
gf_engine_trampoline_lib = static_library(
|
||||
'engine_trampolines',
|
||||
gf_engine_trampoline_sources,
|
||||
include_directories: include_directories('.'),
|
||||
dependencies: gf_engine_trapoline_dependencies,
|
||||
install: false,
|
||||
)
|
||||
|
||||
gr_engine_trampoline_dep = declare_dependency(
|
||||
link_with: gf_engine_trampoline_lib,
|
||||
include_directories: ('.'),
|
||||
dependencies: gf_engine_trapoline_dependencies,
|
||||
)
|
||||
@@ -13,36 +13,29 @@
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
const std::vector<fourdst::atomic::Species>& PyEngine::getNetworkSpecies() const {
|
||||
/*
|
||||
* Acquire the GIL (Global Interpreter Lock) for thread safety
|
||||
* with the Python interpreter.
|
||||
*/
|
||||
py::gil_scoped_acquire gil;
|
||||
|
||||
/*
|
||||
* get_override() looks for a Python method that overrides this C++ one.
|
||||
*/
|
||||
|
||||
if (const py::function override = py::get_override(this, "getNetworkSpecies")) {
|
||||
const py::object result = override();
|
||||
m_species_cache = result.cast<std::vector<fourdst::atomic::Species>>();
|
||||
return m_species_cache;
|
||||
}
|
||||
|
||||
py::pybind11_fail("Tried to call pure virtual function \"DynamicEngine::getNetworkSpecies\"");
|
||||
const std::vector<fourdst::atomic::Species>& PyEngine::getNetworkSpecies(
|
||||
gridfire::engine::scratch::StateBlob& ctx
|
||||
) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
const std::vector<fourdst::atomic::Species>&,
|
||||
gridfire::engine::Engine,
|
||||
getNetworkSpecies,
|
||||
ctx
|
||||
);
|
||||
}
|
||||
|
||||
std::expected<gridfire::engine::StepDerivatives<double>, gridfire::engine::EngineStatus> PyEngine::calculateRHSAndEnergy(
|
||||
gridfire::engine::scratch::StateBlob& ctx,
|
||||
const fourdst::composition::CompositionAbstract &comp,
|
||||
double T9,
|
||||
double rho
|
||||
double rho,
|
||||
bool trust
|
||||
) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
PYBIND11_TYPE(std::expected<gridfire::engine::StepDerivatives<double>, gridfire::engine::EngineStatus>),
|
||||
gridfire::engine::Engine,
|
||||
calculateRHSAndEnergy,
|
||||
comp, T9, rho
|
||||
ctx, comp, T9, rho, trust
|
||||
);
|
||||
}
|
||||
|
||||
@@ -50,41 +43,35 @@ std::expected<gridfire::engine::StepDerivatives<double>, gridfire::engine::Engin
|
||||
/// PyDynamicEngine Implementation ///
|
||||
/////////////////////////////////////
|
||||
|
||||
const std::vector<fourdst::atomic::Species>& PyDynamicEngine::getNetworkSpecies() const {
|
||||
/*
|
||||
* Acquire the GIL (Global Interpreter Lock) for thread safety
|
||||
* with the Python interpreter.
|
||||
*/
|
||||
py::gil_scoped_acquire gil;
|
||||
|
||||
/*
|
||||
* get_override() looks for a Python method that overrides this C++ one.
|
||||
*/
|
||||
|
||||
if (const py::function override = py::get_override(this, "getNetworkSpecies")) {
|
||||
const py::object result = override();
|
||||
m_species_cache = result.cast<std::vector<fourdst::atomic::Species>>();
|
||||
return m_species_cache;
|
||||
}
|
||||
|
||||
py::pybind11_fail("Tried to call pure virtual function \"DynamicEngine::getNetworkSpecies\"");
|
||||
const std::vector<fourdst::atomic::Species>& PyDynamicEngine::getNetworkSpecies(
|
||||
gridfire::engine::scratch::StateBlob& ctx
|
||||
) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
const std::vector<fourdst::atomic::Species>&,
|
||||
gridfire::engine::DynamicEngine,
|
||||
getNetworkSpecies,
|
||||
ctx
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
std::expected<gridfire::engine::StepDerivatives<double>, gridfire::engine::EngineStatus> PyDynamicEngine::calculateRHSAndEnergy(
|
||||
gridfire::engine::scratch::StateBlob& ctx,
|
||||
const fourdst::composition::CompositionAbstract &comp,
|
||||
double T9,
|
||||
double rho
|
||||
double rho,
|
||||
bool trust
|
||||
) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
PYBIND11_TYPE(std::expected<gridfire::engine::StepDerivatives<double>, gridfire::engine::EngineStatus>),
|
||||
gridfire::engine::DynamicEngine,
|
||||
calculateRHSAndEnergy,
|
||||
comp, T9, rho
|
||||
ctx, comp, T9, rho, trust
|
||||
);
|
||||
}
|
||||
|
||||
gridfire::engine::NetworkJacobian PyDynamicEngine::generateJacobianMatrix(
|
||||
gridfire::engine::scratch::StateBlob& ctx,
|
||||
const fourdst::composition::CompositionAbstract& comp,
|
||||
double T9,
|
||||
double rho
|
||||
@@ -100,6 +87,7 @@ gridfire::engine::NetworkJacobian PyDynamicEngine::generateJacobianMatrix(
|
||||
}
|
||||
|
||||
gridfire::engine::NetworkJacobian PyDynamicEngine::generateJacobianMatrix(
|
||||
gridfire::engine::scratch::StateBlob& ctx,
|
||||
const fourdst::composition::CompositionAbstract &comp,
|
||||
const double T9,
|
||||
const double rho,
|
||||
@@ -109,6 +97,7 @@ gridfire::engine::NetworkJacobian PyDynamicEngine::generateJacobianMatrix(
|
||||
gridfire::engine::NetworkJacobian,
|
||||
gridfire::engine::DynamicEngine,
|
||||
generateJacobianMatrix,
|
||||
ctx,
|
||||
comp,
|
||||
T9,
|
||||
rho,
|
||||
@@ -117,6 +106,7 @@ gridfire::engine::NetworkJacobian PyDynamicEngine::generateJacobianMatrix(
|
||||
}
|
||||
|
||||
gridfire::engine::NetworkJacobian PyDynamicEngine::generateJacobianMatrix(
|
||||
gridfire::engine::scratch::StateBlob& ctx,
|
||||
const fourdst::composition::CompositionAbstract &comp,
|
||||
double T9,
|
||||
double rho,
|
||||
@@ -126,6 +116,7 @@ gridfire::engine::NetworkJacobian PyDynamicEngine::generateJacobianMatrix(
|
||||
gridfire::engine::NetworkJacobian,
|
||||
gridfire::engine::DynamicEngine,
|
||||
generateJacobianMatrix,
|
||||
ctx,
|
||||
comp,
|
||||
T9,
|
||||
rho,
|
||||
@@ -133,28 +124,8 @@ gridfire::engine::NetworkJacobian PyDynamicEngine::generateJacobianMatrix(
|
||||
);
|
||||
}
|
||||
|
||||
void PyDynamicEngine::generateStoichiometryMatrix() {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
void,
|
||||
gridfire::engine::DynamicEngine,
|
||||
generateStoichiometryMatrix
|
||||
);
|
||||
}
|
||||
|
||||
int PyDynamicEngine::getStoichiometryMatrixEntry(
|
||||
const fourdst::atomic::Species& species,
|
||||
const gridfire::reaction::Reaction& reaction
|
||||
) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
int,
|
||||
gridfire::engine::DynamicEngine,
|
||||
getStoichiometryMatrixEntry,
|
||||
species,
|
||||
reaction
|
||||
);
|
||||
}
|
||||
|
||||
double PyDynamicEngine::calculateMolarReactionFlow(
|
||||
gridfire::engine::scratch::StateBlob& ctx,
|
||||
const gridfire::reaction::Reaction &reaction,
|
||||
const fourdst::composition::CompositionAbstract &comp,
|
||||
double T9,
|
||||
@@ -164,6 +135,7 @@ double PyDynamicEngine::calculateMolarReactionFlow(
|
||||
double,
|
||||
gridfire::engine::DynamicEngine,
|
||||
calculateMolarReactionFlow,
|
||||
ctx,
|
||||
reaction,
|
||||
comp,
|
||||
T9,
|
||||
@@ -171,24 +143,19 @@ double PyDynamicEngine::calculateMolarReactionFlow(
|
||||
);
|
||||
}
|
||||
|
||||
const gridfire::reaction::ReactionSet& PyDynamicEngine::getNetworkReactions() const {
|
||||
const gridfire::reaction::ReactionSet& PyDynamicEngine::getNetworkReactions(
|
||||
gridfire::engine::scratch::StateBlob& ctx
|
||||
) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
const gridfire::reaction::ReactionSet&,
|
||||
gridfire::engine::DynamicEngine,
|
||||
getNetworkReactions
|
||||
);
|
||||
}
|
||||
|
||||
void PyDynamicEngine::setNetworkReactions(const gridfire::reaction::ReactionSet& reactions) {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
void,
|
||||
gridfire::engine::DynamicEngine,
|
||||
setNetworkReactions,
|
||||
reactions
|
||||
getNetworkReactions,
|
||||
ctx
|
||||
);
|
||||
}
|
||||
|
||||
std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::engine::EngineStatus> PyDynamicEngine::getSpeciesTimescales(
|
||||
gridfire::engine::scratch::StateBlob& ctx,
|
||||
const fourdst::composition::CompositionAbstract &comp,
|
||||
double T9,
|
||||
double rho
|
||||
@@ -197,6 +164,7 @@ std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::en
|
||||
PYBIND11_TYPE(std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::engine::EngineStatus>),
|
||||
gridfire::engine::DynamicEngine,
|
||||
getSpeciesTimescales,
|
||||
ctx,
|
||||
comp,
|
||||
T9,
|
||||
rho
|
||||
@@ -204,6 +172,7 @@ std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::en
|
||||
}
|
||||
|
||||
std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::engine::EngineStatus> PyDynamicEngine::getSpeciesDestructionTimescales(
|
||||
gridfire::engine::scratch::StateBlob& ctx,
|
||||
const fourdst::composition::CompositionAbstract &comp,
|
||||
double T9,
|
||||
double rho
|
||||
@@ -212,80 +181,71 @@ std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::en
|
||||
PYBIND11_TYPE(std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::engine::EngineStatus>),
|
||||
gridfire::engine::DynamicEngine,
|
||||
getSpeciesDestructionTimescales,
|
||||
comp, T9, rho
|
||||
ctx, comp, T9, rho
|
||||
);
|
||||
}
|
||||
|
||||
fourdst::composition::Composition PyDynamicEngine::update(const gridfire::NetIn &netIn) {
|
||||
fourdst::composition::Composition PyDynamicEngine::project(
|
||||
gridfire::engine::scratch::StateBlob& ctx,
|
||||
const gridfire::NetIn &netIn
|
||||
) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
fourdst::composition::Composition,
|
||||
gridfire::engine::DynamicEngine,
|
||||
update,
|
||||
project,
|
||||
ctx,
|
||||
netIn
|
||||
);
|
||||
}
|
||||
|
||||
bool PyDynamicEngine::isStale(const gridfire::NetIn &netIn) {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
bool,
|
||||
gridfire::engine::DynamicEngine,
|
||||
isStale,
|
||||
netIn
|
||||
);
|
||||
}
|
||||
|
||||
void PyDynamicEngine::setScreeningModel(gridfire::screening::ScreeningType model) {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
void,
|
||||
gridfire::engine::DynamicEngine,
|
||||
setScreeningModel,
|
||||
model
|
||||
);
|
||||
}
|
||||
|
||||
gridfire::screening::ScreeningType PyDynamicEngine::getScreeningModel() const {
|
||||
gridfire::screening::ScreeningType PyDynamicEngine::getScreeningModel(
|
||||
gridfire::engine::scratch::StateBlob& ctx
|
||||
) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
gridfire::screening::ScreeningType,
|
||||
gridfire::engine::DynamicEngine,
|
||||
getScreeningModel
|
||||
getScreeningModel,
|
||||
ctx
|
||||
);
|
||||
}
|
||||
|
||||
size_t PyDynamicEngine::getSpeciesIndex(const fourdst::atomic::Species &species) const {
|
||||
size_t PyDynamicEngine::getSpeciesIndex(
|
||||
gridfire::engine::scratch::StateBlob& ctx,
|
||||
const fourdst::atomic::Species &species
|
||||
) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
int,
|
||||
gridfire::engine::DynamicEngine,
|
||||
getSpeciesIndex,
|
||||
ctx,
|
||||
species
|
||||
);
|
||||
}
|
||||
|
||||
std::vector<double> PyDynamicEngine::mapNetInToMolarAbundanceVector(const gridfire::NetIn &netIn) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
std::vector<double>,
|
||||
gridfire::engine::DynamicEngine,
|
||||
mapNetInToMolarAbundanceVector,
|
||||
netIn
|
||||
);
|
||||
}
|
||||
|
||||
gridfire::engine::PrimingReport PyDynamicEngine::primeEngine(const gridfire::NetIn &netIn) {
|
||||
gridfire::engine::PrimingReport PyDynamicEngine::primeEngine(
|
||||
gridfire::engine::scratch::StateBlob& ctx,
|
||||
const gridfire::NetIn &netIn
|
||||
) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
gridfire::engine::PrimingReport,
|
||||
gridfire::engine::DynamicEngine,
|
||||
primeEngine,
|
||||
ctx,
|
||||
netIn
|
||||
);
|
||||
}
|
||||
|
||||
gridfire::engine::EnergyDerivatives PyDynamicEngine::calculateEpsDerivatives(
|
||||
gridfire::engine::scratch::StateBlob& ctx,
|
||||
const fourdst::composition::CompositionAbstract &comp,
|
||||
const double T9,
|
||||
const double rho) const {
|
||||
const double rho
|
||||
) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
gridfire::engine::EnergyDerivatives,
|
||||
gridfire::engine::DynamicEngine,
|
||||
calculateEpsDerivatives,
|
||||
ctx,
|
||||
comp,
|
||||
T9,
|
||||
rho
|
||||
@@ -293,6 +253,7 @@ gridfire::engine::EnergyDerivatives PyDynamicEngine::calculateEpsDerivatives(
|
||||
}
|
||||
|
||||
fourdst::composition::Composition PyDynamicEngine::collectComposition(
|
||||
gridfire::engine::scratch::StateBlob& ctx,
|
||||
const fourdst::composition::CompositionAbstract &comp,
|
||||
const double T9,
|
||||
const double rho
|
||||
@@ -301,21 +262,37 @@ fourdst::composition::Composition PyDynamicEngine::collectComposition(
|
||||
fourdst::composition::Composition,
|
||||
gridfire::engine::DynamicEngine,
|
||||
collectComposition,
|
||||
ctx,
|
||||
comp,
|
||||
T9,
|
||||
rho
|
||||
);
|
||||
}
|
||||
|
||||
gridfire::engine::SpeciesStatus PyDynamicEngine::getSpeciesStatus(const fourdst::atomic::Species &species) const {
|
||||
gridfire::engine::SpeciesStatus PyDynamicEngine::getSpeciesStatus(
|
||||
gridfire::engine::scratch::StateBlob& ctx,
|
||||
const fourdst::atomic::Species &species
|
||||
) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
gridfire::engine::SpeciesStatus,
|
||||
gridfire::engine::DynamicEngine,
|
||||
getSpeciesStatus,
|
||||
ctx,
|
||||
species
|
||||
);
|
||||
}
|
||||
|
||||
std::optional<gridfire::engine::StepDerivatives<double>> PyDynamicEngine::getMostRecentRHSCalculation(
|
||||
gridfire::engine::scratch::StateBlob &ctx
|
||||
) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
PYBIND11_TYPE(std::optional<gridfire::engine::StepDerivatives<double>>),
|
||||
gridfire::engine::DynamicEngine,
|
||||
getMostRecentRHSCalculation,
|
||||
ctx
|
||||
);
|
||||
}
|
||||
|
||||
const gridfire::engine::Engine& PyEngineView::getBaseEngine() const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
const gridfire::engine::Engine&,
|
||||
|
||||
@@ -10,12 +10,16 @@
|
||||
|
||||
class PyEngine final : public gridfire::engine::Engine {
|
||||
public:
|
||||
const std::vector<fourdst::atomic::Species>& getNetworkSpecies() const override;
|
||||
const std::vector<fourdst::atomic::Species>& getNetworkSpecies(
|
||||
gridfire::engine::scratch::StateBlob& ctx
|
||||
) const override;
|
||||
|
||||
std::expected<gridfire::engine::StepDerivatives<double>, gridfire::engine::EngineStatus> calculateRHSAndEnergy(
|
||||
gridfire::engine::scratch::StateBlob& ctx,
|
||||
const fourdst::composition::CompositionAbstract &comp,
|
||||
double T9,
|
||||
double rho
|
||||
double rho,
|
||||
bool trust
|
||||
) const override;
|
||||
private:
|
||||
mutable std::vector<fourdst::atomic::Species> m_species_cache;
|
||||
@@ -23,21 +27,27 @@ private:
|
||||
|
||||
class PyDynamicEngine final : public gridfire::engine::DynamicEngine {
|
||||
public:
|
||||
const std::vector<fourdst::atomic::Species>& getNetworkSpecies() const override;
|
||||
const std::vector<fourdst::atomic::Species>& getNetworkSpecies(
|
||||
gridfire::engine::scratch::StateBlob& ctx
|
||||
) const override;
|
||||
|
||||
std::expected<gridfire::engine::StepDerivatives<double>, gridfire::engine::EngineStatus> calculateRHSAndEnergy(
|
||||
gridfire::engine::scratch::StateBlob& ctx,
|
||||
const fourdst::composition::CompositionAbstract &comp,
|
||||
double T9,
|
||||
double rho
|
||||
double rho,
|
||||
bool trust
|
||||
) const override;
|
||||
|
||||
gridfire::engine::NetworkJacobian generateJacobianMatrix(
|
||||
gridfire::engine::scratch::StateBlob& ctx,
|
||||
const fourdst::composition::CompositionAbstract& comp,
|
||||
double T9,
|
||||
double rho
|
||||
) const override;
|
||||
|
||||
gridfire::engine::NetworkJacobian generateJacobianMatrix(
|
||||
gridfire::engine::scratch::StateBlob& ctx,
|
||||
const fourdst::composition::CompositionAbstract &comp,
|
||||
double T9,
|
||||
double rho,
|
||||
@@ -45,96 +55,81 @@ public:
|
||||
) const override;
|
||||
|
||||
gridfire::engine::NetworkJacobian generateJacobianMatrix(
|
||||
gridfire::engine::scratch::StateBlob& ctx,
|
||||
const fourdst::composition::CompositionAbstract& comp,
|
||||
double T9,
|
||||
double rho,
|
||||
const gridfire::engine::SparsityPattern &sparsityPattern
|
||||
) const override;
|
||||
|
||||
void generateStoichiometryMatrix() override;
|
||||
|
||||
int getStoichiometryMatrixEntry(
|
||||
const fourdst::atomic::Species& species,
|
||||
const gridfire::reaction::Reaction& reaction
|
||||
) const override;
|
||||
|
||||
double calculateMolarReactionFlow(
|
||||
gridfire::engine::scratch::StateBlob& ctx,
|
||||
const gridfire::reaction::Reaction &reaction,
|
||||
const fourdst::composition::CompositionAbstract &comp,
|
||||
double T9,
|
||||
double rho
|
||||
) const override;
|
||||
|
||||
const gridfire::reaction::ReactionSet& getNetworkReactions() const override;
|
||||
|
||||
void setNetworkReactions(
|
||||
const gridfire::reaction::ReactionSet& reactions
|
||||
) override;
|
||||
const gridfire::reaction::ReactionSet& getNetworkReactions(
|
||||
gridfire::engine::scratch::StateBlob& ctx
|
||||
) const override;
|
||||
|
||||
std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::engine::EngineStatus> getSpeciesTimescales(
|
||||
gridfire::engine::scratch::StateBlob& ctx,
|
||||
const fourdst::composition::CompositionAbstract &comp,
|
||||
double T9,
|
||||
double rho
|
||||
) const override;
|
||||
|
||||
std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::engine::EngineStatus> getSpeciesDestructionTimescales(
|
||||
gridfire::engine::scratch::StateBlob& ctx,
|
||||
const fourdst::composition::CompositionAbstract &comp,
|
||||
double T9,
|
||||
double rho
|
||||
) const override;
|
||||
|
||||
fourdst::composition::Composition update(
|
||||
fourdst::composition::Composition project(
|
||||
gridfire::engine::scratch::StateBlob& ctx,
|
||||
const gridfire::NetIn &netIn
|
||||
) override;
|
||||
) const override;
|
||||
|
||||
bool isStale(
|
||||
const gridfire::NetIn &netIn
|
||||
) override;
|
||||
|
||||
void setScreeningModel(
|
||||
gridfire::screening::ScreeningType model
|
||||
) override;
|
||||
|
||||
gridfire::screening::ScreeningType getScreeningModel() const override;
|
||||
gridfire::screening::ScreeningType getScreeningModel(
|
||||
gridfire::engine::scratch::StateBlob& ctx
|
||||
) const override;
|
||||
|
||||
size_t getSpeciesIndex(
|
||||
gridfire::engine::scratch::StateBlob& ctx,
|
||||
const fourdst::atomic::Species &species
|
||||
) const override;
|
||||
|
||||
std::vector<double> mapNetInToMolarAbundanceVector(
|
||||
gridfire::engine::PrimingReport primeEngine(
|
||||
gridfire::engine::scratch::StateBlob& ctx,
|
||||
const gridfire::NetIn &netIn
|
||||
) const override;
|
||||
|
||||
gridfire::engine::PrimingReport primeEngine(
|
||||
const gridfire::NetIn &netIn
|
||||
) override;
|
||||
|
||||
gridfire::engine::BuildDepthType getDepth() const override {
|
||||
throw std::logic_error("Network depth not supported by this engine.");
|
||||
}
|
||||
void rebuild(
|
||||
const fourdst::composition::CompositionAbstract &comp,
|
||||
gridfire::engine::BuildDepthType depth
|
||||
) override {
|
||||
throw std::logic_error("Setting network depth not supported by this engine.");
|
||||
}
|
||||
|
||||
[[nodiscard]] gridfire::engine::EnergyDerivatives calculateEpsDerivatives(
|
||||
gridfire::engine::scratch::StateBlob& ctx,
|
||||
const fourdst::composition::CompositionAbstract &comp,
|
||||
double T9,
|
||||
double rho
|
||||
) const override;
|
||||
|
||||
fourdst::composition::Composition collectComposition(
|
||||
gridfire::engine::scratch::StateBlob& ctx,
|
||||
const fourdst::composition::CompositionAbstract &comp,
|
||||
double T9,
|
||||
double rho
|
||||
) const override;
|
||||
|
||||
gridfire::engine::SpeciesStatus getSpeciesStatus(
|
||||
gridfire::engine::scratch::StateBlob& ctx,
|
||||
const fourdst::atomic::Species &species
|
||||
) const override;
|
||||
|
||||
std::optional<gridfire::engine::StepDerivatives<double>> getMostRecentRHSCalculation(
|
||||
gridfire::engine::scratch::StateBlob &ctx
|
||||
) const override;
|
||||
|
||||
private:
|
||||
mutable std::vector<fourdst::atomic::Species> m_species_cache;
|
||||
};
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
#include "bindings.h"
|
||||
|
||||
#include "gridfire/exceptions/error_scratchpad.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
#include "gridfire/exceptions/exceptions.h"
|
||||
@@ -44,4 +46,6 @@ void register_exception_bindings(const py::module &m) {
|
||||
py::register_exception<gridfire::exceptions::CVODESolverFailureError>(m, "CVODESolverFailureError", m.attr("SUNDIALSError"));
|
||||
py::register_exception<gridfire::exceptions::KINSolSolverFailureError>(m, "KINSolSolverFailureError", m.attr("SUNDIALSError"));
|
||||
|
||||
py::register_exception<gridfire::exceptions::ScratchPadError>(m, "ScratchPadError", m.attr("GridFireError"));
|
||||
|
||||
}
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
# Define the library
|
||||
bindings_sources = files('bindings.cpp')
|
||||
bindings_headers = files('bindings.h')
|
||||
|
||||
dependencies = [
|
||||
gridfire_dep,
|
||||
python3_dep,
|
||||
pybind11_dep,
|
||||
]
|
||||
|
||||
shared_module('py_gf_exceptions',
|
||||
bindings_sources,
|
||||
cpp_args: ['-fvisibility=default'],
|
||||
install : true,
|
||||
dependencies: dependencies,
|
||||
include_directories: include_directories('.')
|
||||
)
|
||||
@@ -1,7 +1,7 @@
|
||||
from ._gridfire import *
|
||||
import sys
|
||||
|
||||
from ._gridfire import type, utils, engine, solver, exceptions, partition, reaction, screening, io, policy
|
||||
from ._gridfire import type, utils, engine, solver, exceptions, partition, reaction, screening, io, policy, config
|
||||
|
||||
sys.modules['gridfire.type'] = type
|
||||
sys.modules['gridfire.utils'] = utils
|
||||
@@ -13,8 +13,61 @@ sys.modules['gridfire.reaction'] = reaction
|
||||
sys.modules['gridfire.screening'] = screening
|
||||
sys.modules['gridfire.policy'] = policy
|
||||
sys.modules['gridfire.io'] = io
|
||||
sys.modules['gridfire.config'] = config
|
||||
|
||||
__all__ = ['type', 'utils', 'engine', 'solver', 'exceptions', 'partition', 'reaction', 'screening', 'io', 'policy']
|
||||
__all__ = ['type', 'utils', 'engine', 'solver', 'exceptions', 'partition', 'reaction', 'screening', 'io', 'policy', 'config']
|
||||
|
||||
__version__ = "v0.7.4_rc2"
|
||||
import importlib.metadata
|
||||
|
||||
try:
|
||||
_meta = importlib.metadata.metadata('gridfire')
|
||||
__version__ = _meta['Version']
|
||||
__author__ = _meta['Author']
|
||||
__license__ = _meta['License']
|
||||
__email__ = _meta['Author-email']
|
||||
__url__ = _meta['Home-page'] or _meta.get('Project-URL', '').split(',')[0].split(' ')[-1].strip()
|
||||
__description__ = _meta['Summary']
|
||||
except importlib.metadata.PackageNotFoundError :
|
||||
__version__ = 'unknown - Package not installed'
|
||||
__author__ = 'Emily M. Boudreaux'
|
||||
__license__ = 'GNU General Public License v3.0'
|
||||
__email__ = 'emily.boudreaux@dartmouth.edu'
|
||||
__url__ = 'https://github.com/4D-STAR/GridFire'
|
||||
|
||||
def gf_metadata():
|
||||
return {
|
||||
'version': __version__,
|
||||
'author': __author__,
|
||||
'license': __license__,
|
||||
'email': __email__,
|
||||
'url': __url__,
|
||||
'description': __description__
|
||||
}
|
||||
|
||||
def gf_version():
|
||||
return __version__
|
||||
|
||||
def gf_author():
|
||||
return __author__
|
||||
|
||||
def gf_license():
|
||||
return __license__
|
||||
|
||||
def gf_email():
|
||||
return __email__
|
||||
|
||||
def gf_url():
|
||||
return __url__
|
||||
|
||||
def gf_description():
|
||||
return __description__
|
||||
|
||||
def gf_collaboration():
|
||||
return "4D-STAR Collaboration"
|
||||
|
||||
def gf_credits():
|
||||
return [
|
||||
"Emily M. Boudreaux - Lead Developer",
|
||||
"Aaron Dotter - Co-Developer",
|
||||
"4D-STAR Collaboration - Contributors"
|
||||
]
|
||||
@@ -1,17 +0,0 @@
|
||||
# Define the library
|
||||
bindings_sources = files('bindings.cpp')
|
||||
bindings_headers = files('bindings.h')
|
||||
|
||||
dependencies = [
|
||||
gridfire_dep,
|
||||
python3_dep,
|
||||
pybind11_dep,
|
||||
]
|
||||
|
||||
shared_module('py_gf_io',
|
||||
bindings_sources,
|
||||
cpp_args: ['-fvisibility=default'],
|
||||
install : true,
|
||||
dependencies: dependencies,
|
||||
include_directories: include_directories('.')
|
||||
)
|
||||
@@ -1,21 +0,0 @@
|
||||
gf_io_trampoline_sources = files('py_io.cpp')
|
||||
|
||||
gf_io_trapoline_dependencies = [
|
||||
gridfire_dep,
|
||||
pybind11_dep,
|
||||
python3_dep,
|
||||
]
|
||||
|
||||
gf_io_trampoline_lib = static_library(
|
||||
'io_trampolines',
|
||||
gf_io_trampoline_sources,
|
||||
include_directories: include_directories('.'),
|
||||
dependencies: gf_io_trapoline_dependencies,
|
||||
install: false,
|
||||
)
|
||||
|
||||
gr_io_trampoline_dep = declare_dependency(
|
||||
link_with: gf_io_trampoline_lib,
|
||||
include_directories: ('.'),
|
||||
dependencies: gf_io_trapoline_dependencies,
|
||||
)
|
||||
@@ -1,10 +0,0 @@
|
||||
subdir('types')
|
||||
subdir('utils')
|
||||
subdir('exceptions')
|
||||
subdir('io')
|
||||
subdir('partition')
|
||||
subdir('reaction')
|
||||
subdir('screening')
|
||||
subdir('engine')
|
||||
subdir('policy')
|
||||
subdir('solver')
|
||||
@@ -1,19 +0,0 @@
|
||||
subdir('trampoline')
|
||||
|
||||
# Define the library
|
||||
bindings_sources = files('bindings.cpp')
|
||||
bindings_headers = files('bindings.h')
|
||||
|
||||
dependencies = [
|
||||
gridfire_dep,
|
||||
python3_dep,
|
||||
pybind11_dep,
|
||||
]
|
||||
|
||||
shared_module('py_gf_partition',
|
||||
bindings_sources,
|
||||
cpp_args: ['-fvisibility=default'],
|
||||
install : true,
|
||||
dependencies: dependencies,
|
||||
include_directories: include_directories('.')
|
||||
)
|
||||
@@ -1,21 +0,0 @@
|
||||
gf_partition_trampoline_sources = files('py_partition.cpp')
|
||||
|
||||
gf_partition_trapoline_dependencies = [
|
||||
gridfire_dep,
|
||||
pybind11_dep,
|
||||
python3_dep,
|
||||
]
|
||||
|
||||
gf_partition_trampoline_lib = static_library(
|
||||
'partition_trampolines',
|
||||
gf_partition_trampoline_sources,
|
||||
include_directories: include_directories('.'),
|
||||
dependencies: gf_partition_trapoline_dependencies,
|
||||
install: false,
|
||||
)
|
||||
|
||||
gr_partition_trampoline_dep = declare_dependency(
|
||||
link_with: gf_partition_trampoline_lib,
|
||||
include_directories: ('.'),
|
||||
dependencies: gf_partition_trapoline_dependencies,
|
||||
)
|
||||
@@ -8,7 +8,6 @@
|
||||
|
||||
#include "gridfire/policy/policy.h"
|
||||
|
||||
PYBIND11_DECLARE_HOLDER_TYPE(T, std::unique_ptr<T>, true) // Declare unique_ptr as a holder type for pybind11
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
@@ -103,9 +102,30 @@ namespace {
|
||||
.def(
|
||||
"construct",
|
||||
&T::construct,
|
||||
py::return_value_policy::reference,
|
||||
"Construct the network according to the policy."
|
||||
|
||||
)
|
||||
.def(
|
||||
"get_engine_stack",
|
||||
[](const T &self) {
|
||||
const auto& stack = self.get_engine_stack();
|
||||
std::vector<gridfire::engine::DynamicEngine*> engine_ptrs;
|
||||
engine_ptrs.reserve(stack.size());
|
||||
for (const auto& engine_uptr : stack) {
|
||||
engine_ptrs.push_back(engine_uptr.get());
|
||||
}
|
||||
|
||||
return engine_ptrs;
|
||||
},
|
||||
py::return_value_policy::reference_internal
|
||||
)
|
||||
.def(
|
||||
"get_stack_scratch_blob",
|
||||
&T::get_stack_scratch_blob
|
||||
)
|
||||
.def(
|
||||
"get_partition_function",
|
||||
&T::get_partition_function
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -215,6 +235,26 @@ void register_network_policy_bindings(pybind11::module &m) {
|
||||
.value("INITIALIZED_VERIFIED", gridfire::policy::NetworkPolicyStatus::INITIALIZED_VERIFIED)
|
||||
.export_values();
|
||||
|
||||
m.def("network_policy_status_to_string",
|
||||
&gridfire::policy::NetworkPolicyStatusToString,
|
||||
py::arg("status"),
|
||||
"Convert a NetworkPolicyStatus enum value to its string representation."
|
||||
);
|
||||
|
||||
py::class_<gridfire::policy::ConstructionResults>(m, "ConstructionResults")
|
||||
.def_property_readonly("engine",
|
||||
[](const gridfire::policy::ConstructionResults &self) -> const gridfire::engine::DynamicEngine& {
|
||||
return self.engine;
|
||||
},
|
||||
py::return_value_policy::reference
|
||||
)
|
||||
.def_property_readonly("scratch_blob",
|
||||
[](const gridfire::policy::ConstructionResults &self) {
|
||||
return self.scratch_blob.get();
|
||||
},
|
||||
py::return_value_policy::reference_internal
|
||||
);
|
||||
|
||||
py::class_<gridfire::policy::NetworkPolicy, PyNetworkPolicy> py_networkPolicy(m, "NetworkPolicy");
|
||||
py::class_<gridfire::policy::MainSequencePolicy, gridfire::policy::NetworkPolicy> py_mainSeqPolicy(m, "MainSequencePolicy");
|
||||
py_mainSeqPolicy.def(
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
# Define the library
|
||||
subdir('trampoline')
|
||||
|
||||
bindings_sources = files('bindings.cpp')
|
||||
bindings_headers = files('bindings.h')
|
||||
|
||||
dependencies = [
|
||||
gridfire_dep,
|
||||
python3_dep,
|
||||
pybind11_dep,
|
||||
]
|
||||
|
||||
shared_module('py_gf_policy',
|
||||
bindings_sources,
|
||||
cpp_args: ['-fvisibility=default'],
|
||||
install : true,
|
||||
dependencies: dependencies,
|
||||
include_directories: include_directories('.')
|
||||
)
|
||||
@@ -1,21 +0,0 @@
|
||||
gf_policy_trampoline_sources = files('py_policy.cpp')
|
||||
|
||||
gf_policy_trapoline_dependencies = [
|
||||
gridfire_dep,
|
||||
pybind11_dep,
|
||||
python3_dep,
|
||||
]
|
||||
|
||||
gf_policy_trampoline_lib = static_library(
|
||||
'policy_trampolines',
|
||||
gf_policy_trampoline_sources,
|
||||
include_directories: include_directories('.'),
|
||||
dependencies: gf_policy_trapoline_dependencies,
|
||||
install: false,
|
||||
)
|
||||
|
||||
gr_policy_trampoline_dep = declare_dependency(
|
||||
link_with: gf_policy_trampoline_lib,
|
||||
include_directories: ('.'),
|
||||
dependencies: gf_policy_trapoline_dependencies,
|
||||
)
|
||||
@@ -39,9 +39,9 @@ const gridfire::reaction::ReactionSet& PyNetworkPolicy::get_seed_reactions() con
|
||||
);
|
||||
}
|
||||
|
||||
gridfire::engine::DynamicEngine& PyNetworkPolicy::construct() {
|
||||
gridfire::policy::ConstructionResults PyNetworkPolicy::construct() {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
gridfire::engine::DynamicEngine&,
|
||||
gridfire::policy::ConstructionResults,
|
||||
gridfire::policy::NetworkPolicy,
|
||||
construct
|
||||
);
|
||||
@@ -79,6 +79,14 @@ const std::unique_ptr<gridfire::partition::PartitionFunction>& PyNetworkPolicy::
|
||||
);
|
||||
}
|
||||
|
||||
std::unique_ptr<gridfire::engine::scratch::StateBlob> PyNetworkPolicy::get_stack_scratch_blob() const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
std::unique_ptr<gridfire::engine::scratch::StateBlob>,
|
||||
gridfire::policy::NetworkPolicy,
|
||||
get_stack_scratch_blob
|
||||
);
|
||||
}
|
||||
|
||||
const gridfire::reaction::ReactionSet &PyReactionChainPolicy::get_reactions() const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
const gridfire::reaction::ReactionSet &,
|
||||
|
||||
@@ -13,7 +13,7 @@ public:
|
||||
|
||||
[[nodiscard]] const gridfire::reaction::ReactionSet& get_seed_reactions() const override;
|
||||
|
||||
[[nodiscard]] gridfire::engine::DynamicEngine& construct() override;
|
||||
[[nodiscard]] gridfire::policy::ConstructionResults construct() override;
|
||||
|
||||
[[nodiscard]] gridfire::policy::NetworkPolicyStatus get_status() const override;
|
||||
|
||||
@@ -22,6 +22,8 @@ public:
|
||||
[[nodiscard]] std::vector<gridfire::engine::EngineTypes> get_engine_types_stack() const override;
|
||||
|
||||
[[nodiscard]] const std::unique_ptr<gridfire::partition::PartitionFunction>& get_partition_function() const override;
|
||||
|
||||
[[nodiscard]] std::unique_ptr<gridfire::engine::scratch::StateBlob> get_stack_scratch_blob() const override;
|
||||
};
|
||||
|
||||
class PyReactionChainPolicy final : public gridfire::policy::ReactionChainPolicy {
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
# Define the library
|
||||
bindings_sources = files('bindings.cpp')
|
||||
bindings_headers = files('bindings.h')
|
||||
|
||||
dependencies = [
|
||||
gridfire_dep,
|
||||
python3_dep,
|
||||
pybind11_dep,
|
||||
]
|
||||
|
||||
shared_module('py_gf_reaction',
|
||||
bindings_sources,
|
||||
cpp_args: ['-fvisibility=default'],
|
||||
install : true,
|
||||
dependencies: dependencies,
|
||||
include_directories: include_directories('.')
|
||||
)
|
||||
@@ -1,19 +0,0 @@
|
||||
subdir('trampoline')
|
||||
|
||||
# Define the library
|
||||
bindings_sources = files('bindings.cpp')
|
||||
bindings_headers = files('bindings.h')
|
||||
|
||||
dependencies = [
|
||||
gridfire_dep,
|
||||
python3_dep,
|
||||
pybind11_dep,
|
||||
]
|
||||
|
||||
shared_module('py_gf_screening',
|
||||
bindings_sources,
|
||||
cpp_args: ['-fvisibility=default'],
|
||||
install : true,
|
||||
dependencies: dependencies,
|
||||
include_directories: include_directories('.')
|
||||
)
|
||||
@@ -1,21 +0,0 @@
|
||||
gf_screening_trampoline_sources = files('py_screening.cpp')
|
||||
|
||||
gf_screening_trapoline_dependencies = [
|
||||
gridfire_dep,
|
||||
pybind11_dep,
|
||||
python3_dep,
|
||||
]
|
||||
|
||||
gf_screening_trampoline_lib = static_library(
|
||||
'screening_trampolines',
|
||||
gf_screening_trampoline_sources,
|
||||
include_directories: include_directories('.'),
|
||||
dependencies: gf_screening_trapoline_dependencies,
|
||||
install: false,
|
||||
)
|
||||
|
||||
gr_screening_trampoline_dep = declare_dependency(
|
||||
link_with: gf_screening_trampoline_lib,
|
||||
include_directories: ('.'),
|
||||
dependencies: gf_screening_trapoline_dependencies,
|
||||
)
|
||||
@@ -7,125 +7,226 @@
|
||||
|
||||
#include "bindings.h"
|
||||
|
||||
#include "gridfire/solver/strategies/CVODE_solver_strategy.h"
|
||||
#include "gridfire/solver/strategies/PointSolver.h"
|
||||
#include "gridfire/engine/scratchpads/blob.h"
|
||||
#include "trampoline/py_solver.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
|
||||
void register_solver_bindings(const py::module &m) {
|
||||
auto py_solver_context_base = py::class_<gridfire::solver::SolverContextBase>(m, "SolverContextBase");
|
||||
|
||||
auto py_cvode_timestep_context = py::class_<gridfire::solver::CVODESolverStrategy::TimestepContext, gridfire::solver::SolverContextBase>(m, "CVODETimestepContext");
|
||||
py_cvode_timestep_context.def_readonly("t", &gridfire::solver::CVODESolverStrategy::TimestepContext::t);
|
||||
auto py_cvode_timestep_context = py::class_<gridfire::solver::PointSolverTimestepContext>(m, "PointSolverTimestepContext");
|
||||
py_cvode_timestep_context.def_readonly("t", &gridfire::solver::PointSolverTimestepContext::t);
|
||||
py_cvode_timestep_context.def_property_readonly(
|
||||
"state",
|
||||
[](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> std::vector<double> {
|
||||
[](const gridfire::solver::PointSolverTimestepContext& self) -> std::vector<double> {
|
||||
const sunrealtype* nvec_data = N_VGetArrayPointer(self.state);
|
||||
const sunindextype length = N_VGetLength(self.state);
|
||||
return std::vector<double>(nvec_data, nvec_data + length);
|
||||
return {nvec_data, nvec_data + length};
|
||||
}
|
||||
);
|
||||
py_cvode_timestep_context.def_readonly("dt", &gridfire::solver::CVODESolverStrategy::TimestepContext::dt);
|
||||
py_cvode_timestep_context.def_readonly("last_step_time", &gridfire::solver::CVODESolverStrategy::TimestepContext::last_step_time);
|
||||
py_cvode_timestep_context.def_readonly("T9", &gridfire::solver::CVODESolverStrategy::TimestepContext::T9);
|
||||
py_cvode_timestep_context.def_readonly("rho", &gridfire::solver::CVODESolverStrategy::TimestepContext::rho);
|
||||
py_cvode_timestep_context.def_readonly("num_steps", &gridfire::solver::CVODESolverStrategy::TimestepContext::num_steps);
|
||||
py_cvode_timestep_context.def_readonly("currentConvergenceFailures", &gridfire::solver::CVODESolverStrategy::TimestepContext::currentConvergenceFailures);
|
||||
py_cvode_timestep_context.def_readonly("currentNonlinearIterations", &gridfire::solver::CVODESolverStrategy::TimestepContext::currentNonlinearIterations);
|
||||
py_cvode_timestep_context.def_readonly("dt", &gridfire::solver::PointSolverTimestepContext::dt);
|
||||
py_cvode_timestep_context.def_readonly("last_step_time", &gridfire::solver::PointSolverTimestepContext::last_step_time);
|
||||
py_cvode_timestep_context.def_readonly("T9", &gridfire::solver::PointSolverTimestepContext::T9);
|
||||
py_cvode_timestep_context.def_readonly("rho", &gridfire::solver::PointSolverTimestepContext::rho);
|
||||
py_cvode_timestep_context.def_readonly("num_steps", &gridfire::solver::PointSolverTimestepContext::num_steps);
|
||||
py_cvode_timestep_context.def_readonly("currentConvergenceFailures", &gridfire::solver::PointSolverTimestepContext::currentConvergenceFailures);
|
||||
py_cvode_timestep_context.def_readonly("currentNonlinearIterations", &gridfire::solver::PointSolverTimestepContext::currentNonlinearIterations);
|
||||
py_cvode_timestep_context.def_property_readonly(
|
||||
"engine",
|
||||
[](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> const gridfire::engine::DynamicEngine& {
|
||||
[](const gridfire::solver::PointSolverTimestepContext& self) -> const gridfire::engine::DynamicEngine& {
|
||||
return self.engine;
|
||||
}
|
||||
);
|
||||
py_cvode_timestep_context.def_property_readonly(
|
||||
"networkSpecies",
|
||||
[](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> std::vector<fourdst::atomic::Species> {
|
||||
[](const gridfire::solver::PointSolverTimestepContext& self) -> std::vector<fourdst::atomic::Species> {
|
||||
return self.networkSpecies;
|
||||
}
|
||||
);
|
||||
py_cvode_timestep_context.def_property_readonly(
|
||||
"state_ctx",
|
||||
[](const gridfire::solver::PointSolverTimestepContext& self) {
|
||||
return &(self.state_ctx);
|
||||
},
|
||||
py::return_value_policy::reference_internal
|
||||
);
|
||||
|
||||
auto py_dynamic_network_solver_strategy = py::class_<gridfire::solver::DynamicNetworkSolverStrategy, PyDynamicNetworkSolverStrategy>(m, "DynamicNetworkSolverStrategy");
|
||||
py_dynamic_network_solver_strategy.def(
|
||||
|
||||
auto py_solver_context_base = py::class_<gridfire::solver::SolverContextBase>(m, "SolverContextBase");
|
||||
auto py_point_solver_context = py::class_<gridfire::solver::PointSolverContext, gridfire::solver::SolverContextBase>(m, "PointSolverContext");
|
||||
|
||||
py_point_solver_context
|
||||
.def_readonly(
|
||||
"sun_ctx", &gridfire::solver::PointSolverContext::sun_ctx
|
||||
)
|
||||
.def_readonly(
|
||||
"cvode_mem", &gridfire::solver::PointSolverContext::cvode_mem
|
||||
)
|
||||
.def_readonly(
|
||||
"Y", &gridfire::solver::PointSolverContext::Y
|
||||
)
|
||||
.def_readonly(
|
||||
"YErr", &gridfire::solver::PointSolverContext::YErr
|
||||
)
|
||||
.def_readonly(
|
||||
"J", &gridfire::solver::PointSolverContext::J
|
||||
)
|
||||
.def_readonly(
|
||||
"LS", &gridfire::solver::PointSolverContext::LS
|
||||
)
|
||||
.def_property_readonly(
|
||||
"engine_ctx",
|
||||
[](const gridfire::solver::PointSolverContext& self) -> gridfire::engine::scratch::StateBlob& {
|
||||
return *(self.engine_ctx);
|
||||
},
|
||||
py::return_value_policy::reference
|
||||
)
|
||||
.def_readonly(
|
||||
"num_steps", &gridfire::solver::PointSolverContext::num_steps
|
||||
)
|
||||
.def_property(
|
||||
"abs_tol",
|
||||
[](const gridfire::solver::PointSolverContext& self) -> double {
|
||||
return self.abs_tol.value();
|
||||
},
|
||||
[](gridfire::solver::PointSolverContext& self, double abs_tol) -> void {
|
||||
self.abs_tol = abs_tol;
|
||||
}
|
||||
)
|
||||
.def_property(
|
||||
"rel_tol",
|
||||
[](const gridfire::solver::PointSolverContext& self) -> double {
|
||||
return self.rel_tol.value();
|
||||
},
|
||||
[](gridfire::solver::PointSolverContext& self, double rel_tol) -> void {
|
||||
self.rel_tol = rel_tol;
|
||||
}
|
||||
)
|
||||
.def_property(
|
||||
"stdout_logging",
|
||||
[](const gridfire::solver::PointSolverContext& self) -> bool {
|
||||
return self.stdout_logging;
|
||||
},
|
||||
[](gridfire::solver::PointSolverContext& self, const bool enable) -> void {
|
||||
self.stdout_logging = enable;
|
||||
}
|
||||
)
|
||||
.def_property(
|
||||
"detailed_logging",
|
||||
[](const gridfire::solver::PointSolverContext& self) -> bool {
|
||||
return self.detailed_step_logging;
|
||||
},
|
||||
[](gridfire::solver::PointSolverContext& self, const bool enable) -> void {
|
||||
self.detailed_step_logging = enable;
|
||||
}
|
||||
)
|
||||
.def_property(
|
||||
"callback",
|
||||
[](const gridfire::solver::PointSolverContext& self) -> std::optional<std::function<void(const gridfire::solver::PointSolverTimestepContext&)>> {
|
||||
return self.callback;
|
||||
},
|
||||
[](gridfire::solver::PointSolverContext& self, const std::optional<std::function<void(const gridfire::solver::PointSolverTimestepContext&)>>& cb) {
|
||||
self.callback = cb;
|
||||
}
|
||||
)
|
||||
.def("reset_all", &gridfire::solver::PointSolverContext::reset_all)
|
||||
.def("reset_user", &gridfire::solver::PointSolverContext::reset_user)
|
||||
.def("reset_cvode", &gridfire::solver::PointSolverContext::reset_cvode)
|
||||
.def("clear_context", &gridfire::solver::PointSolverContext::clear_context)
|
||||
.def("init_context", &gridfire::solver::PointSolverContext::init_context)
|
||||
.def("has_context", &gridfire::solver::PointSolverContext::has_context)
|
||||
.def("init", &gridfire::solver::PointSolverContext::init)
|
||||
.def(py::init<const gridfire::engine::scratch::StateBlob&>(), py::arg("engine_ctx"));
|
||||
|
||||
|
||||
|
||||
auto py_single_zone_dynamic_network_solver = py::class_<gridfire::solver::SingleZoneDynamicNetworkSolver, PySingleZoneDynamicNetworkSolver>(m, "SingleZoneDynamicNetworkSolver");
|
||||
py_single_zone_dynamic_network_solver.def(
|
||||
"evaluate",
|
||||
&gridfire::solver::DynamicNetworkSolverStrategy::evaluate,
|
||||
&gridfire::solver::SingleZoneDynamicNetworkSolver::evaluate,
|
||||
py::arg("solver_ctx"),
|
||||
py::arg("netIn"),
|
||||
"evaluate the dynamic engine using the dynamic engine class"
|
||||
"evaluate the dynamic engine using the dynamic engine class for a single zone"
|
||||
);
|
||||
auto py_multi_zone_dynamic_network_solver = py::class_<gridfire::solver::MultiZoneDynamicNetworkSolver, PyMultiZoneDynamicNetworkSolver>(m, "MultiZoneDynamicNetworkSolver");
|
||||
py_multi_zone_dynamic_network_solver.def(
|
||||
"evaluate",
|
||||
&gridfire::solver::MultiZoneDynamicNetworkSolver::evaluate,
|
||||
py::arg("solver_ctx"),
|
||||
py::arg("netIns"),
|
||||
"evaluate the dynamic engine using the dynamic engine class for multiple zones (using openmp if available)"
|
||||
);
|
||||
|
||||
auto py_point_solver = py::class_<gridfire::solver::PointSolver, gridfire::solver::SingleZoneDynamicNetworkSolver>(m, "PointSolver");
|
||||
|
||||
py_dynamic_network_solver_strategy.def(
|
||||
"describe_callback_context",
|
||||
&gridfire::solver::DynamicNetworkSolverStrategy::describe_callback_context,
|
||||
"Get a structure representing what data is in the callback context in a human readable format"
|
||||
);
|
||||
|
||||
auto py_cvode_solver_strategy = py::class_<gridfire::solver::CVODESolverStrategy, gridfire::solver::DynamicNetworkSolverStrategy>(m, "CVODESolverStrategy");
|
||||
|
||||
py_cvode_solver_strategy.def(
|
||||
py_point_solver.def(
|
||||
py::init<gridfire::engine::DynamicEngine&>(),
|
||||
py::arg("engine"),
|
||||
"Initialize the CVODESolverStrategy object."
|
||||
"Initialize the PointSolver object."
|
||||
);
|
||||
|
||||
py_cvode_solver_strategy.def(
|
||||
py_point_solver.def(
|
||||
"evaluate",
|
||||
py::overload_cast<const gridfire::NetIn&, bool>(&gridfire::solver::CVODESolverStrategy::evaluate),
|
||||
py::overload_cast<gridfire::solver::SolverContextBase&, const gridfire::NetIn&, bool, bool>(&gridfire::solver::PointSolver::evaluate, py::const_),
|
||||
py::arg("solver_ctx"),
|
||||
py::arg("netIn"),
|
||||
py::arg("display_trigger") = false,
|
||||
py::arg("force_reinitialization") = false,
|
||||
"evaluate the dynamic engine using the dynamic engine class"
|
||||
);
|
||||
|
||||
py_cvode_solver_strategy.def(
|
||||
"get_stdout_logging_enabled",
|
||||
&gridfire::solver::CVODESolverStrategy::get_stdout_logging_enabled,
|
||||
"Check if solver logging to standard output is enabled."
|
||||
auto py_grid_solver_context = py::class_<gridfire::solver::GridSolverContext, gridfire::solver::SolverContextBase>(m, "GridSolverContext");
|
||||
py_grid_solver_context.def(py::init<const gridfire::engine::scratch::StateBlob&>(), py::arg("ctx_template"));
|
||||
py_grid_solver_context.def("init", &gridfire::solver::GridSolverContext::init);
|
||||
py_grid_solver_context.def("reset", &gridfire::solver::GridSolverContext::reset);
|
||||
py_grid_solver_context.def("set_callback", py::overload_cast<const std::function<void(const gridfire::solver::TimestepContextBase&)>&>(&gridfire::solver::GridSolverContext::set_callback) , py::arg("callback"));
|
||||
py_grid_solver_context.def("set_callback", py::overload_cast<const std::function<void(const gridfire::solver::TimestepContextBase&)>&, size_t>(&gridfire::solver::GridSolverContext::set_callback) , py::arg("callback"), py::arg("zone_idx"));
|
||||
py_grid_solver_context.def("clear_callback", py::overload_cast<>(&gridfire::solver::GridSolverContext::clear_callback));
|
||||
py_grid_solver_context.def("clear_callback", py::overload_cast<size_t>(&gridfire::solver::GridSolverContext::clear_callback), py::arg("zone_idx"));
|
||||
py_grid_solver_context.def_property(
|
||||
"stdout_logging",
|
||||
[](const gridfire::solver::GridSolverContext& self) -> bool {
|
||||
return self.zone_stdout_logging;
|
||||
},
|
||||
[](gridfire::solver::GridSolverContext& self, const bool enable) -> void {
|
||||
self.zone_stdout_logging = enable;
|
||||
}
|
||||
)
|
||||
.def_property(
|
||||
"detailed_logging",
|
||||
[](const gridfire::solver::GridSolverContext& self) -> bool {
|
||||
return self.zone_detailed_logging;
|
||||
},
|
||||
[](gridfire::solver::GridSolverContext& self, const bool enable) -> void {
|
||||
self.zone_detailed_logging = enable;
|
||||
}
|
||||
)
|
||||
.def_property(
|
||||
"zone_completion_logging",
|
||||
[](const gridfire::solver::GridSolverContext& self) -> bool {
|
||||
return self.zone_completion_logging;
|
||||
},
|
||||
[](gridfire::solver::GridSolverContext& self, const bool enable) -> void {
|
||||
self.zone_completion_logging = enable;
|
||||
}
|
||||
);
|
||||
|
||||
py_cvode_solver_strategy.def(
|
||||
"set_stdout_logging_enabled",
|
||||
&gridfire::solver::CVODESolverStrategy::set_stdout_logging_enabled,
|
||||
py::arg("logging_enabled"),
|
||||
"Enable logging to standard output."
|
||||
auto py_grid_solver = py::class_<gridfire::solver::GridSolver, gridfire::solver::MultiZoneDynamicNetworkSolver>(m, "GridSolver");
|
||||
py_grid_solver.def(
|
||||
py::init<const gridfire::engine::DynamicEngine&, const gridfire::solver::SingleZoneDynamicNetworkSolver&>(),
|
||||
py::arg("engine"),
|
||||
py::arg("solver"),
|
||||
"Initialize the GridSolver object."
|
||||
);
|
||||
|
||||
py_cvode_solver_strategy.def(
|
||||
"set_absTol",
|
||||
&gridfire::solver::CVODESolverStrategy::set_absTol,
|
||||
py::arg("absTol"),
|
||||
"Set the absolute tolerance for the CVODE solver."
|
||||
py_grid_solver.def(
|
||||
"evaluate",
|
||||
&gridfire::solver::GridSolver::evaluate,
|
||||
py::arg("solver_ctx"),
|
||||
py::arg("netIns"),
|
||||
"evaluate the dynamic engine using the dynamic engine class"
|
||||
);
|
||||
|
||||
py_cvode_solver_strategy.def(
|
||||
"set_relTol",
|
||||
&gridfire::solver::CVODESolverStrategy::set_relTol,
|
||||
py::arg("relTol"),
|
||||
"Set the relative tolerance for the CVODE solver."
|
||||
);
|
||||
|
||||
py_cvode_solver_strategy.def(
|
||||
"get_absTol",
|
||||
&gridfire::solver::CVODESolverStrategy::get_absTol,
|
||||
"Get the absolute tolerance for the CVODE solver."
|
||||
);
|
||||
|
||||
py_cvode_solver_strategy.def(
|
||||
"get_relTol",
|
||||
&gridfire::solver::CVODESolverStrategy::get_relTol,
|
||||
"Get the relative tolerance for the CVODE solver."
|
||||
);
|
||||
|
||||
py_cvode_solver_strategy.def(
|
||||
"set_callback",
|
||||
[](
|
||||
gridfire::solver::CVODESolverStrategy& self,
|
||||
std::function<void(const gridfire::solver::CVODESolverStrategy::TimestepContext&)> cb
|
||||
) {
|
||||
self.set_callback(std::any(cb));
|
||||
},
|
||||
py::arg("cb"),
|
||||
"Set a callback function which will run at the end of every successful timestep"
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
# Define the library
|
||||
bindings_sources = files('bindings.cpp')
|
||||
bindings_headers = files('bindings.h')
|
||||
|
||||
dependencies = [
|
||||
gridfire_dep,
|
||||
python3_dep,
|
||||
pybind11_dep,
|
||||
]
|
||||
|
||||
shared_module('py_gf_solver',
|
||||
bindings_sources,
|
||||
cpp_args: ['-fvisibility=default'],
|
||||
install : true,
|
||||
dependencies: dependencies,
|
||||
include_directories: include_directories('.')
|
||||
)
|
||||
@@ -1,21 +0,0 @@
|
||||
gf_solver_trampoline_sources = files('py_solver.cpp')
|
||||
|
||||
gf_solver_trapoline_dependencies = [
|
||||
gridfire_dep,
|
||||
pybind11_dep,
|
||||
python3_dep,
|
||||
]
|
||||
|
||||
gf_solver_trampoline_lib = static_library(
|
||||
'solver_trampolines',
|
||||
gf_solver_trampoline_sources,
|
||||
include_directories: include_directories('.'),
|
||||
dependencies: gf_solver_trapoline_dependencies,
|
||||
install: false,
|
||||
)
|
||||
|
||||
gr_solver_trampoline_dep = declare_dependency(
|
||||
link_with: gf_solver_trampoline_lib,
|
||||
include_directories: ('.'),
|
||||
dependencies: gf_solver_trapoline_dependencies,
|
||||
)
|
||||
@@ -13,38 +13,63 @@
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
gridfire::NetOut PyDynamicNetworkSolverStrategy::evaluate(const gridfire::NetIn &netIn) {
|
||||
gridfire::NetOut PySingleZoneDynamicNetworkSolver::evaluate(
|
||||
gridfire::solver::SolverContextBase &solver_ctx,
|
||||
const gridfire::NetIn &netIn
|
||||
) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
gridfire::NetOut, // Return type
|
||||
gridfire::solver::DynamicNetworkSolverStrategy, // Base class
|
||||
evaluate, // Method name
|
||||
netIn // Arguments
|
||||
gridfire::NetOut,
|
||||
gridfire::solver::SingleZoneDynamicNetworkSolver,
|
||||
evaluate,
|
||||
solver_ctx,
|
||||
netIn
|
||||
);
|
||||
}
|
||||
|
||||
void PyDynamicNetworkSolverStrategy::set_callback(const std::any &callback) {
|
||||
std::vector<gridfire::NetOut> PyMultiZoneDynamicNetworkSolver::evaluate(
|
||||
gridfire::solver::SolverContextBase &solver_ctx,
|
||||
const std::vector<gridfire::NetIn> &netIns
|
||||
) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
void,
|
||||
gridfire::solver::DynamicNetworkSolverStrategy, // Base class
|
||||
set_callback, // Method name
|
||||
callback // Arguments
|
||||
std::vector<gridfire::NetOut>,
|
||||
gridfire::solver::MultiZoneDynamicNetworkSolver,
|
||||
evaluate,
|
||||
solver_ctx,
|
||||
netIns
|
||||
);
|
||||
}
|
||||
|
||||
std::vector<std::tuple<std::string, std::string>> PyDynamicNetworkSolverStrategy::describe_callback_context() const {
|
||||
using DescriptionVector = std::vector<std::tuple<std::string, std::string>>;
|
||||
std::vector<std::tuple<std::string, std::string>> PyTimestepContextBase::describe() const {
|
||||
using ReturnType = std::vector<std::tuple<std::string, std::string>>;
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
DescriptionVector, // Return type
|
||||
gridfire::solver::DynamicNetworkSolverStrategy, // Base class
|
||||
describe_callback_context // Method name
|
||||
);
|
||||
}
|
||||
|
||||
std::vector<std::tuple<std::string, std::string>> PySolverContextBase::describe() const {
|
||||
using DescriptionVector = std::vector<std::tuple<std::string, std::string>>;
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
DescriptionVector,
|
||||
gridfire::solver::SolverContextBase,
|
||||
ReturnType,
|
||||
gridfire::solver::TimestepContextBase,
|
||||
describe
|
||||
);
|
||||
}
|
||||
|
||||
void PySolverContextBase::init() {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
void,
|
||||
gridfire::solver::SolverContextBase,
|
||||
init
|
||||
);
|
||||
}
|
||||
|
||||
void PySolverContextBase::set_stdout_logging(bool enable) {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
void,
|
||||
gridfire::solver::SolverContextBase,
|
||||
set_stdout_logging,
|
||||
enable
|
||||
);
|
||||
}
|
||||
|
||||
void PySolverContextBase::set_detailed_logging(bool enable) {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
void,
|
||||
gridfire::solver::SolverContextBase,
|
||||
set_detailed_logging,
|
||||
enable
|
||||
);
|
||||
}
|
||||
@@ -7,14 +7,37 @@
|
||||
#include <string>
|
||||
#include <any>
|
||||
|
||||
class PyDynamicNetworkSolverStrategy final : public gridfire::solver::DynamicNetworkSolverStrategy {
|
||||
explicit PyDynamicNetworkSolverStrategy(gridfire::engine::DynamicEngine &engine) : gridfire::solver::DynamicNetworkSolverStrategy(engine) {}
|
||||
gridfire::NetOut evaluate(const gridfire::NetIn &netIn) override;
|
||||
void set_callback(const std::any &callback) override;
|
||||
[[nodiscard]] std::vector<std::tuple<std::string, std::string>> describe_callback_context() const override;
|
||||
class PySingleZoneDynamicNetworkSolver final : public gridfire::solver::SingleZoneDynamicNetworkSolver {
|
||||
public:
|
||||
explicit PySingleZoneDynamicNetworkSolver(const gridfire::engine::DynamicEngine &engine) : gridfire::solver::SingleZoneDynamicNetworkSolver(engine) {}
|
||||
|
||||
gridfire::NetOut evaluate(
|
||||
gridfire::solver::SolverContextBase &solver_ctx,
|
||||
const gridfire::NetIn &netIn
|
||||
) const override;
|
||||
};
|
||||
|
||||
class PyMultiZoneDynamicNetworkSolver final : public gridfire::solver::MultiZoneDynamicNetworkSolver {
|
||||
public:
|
||||
explicit PyMultiZoneDynamicNetworkSolver(
|
||||
const gridfire::engine::DynamicEngine &engine,
|
||||
const gridfire::solver::SingleZoneDynamicNetworkSolver &local_solver
|
||||
) : gridfire::solver::MultiZoneDynamicNetworkSolver(engine, local_solver) {}
|
||||
|
||||
std::vector<gridfire::NetOut> evaluate(
|
||||
gridfire::solver::SolverContextBase &solver_ctx,
|
||||
const std::vector<gridfire::NetIn> &netIns
|
||||
) const override;
|
||||
};
|
||||
|
||||
class PyTimestepContextBase final : public gridfire::solver::TimestepContextBase {
|
||||
public:
|
||||
[[nodiscard]] std::vector<std::tuple<std::string, std::string>> describe() const override;
|
||||
};
|
||||
|
||||
class PySolverContextBase final : public gridfire::solver::SolverContextBase {
|
||||
public:
|
||||
[[nodiscard]] std::vector<std::tuple<std::string, std::string>> describe() const override;
|
||||
};
|
||||
void init() override;
|
||||
void set_stdout_logging(bool enable) override;
|
||||
void set_detailed_logging(bool enable) override;
|
||||
};
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
# Define the library
|
||||
bindings_sources = files('bindings.cpp')
|
||||
bindings_headers = files('bindings.h')
|
||||
|
||||
dependencies = [
|
||||
gridfire_dep,
|
||||
python3_dep,
|
||||
pybind11_dep,
|
||||
]
|
||||
|
||||
shared_module('py_gf_types',
|
||||
bindings_sources,
|
||||
cpp_args: ['-fvisibility=default'],
|
||||
install : true,
|
||||
dependencies: dependencies,
|
||||
include_directories: include_directories('.')
|
||||
)
|
||||
@@ -12,6 +12,7 @@ namespace py = pybind11;
|
||||
void register_utils_bindings(py::module &m) {
|
||||
m.def("formatNuclearTimescaleLogString",
|
||||
&gridfire::utils::formatNuclearTimescaleLogString,
|
||||
py::arg("ctx"),
|
||||
py::arg("engine"),
|
||||
py::arg("Y"),
|
||||
py::arg("T9"),
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
# Define the library
|
||||
bindings_sources = files('bindings.cpp')
|
||||
bindings_headers = files('bindings.h')
|
||||
|
||||
dependencies = [
|
||||
gridfire_dep,
|
||||
python3_dep,
|
||||
pybind11_dep,
|
||||
]
|
||||
|
||||
shared_module('py_gf_utils',
|
||||
bindings_sources,
|
||||
cpp_args: ['-fvisibility=default'],
|
||||
install : true,
|
||||
dependencies: dependencies,
|
||||
include_directories: include_directories('.')
|
||||
)
|
||||
Reference in New Issue
Block a user