feat(solver): added callback functions to solver in C++ and python
This commit is contained in:
@@ -34,7 +34,7 @@ namespace gridfire::solver {
|
||||
size_t numSpecies = m_engine.getNetworkSpecies().size();
|
||||
ublas::vector<double> Y(numSpecies + 1);
|
||||
|
||||
RHSManager manager(m_engine, T9, netIn.density);
|
||||
RHSManager manager(m_engine, T9, netIn.density, m_callback, m_engine.getNetworkSpecies());
|
||||
JacobianFunctor jacobianFunctor(m_engine, T9, netIn.density);
|
||||
|
||||
auto populateY = [&](const Composition& comp) {
|
||||
@@ -149,6 +149,44 @@ namespace gridfire::solver {
|
||||
return netOut;
|
||||
}
|
||||
|
||||
void DirectNetworkSolver::set_callback(const std::any& callback) {
|
||||
if (!callback.has_value()) {
|
||||
m_callback = {};
|
||||
return;
|
||||
}
|
||||
|
||||
using FunctionPtrType = void (*)(const TimestepContext&);
|
||||
|
||||
if (callback.type() == typeid(TimestepCallback)) {
|
||||
m_callback = std::any_cast<TimestepCallback>(callback);
|
||||
}
|
||||
else if (callback.type() == typeid(FunctionPtrType)) {
|
||||
auto func_ptr = std::any_cast<FunctionPtrType>(callback);
|
||||
m_callback = func_ptr;
|
||||
}
|
||||
else {
|
||||
throw std::invalid_argument("Unsupported type passed to set_callback. "
|
||||
"Provide a std::function or a matching function pointer.");
|
||||
}
|
||||
}
|
||||
std::vector<std::tuple<std::string, std::string>> DirectNetworkSolver::describe_callback_context() const {
|
||||
const TimestepContext context(
|
||||
0.0, // time
|
||||
boost::numeric::ublas::vector<double>(), // state
|
||||
0.0, // dt
|
||||
0.0, // cached_time
|
||||
0.0, // last_observed_time
|
||||
0.0, // last_step_time
|
||||
0.0, // T9
|
||||
0.0, // rho
|
||||
std::nullopt, // cached_result
|
||||
0, // num_steps
|
||||
m_engine, // engine,
|
||||
{}
|
||||
);
|
||||
return context.describe();
|
||||
}
|
||||
|
||||
void DirectNetworkSolver::RHSManager::operator()(
|
||||
const boost::numeric::ublas::vector<double> &Y,
|
||||
boost::numeric::ublas::vector<double> &dYdt,
|
||||
@@ -181,6 +219,29 @@ namespace gridfire::solver {
|
||||
oss << std::scientific << std::setprecision(3);
|
||||
oss << "(Step: " << std::setw(10) << m_num_steps << ") t = " << t << " (dt = " << dt << ", eps_nuc: " << state(state.size() - 1) << " [erg])\n";
|
||||
std::cout << oss.str();
|
||||
|
||||
// Callback logic
|
||||
if (m_callback) {
|
||||
LOG_TRACE_L1(m_logger, "Calling user callback function at t = {:0.3E} with dt = {:0.3E}", t, dt);
|
||||
const TimestepContext context(
|
||||
t,
|
||||
state,
|
||||
dt,
|
||||
m_cached_time,
|
||||
m_last_observed_time,
|
||||
m_last_step_time,
|
||||
m_T9,
|
||||
m_rho,
|
||||
m_cached_result,
|
||||
m_num_steps,
|
||||
m_engine,
|
||||
m_networkSpecies
|
||||
);
|
||||
|
||||
m_callback(context);
|
||||
LOG_TRACE_L1(m_logger, "User callback function completed at t = {:0.3E} with dt = {:0.3E}", t, dt);
|
||||
}
|
||||
|
||||
m_last_observed_time = t;
|
||||
m_last_step_time = dt;
|
||||
|
||||
@@ -228,4 +289,49 @@ namespace gridfire::solver {
|
||||
}
|
||||
}
|
||||
|
||||
DirectNetworkSolver::TimestepContext::TimestepContext(
|
||||
const double t,
|
||||
const boost::numeric::ublas::vector<double> &state,
|
||||
const double dt,
|
||||
const double cached_time,
|
||||
const double last_observed_time,
|
||||
const double last_step_time,
|
||||
const double t9,
|
||||
const double rho,
|
||||
const std::optional<StepDerivatives<double>> &cached_result,
|
||||
const int num_steps,
|
||||
const DynamicEngine &engine,
|
||||
const std::vector<fourdst::atomic::Species> &networkSpecies
|
||||
)
|
||||
: t(t),
|
||||
state(state),
|
||||
dt(dt),
|
||||
cached_time(cached_time),
|
||||
last_observed_time(last_observed_time),
|
||||
last_step_time(last_step_time),
|
||||
T9(t9),
|
||||
rho(rho),
|
||||
cached_result(cached_result),
|
||||
num_steps(num_steps),
|
||||
engine(engine),
|
||||
networkSpecies(networkSpecies) {}
|
||||
|
||||
std::vector<std::tuple<std::string, std::string>> DirectNetworkSolver::TimestepContext::describe() const {
|
||||
return {
|
||||
{"time", "double"},
|
||||
{"state", "boost::numeric::ublas::vector<double>&"},
|
||||
{"dt", "double"},
|
||||
{"cached_time", "double"},
|
||||
{"last_observed_time", "double"},
|
||||
{"last_step_time", "double"},
|
||||
{"T9", "double"},
|
||||
{"rho", "double"},
|
||||
{"cached_result", "std::optional<StepDerivatives<double>>&"},
|
||||
{"num_steps", "int"},
|
||||
{"engine", "DynamicEngine&"},
|
||||
{"networkSpecies", "std::vector<fourdst::atomic::Species>&"}
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user