feat(SpectralSolver): Began work on multizone spectral solver

The single zone solver we have is too slow for a true high resolution
multi-zone enviroment. Began work on a spectral element method
multi-zone solver
This commit is contained in:
2025-12-10 12:50:35 -05:00
parent b57ed57166
commit 97a7fd05d2
16 changed files with 1100 additions and 91 deletions

View File

@@ -8,8 +8,27 @@ namespace gridfire::config {
double relTol = 1.0e-5;
};
struct SpectralSolverConfig {
struct MonitorFunctionConfig {
double structure_weight = 1.0;
double abundance_weight = 10.0;
double alpha = 0.2;
double beta = 0.8;
};
struct BasisConfig {
size_t num_elements = 50;
};
double absTol = 1.0e-8;
double relTol = 1.0e-5;
size_t degree = 3;
MonitorFunctionConfig monitorFunction;
BasisConfig basis;
};
struct SolverConfig {
CVODESolverConfig cvode;
SpectralSolverConfig spectral;
};
struct AdaptiveEngineViewConfig {

View File

@@ -514,5 +514,9 @@ namespace gridfire::engine {
*/
[[nodiscard]] virtual SpeciesStatus getSpeciesStatus(const fourdst::atomic::Species& species) const = 0;
[[nodiscard]] virtual std::optional<StepDerivatives<double>> getMostRecentRHSCalculation() const {
return std::nullopt;
}
};
}

View File

