perf(multi): Simple parallel multi zone solver

Added a simple parallel multi-zone solver
This commit is contained in:
2025-12-18 12:47:39 -05:00
parent 4e1edfc142
commit dcfd7b60aa
27 changed files with 1018 additions and 2193 deletions

View File

@@ -11,10 +11,10 @@ namespace gridfire::config {
struct SpectralSolverConfig {
struct Trigger {
double simulationTimeInterval = 1.0e12;
double offDiagonalThreshold = 1.0e10;
double timestepCollapseRatio = 0.5;
size_t maxConvergenceFailures = 2;
double relativeFailureRate = 0.5;
size_t windowSize = 10;
};
struct MonitorFunctionConfig {
double structure_weight = 1.0;

View File

@@ -807,8 +807,6 @@ namespace gridfire::engine {
CppAD::ADFun<double> m_authoritativeADFun;
const size_t m_state_blob_offset;
private:
/**
* @brief Synchronizes the internal maps.

View File

@@ -0,0 +1,43 @@
#pragma once
#include "gridfire/solver/strategies/strategy_abstract.h"
#include <functional>
namespace gridfire::solver {
struct GridSolverContext final : SolverContextBase {
std::vector<std::unique_ptr<SolverContextBase>> solver_workspaces;
std::vector<std::function<void(const TimestepContextBase&)>> timestep_callbacks;
const engine::scratch::StateBlob& ctx_template;
bool zone_completion_logging = true;
bool zone_stdout_logging = false;
bool zone_detailed_logging = false;
void init() override;
void reset();
void set_callback(const std::function<void(const TimestepContextBase&)> &callback);
void set_callback(const std::function<void(const TimestepContextBase&)> &callback, size_t zone_idx);
void set_stdout_logging(bool enable) override;
void set_detailed_logging(bool enable) override;
explicit GridSolverContext(const engine::scratch::StateBlob& ctx_template);
};
class GridSolver final : public MultiZoneDynamicNetworkSolver {
public:
GridSolver(
const engine::DynamicEngine& engine,
const SingleZoneDynamicNetworkSolver& solver
);
~GridSolver() override = default;
std::vector<NetOut> evaluate(
SolverContextBase& ctx,
const std::vector<NetIn>& netIns
) const override;
};
}

View File

@@ -44,8 +44,88 @@
#endif
namespace gridfire::solver {
struct PointSolverTimestepContext final : TimestepContextBase {
const double t; ///< Current integration time [s].
const N_Vector& state; ///< Current CVODE state vector (N_Vector).
const double dt; ///< Last step size [s].
const double last_step_time; ///< Time at last callback [s].
const double T9; ///< Temperature in GK.
const double rho; ///< Density [g cm^-3].
const size_t num_steps; ///< Number of CVODE steps taken so far.
const engine::DynamicEngine& engine; ///< Reference to the engine.
const std::vector<fourdst::atomic::Species>& networkSpecies; ///< Species layout.
const size_t currentConvergenceFailures; ///< Total number of convergence failures
const size_t currentNonlinearIterations; ///< Total number of non-linear iterations
const std::map<fourdst::atomic::Species, std::unordered_map<std::string, double>>& reactionContributionMap; ///< Map of reaction contributions for the current step
engine::scratch::StateBlob& state_ctx; ///< Reference to the engine scratch state blob
PointSolverTimestepContext(
double t,
const N_Vector& state,
double dt,
double last_step_time,
double t9,
double rho,
size_t num_steps,
const engine::DynamicEngine& engine,
const std::vector<fourdst::atomic::Species>& networkSpecies,
size_t currentConvergenceFailure,
size_t currentNonlinearIterations,
const std::map<fourdst::atomic::Species, std::unordered_map<std::string, double>> &reactionContributionMap,
engine::scratch::StateBlob& state_ctx
);
[[nodiscard]] std::vector<std::tuple<std::string, std::string>> describe() const override;
};
using TimestepCallback = std::function<void(const PointSolverTimestepContext& context)>; ///< Type alias for a timestep callback function.
struct PointSolverContext final : SolverContextBase {
SUNContext sun_ctx = nullptr; ///< SUNDIALS context (lifetime of the solver).
void* cvode_mem = nullptr; ///< CVODE memory block.
N_Vector Y = nullptr; ///< CVODE state vector (species + energy accumulator).
N_Vector YErr = nullptr; ///< Estimated local errors.
SUNMatrix J = nullptr; ///< Dense Jacobian matrix.
SUNLinearSolver LS = nullptr; ///< Dense linear solver.
std::unique_ptr<engine::scratch::StateBlob> engine_ctx;
std::optional<TimestepCallback> callback; ///< Optional per-step callback.
int num_steps = 0; ///< CVODE step counter (used for diagnostics and triggers).
bool stdout_logging = true; ///< If true, print per-step logs and use CV_ONE_STEP.
N_Vector constraints = nullptr; ///< CVODE constraints vector (>= 0 for species entries).
std::optional<double> abs_tol; ///< User-specified absolute tolerance.
std::optional<double> rel_tol; ///< User-specified relative tolerance.
bool detailed_step_logging = false; ///< If true, log detailed step diagnostics (error ratios, Jacobian, species balance).
size_t last_size = 0;
size_t last_composition_hash = 0ULL;
sunrealtype last_good_time_step = 0ULL;
void init() override;
void set_stdout_logging(bool enable) override;
void set_detailed_logging(bool enable) override;
void reset_all();
void reset_user();
void reset_cvode();
void clear_context();
void init_context();
[[nodiscard]] bool has_context() const;
explicit PointSolverContext(const engine::scratch::StateBlob& engine_ctx);
~PointSolverContext() override;
};
/**
* @class CVODESolverStrategy
* @class PointSolver
* @brief Stiff ODE integrator backed by SUNDIALS CVODE (BDF) for network + energy.
*
* Integrates the nuclear network abundances along with a final accumulator entry for specific
@@ -78,27 +158,16 @@ namespace gridfire::solver {
* std::cout << "Final energy: " << out.energy << " erg/g\n";
* @endcode
*/
class CVODESolverStrategy final : public SingleZoneDynamicNetworkSolver {
class PointSolver final : public SingleZoneDynamicNetworkSolver {
public:
/**
* @brief Construct the CVODE strategy and create a SUNDIALS context.
* @param engine DynamicEngine used for RHS/Jacobian evaluation and network access.
* @throws std::runtime_error If SUNContext_Create fails.
*/
explicit CVODESolverStrategy(
const engine::DynamicEngine& engine,
const engine::scratch::StateBlob& ctx
explicit PointSolver(
const engine::DynamicEngine& engine
);
/**
* @brief Destructor: cleans CVODE/SUNDIALS resources and frees SUNContext.
*/
~CVODESolverStrategy() override;
// Make the class non-copyable and non-movable to prevent shallow copies of CVODE pointers
CVODESolverStrategy(const CVODESolverStrategy&) = delete;
CVODESolverStrategy& operator=(const CVODESolverStrategy&) = delete;
CVODESolverStrategy(CVODESolverStrategy&&) = delete;
CVODESolverStrategy& operator=(CVODESolverStrategy&&) = delete;
/**
* @brief Integrate from t=0 to netIn.tMax and return final composition and energy.
@@ -114,6 +183,7 @@ namespace gridfire::solver {
* - At the end, converts molar abundances to mass fractions and assembles NetOut,
* including derivatives of energy w.r.t. T and rho from the engine.
*
* @param solver_ctx
* @param netIn Inputs: temperature [K], density [g cm^-3], tMax [s], composition.
* @return NetOut containing final Composition, accumulated energy [erg/g], step count,
* and dEps/dT, dEps/dRho.
@@ -122,10 +192,14 @@ namespace gridfire::solver {
* @throws exceptions::StaleEngineTrigger Propagated if the engine signals a stale state
* during RHS evaluation (captured in the wrapper then rethrown here).
*/
NetOut evaluate(const NetIn& netIn) override;
NetOut evaluate(
SolverContextBase& solver_ctx,
const NetIn& netIn
) const override;
/**
* @brief Call to evaluate which will let the user control if the trigger reasoning is displayed
* @param solver_ctx
* @param netIn Inputs: temperature [K], density [g cm^-3], tMax [s], composition.
* @param displayTrigger Boolean flag to control if trigger reasoning is displayed
* @param forceReinitialize Boolean flag to force reinitialization of CVODE resources at the start
@@ -136,89 +210,13 @@ namespace gridfire::solver {
* @throws exceptions::StaleEngineTrigger Propagated if the engine signals a stale state
* during RHS evaluation (captured in the wrapper then rethrown here).
*/
NetOut evaluate(const NetIn& netIn, bool displayTrigger, bool forceReinitialize = false);
NetOut evaluate(
SolverContextBase& solver_ctx,
const NetIn& netIn,
bool displayTrigger,
bool forceReinitialize = false
) const;
/**
* @brief Install a timestep callback.
* @param callback std::any containing TimestepCallback (std::function<void(const TimestepContext&)>).
* @throws std::bad_any_cast If callback is not of the expected type.
*/
void set_callback(const std::any &callback) override;
/**
* @brief Whether per-step logs are printed to stdout and CVode is stepped with CV_ONE_STEP.
*/
[[nodiscard]] bool get_stdout_logging_enabled() const;
/**
* @brief Enable/disable per-step stdout logging.
* @param logging_enabled Flag to control if a timestep summary is written to standard output or not
*/
void set_stdout_logging_enabled(bool logging_enabled);
void set_absTol(double absTol);
void set_relTol(double relTol);
double get_absTol() const;
double get_relTol() const;
/**
* @brief Schema of fields exposed to the timestep callback context.
*/
[[nodiscard]] std::vector<std::tuple<std::string, std::string>> describe_callback_context() const override;
/**
* @struct TimestepContext
* @brief Immutable view of the current integration state passed to callbacks.
*
* Fields capture CVODE time/state, step size, thermodynamic state, the engine reference,
* and the list of network species used to interpret the state vector layout.
*/
struct TimestepContext final : public SolverContextBase {
// This struct can be identical to the one in DirectNetworkSolver
const double t; ///< Current integration time [s].
const N_Vector& state; ///< Current CVODE state vector (N_Vector).
const double dt; ///< Last step size [s].
const double last_step_time; ///< Time at last callback [s].
const double T9; ///< Temperature in GK.
const double rho; ///< Density [g cm^-3].
const size_t num_steps; ///< Number of CVODE steps taken so far.
const engine::DynamicEngine& engine; ///< Reference to the engine.
const std::vector<fourdst::atomic::Species>& networkSpecies; ///< Species layout.
const size_t currentConvergenceFailures; ///< Total number of convergence failures
const size_t currentNonlinearIterations; ///< Total number of non-linear iterations
const std::map<fourdst::atomic::Species, std::unordered_map<std::string, double>>& reactionContributionMap; ///< Map of reaction contributions for the current step
engine::scratch::StateBlob& state_ctx; ///< Reference to the engine scratch state blob
/**
* @brief Construct a context snapshot.
*/
TimestepContext(
double t,
const N_Vector& state,
double dt,
double last_step_time,
double t9,
double rho,
size_t num_steps,
const engine::DynamicEngine& engine,
const std::vector<fourdst::atomic::Species>& networkSpecies,
size_t currentConvergenceFailure,
size_t currentNonlinearIterations,
const std::map<fourdst::atomic::Species, std::unordered_map<std::string, double>> &reactionContributionMap,
engine::scratch::StateBlob& state_ctx
);
/**
* @brief Human-readable description of the context fields.
*/
[[nodiscard]] std::vector<std::tuple<std::string, std::string>> describe() const override;
};
/**
* @brief Type alias for a timestep callback.
*/
using TimestepCallback = std::function<void(const TimestepContext& context)>; ///< Type alias for a timestep callback function.
private:
/**
* @struct CVODEUserData
@@ -230,7 +228,8 @@ namespace gridfire::solver {
* to CVODE, then the driver loop inspects and rethrows.
*/
struct CVODEUserData {
CVODESolverStrategy* solver_instance{}; // Pointer back to the class instance
const PointSolver* solver_instance{}; // Pointer back to the class instance
PointSolverContext* sctx; // Pointer to the solver context
engine::scratch::StateBlob& ctx;
const engine::DynamicEngine* engine{};
double T9{};
@@ -283,6 +282,7 @@ namespace gridfire::solver {
* step size, creates a dense matrix and dense linear solver, and registers the Jacobian.
*/
void initialize_cvode_integration_resources(
PointSolverContext* ctx,
uint64_t N,
size_t numSpecies,
double current_time,
@@ -290,15 +290,7 @@ namespace gridfire::solver {
double absTol,
double relTol,
double accumulatedEnergy
);
/**
* @brief Destroy CVODE vectors/linear algebra and optionally the CVODE memory block.
* @param memFree If true, also calls CVodeFree on m_cvode_mem.
*/
void cleanup_cvode_resources(bool memFree);
void set_detailed_step_logging(bool enabled);
) const;
/**
@@ -308,31 +300,13 @@ namespace gridfire::solver {
* sorted table of species with the highest error ratios; then invokes diagnostic routines to
* inspect Jacobian stiffness and species balance.
*/
void log_step_diagnostics(engine::scratch::StateBlob &ctx, const CVODEUserData& user_data, bool displayJacobianStiffness, bool
displaySpeciesBalance, bool to_file, std::optional<std::string> filename) const;
private:
SUNContext m_sun_ctx = nullptr; ///< SUNDIALS context (lifetime of the solver).
void* m_cvode_mem = nullptr; ///< CVODE memory block.
N_Vector m_Y = nullptr; ///< CVODE state vector (species + energy accumulator).
N_Vector m_YErr = nullptr; ///< Estimated local errors.
SUNMatrix m_J = nullptr; ///< Dense Jacobian matrix.
SUNLinearSolver m_LS = nullptr; ///< Dense linear solver.
std::optional<TimestepCallback> m_callback; ///< Optional per-step callback.
int m_num_steps = 0; ///< CVODE step counter (used for diagnostics and triggers).
bool m_stdout_logging_enabled = true; ///< If true, print per-step logs and use CV_ONE_STEP.
N_Vector m_constraints = nullptr; ///< CVODE constraints vector (>= 0 for species entries).
std::optional<double> m_absTol; ///< User-specified absolute tolerance.
std::optional<double> m_relTol; ///< User-specified relative tolerance.
bool m_detailed_step_logging = false; ///< If true, log detailed step diagnostics (error ratios, Jacobian, species balance).
mutable size_t m_last_size = 0;
mutable size_t m_last_composition_hash = 0ULL;
mutable sunrealtype m_last_good_time_step = 0ULL;
void log_step_diagnostics(
PointSolverContext* sctx_p,
engine::scratch::StateBlob &ctx,
const CVODEUserData& user_data,
bool displayJacobianStiffness,
bool displaySpeciesBalance,
bool to_file, std::optional<std::string> filename
) const;
};
}

View File

@@ -1,221 +0,0 @@
#pragma once
#include "gridfire/solver/strategies/strategy_abstract.h"
#include "gridfire/engine/engine_abstract.h"
#include "gridfire/types/types.h"
#include "gridfire/config/config.h"
#include "fourdst/logging/logging.h"
#include "fourdst/constants/const.h"
#include <vector>
#include <functional>
#include <cvode/cvode.h>
#include <sundials/sundials_types.h>
#include "gridfire/exceptions/error_engine.h"
#ifdef SUNDIALS_HAVE_OPENMP
#include <nvector/nvector_openmp.h>
#endif
#ifdef SUNDIALS_HAVE_PTHREADS
#include <nvector/nvector_pthreads.hh>
#endif
#ifndef SUNDIALS_HAVE_OPENMP
#ifndef SUNDIALS_HAVE_PTHREADS
#include <nvector/nvector_serial.h>
#endif
#endif
namespace gridfire::solver {
class SpectralSolverStrategy final : public MultiZoneDynamicNetworkSolver {
public:
explicit SpectralSolverStrategy(const engine::DynamicEngine& engine);
~SpectralSolverStrategy() override;
std::vector<NetOut> evaluate(
const std::vector<NetIn> &netIns,
const std::vector<double>& mass_coords, const engine::scratch::StateBlob &ctx_template
) override;
void set_callback(const std::any &callback) override;
[[nodiscard]] std::vector<std::tuple<std::string, std::string>> describe_callback_context() const override;
[[nodiscard]] bool get_stdout_logging_enabled() const;
void set_stdout_logging_enabled(bool logging_enabled);
public:
struct TimestepContext final : public SolverContextBase {
TimestepContext(
const double t,
const N_Vector &state,
const double dt,
const double last_time_step,
const engine::DynamicEngine &engine
) :
t(t),
state(state),
dt(dt),
last_time_step(last_time_step),
engine(engine) {}
[[nodiscard]] std::vector<std::tuple<std::string, std::string>> describe() const override;
const double t;
const N_Vector& state;
const double dt;
const double last_time_step;
const engine::DynamicEngine& engine;
};
struct BasisEval {
size_t start_idx;
std::vector<double> phi;
};
struct SplineBasis {
std::vector<double> knots;
std::vector<double> quadrature_nodes;
std::vector<double> quadrature_weights;
int degree = 3;
std::vector<BasisEval> quad_evals;
};
public:
using TimestepCallback = std::function<void(const TimestepContext&)>;
private:
enum class ParallelInitializationResult : uint8_t {
SUCCESS,
FAILURE
};
struct SpectralCoefficients {
size_t num_sets;
size_t num_coefficients;
std::vector<double> coefficients;
double operator()(size_t i, size_t j) const;
};
struct GridPoint {
double T9;
double rho;
fourdst::composition::Composition composition;
};
struct Constants {
const double c = fourdst::constant::Constants::getInstance().get("c").value;
const double N_a = fourdst::constant::Constants::getInstance().get("N_a").value;
const double c2 = c * c;
};
struct DenseLinearSolver {
SUNMatrix A;
SUNLinearSolver LS;
N_Vector temp_vector;
SUNContext ctx;
DenseLinearSolver(size_t size, SUNContext sun_ctx);
~DenseLinearSolver();
DenseLinearSolver(const DenseLinearSolver&) = delete;
DenseLinearSolver& operator=(const DenseLinearSolver&) = delete;
void setup() const;
void zero() const;
void init_from_cache(size_t num_basis_funcs, const std::vector<BasisEval>& shell_cache) const;
void init_from_basis(size_t num_basis_funcs, const SplineBasis& basis) const;
void solve_inplace(N_Vector x, size_t num_vars, size_t basis_size) const;
void solve_inplace_ptr(sunrealtype* data_ptr, size_t num_vars, size_t basis_size) const;
};
struct CVODEUserData {
SpectralSolverStrategy* solver_instance{};
std::vector<std::reference_wrapper<engine::scratch::StateBlob>> workspaces;
const engine::DynamicEngine* engine{};
std::unique_ptr<exceptions::EngineError> captured_exception{};
std::vector<double> T9{};
std::vector<double> rho{};
double energy{};
double neutrino_energy_loss_rate = 0.0;
double total_neutrino_flux = 0.0;
DenseLinearSolver* mass_matrix_solver_instance{};
const SplineBasis* basis{};
};
private:
fourdst::config::Config<config::GridFireConfig> m_config;
quill::Logger* m_logger = fourdst::logging::LogManager::getInstance().getLogger("log");
SUNContext m_sun_ctx = nullptr; ///< SUNDIALS context (lifetime of the solver).
void* m_cvode_mem = nullptr; ///< CVODE memory block.
N_Vector m_Y = nullptr; ///< CVODE state vector (species + energy accumulator).
SUNMatrix m_J = nullptr; ///< Dense Jacobian matrix.
SUNLinearSolver m_LS = nullptr; ///< Dense linear solver.
std::optional<TimestepCallback> m_callback; ///< Optional per-step callback.
int m_num_steps = 0; ///< CVODE step counter (used for diagnostics and triggers).
bool m_stdout_logging_enabled = true; ///< If true, print per-step logs and use CV_ONE_STEP.
N_Vector m_constraints = nullptr; ///< CVODE constraints vector (>= 0 for species entries).
std::optional<double> m_absTol; ///< User-specified absolute tolerance.
std::optional<double> m_relTol; ///< User-specified relative tolerance.
bool m_detailed_step_logging = false; ///< If true, log detailed step diagnostics (error ratios, Jacobian, species balance).
mutable size_t m_last_size = 0;
mutable size_t m_last_composition_hash = 0ULL;
mutable sunrealtype m_last_good_time_step = 0ULL;
SplineBasis m_current_basis;
Constants m_constants;
N_Vector m_T_coeffs = nullptr;
N_Vector m_rho_coeffs = nullptr;
std::vector<fourdst::atomic::Species> m_global_species_list;
private:
std::vector<double> evaluate_monitor_function(const std::vector<NetIn>& current_shells) const;
static SplineBasis generate_basis_from_monitor(const std::vector<double>& monitor_values, const std::vector<double>& mass_coordinates, size_t actual_elements);
GridPoint reconstruct_at_quadrature(const N_Vector y_coeffs, size_t quad_index, const SplineBasis &basis) const;
std::vector<NetOut> reconstruct_solution(const std::vector<NetIn>& original_inputs, const std::vector<double>& mass_coordinates, const N_Vector final_coeffs, const SplineBasis& basis, double dt) const;
static int cvode_rhs_wrapper(sunrealtype t, N_Vector y, N_Vector, void* user_data);
static int cvode_jac_wrapper(sunrealtype t, N_Vector y, N_Vector ydot, SUNMatrix J, void* user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3);
int calculate_rhs(sunrealtype t, N_Vector y_coeffs, N_Vector ydot_coeffs, CVODEUserData* data) const;
int calculate_jacobian(sunrealtype t, N_Vector y_coeffs, N_Vector ydot_coeffs, SUNMatrix J, const CVODEUserData *data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3) const;
static size_t nyquist_elements(size_t requested_elements, size_t num_shells) ;
static void project_specific_variable(
const std::vector<NetIn>& current_shells,
const std::vector<double>& mass_coordinates,
const std::vector<BasisEval>& shell_cache,
const DenseLinearSolver& linear_solver,
N_Vector output_vec,
size_t output_offset,
const std::function<double(const NetIn&)> &getter,
bool use_log
);
void inspect_jacobian(SUNMatrix J, const std::string& context) const;
};
}

View File

@@ -2,5 +2,5 @@
#include "gridfire/solver/strategies/triggers/triggers.h"
#include "gridfire/solver/strategies/strategy_abstract.h"
#include "gridfire/solver/strategies/CVODE_solver_strategy.h"
#include "gridfire/solver/strategies/SpectralSolverStrategy.h"
#include "gridfire/solver/strategies/PointSolver.h"
#include "gridfire/solver/strategies/GridSolver.h"

View File

@@ -13,17 +13,24 @@ namespace gridfire::solver {
template <typename EngineT>
concept IsEngine = std::is_base_of_v<engine::Engine, EngineT>;
struct SolverContextBase {
virtual void init() = 0;
virtual void set_stdout_logging(bool enable) = 0;
virtual void set_detailed_logging(bool enable) = 0;
virtual ~SolverContextBase() = default;
};
/**
* @struct SolverContextBase
* @struct TimestepContextBase
* @brief Base class for solver callback contexts.
*
* This struct serves as a base class for contexts that can be passed to solver callbacks, it enforces
* that derived classes implement a `describe` method that returns a vector of tuples describing
* the context that a callback will receive when called.
*/
class SolverContextBase {
class TimestepContextBase {
public:
virtual ~SolverContextBase() = default;
virtual ~TimestepContextBase() = default;
/**
* @brief Describe the context for callback functions.
@@ -54,11 +61,9 @@ namespace gridfire::solver {
* @param engine The engine to use for evaluating the network.
*/
explicit SingleZoneNetworkSolver(
const EngineT& engine,
const engine::scratch::StateBlob& ctx
const EngineT& engine
) :
m_engine(engine),
m_scratch_blob(ctx.clone_structure()) {};
m_engine(engine) {};
/**
* @brief Virtual destructor.
@@ -67,58 +72,39 @@ namespace gridfire::solver {
/**
* @brief Evaluates the network for a given timestep.
* @param solver_ctx
* @param engine_ctx
* @param netIn The input conditions for the network.
* @return The output conditions after the timestep.
*/
virtual NetOut evaluate(const NetIn& netIn) = 0;
virtual NetOut evaluate(
SolverContextBase& solver_ctx,
const NetIn& netIn
) const = 0;
/**
* @brief set the callback function to be called at the end of each timestep.
*
* This function allows the user to set a callback function that will be called at the end of each timestep.
* The callback function will receive a gridfire::solver::<SOMESOLVER>::TimestepContext object. Note that
* depending on the solver, this context may contain different information. Further, the exact
* signature of the callback function is left up to each solver. Every solver should provide a type or type alias
* TimestepCallback that defines the signature of the callback function so that the user can easily
* get that type information.
*
* @param callback The callback function to be called at the end of each timestep.
*/
virtual void set_callback(const std::any& callback) = 0;
/**
* @brief Describe the context that will be passed to the callback function.
* @return A vector of tuples, each containing a string for the parameter's name and a string for its type.
*
* This method should be overridden by derived classes to provide a description of the context
* that will be passed to the callback function. The intent of this method is that an end user can investigate
* the context that will be passed to the callback function, and use this information to craft their own
* callback function.
*/
[[nodiscard]] virtual std::vector<std::tuple<std::string, std::string>> describe_callback_context() const = 0;
protected:
const EngineT& m_engine; ///< The engine used by this solver strategy.
std::unique_ptr<engine::scratch::StateBlob> m_scratch_blob;
};
template <IsEngine EngineT>
class MultiZoneNetworkSolver {
public:
explicit MultiZoneNetworkSolver(
const EngineT& engine
const EngineT& engine,
const SingleZoneNetworkSolver<EngineT>& solver
) :
m_engine(engine) {};
m_engine(engine),
m_solver(solver) {};
virtual ~MultiZoneNetworkSolver() = default;
virtual std::vector<NetOut> evaluate(
const std::vector<NetIn>& netIns,
const std::vector<double>& mass_coords, const engine::scratch::StateBlob &ctx_template
) = 0;
virtual void set_callback(const std::any& callback) = 0;
[[nodiscard]] virtual std::vector<std::tuple<std::string, std::string>> describe_callback_context() const = 0;
SolverContextBase& solver_ctx,
const std::vector<NetIn>& netIns
) const = 0;
protected:
const EngineT& m_engine; ///< The engine used by this solver strategy.
const SingleZoneNetworkSolver<EngineT>& m_solver;
};
/**

View File

@@ -2,7 +2,7 @@
#include "gridfire/trigger/trigger_abstract.h"
#include "gridfire/trigger/trigger_result.h"
#include "gridfire/solver/strategies/CVODE_solver_strategy.h"
#include "gridfire/solver/strategies/PointSolver.h"
#include "fourdst/logging/logging.h"
#include <string>
@@ -47,7 +47,7 @@ namespace gridfire::trigger::solver::CVODE {
*
* See also: engine_partitioning_trigger.cpp for the concrete logic and logging.
*/
class SimulationTimeTrigger final : public Trigger<gridfire::solver::CVODESolverStrategy::TimestepContext> {
class SimulationTimeTrigger final : public Trigger<gridfire::solver::PointSolverTimestepContext> {
public:
/**
* @brief Construct with a positive time interval between firings.
@@ -62,7 +62,7 @@ namespace gridfire::trigger::solver::CVODE {
*
* @post increments hit/miss counters and may emit trace logs.
*/
bool check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override;
bool check(const gridfire::solver::PointSolverTimestepContext &ctx) const override;
/**
* @brief Update internal state; if check(ctx) is true, advance last_trigger_time.
* @param ctx CVODE timestep context.
@@ -70,9 +70,9 @@ namespace gridfire::trigger::solver::CVODE {
* @note update() calls check(ctx) and, on success, records the overshoot delta
* (ctx.t - last_trigger_time) - interval for diagnostics.
*/
void update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override;
void update(const gridfire::solver::PointSolverTimestepContext &ctx) override;
void step(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override;
void step(const gridfire::solver::PointSolverTimestepContext &ctx) override;
/**
* @brief Reset counters and last trigger bookkeeping (time and delta) to zero.
*/
@@ -85,7 +85,7 @@ namespace gridfire::trigger::solver::CVODE {
* @param ctx CVODE timestep context.
* @return TriggerResult including name, value, and description.
*/
TriggerResult why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override;
TriggerResult why(const gridfire::solver::PointSolverTimestepContext &ctx) const override;
/** @brief Textual description including configured interval. */
std::string describe() const override;
/** @brief Number of true evaluations since last reset. */
@@ -130,7 +130,7 @@ namespace gridfire::trigger::solver::CVODE {
* @par See also
* - engine_partitioning_trigger.cpp for concrete logic and trace logging.
*/
class OffDiagonalTrigger final : public Trigger<gridfire::solver::CVODESolverStrategy::TimestepContext> {
class OffDiagonalTrigger final : public Trigger<gridfire::solver::PointSolverTimestepContext> {
public:
/**
* @brief Construct with a non-negative magnitude threshold.
@@ -145,13 +145,13 @@ namespace gridfire::trigger::solver::CVODE {
*
* @post increments hit/miss counters and may emit trace logs.
*/
bool check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override;
bool check(const gridfire::solver::PointSolverTimestepContext &ctx) const override;
/**
* @brief Record an update; does not mutate any Jacobian-related state.
* @param ctx CVODE timestep context (unused except for symmetry with interface).
*/
void update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override;
void step(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override;
void update(const gridfire::solver::PointSolverTimestepContext &ctx) override;
void step(const gridfire::solver::PointSolverTimestepContext &ctx) override;
/** @brief Reset counters to zero. */
void reset() override;
@@ -161,7 +161,7 @@ namespace gridfire::trigger::solver::CVODE {
* @brief Structured explanation of the evaluation outcome.
* @param ctx CVODE timestep context.
*/
TriggerResult why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override;
TriggerResult why(const gridfire::solver::PointSolverTimestepContext &ctx) const override;
/** @brief Textual description including configured threshold. */
std::string describe() const override;
/** @brief Number of true evaluations since last reset. */
@@ -206,7 +206,7 @@ namespace gridfire::trigger::solver::CVODE {
*
* See also: engine_partitioning_trigger.cpp for exact logic and logging.
*/
class TimestepCollapseTrigger final : public Trigger<gridfire::solver::CVODESolverStrategy::TimestepContext> {
class TimestepCollapseTrigger final : public Trigger<gridfire::solver::PointSolverTimestepContext> {
public:
/**
* @brief Construct with threshold and relative/absolute mode; window size defaults to 1.
@@ -230,20 +230,20 @@ namespace gridfire::trigger::solver::CVODE {
*
* @post increments hit/miss counters and may emit trace logs.
*/
bool check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override;
bool check(const gridfire::solver::PointSolverTimestepContext &ctx) const override;
/**
* @brief Update sliding window with the most recent dt and increment update counter.
* @param ctx CVODE timestep context.
*/
void update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override;
void step(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override;
void update(const gridfire::solver::PointSolverTimestepContext &ctx) override;
void step(const gridfire::solver::PointSolverTimestepContext &ctx) override;
/** @brief Reset counters and clear the dt window. */
void reset() override;
/** @brief Stable human-readable name. */
std::string name() const override;
/** @brief Structured explanation of the evaluation outcome. */
TriggerResult why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override;
TriggerResult why(const gridfire::solver::PointSolverTimestepContext &ctx) const override;
/** @brief Textual description including threshold, mode, and window size. */
std::string describe() const override;
/** @brief Number of true evaluations since last reset. */
@@ -272,15 +272,15 @@ namespace gridfire::trigger::solver::CVODE {
std::deque<double> m_timestep_window;
};
class ConvergenceFailureTrigger final : public Trigger<gridfire::solver::CVODESolverStrategy::TimestepContext> {
class ConvergenceFailureTrigger final : public Trigger<gridfire::solver::PointSolverTimestepContext> {
public:
explicit ConvergenceFailureTrigger(size_t totalFailures, float relativeFailureRate, size_t windowSize);
bool check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override;
bool check(const gridfire::solver::PointSolverTimestepContext &ctx) const override;
void update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override;
void update(const gridfire::solver::PointSolverTimestepContext &ctx) override;
void step(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override;
void step(const gridfire::solver::PointSolverTimestepContext &ctx) override;
void reset() override;
@@ -288,7 +288,7 @@ namespace gridfire::trigger::solver::CVODE {
[[nodiscard]] std::string describe() const override;
[[nodiscard]] TriggerResult why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override;
[[nodiscard]] TriggerResult why(const gridfire::solver::PointSolverTimestepContext &ctx) const override;
[[nodiscard]] size_t numTriggers() const override;
@@ -312,8 +312,8 @@ namespace gridfire::trigger::solver::CVODE {
private:
float current_mean() const;
bool abs_failure(const gridfire::solver::CVODESolverStrategy::TimestepContext& ctx) const;
bool rel_failure(const gridfire::solver::CVODESolverStrategy::TimestepContext& ctx) const;
bool abs_failure(const gridfire::solver::PointSolverTimestepContext& ctx) const;
bool rel_failure(const gridfire::solver::PointSolverTimestepContext& ctx) const;
};
/**
@@ -337,10 +337,10 @@ namespace gridfire::trigger::solver::CVODE {
*
* @note The exact policy is subject to change; this function centralizes that decision.
*/
std::unique_ptr<Trigger<gridfire::solver::CVODESolverStrategy::TimestepContext>> makeEnginePartitioningTrigger(
const double simulationTimeInterval,
const double offDiagonalThreshold,
const double timestepCollapseRatio,
const size_t maxConvergenceFailures
std::unique_ptr<Trigger<gridfire::solver::PointSolverTimestepContext>> makeEnginePartitioningTrigger(
double simulationTimeInterval,
double offDiagonalThreshold,
double timestepCollapseRatio,
size_t maxConvergenceFailures
);
}

View File

@@ -30,6 +30,7 @@ namespace gridfire::omp {
);
CppAD::thread_alloc::hold_memory(true);
CppAD::CheckSimpleVector<double, std::vector<double>>(0, 1);
s_par_mode_initialized = true;
}
}