From d7e959764378c3b0c4313759aa70922984651617 Mon Sep 17 00:00:00 2001 From: Emily Boudreaux Date: Tue, 4 Nov 2025 13:24:29 -0500 Subject: [PATCH] fix(python/solver): fidxed any cast issue with set_callback and moved initialization of solver context to before set callback --- src/python/solver/bindings.cpp | 76 +++++++++++++++++++--------------- 1 file changed, 43 insertions(+), 33 deletions(-) diff --git a/src/python/solver/bindings.cpp b/src/python/solver/bindings.cpp index 3b4fccd9..b00307b8 100644 --- a/src/python/solver/bindings.cpp +++ b/src/python/solver/bindings.cpp @@ -16,6 +16,38 @@ namespace py = pybind11; void register_solver_bindings(const py::module &m) { + auto py_solver_context_base = py::class_(m, "SolverContextBase"); + + auto py_cvode_timestep_context = py::class_(m, "CVODETimestepContext"); + py_cvode_timestep_context.def_readonly("t", &gridfire::solver::CVODESolverStrategy::TimestepContext::t); + py_cvode_timestep_context.def_property_readonly( + "state", + [](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> std::vector { + const sunrealtype* nvec_data = N_VGetArrayPointer(self.state); + const sunindextype length = N_VGetLength(self.state); + return std::vector(nvec_data, nvec_data + length); + } + ); + py_cvode_timestep_context.def_readonly("dt", &gridfire::solver::CVODESolverStrategy::TimestepContext::dt); + py_cvode_timestep_context.def_readonly("last_step_time", &gridfire::solver::CVODESolverStrategy::TimestepContext::last_step_time); + py_cvode_timestep_context.def_readonly("T9", &gridfire::solver::CVODESolverStrategy::TimestepContext::T9); + py_cvode_timestep_context.def_readonly("rho", &gridfire::solver::CVODESolverStrategy::TimestepContext::rho); + py_cvode_timestep_context.def_readonly("num_steps", &gridfire::solver::CVODESolverStrategy::TimestepContext::num_steps); + py_cvode_timestep_context.def_readonly("currentConvergenceFailures", &gridfire::solver::CVODESolverStrategy::TimestepContext::currentConvergenceFailures); + py_cvode_timestep_context.def_readonly("currentNonlinearIterations", &gridfire::solver::CVODESolverStrategy::TimestepContext::currentNonlinearIterations); + py_cvode_timestep_context.def_property_readonly( + "engine", + [](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> const gridfire::DynamicEngine& { + return self.engine; + } + ); + py_cvode_timestep_context.def_property_readonly( + "networkSpecies", + [](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> std::vector { + return self.networkSpecies; + } + ); + auto py_dynamic_network_solver_strategy = py::class_(m, "DynamicNetworkSolverStrategy"); py_dynamic_network_solver_strategy.def( "evaluate", @@ -24,13 +56,6 @@ void register_solver_bindings(const py::module &m) { "evaluate the dynamic engine using the dynamic engine class" ); - py_dynamic_network_solver_strategy.def( - "set_callback", - [](gridfire::solver::DynamicNetworkSolverStrategy& self, std::function cb) { - self.set_callback(cb); - }, - "Set a callback function which will run at the end of every successful timestep" - ); py_dynamic_network_solver_strategy.def( "describe_callback_context", @@ -67,33 +92,18 @@ void register_solver_bindings(const py::module &m) { "Enable logging to standard output." ); - auto py_cvode_timestep_context = py::class_(m, "CVODETimestepContext"); - py_cvode_timestep_context.def_readonly("t", &gridfire::solver::CVODESolverStrategy::TimestepContext::t); - py_cvode_timestep_context.def_property_readonly( - "state", - [](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> std::vector { - const sunrealtype* nvec_data = N_VGetArrayPointer(self.state); - const sunindextype length = N_VGetLength(self.state); - return std::vector(nvec_data, nvec_data + length); - } - ); - py_cvode_timestep_context.def_readonly("dt", &gridfire::solver::CVODESolverStrategy::TimestepContext::dt); - py_cvode_timestep_context.def_readonly("last_step_time", &gridfire::solver::CVODESolverStrategy::TimestepContext::last_step_time); - py_cvode_timestep_context.def_readonly("T9", &gridfire::solver::CVODESolverStrategy::TimestepContext::T9); - py_cvode_timestep_context.def_readonly("rho", &gridfire::solver::CVODESolverStrategy::TimestepContext::rho); - py_cvode_timestep_context.def_readonly("num_steps", &gridfire::solver::CVODESolverStrategy::TimestepContext::num_steps); - py_cvode_timestep_context.def_property_readonly( - "engine", - [](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> const gridfire::DynamicEngine& { - return self.engine; - } - ); - py_cvode_timestep_context.def_property_readonly( - "networkSpecies", - [](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> std::vector { - return self.networkSpecies; - } + py_cvode_solver_strategy.def( + "set_callback", + []( + gridfire::solver::CVODESolverStrategy& self, + std::function cb + ) { + self.set_callback(std::any(cb)); + }, + py::arg("cb"), + "Set a callback function which will run at the end of every successful timestep" ); + }