@@ -14,6 +14,8 @@
#include "gridfire/engine/procedures/construction.h"
#include "gridfire/config/config.h"
#include "ankerl/unordered_dense.h"
#include <string>
#include <unordered_map>
#include <vector>
@@ -764,6 +766,8 @@ namespace gridfire::engine {
m_store_intermediate_reaction_contributions = value;
}
[[nodiscard]] std::optional<StepDerivatives<double>> getMostRecentRHSCalculation() const override;
private:
struct PrecomputedReaction {
@@ -887,6 +891,7 @@ namespace gridfire::engine {
mutable std::unordered_map<size_t, StepDerivatives<double>> m_stepDerivativesCache;
mutable std::unordered_map<size_t, CppAD::sparse_rcv<std::vector<size_t>, std::vector<double>>> m_jacobianSubsetCache;
mutable std::unordered_map<size_t, CppAD::sparse_jac_work> m_jacWorkCache;
mutable std::optional<StepDerivatives<double>> m_most_recent_rhs_calculation;
bool m_has_been_primed = false; ///< Flag indicating if the engine has been primed.

View File

@@ -37,7 +37,6 @@
#ifdef SUNDIALS_HAVE_PTHREADS
#include <nvector/nvector_pthreads.hh>
#endif
// Default to serial if no parallelism is enabled
#ifndef SUNDIALS_HAVE_OPENMP
#ifndef SUNDIALS_HAVE_PTHREADS
#include <nvector/nvector_serial.h>
@@ -79,7 +78,7 @@ namespace gridfire::solver {
* std::cout << "Final energy: " << out.energy << " erg/g\n";
* @endcode
*/
class CVODESolverStrategy final : public DynamicNetworkSolverStrategy {
class CVODESolverStrategy final : public SingleZoneDynamicNetworkSolverStrategy {
public:
/**
* @brief Construct the CVODE strategy and create a SUNDIALS context.

View File

@@ -0,0 +1,196 @@
#pragma once
#include "gridfire/solver/strategies/strategy_abstract.h"
#include "gridfire/engine/engine_abstract.h"
#include "gridfire/types/types.h"
#include "gridfire/config/config.h"
#include "fourdst/logging/logging.h"
#include "fourdst/constants/const.h"
#include <vector>
#include <cvode/cvode.h>
#include <sundials/sundials_types.h>
#ifdef SUNDIALS_HAVE_OPENMP
#include <nvector/nvector_openmp.h>
#endif
#ifdef SUNDIALS_HAVE_PTHREADS
#include <nvector/nvector_pthreads.hh>
#endif
#ifndef SUNDIALS_HAVE_OPENMP
#ifndef SUNDIALS_HAVE_PTHREADS
#include <nvector/nvector_serial.h>
#endif
#endif
namespace gridfire::solver {
class SpectralSolverStrategy final : public MultiZoneDynamicNetworkSolverStrategy {
public:
explicit SpectralSolverStrategy(engine::DynamicEngine& engine);
~SpectralSolverStrategy() override;
std::vector<NetOut> evaluate(
const std::vector<NetIn> &netIns,
const std::vector<double>& mass_coords
) override;
void set_callback(const std::any &callback) override;
[[nodiscard]] std::vector<std::tuple<std::string, std::string>> describe_callback_context() const override;
[[nodiscard]] bool get_stdout_logging_enabled() const;
void set_stdout_logging_enabled(bool logging_enabled);
public:
struct TimestepContext final : public SolverContextBase {
TimestepContext(
const double t,
const N_Vector &state,
const double dt,
const double last_time_step,
const engine::DynamicEngine &engine
) :
t(t),
state(state),
dt(dt),
last_time_step(last_time_step),
engine(engine) {}
[[nodiscard]] std::vector<std::tuple<std::string, std::string>> describe() const override;
const double t;
const N_Vector& state;
const double dt;
const double last_time_step;
const engine::DynamicEngine& engine;
};
struct BasisEval {
size_t start_idx;
std::vector<double> phi;
};
struct SplineBasis {
std::vector<double> knots;
std::vector<double> quadrature_nodes;
std::vector<double> quadrature_weights;
int degree = 3;
std::vector<BasisEval> quad_evals;
};
public:
using TimestepCallback = std::function<void(const TimestepContext&)>;
private:
struct SpectralCoefficients {
size_t num_sets;
size_t num_coefficients;
std::vector<double> coefficients;
double operator()(size_t i, size_t j) const;
};
struct GridPoint {
double T9;
double rho;
fourdst::composition::Composition composition;
};
struct Constants {
const double c = fourdst::constant::Constants::getInstance().get("c").value;
const double N_a = fourdst::constant::Constants::getInstance().get("N_a").value;
const double c2 = c * c;
};
struct DenseLinearSolver {
SUNMatrix A;
SUNLinearSolver LS;
N_Vector temp_vector;
SUNContext ctx;
DenseLinearSolver(size_t size, SUNContext sun_ctx);
~DenseLinearSolver();
DenseLinearSolver(const DenseLinearSolver&) = delete;
DenseLinearSolver& operator=(const DenseLinearSolver&) = delete;
void setup() const;
void zero() const;
void init_from_cache(size_t num_basis_funcs, const std::vector<BasisEval>& shell_cache) const;
void init_from_basis(size_t num_basis_funcs, const SplineBasis& basis) const;
void solve_inplace(N_Vector x, size_t num_vars, size_t basis_size) const;
};
struct CVODEUserData {
SpectralSolverStrategy* solver_instance{};
engine::DynamicEngine* engine;
DenseLinearSolver* mass_matrix_solver_instance;
const SplineBasis* basis;
};
private:
fourdst::config::Config<config::GridFireConfig> m_config;
quill::Logger* m_logger = fourdst::logging::LogManager::getInstance().getLogger("log");
SUNContext m_sun_ctx = nullptr; ///< SUNDIALS context (lifetime of the solver).
void* m_cvode_mem = nullptr; ///< CVODE memory block.
N_Vector m_Y = nullptr; ///< CVODE state vector (species + energy accumulator).
SUNMatrix m_J = nullptr; ///< Dense Jacobian matrix.
SUNLinearSolver m_LS = nullptr; ///< Dense linear solver.
std::optional<TimestepCallback> m_callback; ///< Optional per-step callback.
int m_num_steps = 0; ///< CVODE step counter (used for diagnostics and triggers).
bool m_stdout_logging_enabled = true; ///< If true, print per-step logs and use CV_ONE_STEP.
N_Vector m_constraints = nullptr; ///< CVODE constraints vector (>= 0 for species entries).
std::optional<double> m_absTol; ///< User-specified absolute tolerance.
std::optional<double> m_relTol; ///< User-specified relative tolerance.
bool m_detailed_step_logging = false; ///< If true, log detailed step diagnostics (error ratios, Jacobian, species balance).
mutable size_t m_last_size = 0;
mutable size_t m_last_composition_hash = 0ULL;
mutable sunrealtype m_last_good_time_step = 0ULL;
SplineBasis m_current_basis;
Constants m_constants;
N_Vector m_T_coeffs = nullptr;
N_Vector m_rho_coeffs = nullptr;
private:
std::vector<double> evaluate_monitor_function(const std::vector<NetIn>& current_shells) const;
SplineBasis generate_basis_from_monitor(const std::vector<double>& monitor_values, const std::vector<double>& mass_coordinates) const;
GridPoint reconstruct_at_quadrature(const N_Vector y_coeffs, size_t quad_index, const SplineBasis &basis) const;
std::vector<NetOut> reconstruct_solution(const std::vector<NetIn>& original_inputs, const std::vector<double>& mass_coordinates, const N_Vector final_coeffs, const SplineBasis& basis, double dt) const;
static int cvode_rhs_wrapper(sunrealtype t, N_Vector y, N_Vector, void* user_data);
static int cvode_jac_wrapper(sunrealtype t, N_Vector y, N_Vector ydot, SUNMatrix J, void* user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3);
int calculate_rhs(sunrealtype t, N_Vector y_coeffs, N_Vector ydot_coeffs, CVODEUserData* data) const;
static void project_specific_variable(
const std::vector<NetIn>& current_shells,
const std::vector<double>& mass_coordinates,
const std::vector<BasisEval>& shell_cache,
const DenseLinearSolver& linear_solver,
N_Vector output_vec,
size_t output_offset,
const std::function<double(const NetIn&)> &getter,
bool use_log
);
};
}

View File

@@ -2,4 +2,5 @@
#include "gridfire/solver/strategies/triggers/triggers.h"
#include "gridfire/solver/strategies/strategy_abstract.h"
#include "gridfire/solver/strategies/CVODE_solver_strategy.h"
#include "gridfire/solver/strategies/CVODE_solver_strategy.h"
#include "gridfire/solver/strategies/SpectralSolverStrategy.h"

View File

@@ -10,6 +10,9 @@
#include <string>
namespace gridfire::solver {
template <typename EngineT>
concept IsEngine = std::is_base_of_v<engine::Engine, EngineT>;
/**
* @struct SolverContextBase
* @brief Base class for solver callback contexts.
@@ -34,7 +37,7 @@ namespace gridfire::solver {
[[nodiscard]] virtual std::vector<std::tuple<std::string, std::string>> describe() const = 0;
};
/**
* @class NetworkSolverStrategy
* @class SingleZoneNetworkSolverStrategy
* @brief Abstract base class for network solver strategies.
*
* This class defines the interface for network solver strategies, which are responsible
@@ -43,19 +46,19 @@ namespace gridfire::solver {
*
* @tparam EngineT The type of engine to use with this solver strategy. Must inherit from Engine.
*/
template <typename EngineT>
class NetworkSolverStrategy {
template <IsEngine EngineT>
class SingleZoneNetworkSolverStrategy {
public:
/**
* @brief Constructor for the NetworkSolverStrategy.
* @param engine The engine to use for evaluating the network.
*/
explicit NetworkSolverStrategy(EngineT& engine) : m_engine(engine) {};
explicit SingleZoneNetworkSolverStrategy(EngineT& engine) : m_engine(engine) {};
/**
* @brief Virtual destructor.
*/
virtual ~NetworkSolverStrategy() = default;
virtual ~SingleZoneNetworkSolverStrategy() = default;
/**
* @brief Evaluates the network for a given timestep.
@@ -92,8 +95,25 @@ namespace gridfire::solver {
EngineT& m_engine; ///< The engine used by this solver strategy.
};
template <IsEngine EngineT>
class MultiZoneNetworkSolverStrategy {
public:
explicit MultiZoneNetworkSolverStrategy(EngineT& engine) : m_engine(engine) {};
virtual ~MultiZoneNetworkSolverStrategy() = default;
virtual std::vector<NetOut> evaluate(
const std::vector<NetIn>& netIns,
const std::vector<double>& mass_coords
) = 0;
virtual void set_callback(const std::any& callback) = 0;
[[nodiscard]] virtual std::vector<std::tuple<std::string, std::string>> describe_callback_context() const = 0;
protected:
EngineT& m_engine; ///< The engine used by this solver strategy.
};
/**
* @brief Type alias for a network solver strategy that uses a DynamicEngine.
*/
using DynamicNetworkSolverStrategy = NetworkSolverStrategy<engine::DynamicEngine>;
using SingleZoneDynamicNetworkSolverStrategy = SingleZoneNetworkSolverStrategy<engine::DynamicEngine>;
using MultiZoneDynamicNetworkSolverStrategy = MultiZoneNetworkSolverStrategy<engine::DynamicEngine>;
}