feat(trigger): added working robust repartitioning trigger system

more work is needed to identify the most robust set of criteria to trigger on but the system is now very easy to exend, probe, and use.
This commit is contained in:
2025-09-29 13:35:48 -04:00
parent 4c91f8c525
commit 4f1c260444
12 changed files with 980 additions and 197 deletions

View File

@@ -17,6 +17,8 @@
#include <algorithm>
#include "fourdst/composition/exceptions/exceptions_composition.h"
#include "gridfire/solver/strategies/triggers/engine_partitioning_trigger.h"
#include "gridfire/trigger/procedures/trigger_pprint.h"
namespace {
@@ -78,6 +80,43 @@ namespace {
namespace gridfire::solver {
CVODESolverStrategy::TimestepContext::TimestepContext(
const double t,
const N_Vector &state,
const double dt,
const double last_step_time,
const double t9,
const double rho,
const int num_steps,
const DynamicEngine &engine,
const std::vector<fourdst::atomic::Species> &networkSpecies
) :
t(t),
state(state),
dt(dt),
last_step_time(last_step_time),
T9(t9),
rho(rho),
num_steps(num_steps),
engine(engine),
networkSpecies(networkSpecies)
{}
std::vector<std::tuple<std::string, std::string>> CVODESolverStrategy::TimestepContext::describe() const {
std::vector<std::tuple<std::string, std::string>> description;
description.emplace_back("t", "Current Time");
description.emplace_back("state", "Current State Vector (N_Vector)");
description.emplace_back("dt", "Last Timestep Size");
description.emplace_back("last_step_time", "Time at Last Step");
description.emplace_back("T9", "Temperature in GK");
description.emplace_back("rho", "Density in g/cm^3");
description.emplace_back("num_steps", "Number of Steps Taken So Far");
description.emplace_back("engine", "Reference to the DynamicEngine");
description.emplace_back("networkSpecies", "Reference to the list of network species");
return description;
}
CVODESolverStrategy::CVODESolverStrategy(DynamicEngine &engine): NetworkSolverStrategy<DynamicEngine>(engine) {
// TODO: In order to support MPI this function must be changed
const int flag = SUNContext_Create(SUN_COMM_NULL, &m_sun_ctx);
@@ -96,6 +135,8 @@ namespace gridfire::solver {
}
NetOut CVODESolverStrategy::evaluate(const NetIn& netIn) {
auto trigger = trigger::solver::CVODE::makeEnginePartitioningTrigger(1e12, 1e10, 1, true, 10);
const double T9 = netIn.temperature / 1e9; // Convert temperature from Kelvin to T9 (T9 = T / 1e9)
const auto absTol = m_config.get<double>("gridfire:solver:CVODESolverStrategy:absTol", 1.0e-8);
@@ -121,74 +162,63 @@ namespace gridfire::solver {
double accumulated_energy = 0.0;
int total_update_stages_triggered = 0;
while (current_time < netIn.tMax) {
try {
user_data.T9 = T9;
user_data.rho = netIn.density;
user_data.networkSpecies = &m_engine.getNetworkSpecies();
user_data.captured_exception.reset();
user_data.T9 = T9;
user_data.rho = netIn.density;
user_data.networkSpecies = &m_engine.getNetworkSpecies();
user_data.captured_exception.reset();
check_cvode_flag(CVodeSetUserData(m_cvode_mem, &user_data), "CVodeSetUserData");
check_cvode_flag(CVodeSetUserData(m_cvode_mem, &user_data), "CVodeSetUserData");
int flag = -1;
if (m_stdout_logging_enabled) {
flag = CVode(m_cvode_mem, netIn.tMax, m_Y, &current_time, CV_ONE_STEP);
} else {
flag = CVode(m_cvode_mem, netIn.tMax, m_Y, &current_time, CV_NORMAL);
}
int flag = -1;
if (m_stdout_logging_enabled) {
flag = CVode(m_cvode_mem, netIn.tMax, m_Y, &current_time, CV_ONE_STEP);
} else {
flag = CVode(m_cvode_mem, netIn.tMax, m_Y, &current_time, CV_NORMAL);
}
if (user_data.captured_exception){
std::rethrow_exception(std::make_exception_ptr(*user_data.captured_exception));
}
if (user_data.captured_exception){
std::rethrow_exception(std::make_exception_ptr(*user_data.captured_exception));
}
check_cvode_flag(flag, "CVode");
check_cvode_flag(flag, "CVode");
long int n_steps;
double last_step_size;
CVodeGetNumSteps(m_cvode_mem, &n_steps);
CVodeGetLastStep(m_cvode_mem, &last_step_size);
long int nliters, nlcfails;
CVodeGetNumNonlinSolvIters(m_cvode_mem, &nliters);
CVodeGetNumNonlinSolvConvFails(m_cvode_mem, &nlcfails);
long int n_steps;
double last_step_size;
CVodeGetNumSteps(m_cvode_mem, &n_steps);
CVodeGetLastStep(m_cvode_mem, &last_step_size);
long int nliters, nlcfails;
CVodeGetNumNonlinSolvIters(m_cvode_mem, &nliters);
CVodeGetNumNonlinSolvConvFails(m_cvode_mem, &nlcfails);
sunrealtype* y_data = N_VGetArrayPointer(m_Y);
const double current_energy = y_data[numSpecies]; // Specific energy rate
std::cout << std::scientific << std::setprecision(3)
<< "Step: " << std::setw(6) << n_steps
<< " | Time: " << current_time << " [s]"
<< " | Last Step Size: " << last_step_size
<< " | Accumulated Energy: " << current_energy << " [erg/g]"
<< " | NonlinIters: " << std::setw(2) << nliters
<< " | ConvFails: " << std::setw(2) << nlcfails
<< std::endl;
if (n_steps % 300 == 0) {
std::cout << "Manually triggering engine update at step " << n_steps << "..." << std::endl;
exceptions::StaleEngineTrigger::state staleState {
T9,
netIn.density,
std::vector<double>(y_data, y_data + numSpecies),
current_time,
static_cast<int>(n_steps),
current_energy
};
throw exceptions::StaleEngineTrigger(staleState);
}
sunrealtype* y_data = N_VGetArrayPointer(m_Y);
const double current_energy = y_data[numSpecies]; // Specific energy rate
std::cout << std::scientific << std::setprecision(3)
<< "Step: " << std::setw(6) << n_steps
<< " | Time: " << current_time << " [s]"
<< " | Last Step Size: " << last_step_size
<< " | Accumulated Energy: " << current_energy << " [erg/g]"
<< " | NonLinIters: " << std::setw(2) << nliters
<< " | ConvFails: " << std::setw(2) << nlcfails
<< std::endl;
// if (n_steps % 50 == 0) {
// std::cout << "Logging step diagnostics at step " << n_steps << "..." << std::endl;
// log_step_diagnostics(user_data);
// }
// if (n_steps == 300) {
// log_step_diagnostics(user_data);
// exit(0);
// }
// log_step_diagnostics(user_data);
auto ctx = TimestepContext(
current_time,
reinterpret_cast<N_Vector>(y_data),
last_step_size,
last_callback_time,
T9,
netIn.density,
n_steps,
m_engine,
m_engine.getNetworkSpecies());
} catch (const exceptions::StaleEngineTrigger& e) {
exceptions::StaleEngineTrigger::state staleState = e.getState();
accumulated_energy += e.energy(); // Add the specific energy rate to the accumulated energy
if (trigger->check(ctx)) {
trigger::printWhy(trigger->why(ctx));
trigger->update(ctx);
accumulated_energy += current_energy; // Add the specific energy rate to the accumulated energy
LOG_INFO(
m_logger,
"Engine Update Triggered due to StaleEngineTrigger exception at time {} ({} update{} triggered). Current total specific energy {} [erg/g]",
"Engine Update Triggered at time {} ({} update{} triggered). Current total specific energy {} [erg/g]",
current_time,
total_update_stages_triggered,
total_update_stages_triggered == 1 ? "" : "s",
@@ -197,21 +227,31 @@ namespace gridfire::solver {
fourdst::composition::Composition temp_comp;
std::vector<double> mass_fractions;
size_t num_species_at_stop = e.numSpecies();
size_t num_species_at_stop = m_engine.getNetworkSpecies().size();
if (num_species_at_stop > m_Y->ops->nvgetlength(m_Y) - 1) {
LOG_ERROR(
m_logger,
"Number of species at engine update ({}) exceeds the number of species in the CVODE solver ({}). This should never happen.",
num_species_at_stop,
m_Y->ops->nvgetlength(m_Y) - 1 // -1 due to energy in the last index
);
throw std::runtime_error("Number of species at engine update exceeds the number of species in the CVODE solver. This should never happen.");
}
mass_fractions.reserve(num_species_at_stop);
for (size_t i = 0; i < num_species_at_stop; ++i) {
const auto& species = m_engine.getNetworkSpecies()[i];
temp_comp.registerSpecies(species);
mass_fractions.push_back(e.getMolarAbundance(i) * species.mass()); // Convert from molar abundance to mass fraction
mass_fractions.push_back(y_data[i] * species.mass()); // Convert from molar abundance to mass fraction
}
temp_comp.setMassFraction(m_engine.getNetworkSpecies(), mass_fractions);
temp_comp.finalize(true);
NetIn netInTemp = netIn;
netInTemp.temperature = e.temperature();
netInTemp.density = e.density();
netInTemp.temperature = T9 * 1e9; // Convert back to Kelvin
netInTemp.density = netIn.density;
netInTemp.composition = temp_comp;
fourdst::composition::Composition currentComposition = m_engine.update(netInTemp);
@@ -225,13 +265,16 @@ namespace gridfire::solver {
numSpecies = m_engine.getNetworkSpecies().size();
N = numSpecies + 1;
cleanup_cvode_resources(true);
m_cvode_mem = CVodeCreate(CV_BDF, m_sun_ctx);
check_cvode_flag(m_cvode_mem == nullptr ? -1 : 0, "CVodeCreate");
initialize_cvode_integration_resources(N, numSpecies, current_time, currentComposition, absTol, relTol, accumulated_energy);
check_cvode_flag(CVodeReInit(m_cvode_mem, current_time, m_Y), "CVodeReInit");
} catch (fourdst::composition::exceptions::InvalidCompositionError& e) {
log_step_diagnostics(user_data);
std::rethrow_exception(std::make_exception_ptr(e));
}
}
sunrealtype* y_data = N_VGetArrayPointer(m_Y);