feat(SpectralSolver): Spectral Solver now works in a limited fashion

Major work on spectral solver, can now evolve up to about a year. At
that point we likely need to impliment repartitioning logic to stabalize
the network or some other scheme based on the jacobian structure
This commit is contained in:
2025-12-12 17:24:53 -05:00
parent e114c0e240
commit 0b09ed1cb3
17 changed files with 653 additions and 150 deletions

View File

@@ -25,14 +25,14 @@
#endif
namespace gridfire::solver {
class SpectralSolverStrategy final : public MultiZoneDynamicNetworkSolverStrategy {
class SpectralSolverStrategy final : public MultiZoneDynamicNetworkSolver {
public:
explicit SpectralSolverStrategy(engine::DynamicEngine& engine);
explicit SpectralSolverStrategy(const engine::DynamicEngine& engine);
~SpectralSolverStrategy() override;
std::vector<NetOut> evaluate(
const std::vector<NetIn> &netIns,
const std::vector<double>& mass_coords
const std::vector<double>& mass_coords, const engine::scratch::StateBlob &ctx_template
) override;
void set_callback(const std::any &callback) override;
@@ -83,6 +83,11 @@ namespace gridfire::solver {
using TimestepCallback = std::function<void(const TimestepContext&)>;
private:
enum class ParallelInitializationResult : uint8_t {
SUCCESS,
FAILURE
};
struct SpectralCoefficients {
size_t num_sets;
size_t num_coefficients;
@@ -122,14 +127,16 @@ namespace gridfire::solver {
void init_from_basis(size_t num_basis_funcs, const SplineBasis& basis) const;
void solve_inplace(N_Vector x, size_t num_vars, size_t basis_size) const;
void solve_inplace_ptr(sunrealtype* data_ptr, size_t num_vars, size_t basis_size) const;
};
struct CVODEUserData {
SpectralSolverStrategy* solver_instance{};
engine::DynamicEngine* engine;
const engine::DynamicEngine* engine{};
DenseLinearSolver* mass_matrix_solver_instance;
const SplineBasis* basis;
DenseLinearSolver* mass_matrix_solver_instance{};
const SplineBasis* basis{};
std::vector<std::unique_ptr<engine::scratch::StateBlob>>* workspaces;
};
private:
@@ -166,10 +173,12 @@ namespace gridfire::solver {
N_Vector m_T_coeffs = nullptr;
N_Vector m_rho_coeffs = nullptr;
std::vector<fourdst::atomic::Species> m_global_species_list;
private:
std::vector<double> evaluate_monitor_function(const std::vector<NetIn>& current_shells) const;
SplineBasis generate_basis_from_monitor(const std::vector<double>& monitor_values, const std::vector<double>& mass_coordinates) const;
SplineBasis generate_basis_from_monitor(const std::vector<double>& monitor_values, const std::vector<double>& mass_coordinates, size_t actual_elements) const;
GridPoint reconstruct_at_quadrature(const N_Vector y_coeffs, size_t quad_index, const SplineBasis &basis) const;
@@ -179,6 +188,9 @@ namespace gridfire::solver {
static int cvode_jac_wrapper(sunrealtype t, N_Vector y, N_Vector ydot, SUNMatrix J, void* user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3);
int calculate_rhs(sunrealtype t, N_Vector y_coeffs, N_Vector ydot_coeffs, CVODEUserData* data) const;
int calculate_jacobian(sunrealtype t, N_Vector y_coeffs, N_Vector ydot_coeffs, SUNMatrix J, const CVODEUserData *data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3) const;
static size_t nyquist_elements(size_t requested_elements, size_t num_shells) ;
static void project_specific_variable(
const std::vector<NetIn>& current_shells,
@@ -191,6 +203,7 @@ namespace gridfire::solver {
bool use_log
);
void inspect_jacobian(SUNMatrix J, const std::string& context) const;
};
}

View File

@@ -105,23 +105,20 @@ namespace gridfire::solver {
class MultiZoneNetworkSolver {
public:
explicit MultiZoneNetworkSolver(
const EngineT& engine,
const engine::scratch::StateBlob& ctx
const EngineT& engine
) :
m_engine(engine),
m_scratch_blob_structure(ctx.clone_structure()){};
m_engine(engine) {};
virtual ~MultiZoneNetworkSolver() = default;
virtual std::vector<NetOut> evaluate(
const std::vector<NetIn>& netIns,
const std::vector<double>& mass_coords
const std::vector<double>& mass_coords, const engine::scratch::StateBlob &ctx_template
) = 0;
virtual void set_callback(const std::any& callback) = 0;
[[nodiscard]] virtual std::vector<std::tuple<std::string, std::string>> describe_callback_context() const = 0;
protected:
const EngineT& m_engine; ///< The engine used by this solver strategy.
std::unique_ptr<engine::scratch::StateBlob> m_scratch_blob_structure;
};
/**

View File

@@ -59,3 +59,26 @@ namespace gridfire {
concept IsArithmeticOrAD = std::is_same_v<T, double> || std::is_same_v<T, CppAD::AD<double>>;
} // namespace nuclearNetwork
template<>
struct std::formatter<gridfire::NetIn> : std::formatter<std::string> {
auto format(const gridfire::NetIn& netIn, auto& ctx) {
std::string output = "NetIn(, tMax=" + std::to_string(netIn.tMax) +
", dt0=" + std::to_string(netIn.dt0) +
", temperature=" + std::to_string(netIn.temperature) +
", density=" + std::to_string(netIn.density) +
", energy=" + std::to_string(netIn.energy) + ")";
return std::formatter<std::string>::format(output, ctx);
}
};
template <>
struct std::formatter<gridfire::NetOut> : std::formatter<std::string> {
auto format(const gridfire::NetOut& netOut, auto& ctx) {
std::string output = "NetOut(, num_steps=" + std::to_string(netOut.num_steps) +
", energy=" + std::to_string(netOut.energy) +
", dEps_dT=" + std::to_string(netOut.dEps_dT) +
", dEps_dRho=" + std::to_string(netOut.dEps_dRho) + ")";
return std::formatter<std::string>::format(output, ctx);
}
};

View File

@@ -0,0 +1,50 @@
#pragma once
#include "fourdst/logging/logging.h"
#include "quill/LogMacros.h"
#if defined(GF_USE_OPENMP)
#include <omp.h>
namespace gridfire::omp {
static bool s_par_mode_initialized = false;
inline unsigned long get_thread_id() {
return static_cast<unsigned long>(omp_get_thread_num());
}
inline bool in_parallel() {
return omp_in_parallel() != 0;
}
inline void init_parallel_mode() {
if (s_par_mode_initialized) {
return; // Only initialize once
}
quill::Logger* logger = fourdst::logging::LogManager::getInstance().getLogger("log");
LOG_INFO(logger, "Initializing OpenMP parallel mode with {} threads", static_cast<unsigned long>(omp_get_max_threads()));
CppAD::thread_alloc::parallel_setup(
static_cast<size_t>(omp_get_max_threads()), // Max threads
[]() -> bool { return in_parallel(); }, // Function to get thread ID
[]() -> size_t { return get_thread_id(); } // Function to check parallel state
);
CppAD::thread_alloc::hold_memory(true);
s_par_mode_initialized = true;
}
}
#define GF_PAR_INIT() gridfire::omp::init_parallel_mode();
#else
namespace gridfire::omp {
inline void log_not_in_parallel_mode() {
quill::Logger* logger = fourdst::logging::LogManager::getInstance().getLogger("log");
LOG_INFO(logger, "This is not an error! Note: OpenMP parallel mode is not enabled. GF_USE_OPENMP is not defined. Pass -DGF_USE_OPENMP when compiling to enable OpenMP support. When using meson use the option -Dopenmp_support=true");
}
}
#define GF_PAR_INIT() gridfire::omp::log_not_in_parallel_mode();
#endif

View File

@@ -0,0 +1,9 @@
#pragma once
#if defined(GF_USE_OPENMP)
#define GF_OMP_PRAGMA(x) _Pragma(#x)
#define GF_OMP(omp_args, _) GF_OMP_PRAGMA(omp omp_args)
#else
#define GF_OMP(_,fallback_args) fallback_args
#endif

View File

@@ -5,3 +5,4 @@
#include "gridfire/utils/logging.h"
#include "gridfire/utils/sundials.h"
#include "gridfire/utils/table_format.h"
#include "gridfire/utils/macros.h"