fix(python/solver): fidxed any cast issue with set_callback and moved initialization of solver context to before set callback
This commit is contained in:
@@ -16,6 +16,38 @@ namespace py = pybind11;
|
||||
|
||||
|
||||
void register_solver_bindings(const py::module &m) {
|
||||
auto py_solver_context_base = py::class_<gridfire::solver::SolverContextBase>(m, "SolverContextBase");
|
||||
|
||||
auto py_cvode_timestep_context = py::class_<gridfire::solver::CVODESolverStrategy::TimestepContext, gridfire::solver::SolverContextBase>(m, "CVODETimestepContext");
|
||||
py_cvode_timestep_context.def_readonly("t", &gridfire::solver::CVODESolverStrategy::TimestepContext::t);
|
||||
py_cvode_timestep_context.def_property_readonly(
|
||||
"state",
|
||||
[](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> std::vector<double> {
|
||||
const sunrealtype* nvec_data = N_VGetArrayPointer(self.state);
|
||||
const sunindextype length = N_VGetLength(self.state);
|
||||
return std::vector<double>(nvec_data, nvec_data + length);
|
||||
}
|
||||
);
|
||||
py_cvode_timestep_context.def_readonly("dt", &gridfire::solver::CVODESolverStrategy::TimestepContext::dt);
|
||||
py_cvode_timestep_context.def_readonly("last_step_time", &gridfire::solver::CVODESolverStrategy::TimestepContext::last_step_time);
|
||||
py_cvode_timestep_context.def_readonly("T9", &gridfire::solver::CVODESolverStrategy::TimestepContext::T9);
|
||||
py_cvode_timestep_context.def_readonly("rho", &gridfire::solver::CVODESolverStrategy::TimestepContext::rho);
|
||||
py_cvode_timestep_context.def_readonly("num_steps", &gridfire::solver::CVODESolverStrategy::TimestepContext::num_steps);
|
||||
py_cvode_timestep_context.def_readonly("currentConvergenceFailures", &gridfire::solver::CVODESolverStrategy::TimestepContext::currentConvergenceFailures);
|
||||
py_cvode_timestep_context.def_readonly("currentNonlinearIterations", &gridfire::solver::CVODESolverStrategy::TimestepContext::currentNonlinearIterations);
|
||||
py_cvode_timestep_context.def_property_readonly(
|
||||
"engine",
|
||||
[](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> const gridfire::DynamicEngine& {
|
||||
return self.engine;
|
||||
}
|
||||
);
|
||||
py_cvode_timestep_context.def_property_readonly(
|
||||
"networkSpecies",
|
||||
[](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> std::vector<fourdst::atomic::Species> {
|
||||
return self.networkSpecies;
|
||||
}
|
||||
);
|
||||
|
||||
auto py_dynamic_network_solver_strategy = py::class_<gridfire::solver::DynamicNetworkSolverStrategy, PyDynamicNetworkSolverStrategy>(m, "DynamicNetworkSolverStrategy");
|
||||
py_dynamic_network_solver_strategy.def(
|
||||
"evaluate",
|
||||
@@ -24,13 +56,6 @@ void register_solver_bindings(const py::module &m) {
|
||||
"evaluate the dynamic engine using the dynamic engine class"
|
||||
);
|
||||
|
||||
py_dynamic_network_solver_strategy.def(
|
||||
"set_callback",
|
||||
[](gridfire::solver::DynamicNetworkSolverStrategy& self, std::function<void(const gridfire::solver::SolverContextBase&)> cb) {
|
||||
self.set_callback(cb);
|
||||
},
|
||||
"Set a callback function which will run at the end of every successful timestep"
|
||||
);
|
||||
|
||||
py_dynamic_network_solver_strategy.def(
|
||||
"describe_callback_context",
|
||||
@@ -67,33 +92,18 @@ void register_solver_bindings(const py::module &m) {
|
||||
"Enable logging to standard output."
|
||||
);
|
||||
|
||||
auto py_cvode_timestep_context = py::class_<gridfire::solver::CVODESolverStrategy::TimestepContext>(m, "CVODETimestepContext");
|
||||
py_cvode_timestep_context.def_readonly("t", &gridfire::solver::CVODESolverStrategy::TimestepContext::t);
|
||||
py_cvode_timestep_context.def_property_readonly(
|
||||
"state",
|
||||
[](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> std::vector<double> {
|
||||
const sunrealtype* nvec_data = N_VGetArrayPointer(self.state);
|
||||
const sunindextype length = N_VGetLength(self.state);
|
||||
return std::vector<double>(nvec_data, nvec_data + length);
|
||||
}
|
||||
);
|
||||
py_cvode_timestep_context.def_readonly("dt", &gridfire::solver::CVODESolverStrategy::TimestepContext::dt);
|
||||
py_cvode_timestep_context.def_readonly("last_step_time", &gridfire::solver::CVODESolverStrategy::TimestepContext::last_step_time);
|
||||
py_cvode_timestep_context.def_readonly("T9", &gridfire::solver::CVODESolverStrategy::TimestepContext::T9);
|
||||
py_cvode_timestep_context.def_readonly("rho", &gridfire::solver::CVODESolverStrategy::TimestepContext::rho);
|
||||
py_cvode_timestep_context.def_readonly("num_steps", &gridfire::solver::CVODESolverStrategy::TimestepContext::num_steps);
|
||||
py_cvode_timestep_context.def_property_readonly(
|
||||
"engine",
|
||||
[](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> const gridfire::DynamicEngine& {
|
||||
return self.engine;
|
||||
}
|
||||
);
|
||||
py_cvode_timestep_context.def_property_readonly(
|
||||
"networkSpecies",
|
||||
[](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> std::vector<fourdst::atomic::Species> {
|
||||
return self.networkSpecies;
|
||||
}
|
||||
py_cvode_solver_strategy.def(
|
||||
"set_callback",
|
||||
[](
|
||||
gridfire::solver::CVODESolverStrategy& self,
|
||||
std::function<void(const gridfire::solver::CVODESolverStrategy::TimestepContext&)> cb
|
||||
) {
|
||||
self.set_callback(std::any(cb));
|
||||
},
|
||||
py::arg("cb"),
|
||||
"Set a callback function which will run at the end of every successful timestep"
|
||||
);
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user