feat(trigger): added working robust repartitioning trigger system
more work is needed to identify the most robust set of criteria to trigger on but the system is now very easy to exend, probe, and use.
This commit is contained in:
@@ -17,6 +17,8 @@
|
||||
#include <algorithm>
|
||||
|
||||
#include "fourdst/composition/exceptions/exceptions_composition.h"
|
||||
#include "gridfire/solver/strategies/triggers/engine_partitioning_trigger.h"
|
||||
#include "gridfire/trigger/procedures/trigger_pprint.h"
|
||||
|
||||
|
||||
namespace {
|
||||
@@ -78,6 +80,43 @@ namespace {
|
||||
|
||||
namespace gridfire::solver {
|
||||
|
||||
CVODESolverStrategy::TimestepContext::TimestepContext(
|
||||
const double t,
|
||||
const N_Vector &state,
|
||||
const double dt,
|
||||
const double last_step_time,
|
||||
const double t9,
|
||||
const double rho,
|
||||
const int num_steps,
|
||||
const DynamicEngine &engine,
|
||||
const std::vector<fourdst::atomic::Species> &networkSpecies
|
||||
) :
|
||||
t(t),
|
||||
state(state),
|
||||
dt(dt),
|
||||
last_step_time(last_step_time),
|
||||
T9(t9),
|
||||
rho(rho),
|
||||
num_steps(num_steps),
|
||||
engine(engine),
|
||||
networkSpecies(networkSpecies)
|
||||
{}
|
||||
|
||||
std::vector<std::tuple<std::string, std::string>> CVODESolverStrategy::TimestepContext::describe() const {
|
||||
std::vector<std::tuple<std::string, std::string>> description;
|
||||
description.emplace_back("t", "Current Time");
|
||||
description.emplace_back("state", "Current State Vector (N_Vector)");
|
||||
description.emplace_back("dt", "Last Timestep Size");
|
||||
description.emplace_back("last_step_time", "Time at Last Step");
|
||||
description.emplace_back("T9", "Temperature in GK");
|
||||
description.emplace_back("rho", "Density in g/cm^3");
|
||||
description.emplace_back("num_steps", "Number of Steps Taken So Far");
|
||||
description.emplace_back("engine", "Reference to the DynamicEngine");
|
||||
description.emplace_back("networkSpecies", "Reference to the list of network species");
|
||||
return description;
|
||||
}
|
||||
|
||||
|
||||
CVODESolverStrategy::CVODESolverStrategy(DynamicEngine &engine): NetworkSolverStrategy<DynamicEngine>(engine) {
|
||||
// TODO: In order to support MPI this function must be changed
|
||||
const int flag = SUNContext_Create(SUN_COMM_NULL, &m_sun_ctx);
|
||||
@@ -96,6 +135,8 @@ namespace gridfire::solver {
|
||||
}
|
||||
|
||||
NetOut CVODESolverStrategy::evaluate(const NetIn& netIn) {
|
||||
auto trigger = trigger::solver::CVODE::makeEnginePartitioningTrigger(1e12, 1e10, 1, true, 10);
|
||||
|
||||
const double T9 = netIn.temperature / 1e9; // Convert temperature from Kelvin to T9 (T9 = T / 1e9)
|
||||
|
||||
const auto absTol = m_config.get<double>("gridfire:solver:CVODESolverStrategy:absTol", 1.0e-8);
|
||||
@@ -121,74 +162,63 @@ namespace gridfire::solver {
|
||||
double accumulated_energy = 0.0;
|
||||
int total_update_stages_triggered = 0;
|
||||
while (current_time < netIn.tMax) {
|
||||
try {
|
||||
user_data.T9 = T9;
|
||||
user_data.rho = netIn.density;
|
||||
user_data.networkSpecies = &m_engine.getNetworkSpecies();
|
||||
user_data.captured_exception.reset();
|
||||
user_data.T9 = T9;
|
||||
user_data.rho = netIn.density;
|
||||
user_data.networkSpecies = &m_engine.getNetworkSpecies();
|
||||
user_data.captured_exception.reset();
|
||||
|
||||
check_cvode_flag(CVodeSetUserData(m_cvode_mem, &user_data), "CVodeSetUserData");
|
||||
check_cvode_flag(CVodeSetUserData(m_cvode_mem, &user_data), "CVodeSetUserData");
|
||||
|
||||
int flag = -1;
|
||||
if (m_stdout_logging_enabled) {
|
||||
flag = CVode(m_cvode_mem, netIn.tMax, m_Y, ¤t_time, CV_ONE_STEP);
|
||||
} else {
|
||||
flag = CVode(m_cvode_mem, netIn.tMax, m_Y, ¤t_time, CV_NORMAL);
|
||||
}
|
||||
int flag = -1;
|
||||
if (m_stdout_logging_enabled) {
|
||||
flag = CVode(m_cvode_mem, netIn.tMax, m_Y, ¤t_time, CV_ONE_STEP);
|
||||
} else {
|
||||
flag = CVode(m_cvode_mem, netIn.tMax, m_Y, ¤t_time, CV_NORMAL);
|
||||
}
|
||||
|
||||
if (user_data.captured_exception){
|
||||
std::rethrow_exception(std::make_exception_ptr(*user_data.captured_exception));
|
||||
}
|
||||
if (user_data.captured_exception){
|
||||
std::rethrow_exception(std::make_exception_ptr(*user_data.captured_exception));
|
||||
}
|
||||
|
||||
check_cvode_flag(flag, "CVode");
|
||||
check_cvode_flag(flag, "CVode");
|
||||
|
||||
long int n_steps;
|
||||
double last_step_size;
|
||||
CVodeGetNumSteps(m_cvode_mem, &n_steps);
|
||||
CVodeGetLastStep(m_cvode_mem, &last_step_size);
|
||||
long int nliters, nlcfails;
|
||||
CVodeGetNumNonlinSolvIters(m_cvode_mem, &nliters);
|
||||
CVodeGetNumNonlinSolvConvFails(m_cvode_mem, &nlcfails);
|
||||
long int n_steps;
|
||||
double last_step_size;
|
||||
CVodeGetNumSteps(m_cvode_mem, &n_steps);
|
||||
CVodeGetLastStep(m_cvode_mem, &last_step_size);
|
||||
long int nliters, nlcfails;
|
||||
CVodeGetNumNonlinSolvIters(m_cvode_mem, &nliters);
|
||||
CVodeGetNumNonlinSolvConvFails(m_cvode_mem, &nlcfails);
|
||||
|
||||
sunrealtype* y_data = N_VGetArrayPointer(m_Y);
|
||||
const double current_energy = y_data[numSpecies]; // Specific energy rate
|
||||
std::cout << std::scientific << std::setprecision(3)
|
||||
<< "Step: " << std::setw(6) << n_steps
|
||||
<< " | Time: " << current_time << " [s]"
|
||||
<< " | Last Step Size: " << last_step_size
|
||||
<< " | Accumulated Energy: " << current_energy << " [erg/g]"
|
||||
<< " | NonlinIters: " << std::setw(2) << nliters
|
||||
<< " | ConvFails: " << std::setw(2) << nlcfails
|
||||
<< std::endl;
|
||||
if (n_steps % 300 == 0) {
|
||||
std::cout << "Manually triggering engine update at step " << n_steps << "..." << std::endl;
|
||||
exceptions::StaleEngineTrigger::state staleState {
|
||||
T9,
|
||||
netIn.density,
|
||||
std::vector<double>(y_data, y_data + numSpecies),
|
||||
current_time,
|
||||
static_cast<int>(n_steps),
|
||||
current_energy
|
||||
};
|
||||
throw exceptions::StaleEngineTrigger(staleState);
|
||||
}
|
||||
sunrealtype* y_data = N_VGetArrayPointer(m_Y);
|
||||
const double current_energy = y_data[numSpecies]; // Specific energy rate
|
||||
std::cout << std::scientific << std::setprecision(3)
|
||||
<< "Step: " << std::setw(6) << n_steps
|
||||
<< " | Time: " << current_time << " [s]"
|
||||
<< " | Last Step Size: " << last_step_size
|
||||
<< " | Accumulated Energy: " << current_energy << " [erg/g]"
|
||||
<< " | NonLinIters: " << std::setw(2) << nliters
|
||||
<< " | ConvFails: " << std::setw(2) << nlcfails
|
||||
<< std::endl;
|
||||
|
||||
// if (n_steps % 50 == 0) {
|
||||
// std::cout << "Logging step diagnostics at step " << n_steps << "..." << std::endl;
|
||||
// log_step_diagnostics(user_data);
|
||||
// }
|
||||
// if (n_steps == 300) {
|
||||
// log_step_diagnostics(user_data);
|
||||
// exit(0);
|
||||
// }
|
||||
// log_step_diagnostics(user_data);
|
||||
auto ctx = TimestepContext(
|
||||
current_time,
|
||||
reinterpret_cast<N_Vector>(y_data),
|
||||
last_step_size,
|
||||
last_callback_time,
|
||||
T9,
|
||||
netIn.density,
|
||||
n_steps,
|
||||
m_engine,
|
||||
m_engine.getNetworkSpecies());
|
||||
|
||||
} catch (const exceptions::StaleEngineTrigger& e) {
|
||||
exceptions::StaleEngineTrigger::state staleState = e.getState();
|
||||
accumulated_energy += e.energy(); // Add the specific energy rate to the accumulated energy
|
||||
if (trigger->check(ctx)) {
|
||||
trigger::printWhy(trigger->why(ctx));
|
||||
trigger->update(ctx);
|
||||
accumulated_energy += current_energy; // Add the specific energy rate to the accumulated energy
|
||||
LOG_INFO(
|
||||
m_logger,
|
||||
"Engine Update Triggered due to StaleEngineTrigger exception at time {} ({} update{} triggered). Current total specific energy {} [erg/g]",
|
||||
"Engine Update Triggered at time {} ({} update{} triggered). Current total specific energy {} [erg/g]",
|
||||
current_time,
|
||||
total_update_stages_triggered,
|
||||
total_update_stages_triggered == 1 ? "" : "s",
|
||||
@@ -197,21 +227,31 @@ namespace gridfire::solver {
|
||||
|
||||
fourdst::composition::Composition temp_comp;
|
||||
std::vector<double> mass_fractions;
|
||||
size_t num_species_at_stop = e.numSpecies();
|
||||
size_t num_species_at_stop = m_engine.getNetworkSpecies().size();
|
||||
|
||||
if (num_species_at_stop > m_Y->ops->nvgetlength(m_Y) - 1) {
|
||||
LOG_ERROR(
|
||||
m_logger,
|
||||
"Number of species at engine update ({}) exceeds the number of species in the CVODE solver ({}). This should never happen.",
|
||||
num_species_at_stop,
|
||||
m_Y->ops->nvgetlength(m_Y) - 1 // -1 due to energy in the last index
|
||||
);
|
||||
throw std::runtime_error("Number of species at engine update exceeds the number of species in the CVODE solver. This should never happen.");
|
||||
}
|
||||
|
||||
mass_fractions.reserve(num_species_at_stop);
|
||||
|
||||
for (size_t i = 0; i < num_species_at_stop; ++i) {
|
||||
const auto& species = m_engine.getNetworkSpecies()[i];
|
||||
temp_comp.registerSpecies(species);
|
||||
mass_fractions.push_back(e.getMolarAbundance(i) * species.mass()); // Convert from molar abundance to mass fraction
|
||||
mass_fractions.push_back(y_data[i] * species.mass()); // Convert from molar abundance to mass fraction
|
||||
}
|
||||
temp_comp.setMassFraction(m_engine.getNetworkSpecies(), mass_fractions);
|
||||
temp_comp.finalize(true);
|
||||
|
||||
NetIn netInTemp = netIn;
|
||||
netInTemp.temperature = e.temperature();
|
||||
netInTemp.density = e.density();
|
||||
netInTemp.temperature = T9 * 1e9; // Convert back to Kelvin
|
||||
netInTemp.density = netIn.density;
|
||||
netInTemp.composition = temp_comp;
|
||||
|
||||
fourdst::composition::Composition currentComposition = m_engine.update(netInTemp);
|
||||
@@ -225,13 +265,16 @@ namespace gridfire::solver {
|
||||
numSpecies = m_engine.getNetworkSpecies().size();
|
||||
N = numSpecies + 1;
|
||||
|
||||
cleanup_cvode_resources(true);
|
||||
|
||||
m_cvode_mem = CVodeCreate(CV_BDF, m_sun_ctx);
|
||||
check_cvode_flag(m_cvode_mem == nullptr ? -1 : 0, "CVodeCreate");
|
||||
|
||||
initialize_cvode_integration_resources(N, numSpecies, current_time, currentComposition, absTol, relTol, accumulated_energy);
|
||||
|
||||
check_cvode_flag(CVodeReInit(m_cvode_mem, current_time, m_Y), "CVodeReInit");
|
||||
} catch (fourdst::composition::exceptions::InvalidCompositionError& e) {
|
||||
log_step_diagnostics(user_data);
|
||||
std::rethrow_exception(std::make_exception_ptr(e));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
sunrealtype* y_data = N_VGetArrayPointer(m_Y);
|
||||
|
||||
@@ -0,0 +1,263 @@
|
||||
#include "gridfire/solver/strategies/triggers/engine_partitioning_trigger.h"
|
||||
#include "gridfire/solver/strategies/CVODE_solver_strategy.h"
|
||||
|
||||
#include "gridfire/trigger/trigger_logical.h"
|
||||
#include "gridfire/trigger/trigger_abstract.h"
|
||||
|
||||
#include "quill/LogMacros.h"
|
||||
|
||||
#include <memory>
|
||||
#include <deque>
|
||||
#include <string>
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
void push_to_fixed_deque(std::deque<T>& dq, T value, size_t max_size) {
|
||||
dq.push_back(value);
|
||||
if (dq.size() > max_size) {
|
||||
dq.pop_front();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
namespace gridfire::trigger::solver::CVODE {
|
||||
SimulationTimeTrigger::SimulationTimeTrigger(double interval) : m_interval(interval) {
|
||||
if (interval <= 0.0) {
|
||||
LOG_ERROR(m_logger, "Interval must be positive, currently it is {}", interval);
|
||||
throw std::invalid_argument("Interval must be positive, currently it is " + std::to_string(interval));
|
||||
}
|
||||
}
|
||||
|
||||
bool SimulationTimeTrigger::check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const {
|
||||
if (ctx.t - m_last_trigger_time >= m_interval) {
|
||||
m_hits++;
|
||||
LOG_TRACE_L2(m_logger, "SimulationTimeTrigger triggered at t = {}, last trigger time was {}, delta = {}", ctx.t, m_last_trigger_time, m_last_trigger_time_delta);
|
||||
return true;
|
||||
}
|
||||
m_misses++;
|
||||
return false;
|
||||
}
|
||||
|
||||
void SimulationTimeTrigger::update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) {
|
||||
if (check(ctx)) {
|
||||
m_last_trigger_time_delta = (ctx.t - m_last_trigger_time) - m_interval;
|
||||
m_last_trigger_time = ctx.t;
|
||||
m_updates++;
|
||||
}
|
||||
}
|
||||
|
||||
void SimulationTimeTrigger::reset() {
|
||||
m_misses = 0;
|
||||
m_hits = 0;
|
||||
m_updates = 0;
|
||||
m_last_trigger_time = 0.0;
|
||||
m_last_trigger_time_delta = 0.0;
|
||||
m_resets++;
|
||||
}
|
||||
|
||||
std::string SimulationTimeTrigger::name() const {
|
||||
return "Simulation Time Trigger";
|
||||
}
|
||||
|
||||
TriggerResult SimulationTimeTrigger::why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const {
|
||||
TriggerResult result;
|
||||
result.name = name();
|
||||
if (check(ctx)) {
|
||||
result.value = true;
|
||||
result.description = "Triggered because current time " + std::to_string(ctx.t) + " - last trigger time " + std::to_string(m_last_trigger_time - m_last_trigger_time_delta) + " >= interval " + std::to_string(m_interval);
|
||||
} else {
|
||||
result.value = false;
|
||||
result.description = "Not triggered because current time " + std::to_string(ctx.t) + " - last trigger time " + std::to_string(m_last_trigger_time) + " < interval " + std::to_string(m_interval);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string SimulationTimeTrigger::describe() const {
|
||||
return "SimulationTimeTrigger(interval=" + std::to_string(m_interval) + ")";
|
||||
}
|
||||
|
||||
size_t SimulationTimeTrigger::numTriggers() const {
|
||||
return m_hits;
|
||||
}
|
||||
|
||||
size_t SimulationTimeTrigger::numMisses() const {
|
||||
return m_misses;
|
||||
}
|
||||
|
||||
OffDiagonalTrigger::OffDiagonalTrigger(
|
||||
double threshold
|
||||
) : m_threshold(threshold) {
|
||||
if (threshold < 0.0) {
|
||||
LOG_ERROR(m_logger, "Threshold must be non-negative, currently it is {}", threshold);
|
||||
throw std::invalid_argument("Threshold must be non-negative, currently it is " + std::to_string(threshold));
|
||||
}
|
||||
}
|
||||
|
||||
bool OffDiagonalTrigger::check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const {
|
||||
const size_t numSpecies = ctx.engine.getNetworkSpecies().size();
|
||||
for (int row = 0; row < numSpecies; ++row) {
|
||||
for (int col = 0; col < numSpecies; ++col) {
|
||||
double DRowDCol = std::abs(ctx.engine.getJacobianMatrixEntry(row, col));
|
||||
if (row != col && DRowDCol > m_threshold) {
|
||||
m_hits++;
|
||||
LOG_TRACE_L2(m_logger, "OffDiagonalTrigger triggered at t = {} due to entry ({}, {}) = {}", ctx.t, row, col, DRowDCol);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
m_misses++;
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
void OffDiagonalTrigger::update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) {
|
||||
m_updates++;
|
||||
}
|
||||
|
||||
void OffDiagonalTrigger::reset() {
|
||||
m_misses = 0;
|
||||
m_hits = 0;
|
||||
m_updates = 0;
|
||||
m_resets++;
|
||||
}
|
||||
|
||||
std::string OffDiagonalTrigger::name() const {
|
||||
return "Off-Diagonal Trigger";
|
||||
}
|
||||
|
||||
TriggerResult OffDiagonalTrigger::why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const {
|
||||
TriggerResult result;
|
||||
result.name = name();
|
||||
|
||||
if (check(ctx)) {
|
||||
result.value = true;
|
||||
result.description = "Triggered because an off-diagonal Jacobian entry exceeded the threshold " + std::to_string(m_threshold);
|
||||
} else {
|
||||
result.value = false;
|
||||
result.description = "Not triggered because no off-diagonal Jacobian entry exceeded the threshold " + std::to_string(m_threshold);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string OffDiagonalTrigger::describe() const {
|
||||
return "OffDiagonalTrigger(threshold=" + std::to_string(m_threshold) + ")";
|
||||
}
|
||||
|
||||
size_t OffDiagonalTrigger::numTriggers() const {
|
||||
return m_hits;
|
||||
}
|
||||
|
||||
size_t OffDiagonalTrigger::numMisses() const {
|
||||
return m_misses;
|
||||
}
|
||||
|
||||
TimestepCollapseTrigger::TimestepCollapseTrigger(
|
||||
const double threshold,
|
||||
const bool relative
|
||||
) : TimestepCollapseTrigger(threshold, relative, 1){}
|
||||
|
||||
|
||||
TimestepCollapseTrigger::TimestepCollapseTrigger(
|
||||
double threshold,
|
||||
const bool relative,
|
||||
const size_t windowSize
|
||||
) : m_threshold(threshold), m_relative(relative), m_windowSize(windowSize) {
|
||||
if (threshold < 0.0) {
|
||||
LOG_ERROR(m_logger, "Threshold must be non-negative, currently it is {}", threshold);
|
||||
throw std::invalid_argument("Threshold must be non-negative, currently it is " + std::to_string(threshold));
|
||||
}
|
||||
if (relative && threshold > 1.0) {
|
||||
LOG_ERROR(m_logger, "Relative threshold must be between 0 and 1, currently it is {}", threshold);
|
||||
throw std::invalid_argument("Relative threshold must be between 0 and 1, currently it is " + std::to_string(threshold));
|
||||
}
|
||||
}
|
||||
|
||||
bool TimestepCollapseTrigger::check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const {
|
||||
if (m_timestep_window.size() < 1) {
|
||||
m_misses++;
|
||||
return false;
|
||||
}
|
||||
double averageTimestep = 0.0;
|
||||
for (const auto& dt : m_timestep_window) {
|
||||
averageTimestep += dt;
|
||||
}
|
||||
averageTimestep /= m_timestep_window.size();
|
||||
if (m_relative && (std::abs(ctx.dt - averageTimestep) / averageTimestep) >= m_threshold) {
|
||||
m_hits++;
|
||||
LOG_TRACE_L2(m_logger, "TimestepCollapseTrigger triggered at t = {} due to relative growth: dt = {}, average dt = {}, threshold = {}", ctx.t, ctx.dt, averageTimestep, m_threshold);
|
||||
return true;
|
||||
} else if (!m_relative && std::abs(ctx.dt - averageTimestep) >= m_threshold) {
|
||||
m_hits++;
|
||||
LOG_TRACE_L2(m_logger, "TimestepCollapseTrigger triggered at t = {} due to absolute growth: dt = {}, average dt = {}, threshold = {}", ctx.t, ctx.dt, averageTimestep, m_threshold);
|
||||
return true;
|
||||
}
|
||||
m_misses++;
|
||||
return false;
|
||||
}
|
||||
|
||||
void TimestepCollapseTrigger::update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) {
|
||||
push_to_fixed_deque(m_timestep_window, ctx.dt, m_windowSize);
|
||||
m_updates++;
|
||||
}
|
||||
|
||||
void TimestepCollapseTrigger::reset() {
|
||||
m_misses = 0;
|
||||
m_hits = 0;
|
||||
m_updates = 0;
|
||||
m_resets++;
|
||||
m_timestep_window.clear();
|
||||
}
|
||||
|
||||
std::string TimestepCollapseTrigger::name() const {
|
||||
return "TimestepCollapseTrigger";
|
||||
}
|
||||
|
||||
TriggerResult TimestepCollapseTrigger::why(
|
||||
const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx
|
||||
) const {
|
||||
TriggerResult result;
|
||||
result.name = name();
|
||||
|
||||
if (check(ctx)) {
|
||||
result.value = true;
|
||||
result.description = "Triggered because timestep change exceeded the threshold " + std::to_string(m_threshold);
|
||||
} else {
|
||||
result.value = false;
|
||||
result.description = "Not triggered because timestep change did not exceed the threshold " + std::to_string(m_threshold);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string TimestepCollapseTrigger::describe() const {
|
||||
return "TimestepCollapseTrigger(threshold=" + std::to_string(m_threshold) + ", relative=" + (m_relative ? "true" : "false") + ", windowSize=" + std::to_string(m_windowSize) + ")";
|
||||
}
|
||||
|
||||
size_t TimestepCollapseTrigger::numTriggers() const {
|
||||
return m_hits;
|
||||
}
|
||||
|
||||
size_t TimestepCollapseTrigger::numMisses() const {
|
||||
return m_misses;
|
||||
}
|
||||
|
||||
std::unique_ptr<Trigger<gridfire::solver::CVODESolverStrategy::TimestepContext>> makeEnginePartitioningTrigger(
|
||||
const double simulationTimeInterval,
|
||||
const double offDiagonalThreshold,
|
||||
const double timestepGrowthThreshold,
|
||||
const bool timestepGrowthRelative,
|
||||
const size_t timestepGrowthWindowSize
|
||||
) {
|
||||
using ctx_t = gridfire::solver::CVODESolverStrategy::TimestepContext;
|
||||
|
||||
// Create the individual conditions that can trigger a repartitioning
|
||||
auto simulationTimeTrigger = std::make_unique<EveryNthTrigger<ctx_t>>(std::make_unique<SimulationTimeTrigger>(simulationTimeInterval), 1000);
|
||||
auto offDiagTrigger = std::make_unique<OffDiagonalTrigger>(offDiagonalThreshold);
|
||||
auto timestepGrowthTrigger = std::make_unique<EveryNthTrigger<ctx_t>>(std::make_unique<TimestepCollapseTrigger>(timestepGrowthThreshold, timestepGrowthRelative, timestepGrowthWindowSize), 10);
|
||||
|
||||
// Combine the triggers using logical OR
|
||||
auto orTriggerA = std::make_unique<OrTrigger<ctx_t>>(std::move(simulationTimeTrigger), std::move(offDiagTrigger));
|
||||
auto orTriggerB = std::make_unique<OrTrigger<ctx_t>>(std::move(orTriggerA), std::move(timestepGrowthTrigger));
|
||||
|
||||
return orTriggerB;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user