feat(trigger): added working robust repartitioning trigger system
more work is needed to identify the most robust set of criteria to trigger on but the system is now very easy to exend, probe, and use.
This commit is contained in:
@@ -176,20 +176,6 @@ namespace gridfire {
|
||||
}
|
||||
|
||||
// Check the cache to see if the network needs to be repartitioned. Note that the QSECacheKey manages binning of T9, rho, and Y_full to ensure that small changes (which would likely not result in a repartitioning) do not trigger a cache miss.
|
||||
const QSECacheKey key(T9, rho, Y_full);
|
||||
if (! m_qse_abundance_cache.contains(key)) {
|
||||
m_cacheStats.miss(CacheStats::operators::CalculateRHSAndEnergy);
|
||||
LOG_ERROR(
|
||||
m_logger,
|
||||
"QSE abundance cache miss for T9 = {}, rho = {} (misses: {}, hits: {}). calculateRHSAndEnergy does not receive sufficient context to partition and stabilize the network. Throwing an error which should be caught by the caller and trigger a re-partition stage.",
|
||||
T9,
|
||||
rho,
|
||||
m_cacheStats.misses(),
|
||||
m_cacheStats.hits()
|
||||
);
|
||||
return std::unexpected{expectations::StaleEngineError(expectations::StaleEngineErrorTypes::SYSTEM_RESIZED)};
|
||||
}
|
||||
m_cacheStats.hit(CacheStats::operators::CalculateRHSAndEnergy);
|
||||
const auto result = m_baseEngine.calculateRHSAndEnergy(Y_full, T9, rho);
|
||||
if (!result) {
|
||||
return std::unexpected{result.error()};
|
||||
@@ -215,21 +201,6 @@ namespace gridfire {
|
||||
const double T9,
|
||||
const double rho
|
||||
) const {
|
||||
const QSECacheKey key(T9, rho, Y_full);
|
||||
if (!m_qse_abundance_cache.contains(key)) {
|
||||
m_cacheStats.miss(CacheStats::operators::GenerateJacobianMatrix);
|
||||
LOG_ERROR(
|
||||
m_logger,
|
||||
"QSE abundance cache miss for T9 = {}, rho = {} (misses: {}, hits: {}). generateJacobianMatrix does not receive sufficient context to partition and stabilize the network. Throwing an error which should be caught by the caller and trigger a re-partition stage.",
|
||||
T9,
|
||||
rho,
|
||||
m_cacheStats.misses(),
|
||||
m_cacheStats.hits()
|
||||
);
|
||||
throw exceptions::StaleEngineError("QSE Cache Miss while lacking context for partitioning. This should be caught by the caller and trigger a re-partition stage.");
|
||||
}
|
||||
m_cacheStats.hit(CacheStats::operators::GenerateJacobianMatrix);
|
||||
|
||||
// TODO: Add sparsity pattern to this to prevent base engine from doing unnecessary work.
|
||||
m_baseEngine.generateJacobianMatrix(Y_full, T9, rho);
|
||||
}
|
||||
@@ -268,27 +239,11 @@ namespace gridfire {
|
||||
const double T9,
|
||||
const double rho
|
||||
) const {
|
||||
const auto key = QSECacheKey(T9, rho, Y_full);
|
||||
if (!m_qse_abundance_cache.contains(key)) {
|
||||
m_cacheStats.miss(CacheStats::operators::CalculateMolarReactionFlow);
|
||||
LOG_ERROR(
|
||||
m_logger,
|
||||
"QSE abundance cache miss for T9 = {}, rho = {} (misses: {}, hits: {}). calculateMolarReactionFlow does not receive sufficient context to partition and stabilize the network. Throwing an error which should be caught by the caller and trigger a re-partition stage.",
|
||||
T9,
|
||||
rho,
|
||||
m_cacheStats.misses(),
|
||||
m_cacheStats.hits()
|
||||
);
|
||||
throw exceptions::StaleEngineError("QSE Cache Miss while lacking context for partitioning. This should be caught by the caller and trigger a re-partition stage.");
|
||||
}
|
||||
m_cacheStats.hit(CacheStats::operators::CalculateMolarReactionFlow);
|
||||
std::vector<double> Y_algebraic = m_qse_abundance_cache.at(key);
|
||||
|
||||
assert(Y_algebraic.size() == m_algebraic_species_indices.size());
|
||||
assert(m_Y_algebraic.size() == m_algebraic_species_indices.size());
|
||||
|
||||
// Fix the algebraic species to the equilibrium abundances we calculate.
|
||||
std::vector<double> Y_mutable = Y_full;
|
||||
for (const auto& [index, Yi] : std::views::zip(m_algebraic_species_indices, Y_algebraic)) {
|
||||
for (const auto& [index, Yi] : std::views::zip(m_algebraic_species_indices, m_Y_algebraic)) {
|
||||
Y_mutable[index] = Yi;
|
||||
|
||||
}
|
||||
@@ -309,20 +264,6 @@ namespace gridfire {
|
||||
const double T9,
|
||||
const double rho
|
||||
) const {
|
||||
const auto key = QSECacheKey(T9, rho, Y);
|
||||
if (!m_qse_abundance_cache.contains(key)) {
|
||||
m_cacheStats.miss(CacheStats::operators::GetSpeciesTimescales);
|
||||
LOG_ERROR(
|
||||
m_logger,
|
||||
"QSE abundance cache miss for T9 = {}, rho = {} (misses: {}, hits: {}). getSpeciesTimescales does not receive sufficient context to partition and stabilize the network. Throwing an error which should be caught by the caller and trigger a re-partition stage.",
|
||||
T9,
|
||||
rho,
|
||||
m_cacheStats.misses(),
|
||||
m_cacheStats.hits()
|
||||
);
|
||||
throw exceptions::StaleEngineError("QSE Cache Miss while lacking context for partitioning. This should be caught by the caller and trigger a re-partition stage.");
|
||||
}
|
||||
m_cacheStats.hit(CacheStats::operators::GetSpeciesTimescales);
|
||||
const auto result = m_baseEngine.getSpeciesTimescales(Y, T9, rho);
|
||||
if (!result) {
|
||||
return std::unexpected{result.error()};
|
||||
@@ -337,23 +278,9 @@ namespace gridfire {
|
||||
std::expected<std::unordered_map<fourdst::atomic::Species, double>, expectations::StaleEngineError>
|
||||
MultiscalePartitioningEngineView::getSpeciesDestructionTimescales(
|
||||
const std::vector<double> &Y,
|
||||
double T9,
|
||||
double rho
|
||||
const double T9,
|
||||
const double rho
|
||||
) const {
|
||||
const auto key = QSECacheKey(T9, rho, Y);
|
||||
if (!m_qse_abundance_cache.contains(key)) {
|
||||
m_cacheStats.miss(CacheStats::operators::GetSpeciesDestructionTimescales);
|
||||
LOG_ERROR(
|
||||
m_logger,
|
||||
"QSE abundance cache miss for T9 = {}, rho = {} (misses: {}, hits: {}). getSpeciesDestructionTimescales does not receive sufficient context to partition and stabilize the network. Throwing an error which should be caught by the caller and trigger a re-partition stage.",
|
||||
T9,
|
||||
rho,
|
||||
m_cacheStats.misses(),
|
||||
m_cacheStats.hits()
|
||||
);
|
||||
throw exceptions::StaleEngineError("QSE Cache Miss while lacking context for partitioning. This should be caught by the caller and trigger a re-partition stage.");
|
||||
}
|
||||
m_cacheStats.hit(CacheStats::operators::GetSpeciesDestructionTimescales);
|
||||
const auto result = m_baseEngine.getSpeciesDestructionTimescales(Y, T9, rho);
|
||||
if (!result) {
|
||||
return std::unexpected{result.error()};
|
||||
@@ -367,16 +294,7 @@ namespace gridfire {
|
||||
|
||||
fourdst::composition::Composition MultiscalePartitioningEngineView::update(const NetIn &netIn) {
|
||||
const fourdst::composition::Composition baseUpdatedComposition = m_baseEngine.update(netIn);
|
||||
double T9 = netIn.temperature / 1.0e9; // Convert temperature from Kelvin to T9 (T9 = T / 1e9)
|
||||
|
||||
const auto preKey = QSECacheKey(
|
||||
T9,
|
||||
netIn.density,
|
||||
packCompositionToVector(baseUpdatedComposition, m_baseEngine)
|
||||
);
|
||||
if (m_qse_abundance_cache.contains(preKey)) {
|
||||
return baseUpdatedComposition;
|
||||
}
|
||||
NetIn baseUpdatedNetIn = netIn;
|
||||
baseUpdatedNetIn.composition = baseUpdatedComposition;
|
||||
const fourdst::composition::Composition equilibratedComposition = equilibrateNetwork(baseUpdatedNetIn);
|
||||
@@ -386,15 +304,7 @@ namespace gridfire {
|
||||
Y_algebraic[i] = equilibratedComposition.getMolarAbundance(m_baseEngine.getNetworkSpecies()[species_index]);
|
||||
}
|
||||
|
||||
// We store the algebraic abundances in the cache for both pre- and post-conditions to avoid recalculating them.
|
||||
m_qse_abundance_cache[preKey] = Y_algebraic;
|
||||
|
||||
const auto postKey = QSECacheKey(
|
||||
T9,
|
||||
netIn.density,
|
||||
packCompositionToVector(equilibratedComposition, m_baseEngine)
|
||||
);
|
||||
m_qse_abundance_cache[postKey] = Y_algebraic;
|
||||
m_Y_algebraic = std::move(Y_algebraic);
|
||||
|
||||
return equilibratedComposition;
|
||||
}
|
||||
@@ -594,11 +504,6 @@ namespace gridfire {
|
||||
m_qse_groups.size(),
|
||||
m_qse_groups.size() == 1 ? "" : "s"
|
||||
);
|
||||
|
||||
// throw std::runtime_error(
|
||||
// "Partitioning complete. Throwing an error to end the program during debugging. This error should not be caught by the caller. "
|
||||
// );
|
||||
|
||||
}
|
||||
|
||||
void MultiscalePartitioningEngineView::partitionNetwork(
|
||||
@@ -1129,29 +1034,6 @@ namespace gridfire {
|
||||
coupling_flux += flow * coupling_fraction;
|
||||
}
|
||||
|
||||
// if (leakage_flux < 1e-99) {
|
||||
// LOG_TRACE_L1(
|
||||
// m_logger,
|
||||
// "Group containing {} is in equilibrium due to vanishing leakage: leakage flux = {}, coupling flux = {}, ratio = {}",
|
||||
// [&]() -> std::string {
|
||||
// std::stringstream ss;
|
||||
// int count = 0;
|
||||
// for (const auto& idx : group.algebraic_indices) {
|
||||
// ss << m_baseEngine.getNetworkSpecies()[idx].name();
|
||||
// if (count < group.species_indices.size() - 1) {
|
||||
// ss << ", ";
|
||||
// }
|
||||
// count++;
|
||||
// }
|
||||
// return ss.str();
|
||||
// }(),
|
||||
// leakage_flux,
|
||||
// coupling_flux,
|
||||
// coupling_flux / leakage_flux
|
||||
// );
|
||||
// validated_groups.emplace_back(group);
|
||||
// validated_groups.back().is_in_equilibrium = true;
|
||||
// } else if ((coupling_flux / leakage_flux ) > FLUX_RATIO_THRESHOLD) {
|
||||
if ((coupling_flux / leakage_flux ) > FLUX_RATIO_THRESHOLD) {
|
||||
LOG_TRACE_L1(
|
||||
m_logger,
|
||||
@@ -1703,7 +1585,7 @@ namespace gridfire {
|
||||
|
||||
}
|
||||
|
||||
std::string MultiscalePartitioningEngineView::QSEGroup::toString(DynamicEngine &engine) const {
|
||||
std::string MultiscalePartitioningEngineView::QSEGroup::toString(const DynamicEngine &engine) const {
|
||||
std::stringstream ss;
|
||||
ss << "QSEGroup(Algebraic: [";
|
||||
size_t count = 0;
|
||||
|
||||
@@ -17,6 +17,8 @@
|
||||
#include <algorithm>
|
||||
|
||||
#include "fourdst/composition/exceptions/exceptions_composition.h"
|
||||
#include "gridfire/solver/strategies/triggers/engine_partitioning_trigger.h"
|
||||
#include "gridfire/trigger/procedures/trigger_pprint.h"
|
||||
|
||||
|
||||
namespace {
|
||||
@@ -78,6 +80,43 @@ namespace {
|
||||
|
||||
namespace gridfire::solver {
|
||||
|
||||
CVODESolverStrategy::TimestepContext::TimestepContext(
|
||||
const double t,
|
||||
const N_Vector &state,
|
||||
const double dt,
|
||||
const double last_step_time,
|
||||
const double t9,
|
||||
const double rho,
|
||||
const int num_steps,
|
||||
const DynamicEngine &engine,
|
||||
const std::vector<fourdst::atomic::Species> &networkSpecies
|
||||
) :
|
||||
t(t),
|
||||
state(state),
|
||||
dt(dt),
|
||||
last_step_time(last_step_time),
|
||||
T9(t9),
|
||||
rho(rho),
|
||||
num_steps(num_steps),
|
||||
engine(engine),
|
||||
networkSpecies(networkSpecies)
|
||||
{}
|
||||
|
||||
std::vector<std::tuple<std::string, std::string>> CVODESolverStrategy::TimestepContext::describe() const {
|
||||
std::vector<std::tuple<std::string, std::string>> description;
|
||||
description.emplace_back("t", "Current Time");
|
||||
description.emplace_back("state", "Current State Vector (N_Vector)");
|
||||
description.emplace_back("dt", "Last Timestep Size");
|
||||
description.emplace_back("last_step_time", "Time at Last Step");
|
||||
description.emplace_back("T9", "Temperature in GK");
|
||||
description.emplace_back("rho", "Density in g/cm^3");
|
||||
description.emplace_back("num_steps", "Number of Steps Taken So Far");
|
||||
description.emplace_back("engine", "Reference to the DynamicEngine");
|
||||
description.emplace_back("networkSpecies", "Reference to the list of network species");
|
||||
return description;
|
||||
}
|
||||
|
||||
|
||||
CVODESolverStrategy::CVODESolverStrategy(DynamicEngine &engine): NetworkSolverStrategy<DynamicEngine>(engine) {
|
||||
// TODO: In order to support MPI this function must be changed
|
||||
const int flag = SUNContext_Create(SUN_COMM_NULL, &m_sun_ctx);
|
||||
@@ -96,6 +135,8 @@ namespace gridfire::solver {
|
||||
}
|
||||
|
||||
NetOut CVODESolverStrategy::evaluate(const NetIn& netIn) {
|
||||
auto trigger = trigger::solver::CVODE::makeEnginePartitioningTrigger(1e12, 1e10, 1, true, 10);
|
||||
|
||||
const double T9 = netIn.temperature / 1e9; // Convert temperature from Kelvin to T9 (T9 = T / 1e9)
|
||||
|
||||
const auto absTol = m_config.get<double>("gridfire:solver:CVODESolverStrategy:absTol", 1.0e-8);
|
||||
@@ -121,74 +162,63 @@ namespace gridfire::solver {
|
||||
double accumulated_energy = 0.0;
|
||||
int total_update_stages_triggered = 0;
|
||||
while (current_time < netIn.tMax) {
|
||||
try {
|
||||
user_data.T9 = T9;
|
||||
user_data.rho = netIn.density;
|
||||
user_data.networkSpecies = &m_engine.getNetworkSpecies();
|
||||
user_data.captured_exception.reset();
|
||||
user_data.T9 = T9;
|
||||
user_data.rho = netIn.density;
|
||||
user_data.networkSpecies = &m_engine.getNetworkSpecies();
|
||||
user_data.captured_exception.reset();
|
||||
|
||||
check_cvode_flag(CVodeSetUserData(m_cvode_mem, &user_data), "CVodeSetUserData");
|
||||
check_cvode_flag(CVodeSetUserData(m_cvode_mem, &user_data), "CVodeSetUserData");
|
||||
|
||||
int flag = -1;
|
||||
if (m_stdout_logging_enabled) {
|
||||
flag = CVode(m_cvode_mem, netIn.tMax, m_Y, ¤t_time, CV_ONE_STEP);
|
||||
} else {
|
||||
flag = CVode(m_cvode_mem, netIn.tMax, m_Y, ¤t_time, CV_NORMAL);
|
||||
}
|
||||
int flag = -1;
|
||||
if (m_stdout_logging_enabled) {
|
||||
flag = CVode(m_cvode_mem, netIn.tMax, m_Y, ¤t_time, CV_ONE_STEP);
|
||||
} else {
|
||||
flag = CVode(m_cvode_mem, netIn.tMax, m_Y, ¤t_time, CV_NORMAL);
|
||||
}
|
||||
|
||||
if (user_data.captured_exception){
|
||||
std::rethrow_exception(std::make_exception_ptr(*user_data.captured_exception));
|
||||
}
|
||||
if (user_data.captured_exception){
|
||||
std::rethrow_exception(std::make_exception_ptr(*user_data.captured_exception));
|
||||
}
|
||||
|
||||
check_cvode_flag(flag, "CVode");
|
||||
check_cvode_flag(flag, "CVode");
|
||||
|
||||
long int n_steps;
|
||||
double last_step_size;
|
||||
CVodeGetNumSteps(m_cvode_mem, &n_steps);
|
||||
CVodeGetLastStep(m_cvode_mem, &last_step_size);
|
||||
long int nliters, nlcfails;
|
||||
CVodeGetNumNonlinSolvIters(m_cvode_mem, &nliters);
|
||||
CVodeGetNumNonlinSolvConvFails(m_cvode_mem, &nlcfails);
|
||||
long int n_steps;
|
||||
double last_step_size;
|
||||
CVodeGetNumSteps(m_cvode_mem, &n_steps);
|
||||
CVodeGetLastStep(m_cvode_mem, &last_step_size);
|
||||
long int nliters, nlcfails;
|
||||
CVodeGetNumNonlinSolvIters(m_cvode_mem, &nliters);
|
||||
CVodeGetNumNonlinSolvConvFails(m_cvode_mem, &nlcfails);
|
||||
|
||||
sunrealtype* y_data = N_VGetArrayPointer(m_Y);
|
||||
const double current_energy = y_data[numSpecies]; // Specific energy rate
|
||||
std::cout << std::scientific << std::setprecision(3)
|
||||
<< "Step: " << std::setw(6) << n_steps
|
||||
<< " | Time: " << current_time << " [s]"
|
||||
<< " | Last Step Size: " << last_step_size
|
||||
<< " | Accumulated Energy: " << current_energy << " [erg/g]"
|
||||
<< " | NonlinIters: " << std::setw(2) << nliters
|
||||
<< " | ConvFails: " << std::setw(2) << nlcfails
|
||||
<< std::endl;
|
||||
if (n_steps % 300 == 0) {
|
||||
std::cout << "Manually triggering engine update at step " << n_steps << "..." << std::endl;
|
||||
exceptions::StaleEngineTrigger::state staleState {
|
||||
T9,
|
||||
netIn.density,
|
||||
std::vector<double>(y_data, y_data + numSpecies),
|
||||
current_time,
|
||||
static_cast<int>(n_steps),
|
||||
current_energy
|
||||
};
|
||||
throw exceptions::StaleEngineTrigger(staleState);
|
||||
}
|
||||
sunrealtype* y_data = N_VGetArrayPointer(m_Y);
|
||||
const double current_energy = y_data[numSpecies]; // Specific energy rate
|
||||
std::cout << std::scientific << std::setprecision(3)
|
||||
<< "Step: " << std::setw(6) << n_steps
|
||||
<< " | Time: " << current_time << " [s]"
|
||||
<< " | Last Step Size: " << last_step_size
|
||||
<< " | Accumulated Energy: " << current_energy << " [erg/g]"
|
||||
<< " | NonLinIters: " << std::setw(2) << nliters
|
||||
<< " | ConvFails: " << std::setw(2) << nlcfails
|
||||
<< std::endl;
|
||||
|
||||
// if (n_steps % 50 == 0) {
|
||||
// std::cout << "Logging step diagnostics at step " << n_steps << "..." << std::endl;
|
||||
// log_step_diagnostics(user_data);
|
||||
// }
|
||||
// if (n_steps == 300) {
|
||||
// log_step_diagnostics(user_data);
|
||||
// exit(0);
|
||||
// }
|
||||
// log_step_diagnostics(user_data);
|
||||
auto ctx = TimestepContext(
|
||||
current_time,
|
||||
reinterpret_cast<N_Vector>(y_data),
|
||||
last_step_size,
|
||||
last_callback_time,
|
||||
T9,
|
||||
netIn.density,
|
||||
n_steps,
|
||||
m_engine,
|
||||
m_engine.getNetworkSpecies());
|
||||
|
||||
} catch (const exceptions::StaleEngineTrigger& e) {
|
||||
exceptions::StaleEngineTrigger::state staleState = e.getState();
|
||||
accumulated_energy += e.energy(); // Add the specific energy rate to the accumulated energy
|
||||
if (trigger->check(ctx)) {
|
||||
trigger::printWhy(trigger->why(ctx));
|
||||
trigger->update(ctx);
|
||||
accumulated_energy += current_energy; // Add the specific energy rate to the accumulated energy
|
||||
LOG_INFO(
|
||||
m_logger,
|
||||
"Engine Update Triggered due to StaleEngineTrigger exception at time {} ({} update{} triggered). Current total specific energy {} [erg/g]",
|
||||
"Engine Update Triggered at time {} ({} update{} triggered). Current total specific energy {} [erg/g]",
|
||||
current_time,
|
||||
total_update_stages_triggered,
|
||||
total_update_stages_triggered == 1 ? "" : "s",
|
||||
@@ -197,21 +227,31 @@ namespace gridfire::solver {
|
||||
|
||||
fourdst::composition::Composition temp_comp;
|
||||
std::vector<double> mass_fractions;
|
||||
size_t num_species_at_stop = e.numSpecies();
|
||||
size_t num_species_at_stop = m_engine.getNetworkSpecies().size();
|
||||
|
||||
if (num_species_at_stop > m_Y->ops->nvgetlength(m_Y) - 1) {
|
||||
LOG_ERROR(
|
||||
m_logger,
|
||||
"Number of species at engine update ({}) exceeds the number of species in the CVODE solver ({}). This should never happen.",
|
||||
num_species_at_stop,
|
||||
m_Y->ops->nvgetlength(m_Y) - 1 // -1 due to energy in the last index
|
||||
);
|
||||
throw std::runtime_error("Number of species at engine update exceeds the number of species in the CVODE solver. This should never happen.");
|
||||
}
|
||||
|
||||
mass_fractions.reserve(num_species_at_stop);
|
||||
|
||||
for (size_t i = 0; i < num_species_at_stop; ++i) {
|
||||
const auto& species = m_engine.getNetworkSpecies()[i];
|
||||
temp_comp.registerSpecies(species);
|
||||
mass_fractions.push_back(e.getMolarAbundance(i) * species.mass()); // Convert from molar abundance to mass fraction
|
||||
mass_fractions.push_back(y_data[i] * species.mass()); // Convert from molar abundance to mass fraction
|
||||
}
|
||||
temp_comp.setMassFraction(m_engine.getNetworkSpecies(), mass_fractions);
|
||||
temp_comp.finalize(true);
|
||||
|
||||
NetIn netInTemp = netIn;
|
||||
netInTemp.temperature = e.temperature();
|
||||
netInTemp.density = e.density();
|
||||
netInTemp.temperature = T9 * 1e9; // Convert back to Kelvin
|
||||
netInTemp.density = netIn.density;
|
||||
netInTemp.composition = temp_comp;
|
||||
|
||||
fourdst::composition::Composition currentComposition = m_engine.update(netInTemp);
|
||||
@@ -225,13 +265,16 @@ namespace gridfire::solver {
|
||||
numSpecies = m_engine.getNetworkSpecies().size();
|
||||
N = numSpecies + 1;
|
||||
|
||||
cleanup_cvode_resources(true);
|
||||
|
||||
m_cvode_mem = CVodeCreate(CV_BDF, m_sun_ctx);
|
||||
check_cvode_flag(m_cvode_mem == nullptr ? -1 : 0, "CVodeCreate");
|
||||
|
||||
initialize_cvode_integration_resources(N, numSpecies, current_time, currentComposition, absTol, relTol, accumulated_energy);
|
||||
|
||||
check_cvode_flag(CVodeReInit(m_cvode_mem, current_time, m_Y), "CVodeReInit");
|
||||
} catch (fourdst::composition::exceptions::InvalidCompositionError& e) {
|
||||
log_step_diagnostics(user_data);
|
||||
std::rethrow_exception(std::make_exception_ptr(e));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
sunrealtype* y_data = N_VGetArrayPointer(m_Y);
|
||||
|
||||
@@ -0,0 +1,263 @@
|
||||
#include "gridfire/solver/strategies/triggers/engine_partitioning_trigger.h"
|
||||
#include "gridfire/solver/strategies/CVODE_solver_strategy.h"
|
||||
|
||||
#include "gridfire/trigger/trigger_logical.h"
|
||||
#include "gridfire/trigger/trigger_abstract.h"
|
||||
|
||||
#include "quill/LogMacros.h"
|
||||
|
||||
#include <memory>
|
||||
#include <deque>
|
||||
#include <string>
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
void push_to_fixed_deque(std::deque<T>& dq, T value, size_t max_size) {
|
||||
dq.push_back(value);
|
||||
if (dq.size() > max_size) {
|
||||
dq.pop_front();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
namespace gridfire::trigger::solver::CVODE {
|
||||
SimulationTimeTrigger::SimulationTimeTrigger(double interval) : m_interval(interval) {
|
||||
if (interval <= 0.0) {
|
||||
LOG_ERROR(m_logger, "Interval must be positive, currently it is {}", interval);
|
||||
throw std::invalid_argument("Interval must be positive, currently it is " + std::to_string(interval));
|
||||
}
|
||||
}
|
||||
|
||||
bool SimulationTimeTrigger::check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const {
|
||||
if (ctx.t - m_last_trigger_time >= m_interval) {
|
||||
m_hits++;
|
||||
LOG_TRACE_L2(m_logger, "SimulationTimeTrigger triggered at t = {}, last trigger time was {}, delta = {}", ctx.t, m_last_trigger_time, m_last_trigger_time_delta);
|
||||
return true;
|
||||
}
|
||||
m_misses++;
|
||||
return false;
|
||||
}
|
||||
|
||||
void SimulationTimeTrigger::update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) {
|
||||
if (check(ctx)) {
|
||||
m_last_trigger_time_delta = (ctx.t - m_last_trigger_time) - m_interval;
|
||||
m_last_trigger_time = ctx.t;
|
||||
m_updates++;
|
||||
}
|
||||
}
|
||||
|
||||
void SimulationTimeTrigger::reset() {
|
||||
m_misses = 0;
|
||||
m_hits = 0;
|
||||
m_updates = 0;
|
||||
m_last_trigger_time = 0.0;
|
||||
m_last_trigger_time_delta = 0.0;
|
||||
m_resets++;
|
||||
}
|
||||
|
||||
std::string SimulationTimeTrigger::name() const {
|
||||
return "Simulation Time Trigger";
|
||||
}
|
||||
|
||||
TriggerResult SimulationTimeTrigger::why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const {
|
||||
TriggerResult result;
|
||||
result.name = name();
|
||||
if (check(ctx)) {
|
||||
result.value = true;
|
||||
result.description = "Triggered because current time " + std::to_string(ctx.t) + " - last trigger time " + std::to_string(m_last_trigger_time - m_last_trigger_time_delta) + " >= interval " + std::to_string(m_interval);
|
||||
} else {
|
||||
result.value = false;
|
||||
result.description = "Not triggered because current time " + std::to_string(ctx.t) + " - last trigger time " + std::to_string(m_last_trigger_time) + " < interval " + std::to_string(m_interval);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string SimulationTimeTrigger::describe() const {
|
||||
return "SimulationTimeTrigger(interval=" + std::to_string(m_interval) + ")";
|
||||
}
|
||||
|
||||
size_t SimulationTimeTrigger::numTriggers() const {
|
||||
return m_hits;
|
||||
}
|
||||
|
||||
size_t SimulationTimeTrigger::numMisses() const {
|
||||
return m_misses;
|
||||
}
|
||||
|
||||
OffDiagonalTrigger::OffDiagonalTrigger(
|
||||
double threshold
|
||||
) : m_threshold(threshold) {
|
||||
if (threshold < 0.0) {
|
||||
LOG_ERROR(m_logger, "Threshold must be non-negative, currently it is {}", threshold);
|
||||
throw std::invalid_argument("Threshold must be non-negative, currently it is " + std::to_string(threshold));
|
||||
}
|
||||
}
|
||||
|
||||
bool OffDiagonalTrigger::check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const {
|
||||
const size_t numSpecies = ctx.engine.getNetworkSpecies().size();
|
||||
for (int row = 0; row < numSpecies; ++row) {
|
||||
for (int col = 0; col < numSpecies; ++col) {
|
||||
double DRowDCol = std::abs(ctx.engine.getJacobianMatrixEntry(row, col));
|
||||
if (row != col && DRowDCol > m_threshold) {
|
||||
m_hits++;
|
||||
LOG_TRACE_L2(m_logger, "OffDiagonalTrigger triggered at t = {} due to entry ({}, {}) = {}", ctx.t, row, col, DRowDCol);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
m_misses++;
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
void OffDiagonalTrigger::update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) {
|
||||
m_updates++;
|
||||
}
|
||||
|
||||
void OffDiagonalTrigger::reset() {
|
||||
m_misses = 0;
|
||||
m_hits = 0;
|
||||
m_updates = 0;
|
||||
m_resets++;
|
||||
}
|
||||
|
||||
std::string OffDiagonalTrigger::name() const {
|
||||
return "Off-Diagonal Trigger";
|
||||
}
|
||||
|
||||
TriggerResult OffDiagonalTrigger::why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const {
|
||||
TriggerResult result;
|
||||
result.name = name();
|
||||
|
||||
if (check(ctx)) {
|
||||
result.value = true;
|
||||
result.description = "Triggered because an off-diagonal Jacobian entry exceeded the threshold " + std::to_string(m_threshold);
|
||||
} else {
|
||||
result.value = false;
|
||||
result.description = "Not triggered because no off-diagonal Jacobian entry exceeded the threshold " + std::to_string(m_threshold);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string OffDiagonalTrigger::describe() const {
|
||||
return "OffDiagonalTrigger(threshold=" + std::to_string(m_threshold) + ")";
|
||||
}
|
||||
|
||||
size_t OffDiagonalTrigger::numTriggers() const {
|
||||
return m_hits;
|
||||
}
|
||||
|
||||
size_t OffDiagonalTrigger::numMisses() const {
|
||||
return m_misses;
|
||||
}
|
||||
|
||||
TimestepCollapseTrigger::TimestepCollapseTrigger(
|
||||
const double threshold,
|
||||
const bool relative
|
||||
) : TimestepCollapseTrigger(threshold, relative, 1){}
|
||||
|
||||
|
||||
TimestepCollapseTrigger::TimestepCollapseTrigger(
|
||||
double threshold,
|
||||
const bool relative,
|
||||
const size_t windowSize
|
||||
) : m_threshold(threshold), m_relative(relative), m_windowSize(windowSize) {
|
||||
if (threshold < 0.0) {
|
||||
LOG_ERROR(m_logger, "Threshold must be non-negative, currently it is {}", threshold);
|
||||
throw std::invalid_argument("Threshold must be non-negative, currently it is " + std::to_string(threshold));
|
||||
}
|
||||
if (relative && threshold > 1.0) {
|
||||
LOG_ERROR(m_logger, "Relative threshold must be between 0 and 1, currently it is {}", threshold);
|
||||
throw std::invalid_argument("Relative threshold must be between 0 and 1, currently it is " + std::to_string(threshold));
|
||||
}
|
||||
}
|
||||
|
||||
bool TimestepCollapseTrigger::check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const {
|
||||
if (m_timestep_window.size() < 1) {
|
||||
m_misses++;
|
||||
return false;
|
||||
}
|
||||
double averageTimestep = 0.0;
|
||||
for (const auto& dt : m_timestep_window) {
|
||||
averageTimestep += dt;
|
||||
}
|
||||
averageTimestep /= m_timestep_window.size();
|
||||
if (m_relative && (std::abs(ctx.dt - averageTimestep) / averageTimestep) >= m_threshold) {
|
||||
m_hits++;
|
||||
LOG_TRACE_L2(m_logger, "TimestepCollapseTrigger triggered at t = {} due to relative growth: dt = {}, average dt = {}, threshold = {}", ctx.t, ctx.dt, averageTimestep, m_threshold);
|
||||
return true;
|
||||
} else if (!m_relative && std::abs(ctx.dt - averageTimestep) >= m_threshold) {
|
||||
m_hits++;
|
||||
LOG_TRACE_L2(m_logger, "TimestepCollapseTrigger triggered at t = {} due to absolute growth: dt = {}, average dt = {}, threshold = {}", ctx.t, ctx.dt, averageTimestep, m_threshold);
|
||||
return true;
|
||||
}
|
||||
m_misses++;
|
||||
return false;
|
||||
}
|
||||
|
||||
void TimestepCollapseTrigger::update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) {
|
||||
push_to_fixed_deque(m_timestep_window, ctx.dt, m_windowSize);
|
||||
m_updates++;
|
||||
}
|
||||
|
||||
void TimestepCollapseTrigger::reset() {
|
||||
m_misses = 0;
|
||||
m_hits = 0;
|
||||
m_updates = 0;
|
||||
m_resets++;
|
||||
m_timestep_window.clear();
|
||||
}
|
||||
|
||||
std::string TimestepCollapseTrigger::name() const {
|
||||
return "TimestepCollapseTrigger";
|
||||
}
|
||||
|
||||
TriggerResult TimestepCollapseTrigger::why(
|
||||
const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx
|
||||
) const {
|
||||
TriggerResult result;
|
||||
result.name = name();
|
||||
|
||||
if (check(ctx)) {
|
||||
result.value = true;
|
||||
result.description = "Triggered because timestep change exceeded the threshold " + std::to_string(m_threshold);
|
||||
} else {
|
||||
result.value = false;
|
||||
result.description = "Not triggered because timestep change did not exceed the threshold " + std::to_string(m_threshold);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string TimestepCollapseTrigger::describe() const {
|
||||
return "TimestepCollapseTrigger(threshold=" + std::to_string(m_threshold) + ", relative=" + (m_relative ? "true" : "false") + ", windowSize=" + std::to_string(m_windowSize) + ")";
|
||||
}
|
||||
|
||||
size_t TimestepCollapseTrigger::numTriggers() const {
|
||||
return m_hits;
|
||||
}
|
||||
|
||||
size_t TimestepCollapseTrigger::numMisses() const {
|
||||
return m_misses;
|
||||
}
|
||||
|
||||
std::unique_ptr<Trigger<gridfire::solver::CVODESolverStrategy::TimestepContext>> makeEnginePartitioningTrigger(
|
||||
const double simulationTimeInterval,
|
||||
const double offDiagonalThreshold,
|
||||
const double timestepGrowthThreshold,
|
||||
const bool timestepGrowthRelative,
|
||||
const size_t timestepGrowthWindowSize
|
||||
) {
|
||||
using ctx_t = gridfire::solver::CVODESolverStrategy::TimestepContext;
|
||||
|
||||
// Create the individual conditions that can trigger a repartitioning
|
||||
auto simulationTimeTrigger = std::make_unique<EveryNthTrigger<ctx_t>>(std::make_unique<SimulationTimeTrigger>(simulationTimeInterval), 1000);
|
||||
auto offDiagTrigger = std::make_unique<OffDiagonalTrigger>(offDiagonalThreshold);
|
||||
auto timestepGrowthTrigger = std::make_unique<EveryNthTrigger<ctx_t>>(std::make_unique<TimestepCollapseTrigger>(timestepGrowthThreshold, timestepGrowthRelative, timestepGrowthWindowSize), 10);
|
||||
|
||||
// Combine the triggers using logical OR
|
||||
auto orTriggerA = std::make_unique<OrTrigger<ctx_t>>(std::move(simulationTimeTrigger), std::move(offDiagTrigger));
|
||||
auto orTriggerB = std::make_unique<OrTrigger<ctx_t>>(std::move(orTriggerA), std::move(timestepGrowthTrigger));
|
||||
|
||||
return orTriggerB;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user