feat(SpectralSolver): Spectral Solver now works in a limited fashion
Major work on spectral solver, can now evolve up to about a year. At that point we likely need to impliment repartitioning logic to stabalize the network or some other scheme based on the jacobian structure
This commit is contained in:
@@ -25,14 +25,14 @@
|
||||
#endif
|
||||
|
||||
namespace gridfire::solver {
|
||||
class SpectralSolverStrategy final : public MultiZoneDynamicNetworkSolverStrategy {
|
||||
class SpectralSolverStrategy final : public MultiZoneDynamicNetworkSolver {
|
||||
public:
|
||||
explicit SpectralSolverStrategy(engine::DynamicEngine& engine);
|
||||
explicit SpectralSolverStrategy(const engine::DynamicEngine& engine);
|
||||
~SpectralSolverStrategy() override;
|
||||
|
||||
std::vector<NetOut> evaluate(
|
||||
const std::vector<NetIn> &netIns,
|
||||
const std::vector<double>& mass_coords
|
||||
const std::vector<double>& mass_coords, const engine::scratch::StateBlob &ctx_template
|
||||
) override;
|
||||
|
||||
void set_callback(const std::any &callback) override;
|
||||
@@ -83,6 +83,11 @@ namespace gridfire::solver {
|
||||
using TimestepCallback = std::function<void(const TimestepContext&)>;
|
||||
private:
|
||||
|
||||
enum class ParallelInitializationResult : uint8_t {
|
||||
SUCCESS,
|
||||
FAILURE
|
||||
};
|
||||
|
||||
struct SpectralCoefficients {
|
||||
size_t num_sets;
|
||||
size_t num_coefficients;
|
||||
@@ -122,14 +127,16 @@ namespace gridfire::solver {
|
||||
void init_from_basis(size_t num_basis_funcs, const SplineBasis& basis) const;
|
||||
|
||||
void solve_inplace(N_Vector x, size_t num_vars, size_t basis_size) const;
|
||||
void solve_inplace_ptr(sunrealtype* data_ptr, size_t num_vars, size_t basis_size) const;
|
||||
};
|
||||
|
||||
struct CVODEUserData {
|
||||
SpectralSolverStrategy* solver_instance{};
|
||||
engine::DynamicEngine* engine;
|
||||
const engine::DynamicEngine* engine{};
|
||||
|
||||
DenseLinearSolver* mass_matrix_solver_instance;
|
||||
const SplineBasis* basis;
|
||||
DenseLinearSolver* mass_matrix_solver_instance{};
|
||||
const SplineBasis* basis{};
|
||||
std::vector<std::unique_ptr<engine::scratch::StateBlob>>* workspaces;
|
||||
};
|
||||
|
||||
private:
|
||||
@@ -166,10 +173,12 @@ namespace gridfire::solver {
|
||||
N_Vector m_T_coeffs = nullptr;
|
||||
N_Vector m_rho_coeffs = nullptr;
|
||||
|
||||
std::vector<fourdst::atomic::Species> m_global_species_list;
|
||||
|
||||
|
||||
private:
|
||||
std::vector<double> evaluate_monitor_function(const std::vector<NetIn>& current_shells) const;
|
||||
SplineBasis generate_basis_from_monitor(const std::vector<double>& monitor_values, const std::vector<double>& mass_coordinates) const;
|
||||
SplineBasis generate_basis_from_monitor(const std::vector<double>& monitor_values, const std::vector<double>& mass_coordinates, size_t actual_elements) const;
|
||||
|
||||
GridPoint reconstruct_at_quadrature(const N_Vector y_coeffs, size_t quad_index, const SplineBasis &basis) const;
|
||||
|
||||
@@ -179,6 +188,9 @@ namespace gridfire::solver {
|
||||
static int cvode_jac_wrapper(sunrealtype t, N_Vector y, N_Vector ydot, SUNMatrix J, void* user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3);
|
||||
|
||||
int calculate_rhs(sunrealtype t, N_Vector y_coeffs, N_Vector ydot_coeffs, CVODEUserData* data) const;
|
||||
int calculate_jacobian(sunrealtype t, N_Vector y_coeffs, N_Vector ydot_coeffs, SUNMatrix J, const CVODEUserData *data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3) const;
|
||||
|
||||
static size_t nyquist_elements(size_t requested_elements, size_t num_shells) ;
|
||||
|
||||
static void project_specific_variable(
|
||||
const std::vector<NetIn>& current_shells,
|
||||
@@ -191,6 +203,7 @@ namespace gridfire::solver {
|
||||
bool use_log
|
||||
);
|
||||
|
||||
void inspect_jacobian(SUNMatrix J, const std::string& context) const;
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
@@ -105,23 +105,20 @@ namespace gridfire::solver {
|
||||
class MultiZoneNetworkSolver {
|
||||
public:
|
||||
explicit MultiZoneNetworkSolver(
|
||||
const EngineT& engine,
|
||||
const engine::scratch::StateBlob& ctx
|
||||
const EngineT& engine
|
||||
) :
|
||||
m_engine(engine),
|
||||
m_scratch_blob_structure(ctx.clone_structure()){};
|
||||
m_engine(engine) {};
|
||||
|
||||
virtual ~MultiZoneNetworkSolver() = default;
|
||||
|
||||
virtual std::vector<NetOut> evaluate(
|
||||
const std::vector<NetIn>& netIns,
|
||||
const std::vector<double>& mass_coords
|
||||
const std::vector<double>& mass_coords, const engine::scratch::StateBlob &ctx_template
|
||||
) = 0;
|
||||
virtual void set_callback(const std::any& callback) = 0;
|
||||
[[nodiscard]] virtual std::vector<std::tuple<std::string, std::string>> describe_callback_context() const = 0;
|
||||
protected:
|
||||
const EngineT& m_engine; ///< The engine used by this solver strategy.
|
||||
std::unique_ptr<engine::scratch::StateBlob> m_scratch_blob_structure;
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -59,3 +59,26 @@ namespace gridfire {
|
||||
concept IsArithmeticOrAD = std::is_same_v<T, double> || std::is_same_v<T, CppAD::AD<double>>;
|
||||
|
||||
} // namespace nuclearNetwork
|
||||
|
||||
template<>
|
||||
struct std::formatter<gridfire::NetIn> : std::formatter<std::string> {
|
||||
auto format(const gridfire::NetIn& netIn, auto& ctx) {
|
||||
std::string output = "NetIn(, tMax=" + std::to_string(netIn.tMax) +
|
||||
", dt0=" + std::to_string(netIn.dt0) +
|
||||
", temperature=" + std::to_string(netIn.temperature) +
|
||||
", density=" + std::to_string(netIn.density) +
|
||||
", energy=" + std::to_string(netIn.energy) + ")";
|
||||
return std::formatter<std::string>::format(output, ctx);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct std::formatter<gridfire::NetOut> : std::formatter<std::string> {
|
||||
auto format(const gridfire::NetOut& netOut, auto& ctx) {
|
||||
std::string output = "NetOut(, num_steps=" + std::to_string(netOut.num_steps) +
|
||||
", energy=" + std::to_string(netOut.energy) +
|
||||
", dEps_dT=" + std::to_string(netOut.dEps_dT) +
|
||||
", dEps_dRho=" + std::to_string(netOut.dEps_dRho) + ")";
|
||||
return std::formatter<std::string>::format(output, ctx);
|
||||
}
|
||||
};
|
||||
|
||||
50
src/include/gridfire/utils/gf_omp.h
Normal file
50
src/include/gridfire/utils/gf_omp.h
Normal file
@@ -0,0 +1,50 @@
|
||||
#pragma once
|
||||
#include "fourdst/logging/logging.h"
|
||||
#include "quill/LogMacros.h"
|
||||
|
||||
#if defined(GF_USE_OPENMP)
|
||||
|
||||
#include <omp.h>
|
||||
|
||||
namespace gridfire::omp {
|
||||
static bool s_par_mode_initialized = false;
|
||||
|
||||
inline unsigned long get_thread_id() {
|
||||
return static_cast<unsigned long>(omp_get_thread_num());
|
||||
}
|
||||
|
||||
inline bool in_parallel() {
|
||||
return omp_in_parallel() != 0;
|
||||
}
|
||||
|
||||
inline void init_parallel_mode() {
|
||||
if (s_par_mode_initialized) {
|
||||
return; // Only initialize once
|
||||
}
|
||||
quill::Logger* logger = fourdst::logging::LogManager::getInstance().getLogger("log");
|
||||
LOG_INFO(logger, "Initializing OpenMP parallel mode with {} threads", static_cast<unsigned long>(omp_get_max_threads()));
|
||||
CppAD::thread_alloc::parallel_setup(
|
||||
static_cast<size_t>(omp_get_max_threads()), // Max threads
|
||||
[]() -> bool { return in_parallel(); }, // Function to get thread ID
|
||||
[]() -> size_t { return get_thread_id(); } // Function to check parallel state
|
||||
);
|
||||
|
||||
CppAD::thread_alloc::hold_memory(true);
|
||||
s_par_mode_initialized = true;
|
||||
}
|
||||
}
|
||||
|
||||
#define GF_PAR_INIT() gridfire::omp::init_parallel_mode();
|
||||
|
||||
#else
|
||||
|
||||
namespace gridfire::omp {
|
||||
inline void log_not_in_parallel_mode() {
|
||||
quill::Logger* logger = fourdst::logging::LogManager::getInstance().getLogger("log");
|
||||
LOG_INFO(logger, "This is not an error! Note: OpenMP parallel mode is not enabled. GF_USE_OPENMP is not defined. Pass -DGF_USE_OPENMP when compiling to enable OpenMP support. When using meson use the option -Dopenmp_support=true");
|
||||
}
|
||||
}
|
||||
|
||||
#define GF_PAR_INIT() gridfire::omp::log_not_in_parallel_mode();
|
||||
|
||||
#endif
|
||||
9
src/include/gridfire/utils/macros.h
Normal file
9
src/include/gridfire/utils/macros.h
Normal file
@@ -0,0 +1,9 @@
|
||||
#pragma once
|
||||
|
||||
|
||||
#if defined(GF_USE_OPENMP)
|
||||
#define GF_OMP_PRAGMA(x) _Pragma(#x)
|
||||
#define GF_OMP(omp_args, _) GF_OMP_PRAGMA(omp omp_args)
|
||||
#else
|
||||
#define GF_OMP(_,fallback_args) fallback_args
|
||||
#endif
|
||||
@@ -5,3 +5,4 @@
|
||||
#include "gridfire/utils/logging.h"
|
||||
#include "gridfire/utils/sundials.h"
|
||||
#include "gridfire/utils/table_format.h"
|
||||
#include "gridfire/utils/macros.h"
|
||||
|
||||
Reference in New Issue
Block a user