feat(solver): added callback functions to solver in C++ and python

This commit is contained in:
2025-07-31 15:04:57 -04:00
parent 5b74155477
commit 24049b2658
482 changed files with 4318 additions and 1467 deletions

View File

@@ -34,7 +34,7 @@ namespace gridfire::solver {
size_t numSpecies = m_engine.getNetworkSpecies().size();
ublas::vector<double> Y(numSpecies + 1);
RHSManager manager(m_engine, T9, netIn.density);
RHSManager manager(m_engine, T9, netIn.density, m_callback, m_engine.getNetworkSpecies());
JacobianFunctor jacobianFunctor(m_engine, T9, netIn.density);
auto populateY = [&](const Composition& comp) {
@@ -149,6 +149,44 @@ namespace gridfire::solver {
return netOut;
}
void DirectNetworkSolver::set_callback(const std::any& callback) {
if (!callback.has_value()) {
m_callback = {};
return;
}
using FunctionPtrType = void (*)(const TimestepContext&);
if (callback.type() == typeid(TimestepCallback)) {
m_callback = std::any_cast<TimestepCallback>(callback);
}
else if (callback.type() == typeid(FunctionPtrType)) {
auto func_ptr = std::any_cast<FunctionPtrType>(callback);
m_callback = func_ptr;
}
else {
throw std::invalid_argument("Unsupported type passed to set_callback. "
"Provide a std::function or a matching function pointer.");
}
}
std::vector<std::tuple<std::string, std::string>> DirectNetworkSolver::describe_callback_context() const {
const TimestepContext context(
0.0, // time
boost::numeric::ublas::vector<double>(), // state
0.0, // dt
0.0, // cached_time
0.0, // last_observed_time
0.0, // last_step_time
0.0, // T9
0.0, // rho
std::nullopt, // cached_result
0, // num_steps
m_engine, // engine,
{}
);
return context.describe();
}
void DirectNetworkSolver::RHSManager::operator()(
const boost::numeric::ublas::vector<double> &Y,
boost::numeric::ublas::vector<double> &dYdt,
@@ -181,6 +219,29 @@ namespace gridfire::solver {
oss << std::scientific << std::setprecision(3);
oss << "(Step: " << std::setw(10) << m_num_steps << ") t = " << t << " (dt = " << dt << ", eps_nuc: " << state(state.size() - 1) << " [erg])\n";
std::cout << oss.str();
// Callback logic
if (m_callback) {
LOG_TRACE_L1(m_logger, "Calling user callback function at t = {:0.3E} with dt = {:0.3E}", t, dt);
const TimestepContext context(
t,
state,
dt,
m_cached_time,
m_last_observed_time,
m_last_step_time,
m_T9,
m_rho,
m_cached_result,
m_num_steps,
m_engine,
m_networkSpecies
);
m_callback(context);
LOG_TRACE_L1(m_logger, "User callback function completed at t = {:0.3E} with dt = {:0.3E}", t, dt);
}
m_last_observed_time = t;
m_last_step_time = dt;
@@ -228,4 +289,49 @@ namespace gridfire::solver {
}
}
DirectNetworkSolver::TimestepContext::TimestepContext(
const double t,
const boost::numeric::ublas::vector<double> &state,
const double dt,
const double cached_time,
const double last_observed_time,
const double last_step_time,
const double t9,
const double rho,
const std::optional<StepDerivatives<double>> &cached_result,
const int num_steps,
const DynamicEngine &engine,
const std::vector<fourdst::atomic::Species> &networkSpecies
)
: t(t),
state(state),
dt(dt),
cached_time(cached_time),
last_observed_time(last_observed_time),
last_step_time(last_step_time),
T9(t9),
rho(rho),
cached_result(cached_result),
num_steps(num_steps),
engine(engine),
networkSpecies(networkSpecies) {}
std::vector<std::tuple<std::string, std::string>> DirectNetworkSolver::TimestepContext::describe() const {
return {
{"time", "double"},
{"state", "boost::numeric::ublas::vector<double>&"},
{"dt", "double"},
{"cached_time", "double"},
{"last_observed_time", "double"},
{"last_step_time", "double"},
{"T9", "double"},
{"rho", "double"},
{"cached_result", "std::optional<StepDerivatives<double>>&"},
{"num_steps", "int"},
{"engine", "DynamicEngine&"},
{"networkSpecies", "std::vector<fourdst::atomic::Species>&"}
};
}
}