feat(trigger): added working robust repartitioning trigger system

more work is needed to identify the most robust set of criteria to trigger on but the system is now very easy to exend, probe, and use.
This commit is contained in:
2025-09-29 13:35:48 -04:00
parent 4c91f8c525
commit 4f1c260444
12 changed files with 980 additions and 197 deletions

View File

@@ -81,7 +81,7 @@ subdir('src')
if get_option('build-python')
message('Configuring Python bindings...')
subdir('src-pybind')
subdir('build-python')
else
message('Skipping Python bindings...')
endif

View File

@@ -57,9 +57,9 @@ namespace gridfire {
// TODO: We should probably sort out how to adjust these from absolute to relative tolerances.
QSECacheConfig m_cacheConfig = {
1e-3, // Default tolerance for T9
1e-1, // Default tolerance for rho
1e-3 // Default tolerance for species abundances
1e-10, // Default tolerance for T9
1e-10, // Default tolerance for rho
1e-10 // Default tolerance for species abundances
};
/**
@@ -763,7 +763,7 @@ namespace gridfire {
std::string toString() const;
std::string toString(DynamicEngine &engine) const;
std::string toString(const DynamicEngine &engine) const;
};
/**
@@ -989,6 +989,11 @@ namespace gridfire {
* @brief Species that are treated as algebraic (in QSE) in the QSE groups.
*/
std::vector<fourdst::atomic::Species> m_algebraic_species;
/**
* @breif Stateful storage of the current algebraic species abundances. This is updated every time the update method is called.
*/
std::vector<double> m_Y_algebraic;
/**
* @brief Indices of algebraic species in the full network.
*/
@@ -1003,7 +1008,6 @@ namespace gridfire {
*/
std::vector<size_t> m_activeReactionIndices;
// TODO: Enhance the hashing for the cache to consider not just T and rho but also the current abundance in some careful way that automatically ignores small changes (i.e. network should only be repartitioned sometimes)
/**
* @brief Cache for QSE abundances based on T9, rho, and Y.
*

View File

@@ -0,0 +1,94 @@
#pragma once
#include "gridfire/trigger/trigger_abstract.h"
#include "gridfire/trigger/trigger_result.h"
#include "gridfire/solver/strategies/CVODE_solver_strategy.h"
#include "fourdst/logging/logging.h"
#include <string>
#include <deque>
#include <memory>
namespace gridfire::trigger::solver::CVODE {
class SimulationTimeTrigger final : public Trigger<gridfire::solver::CVODESolverStrategy::TimestepContext> {
public:
explicit SimulationTimeTrigger(double interval);
bool check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override;
void update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override;
void reset() override;
std::string name() const override;
TriggerResult why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override;
std::string describe() const override;
size_t numTriggers() const override;
size_t numMisses() const override;
private:
quill::Logger* m_logger = LogManager::getInstance().getLogger("log");
mutable size_t m_hits = 0;
mutable size_t m_misses = 0;
mutable size_t m_updates = 0;
mutable size_t m_resets = 0;
double m_interval;
mutable double m_last_trigger_time = 0.0;
mutable double m_last_trigger_time_delta = 0.0;
};
class OffDiagonalTrigger final : public Trigger<gridfire::solver::CVODESolverStrategy::TimestepContext> {
public:
explicit OffDiagonalTrigger(double threshold);
bool check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override;
void update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override;
void reset() override;
std::string name() const override;
TriggerResult why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override;
std::string describe() const override;
size_t numTriggers() const override;
size_t numMisses() const override;
private:
quill::Logger* m_logger = LogManager::getInstance().getLogger("log");
mutable size_t m_hits = 0;
mutable size_t m_misses = 0;
mutable size_t m_updates = 0;
mutable size_t m_resets = 0;
double m_threshold;
};
class TimestepCollapseTrigger final : public Trigger<gridfire::solver::CVODESolverStrategy::TimestepContext> {
public:
explicit TimestepCollapseTrigger(double threshold, bool relative);
explicit TimestepCollapseTrigger(double threshold, bool relative, size_t windowSize);
bool check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override;
void update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override;
void reset() override;
std::string name() const override;
TriggerResult why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override;
std::string describe() const override;
size_t numTriggers() const override;
size_t numMisses() const override;
private:
quill::Logger* m_logger = LogManager::getInstance().getLogger("log");
mutable size_t m_hits = 0;
mutable size_t m_misses = 0;
mutable size_t m_updates = 0;
mutable size_t m_resets = 0;
double m_threshold;
bool m_relative;
size_t m_windowSize;
std::deque<double> m_timestep_window;
};
std::unique_ptr<Trigger<gridfire::solver::CVODESolverStrategy::TimestepContext>> makeEnginePartitioningTrigger(
const double simulationTimeInterval,
const double offDiagonalThreshold,
const double timestepGrowthThreshold,
const bool timestepGrowthRelative,
const size_t timestepGrowthWindowSize
);
}

View File

@@ -0,0 +1,17 @@
#pragma once
#include "gridfire/trigger/trigger_result.h"
#include <iostream>
namespace gridfire::trigger {
inline void printWhy(const TriggerResult& result, const int indent = 0) {
const std::string prefix(indent * 2, ' ');
std::cout << prefix << "• [" << (result.value ? "TRUE" : "FALSE")
<< "] " << result.name << ": " << result.description << std::endl;
for (const auto& cause : result.causes) {
printWhy(cause, indent + 1);
}
}
}

View File

@@ -0,0 +1,26 @@
#pragma once
#include "gridfire/trigger/trigger_result.h"
#include <string>
#include <unordered_map>
namespace gridfire::trigger {
template <typename TriggerContextStruct>
class Trigger {
public:
virtual ~Trigger() = default;
virtual bool check(const TriggerContextStruct& ctx) const = 0;
virtual void update(const TriggerContextStruct& ctx) = 0;
virtual void reset() = 0;
virtual std::string name() const = 0;
virtual std::string describe() const = 0;
virtual TriggerResult why(const TriggerContextStruct& ctx) const = 0;
virtual size_t numTriggers() const = 0;
virtual size_t numMisses() const = 0;
};
}

View File

@@ -0,0 +1,440 @@
#pragma once
#include "gridfire/trigger/trigger_abstract.h"
#include "gridfire/trigger/trigger_result.h"
#include <string>
#include <vector>
#include <memory>
namespace gridfire::trigger {
template <typename TriggerContextStruct>
class LogicalTrigger : public Trigger<TriggerContextStruct> {};
template <typename TriggerContextStruct>
class AndTrigger final : public LogicalTrigger<TriggerContextStruct> {
public:
AndTrigger(std::unique_ptr<Trigger<TriggerContextStruct>> A, std::unique_ptr<Trigger<TriggerContextStruct>> B);
~AndTrigger() override = default;
bool check(const TriggerContextStruct& ctx) const override;
void update(const TriggerContextStruct& ctx) override;
void reset() override;
std::string name() const override;
TriggerResult why(const TriggerContextStruct& ctx) const override;
std::string describe() const override;
size_t numTriggers() const override;
size_t numMisses() const override;
private:
std::unique_ptr<Trigger<TriggerContextStruct>> m_A;
std::unique_ptr<Trigger<TriggerContextStruct>> m_B;
mutable size_t m_hits = 0;
mutable size_t m_misses = 0;
mutable size_t m_updates = 0;
mutable size_t m_resets = 0;
};
template <typename TriggerContextStruct>
class OrTrigger final : public LogicalTrigger<TriggerContextStruct> {
public:
OrTrigger(std::unique_ptr<Trigger<TriggerContextStruct>> A, std::unique_ptr<Trigger<TriggerContextStruct>> B);
~OrTrigger() override = default;
bool check(const TriggerContextStruct& ctx) const override;
void update(const TriggerContextStruct& ctx) override;
void reset() override;
std::string name() const override;
TriggerResult why(const TriggerContextStruct& ctx) const override;
std::string describe() const override;
size_t numTriggers() const override;
size_t numMisses() const override;
private:
std::unique_ptr<Trigger<TriggerContextStruct>> m_A;
std::unique_ptr<Trigger<TriggerContextStruct>> m_B;
mutable size_t m_hits = 0;
mutable size_t m_misses = 0;
mutable size_t m_updates = 0;
mutable size_t m_resets = 0;
};
template <typename TriggerContextStruct>
class NotTrigger final : public LogicalTrigger<TriggerContextStruct> {
public:
explicit NotTrigger(std::unique_ptr<Trigger<TriggerContextStruct>> A);
~NotTrigger() override = default;
bool check(const TriggerContextStruct& ctx) const override;
void update(const TriggerContextStruct& ctx) override;
void reset() override;
std::string name() const override;
TriggerResult why(const TriggerContextStruct& ctx) const override;
std::string describe() const override;
size_t numTriggers() const override;
size_t numMisses() const override;
private:
std::unique_ptr<Trigger<TriggerContextStruct>> m_A;
mutable size_t m_hits = 0;
mutable size_t m_misses = 0;
mutable size_t m_updates = 0;
mutable size_t m_resets = 0;
};
template <typename TriggerContextStruct>
class EveryNthTrigger final : public LogicalTrigger<TriggerContextStruct> {
public:
explicit EveryNthTrigger(std::unique_ptr<Trigger<TriggerContextStruct>> A, size_t N);
~EveryNthTrigger() override = default;
bool check(const TriggerContextStruct& ctx) const override;
void update(const TriggerContextStruct& ctx) override;
void reset() override;
std::string name() const override;
TriggerResult why(const TriggerContextStruct& ctx) const override;
std::string describe() const override;
size_t numTriggers() const override;
size_t numMisses() const override;
private:
std::unique_ptr<Trigger<TriggerContextStruct>> m_A;
size_t m_N;
mutable size_t m_counter = 0;
mutable size_t m_hits = 0;
mutable size_t m_misses = 0;
mutable size_t m_updates = 0;
mutable size_t m_resets = 0;
};
///////////////////////////////
// Templated Implementations //
///////////////////////////////
template<typename TriggerContextStruct>
AndTrigger<TriggerContextStruct>::AndTrigger(
std::unique_ptr<Trigger<TriggerContextStruct>> A,
std::unique_ptr<Trigger<TriggerContextStruct>> B
) : m_A(std::move(A)), m_B(std::move(B)) {}
template<typename TriggerContextStruct>
bool AndTrigger<TriggerContextStruct>::check(const TriggerContextStruct &ctx) const {
const bool valid = m_A->check(ctx) && m_B->check(ctx);
if (valid) {
m_hits++;
} else {
m_misses++;
}
return valid;
}
template <typename TriggerContextStruct>
void AndTrigger<TriggerContextStruct>::update(const TriggerContextStruct &ctx) {
m_A->update(ctx);
m_B->update(ctx);
m_updates++;
}
template <typename TriggerContextStruct>
void AndTrigger<TriggerContextStruct>::reset() {
m_A->reset();
m_B->reset();
m_resets++;
m_hits = 0;
m_misses = 0;
m_updates = 0;
}
template<typename TriggerContextStruct>
std::string AndTrigger<TriggerContextStruct>::name() const {
return "AND Trigger";
}
template<typename TriggerContextStruct>
TriggerResult AndTrigger<TriggerContextStruct>::why(const TriggerContextStruct &ctx) const {
TriggerResult result;
result.name = name();
TriggerResult A_result = m_A->why(ctx);
result.causes.push_back(A_result);
if (!A_result.value) {
// Short Circuit
result.value = false;
result.description = "Failed because A (" + A_result.name + ") is false.";
return result;
}
TriggerResult B_result = m_B->why(ctx);
result.causes.push_back(B_result);
if (!B_result.value) {
result.value = false;
result.description = "Failed because B (" + B_result.name + ") is false.";
return result;
}
result.value = true;
result.description = "Succeeded because both A (" + A_result.name + ") and B (" + B_result.description + ") are true.";
return result;
}
template <typename TriggerContextStruct>
std::string AndTrigger<TriggerContextStruct>::describe() const {
return "(" + m_A->describe() + ") AND (" + m_B->describe() + ")";
}
template <typename TriggerContextStruct>
size_t AndTrigger<TriggerContextStruct>::numTriggers() const {
return m_hits;
}
template <typename TriggerContextStruct>
size_t AndTrigger<TriggerContextStruct>::numMisses() const {
return m_misses;
}
template <typename TriggerContextStruct>
OrTrigger<TriggerContextStruct>::OrTrigger(
std::unique_ptr<Trigger<TriggerContextStruct>> A,
std::unique_ptr<Trigger<TriggerContextStruct>> B
) : m_A(std::move(A)), m_B(std::move(B)) {}
template <typename TriggerContextStruct>
bool OrTrigger<TriggerContextStruct>::check(const TriggerContextStruct &ctx) const {
const bool valid = m_A->check(ctx) || m_B->check(ctx);
if (valid) {
m_hits++;
} else {
m_misses++;
}
return valid;
}
template <typename TriggerContextStruct>
void OrTrigger<TriggerContextStruct>::update(const TriggerContextStruct &ctx) {
m_A->update(ctx);
m_B->update(ctx);
m_updates++;
}
template <typename TriggerContextStruct>
void OrTrigger<TriggerContextStruct>::reset() {
m_A->reset();
m_B->reset();
m_resets++;
m_hits = 0;
m_misses = 0;
m_updates = 0;
}
template<typename TriggerContextStruct>
std::string OrTrigger<TriggerContextStruct>::name() const {
return "OR Trigger";
}
template<typename TriggerContextStruct>
TriggerResult OrTrigger<TriggerContextStruct>::why(const TriggerContextStruct &ctx) const {
TriggerResult result;
result.name = name();
TriggerResult A_result = m_A->why(ctx);
result.causes.push_back(A_result);
if (A_result.value) {
// Short Circuit
result.value = true;
result.description = "Succeeded because A (" + A_result.name + ") is true.";
return result;
}
TriggerResult B_result = m_B->why(ctx);
result.causes.push_back(B_result);
if (B_result.value) {
result.value = true;
result.description = "Succeeded because B (" + B_result.name + ") is true.";
return result;
}
result.value = false;
result.description = "Failed because both A (" + A_result.name + ") and B (" + B_result.name + ") are false.";
return result;
}
template <typename TriggerContextStruct>
std::string OrTrigger<TriggerContextStruct>::describe() const {
return "(" + m_A->describe() + ") OR (" + m_B->describe() + ")";
}
template <typename TriggerContextStruct>
size_t OrTrigger<TriggerContextStruct>::numTriggers() const {
return m_hits;
}
template <typename TriggerContextStruct>
size_t OrTrigger<TriggerContextStruct>::numMisses() const {
return m_misses;
}
template <typename TriggerContextStruct>
NotTrigger<TriggerContextStruct>::NotTrigger(
std::unique_ptr<Trigger<TriggerContextStruct>> A
) : m_A(std::move(A)) {}
template <typename TriggerContextStruct>
bool NotTrigger<TriggerContextStruct>::check(const TriggerContextStruct &ctx) const {
const bool valid = !m_A->check(ctx);
if (valid) {
m_hits++;
} else {
m_misses++;
}
return valid;
}
template <typename TriggerContextStruct>
void NotTrigger<TriggerContextStruct>::update(const TriggerContextStruct &ctx) {
m_A->update(ctx);
m_updates++;
}
template <typename TriggerContextStruct>
void NotTrigger<TriggerContextStruct>::reset() {
m_A->reset();
m_resets++;
m_hits = 0;
m_misses = 0;
m_updates = 0;
}
template<typename TriggerContextStruct>
std::string NotTrigger<TriggerContextStruct>::name() const {
return "NOT Trigger";
}
template<typename TriggerContextStruct>
TriggerResult NotTrigger<TriggerContextStruct>::why(const TriggerContextStruct &ctx) const {
TriggerResult result;
result.name = name();
TriggerResult A_result = m_A->why(ctx);
result.causes.push_back(A_result);
if (A_result.value) {
result.value = false;
result.description = "Failed because A (" + A_result.name + ") is true.";
return result;
}
result.value = true;
result.description = "Succeeded because A (" + A_result.name + ") is false.";
return result;
}
template <typename TriggerContextStruct>
std::string NotTrigger<TriggerContextStruct>::describe() const {
return "NOT (" + m_A->describe() + ")";
}
template <typename TriggerContextStruct>
size_t NotTrigger<TriggerContextStruct>::numTriggers() const {
return m_hits;
}
template <typename TriggerContextStruct>
size_t NotTrigger<TriggerContextStruct>::numMisses() const {
return m_misses;
}
template <typename TriggerContextStruct>
EveryNthTrigger<TriggerContextStruct>::EveryNthTrigger(std::unique_ptr<Trigger<TriggerContextStruct>> A, const size_t N) : m_A(std::move(A)), m_N(N) {
if (N == 0) {
throw std::invalid_argument("N must be greater than 0.");
}
}
template <typename TriggerContextStruct>
bool EveryNthTrigger<TriggerContextStruct>::check(const TriggerContextStruct &ctx) const
{
if (m_A->check(ctx) && (m_counter % m_N == 0)) {
m_hits++;
return true;
}
m_misses++;
return false;
}
template <typename TriggerContextStruct>
void EveryNthTrigger<TriggerContextStruct>::update(const TriggerContextStruct &ctx) {
if (m_A->check(ctx)) {
m_counter++;
}
m_A->update(ctx);
m_updates++;
}
template <typename TriggerContextStruct>
void EveryNthTrigger<TriggerContextStruct>::reset() {
m_A->reset();
m_resets++;
m_counter = 0;
m_hits = 0;
m_misses = 0;
m_updates = 0;
}
template <typename TriggerContextStruct>
std::string EveryNthTrigger<TriggerContextStruct>::name() const {
return "Every Nth Trigger";
}
template <typename TriggerContextStruct>
TriggerResult EveryNthTrigger<TriggerContextStruct>::why(const TriggerContextStruct &ctx) const {
TriggerResult result;
result.name = name();
TriggerResult A_result = m_A->why(ctx);
result.causes.push_back(A_result);
if (!A_result.value) {
result.value = false;
result.description = "Failed because A (" + A_result.name + ") is false.";
return result;
}
if (m_counter % m_N == 0) {
result.value = true;
result.description = "Succeeded because A (" + A_result.name + ") is true and the counter (" + std::to_string(m_counter) + ") is a multiple of N (" + std::to_string(m_N) + ").";
return result;
}
result.value = false;
result.description = "Failed because the counter (" + std::to_string(m_counter) + ") is not a multiple of N (" + std::to_string(m_N) + ").";
return result;
}
template <typename TriggerContextStruct>
std::string EveryNthTrigger<TriggerContextStruct>::describe() const {
return "Every " + std::to_string(m_N) + "th (" + m_A->describe() + ")";
}
template <typename TriggerContextStruct>
size_t EveryNthTrigger<TriggerContextStruct>::numTriggers() const {
return m_hits;
}
template <typename TriggerContextStruct>
size_t EveryNthTrigger<TriggerContextStruct>::numMisses() const {
return m_misses;
}
}

View File

@@ -0,0 +1,13 @@
#pragma once
#include <vector>
#include <string>
namespace gridfire::trigger {
struct TriggerResult {
std::string name;
std::string description;
bool value;
std::vector<TriggerResult> causes;
};
}

View File

@@ -176,20 +176,6 @@ namespace gridfire {
}
// Check the cache to see if the network needs to be repartitioned. Note that the QSECacheKey manages binning of T9, rho, and Y_full to ensure that small changes (which would likely not result in a repartitioning) do not trigger a cache miss.
const QSECacheKey key(T9, rho, Y_full);
if (! m_qse_abundance_cache.contains(key)) {
m_cacheStats.miss(CacheStats::operators::CalculateRHSAndEnergy);
LOG_ERROR(
m_logger,
"QSE abundance cache miss for T9 = {}, rho = {} (misses: {}, hits: {}). calculateRHSAndEnergy does not receive sufficient context to partition and stabilize the network. Throwing an error which should be caught by the caller and trigger a re-partition stage.",
T9,
rho,
m_cacheStats.misses(),
m_cacheStats.hits()
);
return std::unexpected{expectations::StaleEngineError(expectations::StaleEngineErrorTypes::SYSTEM_RESIZED)};
}
m_cacheStats.hit(CacheStats::operators::CalculateRHSAndEnergy);
const auto result = m_baseEngine.calculateRHSAndEnergy(Y_full, T9, rho);
if (!result) {
return std::unexpected{result.error()};
@@ -215,21 +201,6 @@ namespace gridfire {
const double T9,
const double rho
) const {
const QSECacheKey key(T9, rho, Y_full);
if (!m_qse_abundance_cache.contains(key)) {
m_cacheStats.miss(CacheStats::operators::GenerateJacobianMatrix);
LOG_ERROR(
m_logger,
"QSE abundance cache miss for T9 = {}, rho = {} (misses: {}, hits: {}). generateJacobianMatrix does not receive sufficient context to partition and stabilize the network. Throwing an error which should be caught by the caller and trigger a re-partition stage.",
T9,
rho,
m_cacheStats.misses(),
m_cacheStats.hits()
);
throw exceptions::StaleEngineError("QSE Cache Miss while lacking context for partitioning. This should be caught by the caller and trigger a re-partition stage.");
}
m_cacheStats.hit(CacheStats::operators::GenerateJacobianMatrix);
// TODO: Add sparsity pattern to this to prevent base engine from doing unnecessary work.
m_baseEngine.generateJacobianMatrix(Y_full, T9, rho);
}
@@ -268,27 +239,11 @@ namespace gridfire {
const double T9,
const double rho
) const {
const auto key = QSECacheKey(T9, rho, Y_full);
if (!m_qse_abundance_cache.contains(key)) {
m_cacheStats.miss(CacheStats::operators::CalculateMolarReactionFlow);
LOG_ERROR(
m_logger,
"QSE abundance cache miss for T9 = {}, rho = {} (misses: {}, hits: {}). calculateMolarReactionFlow does not receive sufficient context to partition and stabilize the network. Throwing an error which should be caught by the caller and trigger a re-partition stage.",
T9,
rho,
m_cacheStats.misses(),
m_cacheStats.hits()
);
throw exceptions::StaleEngineError("QSE Cache Miss while lacking context for partitioning. This should be caught by the caller and trigger a re-partition stage.");
}
m_cacheStats.hit(CacheStats::operators::CalculateMolarReactionFlow);
std::vector<double> Y_algebraic = m_qse_abundance_cache.at(key);
assert(Y_algebraic.size() == m_algebraic_species_indices.size());
assert(m_Y_algebraic.size() == m_algebraic_species_indices.size());
// Fix the algebraic species to the equilibrium abundances we calculate.
std::vector<double> Y_mutable = Y_full;
for (const auto& [index, Yi] : std::views::zip(m_algebraic_species_indices, Y_algebraic)) {
for (const auto& [index, Yi] : std::views::zip(m_algebraic_species_indices, m_Y_algebraic)) {
Y_mutable[index] = Yi;
}
@@ -309,20 +264,6 @@ namespace gridfire {
const double T9,
const double rho
) const {
const auto key = QSECacheKey(T9, rho, Y);
if (!m_qse_abundance_cache.contains(key)) {
m_cacheStats.miss(CacheStats::operators::GetSpeciesTimescales);
LOG_ERROR(
m_logger,
"QSE abundance cache miss for T9 = {}, rho = {} (misses: {}, hits: {}). getSpeciesTimescales does not receive sufficient context to partition and stabilize the network. Throwing an error which should be caught by the caller and trigger a re-partition stage.",
T9,
rho,
m_cacheStats.misses(),
m_cacheStats.hits()
);
throw exceptions::StaleEngineError("QSE Cache Miss while lacking context for partitioning. This should be caught by the caller and trigger a re-partition stage.");
}
m_cacheStats.hit(CacheStats::operators::GetSpeciesTimescales);
const auto result = m_baseEngine.getSpeciesTimescales(Y, T9, rho);
if (!result) {
return std::unexpected{result.error()};
@@ -337,23 +278,9 @@ namespace gridfire {
std::expected<std::unordered_map<fourdst::atomic::Species, double>, expectations::StaleEngineError>
MultiscalePartitioningEngineView::getSpeciesDestructionTimescales(
const std::vector<double> &Y,
double T9,
double rho
const double T9,
const double rho
) const {
const auto key = QSECacheKey(T9, rho, Y);
if (!m_qse_abundance_cache.contains(key)) {
m_cacheStats.miss(CacheStats::operators::GetSpeciesDestructionTimescales);
LOG_ERROR(
m_logger,
"QSE abundance cache miss for T9 = {}, rho = {} (misses: {}, hits: {}). getSpeciesDestructionTimescales does not receive sufficient context to partition and stabilize the network. Throwing an error which should be caught by the caller and trigger a re-partition stage.",
T9,
rho,
m_cacheStats.misses(),
m_cacheStats.hits()
);
throw exceptions::StaleEngineError("QSE Cache Miss while lacking context for partitioning. This should be caught by the caller and trigger a re-partition stage.");
}
m_cacheStats.hit(CacheStats::operators::GetSpeciesDestructionTimescales);
const auto result = m_baseEngine.getSpeciesDestructionTimescales(Y, T9, rho);
if (!result) {
return std::unexpected{result.error()};
@@ -367,16 +294,7 @@ namespace gridfire {
fourdst::composition::Composition MultiscalePartitioningEngineView::update(const NetIn &netIn) {
const fourdst::composition::Composition baseUpdatedComposition = m_baseEngine.update(netIn);
double T9 = netIn.temperature / 1.0e9; // Convert temperature from Kelvin to T9 (T9 = T / 1e9)
const auto preKey = QSECacheKey(
T9,
netIn.density,
packCompositionToVector(baseUpdatedComposition, m_baseEngine)
);
if (m_qse_abundance_cache.contains(preKey)) {
return baseUpdatedComposition;
}
NetIn baseUpdatedNetIn = netIn;
baseUpdatedNetIn.composition = baseUpdatedComposition;
const fourdst::composition::Composition equilibratedComposition = equilibrateNetwork(baseUpdatedNetIn);
@@ -386,15 +304,7 @@ namespace gridfire {
Y_algebraic[i] = equilibratedComposition.getMolarAbundance(m_baseEngine.getNetworkSpecies()[species_index]);
}
// We store the algebraic abundances in the cache for both pre- and post-conditions to avoid recalculating them.
m_qse_abundance_cache[preKey] = Y_algebraic;
const auto postKey = QSECacheKey(
T9,
netIn.density,
packCompositionToVector(equilibratedComposition, m_baseEngine)
);
m_qse_abundance_cache[postKey] = Y_algebraic;
m_Y_algebraic = std::move(Y_algebraic);
return equilibratedComposition;
}
@@ -594,11 +504,6 @@ namespace gridfire {
m_qse_groups.size(),
m_qse_groups.size() == 1 ? "" : "s"
);
// throw std::runtime_error(
// "Partitioning complete. Throwing an error to end the program during debugging. This error should not be caught by the caller. "
// );
}
void MultiscalePartitioningEngineView::partitionNetwork(
@@ -1129,29 +1034,6 @@ namespace gridfire {
coupling_flux += flow * coupling_fraction;
}
// if (leakage_flux < 1e-99) {
// LOG_TRACE_L1(
// m_logger,
// "Group containing {} is in equilibrium due to vanishing leakage: leakage flux = {}, coupling flux = {}, ratio = {}",
// [&]() -> std::string {
// std::stringstream ss;
// int count = 0;
// for (const auto& idx : group.algebraic_indices) {
// ss << m_baseEngine.getNetworkSpecies()[idx].name();
// if (count < group.species_indices.size() - 1) {
// ss << ", ";
// }
// count++;
// }
// return ss.str();
// }(),
// leakage_flux,
// coupling_flux,
// coupling_flux / leakage_flux
// );
// validated_groups.emplace_back(group);
// validated_groups.back().is_in_equilibrium = true;
// } else if ((coupling_flux / leakage_flux ) > FLUX_RATIO_THRESHOLD) {
if ((coupling_flux / leakage_flux ) > FLUX_RATIO_THRESHOLD) {
LOG_TRACE_L1(
m_logger,
@@ -1703,7 +1585,7 @@ namespace gridfire {
}
std::string MultiscalePartitioningEngineView::QSEGroup::toString(DynamicEngine &engine) const {
std::string MultiscalePartitioningEngineView::QSEGroup::toString(const DynamicEngine &engine) const {
std::stringstream ss;
ss << "QSEGroup(Algebraic: [";
size_t count = 0;

View File

@@ -17,6 +17,8 @@
#include <algorithm>
#include "fourdst/composition/exceptions/exceptions_composition.h"
#include "gridfire/solver/strategies/triggers/engine_partitioning_trigger.h"
#include "gridfire/trigger/procedures/trigger_pprint.h"
namespace {
@@ -78,6 +80,43 @@ namespace {
namespace gridfire::solver {
CVODESolverStrategy::TimestepContext::TimestepContext(
const double t,
const N_Vector &state,
const double dt,
const double last_step_time,
const double t9,
const double rho,
const int num_steps,
const DynamicEngine &engine,
const std::vector<fourdst::atomic::Species> &networkSpecies
) :
t(t),
state(state),
dt(dt),
last_step_time(last_step_time),
T9(t9),
rho(rho),
num_steps(num_steps),
engine(engine),
networkSpecies(networkSpecies)
{}
std::vector<std::tuple<std::string, std::string>> CVODESolverStrategy::TimestepContext::describe() const {
std::vector<std::tuple<std::string, std::string>> description;
description.emplace_back("t", "Current Time");
description.emplace_back("state", "Current State Vector (N_Vector)");
description.emplace_back("dt", "Last Timestep Size");
description.emplace_back("last_step_time", "Time at Last Step");
description.emplace_back("T9", "Temperature in GK");
description.emplace_back("rho", "Density in g/cm^3");
description.emplace_back("num_steps", "Number of Steps Taken So Far");
description.emplace_back("engine", "Reference to the DynamicEngine");
description.emplace_back("networkSpecies", "Reference to the list of network species");
return description;
}
CVODESolverStrategy::CVODESolverStrategy(DynamicEngine &engine): NetworkSolverStrategy<DynamicEngine>(engine) {
// TODO: In order to support MPI this function must be changed
const int flag = SUNContext_Create(SUN_COMM_NULL, &m_sun_ctx);
@@ -96,6 +135,8 @@ namespace gridfire::solver {
}
NetOut CVODESolverStrategy::evaluate(const NetIn& netIn) {
auto trigger = trigger::solver::CVODE::makeEnginePartitioningTrigger(1e12, 1e10, 1, true, 10);
const double T9 = netIn.temperature / 1e9; // Convert temperature from Kelvin to T9 (T9 = T / 1e9)
const auto absTol = m_config.get<double>("gridfire:solver:CVODESolverStrategy:absTol", 1.0e-8);
@@ -121,7 +162,6 @@ namespace gridfire::solver {
double accumulated_energy = 0.0;
int total_update_stages_triggered = 0;
while (current_time < netIn.tMax) {
try {
user_data.T9 = T9;
user_data.rho = netIn.density;
user_data.networkSpecies = &m_engine.getNetworkSpecies();
@@ -157,38 +197,28 @@ namespace gridfire::solver {
<< " | Time: " << current_time << " [s]"
<< " | Last Step Size: " << last_step_size
<< " | Accumulated Energy: " << current_energy << " [erg/g]"
<< " | NonlinIters: " << std::setw(2) << nliters
<< " | NonLinIters: " << std::setw(2) << nliters
<< " | ConvFails: " << std::setw(2) << nlcfails
<< std::endl;
if (n_steps % 300 == 0) {
std::cout << "Manually triggering engine update at step " << n_steps << "..." << std::endl;
exceptions::StaleEngineTrigger::state staleState {
auto ctx = TimestepContext(
current_time,
reinterpret_cast<N_Vector>(y_data),
last_step_size,
last_callback_time,
T9,
netIn.density,
std::vector<double>(y_data, y_data + numSpecies),
current_time,
static_cast<int>(n_steps),
current_energy
};
throw exceptions::StaleEngineTrigger(staleState);
}
n_steps,
m_engine,
m_engine.getNetworkSpecies());
// if (n_steps % 50 == 0) {
// std::cout << "Logging step diagnostics at step " << n_steps << "..." << std::endl;
// log_step_diagnostics(user_data);
// }
// if (n_steps == 300) {
// log_step_diagnostics(user_data);
// exit(0);
// }
// log_step_diagnostics(user_data);
} catch (const exceptions::StaleEngineTrigger& e) {
exceptions::StaleEngineTrigger::state staleState = e.getState();
accumulated_energy += e.energy(); // Add the specific energy rate to the accumulated energy
if (trigger->check(ctx)) {
trigger::printWhy(trigger->why(ctx));
trigger->update(ctx);
accumulated_energy += current_energy; // Add the specific energy rate to the accumulated energy
LOG_INFO(
m_logger,
"Engine Update Triggered due to StaleEngineTrigger exception at time {} ({} update{} triggered). Current total specific energy {} [erg/g]",
"Engine Update Triggered at time {} ({} update{} triggered). Current total specific energy {} [erg/g]",
current_time,
total_update_stages_triggered,
total_update_stages_triggered == 1 ? "" : "s",
@@ -197,21 +227,31 @@ namespace gridfire::solver {
fourdst::composition::Composition temp_comp;
std::vector<double> mass_fractions;
size_t num_species_at_stop = e.numSpecies();
size_t num_species_at_stop = m_engine.getNetworkSpecies().size();
if (num_species_at_stop > m_Y->ops->nvgetlength(m_Y) - 1) {
LOG_ERROR(
m_logger,
"Number of species at engine update ({}) exceeds the number of species in the CVODE solver ({}). This should never happen.",
num_species_at_stop,
m_Y->ops->nvgetlength(m_Y) - 1 // -1 due to energy in the last index
);
throw std::runtime_error("Number of species at engine update exceeds the number of species in the CVODE solver. This should never happen.");
}
mass_fractions.reserve(num_species_at_stop);
for (size_t i = 0; i < num_species_at_stop; ++i) {
const auto& species = m_engine.getNetworkSpecies()[i];
temp_comp.registerSpecies(species);
mass_fractions.push_back(e.getMolarAbundance(i) * species.mass()); // Convert from molar abundance to mass fraction
mass_fractions.push_back(y_data[i] * species.mass()); // Convert from molar abundance to mass fraction
}
temp_comp.setMassFraction(m_engine.getNetworkSpecies(), mass_fractions);
temp_comp.finalize(true);
NetIn netInTemp = netIn;
netInTemp.temperature = e.temperature();
netInTemp.density = e.density();
netInTemp.temperature = T9 * 1e9; // Convert back to Kelvin
netInTemp.density = netIn.density;
netInTemp.composition = temp_comp;
fourdst::composition::Composition currentComposition = m_engine.update(netInTemp);
@@ -225,13 +265,16 @@ namespace gridfire::solver {
numSpecies = m_engine.getNetworkSpecies().size();
N = numSpecies + 1;
cleanup_cvode_resources(true);
m_cvode_mem = CVodeCreate(CV_BDF, m_sun_ctx);
check_cvode_flag(m_cvode_mem == nullptr ? -1 : 0, "CVodeCreate");
initialize_cvode_integration_resources(N, numSpecies, current_time, currentComposition, absTol, relTol, accumulated_energy);
check_cvode_flag(CVodeReInit(m_cvode_mem, current_time, m_Y), "CVodeReInit");
} catch (fourdst::composition::exceptions::InvalidCompositionError& e) {
log_step_diagnostics(user_data);
std::rethrow_exception(std::make_exception_ptr(e));
}
}
sunrealtype* y_data = N_VGetArrayPointer(m_Y);

View File

@@ -0,0 +1,263 @@
#include "gridfire/solver/strategies/triggers/engine_partitioning_trigger.h"
#include "gridfire/solver/strategies/CVODE_solver_strategy.h"
#include "gridfire/trigger/trigger_logical.h"
#include "gridfire/trigger/trigger_abstract.h"
#include "quill/LogMacros.h"
#include <memory>
#include <deque>
#include <string>
namespace {
template <typename T>
void push_to_fixed_deque(std::deque<T>& dq, T value, size_t max_size) {
dq.push_back(value);
if (dq.size() > max_size) {
dq.pop_front();
}
}
}
namespace gridfire::trigger::solver::CVODE {
SimulationTimeTrigger::SimulationTimeTrigger(double interval) : m_interval(interval) {
if (interval <= 0.0) {
LOG_ERROR(m_logger, "Interval must be positive, currently it is {}", interval);
throw std::invalid_argument("Interval must be positive, currently it is " + std::to_string(interval));
}
}
bool SimulationTimeTrigger::check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const {
if (ctx.t - m_last_trigger_time >= m_interval) {
m_hits++;
LOG_TRACE_L2(m_logger, "SimulationTimeTrigger triggered at t = {}, last trigger time was {}, delta = {}", ctx.t, m_last_trigger_time, m_last_trigger_time_delta);
return true;
}
m_misses++;
return false;
}
void SimulationTimeTrigger::update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) {
if (check(ctx)) {
m_last_trigger_time_delta = (ctx.t - m_last_trigger_time) - m_interval;
m_last_trigger_time = ctx.t;
m_updates++;
}
}
void SimulationTimeTrigger::reset() {
m_misses = 0;
m_hits = 0;
m_updates = 0;
m_last_trigger_time = 0.0;
m_last_trigger_time_delta = 0.0;
m_resets++;
}
std::string SimulationTimeTrigger::name() const {
return "Simulation Time Trigger";
}
TriggerResult SimulationTimeTrigger::why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const {
TriggerResult result;
result.name = name();
if (check(ctx)) {
result.value = true;
result.description = "Triggered because current time " + std::to_string(ctx.t) + " - last trigger time " + std::to_string(m_last_trigger_time - m_last_trigger_time_delta) + " >= interval " + std::to_string(m_interval);
} else {
result.value = false;
result.description = "Not triggered because current time " + std::to_string(ctx.t) + " - last trigger time " + std::to_string(m_last_trigger_time) + " < interval " + std::to_string(m_interval);
}
return result;
}
std::string SimulationTimeTrigger::describe() const {
return "SimulationTimeTrigger(interval=" + std::to_string(m_interval) + ")";
}
size_t SimulationTimeTrigger::numTriggers() const {
return m_hits;
}
size_t SimulationTimeTrigger::numMisses() const {
return m_misses;
}
OffDiagonalTrigger::OffDiagonalTrigger(
double threshold
) : m_threshold(threshold) {
if (threshold < 0.0) {
LOG_ERROR(m_logger, "Threshold must be non-negative, currently it is {}", threshold);
throw std::invalid_argument("Threshold must be non-negative, currently it is " + std::to_string(threshold));
}
}
bool OffDiagonalTrigger::check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const {
const size_t numSpecies = ctx.engine.getNetworkSpecies().size();
for (int row = 0; row < numSpecies; ++row) {
for (int col = 0; col < numSpecies; ++col) {
double DRowDCol = std::abs(ctx.engine.getJacobianMatrixEntry(row, col));
if (row != col && DRowDCol > m_threshold) {
m_hits++;
LOG_TRACE_L2(m_logger, "OffDiagonalTrigger triggered at t = {} due to entry ({}, {}) = {}", ctx.t, row, col, DRowDCol);
return true;
}
}
}
m_misses++;
return false;
}
void OffDiagonalTrigger::update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) {
m_updates++;
}
void OffDiagonalTrigger::reset() {
m_misses = 0;
m_hits = 0;
m_updates = 0;
m_resets++;
}
std::string OffDiagonalTrigger::name() const {
return "Off-Diagonal Trigger";
}
TriggerResult OffDiagonalTrigger::why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const {
TriggerResult result;
result.name = name();
if (check(ctx)) {
result.value = true;
result.description = "Triggered because an off-diagonal Jacobian entry exceeded the threshold " + std::to_string(m_threshold);
} else {
result.value = false;
result.description = "Not triggered because no off-diagonal Jacobian entry exceeded the threshold " + std::to_string(m_threshold);
}
return result;
}
std::string OffDiagonalTrigger::describe() const {
return "OffDiagonalTrigger(threshold=" + std::to_string(m_threshold) + ")";
}
size_t OffDiagonalTrigger::numTriggers() const {
return m_hits;
}
size_t OffDiagonalTrigger::numMisses() const {
return m_misses;
}
TimestepCollapseTrigger::TimestepCollapseTrigger(
const double threshold,
const bool relative
) : TimestepCollapseTrigger(threshold, relative, 1){}
TimestepCollapseTrigger::TimestepCollapseTrigger(
double threshold,
const bool relative,
const size_t windowSize
) : m_threshold(threshold), m_relative(relative), m_windowSize(windowSize) {
if (threshold < 0.0) {
LOG_ERROR(m_logger, "Threshold must be non-negative, currently it is {}", threshold);
throw std::invalid_argument("Threshold must be non-negative, currently it is " + std::to_string(threshold));
}
if (relative && threshold > 1.0) {
LOG_ERROR(m_logger, "Relative threshold must be between 0 and 1, currently it is {}", threshold);
throw std::invalid_argument("Relative threshold must be between 0 and 1, currently it is " + std::to_string(threshold));
}
}
bool TimestepCollapseTrigger::check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const {
if (m_timestep_window.size() < 1) {
m_misses++;
return false;
}
double averageTimestep = 0.0;
for (const auto& dt : m_timestep_window) {
averageTimestep += dt;
}
averageTimestep /= m_timestep_window.size();
if (m_relative && (std::abs(ctx.dt - averageTimestep) / averageTimestep) >= m_threshold) {
m_hits++;
LOG_TRACE_L2(m_logger, "TimestepCollapseTrigger triggered at t = {} due to relative growth: dt = {}, average dt = {}, threshold = {}", ctx.t, ctx.dt, averageTimestep, m_threshold);
return true;
} else if (!m_relative && std::abs(ctx.dt - averageTimestep) >= m_threshold) {
m_hits++;
LOG_TRACE_L2(m_logger, "TimestepCollapseTrigger triggered at t = {} due to absolute growth: dt = {}, average dt = {}, threshold = {}", ctx.t, ctx.dt, averageTimestep, m_threshold);
return true;
}
m_misses++;
return false;
}
void TimestepCollapseTrigger::update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) {
push_to_fixed_deque(m_timestep_window, ctx.dt, m_windowSize);
m_updates++;
}
void TimestepCollapseTrigger::reset() {
m_misses = 0;
m_hits = 0;
m_updates = 0;
m_resets++;
m_timestep_window.clear();
}
std::string TimestepCollapseTrigger::name() const {
return "TimestepCollapseTrigger";
}
TriggerResult TimestepCollapseTrigger::why(
const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx
) const {
TriggerResult result;
result.name = name();
if (check(ctx)) {
result.value = true;
result.description = "Triggered because timestep change exceeded the threshold " + std::to_string(m_threshold);
} else {
result.value = false;
result.description = "Not triggered because timestep change did not exceed the threshold " + std::to_string(m_threshold);
}
return result;
}
std::string TimestepCollapseTrigger::describe() const {
return "TimestepCollapseTrigger(threshold=" + std::to_string(m_threshold) + ", relative=" + (m_relative ? "true" : "false") + ", windowSize=" + std::to_string(m_windowSize) + ")";
}
size_t TimestepCollapseTrigger::numTriggers() const {
return m_hits;
}
size_t TimestepCollapseTrigger::numMisses() const {
return m_misses;
}
std::unique_ptr<Trigger<gridfire::solver::CVODESolverStrategy::TimestepContext>> makeEnginePartitioningTrigger(
const double simulationTimeInterval,
const double offDiagonalThreshold,
const double timestepGrowthThreshold,
const bool timestepGrowthRelative,
const size_t timestepGrowthWindowSize
) {
using ctx_t = gridfire::solver::CVODESolverStrategy::TimestepContext;
// Create the individual conditions that can trigger a repartitioning
auto simulationTimeTrigger = std::make_unique<EveryNthTrigger<ctx_t>>(std::make_unique<SimulationTimeTrigger>(simulationTimeInterval), 1000);
auto offDiagTrigger = std::make_unique<OffDiagonalTrigger>(offDiagonalThreshold);
auto timestepGrowthTrigger = std::make_unique<EveryNthTrigger<ctx_t>>(std::make_unique<TimestepCollapseTrigger>(timestepGrowthThreshold, timestepGrowthRelative, timestepGrowthWindowSize), 10);
// Combine the triggers using logical OR
auto orTriggerA = std::make_unique<OrTrigger<ctx_t>>(std::move(simulationTimeTrigger), std::move(offDiagTrigger));
auto orTriggerB = std::make_unique<OrTrigger<ctx_t>>(std::move(orTriggerA), std::move(timestepGrowthTrigger));
return orTriggerB;
}
}

View File

@@ -15,6 +15,7 @@ gridfire_sources = files(
'lib/io/network_file.cpp',
'lib/solver/solver.cpp',
'lib/solver/strategies/CVODE_solver_strategy.cpp',
'lib/solver/strategies/triggers/engine_partitioning_trigger.cpp',
'lib/screening/screening_types.cpp',
'lib/screening/screening_weak.cpp',
'lib/screening/screening_bare.cpp',
@@ -56,7 +57,7 @@ install_subdir('include/gridfire', install_dir: get_option('includedir'))
if get_option('build-python')
message('Configuring Python bindings...')
subdir('src-pybind')
subdir('python')
else
message('Skipping Python bindings...')
endif

View File

@@ -122,7 +122,7 @@ int main(int argc, char* argv[]){
netIn.temperature = 1.5e7;
netIn.density = 1.6e2;
netIn.energy = 0;
netIn.tMax = 3e13;
netIn.tMax = 3e16;
// netIn.tMax = 1e-14;
netIn.dt0 = 1e-12;