fix(python/solver): fidxed any cast issue with set_callback and moved initialization of solver context to before set callback
This commit is contained in:
@@ -16,6 +16,38 @@ namespace py = pybind11;
|
|||||||
|
|
||||||
|
|
||||||
void register_solver_bindings(const py::module &m) {
|
void register_solver_bindings(const py::module &m) {
|
||||||
|
auto py_solver_context_base = py::class_<gridfire::solver::SolverContextBase>(m, "SolverContextBase");
|
||||||
|
|
||||||
|
auto py_cvode_timestep_context = py::class_<gridfire::solver::CVODESolverStrategy::TimestepContext, gridfire::solver::SolverContextBase>(m, "CVODETimestepContext");
|
||||||
|
py_cvode_timestep_context.def_readonly("t", &gridfire::solver::CVODESolverStrategy::TimestepContext::t);
|
||||||
|
py_cvode_timestep_context.def_property_readonly(
|
||||||
|
"state",
|
||||||
|
[](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> std::vector<double> {
|
||||||
|
const sunrealtype* nvec_data = N_VGetArrayPointer(self.state);
|
||||||
|
const sunindextype length = N_VGetLength(self.state);
|
||||||
|
return std::vector<double>(nvec_data, nvec_data + length);
|
||||||
|
}
|
||||||
|
);
|
||||||
|
py_cvode_timestep_context.def_readonly("dt", &gridfire::solver::CVODESolverStrategy::TimestepContext::dt);
|
||||||
|
py_cvode_timestep_context.def_readonly("last_step_time", &gridfire::solver::CVODESolverStrategy::TimestepContext::last_step_time);
|
||||||
|
py_cvode_timestep_context.def_readonly("T9", &gridfire::solver::CVODESolverStrategy::TimestepContext::T9);
|
||||||
|
py_cvode_timestep_context.def_readonly("rho", &gridfire::solver::CVODESolverStrategy::TimestepContext::rho);
|
||||||
|
py_cvode_timestep_context.def_readonly("num_steps", &gridfire::solver::CVODESolverStrategy::TimestepContext::num_steps);
|
||||||
|
py_cvode_timestep_context.def_readonly("currentConvergenceFailures", &gridfire::solver::CVODESolverStrategy::TimestepContext::currentConvergenceFailures);
|
||||||
|
py_cvode_timestep_context.def_readonly("currentNonlinearIterations", &gridfire::solver::CVODESolverStrategy::TimestepContext::currentNonlinearIterations);
|
||||||
|
py_cvode_timestep_context.def_property_readonly(
|
||||||
|
"engine",
|
||||||
|
[](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> const gridfire::DynamicEngine& {
|
||||||
|
return self.engine;
|
||||||
|
}
|
||||||
|
);
|
||||||
|
py_cvode_timestep_context.def_property_readonly(
|
||||||
|
"networkSpecies",
|
||||||
|
[](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> std::vector<fourdst::atomic::Species> {
|
||||||
|
return self.networkSpecies;
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
auto py_dynamic_network_solver_strategy = py::class_<gridfire::solver::DynamicNetworkSolverStrategy, PyDynamicNetworkSolverStrategy>(m, "DynamicNetworkSolverStrategy");
|
auto py_dynamic_network_solver_strategy = py::class_<gridfire::solver::DynamicNetworkSolverStrategy, PyDynamicNetworkSolverStrategy>(m, "DynamicNetworkSolverStrategy");
|
||||||
py_dynamic_network_solver_strategy.def(
|
py_dynamic_network_solver_strategy.def(
|
||||||
"evaluate",
|
"evaluate",
|
||||||
@@ -24,13 +56,6 @@ void register_solver_bindings(const py::module &m) {
|
|||||||
"evaluate the dynamic engine using the dynamic engine class"
|
"evaluate the dynamic engine using the dynamic engine class"
|
||||||
);
|
);
|
||||||
|
|
||||||
py_dynamic_network_solver_strategy.def(
|
|
||||||
"set_callback",
|
|
||||||
[](gridfire::solver::DynamicNetworkSolverStrategy& self, std::function<void(const gridfire::solver::SolverContextBase&)> cb) {
|
|
||||||
self.set_callback(cb);
|
|
||||||
},
|
|
||||||
"Set a callback function which will run at the end of every successful timestep"
|
|
||||||
);
|
|
||||||
|
|
||||||
py_dynamic_network_solver_strategy.def(
|
py_dynamic_network_solver_strategy.def(
|
||||||
"describe_callback_context",
|
"describe_callback_context",
|
||||||
@@ -67,33 +92,18 @@ void register_solver_bindings(const py::module &m) {
|
|||||||
"Enable logging to standard output."
|
"Enable logging to standard output."
|
||||||
);
|
);
|
||||||
|
|
||||||
auto py_cvode_timestep_context = py::class_<gridfire::solver::CVODESolverStrategy::TimestepContext>(m, "CVODETimestepContext");
|
py_cvode_solver_strategy.def(
|
||||||
py_cvode_timestep_context.def_readonly("t", &gridfire::solver::CVODESolverStrategy::TimestepContext::t);
|
"set_callback",
|
||||||
py_cvode_timestep_context.def_property_readonly(
|
[](
|
||||||
"state",
|
gridfire::solver::CVODESolverStrategy& self,
|
||||||
[](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> std::vector<double> {
|
std::function<void(const gridfire::solver::CVODESolverStrategy::TimestepContext&)> cb
|
||||||
const sunrealtype* nvec_data = N_VGetArrayPointer(self.state);
|
) {
|
||||||
const sunindextype length = N_VGetLength(self.state);
|
self.set_callback(std::any(cb));
|
||||||
return std::vector<double>(nvec_data, nvec_data + length);
|
},
|
||||||
}
|
py::arg("cb"),
|
||||||
);
|
"Set a callback function which will run at the end of every successful timestep"
|
||||||
py_cvode_timestep_context.def_readonly("dt", &gridfire::solver::CVODESolverStrategy::TimestepContext::dt);
|
|
||||||
py_cvode_timestep_context.def_readonly("last_step_time", &gridfire::solver::CVODESolverStrategy::TimestepContext::last_step_time);
|
|
||||||
py_cvode_timestep_context.def_readonly("T9", &gridfire::solver::CVODESolverStrategy::TimestepContext::T9);
|
|
||||||
py_cvode_timestep_context.def_readonly("rho", &gridfire::solver::CVODESolverStrategy::TimestepContext::rho);
|
|
||||||
py_cvode_timestep_context.def_readonly("num_steps", &gridfire::solver::CVODESolverStrategy::TimestepContext::num_steps);
|
|
||||||
py_cvode_timestep_context.def_property_readonly(
|
|
||||||
"engine",
|
|
||||||
[](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> const gridfire::DynamicEngine& {
|
|
||||||
return self.engine;
|
|
||||||
}
|
|
||||||
);
|
|
||||||
py_cvode_timestep_context.def_property_readonly(
|
|
||||||
"networkSpecies",
|
|
||||||
[](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> std::vector<fourdst::atomic::Species> {
|
|
||||||
return self.networkSpecies;
|
|
||||||
}
|
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user