feat(python): Python Bindings

Python Bindings are working again
This commit is contained in:
2025-12-20 16:02:52 -05:00
parent d65c237b26
commit 11a596b75b
78 changed files with 4411 additions and 1110 deletions

View File

@@ -8,34 +8,8 @@ namespace gridfire::config {
double relTol = 1.0e-5;
};
struct SpectralSolverConfig {
struct Trigger {
double timestepCollapseRatio = 0.5;
size_t maxConvergenceFailures = 2;
double relativeFailureRate = 0.5;
size_t windowSize = 10;
};
struct MonitorFunctionConfig {
double structure_weight = 1.0;
double abundance_weight = 10.0;
double alpha = 0.2;
double beta = 0.8;
};
struct BasisConfig {
size_t num_elements = 50;
};
double absTol = 1.0e-8;
double relTol = 1.0e-5;
size_t degree = 3;
MonitorFunctionConfig monitorFunction;
BasisConfig basis;
Trigger trigger;
};
struct SolverConfig {
CVODESolverConfig cvode;
SpectralSolverConfig spectral;
};
struct AdaptiveEngineViewConfig {
@@ -57,5 +31,4 @@ namespace gridfire::config {
};
}

View File

@@ -25,6 +25,7 @@
#include <set>
#include "gridfire/engine/types/engine_types.h"
#include "gridfire/engine/scratchpads/blob.h"
namespace gridfire::policy {

View File

@@ -20,6 +20,9 @@ namespace gridfire::solver {
void set_callback(const std::function<void(const TimestepContextBase&)> &callback);
void set_callback(const std::function<void(const TimestepContextBase&)> &callback, size_t zone_idx);
void clear_callback();
void clear_callback(size_t zone_idx);
void set_stdout_logging(bool enable) override;
void set_detailed_logging(bool enable) override;

View File

@@ -28,6 +28,19 @@ namespace gridfire::solver {
timestep_callbacks[zone_idx] = callback;
}
void GridSolverContext::clear_callback() {
for (auto &cb : timestep_callbacks) {
cb = nullptr;
}
}
void GridSolverContext::clear_callback(const size_t zone_idx) {
if (zone_idx >= timestep_callbacks.size()) {
throw exceptions::SolverError("GridSolverContext::clear_callback: zone_idx out of range.");
}
timestep_callbacks[zone_idx] = nullptr;
}
void GridSolverContext::set_stdout_logging(const bool enable) {
zone_stdout_logging = enable;
}

View File

@@ -58,12 +58,21 @@ if get_option('openmp_support')
endif
# Define the libnetwork library so it can be linked against by other parts of the build system
libgridfire = library('gridfire',
gridfire_sources,
include_directories: include_directories('include'),
dependencies: gridfire_build_dependencies,
objects: [cvode_objs, kinsol_objs],
install : true)
if get_option('build_python')
libgridfire = static_library('gridfire',
gridfire_sources,
include_directories: include_directories('include'),
dependencies: gridfire_build_dependencies,
objects: [cvode_objs, kinsol_objs],
install : false)
else
libgridfire = library('gridfire',
gridfire_sources,
include_directories: include_directories('include'),
dependencies: gridfire_build_dependencies,
objects: [cvode_objs, kinsol_objs],
install : true)
endif
gridfire_dep = declare_dependency(
include_directories: include_directories('include'),

View File

@@ -4,6 +4,7 @@
#include "types/bindings.h"
#include "partition/bindings.h"
#include "engine/bindings.h"
#include "engine/scratchpads/bindings.h"
#include "exceptions/bindings.h"
#include "io/bindings.h"
#include "reaction/bindings.h"
@@ -11,6 +12,7 @@
#include "solver/bindings.h"
#include "utils/bindings.h"
#include "policy/bindings.h"
#include "config/bindings.h"
PYBIND11_MODULE(_gridfire, m) {
m.doc() = "Python bindings for the fourdst utility modules which are a part of the 4D-STAR project.";
@@ -20,6 +22,9 @@ PYBIND11_MODULE(_gridfire, m) {
pybind11::module::import("fourdst.config");
pybind11::module::import("fourdst.atomic");
auto configMod = m.def_submodule("config", "GridFire configuration bindings");
register_config_bindings(configMod);
auto typeMod = m.def_submodule("type", "GridFire type bindings");
register_type_bindings(typeMod);
@@ -39,6 +44,12 @@ PYBIND11_MODULE(_gridfire, m) {
register_exception_bindings(exceptionMod);
auto engineMod = m.def_submodule("engine", "Engine and Engine View bindings");
auto scratchpadMod = engineMod.def_submodule("scratchpads", "Engine ScratchPad bindings");
register_scratchpad_types_bindings(scratchpadMod);
register_scratchpad_bindings(scratchpadMod);
register_state_blob_bindings(scratchpadMod);
register_engine_bindings(engineMod);
auto solverMod = m.def_submodule("solver", "GridFire numerical solver bindings");

View File

@@ -0,0 +1,34 @@
#include "bindings.h"
#include "gridfire/config/config.h"
#include <pybind11/pybind11.h>
namespace py = pybind11;
void register_config_bindings(pybind11::module &m) {
py::class_<gridfire::config::CVODESolverConfig>(m, "CVODESolverConfig")
.def(py::init<>())
.def_readwrite("absTol", &gridfire::config::CVODESolverConfig::absTol)
.def_readwrite("relTol", &gridfire::config::CVODESolverConfig::relTol);
py::class_<gridfire::config::SolverConfig>(m, "SolverConfig")
.def(py::init<>())
.def_readwrite("cvode", &gridfire::config::SolverConfig::cvode);
py::class_<gridfire::config::AdaptiveEngineViewConfig>(m, "AdaptiveEngineViewConfig")
.def(py::init<>())
.def_readwrite("relativeCullingThreshold", &gridfire::config::AdaptiveEngineViewConfig::relativeCullingThreshold);
py::class_<gridfire::config::EngineViewConfig>(m, "EngineViewConfig")
.def(py::init<>())
.def_readwrite("adaptiveEngineView", &gridfire::config::EngineViewConfig::adaptiveEngineView);
py::class_<gridfire::config::EngineConfig>(m, "EngineConfig")
.def(py::init<>())
.def_readwrite("views", &gridfire::config::EngineConfig::views);
py::class_<gridfire::config::GridFireConfig>(m, "GridFireConfig")
.def(py::init<>())
.def_readwrite("solver", &gridfire::config::GridFireConfig::solver)
.def_readwrite("engine", &gridfire::config::GridFireConfig::engine);
}

View File

@@ -0,0 +1,5 @@
#pragma once
#include <pybind11/pybind11.h>
void register_config_bindings(pybind11::module &m);

View File

@@ -12,6 +12,7 @@
namespace py = pybind11;
namespace sp = gridfire::engine::scratch;
namespace {
template <typename T>
@@ -23,16 +24,18 @@ namespace {
"calculateRHSAndEnergy",
[](
const gridfire::engine::DynamicEngine& self,
sp::StateBlob& ctx,
const fourdst::composition::Composition& comp,
const double T9,
const double rho
) {
auto result = self.calculateRHSAndEnergy(comp, T9, rho);
auto result = self.calculateRHSAndEnergy(ctx, comp, T9, rho, false);
if (!result.has_value()) {
throw gridfire::exceptions::EngineError(std::format("calculateRHSAndEnergy returned a potentially recoverable error {}", gridfire::engine::EngineStatus_to_string(result.error())));
}
return result.value();
},
py::arg("ctx"),
py::arg("comp"),
py::arg("T9"),
py::arg("rho"),
@@ -40,6 +43,7 @@ namespace {
)
.def("calculateEpsDerivatives",
&gridfire::engine::DynamicEngine::calculateEpsDerivatives,
py::arg("ctx"),
py::arg("comp"),
py::arg("T9"),
py::arg("rho"),
@@ -47,11 +51,13 @@ namespace {
)
.def("generateJacobianMatrix",
[](const gridfire::engine::DynamicEngine& self,
sp::StateBlob& ctx,
const fourdst::composition::Composition& comp,
const double T9,
const double rho) -> gridfire::engine::NetworkJacobian {
return self.generateJacobianMatrix(comp, T9, rho);
return self.generateJacobianMatrix(ctx, comp, T9, rho);
},
py::arg("ctx"),
py::arg("comp"),
py::arg("T9"),
py::arg("rho"),
@@ -59,12 +65,14 @@ namespace {
)
.def("generateJacobianMatrix",
[](const gridfire::engine::DynamicEngine& self,
sp::StateBlob& ctx,
const fourdst::composition::Composition& comp,
const double T9,
const double rho,
const std::vector<fourdst::atomic::Species>& activeSpecies) -> gridfire::engine::NetworkJacobian {
return self.generateJacobianMatrix(comp, T9, rho, activeSpecies);
return self.generateJacobianMatrix(ctx, comp, T9, rho, activeSpecies);
},
py::arg("ctx"),
py::arg("comp"),
py::arg("T9"),
py::arg("rho"),
@@ -73,31 +81,32 @@ namespace {
)
.def("generateJacobianMatrix",
[](const gridfire::engine::DynamicEngine& self,
sp::StateBlob& ctx,
const fourdst::composition::Composition& comp,
const double T9,
const double rho,
const gridfire::engine::SparsityPattern& sparsityPattern) -> gridfire::engine::NetworkJacobian {
return self.generateJacobianMatrix(comp, T9, rho, sparsityPattern);
return self.generateJacobianMatrix(ctx, comp, T9, rho, sparsityPattern);
},
py::arg("ctx"),
py::arg("comp"),
py::arg("T9"),
py::arg("rho"),
py::arg("sparsityPattern"),
"Generate the jacobian matrix for the given sparsity pattern"
)
.def("generateStoichiometryMatrix",
&T::generateStoichiometryMatrix
)
.def("calculateMolarReactionFlow",
[](
const gridfire::engine::DynamicEngine& self,
sp::StateBlob& ctx,
const gridfire::reaction::Reaction& reaction,
const fourdst::composition::Composition& comp,
const double T9,
const double rho
) -> double {
return self.calculateMolarReactionFlow(reaction, comp, T9, rho);
return self.calculateMolarReactionFlow(ctx, reaction, comp, T9, rho);
},
py::arg("ctx"),
py::arg("reaction"),
py::arg("comp"),
py::arg("T9"),
@@ -110,28 +119,21 @@ namespace {
.def("getNetworkReactions", &T::getNetworkReactions,
"Get the set of logical reactions in the network."
)
.def ("setNetworkReactions", &T::setNetworkReactions,
py::arg("reactions"),
"Set the network reactions to a new set of reactions."
)
.def("getStoichiometryMatrixEntry", &T::getStoichiometryMatrixEntry,
py::arg("species"),
py::arg("reaction"),
"Get an entry from the stoichiometry matrix."
)
.def("getSpeciesTimescales",
[](
const gridfire::engine::DynamicEngine& self,
sp::StateBlob& ctx,
const fourdst::composition::Composition& comp,
const double T9,
const double rho
) -> std::unordered_map<fourdst::atomic::Species, double> {
const auto result = self.getSpeciesTimescales(comp, T9, rho);
const auto result = self.getSpeciesTimescales(ctx, comp, T9, rho);
if (!result.has_value()) {
throw gridfire::exceptions::EngineError(std::format("getSpeciesTimescales has returned a potentially recoverable error {}", gridfire::engine::EngineStatus_to_string(result.error())));
}
return result.value();
},
py::arg("ctx"),
py::arg("comp"),
py::arg("T9"),
py::arg("rho"),
@@ -140,67 +142,48 @@ namespace {
.def("getSpeciesDestructionTimescales",
[](
const gridfire::engine::DynamicEngine& self,
sp::StateBlob& ctx,
const fourdst::composition::Composition& comp,
const double T9,
const double rho
) -> std::unordered_map<fourdst::atomic::Species, double> {
const auto result = self.getSpeciesDestructionTimescales(comp, T9, rho);
const auto result = self.getSpeciesDestructionTimescales(ctx, comp, T9, rho);
if (!result.has_value()) {
throw gridfire::exceptions::EngineError(std::format("getSpeciesDestructionTimescales has returned a potentially recoverable error {}", gridfire::engine::EngineStatus_to_string(result.error())));
}
return result.value();
},
py::arg("ctx"),
py::arg("comp"),
py::arg("T9"),
py::arg("rho"),
"Get the destruction timescales for each species in the network."
)
.def("update",
&T::update,
.def("project",
&T::project,
py::arg("ctx"),
py::arg("netIn"),
"Update the engine state based on the provided NetIn object."
)
.def("setScreeningModel",
&T::setScreeningModel,
py::arg("screeningModel"),
"Set the screening model for the engine."
)
.def("getScreeningModel",
&T::getScreeningModel,
"Get the current screening model of the engine."
)
.def("getSpeciesIndex",
&T::getSpeciesIndex,
py::arg("ctx"),
py::arg("species"),
"Get the index of a species in the network."
)
.def("mapNetInToMolarAbundanceVector",
&T::mapNetInToMolarAbundanceVector,
py::arg("netIn"),
"Map a NetIn object to a vector of molar abundances."
)
.def("primeEngine",
&T::primeEngine,
py::arg("ctx"),
py::arg("netIn"),
"Prime the engine with a NetIn object to prepare for calculations."
)
.def("getDepth",
&T::getDepth,
"Get the current build depth of the engine."
)
.def("rebuild",
&T::rebuild,
py::arg("composition"),
py::arg("depth") = gridfire::engine::NetworkBuildDepth::Full,
"Rebuild the engine with a new composition and build depth."
)
.def("isStale",
&T::isStale,
py::arg("netIn"),
"Check if the engine is stale based on the provided NetIn object."
)
.def("collectComposition",
&T::collectComposition,
py::arg("ctx"),
py::arg("composition"),
py::arg("T9"),
py::arg("rho"),
@@ -208,6 +191,7 @@ namespace {
)
.def("getSpeciesStatus",
&T::getSpeciesStatus,
py::arg("ctx"),
py::arg("species"),
"Get the status of a species in the network."
);
@@ -253,6 +237,7 @@ void register_engine_diagnostic_bindings(pybind11::module &m) {
auto diagnostics = m.def_submodule("diagnostics", "A submodule for engine diagnostics");
diagnostics.def("report_limiting_species",
&gridfire::engine::diagnostics::report_limiting_species,
py::arg("ctx"),
py::arg("engine"),
py::arg("Y_full"),
py::arg("E_full"),
@@ -264,6 +249,7 @@ void register_engine_diagnostic_bindings(pybind11::module &m) {
diagnostics.def("inspect_species_balance",
&gridfire::engine::diagnostics::inspect_species_balance,
py::arg("ctx"),
py::arg("engine"),
py::arg("species_name"),
py::arg("comp"),
@@ -274,6 +260,7 @@ void register_engine_diagnostic_bindings(pybind11::module &m) {
diagnostics.def("inspect_jacobian_stiffness",
&gridfire::engine::diagnostics::inspect_jacobian_stiffness,
py::arg("ctx"),
py::arg("engine"),
py::arg("comp"),
py::arg("T9"),
@@ -311,6 +298,7 @@ void register_engine_construction_bindings(pybind11::module &m) {
void register_engine_priming_bindings(pybind11::module &m) {
m.def("primeNetwork",
&gridfire::engine::primeNetwork,
py::arg("ctx"),
py::arg("netIn"),
py::arg("engine"),
py::arg("ignoredReactionTypes") = std::nullopt,
@@ -456,19 +444,16 @@ void con_stype_register_graph_engine_bindings(const pybind11::module &m) {
py::arg("reactions"),
"Initialize GraphEngine with a set of reactions."
);
py_graph_engine_bindings.def_static("getNetReactionStoichiometry",
&gridfire::engine::GraphEngine::getNetReactionStoichiometry,
py::arg("reaction"),
"Get the net stoichiometry for a given reaction."
);
py_graph_engine_bindings.def("getSpeciesTimescales",
[](const gridfire::engine::GraphEngine& self,
sp::StateBlob& ctx,
const fourdst::composition::Composition& composition,
const double T9,
const double rho,
const gridfire::reaction::ReactionSet& activeReactions) {
return self.getSpeciesTimescales(composition, T9, rho, activeReactions);
return self.getSpeciesTimescales(ctx, composition, T9, rho, activeReactions);
},
py::arg("ctx"),
py::arg("composition"),
py::arg("T9"),
py::arg("rho"),
@@ -476,12 +461,14 @@ void con_stype_register_graph_engine_bindings(const pybind11::module &m) {
);
py_graph_engine_bindings.def("getSpeciesDestructionTimescales",
[](const gridfire::engine::GraphEngine& self,
sp::StateBlob& ctx,
const fourdst::composition::Composition& composition,
const double T9,
const double rho,
const gridfire::reaction::ReactionSet& activeReactions) {
return self.getSpeciesDestructionTimescales(composition, T9, rho, activeReactions);
return self.getSpeciesDestructionTimescales(ctx, composition, T9, rho, activeReactions);
},
py::arg("ctx"),
py::arg("composition"),
py::arg("T9"),
py::arg("rho"),
@@ -489,24 +476,22 @@ void con_stype_register_graph_engine_bindings(const pybind11::module &m) {
);
py_graph_engine_bindings.def("involvesSpecies",
&gridfire::engine::GraphEngine::involvesSpecies,
py::arg("ctx"),
py::arg("species"),
"Check if a given species is involved in the network."
);
py_graph_engine_bindings.def("exportToDot",
&gridfire::engine::GraphEngine::exportToDot,
py::arg("ctx"),
py::arg("filename"),
"Export the network to a DOT file for visualization."
);
py_graph_engine_bindings.def("exportToCSV",
&gridfire::engine::GraphEngine::exportToCSV,
py::arg("ctx"),
py::arg("filename"),
"Export the network to a CSV file for analysis."
);
py_graph_engine_bindings.def("setPrecomputation",
&gridfire::engine::GraphEngine::setPrecomputation,
py::arg("precompute"),
"Enable or disable precomputation for the engine."
);
py_graph_engine_bindings.def("isPrecomputationEnabled",
&gridfire::engine::GraphEngine::isPrecomputationEnabled,
"Check if precomputation is enabled for the engine."
@@ -544,11 +529,6 @@ void con_stype_register_graph_engine_bindings(const pybind11::module &m) {
&gridfire::engine::GraphEngine::isUsingReverseReactions,
"Check if the engine is using reverse reactions."
);
py_graph_engine_bindings.def("setUseReverseReactions",
&gridfire::engine::GraphEngine::setUseReverseReactions,
py::arg("useReverse"),
"Enable or disable the use of reverse reactions in the engine."
);
// Register the general dynamic engine bindings
registerDynamicEngineDefs<gridfire::engine::GraphEngine, gridfire::engine::DynamicEngine>(py_graph_engine_bindings);
@@ -587,11 +567,13 @@ void register_engine_view_bindings(const pybind11::module &m) {
registerDynamicEngineDefs<gridfire::engine::FileDefinedEngineView, gridfire::engine::DefinedEngineView>(py_file_defined_engine_view_bindings);
auto py_priming_engine_view_bindings = py::class_<gridfire::engine::NetworkPrimingEngineView, gridfire::engine::DefinedEngineView>(m, "NetworkPrimingEngineView");
py_priming_engine_view_bindings.def(py::init<const std::string&, gridfire::engine::GraphEngine&>(),
py_priming_engine_view_bindings.def(py::init<sp::StateBlob&, const std::string&, gridfire::engine::GraphEngine&>(),
py::arg("ctx"),
py::arg("primingSymbol"),
py::arg("baseEngine"),
"Construct a priming engine view with a priming symbol and a base engine.");
py_priming_engine_view_bindings.def(py::init<const fourdst::atomic::Species&, gridfire::engine::GraphEngine&>(),
py_priming_engine_view_bindings.def(py::init<sp::StateBlob&, const fourdst::atomic::Species&, gridfire::engine::GraphEngine&>(),
py::arg("ctx"),
py::arg("primingSpecies"),
py::arg("baseEngine"),
"Construct a priming engine view with a priming species and a base engine.");
@@ -622,15 +604,12 @@ void register_engine_view_bindings(const pybind11::module &m) {
);
py_multiscale_engine_view_bindings.def("partitionNetwork",
&gridfire::engine::MultiscalePartitioningEngineView::partitionNetwork,
py::arg("ctx"),
py::arg("netIn"),
"Partition the network based on species timescales and connectivity.");
py_multiscale_engine_view_bindings.def("partitionNetwork",
py::overload_cast<const gridfire::NetIn&>(&gridfire::engine::MultiscalePartitioningEngineView::partitionNetwork),
py::arg("netIn"),
"Partition the network based on a NetIn object."
);
py_multiscale_engine_view_bindings.def("exportToDot",
&gridfire::engine::MultiscalePartitioningEngineView::exportToDot,
py::arg("ctx"),
py::arg("filename"),
py::arg("comp"),
py::arg("T9"),
@@ -661,7 +640,16 @@ void register_engine_view_bindings(const pybind11::module &m) {
"Check if a given species is involved in the network's dynamic set."
);
py_multiscale_engine_view_bindings.def("getNormalizedEquilibratedComposition",
&gridfire::engine::MultiscalePartitioningEngineView::getNormalizedEquilibratedComposition,
[](
const gridfire::engine::MultiscalePartitioningEngineView& self,
sp::StateBlob& ctx,
const fourdst::composition::Composition& comp,
const double T9,
const double rho
) {
return self.getNormalizedEquilibratedComposition(ctx, comp, T9, rho, false);
},
py::arg("ctx"),
py::arg("comp"),
py::arg("T9"),
py::arg("rho"),

View File

@@ -1,21 +0,0 @@
subdir('trampoline')
# Define the library
bindings_sources = files('bindings.cpp')
bindings_headers = files('bindings.h')
dependencies = [
gridfire_dep,
python3_dep,
pybind11_dep,
]
message('⏳ Python bindings for GridFire Engine are being registered...')
shared_module('py_gf_engine',
bindings_sources,
cpp_args: ['-fvisibility=default'],
install : true,
dependencies: dependencies,
include_directories: include_directories('.')
)
message('✅ Python bindings for GridFire Engine registered successfully!')

View File

@@ -0,0 +1,152 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc. if needed directly
#include "gridfire/engine/scratchpads/scratchpads.h"
#include "bindings.h"
namespace py = pybind11;
namespace sp = gridfire::engine::scratch;
template<typename... ScratchPadTypes>
void build_state_getter(py::module& m) {
}
void register_scratchpad_types_bindings(pybind11::module &m) {
py::enum_<sp::ScratchPadType>(m, "ScratchPadType")
.value("GRAPH_ENGINE_SCRATCHPAD", sp::ScratchPadType::GRAPH_ENGINE_SCRATCHPAD)
.value("MULTISCALE_PARTITIONING_ENGINE_VIEW_SCRATCHPAD", sp::ScratchPadType::MULTISCALE_PARTITIONING_ENGINE_VIEW_SCRATCHPAD)
.value("ADAPTIVE_ENGINE_VIEW_SCRATCHPAD", sp::ScratchPadType::ADAPTIVE_ENGINE_VIEW_SCRATCHPAD)
.value("DEFINED_ENGINE_VIEW_SCRATCHPAD", sp::ScratchPadType::DEFINED_ENGINE_VIEW_SCRATCHPAD)
.export_values();
}
void register_scratchpad_bindings(pybind11::module_ &m) {
py::enum_<sp::GraphEngineScratchPad::ADFunRegistrationResult>(m, "ADFunRegistrationResult")
.value("SUCCESS", sp::GraphEngineScratchPad::ADFunRegistrationResult::SUCCESS)
.value("ALREADY_REGISTERED", sp::GraphEngineScratchPad::ADFunRegistrationResult::ALREADY_REGISTERED)
.export_values();
py::class_<sp::GraphEngineScratchPad>(m, "GraphEngineScratchPad")
.def(py::init<>())
.def("initialize", &sp::GraphEngineScratchPad::initialize, py::arg("engine"))
.def("clone", &sp::GraphEngineScratchPad::clone)
.def("is_initialized", &sp::GraphEngineScratchPad::is_initialized)
.def_readonly("most_recent_rhs_calculation", &sp::GraphEngineScratchPad::most_recent_rhs_calculation)
.def_readonly("local_abundance_cache", &sp::GraphEngineScratchPad::local_abundance_cache)
.def_readonly("has_initialized", &sp::GraphEngineScratchPad::has_initialized)
.def_readonly("stepDerivativesCache", &sp::GraphEngineScratchPad::stepDerivativesCache)
.def_readonly_static("ID", &sp::GraphEngineScratchPad::ID)
.def("__repr__", [](const sp::GraphEngineScratchPad &self) {
return std::format("{}", self);
});
py::class_<sp::MultiscalePartitioningEngineViewScratchPad>(m, "MultiscalePartitioningEngineViewScratchPad")
.def(py::init<>())
.def("initialize", &sp::MultiscalePartitioningEngineViewScratchPad::initialize)
.def("clone", &sp::MultiscalePartitioningEngineViewScratchPad::clone)
.def("is_initialized", &sp::MultiscalePartitioningEngineViewScratchPad::is_initialized)
.def_readonly("qse_groups", &sp::MultiscalePartitioningEngineViewScratchPad::qse_groups)
.def_readonly("dynamic_species", &sp::MultiscalePartitioningEngineViewScratchPad::dynamic_species)
.def_readonly("algebraic_species", &sp::MultiscalePartitioningEngineViewScratchPad::algebraic_species)
.def_readonly("composition_cache", &sp::MultiscalePartitioningEngineViewScratchPad::composition_cache)
.def_readonly("has_initialized", &sp::MultiscalePartitioningEngineViewScratchPad::has_initialized)
.def_readonly_static("ID", &sp::MultiscalePartitioningEngineViewScratchPad::ID)
.def("__repr__", [](const sp::MultiscalePartitioningEngineViewScratchPad &self) {
return std::format("{}", self);
});
py::class_<sp::AdaptiveEngineViewScratchPad>(m, "AdaptiveEngineViewScratchPad")
.def(py::init<>())
.def("initialize", &sp::AdaptiveEngineViewScratchPad::initialize)
.def("clone", &sp::AdaptiveEngineViewScratchPad::clone)
.def("is_initialized", &sp::AdaptiveEngineViewScratchPad::is_initialized)
.def_readonly("active_species", &sp::AdaptiveEngineViewScratchPad::active_species)
.def_readonly("active_reactions", &sp::AdaptiveEngineViewScratchPad::active_reactions)
.def_readonly("has_initialized", &sp::AdaptiveEngineViewScratchPad::has_initialized)
.def_readonly_static("ID", &sp::AdaptiveEngineViewScratchPad::ID)
.def("__repr__", [](const sp::AdaptiveEngineViewScratchPad &self) {
return std::format("{}", self);
});
py::class_<sp::DefinedEngineViewScratchPad>(m, "DefinedEngineViewScratchPad")
.def(py::init<>())
.def("clone", &sp::DefinedEngineViewScratchPad::clone)
.def("is_initialized", &sp::DefinedEngineViewScratchPad::is_initialized)
.def_readonly("active_species", &sp::DefinedEngineViewScratchPad::active_species)
.def_readonly("active_reactions", &sp::DefinedEngineViewScratchPad::active_reactions)
.def_readonly("species_index_map", &sp::DefinedEngineViewScratchPad::species_index_map)
.def_readonly("reaction_index_map", &sp::DefinedEngineViewScratchPad::reaction_index_map)
.def_readonly("has_initialized", &sp::DefinedEngineViewScratchPad::has_initialized)
.def_readonly_static("ID", &sp::DefinedEngineViewScratchPad::ID)
.def("__repr__", [](const sp::DefinedEngineViewScratchPad &self) {
return std::format("{}", self);
});
}
void register_state_blob_bindings(pybind11::module_ &m) {
py::enum_<sp::StateBlob::Error>(m, "StateBlobError")
.value("SCRATCHPAD_OUT_OF_BOUNDS", sp::StateBlob::Error::SCRATCHPAD_OUT_OF_BOUNDS)
.value("SCRATCHPAD_NOT_FOUND", sp::StateBlob::Error::SCRATCHPAD_NOT_FOUND)
.value("SCRATCHPAD_BAD_CAST", sp::StateBlob::Error::SCRATCHPAD_BAD_CAST)
.value("SCRATCHPAD_NOT_INITIALIZED", sp::StateBlob::Error::SCRATCHPAD_NOT_INITIALIZED)
.value("SCRATCHPAD_TYPE_COLLISION", sp::StateBlob::Error::SCRATCHPAD_TYPE_COLLISION)
.value("SCRATCHPAD_UNKNOWN_ERROR", sp::StateBlob::Error::SCRATCHPAD_UNKNOWN_ERROR)
.export_values();
py::class_<sp::StateBlob>(m, "StateBlob")
.def(py::init<>())
.def("enroll", [](sp::StateBlob &self, const sp::ScratchPadType type) {
switch (type) {
case sp::ScratchPadType::GRAPH_ENGINE_SCRATCHPAD:
self.enroll<sp::GraphEngineScratchPad>();
break;
case sp::ScratchPadType::MULTISCALE_PARTITIONING_ENGINE_VIEW_SCRATCHPAD:
self.enroll<sp::MultiscalePartitioningEngineViewScratchPad>();
break;
case sp::ScratchPadType::ADAPTIVE_ENGINE_VIEW_SCRATCHPAD:
self.enroll<sp::AdaptiveEngineViewScratchPad>();
break;
case sp::ScratchPadType::DEFINED_ENGINE_VIEW_SCRATCHPAD:
self.enroll<sp::DefinedEngineViewScratchPad>();
break;
default:
throw std::invalid_argument("Unknown ScratchPadType for enrollment.");
}
})
.def("get", [](const sp::StateBlob &self, const sp::ScratchPadType type) {
auto result = self.get(type);
if (!result.has_value()) {
throw std::runtime_error("Error retrieving scratchpad: " + sp::StateBlob::error_to_string(result.error()));
}
return result.value();
},
pybind11::return_value_policy::reference_internal
)
.def("clone_structure", &sp::StateBlob::clone_structure)
.def("get_registered_scratchpads", &sp::StateBlob::get_registered_scratchpads)
.def("get_status", [](const sp::StateBlob &self, const sp::ScratchPadType type) -> sp::StateBlob::ScratchPadStatus {
switch (type) {
case sp::ScratchPadType::GRAPH_ENGINE_SCRATCHPAD:
return self.get_status<sp::GraphEngineScratchPad>();
case sp::ScratchPadType::MULTISCALE_PARTITIONING_ENGINE_VIEW_SCRATCHPAD:
return self.get_status<sp::MultiscalePartitioningEngineViewScratchPad>();
case sp::ScratchPadType::ADAPTIVE_ENGINE_VIEW_SCRATCHPAD:
return self.get_status<sp::AdaptiveEngineViewScratchPad>();
case sp::ScratchPadType::DEFINED_ENGINE_VIEW_SCRATCHPAD:
return self.get_status<sp::DefinedEngineViewScratchPad>();
default:
throw std::invalid_argument("Unknown ScratchPadType for status retrieval.");
}
})
.def("get_status_map", &sp::StateBlob::get_status_map)
.def_static("error_to_string", &sp::StateBlob::error_to_string)
.def("__repr__", [](const sp::StateBlob &self) {
return std::format("{}", self);
});
}

View File

@@ -0,0 +1,7 @@
#pragma once
#include <pybind11/pybind11.h>
void register_scratchpad_types_bindings(pybind11::module_& m);
void register_scratchpad_bindings(pybind11::module_& m);
void register_state_blob_bindings(pybind11::module_& m);

View File

@@ -1,21 +0,0 @@
gf_engine_trampoline_sources = files('py_engine.cpp')
gf_engine_trapoline_dependencies = [
gridfire_dep,
pybind11_dep,
python3_dep,
]
gf_engine_trampoline_lib = static_library(
'engine_trampolines',
gf_engine_trampoline_sources,
include_directories: include_directories('.'),
dependencies: gf_engine_trapoline_dependencies,
install: false,
)
gr_engine_trampoline_dep = declare_dependency(
link_with: gf_engine_trampoline_lib,
include_directories: ('.'),
dependencies: gf_engine_trapoline_dependencies,
)

View File

@@ -13,36 +13,29 @@
namespace py = pybind11;
const std::vector<fourdst::atomic::Species>& PyEngine::getNetworkSpecies() const {
/*
* Acquire the GIL (Global Interpreter Lock) for thread safety
* with the Python interpreter.
*/
py::gil_scoped_acquire gil;
/*
* get_override() looks for a Python method that overrides this C++ one.
*/
if (const py::function override = py::get_override(this, "getNetworkSpecies")) {
const py::object result = override();
m_species_cache = result.cast<std::vector<fourdst::atomic::Species>>();
return m_species_cache;
}
py::pybind11_fail("Tried to call pure virtual function \"DynamicEngine::getNetworkSpecies\"");
const std::vector<fourdst::atomic::Species>& PyEngine::getNetworkSpecies(
gridfire::engine::scratch::StateBlob& ctx
) const {
PYBIND11_OVERRIDE_PURE(
const std::vector<fourdst::atomic::Species>&,
gridfire::engine::Engine,
getNetworkSpecies,
ctx
);
}
std::expected<gridfire::engine::StepDerivatives<double>, gridfire::engine::EngineStatus> PyEngine::calculateRHSAndEnergy(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho
double rho,
bool trust
) const {
PYBIND11_OVERRIDE_PURE(
PYBIND11_TYPE(std::expected<gridfire::engine::StepDerivatives<double>, gridfire::engine::EngineStatus>),
gridfire::engine::Engine,
calculateRHSAndEnergy,
comp, T9, rho
ctx, comp, T9, rho, trust
);
}
@@ -50,41 +43,35 @@ std::expected<gridfire::engine::StepDerivatives<double>, gridfire::engine::Engin
/// PyDynamicEngine Implementation ///
/////////////////////////////////////
const std::vector<fourdst::atomic::Species>& PyDynamicEngine::getNetworkSpecies() const {
/*
* Acquire the GIL (Global Interpreter Lock) for thread safety
* with the Python interpreter.
*/
py::gil_scoped_acquire gil;
/*
* get_override() looks for a Python method that overrides this C++ one.
*/
if (const py::function override = py::get_override(this, "getNetworkSpecies")) {
const py::object result = override();
m_species_cache = result.cast<std::vector<fourdst::atomic::Species>>();
return m_species_cache;
}
py::pybind11_fail("Tried to call pure virtual function \"DynamicEngine::getNetworkSpecies\"");
const std::vector<fourdst::atomic::Species>& PyDynamicEngine::getNetworkSpecies(
gridfire::engine::scratch::StateBlob& ctx
) const {
PYBIND11_OVERRIDE_PURE(
const std::vector<fourdst::atomic::Species>&,
gridfire::engine::DynamicEngine,
getNetworkSpecies,
ctx
);
}
std::expected<gridfire::engine::StepDerivatives<double>, gridfire::engine::EngineStatus> PyDynamicEngine::calculateRHSAndEnergy(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho
double rho,
bool trust
) const {
PYBIND11_OVERRIDE_PURE(
PYBIND11_TYPE(std::expected<gridfire::engine::StepDerivatives<double>, gridfire::engine::EngineStatus>),
gridfire::engine::DynamicEngine,
calculateRHSAndEnergy,
comp, T9, rho
ctx, comp, T9, rho, trust
);
}
gridfire::engine::NetworkJacobian PyDynamicEngine::generateJacobianMatrix(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract& comp,
double T9,
double rho
@@ -100,6 +87,7 @@ gridfire::engine::NetworkJacobian PyDynamicEngine::generateJacobianMatrix(
}
gridfire::engine::NetworkJacobian PyDynamicEngine::generateJacobianMatrix(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
const double rho,
@@ -109,6 +97,7 @@ gridfire::engine::NetworkJacobian PyDynamicEngine::generateJacobianMatrix(
gridfire::engine::NetworkJacobian,
gridfire::engine::DynamicEngine,
generateJacobianMatrix,
ctx,
comp,
T9,
rho,
@@ -117,6 +106,7 @@ gridfire::engine::NetworkJacobian PyDynamicEngine::generateJacobianMatrix(
}
gridfire::engine::NetworkJacobian PyDynamicEngine::generateJacobianMatrix(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho,
@@ -126,6 +116,7 @@ gridfire::engine::NetworkJacobian PyDynamicEngine::generateJacobianMatrix(
gridfire::engine::NetworkJacobian,
gridfire::engine::DynamicEngine,
generateJacobianMatrix,
ctx,
comp,
T9,
rho,
@@ -133,28 +124,8 @@ gridfire::engine::NetworkJacobian PyDynamicEngine::generateJacobianMatrix(
);
}
void PyDynamicEngine::generateStoichiometryMatrix() {
PYBIND11_OVERRIDE_PURE(
void,
gridfire::engine::DynamicEngine,
generateStoichiometryMatrix
);
}
int PyDynamicEngine::getStoichiometryMatrixEntry(
const fourdst::atomic::Species& species,
const gridfire::reaction::Reaction& reaction
) const {
PYBIND11_OVERRIDE_PURE(
int,
gridfire::engine::DynamicEngine,
getStoichiometryMatrixEntry,
species,
reaction
);
}
double PyDynamicEngine::calculateMolarReactionFlow(
gridfire::engine::scratch::StateBlob& ctx,
const gridfire::reaction::Reaction &reaction,
const fourdst::composition::CompositionAbstract &comp,
double T9,
@@ -164,6 +135,7 @@ double PyDynamicEngine::calculateMolarReactionFlow(
double,
gridfire::engine::DynamicEngine,
calculateMolarReactionFlow,
ctx,
reaction,
comp,
T9,
@@ -171,24 +143,19 @@ double PyDynamicEngine::calculateMolarReactionFlow(
);
}
const gridfire::reaction::ReactionSet& PyDynamicEngine::getNetworkReactions() const {
const gridfire::reaction::ReactionSet& PyDynamicEngine::getNetworkReactions(
gridfire::engine::scratch::StateBlob& ctx
) const {
PYBIND11_OVERRIDE_PURE(
const gridfire::reaction::ReactionSet&,
gridfire::engine::DynamicEngine,
getNetworkReactions
);
}
void PyDynamicEngine::setNetworkReactions(const gridfire::reaction::ReactionSet& reactions) {
PYBIND11_OVERRIDE_PURE(
void,
gridfire::engine::DynamicEngine,
setNetworkReactions,
reactions
getNetworkReactions,
ctx
);
}
std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::engine::EngineStatus> PyDynamicEngine::getSpeciesTimescales(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho
@@ -197,6 +164,7 @@ std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::en
PYBIND11_TYPE(std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::engine::EngineStatus>),
gridfire::engine::DynamicEngine,
getSpeciesTimescales,
ctx,
comp,
T9,
rho
@@ -204,6 +172,7 @@ std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::en
}
std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::engine::EngineStatus> PyDynamicEngine::getSpeciesDestructionTimescales(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho
@@ -212,80 +181,71 @@ std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::en
PYBIND11_TYPE(std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::engine::EngineStatus>),
gridfire::engine::DynamicEngine,
getSpeciesDestructionTimescales,
comp, T9, rho
ctx, comp, T9, rho
);
}
fourdst::composition::Composition PyDynamicEngine::update(const gridfire::NetIn &netIn) {
fourdst::composition::Composition PyDynamicEngine::project(
gridfire::engine::scratch::StateBlob& ctx,
const gridfire::NetIn &netIn
) const {
PYBIND11_OVERRIDE_PURE(
fourdst::composition::Composition,
gridfire::engine::DynamicEngine,
update,
project,
ctx,
netIn
);
}
bool PyDynamicEngine::isStale(const gridfire::NetIn &netIn) {
PYBIND11_OVERRIDE_PURE(
bool,
gridfire::engine::DynamicEngine,
isStale,
netIn
);
}
void PyDynamicEngine::setScreeningModel(gridfire::screening::ScreeningType model) {
PYBIND11_OVERRIDE_PURE(
void,
gridfire::engine::DynamicEngine,
setScreeningModel,
model
);
}
gridfire::screening::ScreeningType PyDynamicEngine::getScreeningModel() const {
gridfire::screening::ScreeningType PyDynamicEngine::getScreeningModel(
gridfire::engine::scratch::StateBlob& ctx
) const {
PYBIND11_OVERRIDE_PURE(
gridfire::screening::ScreeningType,
gridfire::engine::DynamicEngine,
getScreeningModel
getScreeningModel,
ctx
);
}
size_t PyDynamicEngine::getSpeciesIndex(const fourdst::atomic::Species &species) const {
size_t PyDynamicEngine::getSpeciesIndex(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::atomic::Species &species
) const {
PYBIND11_OVERRIDE_PURE(
int,
gridfire::engine::DynamicEngine,
getSpeciesIndex,
ctx,
species
);
}
std::vector<double> PyDynamicEngine::mapNetInToMolarAbundanceVector(const gridfire::NetIn &netIn) const {
PYBIND11_OVERRIDE_PURE(
std::vector<double>,
gridfire::engine::DynamicEngine,
mapNetInToMolarAbundanceVector,
netIn
);
}
gridfire::engine::PrimingReport PyDynamicEngine::primeEngine(const gridfire::NetIn &netIn) {
gridfire::engine::PrimingReport PyDynamicEngine::primeEngine(
gridfire::engine::scratch::StateBlob& ctx,
const gridfire::NetIn &netIn
) const {
PYBIND11_OVERRIDE_PURE(
gridfire::engine::PrimingReport,
gridfire::engine::DynamicEngine,
primeEngine,
ctx,
netIn
);
}
gridfire::engine::EnergyDerivatives PyDynamicEngine::calculateEpsDerivatives(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
const double rho) const {
const double rho
) const {
PYBIND11_OVERRIDE_PURE(
gridfire::engine::EnergyDerivatives,
gridfire::engine::DynamicEngine,
calculateEpsDerivatives,
ctx,
comp,
T9,
rho
@@ -293,6 +253,7 @@ gridfire::engine::EnergyDerivatives PyDynamicEngine::calculateEpsDerivatives(
}
fourdst::composition::Composition PyDynamicEngine::collectComposition(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
const double T9,
const double rho
@@ -301,21 +262,37 @@ fourdst::composition::Composition PyDynamicEngine::collectComposition(
fourdst::composition::Composition,
gridfire::engine::DynamicEngine,
collectComposition,
ctx,
comp,
T9,
rho
);
}
gridfire::engine::SpeciesStatus PyDynamicEngine::getSpeciesStatus(const fourdst::atomic::Species &species) const {
gridfire::engine::SpeciesStatus PyDynamicEngine::getSpeciesStatus(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::atomic::Species &species
) const {
PYBIND11_OVERRIDE_PURE(
gridfire::engine::SpeciesStatus,
gridfire::engine::DynamicEngine,
getSpeciesStatus,
ctx,
species
);
}
std::optional<gridfire::engine::StepDerivatives<double>> PyDynamicEngine::getMostRecentRHSCalculation(
gridfire::engine::scratch::StateBlob &ctx
) const {
PYBIND11_OVERRIDE_PURE(
PYBIND11_TYPE(std::optional<gridfire::engine::StepDerivatives<double>>),
gridfire::engine::DynamicEngine,
getMostRecentRHSCalculation,
ctx
);
}
const gridfire::engine::Engine& PyEngineView::getBaseEngine() const {
PYBIND11_OVERRIDE_PURE(
const gridfire::engine::Engine&,

View File

@@ -10,12 +10,16 @@
class PyEngine final : public gridfire::engine::Engine {
public:
const std::vector<fourdst::atomic::Species>& getNetworkSpecies() const override;
const std::vector<fourdst::atomic::Species>& getNetworkSpecies(
gridfire::engine::scratch::StateBlob& ctx
) const override;
std::expected<gridfire::engine::StepDerivatives<double>, gridfire::engine::EngineStatus> calculateRHSAndEnergy(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho
double rho,
bool trust
) const override;
private:
mutable std::vector<fourdst::atomic::Species> m_species_cache;
@@ -23,21 +27,27 @@ private:
class PyDynamicEngine final : public gridfire::engine::DynamicEngine {
public:
const std::vector<fourdst::atomic::Species>& getNetworkSpecies() const override;
const std::vector<fourdst::atomic::Species>& getNetworkSpecies(
gridfire::engine::scratch::StateBlob& ctx
) const override;
std::expected<gridfire::engine::StepDerivatives<double>, gridfire::engine::EngineStatus> calculateRHSAndEnergy(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho
double rho,
bool trust
) const override;
gridfire::engine::NetworkJacobian generateJacobianMatrix(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract& comp,
double T9,
double rho
) const override;
gridfire::engine::NetworkJacobian generateJacobianMatrix(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho,
@@ -45,96 +55,81 @@ public:
) const override;
gridfire::engine::NetworkJacobian generateJacobianMatrix(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract& comp,
double T9,
double rho,
const gridfire::engine::SparsityPattern &sparsityPattern
) const override;
void generateStoichiometryMatrix() override;
int getStoichiometryMatrixEntry(
const fourdst::atomic::Species& species,
const gridfire::reaction::Reaction& reaction
) const override;
double calculateMolarReactionFlow(
gridfire::engine::scratch::StateBlob& ctx,
const gridfire::reaction::Reaction &reaction,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho
) const override;
const gridfire::reaction::ReactionSet& getNetworkReactions() const override;
void setNetworkReactions(
const gridfire::reaction::ReactionSet& reactions
) override;
const gridfire::reaction::ReactionSet& getNetworkReactions(
gridfire::engine::scratch::StateBlob& ctx
) const override;
std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::engine::EngineStatus> getSpeciesTimescales(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho
) const override;
std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::engine::EngineStatus> getSpeciesDestructionTimescales(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho
) const override;
fourdst::composition::Composition update(
fourdst::composition::Composition project(
gridfire::engine::scratch::StateBlob& ctx,
const gridfire::NetIn &netIn
) override;
) const override;
bool isStale(
const gridfire::NetIn &netIn
) override;
void setScreeningModel(
gridfire::screening::ScreeningType model
) override;
gridfire::screening::ScreeningType getScreeningModel() const override;
gridfire::screening::ScreeningType getScreeningModel(
gridfire::engine::scratch::StateBlob& ctx
) const override;
size_t getSpeciesIndex(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::atomic::Species &species
) const override;
std::vector<double> mapNetInToMolarAbundanceVector(
gridfire::engine::PrimingReport primeEngine(
gridfire::engine::scratch::StateBlob& ctx,
const gridfire::NetIn &netIn
) const override;
gridfire::engine::PrimingReport primeEngine(
const gridfire::NetIn &netIn
) override;
gridfire::engine::BuildDepthType getDepth() const override {
throw std::logic_error("Network depth not supported by this engine.");
}
void rebuild(
const fourdst::composition::CompositionAbstract &comp,
gridfire::engine::BuildDepthType depth
) override {
throw std::logic_error("Setting network depth not supported by this engine.");
}
[[nodiscard]] gridfire::engine::EnergyDerivatives calculateEpsDerivatives(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho
) const override;
fourdst::composition::Composition collectComposition(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp,
double T9,
double rho
) const override;
gridfire::engine::SpeciesStatus getSpeciesStatus(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::atomic::Species &species
) const override;
std::optional<gridfire::engine::StepDerivatives<double>> getMostRecentRHSCalculation(
gridfire::engine::scratch::StateBlob &ctx
) const override;
private:
mutable std::vector<fourdst::atomic::Species> m_species_cache;
};

View File

@@ -2,6 +2,8 @@
#include "bindings.h"
#include "gridfire/exceptions/error_scratchpad.h"
namespace py = pybind11;
#include "gridfire/exceptions/exceptions.h"
@@ -44,4 +46,6 @@ void register_exception_bindings(const py::module &m) {
py::register_exception<gridfire::exceptions::CVODESolverFailureError>(m, "CVODESolverFailureError", m.attr("SUNDIALSError"));
py::register_exception<gridfire::exceptions::KINSolSolverFailureError>(m, "KINSolSolverFailureError", m.attr("SUNDIALSError"));
py::register_exception<gridfire::exceptions::ScratchPadError>(m, "ScratchPadError", m.attr("GridFireError"));
}

View File

@@ -1,17 +0,0 @@
# Define the library
bindings_sources = files('bindings.cpp')
bindings_headers = files('bindings.h')
dependencies = [
gridfire_dep,
python3_dep,
pybind11_dep,
]
shared_module('py_gf_exceptions',
bindings_sources,
cpp_args: ['-fvisibility=default'],
install : true,
dependencies: dependencies,
include_directories: include_directories('.')
)

View File

@@ -1,7 +1,7 @@
from ._gridfire import *
import sys
from ._gridfire import type, utils, engine, solver, exceptions, partition, reaction, screening, io, policy
from ._gridfire import type, utils, engine, solver, exceptions, partition, reaction, screening, io, policy, config
sys.modules['gridfire.type'] = type
sys.modules['gridfire.utils'] = utils
@@ -13,8 +13,61 @@ sys.modules['gridfire.reaction'] = reaction
sys.modules['gridfire.screening'] = screening
sys.modules['gridfire.policy'] = policy
sys.modules['gridfire.io'] = io
sys.modules['gridfire.config'] = config
__all__ = ['type', 'utils', 'engine', 'solver', 'exceptions', 'partition', 'reaction', 'screening', 'io', 'policy']
__all__ = ['type', 'utils', 'engine', 'solver', 'exceptions', 'partition', 'reaction', 'screening', 'io', 'policy', 'config']
__version__ = "v0.7.4_rc2"
import importlib.metadata
try:
_meta = importlib.metadata.metadata('gridfire')
__version__ = _meta['Version']
__author__ = _meta['Author']
__license__ = _meta['License']
__email__ = _meta['Author-email']
__url__ = _meta['Home-page'] or _meta.get('Project-URL', '').split(',')[0].split(' ')[-1].strip()
__description__ = _meta['Summary']
except importlib.metadata.PackageNotFoundError :
__version__ = 'unknown - Package not installed'
__author__ = 'Emily M. Boudreaux'
__license__ = 'GNU General Public License v3.0'
__email__ = 'emily.boudreaux@dartmouth.edu'
__url__ = 'https://github.com/4D-STAR/GridFire'
def gf_metadata():
return {
'version': __version__,
'author': __author__,
'license': __license__,
'email': __email__,
'url': __url__,
'description': __description__
}
def gf_version():
return __version__
def gf_author():
return __author__
def gf_license():
return __license__
def gf_email():
return __email__
def gf_url():
return __url__
def gf_description():
return __description__
def gf_collaboration():
return "4D-STAR Collaboration"
def gf_credits():
return [
"Emily M. Boudreaux - Lead Developer",
"Aaron Dotter - Co-Developer",
"4D-STAR Collaboration - Contributors"
]

View File

@@ -1,17 +0,0 @@
# Define the library
bindings_sources = files('bindings.cpp')
bindings_headers = files('bindings.h')
dependencies = [
gridfire_dep,
python3_dep,
pybind11_dep,
]
shared_module('py_gf_io',
bindings_sources,
cpp_args: ['-fvisibility=default'],
install : true,
dependencies: dependencies,
include_directories: include_directories('.')
)

View File

@@ -1,21 +0,0 @@
gf_io_trampoline_sources = files('py_io.cpp')
gf_io_trapoline_dependencies = [
gridfire_dep,
pybind11_dep,
python3_dep,
]
gf_io_trampoline_lib = static_library(
'io_trampolines',
gf_io_trampoline_sources,
include_directories: include_directories('.'),
dependencies: gf_io_trapoline_dependencies,
install: false,
)
gr_io_trampoline_dep = declare_dependency(
link_with: gf_io_trampoline_lib,
include_directories: ('.'),
dependencies: gf_io_trapoline_dependencies,
)

View File

@@ -1,10 +0,0 @@
subdir('types')
subdir('utils')
subdir('exceptions')
subdir('io')
subdir('partition')
subdir('reaction')
subdir('screening')
subdir('engine')
subdir('policy')
subdir('solver')

View File

@@ -1,19 +0,0 @@
subdir('trampoline')
# Define the library
bindings_sources = files('bindings.cpp')
bindings_headers = files('bindings.h')
dependencies = [
gridfire_dep,
python3_dep,
pybind11_dep,
]
shared_module('py_gf_partition',
bindings_sources,
cpp_args: ['-fvisibility=default'],
install : true,
dependencies: dependencies,
include_directories: include_directories('.')
)

View File

@@ -1,21 +0,0 @@
gf_partition_trampoline_sources = files('py_partition.cpp')
gf_partition_trapoline_dependencies = [
gridfire_dep,
pybind11_dep,
python3_dep,
]
gf_partition_trampoline_lib = static_library(
'partition_trampolines',
gf_partition_trampoline_sources,
include_directories: include_directories('.'),
dependencies: gf_partition_trapoline_dependencies,
install: false,
)
gr_partition_trampoline_dep = declare_dependency(
link_with: gf_partition_trampoline_lib,
include_directories: ('.'),
dependencies: gf_partition_trapoline_dependencies,
)

View File

@@ -8,7 +8,6 @@
#include "gridfire/policy/policy.h"
PYBIND11_DECLARE_HOLDER_TYPE(T, std::unique_ptr<T>, true) // Declare unique_ptr as a holder type for pybind11
namespace py = pybind11;
@@ -103,9 +102,30 @@ namespace {
.def(
"construct",
&T::construct,
py::return_value_policy::reference,
"Construct the network according to the policy."
)
.def(
"get_engine_stack",
[](const T &self) {
const auto& stack = self.get_engine_stack();
std::vector<gridfire::engine::DynamicEngine*> engine_ptrs;
engine_ptrs.reserve(stack.size());
for (const auto& engine_uptr : stack) {
engine_ptrs.push_back(engine_uptr.get());
}
return engine_ptrs;
},
py::return_value_policy::reference_internal
)
.def(
"get_stack_scratch_blob",
&T::get_stack_scratch_blob
)
.def(
"get_partition_function",
&T::get_partition_function
);
}
}
@@ -215,6 +235,26 @@ void register_network_policy_bindings(pybind11::module &m) {
.value("INITIALIZED_VERIFIED", gridfire::policy::NetworkPolicyStatus::INITIALIZED_VERIFIED)
.export_values();
m.def("network_policy_status_to_string",
&gridfire::policy::NetworkPolicyStatusToString,
py::arg("status"),
"Convert a NetworkPolicyStatus enum value to its string representation."
);
py::class_<gridfire::policy::ConstructionResults>(m, "ConstructionResults")
.def_property_readonly("engine",
[](const gridfire::policy::ConstructionResults &self) -> const gridfire::engine::DynamicEngine& {
return self.engine;
},
py::return_value_policy::reference
)
.def_property_readonly("scratch_blob",
[](const gridfire::policy::ConstructionResults &self) {
return self.scratch_blob.get();
},
py::return_value_policy::reference_internal
);
py::class_<gridfire::policy::NetworkPolicy, PyNetworkPolicy> py_networkPolicy(m, "NetworkPolicy");
py::class_<gridfire::policy::MainSequencePolicy, gridfire::policy::NetworkPolicy> py_mainSeqPolicy(m, "MainSequencePolicy");
py_mainSeqPolicy.def(

View File

@@ -1,19 +0,0 @@
# Define the library
subdir('trampoline')
bindings_sources = files('bindings.cpp')
bindings_headers = files('bindings.h')
dependencies = [
gridfire_dep,
python3_dep,
pybind11_dep,
]
shared_module('py_gf_policy',
bindings_sources,
cpp_args: ['-fvisibility=default'],
install : true,
dependencies: dependencies,
include_directories: include_directories('.')
)

View File

@@ -1,21 +0,0 @@
gf_policy_trampoline_sources = files('py_policy.cpp')
gf_policy_trapoline_dependencies = [
gridfire_dep,
pybind11_dep,
python3_dep,
]
gf_policy_trampoline_lib = static_library(
'policy_trampolines',
gf_policy_trampoline_sources,
include_directories: include_directories('.'),
dependencies: gf_policy_trapoline_dependencies,
install: false,
)
gr_policy_trampoline_dep = declare_dependency(
link_with: gf_policy_trampoline_lib,
include_directories: ('.'),
dependencies: gf_policy_trapoline_dependencies,
)

View File

@@ -39,9 +39,9 @@ const gridfire::reaction::ReactionSet& PyNetworkPolicy::get_seed_reactions() con
);
}
gridfire::engine::DynamicEngine& PyNetworkPolicy::construct() {
gridfire::policy::ConstructionResults PyNetworkPolicy::construct() {
PYBIND11_OVERRIDE_PURE(
gridfire::engine::DynamicEngine&,
gridfire::policy::ConstructionResults,
gridfire::policy::NetworkPolicy,
construct
);
@@ -79,6 +79,14 @@ const std::unique_ptr<gridfire::partition::PartitionFunction>& PyNetworkPolicy::
);
}
std::unique_ptr<gridfire::engine::scratch::StateBlob> PyNetworkPolicy::get_stack_scratch_blob() const {
PYBIND11_OVERRIDE_PURE(
std::unique_ptr<gridfire::engine::scratch::StateBlob>,
gridfire::policy::NetworkPolicy,
get_stack_scratch_blob
);
}
const gridfire::reaction::ReactionSet &PyReactionChainPolicy::get_reactions() const {
PYBIND11_OVERRIDE_PURE(
const gridfire::reaction::ReactionSet &,

View File

@@ -13,7 +13,7 @@ public:
[[nodiscard]] const gridfire::reaction::ReactionSet& get_seed_reactions() const override;
[[nodiscard]] gridfire::engine::DynamicEngine& construct() override;
[[nodiscard]] gridfire::policy::ConstructionResults construct() override;
[[nodiscard]] gridfire::policy::NetworkPolicyStatus get_status() const override;
@@ -22,6 +22,8 @@ public:
[[nodiscard]] std::vector<gridfire::engine::EngineTypes> get_engine_types_stack() const override;
[[nodiscard]] const std::unique_ptr<gridfire::partition::PartitionFunction>& get_partition_function() const override;
[[nodiscard]] std::unique_ptr<gridfire::engine::scratch::StateBlob> get_stack_scratch_blob() const override;
};
class PyReactionChainPolicy final : public gridfire::policy::ReactionChainPolicy {

View File

@@ -1,17 +0,0 @@
# Define the library
bindings_sources = files('bindings.cpp')
bindings_headers = files('bindings.h')
dependencies = [
gridfire_dep,
python3_dep,
pybind11_dep,
]
shared_module('py_gf_reaction',
bindings_sources,
cpp_args: ['-fvisibility=default'],
install : true,
dependencies: dependencies,
include_directories: include_directories('.')
)

View File

@@ -1,19 +0,0 @@
subdir('trampoline')
# Define the library
bindings_sources = files('bindings.cpp')
bindings_headers = files('bindings.h')
dependencies = [
gridfire_dep,
python3_dep,
pybind11_dep,
]
shared_module('py_gf_screening',
bindings_sources,
cpp_args: ['-fvisibility=default'],
install : true,
dependencies: dependencies,
include_directories: include_directories('.')
)

View File

@@ -1,21 +0,0 @@
gf_screening_trampoline_sources = files('py_screening.cpp')
gf_screening_trapoline_dependencies = [
gridfire_dep,
pybind11_dep,
python3_dep,
]
gf_screening_trampoline_lib = static_library(
'screening_trampolines',
gf_screening_trampoline_sources,
include_directories: include_directories('.'),
dependencies: gf_screening_trapoline_dependencies,
install: false,
)
gr_screening_trampoline_dep = declare_dependency(
link_with: gf_screening_trampoline_lib,
include_directories: ('.'),
dependencies: gf_screening_trapoline_dependencies,
)

View File

@@ -7,125 +7,226 @@
#include "bindings.h"
#include "gridfire/solver/strategies/CVODE_solver_strategy.h"
#include "gridfire/solver/strategies/PointSolver.h"
#include "gridfire/engine/scratchpads/blob.h"
#include "trampoline/py_solver.h"
namespace py = pybind11;
void register_solver_bindings(const py::module &m) {
auto py_solver_context_base = py::class_<gridfire::solver::SolverContextBase>(m, "SolverContextBase");
auto py_cvode_timestep_context = py::class_<gridfire::solver::CVODESolverStrategy::TimestepContext, gridfire::solver::SolverContextBase>(m, "CVODETimestepContext");
py_cvode_timestep_context.def_readonly("t", &gridfire::solver::CVODESolverStrategy::TimestepContext::t);
auto py_cvode_timestep_context = py::class_<gridfire::solver::PointSolverTimestepContext>(m, "PointSolverTimestepContext");
py_cvode_timestep_context.def_readonly("t", &gridfire::solver::PointSolverTimestepContext::t);
py_cvode_timestep_context.def_property_readonly(
"state",
[](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> std::vector<double> {
[](const gridfire::solver::PointSolverTimestepContext& self) -> std::vector<double> {
const sunrealtype* nvec_data = N_VGetArrayPointer(self.state);
const sunindextype length = N_VGetLength(self.state);
return std::vector<double>(nvec_data, nvec_data + length);
return {nvec_data, nvec_data + length};
}
);
py_cvode_timestep_context.def_readonly("dt", &gridfire::solver::CVODESolverStrategy::TimestepContext::dt);
py_cvode_timestep_context.def_readonly("last_step_time", &gridfire::solver::CVODESolverStrategy::TimestepContext::last_step_time);
py_cvode_timestep_context.def_readonly("T9", &gridfire::solver::CVODESolverStrategy::TimestepContext::T9);
py_cvode_timestep_context.def_readonly("rho", &gridfire::solver::CVODESolverStrategy::TimestepContext::rho);
py_cvode_timestep_context.def_readonly("num_steps", &gridfire::solver::CVODESolverStrategy::TimestepContext::num_steps);
py_cvode_timestep_context.def_readonly("currentConvergenceFailures", &gridfire::solver::CVODESolverStrategy::TimestepContext::currentConvergenceFailures);
py_cvode_timestep_context.def_readonly("currentNonlinearIterations", &gridfire::solver::CVODESolverStrategy::TimestepContext::currentNonlinearIterations);
py_cvode_timestep_context.def_readonly("dt", &gridfire::solver::PointSolverTimestepContext::dt);
py_cvode_timestep_context.def_readonly("last_step_time", &gridfire::solver::PointSolverTimestepContext::last_step_time);
py_cvode_timestep_context.def_readonly("T9", &gridfire::solver::PointSolverTimestepContext::T9);
py_cvode_timestep_context.def_readonly("rho", &gridfire::solver::PointSolverTimestepContext::rho);
py_cvode_timestep_context.def_readonly("num_steps", &gridfire::solver::PointSolverTimestepContext::num_steps);
py_cvode_timestep_context.def_readonly("currentConvergenceFailures", &gridfire::solver::PointSolverTimestepContext::currentConvergenceFailures);
py_cvode_timestep_context.def_readonly("currentNonlinearIterations", &gridfire::solver::PointSolverTimestepContext::currentNonlinearIterations);
py_cvode_timestep_context.def_property_readonly(
"engine",
[](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> const gridfire::engine::DynamicEngine& {
[](const gridfire::solver::PointSolverTimestepContext& self) -> const gridfire::engine::DynamicEngine& {
return self.engine;
}
);
py_cvode_timestep_context.def_property_readonly(
"networkSpecies",
[](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> std::vector<fourdst::atomic::Species> {
[](const gridfire::solver::PointSolverTimestepContext& self) -> std::vector<fourdst::atomic::Species> {
return self.networkSpecies;
}
);
py_cvode_timestep_context.def_property_readonly(
"state_ctx",
[](const gridfire::solver::PointSolverTimestepContext& self) {
return &(self.state_ctx);
},
py::return_value_policy::reference_internal
);
auto py_dynamic_network_solver_strategy = py::class_<gridfire::solver::DynamicNetworkSolverStrategy, PyDynamicNetworkSolverStrategy>(m, "DynamicNetworkSolverStrategy");
py_dynamic_network_solver_strategy.def(
auto py_solver_context_base = py::class_<gridfire::solver::SolverContextBase>(m, "SolverContextBase");
auto py_point_solver_context = py::class_<gridfire::solver::PointSolverContext, gridfire::solver::SolverContextBase>(m, "PointSolverContext");
py_point_solver_context
.def_readonly(
"sun_ctx", &gridfire::solver::PointSolverContext::sun_ctx
)
.def_readonly(
"cvode_mem", &gridfire::solver::PointSolverContext::cvode_mem
)
.def_readonly(
"Y", &gridfire::solver::PointSolverContext::Y
)
.def_readonly(
"YErr", &gridfire::solver::PointSolverContext::YErr
)
.def_readonly(
"J", &gridfire::solver::PointSolverContext::J
)
.def_readonly(
"LS", &gridfire::solver::PointSolverContext::LS
)
.def_property_readonly(
"engine_ctx",
[](const gridfire::solver::PointSolverContext& self) -> gridfire::engine::scratch::StateBlob& {
return *(self.engine_ctx);
},
py::return_value_policy::reference
)
.def_readonly(
"num_steps", &gridfire::solver::PointSolverContext::num_steps
)
.def_property(
"abs_tol",
[](const gridfire::solver::PointSolverContext& self) -> double {
return self.abs_tol.value();
},
[](gridfire::solver::PointSolverContext& self, double abs_tol) -> void {
self.abs_tol = abs_tol;
}
)
.def_property(
"rel_tol",
[](const gridfire::solver::PointSolverContext& self) -> double {
return self.rel_tol.value();
},
[](gridfire::solver::PointSolverContext& self, double rel_tol) -> void {
self.rel_tol = rel_tol;
}
)
.def_property(
"stdout_logging",
[](const gridfire::solver::PointSolverContext& self) -> bool {
return self.stdout_logging;
},
[](gridfire::solver::PointSolverContext& self, const bool enable) -> void {
self.stdout_logging = enable;
}
)
.def_property(
"detailed_logging",
[](const gridfire::solver::PointSolverContext& self) -> bool {
return self.detailed_step_logging;
},
[](gridfire::solver::PointSolverContext& self, const bool enable) -> void {
self.detailed_step_logging = enable;
}
)
.def_property(
"callback",
[](const gridfire::solver::PointSolverContext& self) -> std::optional<std::function<void(const gridfire::solver::PointSolverTimestepContext&)>> {
return self.callback;
},
[](gridfire::solver::PointSolverContext& self, const std::optional<std::function<void(const gridfire::solver::PointSolverTimestepContext&)>>& cb) {
self.callback = cb;
}
)
.def("reset_all", &gridfire::solver::PointSolverContext::reset_all)
.def("reset_user", &gridfire::solver::PointSolverContext::reset_user)
.def("reset_cvode", &gridfire::solver::PointSolverContext::reset_cvode)
.def("clear_context", &gridfire::solver::PointSolverContext::clear_context)
.def("init_context", &gridfire::solver::PointSolverContext::init_context)
.def("has_context", &gridfire::solver::PointSolverContext::has_context)
.def("init", &gridfire::solver::PointSolverContext::init)
.def(py::init<const gridfire::engine::scratch::StateBlob&>(), py::arg("engine_ctx"));
auto py_single_zone_dynamic_network_solver = py::class_<gridfire::solver::SingleZoneDynamicNetworkSolver, PySingleZoneDynamicNetworkSolver>(m, "SingleZoneDynamicNetworkSolver");
py_single_zone_dynamic_network_solver.def(
"evaluate",
&gridfire::solver::DynamicNetworkSolverStrategy::evaluate,
&gridfire::solver::SingleZoneDynamicNetworkSolver::evaluate,
py::arg("solver_ctx"),
py::arg("netIn"),
"evaluate the dynamic engine using the dynamic engine class"
"evaluate the dynamic engine using the dynamic engine class for a single zone"
);
auto py_multi_zone_dynamic_network_solver = py::class_<gridfire::solver::MultiZoneDynamicNetworkSolver, PyMultiZoneDynamicNetworkSolver>(m, "MultiZoneDynamicNetworkSolver");
py_multi_zone_dynamic_network_solver.def(
"evaluate",
&gridfire::solver::MultiZoneDynamicNetworkSolver::evaluate,
py::arg("solver_ctx"),
py::arg("netIns"),
"evaluate the dynamic engine using the dynamic engine class for multiple zones (using openmp if available)"
);
auto py_point_solver = py::class_<gridfire::solver::PointSolver, gridfire::solver::SingleZoneDynamicNetworkSolver>(m, "PointSolver");
py_dynamic_network_solver_strategy.def(
"describe_callback_context",
&gridfire::solver::DynamicNetworkSolverStrategy::describe_callback_context,
"Get a structure representing what data is in the callback context in a human readable format"
);
auto py_cvode_solver_strategy = py::class_<gridfire::solver::CVODESolverStrategy, gridfire::solver::DynamicNetworkSolverStrategy>(m, "CVODESolverStrategy");
py_cvode_solver_strategy.def(
py_point_solver.def(
py::init<gridfire::engine::DynamicEngine&>(),
py::arg("engine"),
"Initialize the CVODESolverStrategy object."
"Initialize the PointSolver object."
);
py_cvode_solver_strategy.def(
py_point_solver.def(
"evaluate",
py::overload_cast<const gridfire::NetIn&, bool>(&gridfire::solver::CVODESolverStrategy::evaluate),
py::overload_cast<gridfire::solver::SolverContextBase&, const gridfire::NetIn&, bool, bool>(&gridfire::solver::PointSolver::evaluate, py::const_),
py::arg("solver_ctx"),
py::arg("netIn"),
py::arg("display_trigger") = false,
py::arg("force_reinitialization") = false,
"evaluate the dynamic engine using the dynamic engine class"
);
py_cvode_solver_strategy.def(
"get_stdout_logging_enabled",
&gridfire::solver::CVODESolverStrategy::get_stdout_logging_enabled,
"Check if solver logging to standard output is enabled."
auto py_grid_solver_context = py::class_<gridfire::solver::GridSolverContext, gridfire::solver::SolverContextBase>(m, "GridSolverContext");
py_grid_solver_context.def(py::init<const gridfire::engine::scratch::StateBlob&>(), py::arg("ctx_template"));
py_grid_solver_context.def("init", &gridfire::solver::GridSolverContext::init);
py_grid_solver_context.def("reset", &gridfire::solver::GridSolverContext::reset);
py_grid_solver_context.def("set_callback", py::overload_cast<const std::function<void(const gridfire::solver::TimestepContextBase&)>&>(&gridfire::solver::GridSolverContext::set_callback) , py::arg("callback"));
py_grid_solver_context.def("set_callback", py::overload_cast<const std::function<void(const gridfire::solver::TimestepContextBase&)>&, size_t>(&gridfire::solver::GridSolverContext::set_callback) , py::arg("callback"), py::arg("zone_idx"));
py_grid_solver_context.def("clear_callback", py::overload_cast<>(&gridfire::solver::GridSolverContext::clear_callback));
py_grid_solver_context.def("clear_callback", py::overload_cast<size_t>(&gridfire::solver::GridSolverContext::clear_callback), py::arg("zone_idx"));
py_grid_solver_context.def_property(
"stdout_logging",
[](const gridfire::solver::GridSolverContext& self) -> bool {
return self.zone_stdout_logging;
},
[](gridfire::solver::GridSolverContext& self, const bool enable) -> void {
self.zone_stdout_logging = enable;
}
)
.def_property(
"detailed_logging",
[](const gridfire::solver::GridSolverContext& self) -> bool {
return self.zone_detailed_logging;
},
[](gridfire::solver::GridSolverContext& self, const bool enable) -> void {
self.zone_detailed_logging = enable;
}
)
.def_property(
"zone_completion_logging",
[](const gridfire::solver::GridSolverContext& self) -> bool {
return self.zone_completion_logging;
},
[](gridfire::solver::GridSolverContext& self, const bool enable) -> void {
self.zone_completion_logging = enable;
}
);
py_cvode_solver_strategy.def(
"set_stdout_logging_enabled",
&gridfire::solver::CVODESolverStrategy::set_stdout_logging_enabled,
py::arg("logging_enabled"),
"Enable logging to standard output."
auto py_grid_solver = py::class_<gridfire::solver::GridSolver, gridfire::solver::MultiZoneDynamicNetworkSolver>(m, "GridSolver");
py_grid_solver.def(
py::init<const gridfire::engine::DynamicEngine&, const gridfire::solver::SingleZoneDynamicNetworkSolver&>(),
py::arg("engine"),
py::arg("solver"),
"Initialize the GridSolver object."
);
py_cvode_solver_strategy.def(
"set_absTol",
&gridfire::solver::CVODESolverStrategy::set_absTol,
py::arg("absTol"),
"Set the absolute tolerance for the CVODE solver."
py_grid_solver.def(
"evaluate",
&gridfire::solver::GridSolver::evaluate,
py::arg("solver_ctx"),
py::arg("netIns"),
"evaluate the dynamic engine using the dynamic engine class"
);
py_cvode_solver_strategy.def(
"set_relTol",
&gridfire::solver::CVODESolverStrategy::set_relTol,
py::arg("relTol"),
"Set the relative tolerance for the CVODE solver."
);
py_cvode_solver_strategy.def(
"get_absTol",
&gridfire::solver::CVODESolverStrategy::get_absTol,
"Get the absolute tolerance for the CVODE solver."
);
py_cvode_solver_strategy.def(
"get_relTol",
&gridfire::solver::CVODESolverStrategy::get_relTol,
"Get the relative tolerance for the CVODE solver."
);
py_cvode_solver_strategy.def(
"set_callback",
[](
gridfire::solver::CVODESolverStrategy& self,
std::function<void(const gridfire::solver::CVODESolverStrategy::TimestepContext&)> cb
) {
self.set_callback(std::any(cb));
},
py::arg("cb"),
"Set a callback function which will run at the end of every successful timestep"
);
}

View File

@@ -1,17 +0,0 @@
# Define the library
bindings_sources = files('bindings.cpp')
bindings_headers = files('bindings.h')
dependencies = [
gridfire_dep,
python3_dep,
pybind11_dep,
]
shared_module('py_gf_solver',
bindings_sources,
cpp_args: ['-fvisibility=default'],
install : true,
dependencies: dependencies,
include_directories: include_directories('.')
)

View File

@@ -1,21 +0,0 @@
gf_solver_trampoline_sources = files('py_solver.cpp')
gf_solver_trapoline_dependencies = [
gridfire_dep,
pybind11_dep,
python3_dep,
]
gf_solver_trampoline_lib = static_library(
'solver_trampolines',
gf_solver_trampoline_sources,
include_directories: include_directories('.'),
dependencies: gf_solver_trapoline_dependencies,
install: false,
)
gr_solver_trampoline_dep = declare_dependency(
link_with: gf_solver_trampoline_lib,
include_directories: ('.'),
dependencies: gf_solver_trapoline_dependencies,
)

View File

@@ -13,38 +13,63 @@
namespace py = pybind11;
gridfire::NetOut PyDynamicNetworkSolverStrategy::evaluate(const gridfire::NetIn &netIn) {
gridfire::NetOut PySingleZoneDynamicNetworkSolver::evaluate(
gridfire::solver::SolverContextBase &solver_ctx,
const gridfire::NetIn &netIn
) const {
PYBIND11_OVERRIDE_PURE(
gridfire::NetOut, // Return type
gridfire::solver::DynamicNetworkSolverStrategy, // Base class
evaluate, // Method name
netIn // Arguments
gridfire::NetOut,
gridfire::solver::SingleZoneDynamicNetworkSolver,
evaluate,
solver_ctx,
netIn
);
}
void PyDynamicNetworkSolverStrategy::set_callback(const std::any &callback) {
std::vector<gridfire::NetOut> PyMultiZoneDynamicNetworkSolver::evaluate(
gridfire::solver::SolverContextBase &solver_ctx,
const std::vector<gridfire::NetIn> &netIns
) const {
PYBIND11_OVERRIDE_PURE(
void,
gridfire::solver::DynamicNetworkSolverStrategy, // Base class
set_callback, // Method name
callback // Arguments
std::vector<gridfire::NetOut>,
gridfire::solver::MultiZoneDynamicNetworkSolver,
evaluate,
solver_ctx,
netIns
);
}
std::vector<std::tuple<std::string, std::string>> PyDynamicNetworkSolverStrategy::describe_callback_context() const {
using DescriptionVector = std::vector<std::tuple<std::string, std::string>>;
std::vector<std::tuple<std::string, std::string>> PyTimestepContextBase::describe() const {
using ReturnType = std::vector<std::tuple<std::string, std::string>>;
PYBIND11_OVERRIDE_PURE(
DescriptionVector, // Return type
gridfire::solver::DynamicNetworkSolverStrategy, // Base class
describe_callback_context // Method name
);
}
std::vector<std::tuple<std::string, std::string>> PySolverContextBase::describe() const {
using DescriptionVector = std::vector<std::tuple<std::string, std::string>>;
PYBIND11_OVERRIDE_PURE(
DescriptionVector,
gridfire::solver::SolverContextBase,
ReturnType,
gridfire::solver::TimestepContextBase,
describe
);
}
void PySolverContextBase::init() {
PYBIND11_OVERRIDE_PURE(
void,
gridfire::solver::SolverContextBase,
init
);
}
void PySolverContextBase::set_stdout_logging(bool enable) {
PYBIND11_OVERRIDE_PURE(
void,
gridfire::solver::SolverContextBase,
set_stdout_logging,
enable
);
}
void PySolverContextBase::set_detailed_logging(bool enable) {
PYBIND11_OVERRIDE_PURE(
void,
gridfire::solver::SolverContextBase,
set_detailed_logging,
enable
);
}

View File

@@ -7,14 +7,37 @@
#include <string>
#include <any>
class PyDynamicNetworkSolverStrategy final : public gridfire::solver::DynamicNetworkSolverStrategy {
explicit PyDynamicNetworkSolverStrategy(gridfire::engine::DynamicEngine &engine) : gridfire::solver::DynamicNetworkSolverStrategy(engine) {}
gridfire::NetOut evaluate(const gridfire::NetIn &netIn) override;
void set_callback(const std::any &callback) override;
[[nodiscard]] std::vector<std::tuple<std::string, std::string>> describe_callback_context() const override;
class PySingleZoneDynamicNetworkSolver final : public gridfire::solver::SingleZoneDynamicNetworkSolver {
public:
explicit PySingleZoneDynamicNetworkSolver(const gridfire::engine::DynamicEngine &engine) : gridfire::solver::SingleZoneDynamicNetworkSolver(engine) {}
gridfire::NetOut evaluate(
gridfire::solver::SolverContextBase &solver_ctx,
const gridfire::NetIn &netIn
) const override;
};
class PyMultiZoneDynamicNetworkSolver final : public gridfire::solver::MultiZoneDynamicNetworkSolver {
public:
explicit PyMultiZoneDynamicNetworkSolver(
const gridfire::engine::DynamicEngine &engine,
const gridfire::solver::SingleZoneDynamicNetworkSolver &local_solver
) : gridfire::solver::MultiZoneDynamicNetworkSolver(engine, local_solver) {}
std::vector<gridfire::NetOut> evaluate(
gridfire::solver::SolverContextBase &solver_ctx,
const std::vector<gridfire::NetIn> &netIns
) const override;
};
class PyTimestepContextBase final : public gridfire::solver::TimestepContextBase {
public:
[[nodiscard]] std::vector<std::tuple<std::string, std::string>> describe() const override;
};
class PySolverContextBase final : public gridfire::solver::SolverContextBase {
public:
[[nodiscard]] std::vector<std::tuple<std::string, std::string>> describe() const override;
};
void init() override;
void set_stdout_logging(bool enable) override;
void set_detailed_logging(bool enable) override;
};

View File

@@ -1,17 +0,0 @@
# Define the library
bindings_sources = files('bindings.cpp')
bindings_headers = files('bindings.h')
dependencies = [
gridfire_dep,
python3_dep,
pybind11_dep,
]
shared_module('py_gf_types',
bindings_sources,
cpp_args: ['-fvisibility=default'],
install : true,
dependencies: dependencies,
include_directories: include_directories('.')
)

View File

@@ -12,6 +12,7 @@ namespace py = pybind11;
void register_utils_bindings(py::module &m) {
m.def("formatNuclearTimescaleLogString",
&gridfire::utils::formatNuclearTimescaleLogString,
py::arg("ctx"),
py::arg("engine"),
py::arg("Y"),
py::arg("T9"),

View File

@@ -1,17 +0,0 @@
# Define the library
bindings_sources = files('bindings.cpp')
bindings_headers = files('bindings.h')
dependencies = [
gridfire_dep,
python3_dep,
pybind11_dep,
]
shared_module('py_gf_utils',
bindings_sources,
cpp_args: ['-fvisibility=default'],
install : true,
dependencies: dependencies,
include_directories: include_directories('.')
)