fix(python/solver): fidxed any cast issue with set_callback and moved initialization of solver context to before set callback

This commit is contained in:
2025-11-04 13:24:29 -05:00
parent 72a3f5bf4c
commit d7e9597643

View File

@@ -16,6 +16,38 @@ namespace py = pybind11;
void register_solver_bindings(const py::module &m) { void register_solver_bindings(const py::module &m) {
auto py_solver_context_base = py::class_<gridfire::solver::SolverContextBase>(m, "SolverContextBase");
auto py_cvode_timestep_context = py::class_<gridfire::solver::CVODESolverStrategy::TimestepContext, gridfire::solver::SolverContextBase>(m, "CVODETimestepContext");
py_cvode_timestep_context.def_readonly("t", &gridfire::solver::CVODESolverStrategy::TimestepContext::t);
py_cvode_timestep_context.def_property_readonly(
"state",
[](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> std::vector<double> {
const sunrealtype* nvec_data = N_VGetArrayPointer(self.state);
const sunindextype length = N_VGetLength(self.state);
return std::vector<double>(nvec_data, nvec_data + length);
}
);
py_cvode_timestep_context.def_readonly("dt", &gridfire::solver::CVODESolverStrategy::TimestepContext::dt);
py_cvode_timestep_context.def_readonly("last_step_time", &gridfire::solver::CVODESolverStrategy::TimestepContext::last_step_time);
py_cvode_timestep_context.def_readonly("T9", &gridfire::solver::CVODESolverStrategy::TimestepContext::T9);
py_cvode_timestep_context.def_readonly("rho", &gridfire::solver::CVODESolverStrategy::TimestepContext::rho);
py_cvode_timestep_context.def_readonly("num_steps", &gridfire::solver::CVODESolverStrategy::TimestepContext::num_steps);
py_cvode_timestep_context.def_readonly("currentConvergenceFailures", &gridfire::solver::CVODESolverStrategy::TimestepContext::currentConvergenceFailures);
py_cvode_timestep_context.def_readonly("currentNonlinearIterations", &gridfire::solver::CVODESolverStrategy::TimestepContext::currentNonlinearIterations);
py_cvode_timestep_context.def_property_readonly(
"engine",
[](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> const gridfire::DynamicEngine& {
return self.engine;
}
);
py_cvode_timestep_context.def_property_readonly(
"networkSpecies",
[](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> std::vector<fourdst::atomic::Species> {
return self.networkSpecies;
}
);
auto py_dynamic_network_solver_strategy = py::class_<gridfire::solver::DynamicNetworkSolverStrategy, PyDynamicNetworkSolverStrategy>(m, "DynamicNetworkSolverStrategy"); auto py_dynamic_network_solver_strategy = py::class_<gridfire::solver::DynamicNetworkSolverStrategy, PyDynamicNetworkSolverStrategy>(m, "DynamicNetworkSolverStrategy");
py_dynamic_network_solver_strategy.def( py_dynamic_network_solver_strategy.def(
"evaluate", "evaluate",
@@ -24,13 +56,6 @@ void register_solver_bindings(const py::module &m) {
"evaluate the dynamic engine using the dynamic engine class" "evaluate the dynamic engine using the dynamic engine class"
); );
py_dynamic_network_solver_strategy.def(
"set_callback",
[](gridfire::solver::DynamicNetworkSolverStrategy& self, std::function<void(const gridfire::solver::SolverContextBase&)> cb) {
self.set_callback(cb);
},
"Set a callback function which will run at the end of every successful timestep"
);
py_dynamic_network_solver_strategy.def( py_dynamic_network_solver_strategy.def(
"describe_callback_context", "describe_callback_context",
@@ -67,33 +92,18 @@ void register_solver_bindings(const py::module &m) {
"Enable logging to standard output." "Enable logging to standard output."
); );
auto py_cvode_timestep_context = py::class_<gridfire::solver::CVODESolverStrategy::TimestepContext>(m, "CVODETimestepContext"); py_cvode_solver_strategy.def(
py_cvode_timestep_context.def_readonly("t", &gridfire::solver::CVODESolverStrategy::TimestepContext::t); "set_callback",
py_cvode_timestep_context.def_property_readonly( [](
"state", gridfire::solver::CVODESolverStrategy& self,
[](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> std::vector<double> { std::function<void(const gridfire::solver::CVODESolverStrategy::TimestepContext&)> cb
const sunrealtype* nvec_data = N_VGetArrayPointer(self.state); ) {
const sunindextype length = N_VGetLength(self.state); self.set_callback(std::any(cb));
return std::vector<double>(nvec_data, nvec_data + length); },
} py::arg("cb"),
); "Set a callback function which will run at the end of every successful timestep"
py_cvode_timestep_context.def_readonly("dt", &gridfire::solver::CVODESolverStrategy::TimestepContext::dt);
py_cvode_timestep_context.def_readonly("last_step_time", &gridfire::solver::CVODESolverStrategy::TimestepContext::last_step_time);
py_cvode_timestep_context.def_readonly("T9", &gridfire::solver::CVODESolverStrategy::TimestepContext::T9);
py_cvode_timestep_context.def_readonly("rho", &gridfire::solver::CVODESolverStrategy::TimestepContext::rho);
py_cvode_timestep_context.def_readonly("num_steps", &gridfire::solver::CVODESolverStrategy::TimestepContext::num_steps);
py_cvode_timestep_context.def_property_readonly(
"engine",
[](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> const gridfire::DynamicEngine& {
return self.engine;
}
);
py_cvode_timestep_context.def_property_readonly(
"networkSpecies",
[](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> std::vector<fourdst::atomic::Species> {
return self.networkSpecies;
}
); );
} }