diff --git a/benchmarks/SingleZoneSolver/main.cpp b/benchmarks/SingleZoneSolver/main.cpp index 5eb0cbcf..72edb819 100644 --- a/benchmarks/SingleZoneSolver/main.cpp +++ b/benchmarks/SingleZoneSolver/main.cpp @@ -63,7 +63,7 @@ int main() { std::println("Scratch Blob State: {}", *construct.scratch_blob); - constexpr size_t runs = 1000; + constexpr size_t runs = 10; auto startTime = std::chrono::high_resolution_clock::now(); // arrays to store timings @@ -72,14 +72,15 @@ int main() { std::array serial_results; for (size_t i = 0; i < runs; ++i) { auto start_setup_time = std::chrono::high_resolution_clock::now(); - solver::CVODESolverStrategy solver(construct.engine, *construct.scratch_blob); - solver.set_stdout_logging_enabled(false); + solver::PointSolverContext solverCtx(*construct.scratch_blob); + solverCtx.set_stdout_logging(false); + solver::PointSolver solver(construct.engine); auto end_setup_time = std::chrono::high_resolution_clock::now(); std::chrono::duration setup_elapsed = end_setup_time - start_setup_time; setup_times[i] = setup_elapsed; auto start_eval_time = std::chrono::high_resolution_clock::now(); - const NetOut netOut = solver.evaluate(netIn); + const NetOut netOut = solver.evaluate(solverCtx, netIn); auto end_eval_time = std::chrono::high_resolution_clock::now(); serial_results[i] = netOut; std::chrono::duration eval_elapsed = end_eval_time - start_eval_time; @@ -99,7 +100,6 @@ int main() { std::println("Average Setup Time over {} runs: {:.6f} seconds", runs, total_setup_time / runs); std::println("Average Evaluation Time over {} runs: {:.6f} seconds", runs, total_eval_time / runs); std::println("Total Time for {} runs: {:.6f} seconds", runs, elapsed.count()); - std::println("Final H-1 Abundances Serial: {}", serial_results[0].composition.getMolarAbundance(fourdst::atomic::H_1)); std::array parallelResults; @@ -114,16 +114,16 @@ int main() { // Parallel runs startTime = std::chrono::high_resolution_clock::now(); - GF_OMP(parallel for,) - for (size_t i = 0; i < runs; ++i) { + GF_OMP(parallel for, for (size_t i = 0; i < runs; ++i)) { auto start_setup_time = std::chrono::high_resolution_clock::now(); - solver::CVODESolverStrategy solver(construct.engine, *workspaces[i]); - solver.set_stdout_logging_enabled(false); + solver::PointSolverContext solverCtx(*construct.scratch_blob); + solverCtx.set_stdout_logging(false); + solver::PointSolver solver(construct.engine); auto end_setup_time = std::chrono::high_resolution_clock::now(); std::chrono::duration setup_elapsed = end_setup_time - start_setup_time; setupTimes[i] = setup_elapsed; auto start_eval_time = std::chrono::high_resolution_clock::now(); - parallelResults[i] = solver.evaluate(netIn); + parallelResults[i] = solver.evaluate(solverCtx, netIn); auto end_eval_time = std::chrono::high_resolution_clock::now(); std::chrono::duration eval_elapsed = end_eval_time - start_eval_time; evalTimes[i] = eval_elapsed; @@ -144,10 +144,6 @@ int main() { std::println("Average Parallel Evaluation Time over {} runs: {:.6f} seconds", runs, total_eval_time / runs); std::println("Total Parallel Time for {} runs: {:.6f} seconds", runs, elapsed.count()); - std::println("Final H-1 Abundances Parallel: {}", utils::iterable_to_delimited_string(parallelResults, ",", [](const auto& result) { - return result.composition.getMolarAbundance(fourdst::atomic::H_1); - })); - std::println("========== Summary =========="); std::println("Serial Runs:"); std::println(" Average Setup Time: {:.6f} seconds", total_setup_time / runs); diff --git a/build-check/CPPC/meson.build b/build-check/CPPC/meson.build index e3d8b05c..743ebc36 100644 --- a/build-check/CPPC/meson.build +++ b/build-check/CPPC/meson.build @@ -35,3 +35,16 @@ endif if get_option('openmp_support') add_project_arguments('-DGF_USE_OPENMP', language: 'cpp') endif + +if get_option('asan') and get_option('buildtype') != 'debug' and get_option('buildtype') != 'debugoptimized' + error('AddressSanitizer (ASan) can only be enabled for debug or debugoptimized builds') +endif + +if get_option('asan') and (get_option('buildtype') == 'debugoptimized' or get_option('buildtype') == 'debug') + message('enabling AddressSanitizer (ASan) support') + add_project_arguments('-fsanitize=address,undefined', language: 'cpp') + add_project_arguments('-fno-omit-frame-pointer', language: 'cpp') + + add_project_link_arguments('-fsanitize=address,undefined', language: 'cpp') + add_project_link_arguments('-fno-omit-frame-pointer', language: 'cpp') +endif diff --git a/meson_options.txt b/meson_options.txt index 9df91c27..0d7a71d1 100644 --- a/meson_options.txt +++ b/meson_options.txt @@ -11,4 +11,5 @@ option('build_c_api', type: 'boolean', value: true, description: 'compile the C option('build_tools', type: 'boolean', value: true, description: 'build the GridFire command line tools') option('openmp_support', type: 'boolean', value: false, description: 'Enable OpenMP support for parallelization') option('use_mimalloc', type: 'boolean', value: true, description: 'Use mimalloc as the memory allocator for GridFire. Generally this is ~10% faster than the system allocator.') -option('build_benchmarks', type: 'boolean', value: false, description: 'build the benchmark suite') \ No newline at end of file +option('build_benchmarks', type: 'boolean', value: false, description: 'build the benchmark suite') +option('asan', type: 'boolean', value: false, description: 'Enable AddressSanitizer (ASan) support for detecting memory errors') \ No newline at end of file diff --git a/src/include/gridfire/config/config.h b/src/include/gridfire/config/config.h index c5a1f067..52ebf648 100644 --- a/src/include/gridfire/config/config.h +++ b/src/include/gridfire/config/config.h @@ -11,10 +11,10 @@ namespace gridfire::config { struct SpectralSolverConfig { struct Trigger { - double simulationTimeInterval = 1.0e12; - double offDiagonalThreshold = 1.0e10; double timestepCollapseRatio = 0.5; size_t maxConvergenceFailures = 2; + double relativeFailureRate = 0.5; + size_t windowSize = 10; }; struct MonitorFunctionConfig { double structure_weight = 1.0; diff --git a/src/include/gridfire/engine/engine_graph.h b/src/include/gridfire/engine/engine_graph.h index c977445f..25c374a1 100644 --- a/src/include/gridfire/engine/engine_graph.h +++ b/src/include/gridfire/engine/engine_graph.h @@ -807,8 +807,6 @@ namespace gridfire::engine { CppAD::ADFun m_authoritativeADFun; - const size_t m_state_blob_offset; - private: /** * @brief Synchronizes the internal maps. diff --git a/src/include/gridfire/solver/strategies/GridSolver.h b/src/include/gridfire/solver/strategies/GridSolver.h new file mode 100644 index 00000000..1b94c92e --- /dev/null +++ b/src/include/gridfire/solver/strategies/GridSolver.h @@ -0,0 +1,43 @@ +#pragma once + +#include "gridfire/solver/strategies/strategy_abstract.h" + +#include + +namespace gridfire::solver { + struct GridSolverContext final : SolverContextBase { + std::vector> solver_workspaces; + std::vector> timestep_callbacks; + const engine::scratch::StateBlob& ctx_template; + + bool zone_completion_logging = true; + bool zone_stdout_logging = false; + bool zone_detailed_logging = false; + + void init() override; + void reset(); + + void set_callback(const std::function &callback); + void set_callback(const std::function &callback, size_t zone_idx); + + void set_stdout_logging(bool enable) override; + void set_detailed_logging(bool enable) override; + + explicit GridSolverContext(const engine::scratch::StateBlob& ctx_template); + }; + + class GridSolver final : public MultiZoneDynamicNetworkSolver { + public: + GridSolver( + const engine::DynamicEngine& engine, + const SingleZoneDynamicNetworkSolver& solver + ); + + ~GridSolver() override = default; + + std::vector evaluate( + SolverContextBase& ctx, + const std::vector& netIns + ) const override; + }; +} \ No newline at end of file diff --git a/src/include/gridfire/solver/strategies/CVODE_solver_strategy.h b/src/include/gridfire/solver/strategies/PointSolver.h similarity index 70% rename from src/include/gridfire/solver/strategies/CVODE_solver_strategy.h rename to src/include/gridfire/solver/strategies/PointSolver.h index b1538ebf..480aebf7 100644 --- a/src/include/gridfire/solver/strategies/CVODE_solver_strategy.h +++ b/src/include/gridfire/solver/strategies/PointSolver.h @@ -44,8 +44,88 @@ #endif namespace gridfire::solver { + struct PointSolverTimestepContext final : TimestepContextBase { + const double t; ///< Current integration time [s]. + const N_Vector& state; ///< Current CVODE state vector (N_Vector). + const double dt; ///< Last step size [s]. + const double last_step_time; ///< Time at last callback [s]. + const double T9; ///< Temperature in GK. + const double rho; ///< Density [g cm^-3]. + const size_t num_steps; ///< Number of CVODE steps taken so far. + const engine::DynamicEngine& engine; ///< Reference to the engine. + const std::vector& networkSpecies; ///< Species layout. + const size_t currentConvergenceFailures; ///< Total number of convergence failures + const size_t currentNonlinearIterations; ///< Total number of non-linear iterations + const std::map>& reactionContributionMap; ///< Map of reaction contributions for the current step + engine::scratch::StateBlob& state_ctx; ///< Reference to the engine scratch state blob + + PointSolverTimestepContext( + double t, + const N_Vector& state, + double dt, + double last_step_time, + double t9, + double rho, + size_t num_steps, + const engine::DynamicEngine& engine, + const std::vector& networkSpecies, + size_t currentConvergenceFailure, + size_t currentNonlinearIterations, + const std::map> &reactionContributionMap, + engine::scratch::StateBlob& state_ctx + ); + + [[nodiscard]] std::vector> describe() const override; + }; + + using TimestepCallback = std::function; ///< Type alias for a timestep callback function. + + struct PointSolverContext final : SolverContextBase { + SUNContext sun_ctx = nullptr; ///< SUNDIALS context (lifetime of the solver). + void* cvode_mem = nullptr; ///< CVODE memory block. + N_Vector Y = nullptr; ///< CVODE state vector (species + energy accumulator). + N_Vector YErr = nullptr; ///< Estimated local errors. + SUNMatrix J = nullptr; ///< Dense Jacobian matrix. + SUNLinearSolver LS = nullptr; ///< Dense linear solver. + + std::unique_ptr engine_ctx; + + + std::optional callback; ///< Optional per-step callback. + int num_steps = 0; ///< CVODE step counter (used for diagnostics and triggers). + + bool stdout_logging = true; ///< If true, print per-step logs and use CV_ONE_STEP. + + N_Vector constraints = nullptr; ///< CVODE constraints vector (>= 0 for species entries). + + std::optional abs_tol; ///< User-specified absolute tolerance. + std::optional rel_tol; ///< User-specified relative tolerance. + + bool detailed_step_logging = false; ///< If true, log detailed step diagnostics (error ratios, Jacobian, species balance). + + size_t last_size = 0; + size_t last_composition_hash = 0ULL; + sunrealtype last_good_time_step = 0ULL; + + void init() override; + void set_stdout_logging(bool enable) override; + void set_detailed_logging(bool enable) override; + + void reset_all(); + void reset_user(); + void reset_cvode(); + void clear_context(); + void init_context(); + + [[nodiscard]] bool has_context() const; + + explicit PointSolverContext(const engine::scratch::StateBlob& engine_ctx); + ~PointSolverContext() override; + + }; + /** - * @class CVODESolverStrategy + * @class PointSolver * @brief Stiff ODE integrator backed by SUNDIALS CVODE (BDF) for network + energy. * * Integrates the nuclear network abundances along with a final accumulator entry for specific @@ -78,27 +158,16 @@ namespace gridfire::solver { * std::cout << "Final energy: " << out.energy << " erg/g\n"; * @endcode */ - class CVODESolverStrategy final : public SingleZoneDynamicNetworkSolver { + class PointSolver final : public SingleZoneDynamicNetworkSolver { public: /** * @brief Construct the CVODE strategy and create a SUNDIALS context. * @param engine DynamicEngine used for RHS/Jacobian evaluation and network access. * @throws std::runtime_error If SUNContext_Create fails. */ - explicit CVODESolverStrategy( - const engine::DynamicEngine& engine, - const engine::scratch::StateBlob& ctx + explicit PointSolver( + const engine::DynamicEngine& engine ); - /** - * @brief Destructor: cleans CVODE/SUNDIALS resources and frees SUNContext. - */ - ~CVODESolverStrategy() override; - - // Make the class non-copyable and non-movable to prevent shallow copies of CVODE pointers - CVODESolverStrategy(const CVODESolverStrategy&) = delete; - CVODESolverStrategy& operator=(const CVODESolverStrategy&) = delete; - CVODESolverStrategy(CVODESolverStrategy&&) = delete; - CVODESolverStrategy& operator=(CVODESolverStrategy&&) = delete; /** * @brief Integrate from t=0 to netIn.tMax and return final composition and energy. @@ -114,6 +183,7 @@ namespace gridfire::solver { * - At the end, converts molar abundances to mass fractions and assembles NetOut, * including derivatives of energy w.r.t. T and rho from the engine. * + * @param solver_ctx * @param netIn Inputs: temperature [K], density [g cm^-3], tMax [s], composition. * @return NetOut containing final Composition, accumulated energy [erg/g], step count, * and dEps/dT, dEps/dRho. @@ -122,10 +192,14 @@ namespace gridfire::solver { * @throws exceptions::StaleEngineTrigger Propagated if the engine signals a stale state * during RHS evaluation (captured in the wrapper then rethrown here). */ - NetOut evaluate(const NetIn& netIn) override; + NetOut evaluate( + SolverContextBase& solver_ctx, + const NetIn& netIn + ) const override; /** * @brief Call to evaluate which will let the user control if the trigger reasoning is displayed + * @param solver_ctx * @param netIn Inputs: temperature [K], density [g cm^-3], tMax [s], composition. * @param displayTrigger Boolean flag to control if trigger reasoning is displayed * @param forceReinitialize Boolean flag to force reinitialization of CVODE resources at the start @@ -136,89 +210,13 @@ namespace gridfire::solver { * @throws exceptions::StaleEngineTrigger Propagated if the engine signals a stale state * during RHS evaluation (captured in the wrapper then rethrown here). */ - NetOut evaluate(const NetIn& netIn, bool displayTrigger, bool forceReinitialize = false); + NetOut evaluate( + SolverContextBase& solver_ctx, + const NetIn& netIn, + bool displayTrigger, + bool forceReinitialize = false + ) const; - /** - * @brief Install a timestep callback. - * @param callback std::any containing TimestepCallback (std::function). - * @throws std::bad_any_cast If callback is not of the expected type. - */ - void set_callback(const std::any &callback) override; - - /** - * @brief Whether per-step logs are printed to stdout and CVode is stepped with CV_ONE_STEP. - */ - [[nodiscard]] bool get_stdout_logging_enabled() const; - - /** - * @brief Enable/disable per-step stdout logging. - * @param logging_enabled Flag to control if a timestep summary is written to standard output or not - */ - void set_stdout_logging_enabled(bool logging_enabled); - - void set_absTol(double absTol); - void set_relTol(double relTol); - - double get_absTol() const; - double get_relTol() const; - - /** - * @brief Schema of fields exposed to the timestep callback context. - */ - [[nodiscard]] std::vector> describe_callback_context() const override; - - /** - * @struct TimestepContext - * @brief Immutable view of the current integration state passed to callbacks. - * - * Fields capture CVODE time/state, step size, thermodynamic state, the engine reference, - * and the list of network species used to interpret the state vector layout. - */ - struct TimestepContext final : public SolverContextBase { - // This struct can be identical to the one in DirectNetworkSolver - const double t; ///< Current integration time [s]. - const N_Vector& state; ///< Current CVODE state vector (N_Vector). - const double dt; ///< Last step size [s]. - const double last_step_time; ///< Time at last callback [s]. - const double T9; ///< Temperature in GK. - const double rho; ///< Density [g cm^-3]. - const size_t num_steps; ///< Number of CVODE steps taken so far. - const engine::DynamicEngine& engine; ///< Reference to the engine. - const std::vector& networkSpecies; ///< Species layout. - const size_t currentConvergenceFailures; ///< Total number of convergence failures - const size_t currentNonlinearIterations; ///< Total number of non-linear iterations - const std::map>& reactionContributionMap; ///< Map of reaction contributions for the current step - engine::scratch::StateBlob& state_ctx; ///< Reference to the engine scratch state blob - - /** - * @brief Construct a context snapshot. - */ - TimestepContext( - double t, - const N_Vector& state, - double dt, - double last_step_time, - double t9, - double rho, - size_t num_steps, - const engine::DynamicEngine& engine, - const std::vector& networkSpecies, - size_t currentConvergenceFailure, - size_t currentNonlinearIterations, - const std::map> &reactionContributionMap, - engine::scratch::StateBlob& state_ctx - ); - - /** - * @brief Human-readable description of the context fields. - */ - [[nodiscard]] std::vector> describe() const override; - }; - - /** - * @brief Type alias for a timestep callback. - */ - using TimestepCallback = std::function; ///< Type alias for a timestep callback function. private: /** * @struct CVODEUserData @@ -230,7 +228,8 @@ namespace gridfire::solver { * to CVODE, then the driver loop inspects and rethrows. */ struct CVODEUserData { - CVODESolverStrategy* solver_instance{}; // Pointer back to the class instance + const PointSolver* solver_instance{}; // Pointer back to the class instance + PointSolverContext* sctx; // Pointer to the solver context engine::scratch::StateBlob& ctx; const engine::DynamicEngine* engine{}; double T9{}; @@ -283,6 +282,7 @@ namespace gridfire::solver { * step size, creates a dense matrix and dense linear solver, and registers the Jacobian. */ void initialize_cvode_integration_resources( + PointSolverContext* ctx, uint64_t N, size_t numSpecies, double current_time, @@ -290,15 +290,7 @@ namespace gridfire::solver { double absTol, double relTol, double accumulatedEnergy - ); - - /** - * @brief Destroy CVODE vectors/linear algebra and optionally the CVODE memory block. - * @param memFree If true, also calls CVodeFree on m_cvode_mem. - */ - void cleanup_cvode_resources(bool memFree); - - void set_detailed_step_logging(bool enabled); + ) const; /** @@ -308,31 +300,13 @@ namespace gridfire::solver { * sorted table of species with the highest error ratios; then invokes diagnostic routines to * inspect Jacobian stiffness and species balance. */ - void log_step_diagnostics(engine::scratch::StateBlob &ctx, const CVODEUserData& user_data, bool displayJacobianStiffness, bool - displaySpeciesBalance, bool to_file, std::optional filename) const; - private: - SUNContext m_sun_ctx = nullptr; ///< SUNDIALS context (lifetime of the solver). - void* m_cvode_mem = nullptr; ///< CVODE memory block. - N_Vector m_Y = nullptr; ///< CVODE state vector (species + energy accumulator). - N_Vector m_YErr = nullptr; ///< Estimated local errors. - SUNMatrix m_J = nullptr; ///< Dense Jacobian matrix. - SUNLinearSolver m_LS = nullptr; ///< Dense linear solver. - - - std::optional m_callback; ///< Optional per-step callback. - int m_num_steps = 0; ///< CVODE step counter (used for diagnostics and triggers). - - bool m_stdout_logging_enabled = true; ///< If true, print per-step logs and use CV_ONE_STEP. - - N_Vector m_constraints = nullptr; ///< CVODE constraints vector (>= 0 for species entries). - - std::optional m_absTol; ///< User-specified absolute tolerance. - std::optional m_relTol; ///< User-specified relative tolerance. - - bool m_detailed_step_logging = false; ///< If true, log detailed step diagnostics (error ratios, Jacobian, species balance). - - mutable size_t m_last_size = 0; - mutable size_t m_last_composition_hash = 0ULL; - mutable sunrealtype m_last_good_time_step = 0ULL; + void log_step_diagnostics( + PointSolverContext* sctx_p, + engine::scratch::StateBlob &ctx, + const CVODEUserData& user_data, + bool displayJacobianStiffness, + bool displaySpeciesBalance, + bool to_file, std::optional filename + ) const; }; } \ No newline at end of file diff --git a/src/include/gridfire/solver/strategies/SpectralSolverStrategy.h b/src/include/gridfire/solver/strategies/SpectralSolverStrategy.h deleted file mode 100644 index d1df6663..00000000 --- a/src/include/gridfire/solver/strategies/SpectralSolverStrategy.h +++ /dev/null @@ -1,221 +0,0 @@ -#pragma once - -#include "gridfire/solver/strategies/strategy_abstract.h" -#include "gridfire/engine/engine_abstract.h" -#include "gridfire/types/types.h" -#include "gridfire/config/config.h" - -#include "fourdst/logging/logging.h" -#include "fourdst/constants/const.h" - -#include -#include -#include -#include - -#include "gridfire/exceptions/error_engine.h" - -#ifdef SUNDIALS_HAVE_OPENMP - #include -#endif -#ifdef SUNDIALS_HAVE_PTHREADS - #include -#endif -#ifndef SUNDIALS_HAVE_OPENMP - #ifndef SUNDIALS_HAVE_PTHREADS - #include - #endif -#endif - -namespace gridfire::solver { - class SpectralSolverStrategy final : public MultiZoneDynamicNetworkSolver { - public: - explicit SpectralSolverStrategy(const engine::DynamicEngine& engine); - ~SpectralSolverStrategy() override; - - std::vector evaluate( - const std::vector &netIns, - const std::vector& mass_coords, const engine::scratch::StateBlob &ctx_template - ) override; - - void set_callback(const std::any &callback) override; - [[nodiscard]] std::vector> describe_callback_context() const override; - - [[nodiscard]] bool get_stdout_logging_enabled() const; - void set_stdout_logging_enabled(bool logging_enabled); - - public: - struct TimestepContext final : public SolverContextBase { - TimestepContext( - const double t, - const N_Vector &state, - const double dt, - const double last_time_step, - const engine::DynamicEngine &engine - ) : - t(t), - state(state), - dt(dt), - last_time_step(last_time_step), - engine(engine) {} - - [[nodiscard]] std::vector> describe() const override; - - const double t; - const N_Vector& state; - const double dt; - const double last_time_step; - const engine::DynamicEngine& engine; - }; - - struct BasisEval { - size_t start_idx; - std::vector phi; - }; - - struct SplineBasis { - std::vector knots; - std::vector quadrature_nodes; - std::vector quadrature_weights; - int degree = 3; - - - std::vector quad_evals; - }; - public: - using TimestepCallback = std::function; - private: - - enum class ParallelInitializationResult : uint8_t { - SUCCESS, - FAILURE - }; - - struct SpectralCoefficients { - size_t num_sets; - size_t num_coefficients; - std::vector coefficients; - - double operator()(size_t i, size_t j) const; - }; - - struct GridPoint { - double T9; - double rho; - fourdst::composition::Composition composition; - }; - - struct Constants { - const double c = fourdst::constant::Constants::getInstance().get("c").value; - const double N_a = fourdst::constant::Constants::getInstance().get("N_a").value; - const double c2 = c * c; - }; - - struct DenseLinearSolver { - SUNMatrix A; - SUNLinearSolver LS; - N_Vector temp_vector; - SUNContext ctx; - - DenseLinearSolver(size_t size, SUNContext sun_ctx); - ~DenseLinearSolver(); - - DenseLinearSolver(const DenseLinearSolver&) = delete; - DenseLinearSolver& operator=(const DenseLinearSolver&) = delete; - - void setup() const; - void zero() const; - - void init_from_cache(size_t num_basis_funcs, const std::vector& shell_cache) const; - void init_from_basis(size_t num_basis_funcs, const SplineBasis& basis) const; - - void solve_inplace(N_Vector x, size_t num_vars, size_t basis_size) const; - void solve_inplace_ptr(sunrealtype* data_ptr, size_t num_vars, size_t basis_size) const; - }; - - struct CVODEUserData { - SpectralSolverStrategy* solver_instance{}; - std::vector> workspaces; - const engine::DynamicEngine* engine{}; - std::unique_ptr captured_exception{}; - - std::vector T9{}; - std::vector rho{}; - double energy{}; - - double neutrino_energy_loss_rate = 0.0; - double total_neutrino_flux = 0.0; - - DenseLinearSolver* mass_matrix_solver_instance{}; - const SplineBasis* basis{}; - }; - - private: - fourdst::config::Config m_config; - quill::Logger* m_logger = fourdst::logging::LogManager::getInstance().getLogger("log"); - - SUNContext m_sun_ctx = nullptr; ///< SUNDIALS context (lifetime of the solver). - void* m_cvode_mem = nullptr; ///< CVODE memory block. - N_Vector m_Y = nullptr; ///< CVODE state vector (species + energy accumulator). - SUNMatrix m_J = nullptr; ///< Dense Jacobian matrix. - SUNLinearSolver m_LS = nullptr; ///< Dense linear solver. - - - std::optional m_callback; ///< Optional per-step callback. - int m_num_steps = 0; ///< CVODE step counter (used for diagnostics and triggers). - - bool m_stdout_logging_enabled = true; ///< If true, print per-step logs and use CV_ONE_STEP. - - N_Vector m_constraints = nullptr; ///< CVODE constraints vector (>= 0 for species entries). - - std::optional m_absTol; ///< User-specified absolute tolerance. - std::optional m_relTol; ///< User-specified relative tolerance. - - bool m_detailed_step_logging = false; ///< If true, log detailed step diagnostics (error ratios, Jacobian, species balance). - - mutable size_t m_last_size = 0; - mutable size_t m_last_composition_hash = 0ULL; - mutable sunrealtype m_last_good_time_step = 0ULL; - - SplineBasis m_current_basis; - - Constants m_constants; - - N_Vector m_T_coeffs = nullptr; - N_Vector m_rho_coeffs = nullptr; - - std::vector m_global_species_list; - - - private: - std::vector evaluate_monitor_function(const std::vector& current_shells) const; - - static SplineBasis generate_basis_from_monitor(const std::vector& monitor_values, const std::vector& mass_coordinates, size_t actual_elements); - - GridPoint reconstruct_at_quadrature(const N_Vector y_coeffs, size_t quad_index, const SplineBasis &basis) const; - - std::vector reconstruct_solution(const std::vector& original_inputs, const std::vector& mass_coordinates, const N_Vector final_coeffs, const SplineBasis& basis, double dt) const; - - static int cvode_rhs_wrapper(sunrealtype t, N_Vector y, N_Vector, void* user_data); - static int cvode_jac_wrapper(sunrealtype t, N_Vector y, N_Vector ydot, SUNMatrix J, void* user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3); - - int calculate_rhs(sunrealtype t, N_Vector y_coeffs, N_Vector ydot_coeffs, CVODEUserData* data) const; - int calculate_jacobian(sunrealtype t, N_Vector y_coeffs, N_Vector ydot_coeffs, SUNMatrix J, const CVODEUserData *data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3) const; - - static size_t nyquist_elements(size_t requested_elements, size_t num_shells) ; - - static void project_specific_variable( - const std::vector& current_shells, - const std::vector& mass_coordinates, - const std::vector& shell_cache, - const DenseLinearSolver& linear_solver, - N_Vector output_vec, - size_t output_offset, - const std::function &getter, - bool use_log - ); - - void inspect_jacobian(SUNMatrix J, const std::string& context) const; - }; - -} diff --git a/src/include/gridfire/solver/strategies/strategies.h b/src/include/gridfire/solver/strategies/strategies.h index f296dd3e..c902dba4 100644 --- a/src/include/gridfire/solver/strategies/strategies.h +++ b/src/include/gridfire/solver/strategies/strategies.h @@ -2,5 +2,5 @@ #include "gridfire/solver/strategies/triggers/triggers.h" #include "gridfire/solver/strategies/strategy_abstract.h" -#include "gridfire/solver/strategies/CVODE_solver_strategy.h" -#include "gridfire/solver/strategies/SpectralSolverStrategy.h" \ No newline at end of file +#include "gridfire/solver/strategies/PointSolver.h" +#include "gridfire/solver/strategies/GridSolver.h" \ No newline at end of file diff --git a/src/include/gridfire/solver/strategies/strategy_abstract.h b/src/include/gridfire/solver/strategies/strategy_abstract.h index a9f3e232..8d3d29f4 100644 --- a/src/include/gridfire/solver/strategies/strategy_abstract.h +++ b/src/include/gridfire/solver/strategies/strategy_abstract.h @@ -13,17 +13,24 @@ namespace gridfire::solver { template concept IsEngine = std::is_base_of_v; + struct SolverContextBase { + virtual void init() = 0; + virtual void set_stdout_logging(bool enable) = 0; + virtual void set_detailed_logging(bool enable) = 0; + virtual ~SolverContextBase() = default; + }; + /** - * @struct SolverContextBase + * @struct TimestepContextBase * @brief Base class for solver callback contexts. * * This struct serves as a base class for contexts that can be passed to solver callbacks, it enforces * that derived classes implement a `describe` method that returns a vector of tuples describing * the context that a callback will receive when called. */ - class SolverContextBase { + class TimestepContextBase { public: - virtual ~SolverContextBase() = default; + virtual ~TimestepContextBase() = default; /** * @brief Describe the context for callback functions. @@ -54,11 +61,9 @@ namespace gridfire::solver { * @param engine The engine to use for evaluating the network. */ explicit SingleZoneNetworkSolver( - const EngineT& engine, - const engine::scratch::StateBlob& ctx + const EngineT& engine ) : - m_engine(engine), - m_scratch_blob(ctx.clone_structure()) {}; + m_engine(engine) {}; /** * @brief Virtual destructor. @@ -67,58 +72,39 @@ namespace gridfire::solver { /** * @brief Evaluates the network for a given timestep. + * @param solver_ctx + * @param engine_ctx * @param netIn The input conditions for the network. * @return The output conditions after the timestep. */ - virtual NetOut evaluate(const NetIn& netIn) = 0; + virtual NetOut evaluate( + SolverContextBase& solver_ctx, + const NetIn& netIn + ) const = 0; - /** - * @brief set the callback function to be called at the end of each timestep. - * - * This function allows the user to set a callback function that will be called at the end of each timestep. - * The callback function will receive a gridfire::solver::::TimestepContext object. Note that - * depending on the solver, this context may contain different information. Further, the exact - * signature of the callback function is left up to each solver. Every solver should provide a type or type alias - * TimestepCallback that defines the signature of the callback function so that the user can easily - * get that type information. - * - * @param callback The callback function to be called at the end of each timestep. - */ - virtual void set_callback(const std::any& callback) = 0; - - /** - * @brief Describe the context that will be passed to the callback function. - * @return A vector of tuples, each containing a string for the parameter's name and a string for its type. - * - * This method should be overridden by derived classes to provide a description of the context - * that will be passed to the callback function. The intent of this method is that an end user can investigate - * the context that will be passed to the callback function, and use this information to craft their own - * callback function. - */ - [[nodiscard]] virtual std::vector> describe_callback_context() const = 0; protected: const EngineT& m_engine; ///< The engine used by this solver strategy. - std::unique_ptr m_scratch_blob; }; template class MultiZoneNetworkSolver { public: explicit MultiZoneNetworkSolver( - const EngineT& engine + const EngineT& engine, + const SingleZoneNetworkSolver& solver ) : - m_engine(engine) {}; + m_engine(engine), + m_solver(solver) {}; virtual ~MultiZoneNetworkSolver() = default; virtual std::vector evaluate( - const std::vector& netIns, - const std::vector& mass_coords, const engine::scratch::StateBlob &ctx_template - ) = 0; - virtual void set_callback(const std::any& callback) = 0; - [[nodiscard]] virtual std::vector> describe_callback_context() const = 0; + SolverContextBase& solver_ctx, + const std::vector& netIns + ) const = 0; protected: const EngineT& m_engine; ///< The engine used by this solver strategy. + const SingleZoneNetworkSolver& m_solver; }; /** diff --git a/src/include/gridfire/solver/strategies/triggers/engine_partitioning_trigger.h b/src/include/gridfire/solver/strategies/triggers/engine_partitioning_trigger.h index 93088596..98b7f0e2 100644 --- a/src/include/gridfire/solver/strategies/triggers/engine_partitioning_trigger.h +++ b/src/include/gridfire/solver/strategies/triggers/engine_partitioning_trigger.h @@ -2,7 +2,7 @@ #include "gridfire/trigger/trigger_abstract.h" #include "gridfire/trigger/trigger_result.h" -#include "gridfire/solver/strategies/CVODE_solver_strategy.h" +#include "gridfire/solver/strategies/PointSolver.h" #include "fourdst/logging/logging.h" #include @@ -47,7 +47,7 @@ namespace gridfire::trigger::solver::CVODE { * * See also: engine_partitioning_trigger.cpp for the concrete logic and logging. */ - class SimulationTimeTrigger final : public Trigger { + class SimulationTimeTrigger final : public Trigger { public: /** * @brief Construct with a positive time interval between firings. @@ -62,7 +62,7 @@ namespace gridfire::trigger::solver::CVODE { * * @post increments hit/miss counters and may emit trace logs. */ - bool check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override; + bool check(const gridfire::solver::PointSolverTimestepContext &ctx) const override; /** * @brief Update internal state; if check(ctx) is true, advance last_trigger_time. * @param ctx CVODE timestep context. @@ -70,9 +70,9 @@ namespace gridfire::trigger::solver::CVODE { * @note update() calls check(ctx) and, on success, records the overshoot delta * (ctx.t - last_trigger_time) - interval for diagnostics. */ - void update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override; + void update(const gridfire::solver::PointSolverTimestepContext &ctx) override; - void step(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override; + void step(const gridfire::solver::PointSolverTimestepContext &ctx) override; /** * @brief Reset counters and last trigger bookkeeping (time and delta) to zero. */ @@ -85,7 +85,7 @@ namespace gridfire::trigger::solver::CVODE { * @param ctx CVODE timestep context. * @return TriggerResult including name, value, and description. */ - TriggerResult why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override; + TriggerResult why(const gridfire::solver::PointSolverTimestepContext &ctx) const override; /** @brief Textual description including configured interval. */ std::string describe() const override; /** @brief Number of true evaluations since last reset. */ @@ -130,7 +130,7 @@ namespace gridfire::trigger::solver::CVODE { * @par See also * - engine_partitioning_trigger.cpp for concrete logic and trace logging. */ - class OffDiagonalTrigger final : public Trigger { + class OffDiagonalTrigger final : public Trigger { public: /** * @brief Construct with a non-negative magnitude threshold. @@ -145,13 +145,13 @@ namespace gridfire::trigger::solver::CVODE { * * @post increments hit/miss counters and may emit trace logs. */ - bool check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override; + bool check(const gridfire::solver::PointSolverTimestepContext &ctx) const override; /** * @brief Record an update; does not mutate any Jacobian-related state. * @param ctx CVODE timestep context (unused except for symmetry with interface). */ - void update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override; - void step(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override; + void update(const gridfire::solver::PointSolverTimestepContext &ctx) override; + void step(const gridfire::solver::PointSolverTimestepContext &ctx) override; /** @brief Reset counters to zero. */ void reset() override; @@ -161,7 +161,7 @@ namespace gridfire::trigger::solver::CVODE { * @brief Structured explanation of the evaluation outcome. * @param ctx CVODE timestep context. */ - TriggerResult why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override; + TriggerResult why(const gridfire::solver::PointSolverTimestepContext &ctx) const override; /** @brief Textual description including configured threshold. */ std::string describe() const override; /** @brief Number of true evaluations since last reset. */ @@ -206,7 +206,7 @@ namespace gridfire::trigger::solver::CVODE { * * See also: engine_partitioning_trigger.cpp for exact logic and logging. */ - class TimestepCollapseTrigger final : public Trigger { + class TimestepCollapseTrigger final : public Trigger { public: /** * @brief Construct with threshold and relative/absolute mode; window size defaults to 1. @@ -230,20 +230,20 @@ namespace gridfire::trigger::solver::CVODE { * * @post increments hit/miss counters and may emit trace logs. */ - bool check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override; + bool check(const gridfire::solver::PointSolverTimestepContext &ctx) const override; /** * @brief Update sliding window with the most recent dt and increment update counter. * @param ctx CVODE timestep context. */ - void update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override; - void step(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override; + void update(const gridfire::solver::PointSolverTimestepContext &ctx) override; + void step(const gridfire::solver::PointSolverTimestepContext &ctx) override; /** @brief Reset counters and clear the dt window. */ void reset() override; /** @brief Stable human-readable name. */ std::string name() const override; /** @brief Structured explanation of the evaluation outcome. */ - TriggerResult why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override; + TriggerResult why(const gridfire::solver::PointSolverTimestepContext &ctx) const override; /** @brief Textual description including threshold, mode, and window size. */ std::string describe() const override; /** @brief Number of true evaluations since last reset. */ @@ -272,15 +272,15 @@ namespace gridfire::trigger::solver::CVODE { std::deque m_timestep_window; }; - class ConvergenceFailureTrigger final : public Trigger { + class ConvergenceFailureTrigger final : public Trigger { public: explicit ConvergenceFailureTrigger(size_t totalFailures, float relativeFailureRate, size_t windowSize); - bool check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override; + bool check(const gridfire::solver::PointSolverTimestepContext &ctx) const override; - void update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override; + void update(const gridfire::solver::PointSolverTimestepContext &ctx) override; - void step(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override; + void step(const gridfire::solver::PointSolverTimestepContext &ctx) override; void reset() override; @@ -288,7 +288,7 @@ namespace gridfire::trigger::solver::CVODE { [[nodiscard]] std::string describe() const override; - [[nodiscard]] TriggerResult why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override; + [[nodiscard]] TriggerResult why(const gridfire::solver::PointSolverTimestepContext &ctx) const override; [[nodiscard]] size_t numTriggers() const override; @@ -312,8 +312,8 @@ namespace gridfire::trigger::solver::CVODE { private: float current_mean() const; - bool abs_failure(const gridfire::solver::CVODESolverStrategy::TimestepContext& ctx) const; - bool rel_failure(const gridfire::solver::CVODESolverStrategy::TimestepContext& ctx) const; + bool abs_failure(const gridfire::solver::PointSolverTimestepContext& ctx) const; + bool rel_failure(const gridfire::solver::PointSolverTimestepContext& ctx) const; }; /** @@ -337,10 +337,10 @@ namespace gridfire::trigger::solver::CVODE { * * @note The exact policy is subject to change; this function centralizes that decision. */ - std::unique_ptr> makeEnginePartitioningTrigger( - const double simulationTimeInterval, - const double offDiagonalThreshold, - const double timestepCollapseRatio, - const size_t maxConvergenceFailures + std::unique_ptr> makeEnginePartitioningTrigger( + double simulationTimeInterval, + double offDiagonalThreshold, + double timestepCollapseRatio, + size_t maxConvergenceFailures ); } diff --git a/src/include/gridfire/utils/gf_omp.h b/src/include/gridfire/utils/gf_omp.h index a52ff2cf..07b645ec 100644 --- a/src/include/gridfire/utils/gf_omp.h +++ b/src/include/gridfire/utils/gf_omp.h @@ -30,6 +30,7 @@ namespace gridfire::omp { ); CppAD::thread_alloc::hold_memory(true); + CppAD::CheckSimpleVector>(0, 1); s_par_mode_initialized = true; } } diff --git a/src/lib/engine/engine_graph.cpp b/src/lib/engine/engine_graph.cpp index 523c57b1..de8195f2 100644 --- a/src/lib/engine/engine_graph.cpp +++ b/src/lib/engine/engine_graph.cpp @@ -118,8 +118,7 @@ namespace gridfire::engine { m_weakRateInterpolator(rates::weak::UNIFIED_WEAK_DATA), m_reactions(build_nuclear_network(composition, m_weakRateInterpolator, buildDepth, reactionTypes)), m_partitionFunction(partitionFunction.clone()), - m_depth(buildDepth), - m_state_blob_offset(0) // For a base engine the offset is always 0 + m_depth(buildDepth) { syncInternalMaps(); } @@ -128,8 +127,7 @@ namespace gridfire::engine { const reaction::ReactionSet &reactions ) : m_weakRateInterpolator(rates::weak::UNIFIED_WEAK_DATA), - m_reactions(reactions), - m_state_blob_offset(0) + m_reactions(reactions) { syncInternalMaps(); } diff --git a/src/lib/engine/procedures/priming.cpp b/src/lib/engine/procedures/priming.cpp index 0d4756ae..547f1630 100644 --- a/src/lib/engine/procedures/priming.cpp +++ b/src/lib/engine/procedures/priming.cpp @@ -2,7 +2,6 @@ #include "fourdst/atomic/species.h" #include "fourdst/composition/utils.h" -#include "gridfire/engine/views/engine_priming.h" #include "gridfire/solver/solver.h" #include "gridfire/engine/engine_abstract.h" @@ -13,7 +12,7 @@ #include "gridfire/engine/scratchpads/engine_graph_scratchpad.h" #include "fourdst/logging/logging.h" -#include "gridfire/solver/strategies/CVODE_solver_strategy.h" +#include "gridfire/solver/strategies/PointSolver.h" #include "quill/Logger.h" #include "quill/LogMacros.h" @@ -28,13 +27,12 @@ namespace gridfire::engine { const GraphEngine& engine, const std::optional>& ignoredReactionTypes ) { const auto logger = LogManager::getInstance().getLogger("log"); - solver::CVODESolverStrategy integrator(engine, ctx); + solver::PointSolver integrator(engine); + solver::PointSolverContext solverCtx(ctx); + solverCtx.abs_tol = 1e-3; + solverCtx.rel_tol = 1e-3; + solverCtx.stdout_logging = false; - // Do not need high precision for priming - integrator.set_absTol(1e-3); - integrator.set_relTol(1e-3); - - integrator.set_stdout_logging_enabled(false); NetIn solverInput(netIn); solverInput.tMax = 1e-15; @@ -43,7 +41,7 @@ namespace gridfire::engine { LOG_INFO(logger, "Short timescale ({}) network ignition started.", solverInput.tMax); PrimingReport report; try { - const NetOut netOut = integrator.evaluate(solverInput, false); + const NetOut netOut = integrator.evaluate(solverCtx, solverInput); LOG_INFO(logger, "Network ignition completed."); LOG_TRACE_L2( logger, diff --git a/src/lib/engine/views/engine_multiscale.cpp b/src/lib/engine/views/engine_multiscale.cpp index 458aeea4..ec3d1a7c 100644 --- a/src/lib/engine/views/engine_multiscale.cpp +++ b/src/lib/engine/views/engine_multiscale.cpp @@ -2005,7 +2005,32 @@ namespace gridfire::engine { LOG_INFO(getLogger(), "KINSol failed to converge within the maximum number of iterations, but achieved acceptable accuracy with function norm {} < {}. Proceeding with solution.", fnorm, ACCEPTABLE_FTOL); } else { - LOG_WARNING(getLogger(), "KINSol failed to converge while solving QSE abundances with flag {}. Error {}", utils::kinsol_ret_code_map.at(flag), fnorm); + LOG_CRITICAL(getLogger(), "KINSol failed to converge while solving QSE abundances with flag {}. Flag No.: {}, Error (fNorm): {}", utils::kinsol_ret_code_map.at(flag), flag, fnorm); + LOG_CRITICAL(getLogger(), "State prior to failure: {}", + [&comp, &data]() -> std::string { + std::ostringstream oss; + oss << "Solve species: <"; + size_t count = 0; + for (const auto& species : data.qse_solve_species) { + oss << species.name(); + if (count < data.qse_solve_species.size() - 1) { + oss << ", "; + } + count++; + } + oss << "> | Abundances and rates at failure: "; + count = 0; + for (const auto& [species, abundance] : comp) { + oss << species.name() << ": Y = " << abundance; + if (count < comp.size() - 1) { + oss << ", "; + } + count++; + } + oss << " | Temperature: " << data.T9 << ", Density: " << data.rho; + return oss.str(); + }() + ); throw exceptions::InvalidQSESolutionError("KINSol failed to converge while solving QSE abundances. " + utils::kinsol_ret_code_map.at(flag)); } } diff --git a/src/lib/solver/strategies/GridSolver.cpp b/src/lib/solver/strategies/GridSolver.cpp new file mode 100644 index 00000000..f69f58da --- /dev/null +++ b/src/lib/solver/strategies/GridSolver.cpp @@ -0,0 +1,94 @@ +#include "gridfire/solver/strategies/GridSolver.h" + +#include "gridfire/exceptions/error_solver.h" +#include "gridfire/solver/strategies/PointSolver.h" +#include "gridfire/utils/macros.h" +#include "gridfire/utils/gf_omp.h" + +#include +#include + +namespace gridfire::solver { + void GridSolverContext::init() {} + void GridSolverContext::reset() { + solver_workspaces.clear(); + timestep_callbacks.clear(); + } + + void GridSolverContext::set_callback(const std::function &callback) { + for (auto &cb : timestep_callbacks) { + cb = callback; + } + } + + void GridSolverContext::set_callback(const std::function &callback, const size_t zone_idx) { + if (zone_idx >= timestep_callbacks.size()) { + throw exceptions::SolverError("GridSolverContext::set_callback: zone_idx out of range."); + } + timestep_callbacks[zone_idx] = callback; + } + + void GridSolverContext::set_stdout_logging(const bool enable) { + zone_stdout_logging = enable; + } + + void GridSolverContext::set_detailed_logging(const bool enable) { + zone_detailed_logging = enable; + } + + GridSolverContext::GridSolverContext( + const engine::scratch::StateBlob &ctx_template + ) : + ctx_template(ctx_template) {} + + + GridSolver::GridSolver( + const engine::DynamicEngine &engine, + const SingleZoneDynamicNetworkSolver &solver + ) : + MultiZoneNetworkSolver(engine, solver) { + GF_PAR_INIT(); + } + + std::vector GridSolver::evaluate( + SolverContextBase& ctx, + const std::vector& netIns + ) const { + auto* sctx_p = dynamic_cast(&ctx); + if (!sctx_p) { + throw exceptions::SolverError("GridSolver::evaluate: SolverContextBase is not of type GridSolverContext."); + } + + const size_t n_zones = netIns.size(); + if (n_zones == 0) { return {}; } + + std::vector results(n_zones); + + sctx_p->solver_workspaces.resize(n_zones); + + GF_OMP( + parallel for default(none) shared(sctx_p, n_zones), + for (size_t zone_idx = 0; zone_idx < n_zones; ++zone_idx)) { + sctx_p->solver_workspaces[zone_idx] = std::make_unique(sctx_p->ctx_template); + sctx_p->solver_workspaces[zone_idx]->set_stdout_logging(sctx_p->zone_stdout_logging); + sctx_p->solver_workspaces[zone_idx]->set_detailed_logging(sctx_p->zone_detailed_logging); + } + + GF_OMP( + parallel for default(none) shared(results, sctx_p, netIns, n_zones), + for (size_t zone_idx = 0; zone_idx < n_zones; ++zone_idx)) { + try { + results[zone_idx] = m_solver.evaluate( + *sctx_p->solver_workspaces[zone_idx], + netIns[zone_idx] + ); + } catch (exceptions::GridFireError& e) { + std::println("CVODE Solver Failure in zone {}: {}", zone_idx, e.what()); + } + if (sctx_p->zone_completion_logging) { + std::println("Thread {} completed zone {}", GF_OMP_THREAD_NUM, zone_idx); + } + } + return results; + } +} diff --git a/src/lib/solver/strategies/CVODE_solver_strategy.cpp b/src/lib/solver/strategies/PointSolver.cpp similarity index 78% rename from src/lib/solver/strategies/CVODE_solver_strategy.cpp rename to src/lib/solver/strategies/PointSolver.cpp index bb47b476..01d3aa59 100644 --- a/src/lib/solver/strategies/CVODE_solver_strategy.cpp +++ b/src/lib/solver/strategies/PointSolver.cpp @@ -1,4 +1,4 @@ -#include "gridfire/solver/strategies/CVODE_solver_strategy.h" +#include "gridfire/solver/strategies/PointSolver.h" #include "gridfire/types/types.h" #include "gridfire/utils/table_format.h" @@ -28,7 +28,7 @@ namespace gridfire::solver { using namespace gridfire::engine; - CVODESolverStrategy::TimestepContext::TimestepContext( + PointSolverTimestepContext::PointSolverTimestepContext( const double t, const N_Vector &state, const double dt, @@ -58,7 +58,7 @@ namespace gridfire::solver { state_ctx(ctx) {} - std::vector> CVODESolverStrategy::TimestepContext::describe() const { + std::vector> PointSolverTimestepContext::describe() const { std::vector> description; description.emplace_back("t", "Current Time"); description.emplace_back("state", "Current State Vector (N_Vector)"); @@ -74,36 +74,112 @@ namespace gridfire::solver { return description; } + void PointSolverContext::init() { + reset_all(); + init_context(); + } - CVODESolverStrategy::CVODESolverStrategy( - const DynamicEngine &engine, - const scratch::StateBlob& ctx - ): SingleZoneNetworkSolver(engine, ctx) { - // PERF: In order to support MPI this function must be changed - const int flag = SUNContext_Create(SUN_COMM_NULL, &m_sun_ctx); - if (flag < 0) { - throw std::runtime_error("Failed to create SUNDIALS context (SUNDIALS Errno: " + std::to_string(flag) + ")"); + void PointSolverContext::set_stdout_logging(const bool enable) { + stdout_logging = enable; + } + + void PointSolverContext::set_detailed_logging(const bool enable) { + detailed_step_logging = enable; + } + + void PointSolverContext::reset_all() { + reset_user(); + reset_cvode(); + } + + void PointSolverContext::reset_user() { + callback.reset(); + num_steps = 0; + stdout_logging = true; + abs_tol.reset(); + rel_tol.reset(); + detailed_step_logging = false; + last_size = 0; + last_composition_hash = 0ULL; + } + + void PointSolverContext::reset_cvode() { + if (LS) { + SUNLinSolFree(LS); + LS = nullptr; + } + if (J) { + SUNMatDestroy(J); + J = nullptr; + } + if (Y) { + N_VDestroy(Y); + Y = nullptr; + } + if (YErr) { + N_VDestroy(YErr); + YErr = nullptr; + } + if (constraints) { + N_VDestroy(constraints); + constraints = nullptr; + } + if (cvode_mem) { + CVodeFree(&cvode_mem); + cvode_mem = nullptr; } } - CVODESolverStrategy::~CVODESolverStrategy() { - LOG_TRACE_L1(m_logger, "Cleaning up CVODE resources..."); - cleanup_cvode_resources(true); - - if (m_sun_ctx) { - SUNContext_Free(&m_sun_ctx); + void PointSolverContext::clear_context() { + if (sun_ctx) { + SUNContext_Free(&sun_ctx); + sun_ctx = nullptr; } } - NetOut CVODESolverStrategy::evaluate(const NetIn& netIn) { - return evaluate(netIn, false); + void PointSolverContext::init_context() { + if (!sun_ctx) { + utils::check_sundials_flag(SUNContext_Create(SUN_COMM_NULL, &sun_ctx), "SUNContext_Create", utils::SUNDIALS_RET_CODE_TYPES::CVODE); + } } - NetOut CVODESolverStrategy::evaluate( + bool PointSolverContext::has_context() const { + return sun_ctx != nullptr; + } + + PointSolverContext::PointSolverContext( + const scratch::StateBlob& engine_ctx + ) : + engine_ctx(engine_ctx.clone_structure()) + { + utils::check_sundials_flag(SUNContext_Create(SUN_COMM_NULL, &sun_ctx), "SUNContext_Create", utils::SUNDIALS_RET_CODE_TYPES::CVODE); + } + + PointSolverContext::~PointSolverContext() { + reset_cvode(); + clear_context(); + } + + + PointSolver::PointSolver( + const DynamicEngine &engine + ): SingleZoneNetworkSolver(engine) {} + + NetOut PointSolver::evaluate( + SolverContextBase& solver_ctx, + const NetIn& netIn + ) const { + return evaluate(solver_ctx, netIn, false); + } + + NetOut PointSolver::evaluate( + SolverContextBase& solver_ctx, const NetIn &netIn, bool displayTrigger, bool forceReinitialize - ) { + ) const { + auto* sctx_p = dynamic_cast(&solver_ctx); + LOG_TRACE_L1(m_logger, "Starting solver evaluation with T9: {} and rho: {}", netIn.temperature/1e9, netIn.density); LOG_TRACE_L1(m_logger, "Building engine update trigger...."); auto trigger = trigger::solver::CVODE::makeEnginePartitioningTrigger(1e12, 1e10, 0.5, 2); @@ -117,23 +193,24 @@ namespace gridfire::solver { // 2. If the user has set tolerances in code, those override the config // 3. If the user has not set tolerances in code and the config does not have them, use hardcoded defaults - auto absTol = m_config->solver.cvode.absTol; - auto relTol = m_config->solver.cvode.relTol; - - if (m_absTol) { - absTol = *m_absTol; + if (!sctx_p->abs_tol.has_value()) { + sctx_p->abs_tol = m_config->solver.cvode.absTol; } - if (m_relTol) { - relTol = *m_relTol; + if (!sctx_p->rel_tol.has_value()) { + sctx_p->rel_tol = m_config->solver.cvode.relTol; } - bool resourcesExist = (m_cvode_mem != nullptr) && (m_Y != nullptr); - bool inconsistentComposition = netIn.composition.hash() != m_last_composition_hash; + bool resourcesExist = (sctx_p->cvode_mem != nullptr) && (sctx_p->Y != nullptr); + + bool inconsistentComposition = netIn.composition.hash() != sctx_p->last_composition_hash; fourdst::composition::Composition equilibratedComposition; if (forceReinitialize || !resourcesExist || inconsistentComposition) { - cleanup_cvode_resources(true); + sctx_p->reset_cvode(); + if (!sctx_p->has_context()) { + sctx_p->init_context(); + } LOG_INFO( m_logger, "Preforming full CVODE initialization (Reason: {})", @@ -141,26 +218,24 @@ namespace gridfire::solver { (!resourcesExist ? "CVODE resources do not exist" : "Input composition inconsistent with previous state")); LOG_TRACE_L1(m_logger, "Starting engine update chain..."); - equilibratedComposition = m_engine.project(*m_scratch_blob, netIn); + equilibratedComposition = m_engine.project(*sctx_p->engine_ctx, netIn); LOG_TRACE_L1(m_logger, "Engine updated and equilibrated composition found!"); - size_t numSpecies = m_engine.getNetworkSpecies(*m_scratch_blob).size(); + size_t numSpecies = m_engine.getNetworkSpecies(*sctx_p->engine_ctx).size(); uint64_t N = numSpecies + 1; LOG_TRACE_L1(m_logger, "Number of species: {} ({} independent variables)", numSpecies, N); LOG_TRACE_L1(m_logger, "Initializing CVODE resources"); - m_cvode_mem = CVodeCreate(CV_BDF, m_sun_ctx); - utils::check_cvode_flag(m_cvode_mem == nullptr ? -1 : 0, "CVodeCreate"); - initialize_cvode_integration_resources(N, numSpecies, 0.0, equilibratedComposition, absTol, relTol, 0.0); - m_last_size = N; + initialize_cvode_integration_resources(sctx_p, N, numSpecies, 0.0, equilibratedComposition, sctx_p->abs_tol.value(), sctx_p->rel_tol.value(), 0.0); + sctx_p->last_size = N; } else { - LOG_INFO(m_logger, "Reusing existing CVODE resources (size: {})", m_last_size); + LOG_INFO(m_logger, "Reusing existing CVODE resources (size: {})", sctx_p->last_size); - const size_t numSpecies = m_engine.getNetworkSpecies(*m_scratch_blob).size(); - sunrealtype *y_data = N_VGetArrayPointer(m_Y); + const size_t numSpecies = m_engine.getNetworkSpecies(*sctx_p->engine_ctx).size(); + sunrealtype *y_data = N_VGetArrayPointer(sctx_p->Y); for (size_t i = 0; i < numSpecies; i++) { - const auto& species = m_engine.getNetworkSpecies(*m_scratch_blob)[i]; + const auto& species = m_engine.getNetworkSpecies(*sctx_p->engine_ctx)[i]; if (netIn.composition.contains(species)) { y_data[i] = netIn.composition.getMolarAbundance(species); } else { @@ -168,16 +243,17 @@ namespace gridfire::solver { } } y_data[numSpecies] = 0.0; // Reset energy accumulator - utils::check_cvode_flag(CVodeSStolerances(m_cvode_mem, relTol, absTol), "CVodeSStolerances"); - utils::check_cvode_flag(CVodeReInit(m_cvode_mem, 0.0, m_Y), "CVodeReInit"); + utils::check_cvode_flag(CVodeSStolerances(sctx_p->cvode_mem, sctx_p->rel_tol.value(), sctx_p->abs_tol.value()), "CVodeSStolerances"); + utils::check_cvode_flag(CVodeReInit(sctx_p->cvode_mem, 0.0, sctx_p->Y), "CVodeReInit"); equilibratedComposition = netIn.composition; // Use the provided composition as-is if we already have validated CVODE resources and that the composition is consistent with the previous state } - size_t numSpecies = m_engine.getNetworkSpecies(*m_scratch_blob).size(); + size_t numSpecies = m_engine.getNetworkSpecies(*sctx_p->engine_ctx).size(); CVODEUserData user_data { .solver_instance = this, - .ctx = *m_scratch_blob, + .sctx = sctx_p, + .ctx = *sctx_p->engine_ctx, .engine = &m_engine, }; LOG_TRACE_L1(m_logger, "CVODE resources successfully initialized!"); @@ -185,7 +261,7 @@ namespace gridfire::solver { double current_time = 0; // ReSharper disable once CppTooWideScope [[maybe_unused]] double last_callback_time = 0; - m_num_steps = 0; + sctx_p->num_steps = 0; double accumulated_energy = 0.0; double accumulated_neutrino_energy_loss = 0.0; @@ -205,13 +281,13 @@ namespace gridfire::solver { while (current_time < netIn.tMax) { user_data.T9 = T9; user_data.rho = netIn.density; - user_data.networkSpecies = &m_engine.getNetworkSpecies(*m_scratch_blob); + user_data.networkSpecies = &m_engine.getNetworkSpecies(*sctx_p->engine_ctx); user_data.captured_exception.reset(); - utils::check_cvode_flag(CVodeSetUserData(m_cvode_mem, &user_data), "CVodeSetUserData"); + utils::check_cvode_flag(CVodeSetUserData(sctx_p->cvode_mem, &user_data), "CVodeSetUserData"); LOG_TRACE_L2(m_logger, "Taking one CVODE step..."); - int flag = CVode(m_cvode_mem, netIn.tMax, m_Y, ¤t_time, CV_ONE_STEP); + int flag = CVode(sctx_p->cvode_mem, netIn.tMax, sctx_p->Y, ¤t_time, CV_ONE_STEP); LOG_TRACE_L2(m_logger, "CVODE step complete. Current time: {}, step status: {}", current_time, utils::cvode_ret_code_map.at(flag)); if (user_data.captured_exception){ @@ -223,13 +299,13 @@ namespace gridfire::solver { long int n_steps; double last_step_size; - CVodeGetNumSteps(m_cvode_mem, &n_steps); - CVodeGetLastStep(m_cvode_mem, &last_step_size); + CVodeGetNumSteps(sctx_p->cvode_mem, &n_steps); + CVodeGetLastStep(sctx_p->cvode_mem, &last_step_size); long int nliters, nlcfails; - CVodeGetNumNonlinSolvIters(m_cvode_mem, &nliters); - CVodeGetNumNonlinSolvConvFails(m_cvode_mem, &nlcfails); + CVodeGetNumNonlinSolvIters(sctx_p->cvode_mem, &nliters); + CVodeGetNumNonlinSolvConvFails(sctx_p->cvode_mem, &nlcfails); - sunrealtype* y_data = N_VGetArrayPointer(m_Y); + sunrealtype* y_data = N_VGetArrayPointer(sctx_p->Y); const double current_energy = y_data[numSpecies]; // Specific energy rate // TODO: Accumulate neutrino loss through the state vector directly which will allow CVODE to properly integrate it @@ -238,7 +314,7 @@ namespace gridfire::solver { size_t iter_diff = (total_nonlinear_iterations + nliters) - prev_nonlinear_iterations; size_t convFail_diff = (total_convergence_failures + nlcfails) - prev_convergence_failures; - if (m_stdout_logging_enabled) { + if (sctx_p->stdout_logging) { std::println( "Step: {:6} | Updates: {:3} | Epoch Steps: {:4} | t: {:.3e} [s] | dt: {:15.6E} [s] | Iterations: {:6} (+{:2}) | Total Convergence Failures: {:2} (+{:2})", total_steps + n_steps, @@ -253,20 +329,16 @@ namespace gridfire::solver { ); } for (size_t i = 0; i < numSpecies; ++i) { - const auto& species = m_engine.getNetworkSpecies(*m_scratch_blob)[i]; + const auto& species = m_engine.getNetworkSpecies(*sctx_p->engine_ctx)[i]; if (y_data[i] > 0.0) { postStep.setMolarAbundance(species, y_data[i]); } } - // fourdst::composition::Composition collectedComposition = m_engine.collectComposition(postStep, netIn.temperature/1e9, netIn.density); - // for (size_t i = 0; i < numSpecies; ++i) { - // y_data[i] = collectedComposition.getMolarAbundance(m_engine.getNetworkSpecies()[i]); - // } LOG_INFO(m_logger, "Completed {:5} steps to time {:10.4E} [s] (dt = {:15.6E} [s]). Current specific energy: {:15.6E} [erg/g]", total_steps + n_steps, current_time, last_step_size, current_energy); LOG_DEBUG(m_logger, "Current composition (molar abundance): {}", [&]() -> std::string { std::stringstream ss; for (size_t i = 0; i < numSpecies; ++i) { - const auto& species = m_engine.getNetworkSpecies(*m_scratch_blob)[i]; + const auto& species = m_engine.getNetworkSpecies(*sctx_p->engine_ctx)[i]; ss << species.name() << ": (y_data = " << y_data[i] << ", collected = " << postStep.getMolarAbundance(species) << ")"; if (i < numSpecies - 1) { ss << ", "; @@ -282,36 +354,44 @@ namespace gridfire::solver { ? user_data.reaction_contribution_map.value() : kEmptyMap; - auto ctx = TimestepContext( + auto ctx = PointSolverTimestepContext( current_time, - m_Y, + sctx_p->Y, last_step_size, last_callback_time, T9, netIn.density, n_steps, m_engine, - m_engine.getNetworkSpecies(*m_scratch_blob), + m_engine.getNetworkSpecies(*sctx_p->engine_ctx), convFail_diff, iter_diff, rcMap, - *m_scratch_blob + *sctx_p->engine_ctx ); prev_nonlinear_iterations = nliters + total_nonlinear_iterations; prev_convergence_failures = nlcfails + total_convergence_failures; - if (m_callback.has_value()) { - m_callback.value()(ctx); + if (sctx_p->callback.has_value()) { + sctx_p->callback.value()(ctx); } trigger->step(ctx); - if (m_detailed_step_logging) { - log_step_diagnostics(*m_scratch_blob, user_data, true, true, true, "step_" + std::to_string(total_steps + n_steps) + ".json"); + if (sctx_p->detailed_step_logging) { + log_step_diagnostics( + sctx_p, + *sctx_p->engine_ctx, + user_data, + true, + true, + true, + "step_" + std::to_string(total_steps + n_steps) + ".json" + ); } if (trigger->check(ctx)) { - if (m_stdout_logging_enabled && displayTrigger) { + if (sctx_p->stdout_logging && displayTrigger) { trigger::printWhy(trigger->why(ctx)); } trigger->update(ctx); @@ -333,20 +413,20 @@ namespace gridfire::solver { fourdst::composition::Composition temp_comp; std::vector mass_fractions; - auto num_species_at_stop = static_cast(m_engine.getNetworkSpecies(*m_scratch_blob).size()); + auto num_species_at_stop = static_cast(m_engine.getNetworkSpecies(*sctx_p->engine_ctx).size()); - if (num_species_at_stop > m_Y->ops->nvgetlength(m_Y) - 1) { + if (num_species_at_stop > sctx_p->Y->ops->nvgetlength(sctx_p->Y) - 1) { LOG_ERROR( m_logger, "Number of species at engine update ({}) exceeds the number of species in the CVODE solver ({}). This should never happen.", num_species_at_stop, - m_Y->ops->nvgetlength(m_Y) - 1 // -1 due to energy in the last index + sctx_p->Y->ops->nvgetlength(sctx_p->Y) - 1 // -1 due to energy in the last index ); throw std::runtime_error("Number of species at engine update exceeds the number of species in the CVODE solver. This should never happen."); } - for (const auto& species: m_engine.getNetworkSpecies(*m_scratch_blob)) { - const size_t sid = m_engine.getSpeciesIndex(*m_scratch_blob, species); + for (const auto& species: m_engine.getNetworkSpecies(*sctx_p->engine_ctx)) { + const size_t sid = m_engine.getSpeciesIndex(*sctx_p->engine_ctx, species); temp_comp.registerSpecies(species); double y = end_of_step_abundances[sid]; if (y > 0.0) { @@ -356,7 +436,7 @@ namespace gridfire::solver { #ifndef NDEBUG for (long int i = 0; i < num_species_at_stop; ++i) { - const auto& species = m_engine.getNetworkSpecies(*m_scratch_blob)[i]; + const auto& species = m_engine.getNetworkSpecies(*sctx_p->engine_ctx)[i]; if (std::abs(temp_comp.getMolarAbundance(species) - y_data[i]) > 1e-12) { throw exceptions::UtilityError("Conversion from solver state to composition molar abundance failed verification."); } @@ -391,7 +471,7 @@ namespace gridfire::solver { "Prior to Engine Update active reactions are: {}", [&]() -> std::string { std::stringstream ss; - const gridfire::reaction::ReactionSet& reactions = m_engine.getNetworkReactions(*m_scratch_blob); + const gridfire::reaction::ReactionSet& reactions = m_engine.getNetworkReactions(*sctx_p->engine_ctx); size_t count = 0; for (const auto& reaction : reactions) { ss << reaction -> id(); @@ -403,7 +483,7 @@ namespace gridfire::solver { return ss.str(); }() ); - fourdst::composition::Composition currentComposition = m_engine.project(*m_scratch_blob, netInTemp); + fourdst::composition::Composition currentComposition = m_engine.project(*sctx_p->engine_ctx, netInTemp); LOG_DEBUG( m_logger, "After to Engine update composition is (molar abundance) {}", @@ -450,7 +530,7 @@ namespace gridfire::solver { "After Engine Update active reactions are: {}", [&]() -> std::string { std::stringstream ss; - const gridfire::reaction::ReactionSet& reactions = m_engine.getNetworkReactions(*m_scratch_blob); + const gridfire::reaction::ReactionSet& reactions = m_engine.getNetworkReactions(*sctx_p->engine_ctx); size_t count = 0; for (const auto& reaction : reactions) { ss << reaction -> id(); @@ -466,34 +546,29 @@ namespace gridfire::solver { m_logger, "Due to a triggered engine update the composition was updated from size {} to {} species.", num_species_at_stop, - m_engine.getNetworkSpecies(*m_scratch_blob).size() + m_engine.getNetworkSpecies(*sctx_p->engine_ctx).size() ); - numSpecies = m_engine.getNetworkSpecies(*m_scratch_blob).size(); + numSpecies = m_engine.getNetworkSpecies(*sctx_p->engine_ctx).size(); size_t N = numSpecies + 1; LOG_INFO(m_logger, "Starting CVODE reinitialization after engine update..."); - cleanup_cvode_resources(true); + sctx_p->reset_cvode(); + initialize_cvode_integration_resources(sctx_p, N, numSpecies, current_time, currentComposition, sctx_p->abs_tol.value(), sctx_p->rel_tol.value(), accumulated_energy); - m_cvode_mem = CVodeCreate(CV_BDF, m_sun_ctx); - utils::check_cvode_flag(m_cvode_mem == nullptr ? -1 : 0, "CVodeCreate"); - - initialize_cvode_integration_resources(N, numSpecies, current_time, currentComposition, absTol, relTol, accumulated_energy); - - utils::check_cvode_flag(CVodeReInit(m_cvode_mem, current_time, m_Y), "CVodeReInit"); - // throw exceptions::DebugException("Debug"); + utils::check_cvode_flag(CVodeReInit(sctx_p->cvode_mem, current_time, sctx_p->Y), "CVodeReInit"); LOG_INFO(m_logger, "Done reinitializing CVODE after engine update. The next log messages will be from the first step after reinitialization..."); } } - if (m_stdout_logging_enabled) { // Flush the buffer if standard out logging is enabled + if (sctx_p->stdout_logging) { // Flush the buffer if standard out logging is enabled std::cout << std::flush; } LOG_INFO(m_logger, "CVODE iteration complete"); - sunrealtype* y_data = N_VGetArrayPointer(m_Y); + sunrealtype* y_data = N_VGetArrayPointer(sctx_p->Y); accumulated_energy += y_data[numSpecies]; std::vector y_vec(y_data, y_data + numSpecies); @@ -505,7 +580,7 @@ namespace gridfire::solver { LOG_INFO(m_logger, "Constructing final composition= with {} species", numSpecies); - fourdst::composition::Composition topLevelComposition(m_engine.getNetworkSpecies(*m_scratch_blob), y_vec); + fourdst::composition::Composition topLevelComposition(m_engine.getNetworkSpecies(*sctx_p->engine_ctx), y_vec); LOG_INFO(m_logger, "Final composition constructed from solver state successfully! ({})", [&topLevelComposition]() -> std::string { std::ostringstream ss; size_t i = 0; @@ -520,7 +595,7 @@ namespace gridfire::solver { }()); LOG_INFO(m_logger, "Collecting final composition..."); - fourdst::composition::Composition outputComposition = m_engine.collectComposition(*m_scratch_blob, topLevelComposition, netIn.temperature/1e9, netIn.density); + fourdst::composition::Composition outputComposition = m_engine.collectComposition(*sctx_p->engine_ctx, topLevelComposition, netIn.temperature/1e9, netIn.density); assert(outputComposition.getRegisteredSymbols().size() == equilibratedComposition.getRegisteredSymbols().size()); @@ -541,11 +616,11 @@ namespace gridfire::solver { NetOut netOut; netOut.composition = outputComposition; netOut.energy = accumulated_energy; - utils::check_cvode_flag(CVodeGetNumSteps(m_cvode_mem, reinterpret_cast(&netOut.num_steps)), "CVodeGetNumSteps"); + utils::check_cvode_flag(CVodeGetNumSteps(sctx_p->cvode_mem, reinterpret_cast(&netOut.num_steps)), "CVodeGetNumSteps"); LOG_TRACE_L2(m_logger, "generating final nuclear energy generation rate derivatives..."); auto [dEps_dT, dEps_dRho] = m_engine.calculateEpsDerivatives( - *m_scratch_blob, + *sctx_p->engine_ctx, outputComposition, T9, netIn.density @@ -559,53 +634,13 @@ namespace gridfire::solver { LOG_TRACE_L2(m_logger, "Output data built!"); LOG_TRACE_L2(m_logger, "Solver evaluation complete!."); - m_last_composition_hash = netOut.composition.hash(); - m_last_size = netOut.composition.size() + 1; - CVodeGetLastStep(m_cvode_mem, &m_last_good_time_step); + sctx_p->last_composition_hash = netOut.composition.hash(); + sctx_p->last_size = netOut.composition.size() + 1; + CVodeGetLastStep(sctx_p->cvode_mem, &sctx_p->last_good_time_step); return netOut; } - void CVODESolverStrategy::set_callback(const std::any &callback) { - m_callback = std::any_cast(callback); - } - - bool CVODESolverStrategy::get_stdout_logging_enabled() const { - return m_stdout_logging_enabled; - } - - void CVODESolverStrategy::set_stdout_logging_enabled(const bool logging_enabled) { - m_stdout_logging_enabled = logging_enabled; - } - - void CVODESolverStrategy::set_absTol(double absTol) { - m_absTol = absTol; - } - - void CVODESolverStrategy::set_relTol(double relTol) { - m_relTol = relTol; - } - - double CVODESolverStrategy::get_absTol() const { - if (m_absTol.has_value()) { - return m_absTol.value(); - } else { - return -1.0; - } - } - - double CVODESolverStrategy::get_relTol() const { - if (m_relTol.has_value()) { - return m_relTol.value(); - } else { - return -1.0; - } - } - - std::vector> CVODESolverStrategy::describe_callback_context() const { - return {}; - } - - int CVODESolverStrategy::cvode_rhs_wrapper( + int PointSolver::cvode_rhs_wrapper( const sunrealtype t, const N_Vector y, const N_Vector ydot, @@ -633,7 +668,7 @@ namespace gridfire::solver { } } - int CVODESolverStrategy::cvode_jac_wrapper( + int PointSolver::cvode_jac_wrapper( sunrealtype t, N_Vector y, N_Vector ydot, @@ -754,7 +789,7 @@ namespace gridfire::solver { return 0; } - CVODESolverStrategy::CVODERHSOutputData CVODESolverStrategy::calculate_rhs( + PointSolver::CVODERHSOutputData PointSolver::calculate_rhs( const sunrealtype t, N_Vector y, N_Vector ydot, @@ -772,10 +807,10 @@ namespace gridfire::solver { } } std::vector y_vec(y_data, y_data + numSpecies); - fourdst::composition::Composition composition(m_engine.getNetworkSpecies(*m_scratch_blob), y_vec); + fourdst::composition::Composition composition(m_engine.getNetworkSpecies(data->ctx), y_vec); LOG_TRACE_L2(m_logger, "Calculating RHS at time {} with {} species in composition", t, composition.size()); - const auto result = m_engine.calculateRHSAndEnergy(*m_scratch_blob, composition, data->T9, data->rho, false); + const auto result = m_engine.calculateRHSAndEnergy(data->ctx, composition, data->T9, data->rho, false); if (!result) { LOG_CRITICAL(m_logger, "Failed to calculate RHS at time {}: {}", t, EngineStatus_to_string(result.error())); throw exceptions::BadRHSEngineError(std::format("Failed to calculate RHS at time {}: {}", t, EngineStatus_to_string(result.error()))); @@ -805,7 +840,7 @@ namespace gridfire::solver { }()); for (size_t i = 0; i < numSpecies; ++i) { - fourdst::atomic::Species species = m_engine.getNetworkSpecies(*m_scratch_blob)[i]; + fourdst::atomic::Species species = m_engine.getNetworkSpecies(data->ctx)[i]; ydot_data[i] = dydt.at(species); } ydot_data[numSpecies] = nuclearEnergyGenerationRate; // Set the last element to the specific energy rate @@ -813,7 +848,8 @@ namespace gridfire::solver { return {reactionContributions, result.value().neutrinoEnergyLossRate, result.value().totalNeutrinoFlux}; } - void CVODESolverStrategy::initialize_cvode_integration_resources( + void PointSolver::initialize_cvode_integration_resources( + PointSolverContext* sctx_p, const uint64_t N, const size_t numSpecies, const double current_time, @@ -821,16 +857,18 @@ namespace gridfire::solver { const double absTol, const double relTol, const double accumulatedEnergy - ) { + ) const { LOG_TRACE_L2(m_logger, "Initializing CVODE integration resources with N: {}, current_time: {}, absTol: {}, relTol: {}", N, current_time, absTol, relTol); - cleanup_cvode_resources(false); // Cleanup any existing resources before initializing new ones + sctx_p->reset_cvode(); - m_Y = utils::init_sun_vector(N, m_sun_ctx); - m_YErr = N_VClone(m_Y); + sctx_p->cvode_mem = CVodeCreate(CV_BDF, sctx_p->sun_ctx); + utils::check_cvode_flag(sctx_p->cvode_mem == nullptr ? -1 : 0, "CVodeCreate"); + sctx_p->Y = utils::init_sun_vector(N, sctx_p->sun_ctx); + sctx_p->YErr = N_VClone(sctx_p->Y); - sunrealtype *y_data = N_VGetArrayPointer(m_Y); + sunrealtype *y_data = N_VGetArrayPointer(sctx_p->Y); for (size_t i = 0; i < numSpecies; i++) { - const auto& species = m_engine.getNetworkSpecies(*m_scratch_blob)[i]; + const auto& species = m_engine.getNetworkSpecies(*sctx_p->engine_ctx)[i]; if (composition.contains(species)) { y_data[i] = composition.getMolarAbundance(species); } else { @@ -840,8 +878,8 @@ namespace gridfire::solver { y_data[numSpecies] = accumulatedEnergy; // Specific energy rate, initialized to zero - utils::check_cvode_flag(CVodeInit(m_cvode_mem, cvode_rhs_wrapper, current_time, m_Y), "CVodeInit"); - utils::check_cvode_flag(CVodeSStolerances(m_cvode_mem, relTol, absTol), "CVodeSStolerances"); + utils::check_cvode_flag(CVodeInit(sctx_p->cvode_mem, cvode_rhs_wrapper, current_time, sctx_p->Y), "CVodeInit"); + utils::check_cvode_flag(CVodeSStolerances(sctx_p->cvode_mem, relTol, absTol), "CVodeSStolerances"); // Constraints // We constrain the solution vector using CVODE's built in constraint flags as outlines on page 53 of the CVODE manual @@ -854,53 +892,30 @@ namespace gridfire::solver { // -2.0: The corresponding component of y is constrained to be < 0 // Here we use 1.0 for all species to ensure they remain non-negative. - m_constraints = N_VClone(m_Y); - if (m_constraints == nullptr) { + sctx_p->constraints = N_VClone(sctx_p->Y); + if (sctx_p->constraints == nullptr) { LOG_ERROR(m_logger, "Failed to create constraints vector for CVODE"); throw std::runtime_error("Failed to create constraints vector for CVODE"); } - N_VConst(1.0, m_constraints); // Set all constraints to >= 0 (note this is where the flag values are set) + N_VConst(1.0, sctx_p->constraints); // Set all constraints to >= 0 (note this is where the flag values are set) - utils::check_cvode_flag(CVodeSetConstraints(m_cvode_mem, m_constraints), "CVodeSetConstraints"); + utils::check_cvode_flag(CVodeSetConstraints(sctx_p->cvode_mem, sctx_p->constraints), "CVodeSetConstraints"); - utils::check_cvode_flag(CVodeSetMaxStep(m_cvode_mem, 1.0e20), "CVodeSetMaxStep"); + utils::check_cvode_flag(CVodeSetMaxStep(sctx_p->cvode_mem, 1.0e20), "CVodeSetMaxStep"); - m_J = SUNDenseMatrix(static_cast(N), static_cast(N), m_sun_ctx); - utils::check_cvode_flag(m_J == nullptr ? -1 : 0, "SUNDenseMatrix"); - m_LS = SUNLinSol_Dense(m_Y, m_J, m_sun_ctx); - utils::check_cvode_flag(m_LS == nullptr ? -1 : 0, "SUNLinSol_Dense"); + sctx_p->J = SUNDenseMatrix(static_cast(N), static_cast(N), sctx_p->sun_ctx); + utils::check_cvode_flag(sctx_p->J == nullptr ? -1 : 0, "SUNDenseMatrix"); + sctx_p->LS = SUNLinSol_Dense(sctx_p->Y, sctx_p->J, sctx_p->sun_ctx); + utils::check_cvode_flag(sctx_p->LS == nullptr ? -1 : 0, "SUNLinSol_Dense"); - utils::check_cvode_flag(CVodeSetLinearSolver(m_cvode_mem, m_LS, m_J), "CVodeSetLinearSolver"); - utils::check_cvode_flag(CVodeSetJacFn(m_cvode_mem, cvode_jac_wrapper), "CVodeSetJacFn"); + utils::check_cvode_flag(CVodeSetLinearSolver(sctx_p->cvode_mem, sctx_p->LS, sctx_p->J), "CVodeSetLinearSolver"); + utils::check_cvode_flag(CVodeSetJacFn(sctx_p->cvode_mem, cvode_jac_wrapper), "CVodeSetJacFn"); LOG_TRACE_L2(m_logger, "CVODE solver initialized"); } - void CVODESolverStrategy::cleanup_cvode_resources(const bool memFree) { - LOG_TRACE_L2(m_logger, "Cleaning up cvode resources"); - if (m_LS) SUNLinSolFree(m_LS); - if (m_J) SUNMatDestroy(m_J); - if (m_Y) N_VDestroy(m_Y); - if (m_YErr) N_VDestroy(m_YErr); - if (m_constraints) N_VDestroy(m_constraints); - m_LS = nullptr; - m_J = nullptr; - m_Y = nullptr; - m_YErr = nullptr; - m_constraints = nullptr; - - if (memFree) { - if (m_cvode_mem) CVodeFree(&m_cvode_mem); - m_cvode_mem = nullptr; - } - LOG_TRACE_L2(m_logger, "Done Cleaning up cvode resources"); - } - - void CVODESolverStrategy::set_detailed_step_logging(const bool enabled) { - m_detailed_step_logging = enabled; - } - - void CVODESolverStrategy::log_step_diagnostics( + void PointSolver::log_step_diagnostics( + PointSolverContext* sctx_p, scratch::StateBlob &ctx, const CVODEUserData &user_data, bool displayJacobianStiffness, @@ -916,10 +931,10 @@ namespace gridfire::solver { sunrealtype hlast, hcur, tcur; int qlast; - utils::check_cvode_flag(CVodeGetLastStep(m_cvode_mem, &hlast), "CVodeGetLastStep"); - utils::check_cvode_flag(CVodeGetCurrentStep(m_cvode_mem, &hcur), "CVodeGetCurrentStep"); - utils::check_cvode_flag(CVodeGetLastOrder(m_cvode_mem, &qlast), "CVodeGetLastOrder"); - utils::check_cvode_flag(CVodeGetCurrentTime(m_cvode_mem, &tcur), "CVodeGetCurrentTime"); + utils::check_cvode_flag(CVodeGetLastStep(sctx_p->cvode_mem, &hlast), "CVodeGetLastStep"); + utils::check_cvode_flag(CVodeGetCurrentStep(sctx_p->cvode_mem, &hcur), "CVodeGetCurrentStep"); + utils::check_cvode_flag(CVodeGetLastOrder(sctx_p->cvode_mem, &qlast), "CVodeGetLastOrder"); + utils::check_cvode_flag(CVodeGetCurrentTime(sctx_p->cvode_mem, &tcur), "CVodeGetCurrentTime"); nlohmann::json j; { @@ -941,13 +956,13 @@ namespace gridfire::solver { // These are the CRITICAL counters for diagnosing your problem long int nsteps, nfevals, nlinsetups, netfails, nniters, nconvfails, nsetfails; - utils::check_cvode_flag(CVodeGetNumSteps(m_cvode_mem, &nsteps), "CVodeGetNumSteps"); - utils::check_cvode_flag(CVodeGetNumRhsEvals(m_cvode_mem, &nfevals), "CVodeGetNumRhsEvals"); - utils::check_cvode_flag(CVodeGetNumLinSolvSetups(m_cvode_mem, &nlinsetups), "CVodeGetNumLinSolvSetups"); - utils::check_cvode_flag(CVodeGetNumErrTestFails(m_cvode_mem, &netfails), "CVodeGetNumErrTestFails"); - utils::check_cvode_flag(CVodeGetNumNonlinSolvIters(m_cvode_mem, &nniters), "CVodeGetNumNonlinSolvIters"); - utils::check_cvode_flag(CVodeGetNumNonlinSolvConvFails(m_cvode_mem, &nconvfails), "CVodeGetNumNonlinSolvConvFails"); - utils::check_cvode_flag(CVodeGetNumLinConvFails(m_cvode_mem, &nsetfails), "CVodeGetNumLinConvFails"); + utils::check_cvode_flag(CVodeGetNumSteps(sctx_p->cvode_mem, &nsteps), "CVodeGetNumSteps"); + utils::check_cvode_flag(CVodeGetNumRhsEvals(sctx_p->cvode_mem, &nfevals), "CVodeGetNumRhsEvals"); + utils::check_cvode_flag(CVodeGetNumLinSolvSetups(sctx_p->cvode_mem, &nlinsetups), "CVodeGetNumLinSolvSetups"); + utils::check_cvode_flag(CVodeGetNumErrTestFails(sctx_p->cvode_mem, &netfails), "CVodeGetNumErrTestFails"); + utils::check_cvode_flag(CVodeGetNumNonlinSolvIters(sctx_p->cvode_mem, &nniters), "CVodeGetNumNonlinSolvIters"); + utils::check_cvode_flag(CVodeGetNumNonlinSolvConvFails(sctx_p->cvode_mem, &nconvfails), "CVodeGetNumNonlinSolvConvFails"); + utils::check_cvode_flag(CVodeGetNumLinConvFails(sctx_p->cvode_mem, &nsetfails), "CVodeGetNumLinConvFails"); { @@ -975,22 +990,26 @@ namespace gridfire::solver { } // --- 3. Get Estimated Local Errors (Your Original Logic) --- - utils::check_cvode_flag(CVodeGetEstLocalErrors(m_cvode_mem, m_YErr), "CVodeGetEstLocalErrors"); + utils::check_cvode_flag(CVodeGetEstLocalErrors(sctx_p->cvode_mem, sctx_p->YErr), "CVodeGetEstLocalErrors"); - sunrealtype *y_data = N_VGetArrayPointer(m_Y); - sunrealtype *y_err_data = N_VGetArrayPointer(m_YErr); - - const auto absTol = m_config->solver.cvode.absTol; - const auto relTol = m_config->solver.cvode.relTol; + sunrealtype *y_data = N_VGetArrayPointer(sctx_p->Y); + sunrealtype *y_err_data = N_VGetArrayPointer(sctx_p->YErr); std::vector err_ratios; - const size_t num_components = N_VGetLength(m_Y); + const size_t num_components = N_VGetLength(sctx_p->Y); err_ratios.resize(num_components - 1); // Assuming -1 is for Energy or similar std::vector Y_full(y_data, y_data + num_components - 1); std::vector E_full(y_err_data, y_err_data + num_components - 1); - auto result = diagnostics::report_limiting_species(ctx, *user_data.engine, Y_full, E_full, relTol, absTol, 10, to_file); + if (!sctx_p->abs_tol.has_value()) { + sctx_p->abs_tol = m_config->solver.cvode.absTol; + } + if (!sctx_p->rel_tol.has_value()) { + sctx_p->rel_tol = m_config->solver.cvode.relTol; + } + + auto result = diagnostics::report_limiting_species(ctx, *user_data.engine, Y_full, E_full, sctx_p->rel_tol.value(), sctx_p->abs_tol.value(), 10, to_file); if (to_file && result.has_value()) { j["Limiting_Species"] = result.value(); } @@ -1003,8 +1022,9 @@ namespace gridfire::solver { 0.0 ); + for (size_t i = 0; i < num_components - 1; i++) { - const double weight = relTol * std::abs(y_data[i]) + absTol; + const double weight = sctx_p->rel_tol.value() * std::abs(y_data[i]) + sctx_p->abs_tol.value(); if (weight == 0.0) { err_ratios[i] = 0.0; // Avoid division by zero continue; @@ -1013,11 +1033,11 @@ namespace gridfire::solver { err_ratios[i] = err_ratio; } - fourdst::composition::Composition composition(user_data.engine->getNetworkSpecies(*m_scratch_blob), Y_full); - fourdst::composition::Composition collectedComposition = user_data.engine->collectComposition(*m_scratch_blob, composition, user_data.T9, user_data.rho); + fourdst::composition::Composition composition(user_data.engine->getNetworkSpecies(*sctx_p->engine_ctx), Y_full); + fourdst::composition::Composition collectedComposition = user_data.engine->collectComposition(*sctx_p->engine_ctx, composition, user_data.T9, user_data.rho); - auto destructionTimescales = user_data.engine->getSpeciesDestructionTimescales(*m_scratch_blob, collectedComposition, user_data.T9, user_data.rho); - auto netTimescales = user_data.engine->getSpeciesTimescales(*m_scratch_blob, collectedComposition, user_data.T9, user_data.rho); + auto destructionTimescales = user_data.engine->getSpeciesDestructionTimescales(*sctx_p->engine_ctx, collectedComposition, user_data.T9, user_data.rho); + auto netTimescales = user_data.engine->getSpeciesTimescales(*sctx_p->engine_ctx, collectedComposition, user_data.T9, user_data.rho); bool timescaleOkay = false; if (destructionTimescales && netTimescales) timescaleOkay = true; @@ -1037,7 +1057,7 @@ namespace gridfire::solver { if (destructionTimescales.value().contains(sp)) destructionTimescales_list.emplace_back(destructionTimescales.value().at(sp)); else destructionTimescales_list.emplace_back(std::numeric_limits::infinity()); - speciesStatus_list.push_back(SpeciesStatus_to_string(user_data.engine->getSpeciesStatus(*m_scratch_blob, sp))); + speciesStatus_list.push_back(SpeciesStatus_to_string(user_data.engine->getSpeciesStatus(*sctx_p->engine_ctx, sp))); } utils::Column speciesColumn("Species", species_list); diff --git a/src/lib/solver/strategies/SpectralSolverStrategy.cpp b/src/lib/solver/strategies/SpectralSolverStrategy.cpp deleted file mode 100644 index 31853ba2..00000000 --- a/src/lib/solver/strategies/SpectralSolverStrategy.cpp +++ /dev/null @@ -1,1292 +0,0 @@ -#include "gridfire/solver/strategies/SpectralSolverStrategy.h" -#include "gridfire/utils/macros.h" -#include "gridfire/utils/table_format.h" - -#include - -#include "gridfire/utils/sundials.h" -#include "gridfire/solver/strategies/triggers/triggers.h" - -#include "quill/LogMacros.h" -#include "sunmatrix/sunmatrix_dense.h" -#include - -#include "gridfire/utils/gf_omp.h" -#include "gridfire/utils/formatters/jacobian_format.h" -#include "gridfire/utils/logging.h" - -namespace { - std::pair> evaluate_bspline( - const double x, - const gridfire::solver::SpectralSolverStrategy::SplineBasis& basis - ) { - const int p = basis.degree; - const std::vector& t = basis.knots; - - const auto it = std::ranges::upper_bound(t, x); - size_t i = std::distance(t.begin(), it) - 1; - - if (i < static_cast(p)) i = p; - if (i >= t.size() - 1 - p) i = t.size() - 2 - p; - - if (x >= t.back()) { - i = t.size() - p - 2; - } - - // Cox-de Boor algorithm - std::vector N(p + 1); - std::vector left(p + 1); - std::vector right(p + 1); - - N[0] = 1.0; - - for (int j = 1; j <= p; ++j) { - left[j] = x - t[i + 1 - j]; - right[j] = t[i + j] - x; - double saved = 0.0; - - for (int r = 0; r < j; ++r) { - const double temp = N[r] / (right[r + 1] + left[j - r]); - N[r] = saved + right[r + 1] * temp; - - saved = left[j - r] * temp; - } - N[j] = saved; - } - return {i - p, N}; - } -} - - - -namespace gridfire::solver { - - SpectralSolverStrategy::SpectralSolverStrategy( - const engine::DynamicEngine& engine - ) : MultiZoneNetworkSolver (engine) { - LOG_INFO(m_logger, "Initializing SpectralSolverStrategy"); - - utils::check_sundials_flag(SUNContext_Create(SUN_COMM_NULL, &m_sun_ctx), "SUNContext_Create", utils::SUNDIALS_RET_CODE_TYPES::CVODE); - - m_absTol = m_config->solver.spectral.absTol; - m_relTol = m_config->solver.spectral.relTol; - - LOG_INFO(m_logger, "SpectralSolverStrategy initialized successfully"); - - GF_PAR_INIT() - } - - SpectralSolverStrategy::~SpectralSolverStrategy() { - LOG_INFO(m_logger, "Destroying SpectralSolverStrategy"); - - - if (m_cvode_mem) { - CVodeFree(&m_cvode_mem); - m_cvode_mem = nullptr; - } - - if (m_LS) SUNLinSolFree(m_LS); - if (m_J) SUNMatDestroy(m_J); - if (m_Y) N_VDestroy(m_Y); - if (m_constraints) N_VDestroy(m_constraints); - - if (m_sun_ctx) { - SUNContext_Free(&m_sun_ctx); - m_sun_ctx = nullptr; - } - - if (m_T_coeffs) N_VDestroy(m_T_coeffs); - if (m_rho_coeffs) N_VDestroy(m_rho_coeffs); - - LOG_INFO(m_logger, "SpectralSolverStrategy destroyed successfully"); - } - - //////////////////////////////////////////////////////////////////////////////// - /// Main Evaluation Loop - ///////////////////////////////////////////////////////////////////////////////// - - std::vector SpectralSolverStrategy::evaluate( - const std::vector& netIns, - const std::vector& mass_coords, - const engine::scratch::StateBlob &ctx_template - ) { - LOG_INFO(m_logger, "Starting spectral solver evaluation for {} zones", netIns.size()); - - assert(std::ranges::all_of(netIns, [&netIns](const NetIn& in) { return in.tMax == netIns[0].tMax; }) && "All NetIn entries must have the same tMax for spectral solver evaluation."); - - std::vector updatedNetIns = netIns; - - std::vector> workspaces; - workspaces.resize(updatedNetIns.size()); - - LOG_INFO(m_logger, "Building workspaces..."); - GF_OMP(parallel for,) - for (size_t shellID = 0; shellID < updatedNetIns.size(); ++shellID) { - workspaces[shellID] = ctx_template.clone_structure(); - } - LOG_INFO(m_logger, "Workspaces built successfully."); - - LOG_INFO(m_logger, "Projecting initial conditions onto engine network..."); - GF_OMP(parallel for,) - for (size_t shellID = 0; shellID < updatedNetIns.size(); ++shellID) { - updatedNetIns[shellID].composition = m_engine.project(*workspaces[shellID], updatedNetIns[shellID]); - } - LOG_INFO(m_logger, "Initial conditions projected successfully."); - - ///////////////////////////////////// - /// Build the species union set /// - ///////////////////////////////////// - - LOG_INFO(m_logger, "Collecting global species set from all zones..."); - std::set global_active_species; - for (const auto &updatedNetIn : updatedNetIns) { - const auto& local_species = updatedNetIn.composition.getRegisteredSpecies(); - global_active_species.insert(local_species.begin(), local_species.end()); - } - m_global_species_list.clear(); - m_global_species_list.reserve(global_active_species.size()); - m_global_species_list.insert( - m_global_species_list.end(), - global_active_species.begin(), - global_active_species.end() - ); - LOG_INFO(m_logger, "Done collecting global species set. Total unique species: {}", m_global_species_list.size()); - - - ///////////////////////////////////// - /// Evaluate the monitor function /// - ///////////////////////////////////// - LOG_INFO(m_logger, "Evaluating monitor function..."); - const std::vector monitor_function = evaluate_monitor_function(updatedNetIns); - LOG_INFO(m_logger, "Monitor function evaluated successfully..."); - - ///////////////////////////////////////////// - /// Determine number of quadratude nodes /// - ///////////////////////////////////////////// - const size_t num_nodes = nyquist_elements( - m_config->solver.spectral.basis.num_elements, - updatedNetIns.size() - ); - LOG_INFO(m_logger, "Configuration requested {} quadrature nodes, actually using {} based on Nyquist criterion and number of zones.", m_config->solver.spectral.basis.num_elements, num_nodes); - - ///////////////////////////////////// - /// Generate Quadrature Basis /// - ///////////////////////////////////// - LOG_INFO(m_logger, "Generating quadrature basis..."); - m_current_basis = generate_basis_from_monitor(monitor_function, mass_coords, num_nodes); - LOG_INFO(m_logger, "Quadrature basis generated successfully with {} basis functions and {} quadrature nodes.", m_current_basis.knots.size() - m_current_basis.degree - 1, m_current_basis.quadrature_nodes.size()); - - ///////////////////////////////////////// - /// Rebuild workspaces for only basis /// - ///////////////////////////////////////// - LOG_INFO(m_logger, "Rebuilding workspaces for {} quadrature nodes...", m_current_basis.quadrature_nodes.size()); - workspaces.clear(); - workspaces.resize(m_current_basis.quadrature_nodes.size()); - GF_OMP(parallel for,) - for (size_t shellID = 0; shellID < workspaces.size(); ++shellID) { // NOLINT(*-loop-convert) - workspaces[shellID] = ctx_template.clone_structure(); - } - - LOG_INFO(m_logger, "Workspaces rebuilt successfully."); - - //////////////////////////////////////// - /// Project Initial Conditions /// - //////////////////////////////////////// - LOG_INFO(m_logger, "Projecting initial conditions onto quadrature nodes..."); - const size_t num_basis_funcs = m_current_basis.knots.size() - m_current_basis.degree - 1; - - std::vector shell_cache(updatedNetIns.size()); - - GF_OMP(parallel for,) - for (size_t shellID = 0; shellID < shell_cache.size(); ++shellID) { - auto [start, phi] = evaluate_bspline(mass_coords[shellID], m_current_basis); - shell_cache[shellID] = {.start_idx=start, .phi=phi}; - } - - const DenseLinearSolver proj_solver(num_basis_funcs, m_sun_ctx); - proj_solver.init_from_cache(num_basis_funcs, shell_cache); - - if (m_T_coeffs) N_VDestroy(m_T_coeffs); - m_T_coeffs = N_VNew_Serial(static_cast(num_basis_funcs), m_sun_ctx); - project_specific_variable(updatedNetIns, mass_coords, shell_cache, proj_solver, m_T_coeffs, 0, [](const NetIn& s) { return s.temperature; }, true); - - if (m_rho_coeffs) N_VDestroy(m_rho_coeffs); - m_rho_coeffs = N_VNew_Serial(static_cast(num_basis_funcs), m_sun_ctx); - project_specific_variable(updatedNetIns, mass_coords, shell_cache, proj_solver, m_rho_coeffs, 0, [](const NetIn& s) { return s.density; }, true); - - const size_t num_species = m_global_species_list.size(); - size_t current_offset = 0; - - const size_t total_coefficients = num_basis_funcs * (num_species + 1); - - if (m_Y) N_VDestroy(m_Y); - if (m_constraints) N_VDestroy(m_constraints); - - m_Y = N_VNew_Serial(static_cast(total_coefficients), m_sun_ctx); - N_VConst(0.0, m_Y); - m_constraints = N_VClone(m_Y); - N_VConst(0.0, m_constraints); // For now no constraints on coefficients - - for (const auto& sp : m_global_species_list) { - project_specific_variable( - updatedNetIns, - mass_coords, - shell_cache, - proj_solver, - m_Y, - current_offset, - [&sp](const NetIn& s) { - if (!s.composition.contains(sp)) return 0.0; - return s.composition.getMolarAbundance(sp); - }, - false - ); - current_offset += num_basis_funcs; - } - - sunrealtype* y_data = N_VGetArrayPointer(m_Y); - const size_t energy_offset = num_species * num_basis_funcs; - - assert(energy_offset == current_offset && "Energy offset calculation mismatch in spectral solver initialization."); - - for (size_t i = 0; i < num_basis_funcs; ++i) { - y_data[energy_offset + i] = 0.0; - } - LOG_INFO(m_logger, "Done projecting initial conditions onto quadrature nodes."); - - LOG_INFO(m_logger, "Projecting quadrature conditions onto workspaces"); - for (size_t q = 0; q < m_current_basis.quadrature_nodes.size(); ++q) { - // ReSharper disable once CppUseStructuredBinding - GridPoint gp = reconstruct_at_quadrature(m_Y, q, m_current_basis); - NetIn quad_netin; - quad_netin.temperature = gp.T9 * 1e9; - quad_netin.density = gp.rho; - quad_netin.tMax = 1; // Not needed for projection - quad_netin.composition = gp.composition; - - (void)m_engine.project(*workspaces[q], quad_netin); // We do not need to capture this output just do the projection. Since project is marked nodiscard we use a void cast to suppress warnings. - } - LOG_INFO(m_logger, "Done projecting initial conditions onto workspaces."); - - - DenseLinearSolver mass_solver(num_basis_funcs, m_sun_ctx); - mass_solver.init_from_basis(num_basis_funcs, m_current_basis); - - ///////////////////////////////////// - /// CVODE Initialization /// - ///////////////////////////////////// - - LOG_INFO(m_logger, "Initializing CVODE resources..."); - CVODEUserData data; - data.solver_instance = this; - data.engine = &m_engine; - data.mass_matrix_solver_instance = &mass_solver; - data.basis = &m_current_basis; - - data.workspaces.reserve(workspaces.size()); - for (const auto& ws_ptr : workspaces) { - data.workspaces.emplace_back(*ws_ptr); - } - - const double absTol = m_absTol.value_or(1e-10); - const double relTol = m_relTol.value_or(1e-6); - - const bool size_changed = m_last_size != total_coefficients; - m_last_size = total_coefficients; - - if (m_cvode_mem == nullptr || size_changed) { - if (m_cvode_mem) { - CVodeFree(&m_cvode_mem); - m_cvode_mem = nullptr; - } - if (m_LS) { - SUNLinSolFree(m_LS); - m_LS = nullptr; - } - if (m_J) { - SUNMatDestroy(m_J); - m_J = nullptr; - } - - m_cvode_mem = CVodeCreate(CV_BDF, m_sun_ctx); - utils::check_sundials_flag(m_cvode_mem == nullptr ? -1 : 0, "CVodeCreate", utils::SUNDIALS_RET_CODE_TYPES::CVODE); - - utils::check_sundials_flag(CVodeInit(m_cvode_mem, cvode_rhs_wrapper, 0.0, m_Y), "CVodeInit", utils::SUNDIALS_RET_CODE_TYPES::CVODE); - m_J = SUNDenseMatrix(static_cast(total_coefficients), static_cast(total_coefficients), m_sun_ctx); - m_LS = SUNLinSol_Dense(m_Y, m_J, m_sun_ctx); - utils::check_sundials_flag(CVodeSetLinearSolver(m_cvode_mem, m_LS, m_J), "CVodeSetLinearSolver", utils::SUNDIALS_RET_CODE_TYPES::CVODE); - - utils::check_sundials_flag(CVodeSetJacFn(m_cvode_mem, cvode_jac_wrapper), "CVodeSetJacFn", utils::SUNDIALS_RET_CODE_TYPES::CVODE); - - } else { - utils::check_sundials_flag(CVodeReInit(m_cvode_mem, 0.0, m_Y), "CVodeReInit", utils::SUNDIALS_RET_CODE_TYPES::CVODE); - } - - utils::check_sundials_flag(CVodeSStolerances(m_cvode_mem, relTol, absTol), "CVodeSStolerances", utils::SUNDIALS_RET_CODE_TYPES::CVODE); - utils::check_sundials_flag(CVodeSetUserData(m_cvode_mem, &data), "CVodeSetUserData", utils::SUNDIALS_RET_CODE_TYPES::CVODE); - LOG_INFO(m_logger, "CVODE resources initialized successfully."); - - ///////////////////////////////////// - /// Setup Trigger /// - ///////////////////////////////////// - - LOG_INFO(m_logger, "Setting up projection trigger..."); - - const double simulationTimeInterval = m_config->solver.spectral.trigger.simulationTimeInterval; - const double offDiagonalThreshold = m_config->solver.spectral.trigger.offDiagonalThreshold; - const double timestepCollapseRatio = m_config->solver.spectral.trigger.timestepCollapseRatio; - const size_t maxConvergenceFailures = m_config->solver.spectral.trigger.maxConvergenceFailures; - auto trigger = trigger::solver::CVODE::makeEnginePartitioningTrigger( - simulationTimeInterval, - offDiagonalThreshold, - timestepCollapseRatio, - maxConvergenceFailures - ); - LOG_INFO(m_logger, "Projection trigger setup with parameters: simulationTimeInterval = {}, offDiagonalThreshold = {}, timestepCollapseRatio = {}, maxConvergenceFailures = {}", - simulationTimeInterval, - offDiagonalThreshold, - timestepCollapseRatio, - maxConvergenceFailures - ); - - - ///////////////////////////////////// - /// Time Integration Loop /// - ///////////////////////////////////// - - const double target_time = updatedNetIns[0].tMax; - double current_time = 0.0; - - double accumulated_energy = 0.0; - double accumulated_neutrino_energy_loss = 0.0; - double accumulated_total_neutrino_flux = 0.0; - - size_t total_convergence_failures = 0; - size_t total_nonlinear_iterations = 0; - size_t total_update_stages_triggered = 0; - - size_t prev_nonlinear_iterations = 0; - size_t prev_convergence_failures = 0; - - size_t total_steps = 0; - - LOG_INFO(m_logger, "Starting CVODE interation...", target_time); - while (current_time < target_time) { - data.captured_exception.reset(); - - int flag = CVode(m_cvode_mem, target_time, m_Y, ¤t_time, CV_ONE_STEP); - if (data.captured_exception) { - LOG_CRITICAL(m_logger, "An exception was captured during RHS evaluation ({}). Rethrowing...", data.captured_exception->what()); - std::rethrow_exception(std::make_exception_ptr(*data.captured_exception)); - } - utils::check_sundials_flag(flag, "CVode", utils::SUNDIALS_RET_CODE_TYPES::CVODE); - - long int n_steps; - double last_step_size; - CVodeGetNumSteps(m_cvode_mem, &n_steps); - CVodeGetLastStep(m_cvode_mem, &last_step_size); - long int nliters, nlcfails; - CVodeGetNumNonlinSolvIters(m_cvode_mem, &nliters); - CVodeGetNumNonlinSolvConvFails(m_cvode_mem, &nlcfails); - - accumulated_neutrino_energy_loss += data.neutrino_energy_loss_rate * last_step_size; - accumulated_total_neutrino_flux += data.total_neutrino_flux * last_step_size; - - size_t iter_diff = (total_nonlinear_iterations + nliters) - prev_nonlinear_iterations; - size_t convFail_diff = (total_convergence_failures + nlcfails) - prev_convergence_failures; - - std::string out_msg = std::format( - "Step: {:6} | Updates: {:3} | Epoch Steps: {:4} | t: {:.3e} [s] | dt: {:15.6E} [s] | Iterations: {:6} (+{:2}) | Total Convergence Failures: {:2} (+{:2})", - total_steps + n_steps, - total_update_stages_triggered, - n_steps, - current_time, - last_step_size, - prev_nonlinear_iterations + nliters, - iter_diff, - prev_convergence_failures + nlcfails, - convFail_diff - ); - prev_nonlinear_iterations = total_nonlinear_iterations + nliters; - prev_convergence_failures = total_convergence_failures + nlcfails; - - std::println("{}", out_msg); - LOG_INFO(m_logger, "{}", out_msg); - } - LOG_INFO(m_logger, "Time integration complete. Reconstructing solution..."); - - std::vector results = reconstruct_solution(updatedNetIns, mass_coords, m_Y, m_current_basis, target_time); - LOG_INFO(m_logger, "Spectral solver evaluation complete for all zones."); - return results; - } - - void SpectralSolverStrategy::set_callback(const std::any &callback) { - m_callback = std::any_cast(callback); - } - - std::vector> SpectralSolverStrategy::describe_callback_context() const { - throw std::runtime_error("SpectralSolverStrategy does not yet implement describe_callback_context."); - } - - bool SpectralSolverStrategy::get_stdout_logging_enabled() const { - return m_stdout_logging_enabled; - } - - void SpectralSolverStrategy::set_stdout_logging_enabled(bool logging_enabled) { - m_stdout_logging_enabled = logging_enabled; - } - - //////////////////////////////////////////////////////////////////////////////// - /// Static Wrappers for SUNDIALS Callbacks - //////////////////////////////////////////////////////////////////////////////// - - int SpectralSolverStrategy::cvode_rhs_wrapper( - const sunrealtype t, - const N_Vector y_coeffs, - const N_Vector ydot_coeffs, - void *user_data - ) { - auto *data = static_cast(user_data); - const auto *instance = data->solver_instance; - - try { - return instance -> calculate_rhs(t, y_coeffs, ydot_coeffs, data); - } catch (const std::exception& e) { - LOG_CRITICAL(instance->m_logger, "Uncaught exception in Spectral Solver RHS wrapper at time {}: {}", t, e.what()); - return -1; - } catch (...) { - LOG_CRITICAL(instance->m_logger, "Unknown uncaught exception in Spectral Solver RHS wrapper at time {}", t); - return -1; - } - } - - int SpectralSolverStrategy::cvode_jac_wrapper( - const sunrealtype t, - const N_Vector y, - const N_Vector ydot, - const SUNMatrix J, - void *user_data, - const N_Vector tmp1, - const N_Vector tmp2, - const N_Vector tmp3 - ) { - const auto *data = static_cast(user_data); - const auto *instance = data->solver_instance; - - try { - return instance->calculate_jacobian(t, y, ydot, J, data, tmp1, tmp2, tmp3); - } catch (const std::exception& e) { - LOG_CRITICAL(instance->m_logger, "Uncaught exception in Spectral Solver Jacobian wrapper at time {}: {}", t, e.what()); - return -1; - } catch (...) { - LOG_CRITICAL(instance->m_logger, "Unknown uncaught exception in Spectral Solver Jacobian wrapper at time {}", t); - return -1; - } - } - - //////////////////////////////////////////////////////////////////////////////// - /// RHS implementation - //////////////////////////////////////////////////////////////////////////////// - - // ReSharper disable once CppDFAUnreachableFunctionCall - int SpectralSolverStrategy::calculate_rhs( - sunrealtype t, - N_Vector y_coeffs, - N_Vector ydot_coeffs, - CVODEUserData* data - ) const { - const auto& basis = m_current_basis; - DenseLinearSolver* mass_solver = data->mass_matrix_solver_instance; - - const size_t num_basis_funcs = basis.knots.size() - basis.degree - 1; - const size_t num_species = m_global_species_list.size(); - - const size_t total_dofs = num_basis_funcs * (num_species + 1); - - sunrealtype* global_rhs_data = N_VGetArrayPointer(ydot_coeffs); - N_VConst(0.0, ydot_coeffs); - - std::atomic failure_flag{false}; - - GF_OMP(parallel,) { - std::vector local_rhs(total_dofs, 0.0); - - GF_OMP(for nowait,) - for (size_t q = 0; q < basis.quadrature_nodes.size(); ++q) { - if (failure_flag.load(std::memory_order_relaxed)) continue; - - double wq = basis.quadrature_weights[q]; - const auto& [start_idx, phi] = basis.quad_evals[q]; - - GridPoint gp = reconstruct_at_quadrature(y_coeffs, q, basis); - LOG_TRACE_L2(m_logger, "RHS Evaluation at time {}: Quad Node {}, T9 = {:10.4e}, rho = {:10.4e}", t, q, gp.T9, gp.rho); - auto results = m_engine.calculateRHSAndEnergy( - data->workspaces[q], - gp.composition, - gp.T9, - gp.rho, - false - ); - if (!results) { - failure_flag.store(true); - LOG_CRITICAL(m_logger, "Engine failed to calculate RHS at time {}: {}", t, EngineStatus_to_string(results.error())); - GF_OMP(critical, { - throw std::runtime_error("Engine failure during RHS calculation."); - }) { - std::fill_n(global_rhs_data, total_dofs, -1.0); - } - continue; - } - - const auto& [dydt, eps_nuc, contributions, nu_loss, nu_flux] = results.value(); - for (size_t s = 0; s < num_species; ++s) { - const auto& sp = m_global_species_list[s]; - double rate = 0.0; - if (dydt.contains(sp)) { - rate = dydt.at(sp); - } - - size_t species_offset = s * num_basis_funcs; - - for (size_t k = 0; k < phi.size(); ++k) { - size_t global_idx = species_offset + start_idx + k; - local_rhs[global_idx] += wq * phi[k] * rate; - } - } - - size_t energy_offset = num_species * num_basis_funcs; - - for (size_t k = 0; k < phi.size(); ++k) { - size_t global_idx = energy_offset + start_idx + k; - local_rhs[global_idx] += eps_nuc * wq * phi[k]; - } - } - - GF_OMP(critical,) { - if (!failure_flag.load(std::memory_order_relaxed)) { - for (size_t i = 0; i < total_dofs; ++i) { - global_rhs_data[i] += local_rhs[i]; - } - } - } - } - - if (failure_flag.load()) { - return -1; - } - - size_t total_vars = num_species + 1; - mass_solver -> solve_inplace(ydot_coeffs, total_vars, num_basis_funcs); - return 0; - } - - // ReSharper disable once CppDFAUnreachableFunctionCall - int SpectralSolverStrategy::calculate_jacobian( - sunrealtype t, - N_Vector y_coeffs, - N_Vector ydot_coeffs, - SUNMatrix J, - const CVODEUserData *data, - N_Vector tmp1, - N_Vector tmp2, - N_Vector tmp3 - ) const { - const auto& basis = m_current_basis; - DenseLinearSolver* mass_solver = data->mass_matrix_solver_instance; - - const size_t num_basis_funcs = basis.knots.size() - basis.degree - 1; - const size_t num_species = m_global_species_list.size(); - const size_t total_dofs = num_basis_funcs * (num_species + 1); - - SUNMatZero(J); - sunrealtype* J_global_data = SUNDenseMatrix_Data(J); - - std::atomic failure_flag(false); - - GF_OMP(parallel,) { -#if defined(GF_USE_OPENMP) - int thread_id = omp_get_thread_num(); -#else - [[maybe_unused]] int thread_id = 0; -#endif - GF_OMP(for schedule(static) nowait,) - for (size_t q = 0; q < basis.quadrature_nodes.size(); ++q) { - if (failure_flag.load(std::memory_order_relaxed)) continue; - - double wq = basis.quadrature_weights[q]; - const auto& [start_idx, phi] = basis.quad_evals[q]; - - const GridPoint gp = reconstruct_at_quadrature(y_coeffs, q, basis); - try { - engine::NetworkJacobian jac = m_engine.generateJacobianMatrix( - data->workspaces[q], - gp.composition, - gp.T9, - gp.rho - ); - - auto accumulate_term = [&](size_t row_offset, size_t col_offset, double value) { - if (std::abs(value) < 1e-100) return; // regularization - - const double w_val = wq * value; - - for (size_t k_col = 0; k_col < phi.size(); ++k_col) { - const size_t global_col = col_offset + start_idx + k_col; - const size_t col_data_idx = global_col * total_dofs; - - for (size_t k_row = 0; k_row < phi.size(); ++k_row) { - const size_t global_row = row_offset + start_idx + k_row; - const double contribution = w_val * phi[k_row] * phi[k_col]; - - GF_OMP(atomic update,) - J_global_data[global_row + col_data_idx] += contribution; - } - } - }; - - for (size_t s_row = 0; s_row < num_species; ++s_row) { - const auto& sp_row = m_global_species_list[s_row]; - - if (!gp.composition.contains(sp_row)) continue; - - for (size_t s_col = 0; s_col < num_species; ++s_col) { - const auto& sp_col = m_global_species_list[s_col]; - - if (!gp.composition.contains(sp_col)) continue; - - double val = jac(sp_row, sp_col); - - accumulate_term(s_row * num_basis_funcs, s_col * num_basis_funcs, val); - } - } - } catch (const exceptions::GridFireError &e) { - failure_flag.store(true); - std::string error_msg = std::format("Engine failed to calculate Jacobian, due to known internal GridFire error, at time {}: {}", t, e.what()); - LOG_CRITICAL(m_logger, "{}", error_msg); - GF_OMP(critical, { - throw std::runtime_error(error_msg); - }) { - SUNMatZero(J); - } - continue; - } catch (const std::exception &e) { - failure_flag.store(true); - std::string error_msg = std::format("Engine failed to calculate Jacobian, due to some known yet non GridFire error, at time {}: {}", t, e.what()); - LOG_CRITICAL(m_logger, "{}", error_msg); - GF_OMP(critical, { - throw std::runtime_error(error_msg); - }) { - SUNMatZero(J); - } - continue; - } catch (...) { - failure_flag.store(true); - std::string error_msg = std::format("Engine failed to calculate Jacobian, due to unknown error, at time {}.", t); - LOG_CRITICAL(m_logger, "{}", error_msg); - GF_OMP(critical, { - throw std::runtime_error(error_msg); - }) { - SUNMatZero(J); - } - continue; - } - } - } - - if (failure_flag.load()) { return -1; } - inspect_jacobian(J, "Physics Assembly (Pre-Mass-Matrix)"); - - GF_OMP(parallel for,) - for (size_t col = 0; col < total_dofs; ++col) { - sunrealtype* col_ptr = J_global_data + (col * total_dofs); - mass_solver->solve_inplace_ptr(col_ptr, num_species + 1, num_basis_funcs); - } - inspect_jacobian(J, "Final Jacobian (Post-Mass-Matrix)"); - return 0; - } - - size_t SpectralSolverStrategy::nyquist_elements( - const size_t requested_elements, - const size_t num_shells - ) { - const size_t max_allowed_elements = std::max(1uz, num_shells/2); - size_t actual_elements = std::min(requested_elements, max_allowed_elements); - if (num_shells <= 5) { - actual_elements = 1; - } - return actual_elements; - } - - - //////////////////////////////////////////////////////////////////////////////// - /// Spectral Utilities - /// These include basis generation, monitor function evaluation - /// projection and reconstruction routines. - //////////////////////////////////////////////////////////////////////////////// - - std::vector SpectralSolverStrategy::evaluate_monitor_function( - const std::vector& current_shells - ) const { - assert(!m_global_species_list.empty() && "Global species list must be initialized before evaluating monitor function."); - assert(!current_shells.empty() && "Current shells list must not be empty when evaluating monitor function."); - - const size_t n_shells = current_shells.size(); - if (n_shells < 3) { - return std::vector(n_shells, 1.0); // NOLINT(*-return-braced-init-list) - } - - std::vector M(n_shells, 1.0); - std::vector data(n_shells); - - auto accumulate_variable = [&](auto getter, double weight, bool use_log) { - double min_val = std::numeric_limits::max(); - double max_val = std::numeric_limits::lowest(); - - for (size_t i = 0 ; i < n_shells; ++i) { - double val = getter(current_shells[i]); - if (use_log) { - val = std::log10(std::max(val, 1e-100)); - } - - data[i] = val; - - if (val < min_val) min_val = val; - if (val > max_val) max_val = val; - } - - const double scale = max_val - min_val; - if (scale < 1e-10) return; - - for (size_t i = 1; i < n_shells - 1; ++i) { - const double v_prev = data[i-1]; - const double v_curr = data[i]; - const double v_next = data[i+1]; - - // Finite difference estimates for first and second derivatives - double d1 = std::abs(v_next - v_prev) / 2.0; - double d2 = std::abs(v_next - 2.0 * v_curr + v_prev); - - d1 /= scale; - d2 /= scale; - - const double alpha = m_config->solver.spectral.monitorFunction.alpha; - const double beta = m_config->solver.spectral.monitorFunction.beta; - - M[i] += weight * (alpha * d1 + beta * d2); - } - }; - - auto safe_get_abundance = [](const NetIn& netIn, const fourdst::atomic::Species& sp) -> double { - if (netIn.composition.contains(sp)) { - return netIn.composition.getMolarAbundance(sp); - } - return 0.0; - }; - - const double structure_weight = m_config->solver.spectral.monitorFunction.structure_weight; - double abundance_weight = m_config->solver.spectral.monitorFunction.abundance_weight; - accumulate_variable([](const NetIn& s) { return s.temperature; }, structure_weight, true); - accumulate_variable([](const NetIn& s) { return s.density; }, structure_weight, true); - - for (const auto& sp : m_global_species_list) { - accumulate_variable([&sp, &safe_get_abundance](const NetIn& s) { return safe_get_abundance(s, sp); }, abundance_weight, false); - } - - ////////////////////////////// - /// Smoothing the Monitor /// - ////////////////////////////// - - std::vector M_smooth = M; - for (size_t i = 1; i < n_shells - 1; ++i) { - M_smooth[i] = (M[i-1] + 2.0 * M[i] + M[i+1]) / 4.0; - } - - M_smooth[0] = M_smooth[1]; - M_smooth[n_shells-1] = M_smooth[n_shells-2]; - - return M_smooth; - - } - - SpectralSolverStrategy::SplineBasis SpectralSolverStrategy::generate_basis_from_monitor( - const std::vector& monitor_values, - const std::vector& mass_coordinates, - const size_t actual_elements - ) { - SplineBasis basis; - basis.degree = 3; // Cubic Spline - - const size_t n_shells = monitor_values.size(); - - std::vector I(n_shells, 0.0); - double current_integral = 0.0; - - for (size_t i = 1; i < n_shells; ++i) { - const double dx = mass_coordinates[i] - mass_coordinates[i-1]; - double dI = 0.5 * (monitor_values[i] + monitor_values[i-1]) * dx; - - dI = std::max(dI, 1e-30); - current_integral += dI; - I[i] = current_integral; - } - - const double total_integral = I.back(); - for (size_t i = 0; i < n_shells; ++i) { - I[i] /= total_integral; - } - - basis.knots.reserve(actual_elements + 1 + 2 * basis.degree); - - // Note that these imply that mass_coordinates must be sorted in increasing order - double min_mass = mass_coordinates.front(); - double max_mass = mass_coordinates.back(); - - for (int i = 0; i < basis.degree; ++i) { - basis.knots.push_back(min_mass); - } - - for (size_t k = 1; k < actual_elements; ++k) { - double target_I = static_cast(k) / static_cast(actual_elements); - - auto it = std::ranges::lower_bound(I, target_I); - size_t idx = std::distance(I.begin(), it); - - if (idx == 0) idx = 1; - if (idx >= n_shells) idx = n_shells - 1; - - double I0 = I[idx-1]; - double I1 = I[idx]; - double m0 = mass_coordinates[idx-1]; - double m1 = mass_coordinates[idx]; - - double fraction = (target_I - I0) / (I1 - I0); - double knot_location = m0 + fraction * (m1 - m0); - - basis.knots.push_back(knot_location); - } - - for (int i = 0; i < basis.degree; ++i) { - basis.knots.push_back(max_mass); - } - - constexpr double sqrt_3_over_5 = 0.77459666924; - constexpr double five_over_nine = 5.0 / 9.0; - constexpr double eight_over_nine = 8.0 / 9.0; - static constexpr std::array gl_nodes = {-sqrt_3_over_5, 0.0, sqrt_3_over_5}; - static constexpr std::array gl_weights = {five_over_nine, eight_over_nine, five_over_nine}; - - basis.quadrature_nodes.clear(); - basis.quadrature_weights.clear(); - - for (size_t i = basis.degree; i < basis.knots.size() - basis.degree - 1; ++i) { - double a = basis.knots[i]; - double b = basis.knots[i+1]; - - if ( b - a < 1e-14) continue; - - double mid = 0.5 * (a + b); - double half_width = 0.5 * (b - a); - - for (size_t j = 0; j < gl_nodes.size(); ++j) { - double phys_node = mid + gl_nodes[j] * half_width; - double phys_weight = gl_weights[j] * half_width; - - basis.quadrature_nodes.push_back(phys_node); - basis.quadrature_weights.push_back(phys_weight); - - auto [start, phi] = evaluate_bspline(phys_node, basis); - basis.quad_evals.push_back({start, phi}); - } - } - - return basis; - } - - - SpectralSolverStrategy::GridPoint SpectralSolverStrategy::reconstruct_at_quadrature( - const N_Vector y_coeffs, - const size_t quad_index, - const SplineBasis &basis - ) const { - auto [start_idx, vals] = basis.quad_evals[quad_index]; - - const sunrealtype* T_ptr = N_VGetArrayPointer(m_T_coeffs); - const sunrealtype* rho_ptr = N_VGetArrayPointer(m_rho_coeffs); - const sunrealtype* y_data = N_VGetArrayPointer(y_coeffs); - - const size_t num_basis_funcs = basis.knots.size() - basis.degree - 1; - const size_t num_species = m_global_species_list.size(); - - double logT = 0.0; - double logRho = 0.0; - - for (size_t k = 0; k < vals.size(); ++k) { - size_t idx = start_idx + k; - logT += T_ptr[idx] * vals[k]; - logRho += rho_ptr[idx] * vals[k]; - } - - GridPoint result; - result.T9 = std::pow(10.0, logT) / 1e9; - result.rho = std::pow(10.0, logRho); - - for (size_t s = 0; s < num_species; ++s) { - const auto& species = m_global_species_list[s]; - - double abundance = 0.0; - const size_t offset = s * num_basis_funcs; - - for (size_t k = 0; k < vals.size(); ++k) { - abundance += y_data[offset + start_idx + k] * vals[k]; - } - - if (abundance < 0.0) abundance = 0.0; - - result.composition.registerSpecies(species); - result.composition.setMolarAbundance(species, abundance); - } - - return result; - } - - std::vector SpectralSolverStrategy::reconstruct_solution( - const std::vector& original_inputs, - const std::vector& mass_coordinates, - const N_Vector final_coeffs, - const SplineBasis& basis, - const double dt - ) const { - const size_t n_shells = original_inputs.size(); - const size_t num_basis_funcs = basis.knots.size() - basis.degree - 1; - const size_t num_species = m_global_species_list.size(); - - // Pre-allocate output vector so threads can write directly - std::vector outputs(n_shells); - - const sunrealtype* c_data = N_VGetArrayPointer(final_coeffs); - - GF_OMP(parallel for,) - for (size_t shellID = 0; shellID < n_shells; ++shellID) { - const double x = mass_coordinates[shellID]; - - auto [start_idx, vals] = evaluate_bspline(x, basis); - - auto reconstruct_var = [&](const size_t coeff_offset) -> double { - double result = 0.0; - for (size_t i = 0; i < vals.size(); ++i) { - result += c_data[coeff_offset + start_idx + i] * vals[i]; - } - return result; - }; - - fourdst::composition::Composition comp_new; - - for (size_t s_idx = 0; s_idx < num_species; ++s_idx) { - const auto& sp = m_global_species_list[s_idx]; - - const size_t current_offset = s_idx * num_basis_funcs; - double Y_val = reconstruct_var(current_offset); - - if (Y_val < 0.0 && Y_val >= -1e-16) { - Y_val = 0.0; - } - - comp_new.registerSpecies(sp); - comp_new.setMolarAbundance(sp, Y_val); - } - - const double energy = reconstruct_var(num_species * num_basis_funcs); - - NetOut netOut; - netOut.composition = comp_new; - netOut.energy = energy; - netOut.num_steps = -1; // Not tracked in spectral solver - - outputs[shellID] = std::move(netOut); - } - - return outputs; - } - - - void SpectralSolverStrategy::project_specific_variable( - const std::vector ¤t_shells, - const std::vector &mass_coordinates, - const std::vector &shell_cache, - const DenseLinearSolver &linear_solver, - N_Vector output_vec, - size_t output_offset, - const std::function &getter, - bool use_log - ) { - const size_t n_shells = current_shells.size(); - - sunrealtype* out_ptr = N_VGetArrayPointer(output_vec); - size_t basis_size = N_VGetLength(linear_solver.temp_vector); - - for (size_t i = 0; i < basis_size; ++i ) { - out_ptr[output_offset + i] = 0.0; - } - - for (size_t shellID = 0; shellID < n_shells; ++shellID) { - double val = getter(current_shells[shellID]); - if (use_log) val = std::log10(std::max(val, 1e-100)); - - const auto& eval = shell_cache[shellID]; - - for (size_t i = 0; i < eval.phi.size(); ++i) { - out_ptr[output_offset + eval.start_idx + i] += val * eval.phi[i]; - } - } - - sunrealtype* tmp_data = N_VGetArrayPointer(linear_solver.temp_vector); - for (size_t i = 0; i < basis_size; ++i) tmp_data[i] = out_ptr[output_offset + i]; - - utils::check_sundials_flag(SUNLinSolSolve(linear_solver.LS, linear_solver.A, linear_solver.temp_vector, linear_solver.temp_vector, 0.0), "SUNLinSolSolve - Projection Solver", utils::SUNDIALS_RET_CODE_TYPES::CVODE); - - for (size_t i = 0; i < basis_size; ++i) out_ptr[output_offset + i] = tmp_data[i]; - } - - /////////////////////////////////////////////////////////////////////////////////// - /// Debugging Utilities - /////////////////////////////////////////////////////////////////////////////////// - // ReSharper disable once CppDFAUnreachableFunctionCall - void SpectralSolverStrategy::inspect_jacobian(SUNMatrix J, const std::string &context) const { - sunrealtype* data = SUNDenseMatrix_Data(J); - sunindextype rows = SUNDenseMatrix_Rows(J); - sunindextype cols = SUNDenseMatrix_Columns(J); - - // --- 1. Gather Statistics --- - size_t nan_count = 0; - size_t inf_count = 0; - size_t zero_diag_count = 0; - - double max_val = 0.0; - double min_val = 0.0; - double min_diag_val = std::numeric_limits::max(); - double max_diag_val = std::numeric_limits::lowest(); - - const bool matrix_is_empty = (rows == 0 || cols == 0); - bool non_zero_elements_found = false; - - // Iterate Column-Major (standard for SUNDIALS) - for (sunindextype j = 0; j < cols; ++j) { - // Diagonal Check - if (j < rows) { - const double diag = data[j * rows + j]; - const double abs_diag = std::abs(diag); - - if (abs_diag < 1e-100) zero_diag_count++; - - if (abs_diag < min_diag_val) min_diag_val = abs_diag; - if (abs_diag > max_diag_val) max_diag_val = abs_diag; - } - - for (sunindextype i = 0; i < rows; ++i) { - const double val = data[j * rows + i]; - - if (std::isnan(val)) { - nan_count++; - } else if (std::isinf(val)) { - inf_count++; - } else { - const double abs_val = std::abs(val); - if (abs_val > 0.0) { - if (!non_zero_elements_found) { - // First non-zero element initializes the range - max_val = abs_val; - min_val = abs_val; - non_zero_elements_found = true; - } else { - if (abs_val > max_val) max_val = abs_val; - if (abs_val < min_val) min_val = abs_val; - } - } - } - } - } - - if (!non_zero_elements_found) { - min_diag_val = 0.0; - max_diag_val = 0.0; - } - - // --- 2. Build Data Vectors --- - std::vector metrics = { - "Dimensions", - "NaN Count", - "Inf Count", - "Zero Diagonals", - "Global Range (abs)", - "Diagonal Range (abs)", - "Status" - }; - - std::vector values; - values.reserve(metrics.size()); - - // Dimensions - values.push_back(std::format("{} x {}", rows, cols)); - - // Errors - values.push_back(std::to_string(nan_count)); - values.push_back(std::to_string(inf_count)); - values.push_back(std::to_string(zero_diag_count)); - - // Ranges - values.push_back(std::format("[{:.2e}, {:.2e}]", min_val, max_val)); - values.push_back(std::format("[{:.2e}, {:.2e}]", min_diag_val, max_diag_val)); - - // Status - const bool failed = (nan_count > 0 || inf_count > 0 || zero_diag_count > 0 || matrix_is_empty); - values.emplace_back(failed ? "FAIL" : "OK"); - - // --- 3. Construct Columns Manually --- - std::vector> columns; - - columns.push_back(std::make_unique>("Metric", metrics)); - columns.push_back(std::make_unique>("Value", values)); - - // --- 4. Format and Log --- - std::string table_name = std::format("Jacobian Inspection: {}", context); - std::string report_str = gridfire::utils::format_table(table_name, columns); - - if (failed) { - std::println("{}", report_str); - LOG_CRITICAL(m_logger, "\n{}", report_str); - } else { - std::println("{}", report_str); - LOG_INFO(m_logger, "\n{}", report_str); - } - } - - /////////////////////////////////////////////////////////////////////////////// - /// SpectralSolverStrategy::MassMatrixSolver Implementation - /////////////////////////////////////////////////////////////////////////////// - - SpectralSolverStrategy::DenseLinearSolver::DenseLinearSolver( - const size_t size, - const SUNContext sun_ctx - ) : ctx(sun_ctx) { - A = SUNDenseMatrix(static_cast(size), static_cast(size), sun_ctx); - temp_vector = N_VNew_Serial(static_cast(size), sun_ctx); - - LS = SUNLinSol_Dense(temp_vector, A, sun_ctx); - - if (!A || !temp_vector || !LS) { - throw std::runtime_error("Failed to create MassMatrixSolver components."); - } - - zero(); - } - - SpectralSolverStrategy::DenseLinearSolver::~DenseLinearSolver() { - if (LS) SUNLinSolFree(LS); - if (A) SUNMatDestroy(A); - if (temp_vector) N_VDestroy(temp_vector); - } - - void SpectralSolverStrategy::DenseLinearSolver::zero() const { - SUNMatZero(A); - } - - void SpectralSolverStrategy::DenseLinearSolver::init_from_cache( - const size_t num_basis_funcs, - const std::vector &shell_cache - ) const { - sunrealtype* a_data = SUNDenseMatrix_Data(A); - for (const auto&[start_idx, phi] : shell_cache) { - for (size_t i = 0; i < phi.size(); ++i) { - const size_t row = start_idx + i; - - for (size_t j = 0; j < phi.size(); ++j) { - const size_t col = start_idx + j; - - a_data[col * num_basis_funcs + row] += phi[i] * phi[j]; - } - } - } - - // Apply a small diagonal perturbation for numerical stability - for (size_t i = 0; i < num_basis_funcs; ++i) { - constexpr double epsilon = 1e-14; - a_data[i * num_basis_funcs + i] += epsilon; - } - - setup(); - } - - void SpectralSolverStrategy::DenseLinearSolver::init_from_basis( - const size_t num_basis_funcs, - const SplineBasis &basis - ) const { - sunrealtype* m_data = SUNDenseMatrix_Data(A); - for (size_t q = 0; q < basis.quadrature_nodes.size(); ++q) { - const double w_q = basis.quadrature_weights[q]; - const auto&[start_idx, phi] = basis.quad_evals[q]; - - for (size_t i = 0; i < phi.size(); ++i) { - const size_t row = start_idx + i; - for (size_t j = 0; j < phi.size(); ++j) { - const size_t col = start_idx + j; - m_data[col * num_basis_funcs + row] += w_q * phi[j] * phi[i]; - } - } - } - setup(); - } - - void SpectralSolverStrategy::DenseLinearSolver::setup() const { - utils::check_sundials_flag(SUNLinSolSetup(LS, A), "SUNLinSolSetup - Mass Matrix Solver", utils::SUNDIALS_RET_CODE_TYPES::CVODE); - } - - // ReSharper disable once CppMemberFunctionMayBeConst - void SpectralSolverStrategy::DenseLinearSolver::solve_inplace(const N_Vector x, const size_t num_vars, const size_t basis_size) const { - sunrealtype* x_data = N_VGetArrayPointer(x); - sunrealtype* tmp_data = N_VGetArrayPointer(temp_vector); - - for (size_t v = 0; v < num_vars; ++v) { - const size_t offset = v * basis_size; - - for (size_t i = 0; i < basis_size; ++i) { - tmp_data[i] = x_data[offset + i]; - } - SUNLinSolSolve(LS, A, temp_vector, temp_vector, 0.0); - - for (size_t i = 0; i < basis_size; ++i) { - x_data[offset + i] = tmp_data[i]; - } - } - } - - void SpectralSolverStrategy::DenseLinearSolver::solve_inplace_ptr( - sunrealtype *data_ptr, - const size_t num_vars, - const size_t basis_size - ) const { - sunrealtype* tmp_data = N_VGetArrayPointer(temp_vector); - for (size_t v = 0; v < num_vars; ++v) { - const size_t offset = v * basis_size; - sunrealtype* var_segment = data_ptr + offset; - - for (size_t i = 0; i < basis_size; ++i) { - tmp_data[i] = var_segment[i]; - } - - const int flag = SUNLinSolSolve(LS, A, temp_vector, temp_vector, 0.0); - - if (flag != 0) { - GF_OMP(critical,) { - std::cerr << "Mass matrix inversion failed in SpectralSolverStrategy::DenseLinearSolver::solve_inplace_ptr. This is a critical error which cannot, at present, be recovered from. If you see this message discard any results reported after it and please report this to the GridFire developers." << flag << std::endl; - std::abort(); - } - // TODO: We cannot trivially throw here if this function fails since it may be run in parallel. For now we trust init_from_cache to not - // generate a singular matrix; however, in the future we may want to implement a more robust error handling strategy. - } - - for (size_t i = 0; i < basis_size; ++i) { - var_segment[i] = tmp_data[i]; - } - } - } -} diff --git a/src/lib/solver/strategies/triggers/engine_partitioning_trigger.cpp b/src/lib/solver/strategies/triggers/engine_partitioning_trigger.cpp index 763782ec..0bbb7878 100644 --- a/src/lib/solver/strategies/triggers/engine_partitioning_trigger.cpp +++ b/src/lib/solver/strategies/triggers/engine_partitioning_trigger.cpp @@ -1,5 +1,5 @@ #include "gridfire/solver/strategies/triggers/engine_partitioning_trigger.h" -#include "gridfire/solver/strategies/CVODE_solver_strategy.h" +#include "gridfire/solver/strategies/PointSolver.h" #include "gridfire/trigger/trigger_logical.h" #include "gridfire/trigger/trigger_abstract.h" @@ -28,7 +28,7 @@ namespace gridfire::trigger::solver::CVODE { } } - bool SimulationTimeTrigger::check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const { + bool SimulationTimeTrigger::check(const gridfire::solver::PointSolverTimestepContext &ctx) const { if (ctx.t - m_last_trigger_time >= m_interval) { m_hits++; LOG_TRACE_L2(m_logger, "SimulationTimeTrigger triggered at t = {}, last trigger time was {}, delta = {}", ctx.t, m_last_trigger_time, m_last_trigger_time_delta); @@ -38,7 +38,7 @@ namespace gridfire::trigger::solver::CVODE { return false; } - void SimulationTimeTrigger::update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) { + void SimulationTimeTrigger::update(const gridfire::solver::PointSolverTimestepContext &ctx) { if (check(ctx)) { m_last_trigger_time_delta = (ctx.t - m_last_trigger_time) - m_interval; m_last_trigger_time = ctx.t; @@ -47,7 +47,7 @@ namespace gridfire::trigger::solver::CVODE { } void SimulationTimeTrigger::step( - const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx + const gridfire::solver::PointSolverTimestepContext &ctx ) { // --- SimulationTimeTrigger::step does nothing and is intentionally left blank --- // } @@ -65,7 +65,7 @@ namespace gridfire::trigger::solver::CVODE { return "Simulation Time Trigger"; } - TriggerResult SimulationTimeTrigger::why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const { + TriggerResult SimulationTimeTrigger::why(const gridfire::solver::PointSolverTimestepContext &ctx) const { TriggerResult result; result.name = name(); if (check(ctx)) { @@ -99,18 +99,18 @@ namespace gridfire::trigger::solver::CVODE { } } - bool OffDiagonalTrigger::check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const { + bool OffDiagonalTrigger::check(const gridfire::solver::PointSolverTimestepContext &ctx) const { //TODO : This currently does nothing return false; } - void OffDiagonalTrigger::update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) { + void OffDiagonalTrigger::update(const gridfire::solver::PointSolverTimestepContext &ctx) { m_updates++; } void OffDiagonalTrigger::step( - const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx + const gridfire::solver::PointSolverTimestepContext &ctx ) { // --- OffDiagonalTrigger::step does nothing and is intentionally left blank --- // } @@ -126,7 +126,7 @@ namespace gridfire::trigger::solver::CVODE { return "Off-Diagonal Trigger"; } - TriggerResult OffDiagonalTrigger::why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const { + TriggerResult OffDiagonalTrigger::why(const gridfire::solver::PointSolverTimestepContext &ctx) const { TriggerResult result; result.name = name(); @@ -173,7 +173,7 @@ namespace gridfire::trigger::solver::CVODE { } } - bool TimestepCollapseTrigger::check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const { + bool TimestepCollapseTrigger::check(const gridfire::solver::PointSolverTimestepContext &ctx) const { if (m_timestep_window.size() < m_windowSize) { m_misses++; return false; @@ -201,13 +201,13 @@ namespace gridfire::trigger::solver::CVODE { return false; } - void TimestepCollapseTrigger::update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) { + void TimestepCollapseTrigger::update(const gridfire::solver::PointSolverTimestepContext &ctx) { m_updates++; m_timestep_window.clear(); } void TimestepCollapseTrigger::step( - const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx + const gridfire::solver::PointSolverTimestepContext &ctx ) { push_to_fixed_deque(m_timestep_window, ctx.dt, m_windowSize); // --- TimestepCollapseTrigger::step does nothing and is intentionally left blank --- // @@ -226,7 +226,7 @@ namespace gridfire::trigger::solver::CVODE { } TriggerResult TimestepCollapseTrigger::why( - const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx + const gridfire::solver::PointSolverTimestepContext &ctx ) const { TriggerResult result; result.name = name(); @@ -263,7 +263,7 @@ namespace gridfire::trigger::solver::CVODE { m_windowSize(windowSize) {} bool ConvergenceFailureTrigger::check( - const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx + const gridfire::solver::PointSolverTimestepContext &ctx ) const { if (m_window.size() != m_windowSize) { m_misses++; @@ -278,13 +278,13 @@ namespace gridfire::trigger::solver::CVODE { } void ConvergenceFailureTrigger::update( - const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx + const gridfire::solver::PointSolverTimestepContext &ctx ) { m_window.clear(); } void ConvergenceFailureTrigger::step( - const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx + const gridfire::solver::PointSolverTimestepContext &ctx ) { push_to_fixed_deque(m_window, ctx.currentConvergenceFailures, m_windowSize); m_updates++; @@ -306,7 +306,7 @@ namespace gridfire::trigger::solver::CVODE { return "ConvergenceFailureTrigger(abs_failure_threshold=" + std::to_string(m_totalFailures) + ", rel_failure_threshold=" + std::to_string(m_relativeFailureRate) + ", windowSize=" + std::to_string(m_windowSize) + ")"; } - TriggerResult ConvergenceFailureTrigger::why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const { + TriggerResult ConvergenceFailureTrigger::why(const gridfire::solver::PointSolverTimestepContext &ctx) const { TriggerResult result; result.name = name(); @@ -348,7 +348,7 @@ namespace gridfire::trigger::solver::CVODE { } bool ConvergenceFailureTrigger::abs_failure( - const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx + const gridfire::solver::PointSolverTimestepContext &ctx ) const { if (ctx.currentConvergenceFailures > m_totalFailures) { return true; @@ -357,7 +357,7 @@ namespace gridfire::trigger::solver::CVODE { } bool ConvergenceFailureTrigger::rel_failure( - const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx + const gridfire::solver::PointSolverTimestepContext &ctx ) const { const float mean = current_mean(); if (mean < 10) { @@ -369,13 +369,13 @@ namespace gridfire::trigger::solver::CVODE { return false; } - std::unique_ptr> makeEnginePartitioningTrigger( + std::unique_ptr> makeEnginePartitioningTrigger( const double simulationTimeInterval, const double offDiagonalThreshold, const double timestepCollapseRatio, const size_t maxConvergenceFailures ) { - using ctx_t = gridfire::solver::CVODESolverStrategy::TimestepContext; + using ctx_t = gridfire::solver::PointSolverTimestepContext; // 1. INSTABILITY TRIGGERS (High Priority) auto convergenceFailureTrigger = std::make_unique( diff --git a/src/meson.build b/src/meson.build index f7af84f3..712793ac 100644 --- a/src/meson.build +++ b/src/meson.build @@ -15,8 +15,8 @@ gridfire_sources = files( 'lib/reaction/weak/weak_interpolator.cpp', 'lib/io/network_file.cpp', 'lib/io/generative/python.cpp', - 'lib/solver/strategies/CVODE_solver_strategy.cpp', - 'lib/solver/strategies/SpectralSolverStrategy.cpp', + 'lib/solver/strategies/PointSolver.cpp', + 'lib/solver/strategies/GridSolver.cpp', 'lib/solver/strategies/triggers/engine_partitioning_trigger.cpp', 'lib/screening/screening_types.cpp', 'lib/screening/screening_weak.cpp', diff --git a/tests/graphnet_sandbox/main.cpp b/tests/graphnet_sandbox/main.cpp index e1c667ad..f53eee32 100644 --- a/tests/graphnet_sandbox/main.cpp +++ b/tests/graphnet_sandbox/main.cpp @@ -19,7 +19,7 @@ #include -#include "gridfire/reaction/reaclib.h" +#include "gridfire/utils/gf_omp.h" static std::terminate_handler g_previousHandler = nullptr; @@ -31,7 +31,7 @@ gridfire::NetIn init(const double temp, const double rho, const double tMax) { std::setlocale(LC_ALL, ""); g_previousHandler = std::set_terminate(quill_terminate_handler); quill::Logger* logger = fourdst::logging::LogManager::getInstance().getLogger("log"); - logger->set_log_level(quill::LogLevel::TraceL2); + logger->set_log_level(quill::LogLevel::Info); using namespace gridfire; const std::vector X = {0.7081145999999999, 2.94e-5, 0.276, 0.003, 0.0011, 9.62e-3, 1.62e-3, 5.16e-4}; @@ -143,7 +143,7 @@ void log_results(const gridfire::NetOut& netOut, const gridfire::NetIn& netIn) { } -void record_abundance_history_callback(const gridfire::solver::CVODESolverStrategy::TimestepContext& ctx) { +void record_abundance_history_callback(const gridfire::solver::PointSolverTimestepContext& ctx) { s_wrote_abundance_history = true; const auto& engine = ctx.engine; // std::unordered_map> abundances; @@ -224,11 +224,12 @@ void quill_terminate_handler() std::abort(); } -void callback_main(const gridfire::solver::CVODESolverStrategy::TimestepContext& ctx) { +void callback_main(const gridfire::solver::PointSolverTimestepContext& ctx) { record_abundance_history_callback(ctx); } int main() { + GF_PAR_INIT(); using namespace gridfire; constexpr size_t breaks = 1; @@ -239,98 +240,20 @@ int main() { const NetIn netIn = init(temp, rho, tMax); policy::MainSequencePolicy stellarPolicy(netIn.composition); - policy::ConstructionResults construct = stellarPolicy.construct(); + auto [engine, ctx_template] = stellarPolicy.construct(); std::println("Sandbox Engine Stack: {}", stellarPolicy); - std::println("Scratch Blob State: {}", *construct.scratch_blob); + std::println("Scratch Blob State: {}", *ctx_template); - - constexpr size_t runs = 1000; - auto startTime = std::chrono::high_resolution_clock::now(); - - // arrays to store timings - std::array, runs> setup_times; - std::array, runs> eval_times; - std::array serial_results; - for (size_t i = 0; i < runs; ++i) { - auto start_setup_time = std::chrono::high_resolution_clock::now(); - std::print("Run {}/{}\r", i + 1, runs); - solver::CVODESolverStrategy solver(construct.engine, *construct.scratch_blob); - // solver.set_callback(solver::CVODESolverStrategy::TimestepCallback(callback_main)); - solver.set_stdout_logging_enabled(false); - auto end_setup_time = std::chrono::high_resolution_clock::now(); - std::chrono::duration setup_elapsed = end_setup_time - start_setup_time; - setup_times[i] = setup_elapsed; - - auto start_eval_time = std::chrono::high_resolution_clock::now(); - const NetOut netOut = solver.evaluate(netIn); - auto end_eval_time = std::chrono::high_resolution_clock::now(); - serial_results[i] = netOut; - std::chrono::duration eval_elapsed = end_eval_time - start_eval_time; - eval_times[i] = eval_elapsed; - - // log_results(netOut, netIn); - } - auto endTime = std::chrono::high_resolution_clock::now(); - std::chrono::duration elapsed = endTime - startTime; - std::println(""); - - // Summarize serial timings - double total_setup_time = 0.0; - double total_eval_time = 0.0; - for (size_t i = 0; i < runs; ++i) { - total_setup_time += setup_times[i].count(); - total_eval_time += eval_times[i].count(); - } - std::println("Average Setup Time over {} runs: {:.6f} seconds", runs, total_setup_time / runs); - std::println("Average Evaluation Time over {} runs: {:.6f} seconds", runs, total_eval_time / runs); - std::println("Total Time for {} runs: {:.6f} seconds", runs, elapsed.count()); - std::println("Final H-1 Abundances Serial: {}", serial_results[0].composition.getMolarAbundance(fourdst::atomic::H_1)); - - // OPTIONAL: Prevent CppAD from returning memory to the system - // during execution to reduce overhead (can speed up tight loops) - CppAD::thread_alloc::hold_memory(true); - - std::array parallelResults; - std::array, runs> setupTimes; - std::array, runs> evalTimes; - std::array, runs> workspaces; - for (size_t i = 0; i < runs; ++i) { - workspaces[i] = construct.scratch_blob->clone_structure(); + constexpr size_t nZones = 100; + std::array netIns; + for (size_t zone = 0; zone < nZones; ++zone) { + netIns[zone] = netIn; + netIns[zone].temperature = 1.0e7; } + const solver::PointSolver localSolver(engine); + solver::GridSolverContext solverCtx(*ctx_template); + const solver::GridSolver gridSolver(engine, localSolver); - // Parallel runs - startTime = std::chrono::high_resolution_clock::now(); - for (size_t i = 0; i < runs; ++i) { - auto start_setup_time = std::chrono::high_resolution_clock::now(); - solver::CVODESolverStrategy solver(construct.engine, *workspaces[i]); - solver.set_stdout_logging_enabled(false); - auto end_setup_time = std::chrono::high_resolution_clock::now(); - std::chrono::duration setup_elapsed = end_setup_time - start_setup_time; - setupTimes[i] = setup_elapsed; - auto start_eval_time = std::chrono::high_resolution_clock::now(); - parallelResults[i] = solver.evaluate(netIn); - auto end_eval_time = std::chrono::high_resolution_clock::now(); - std::chrono::duration eval_elapsed = end_eval_time - start_eval_time; - evalTimes[i] = eval_elapsed; - } - endTime = std::chrono::high_resolution_clock::now(); - elapsed = endTime - startTime; - std::println(""); - - // Summarize parallel timings - total_setup_time = 0.0; - total_eval_time = 0.0; - for (size_t i = 0; i < runs; ++i) { - total_setup_time += setupTimes[i].count(); - total_eval_time += evalTimes[i].count(); - } - - std::println("Average Parallel Setup Time over {} runs: {:.6f} seconds", runs, total_setup_time / runs); - std::println("Average Parallel Evaluation Time over {} runs: {:.6f} seconds", runs, total_eval_time / runs); - std::println("Total Parallel Time for {} runs: {:.6f} seconds", runs, elapsed.count()); - - std::println("Final H-1 Abundances Parallel: {}", utils::iterable_to_delimited_string(parallelResults, ",", [](const auto& result) { - return result.composition.getMolarAbundance(fourdst::atomic::H_1); - })); + std::vector netOuts = gridSolver.evaluate(solverCtx, netIns | std::ranges::to()); } \ No newline at end of file diff --git a/tests/graphnet_sandbox/meson.build b/tests/graphnet_sandbox/meson.build index 65054041..0a1c5dc5 100644 --- a/tests/graphnet_sandbox/meson.build +++ b/tests/graphnet_sandbox/meson.build @@ -4,8 +4,8 @@ executable( dependencies: [gridfire_dep, cli11_dep], ) -executable( - 'spectral_sandbox', - 'spectral_main.cpp', - dependencies: [gridfire_dep, cli11_dep] -) +#executable( +# 'spectral_sandbox', +# 'spectral_main.cpp', +# dependencies: [gridfire_dep, cli11_dep] +#) diff --git a/tests/graphnet_sandbox/spectral_main.cpp b/tests/graphnet_sandbox/spectral_main.cpp deleted file mode 100644 index b6ff519b..00000000 --- a/tests/graphnet_sandbox/spectral_main.cpp +++ /dev/null @@ -1,108 +0,0 @@ -#include -#include -#include -#include - -#include "gridfire/gridfire.h" - -#include "fourdst/composition/composition.h" -#include "fourdst/logging/logging.h" -#include "fourdst/atomic/species.h" -#include "fourdst/composition/utils.h" - -#include "quill/Logger.h" -#include "quill/Backend.h" -#include "CLI/CLI.hpp" - -#include - - -static std::terminate_handler g_previousHandler = nullptr; -static std::vector>>> g_callbackHistory; -static bool s_wrote_abundance_history = false; -void quill_terminate_handler(); - -std::vector linspace(const double start, const double end, const size_t num) { - std::vector result; - if (num == 0) return result; - if (num == 1) { - result.push_back(start); - return result; - } - const double step = (end - start) / static_cast(num - 1); - for (size_t i = 0; i < num; ++i) { - result.push_back(start + i * step); - } - return result; -} - -std::vector init(const double tMin, const double tMax, const double rhoMin, const double rhoMax, const double nShells, const double evolveTime) { - std::setlocale(LC_ALL, ""); - g_previousHandler = std::set_terminate(quill_terminate_handler); - quill::Logger* logger = fourdst::logging::LogManager::getInstance().getLogger("log"); - logger->set_log_level(quill::LogLevel::TraceL2); - LOG_INFO(logger, "Initializing GridFire Spectral Solver Sandbox..."); - - using namespace gridfire; - const std::vector X = {0.7081145999999999, 2.94e-5, 0.276, 0.003, 0.0011, 9.62e-3, 1.62e-3, 5.16e-4}; - const std::vector symbols = {"H-1", "He-3", "He-4", "C-12", "N-14", "O-16", "Ne-20", "Mg-24"}; - - - const fourdst::composition::Composition composition = fourdst::composition::buildCompositionFromMassFractions(symbols, X); - - std::vector netIns; - for (const auto& [T, ρ]: std::views::zip(linspace(tMin, tMax, nShells), linspace(rhoMax, rhoMin, nShells))) { - NetIn netIn; - netIn.composition = composition; - netIn.temperature = T; - netIn.density = ρ; - netIn.energy = 0; - - netIn.tMax = evolveTime; - netIn.dt0 = 1e-12; - netIns.push_back(netIn); - } - - return netIns; -} -void quill_terminate_handler() -{ - quill::Backend::stop(); - if (g_previousHandler) - g_previousHandler(); - else - std::abort(); -} - -int main(int argc, char** argv) { - using namespace gridfire; - - CLI::App app{"GridFire Sandbox Application."}; - - double tMin = 1.5e7; - double tMax = 1.7e7; - double rhoMin = 1.5e2; - double rhoMax = 1.5e2; - double nShells = 15; - double evolveTime = 3.1536e+16; - - app.add_option("--tMin", tMin, "Minimum time in seconds"); - app.add_option("--tMax", tMax, "Maximum time in seconds"); - app.add_option("--rhoMin", rhoMin, "Minimum density in g/cm^3"); - app.add_option("--rhoMax", rhoMax, "Maximum density in g/cm^3"); - app.add_option("--nShells", nShells, "Number of shells"); - app.add_option("--evolveTime", evolveTime, "Maximum time in seconds"); - - CLI11_PARSE(app, argc, argv); - - const std::vector netIns = init(tMin, tMax, rhoMin, rhoMax, nShells, evolveTime); - - policy::MainSequencePolicy stellarPolicy(netIns[0].composition); - stellarPolicy.construct(); - policy::ConstructionResults construct = stellarPolicy.construct(); - - solver::SpectralSolverStrategy solver(construct.engine); - std::vector mass_coords = linspace(1e-5, 1.0, nShells); - - std::vector results = solver.evaluate(netIns, mass_coords, *construct.scratch_blob); -} diff --git a/tools/cli/gf_quick/main.cpp b/tools/cli/gf_quick/main.cpp new file mode 100644 index 00000000..333f4a44 --- /dev/null +++ b/tools/cli/gf_quick/main.cpp @@ -0,0 +1,373 @@ +// ReSharper disable CppUnusedIncludeDirective +#include +#include +#include +#include +#include + +#include "gridfire/gridfire.h" +#include // Required for parallel_setup + +#include "fourdst/composition/composition.h" +#include "fourdst/logging/logging.h" +#include "fourdst/atomic/species.h" +#include "fourdst/composition/utils.h" + +#include "quill/Logger.h" +#include "quill/Backend.h" +#include "CLI/CLI.hpp" + +#include + +#include "gridfire/utils/gf_omp.h" + + +static std::terminate_handler g_previousHandler = nullptr; +static std::vector>>> g_callbackHistory; +static bool s_wrote_abundance_history = false; +void quill_terminate_handler(); + +using namespace fourdst::composition; +Composition rescale(const Composition& comp, double target_X, double target_Z) { + // 1. Validate inputs + if (target_X < 0.0 || target_Z < 0.0 || (target_X + target_Z) > 1.0 + 1e-14) { + throw std::invalid_argument("Target mass fractions X and Z must be non-negative and sum to <= 1.0"); + } + + // Force high precision for the target Y to ensure X+Y+Z = 1.0 exactly in our logic + long double ld_target_X = static_cast(target_X); + long double ld_target_Z = static_cast(target_Z); + long double ld_target_Y = 1.0L - ld_target_X - ld_target_Z; + + // Clamp Y to 0 if it dipped slightly below due to precision (e.g. X+Z=1.0000000001) + if (ld_target_Y < 0.0L) ld_target_Y = 0.0L; + + // 2. Manually calculate current Mass Totals (bypass getCanonicalComposition to avoid crashes) + long double total_mass_H = 0.0L; + long double total_mass_He = 0.0L; + long double total_mass_Z = 0.0L; + + // We need to iterate and identify species types manually + // Standard definition: H (z=1), He (z=2), Metals (z>2) + // Note: We use long double accumulators to prevent summation drift + for (const auto& [spec, molar_abundance] : comp) { + // Retrieve atomic properties. + // Note: usage assumes fourdst::atomic::Species has .z() and .mass() + // consistent with the provided composition.cpp + int z = spec.z(); + double a = spec.mass(); + + long double mass_contribution = static_cast(molar_abundance) * static_cast(a); + + if (z == 1) { + total_mass_H += mass_contribution; + } else if (z == 2) { + total_mass_He += mass_contribution; + } else { + total_mass_Z += mass_contribution; + } + } + + long double total_mass_current = total_mass_H + total_mass_He + total_mass_Z; + + // Edge case: Empty composition + if (total_mass_current <= 0.0L) { + // Return empty or throw? If input was empty, return empty. + if (comp.size() == 0) return comp; + throw std::runtime_error("Input composition has zero total mass."); + } + + // 3. Calculate Scaling Factors + // Factor = (Target_Mass_Fraction / Old_Mass_Fraction) + // = (Target_Mass_Fraction) / (Old_Group_Mass / Total_Mass) + // = (Target_Mass_Fraction * Total_Mass) / Old_Group_Mass + + long double scale_H = 0.0L; + long double scale_He = 0.0L; + long double scale_Z = 0.0L; + + if (ld_target_X > 1e-16L) { + if (total_mass_H <= 1e-19L) { + throw std::runtime_error("Cannot rescale Hydrogen to " + std::to_string(target_X) + + " because input has no Hydrogen."); + } + scale_H = (ld_target_X * total_mass_current) / total_mass_H; + } + + if (ld_target_Y > 1e-16L) { + if (total_mass_He <= 1e-19L) { + throw std::runtime_error("Cannot rescale Helium to " + std::to_string((double)ld_target_Y) + + " because input has no Helium."); + } + scale_He = (ld_target_Y * total_mass_current) / total_mass_He; + } + + if (ld_target_Z > 1e-16L) { + if (total_mass_Z <= 1e-19L) { + throw std::runtime_error("Cannot rescale Metals to " + std::to_string(target_Z) + + " because input has no Metals."); + } + scale_Z = (ld_target_Z * total_mass_current) / total_mass_Z; + } + + // 4. Apply Scaling and Construct New Vectors + std::vector new_species; + std::vector new_abundances; + new_species.reserve(comp.size()); + new_abundances.reserve(comp.size()); + + for (const auto& [spec, abundance] : comp) { + new_species.push_back(spec); + + long double factor = 0.0L; + int z = spec.z(); + + if (z == 1) { + factor = scale_H; + } else if (z == 2) { + factor = scale_He; + } else { + factor = scale_Z; + } + + // Calculate new abundance in long double then cast back + long double new_val_ld = static_cast(abundance) * factor; + new_abundances.push_back(static_cast(new_val_ld)); + } + + return Composition(new_species, new_abundances); +} + +gridfire::NetIn init(const double temp, const double rho, const double tMax) { + std::setlocale(LC_ALL, ""); + g_previousHandler = std::set_terminate(quill_terminate_handler); + quill::Logger* logger = fourdst::logging::LogManager::getInstance().getLogger("log"); + logger->set_log_level(quill::LogLevel::Info); + + using namespace gridfire; + const std::vector X = {0.7081145999999999, 2.94e-5, 0.276, 0.003, 0.0011, 9.62e-3, 1.62e-3, 5.16e-4}; + const std::vector symbols = {"H-1", "He-3", "He-4", "C-12", "N-14", "O-16", "Ne-20", "Mg-24"}; + + + const fourdst::composition::Composition composition = fourdst::composition::buildCompositionFromMassFractions(symbols, X); + + NetIn netIn; + netIn.composition = composition; + netIn.temperature = temp; + netIn.density = rho; + netIn.energy = 0; + + netIn.tMax = tMax; + netIn.dt0 = 1e-12; + + return netIn; +} + +void log_results(const gridfire::NetOut& netOut, const gridfire::NetIn& netIn) { + std::vector logSpecies = { + fourdst::atomic::H_1, + fourdst::atomic::He_3, + fourdst::atomic::He_4, + fourdst::atomic::C_12, + fourdst::atomic::N_14, + fourdst::atomic::O_16, + fourdst::atomic::Ne_20, + fourdst::atomic::Mg_24 + }; + + std::vector initial; + std::vector final; + std::vector delta; + std::vector fractional; + for (const auto& species : logSpecies) { + double initial_X = netIn.composition.getMassFraction(species); + double final_X = netOut.composition.getMassFraction(species); + double delta_X = final_X - initial_X; + double fractionalChange = (delta_X) / initial_X * 100.0; + + initial.push_back(initial_X); + final.push_back(final_X); + delta.push_back(delta_X); + fractional.push_back(fractionalChange); + } + + initial.push_back(0.0); // Placeholder for energy + final.push_back(netOut.energy); + delta.push_back(netOut.energy); + fractional.push_back(0.0); // Placeholder for energy + + initial.push_back(0.0); + final.push_back(netOut.dEps_dT); + delta.push_back(netOut.dEps_dT); + fractional.push_back(0.0); + + initial.push_back(0.0); + final.push_back(netOut.dEps_dRho); + delta.push_back(netOut.dEps_dRho); + fractional.push_back(0.0); + + initial.push_back(0.0); + final.push_back(netOut.specific_neutrino_energy_loss); + delta.push_back(netOut.specific_neutrino_energy_loss); + fractional.push_back(0.0); + + initial.push_back(0.0); + final.push_back(netOut.specific_neutrino_flux); + delta.push_back(netOut.specific_neutrino_flux); + fractional.push_back(0.0); + + initial.push_back(netIn.composition.getMeanParticleMass()); + final.push_back(netOut.composition.getMeanParticleMass()); + delta.push_back(final.back() - initial.back()); + fractional.push_back((final.back() - initial.back()) / initial.back() * 100.0); + + std::vector rowLabels = [&]() -> std::vector { + std::vector labels; + for (const auto& species : logSpecies) { + labels.emplace_back(species.name()); + } + labels.emplace_back("ε"); + labels.emplace_back("dε/dT"); + labels.emplace_back("dε/dρ"); + labels.emplace_back("Eν"); + labels.emplace_back("Fν"); + labels.emplace_back("<μ>"); + return labels; + }(); + + + gridfire::utils::Column paramCol("Parameter", rowLabels); + gridfire::utils::Column initialCol("Initial", initial); + gridfire::utils::Column finalCol ("Final", final); + gridfire::utils::Column deltaCol ("δ", delta); + gridfire::utils::Column percentCol("% Change", fractional); + + std::vector> columns; + columns.push_back(std::make_unique>(paramCol)); + columns.push_back(std::make_unique>(initialCol)); + columns.push_back(std::make_unique>(finalCol)); + columns.push_back(std::make_unique>(deltaCol)); + columns.push_back(std::make_unique>(percentCol)); + + + gridfire::utils::print_table("Simulation Results", columns); +} + + +void record_abundance_history_callback(const gridfire::solver::PointSolverTimestepContext& ctx) { + s_wrote_abundance_history = true; + const auto& engine = ctx.engine; + // std::unordered_map> abundances; + std::vector Y; + for (const auto& species : engine.getNetworkSpecies(ctx.state_ctx)) { + const size_t sid = engine.getSpeciesIndex(ctx.state_ctx, species); + double y = N_VGetArrayPointer(ctx.state)[sid]; + Y.push_back(y > 0.0 ? y : 0.0); // Regularize tiny negative abundances to zero + } + + fourdst::composition::Composition comp(engine.getNetworkSpecies(ctx.state_ctx), Y); + + + std::unordered_map> abundances; + for (const auto& sp : comp | std::views::keys) { + abundances.emplace(std::string(sp.name()), std::make_pair(sp.mass(), comp.getMolarAbundance(sp))); + } + g_callbackHistory.emplace_back(ctx.t, abundances); +} + + +void save_callback_data(const std::string_view filename) { + std::set unique_species; + for (const auto &abundances: g_callbackHistory | std::views::values) { + for (const auto &species_name: abundances | std::views::keys) { + unique_species.insert(species_name); + } + } + std::ofstream csvFile(filename.data(), std::ios::out); + csvFile << "t,"; + + size_t i = 0; + for (const auto& species_name : unique_species) { + csvFile << species_name; + if (i < unique_species.size() - 1) { + csvFile << ","; + } + i++; + } + + csvFile << "\n"; + + for (const auto& [time, data] : g_callbackHistory) { + csvFile << time << ","; + size_t j = 0; + for (const auto& species_name : unique_species) { + if (!data.contains(species_name)) { + csvFile << "0.0"; + } else { + csvFile << data.at(species_name).second; + } + if (j < unique_species.size() - 1) { + csvFile << ","; + } + ++j; + } + csvFile << "\n"; + } + + csvFile.close(); +} + +void log_callback_data(const double temp) { + if (s_wrote_abundance_history) { + std::cout << "Saving abundance history to abundance_history.csv" << std::endl; + save_callback_data("abundance_history_" + std::to_string(temp) + ".csv"); + } + +} + +void quill_terminate_handler() +{ + log_callback_data(1.5e7); + quill::Backend::stop(); + if (g_previousHandler) + g_previousHandler(); + else + std::abort(); +} + +void callback_main(const gridfire::solver::PointSolverTimestepContext& ctx) { + record_abundance_history_callback(ctx); +} + +int main(int argc, char** argv) { + GF_PAR_INIT(); + using namespace gridfire; + + double temp = 1.5e7; + double rho = 1.5e2; + double tMax = 3.1536e+16; + double X = 0.7; + double Z = 0.02; + + + CLI::App app("GridFire Quick CLI Test"); + // Add temp, rho, and tMax as options if desired + app.add_option("--temp", temp, "Initial Temperature")->default_val(std::format("{:5.2E}", temp)); + app.add_option("--rho", rho, "Initial Density")->default_val(std::format("{:5.2E}", rho)); + app.add_option("--tmax", tMax, "Maximum Time")->default_val(std::format("{:5.2E}", tMax)); + // app.add_option("--X", X, "Target Hydrogen Mass Fraction")->default_val(std::format("{:5.2f}", X)); + // app.add_option("--Z", Z, "Target Metal Mass Fraction")->default_val(std::format("{:5.2f}", Z)); + + CLI11_PARSE(app, argc, argv); + NetIn netIn = init(temp, rho, tMax); + // netIn.composition = rescale(netIn.composition, X, Z); + + policy::MainSequencePolicy stellarPolicy(netIn.composition); + auto [engine, ctx_template] = stellarPolicy.construct(); + + solver::PointSolverContext solver_context(*ctx_template); + solver::PointSolver solver(engine); + + NetOut result = solver.evaluate(solver_context, netIn); + log_results(result, netIn); +} \ No newline at end of file diff --git a/tools/cli/gf_quick/meson.build b/tools/cli/gf_quick/meson.build new file mode 100644 index 00000000..6e520a11 --- /dev/null +++ b/tools/cli/gf_quick/meson.build @@ -0,0 +1 @@ +executable('gf_quick', 'main.cpp', dependencies: [gridfire_dep, cli11_dep]) \ No newline at end of file diff --git a/tools/cli/meson.build b/tools/cli/meson.build new file mode 100644 index 00000000..336435f3 --- /dev/null +++ b/tools/cli/meson.build @@ -0,0 +1 @@ +subdir('gf_quick') \ No newline at end of file diff --git a/tools/meson.build b/tools/meson.build index 75f2745b..93d1e5a2 100644 --- a/tools/meson.build +++ b/tools/meson.build @@ -1,3 +1,4 @@ if get_option('build_tools') subdir('config') + subdir('cli') endif \ No newline at end of file