perf(multi): Simple parallel multi zone solver

Added a simple parallel multi-zone solver
This commit is contained in:
2025-12-18 12:47:39 -05:00
parent 4e1edfc142
commit dcfd7b60aa
27 changed files with 1018 additions and 2193 deletions

View File

@@ -63,7 +63,7 @@ int main() {
std::println("Scratch Blob State: {}", *construct.scratch_blob); std::println("Scratch Blob State: {}", *construct.scratch_blob);
constexpr size_t runs = 1000; constexpr size_t runs = 10;
auto startTime = std::chrono::high_resolution_clock::now(); auto startTime = std::chrono::high_resolution_clock::now();
// arrays to store timings // arrays to store timings
@@ -72,14 +72,15 @@ int main() {
std::array<NetOut, runs> serial_results; std::array<NetOut, runs> serial_results;
for (size_t i = 0; i < runs; ++i) { for (size_t i = 0; i < runs; ++i) {
auto start_setup_time = std::chrono::high_resolution_clock::now(); auto start_setup_time = std::chrono::high_resolution_clock::now();
solver::CVODESolverStrategy solver(construct.engine, *construct.scratch_blob); solver::PointSolverContext solverCtx(*construct.scratch_blob);
solver.set_stdout_logging_enabled(false); solverCtx.set_stdout_logging(false);
solver::PointSolver solver(construct.engine);
auto end_setup_time = std::chrono::high_resolution_clock::now(); auto end_setup_time = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> setup_elapsed = end_setup_time - start_setup_time; std::chrono::duration<double> setup_elapsed = end_setup_time - start_setup_time;
setup_times[i] = setup_elapsed; setup_times[i] = setup_elapsed;
auto start_eval_time = std::chrono::high_resolution_clock::now(); auto start_eval_time = std::chrono::high_resolution_clock::now();
const NetOut netOut = solver.evaluate(netIn); const NetOut netOut = solver.evaluate(solverCtx, netIn);
auto end_eval_time = std::chrono::high_resolution_clock::now(); auto end_eval_time = std::chrono::high_resolution_clock::now();
serial_results[i] = netOut; serial_results[i] = netOut;
std::chrono::duration<double> eval_elapsed = end_eval_time - start_eval_time; std::chrono::duration<double> eval_elapsed = end_eval_time - start_eval_time;
@@ -99,7 +100,6 @@ int main() {
std::println("Average Setup Time over {} runs: {:.6f} seconds", runs, total_setup_time / runs); std::println("Average Setup Time over {} runs: {:.6f} seconds", runs, total_setup_time / runs);
std::println("Average Evaluation Time over {} runs: {:.6f} seconds", runs, total_eval_time / runs); std::println("Average Evaluation Time over {} runs: {:.6f} seconds", runs, total_eval_time / runs);
std::println("Total Time for {} runs: {:.6f} seconds", runs, elapsed.count()); std::println("Total Time for {} runs: {:.6f} seconds", runs, elapsed.count());
std::println("Final H-1 Abundances Serial: {}", serial_results[0].composition.getMolarAbundance(fourdst::atomic::H_1));
std::array<NetOut, runs> parallelResults; std::array<NetOut, runs> parallelResults;
@@ -114,16 +114,16 @@ int main() {
// Parallel runs // Parallel runs
startTime = std::chrono::high_resolution_clock::now(); startTime = std::chrono::high_resolution_clock::now();
GF_OMP(parallel for,) GF_OMP(parallel for, for (size_t i = 0; i < runs; ++i)) {
for (size_t i = 0; i < runs; ++i) {
auto start_setup_time = std::chrono::high_resolution_clock::now(); auto start_setup_time = std::chrono::high_resolution_clock::now();
solver::CVODESolverStrategy solver(construct.engine, *workspaces[i]); solver::PointSolverContext solverCtx(*construct.scratch_blob);
solver.set_stdout_logging_enabled(false); solverCtx.set_stdout_logging(false);
solver::PointSolver solver(construct.engine);
auto end_setup_time = std::chrono::high_resolution_clock::now(); auto end_setup_time = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> setup_elapsed = end_setup_time - start_setup_time; std::chrono::duration<double> setup_elapsed = end_setup_time - start_setup_time;
setupTimes[i] = setup_elapsed; setupTimes[i] = setup_elapsed;
auto start_eval_time = std::chrono::high_resolution_clock::now(); auto start_eval_time = std::chrono::high_resolution_clock::now();
parallelResults[i] = solver.evaluate(netIn); parallelResults[i] = solver.evaluate(solverCtx, netIn);
auto end_eval_time = std::chrono::high_resolution_clock::now(); auto end_eval_time = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> eval_elapsed = end_eval_time - start_eval_time; std::chrono::duration<double> eval_elapsed = end_eval_time - start_eval_time;
evalTimes[i] = eval_elapsed; evalTimes[i] = eval_elapsed;
@@ -144,10 +144,6 @@ int main() {
std::println("Average Parallel Evaluation Time over {} runs: {:.6f} seconds", runs, total_eval_time / runs); std::println("Average Parallel Evaluation Time over {} runs: {:.6f} seconds", runs, total_eval_time / runs);
std::println("Total Parallel Time for {} runs: {:.6f} seconds", runs, elapsed.count()); std::println("Total Parallel Time for {} runs: {:.6f} seconds", runs, elapsed.count());
std::println("Final H-1 Abundances Parallel: {}", utils::iterable_to_delimited_string(parallelResults, ",", [](const auto& result) {
return result.composition.getMolarAbundance(fourdst::atomic::H_1);
}));
std::println("========== Summary =========="); std::println("========== Summary ==========");
std::println("Serial Runs:"); std::println("Serial Runs:");
std::println(" Average Setup Time: {:.6f} seconds", total_setup_time / runs); std::println(" Average Setup Time: {:.6f} seconds", total_setup_time / runs);

View File

@@ -35,3 +35,16 @@ endif
if get_option('openmp_support') if get_option('openmp_support')
add_project_arguments('-DGF_USE_OPENMP', language: 'cpp') add_project_arguments('-DGF_USE_OPENMP', language: 'cpp')
endif endif
if get_option('asan') and get_option('buildtype') != 'debug' and get_option('buildtype') != 'debugoptimized'
error('AddressSanitizer (ASan) can only be enabled for debug or debugoptimized builds')
endif
if get_option('asan') and (get_option('buildtype') == 'debugoptimized' or get_option('buildtype') == 'debug')
message('enabling AddressSanitizer (ASan) support')
add_project_arguments('-fsanitize=address,undefined', language: 'cpp')
add_project_arguments('-fno-omit-frame-pointer', language: 'cpp')
add_project_link_arguments('-fsanitize=address,undefined', language: 'cpp')
add_project_link_arguments('-fno-omit-frame-pointer', language: 'cpp')
endif

View File

@@ -11,4 +11,5 @@ option('build_c_api', type: 'boolean', value: true, description: 'compile the C
option('build_tools', type: 'boolean', value: true, description: 'build the GridFire command line tools') option('build_tools', type: 'boolean', value: true, description: 'build the GridFire command line tools')
option('openmp_support', type: 'boolean', value: false, description: 'Enable OpenMP support for parallelization') option('openmp_support', type: 'boolean', value: false, description: 'Enable OpenMP support for parallelization')
option('use_mimalloc', type: 'boolean', value: true, description: 'Use mimalloc as the memory allocator for GridFire. Generally this is ~10% faster than the system allocator.') option('use_mimalloc', type: 'boolean', value: true, description: 'Use mimalloc as the memory allocator for GridFire. Generally this is ~10% faster than the system allocator.')
option('build_benchmarks', type: 'boolean', value: false, description: 'build the benchmark suite') option('build_benchmarks', type: 'boolean', value: false, description: 'build the benchmark suite')
option('asan', type: 'boolean', value: false, description: 'Enable AddressSanitizer (ASan) support for detecting memory errors')

View File

@@ -11,10 +11,10 @@ namespace gridfire::config {
struct SpectralSolverConfig { struct SpectralSolverConfig {
struct Trigger { struct Trigger {
double simulationTimeInterval = 1.0e12;
double offDiagonalThreshold = 1.0e10;
double timestepCollapseRatio = 0.5; double timestepCollapseRatio = 0.5;
size_t maxConvergenceFailures = 2; size_t maxConvergenceFailures = 2;
double relativeFailureRate = 0.5;
size_t windowSize = 10;
}; };
struct MonitorFunctionConfig { struct MonitorFunctionConfig {
double structure_weight = 1.0; double structure_weight = 1.0;

View File

@@ -807,8 +807,6 @@ namespace gridfire::engine {
CppAD::ADFun<double> m_authoritativeADFun; CppAD::ADFun<double> m_authoritativeADFun;
const size_t m_state_blob_offset;
private: private:
/** /**
* @brief Synchronizes the internal maps. * @brief Synchronizes the internal maps.

View File

@@ -0,0 +1,43 @@
#pragma once
#include "gridfire/solver/strategies/strategy_abstract.h"
#include <functional>
namespace gridfire::solver {
struct GridSolverContext final : SolverContextBase {
std::vector<std::unique_ptr<SolverContextBase>> solver_workspaces;
std::vector<std::function<void(const TimestepContextBase&)>> timestep_callbacks;
const engine::scratch::StateBlob& ctx_template;
bool zone_completion_logging = true;
bool zone_stdout_logging = false;
bool zone_detailed_logging = false;
void init() override;
void reset();
void set_callback(const std::function<void(const TimestepContextBase&)> &callback);
void set_callback(const std::function<void(const TimestepContextBase&)> &callback, size_t zone_idx);
void set_stdout_logging(bool enable) override;
void set_detailed_logging(bool enable) override;
explicit GridSolverContext(const engine::scratch::StateBlob& ctx_template);
};
class GridSolver final : public MultiZoneDynamicNetworkSolver {
public:
GridSolver(
const engine::DynamicEngine& engine,
const SingleZoneDynamicNetworkSolver& solver
);
~GridSolver() override = default;
std::vector<NetOut> evaluate(
SolverContextBase& ctx,
const std::vector<NetIn>& netIns
) const override;
};
}

View File

@@ -44,8 +44,88 @@
#endif #endif
namespace gridfire::solver { namespace gridfire::solver {
struct PointSolverTimestepContext final : TimestepContextBase {
const double t; ///< Current integration time [s].
const N_Vector& state; ///< Current CVODE state vector (N_Vector).
const double dt; ///< Last step size [s].
const double last_step_time; ///< Time at last callback [s].
const double T9; ///< Temperature in GK.
const double rho; ///< Density [g cm^-3].
const size_t num_steps; ///< Number of CVODE steps taken so far.
const engine::DynamicEngine& engine; ///< Reference to the engine.
const std::vector<fourdst::atomic::Species>& networkSpecies; ///< Species layout.
const size_t currentConvergenceFailures; ///< Total number of convergence failures
const size_t currentNonlinearIterations; ///< Total number of non-linear iterations
const std::map<fourdst::atomic::Species, std::unordered_map<std::string, double>>& reactionContributionMap; ///< Map of reaction contributions for the current step
engine::scratch::StateBlob& state_ctx; ///< Reference to the engine scratch state blob
PointSolverTimestepContext(
double t,
const N_Vector& state,
double dt,
double last_step_time,
double t9,
double rho,
size_t num_steps,
const engine::DynamicEngine& engine,
const std::vector<fourdst::atomic::Species>& networkSpecies,
size_t currentConvergenceFailure,
size_t currentNonlinearIterations,
const std::map<fourdst::atomic::Species, std::unordered_map<std::string, double>> &reactionContributionMap,
engine::scratch::StateBlob& state_ctx
);
[[nodiscard]] std::vector<std::tuple<std::string, std::string>> describe() const override;
};
using TimestepCallback = std::function<void(const PointSolverTimestepContext& context)>; ///< Type alias for a timestep callback function.
struct PointSolverContext final : SolverContextBase {
SUNContext sun_ctx = nullptr; ///< SUNDIALS context (lifetime of the solver).
void* cvode_mem = nullptr; ///< CVODE memory block.
N_Vector Y = nullptr; ///< CVODE state vector (species + energy accumulator).
N_Vector YErr = nullptr; ///< Estimated local errors.
SUNMatrix J = nullptr; ///< Dense Jacobian matrix.
SUNLinearSolver LS = nullptr; ///< Dense linear solver.
std::unique_ptr<engine::scratch::StateBlob> engine_ctx;
std::optional<TimestepCallback> callback; ///< Optional per-step callback.
int num_steps = 0; ///< CVODE step counter (used for diagnostics and triggers).
bool stdout_logging = true; ///< If true, print per-step logs and use CV_ONE_STEP.
N_Vector constraints = nullptr; ///< CVODE constraints vector (>= 0 for species entries).
std::optional<double> abs_tol; ///< User-specified absolute tolerance.
std::optional<double> rel_tol; ///< User-specified relative tolerance.
bool detailed_step_logging = false; ///< If true, log detailed step diagnostics (error ratios, Jacobian, species balance).
size_t last_size = 0;
size_t last_composition_hash = 0ULL;
sunrealtype last_good_time_step = 0ULL;
void init() override;
void set_stdout_logging(bool enable) override;
void set_detailed_logging(bool enable) override;
void reset_all();
void reset_user();
void reset_cvode();
void clear_context();
void init_context();
[[nodiscard]] bool has_context() const;
explicit PointSolverContext(const engine::scratch::StateBlob& engine_ctx);
~PointSolverContext() override;
};
/** /**
* @class CVODESolverStrategy * @class PointSolver
* @brief Stiff ODE integrator backed by SUNDIALS CVODE (BDF) for network + energy. * @brief Stiff ODE integrator backed by SUNDIALS CVODE (BDF) for network + energy.
* *
* Integrates the nuclear network abundances along with a final accumulator entry for specific * Integrates the nuclear network abundances along with a final accumulator entry for specific
@@ -78,27 +158,16 @@ namespace gridfire::solver {
* std::cout << "Final energy: " << out.energy << " erg/g\n"; * std::cout << "Final energy: " << out.energy << " erg/g\n";
* @endcode * @endcode
*/ */
class CVODESolverStrategy final : public SingleZoneDynamicNetworkSolver { class PointSolver final : public SingleZoneDynamicNetworkSolver {
public: public:
/** /**
* @brief Construct the CVODE strategy and create a SUNDIALS context. * @brief Construct the CVODE strategy and create a SUNDIALS context.
* @param engine DynamicEngine used for RHS/Jacobian evaluation and network access. * @param engine DynamicEngine used for RHS/Jacobian evaluation and network access.
* @throws std::runtime_error If SUNContext_Create fails. * @throws std::runtime_error If SUNContext_Create fails.
*/ */
explicit CVODESolverStrategy( explicit PointSolver(
const engine::DynamicEngine& engine, const engine::DynamicEngine& engine
const engine::scratch::StateBlob& ctx
); );
/**
* @brief Destructor: cleans CVODE/SUNDIALS resources and frees SUNContext.
*/
~CVODESolverStrategy() override;
// Make the class non-copyable and non-movable to prevent shallow copies of CVODE pointers
CVODESolverStrategy(const CVODESolverStrategy&) = delete;
CVODESolverStrategy& operator=(const CVODESolverStrategy&) = delete;
CVODESolverStrategy(CVODESolverStrategy&&) = delete;
CVODESolverStrategy& operator=(CVODESolverStrategy&&) = delete;
/** /**
* @brief Integrate from t=0 to netIn.tMax and return final composition and energy. * @brief Integrate from t=0 to netIn.tMax and return final composition and energy.
@@ -114,6 +183,7 @@ namespace gridfire::solver {
* - At the end, converts molar abundances to mass fractions and assembles NetOut, * - At the end, converts molar abundances to mass fractions and assembles NetOut,
* including derivatives of energy w.r.t. T and rho from the engine. * including derivatives of energy w.r.t. T and rho from the engine.
* *
* @param solver_ctx
* @param netIn Inputs: temperature [K], density [g cm^-3], tMax [s], composition. * @param netIn Inputs: temperature [K], density [g cm^-3], tMax [s], composition.
* @return NetOut containing final Composition, accumulated energy [erg/g], step count, * @return NetOut containing final Composition, accumulated energy [erg/g], step count,
* and dEps/dT, dEps/dRho. * and dEps/dT, dEps/dRho.
@@ -122,10 +192,14 @@ namespace gridfire::solver {
* @throws exceptions::StaleEngineTrigger Propagated if the engine signals a stale state * @throws exceptions::StaleEngineTrigger Propagated if the engine signals a stale state
* during RHS evaluation (captured in the wrapper then rethrown here). * during RHS evaluation (captured in the wrapper then rethrown here).
*/ */
NetOut evaluate(const NetIn& netIn) override; NetOut evaluate(
SolverContextBase& solver_ctx,
const NetIn& netIn
) const override;
/** /**
* @brief Call to evaluate which will let the user control if the trigger reasoning is displayed * @brief Call to evaluate which will let the user control if the trigger reasoning is displayed
* @param solver_ctx
* @param netIn Inputs: temperature [K], density [g cm^-3], tMax [s], composition. * @param netIn Inputs: temperature [K], density [g cm^-3], tMax [s], composition.
* @param displayTrigger Boolean flag to control if trigger reasoning is displayed * @param displayTrigger Boolean flag to control if trigger reasoning is displayed
* @param forceReinitialize Boolean flag to force reinitialization of CVODE resources at the start * @param forceReinitialize Boolean flag to force reinitialization of CVODE resources at the start
@@ -136,89 +210,13 @@ namespace gridfire::solver {
* @throws exceptions::StaleEngineTrigger Propagated if the engine signals a stale state * @throws exceptions::StaleEngineTrigger Propagated if the engine signals a stale state
* during RHS evaluation (captured in the wrapper then rethrown here). * during RHS evaluation (captured in the wrapper then rethrown here).
*/ */
NetOut evaluate(const NetIn& netIn, bool displayTrigger, bool forceReinitialize = false); NetOut evaluate(
SolverContextBase& solver_ctx,
const NetIn& netIn,
bool displayTrigger,
bool forceReinitialize = false
) const;
/**
* @brief Install a timestep callback.
* @param callback std::any containing TimestepCallback (std::function<void(const TimestepContext&)>).
* @throws std::bad_any_cast If callback is not of the expected type.
*/
void set_callback(const std::any &callback) override;
/**
* @brief Whether per-step logs are printed to stdout and CVode is stepped with CV_ONE_STEP.
*/
[[nodiscard]] bool get_stdout_logging_enabled() const;
/**
* @brief Enable/disable per-step stdout logging.
* @param logging_enabled Flag to control if a timestep summary is written to standard output or not
*/
void set_stdout_logging_enabled(bool logging_enabled);
void set_absTol(double absTol);
void set_relTol(double relTol);
double get_absTol() const;
double get_relTol() const;
/**
* @brief Schema of fields exposed to the timestep callback context.
*/
[[nodiscard]] std::vector<std::tuple<std::string, std::string>> describe_callback_context() const override;
/**
* @struct TimestepContext
* @brief Immutable view of the current integration state passed to callbacks.
*
* Fields capture CVODE time/state, step size, thermodynamic state, the engine reference,
* and the list of network species used to interpret the state vector layout.
*/
struct TimestepContext final : public SolverContextBase {
// This struct can be identical to the one in DirectNetworkSolver
const double t; ///< Current integration time [s].
const N_Vector& state; ///< Current CVODE state vector (N_Vector).
const double dt; ///< Last step size [s].
const double last_step_time; ///< Time at last callback [s].
const double T9; ///< Temperature in GK.
const double rho; ///< Density [g cm^-3].
const size_t num_steps; ///< Number of CVODE steps taken so far.
const engine::DynamicEngine& engine; ///< Reference to the engine.
const std::vector<fourdst::atomic::Species>& networkSpecies; ///< Species layout.
const size_t currentConvergenceFailures; ///< Total number of convergence failures
const size_t currentNonlinearIterations; ///< Total number of non-linear iterations
const std::map<fourdst::atomic::Species, std::unordered_map<std::string, double>>& reactionContributionMap; ///< Map of reaction contributions for the current step
engine::scratch::StateBlob& state_ctx; ///< Reference to the engine scratch state blob
/**
* @brief Construct a context snapshot.
*/
TimestepContext(
double t,
const N_Vector& state,
double dt,
double last_step_time,
double t9,
double rho,
size_t num_steps,
const engine::DynamicEngine& engine,
const std::vector<fourdst::atomic::Species>& networkSpecies,
size_t currentConvergenceFailure,
size_t currentNonlinearIterations,
const std::map<fourdst::atomic::Species, std::unordered_map<std::string, double>> &reactionContributionMap,
engine::scratch::StateBlob& state_ctx
);
/**
* @brief Human-readable description of the context fields.
*/
[[nodiscard]] std::vector<std::tuple<std::string, std::string>> describe() const override;
};
/**
* @brief Type alias for a timestep callback.
*/
using TimestepCallback = std::function<void(const TimestepContext& context)>; ///< Type alias for a timestep callback function.
private: private:
/** /**
* @struct CVODEUserData * @struct CVODEUserData
@@ -230,7 +228,8 @@ namespace gridfire::solver {
* to CVODE, then the driver loop inspects and rethrows. * to CVODE, then the driver loop inspects and rethrows.
*/ */
struct CVODEUserData { struct CVODEUserData {
CVODESolverStrategy* solver_instance{}; // Pointer back to the class instance const PointSolver* solver_instance{}; // Pointer back to the class instance
PointSolverContext* sctx; // Pointer to the solver context
engine::scratch::StateBlob& ctx; engine::scratch::StateBlob& ctx;
const engine::DynamicEngine* engine{}; const engine::DynamicEngine* engine{};
double T9{}; double T9{};
@@ -283,6 +282,7 @@ namespace gridfire::solver {
* step size, creates a dense matrix and dense linear solver, and registers the Jacobian. * step size, creates a dense matrix and dense linear solver, and registers the Jacobian.
*/ */
void initialize_cvode_integration_resources( void initialize_cvode_integration_resources(
PointSolverContext* ctx,
uint64_t N, uint64_t N,
size_t numSpecies, size_t numSpecies,
double current_time, double current_time,
@@ -290,15 +290,7 @@ namespace gridfire::solver {
double absTol, double absTol,
double relTol, double relTol,
double accumulatedEnergy double accumulatedEnergy
); ) const;
/**
* @brief Destroy CVODE vectors/linear algebra and optionally the CVODE memory block.
* @param memFree If true, also calls CVodeFree on m_cvode_mem.
*/
void cleanup_cvode_resources(bool memFree);
void set_detailed_step_logging(bool enabled);
/** /**
@@ -308,31 +300,13 @@ namespace gridfire::solver {
* sorted table of species with the highest error ratios; then invokes diagnostic routines to * sorted table of species with the highest error ratios; then invokes diagnostic routines to
* inspect Jacobian stiffness and species balance. * inspect Jacobian stiffness and species balance.
*/ */
void log_step_diagnostics(engine::scratch::StateBlob &ctx, const CVODEUserData& user_data, bool displayJacobianStiffness, bool void log_step_diagnostics(
displaySpeciesBalance, bool to_file, std::optional<std::string> filename) const; PointSolverContext* sctx_p,
private: engine::scratch::StateBlob &ctx,
SUNContext m_sun_ctx = nullptr; ///< SUNDIALS context (lifetime of the solver). const CVODEUserData& user_data,
void* m_cvode_mem = nullptr; ///< CVODE memory block. bool displayJacobianStiffness,
N_Vector m_Y = nullptr; ///< CVODE state vector (species + energy accumulator). bool displaySpeciesBalance,
N_Vector m_YErr = nullptr; ///< Estimated local errors. bool to_file, std::optional<std::string> filename
SUNMatrix m_J = nullptr; ///< Dense Jacobian matrix. ) const;
SUNLinearSolver m_LS = nullptr; ///< Dense linear solver.
std::optional<TimestepCallback> m_callback; ///< Optional per-step callback.
int m_num_steps = 0; ///< CVODE step counter (used for diagnostics and triggers).
bool m_stdout_logging_enabled = true; ///< If true, print per-step logs and use CV_ONE_STEP.
N_Vector m_constraints = nullptr; ///< CVODE constraints vector (>= 0 for species entries).
std::optional<double> m_absTol; ///< User-specified absolute tolerance.
std::optional<double> m_relTol; ///< User-specified relative tolerance.
bool m_detailed_step_logging = false; ///< If true, log detailed step diagnostics (error ratios, Jacobian, species balance).
mutable size_t m_last_size = 0;
mutable size_t m_last_composition_hash = 0ULL;
mutable sunrealtype m_last_good_time_step = 0ULL;
}; };
} }

View File

@@ -1,221 +0,0 @@
#pragma once
#include "gridfire/solver/strategies/strategy_abstract.h"
#include "gridfire/engine/engine_abstract.h"
#include "gridfire/types/types.h"
#include "gridfire/config/config.h"
#include "fourdst/logging/logging.h"
#include "fourdst/constants/const.h"
#include <vector>
#include <functional>
#include <cvode/cvode.h>
#include <sundials/sundials_types.h>
#include "gridfire/exceptions/error_engine.h"
#ifdef SUNDIALS_HAVE_OPENMP
#include <nvector/nvector_openmp.h>
#endif
#ifdef SUNDIALS_HAVE_PTHREADS
#include <nvector/nvector_pthreads.hh>
#endif
#ifndef SUNDIALS_HAVE_OPENMP
#ifndef SUNDIALS_HAVE_PTHREADS
#include <nvector/nvector_serial.h>
#endif
#endif
namespace gridfire::solver {
class SpectralSolverStrategy final : public MultiZoneDynamicNetworkSolver {
public:
explicit SpectralSolverStrategy(const engine::DynamicEngine& engine);
~SpectralSolverStrategy() override;
std::vector<NetOut> evaluate(
const std::vector<NetIn> &netIns,
const std::vector<double>& mass_coords, const engine::scratch::StateBlob &ctx_template
) override;
void set_callback(const std::any &callback) override;
[[nodiscard]] std::vector<std::tuple<std::string, std::string>> describe_callback_context() const override;
[[nodiscard]] bool get_stdout_logging_enabled() const;
void set_stdout_logging_enabled(bool logging_enabled);
public:
struct TimestepContext final : public SolverContextBase {
TimestepContext(
const double t,
const N_Vector &state,
const double dt,
const double last_time_step,
const engine::DynamicEngine &engine
) :
t(t),
state(state),
dt(dt),
last_time_step(last_time_step),
engine(engine) {}
[[nodiscard]] std::vector<std::tuple<std::string, std::string>> describe() const override;
const double t;
const N_Vector& state;
const double dt;
const double last_time_step;
const engine::DynamicEngine& engine;
};
struct BasisEval {
size_t start_idx;
std::vector<double> phi;
};
struct SplineBasis {
std::vector<double> knots;
std::vector<double> quadrature_nodes;
std::vector<double> quadrature_weights;
int degree = 3;
std::vector<BasisEval> quad_evals;
};
public:
using TimestepCallback = std::function<void(const TimestepContext&)>;
private:
enum class ParallelInitializationResult : uint8_t {
SUCCESS,
FAILURE
};
struct SpectralCoefficients {
size_t num_sets;
size_t num_coefficients;
std::vector<double> coefficients;
double operator()(size_t i, size_t j) const;
};
struct GridPoint {
double T9;
double rho;
fourdst::composition::Composition composition;
};
struct Constants {
const double c = fourdst::constant::Constants::getInstance().get("c").value;
const double N_a = fourdst::constant::Constants::getInstance().get("N_a").value;
const double c2 = c * c;
};
struct DenseLinearSolver {
SUNMatrix A;
SUNLinearSolver LS;
N_Vector temp_vector;
SUNContext ctx;
DenseLinearSolver(size_t size, SUNContext sun_ctx);
~DenseLinearSolver();
DenseLinearSolver(const DenseLinearSolver&) = delete;
DenseLinearSolver& operator=(const DenseLinearSolver&) = delete;
void setup() const;
void zero() const;
void init_from_cache(size_t num_basis_funcs, const std::vector<BasisEval>& shell_cache) const;
void init_from_basis(size_t num_basis_funcs, const SplineBasis& basis) const;
void solve_inplace(N_Vector x, size_t num_vars, size_t basis_size) const;
void solve_inplace_ptr(sunrealtype* data_ptr, size_t num_vars, size_t basis_size) const;
};
struct CVODEUserData {
SpectralSolverStrategy* solver_instance{};
std::vector<std::reference_wrapper<engine::scratch::StateBlob>> workspaces;
const engine::DynamicEngine* engine{};
std::unique_ptr<exceptions::EngineError> captured_exception{};
std::vector<double> T9{};
std::vector<double> rho{};
double energy{};
double neutrino_energy_loss_rate = 0.0;
double total_neutrino_flux = 0.0;
DenseLinearSolver* mass_matrix_solver_instance{};
const SplineBasis* basis{};
};
private:
fourdst::config::Config<config::GridFireConfig> m_config;
quill::Logger* m_logger = fourdst::logging::LogManager::getInstance().getLogger("log");
SUNContext m_sun_ctx = nullptr; ///< SUNDIALS context (lifetime of the solver).
void* m_cvode_mem = nullptr; ///< CVODE memory block.
N_Vector m_Y = nullptr; ///< CVODE state vector (species + energy accumulator).
SUNMatrix m_J = nullptr; ///< Dense Jacobian matrix.
SUNLinearSolver m_LS = nullptr; ///< Dense linear solver.
std::optional<TimestepCallback> m_callback; ///< Optional per-step callback.
int m_num_steps = 0; ///< CVODE step counter (used for diagnostics and triggers).
bool m_stdout_logging_enabled = true; ///< If true, print per-step logs and use CV_ONE_STEP.
N_Vector m_constraints = nullptr; ///< CVODE constraints vector (>= 0 for species entries).
std::optional<double> m_absTol; ///< User-specified absolute tolerance.
std::optional<double> m_relTol; ///< User-specified relative tolerance.
bool m_detailed_step_logging = false; ///< If true, log detailed step diagnostics (error ratios, Jacobian, species balance).
mutable size_t m_last_size = 0;
mutable size_t m_last_composition_hash = 0ULL;
mutable sunrealtype m_last_good_time_step = 0ULL;
SplineBasis m_current_basis;
Constants m_constants;
N_Vector m_T_coeffs = nullptr;
N_Vector m_rho_coeffs = nullptr;
std::vector<fourdst::atomic::Species> m_global_species_list;
private:
std::vector<double> evaluate_monitor_function(const std::vector<NetIn>& current_shells) const;
static SplineBasis generate_basis_from_monitor(const std::vector<double>& monitor_values, const std::vector<double>& mass_coordinates, size_t actual_elements);
GridPoint reconstruct_at_quadrature(const N_Vector y_coeffs, size_t quad_index, const SplineBasis &basis) const;
std::vector<NetOut> reconstruct_solution(const std::vector<NetIn>& original_inputs, const std::vector<double>& mass_coordinates, const N_Vector final_coeffs, const SplineBasis& basis, double dt) const;
static int cvode_rhs_wrapper(sunrealtype t, N_Vector y, N_Vector, void* user_data);
static int cvode_jac_wrapper(sunrealtype t, N_Vector y, N_Vector ydot, SUNMatrix J, void* user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3);
int calculate_rhs(sunrealtype t, N_Vector y_coeffs, N_Vector ydot_coeffs, CVODEUserData* data) const;
int calculate_jacobian(sunrealtype t, N_Vector y_coeffs, N_Vector ydot_coeffs, SUNMatrix J, const CVODEUserData *data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3) const;
static size_t nyquist_elements(size_t requested_elements, size_t num_shells) ;
static void project_specific_variable(
const std::vector<NetIn>& current_shells,
const std::vector<double>& mass_coordinates,
const std::vector<BasisEval>& shell_cache,
const DenseLinearSolver& linear_solver,
N_Vector output_vec,
size_t output_offset,
const std::function<double(const NetIn&)> &getter,
bool use_log
);
void inspect_jacobian(SUNMatrix J, const std::string& context) const;
};
}

View File

@@ -2,5 +2,5 @@
#include "gridfire/solver/strategies/triggers/triggers.h" #include "gridfire/solver/strategies/triggers/triggers.h"
#include "gridfire/solver/strategies/strategy_abstract.h" #include "gridfire/solver/strategies/strategy_abstract.h"
#include "gridfire/solver/strategies/CVODE_solver_strategy.h" #include "gridfire/solver/strategies/PointSolver.h"
#include "gridfire/solver/strategies/SpectralSolverStrategy.h" #include "gridfire/solver/strategies/GridSolver.h"

View File

@@ -13,17 +13,24 @@ namespace gridfire::solver {
template <typename EngineT> template <typename EngineT>
concept IsEngine = std::is_base_of_v<engine::Engine, EngineT>; concept IsEngine = std::is_base_of_v<engine::Engine, EngineT>;
struct SolverContextBase {
virtual void init() = 0;
virtual void set_stdout_logging(bool enable) = 0;
virtual void set_detailed_logging(bool enable) = 0;
virtual ~SolverContextBase() = default;
};
/** /**
* @struct SolverContextBase * @struct TimestepContextBase
* @brief Base class for solver callback contexts. * @brief Base class for solver callback contexts.
* *
* This struct serves as a base class for contexts that can be passed to solver callbacks, it enforces * This struct serves as a base class for contexts that can be passed to solver callbacks, it enforces
* that derived classes implement a `describe` method that returns a vector of tuples describing * that derived classes implement a `describe` method that returns a vector of tuples describing
* the context that a callback will receive when called. * the context that a callback will receive when called.
*/ */
class SolverContextBase { class TimestepContextBase {
public: public:
virtual ~SolverContextBase() = default; virtual ~TimestepContextBase() = default;
/** /**
* @brief Describe the context for callback functions. * @brief Describe the context for callback functions.
@@ -54,11 +61,9 @@ namespace gridfire::solver {
* @param engine The engine to use for evaluating the network. * @param engine The engine to use for evaluating the network.
*/ */
explicit SingleZoneNetworkSolver( explicit SingleZoneNetworkSolver(
const EngineT& engine, const EngineT& engine
const engine::scratch::StateBlob& ctx
) : ) :
m_engine(engine), m_engine(engine) {};
m_scratch_blob(ctx.clone_structure()) {};
/** /**
* @brief Virtual destructor. * @brief Virtual destructor.
@@ -67,58 +72,39 @@ namespace gridfire::solver {
/** /**
* @brief Evaluates the network for a given timestep. * @brief Evaluates the network for a given timestep.
* @param solver_ctx
* @param engine_ctx
* @param netIn The input conditions for the network. * @param netIn The input conditions for the network.
* @return The output conditions after the timestep. * @return The output conditions after the timestep.
*/ */
virtual NetOut evaluate(const NetIn& netIn) = 0; virtual NetOut evaluate(
SolverContextBase& solver_ctx,
const NetIn& netIn
) const = 0;
/**
* @brief set the callback function to be called at the end of each timestep.
*
* This function allows the user to set a callback function that will be called at the end of each timestep.
* The callback function will receive a gridfire::solver::<SOMESOLVER>::TimestepContext object. Note that
* depending on the solver, this context may contain different information. Further, the exact
* signature of the callback function is left up to each solver. Every solver should provide a type or type alias
* TimestepCallback that defines the signature of the callback function so that the user can easily
* get that type information.
*
* @param callback The callback function to be called at the end of each timestep.
*/
virtual void set_callback(const std::any& callback) = 0;
/**
* @brief Describe the context that will be passed to the callback function.
* @return A vector of tuples, each containing a string for the parameter's name and a string for its type.
*
* This method should be overridden by derived classes to provide a description of the context
* that will be passed to the callback function. The intent of this method is that an end user can investigate
* the context that will be passed to the callback function, and use this information to craft their own
* callback function.
*/
[[nodiscard]] virtual std::vector<std::tuple<std::string, std::string>> describe_callback_context() const = 0;
protected: protected:
const EngineT& m_engine; ///< The engine used by this solver strategy. const EngineT& m_engine; ///< The engine used by this solver strategy.
std::unique_ptr<engine::scratch::StateBlob> m_scratch_blob;
}; };
template <IsEngine EngineT> template <IsEngine EngineT>
class MultiZoneNetworkSolver { class MultiZoneNetworkSolver {
public: public:
explicit MultiZoneNetworkSolver( explicit MultiZoneNetworkSolver(
const EngineT& engine const EngineT& engine,
const SingleZoneNetworkSolver<EngineT>& solver
) : ) :
m_engine(engine) {}; m_engine(engine),
m_solver(solver) {};
virtual ~MultiZoneNetworkSolver() = default; virtual ~MultiZoneNetworkSolver() = default;
virtual std::vector<NetOut> evaluate( virtual std::vector<NetOut> evaluate(
const std::vector<NetIn>& netIns, SolverContextBase& solver_ctx,
const std::vector<double>& mass_coords, const engine::scratch::StateBlob &ctx_template const std::vector<NetIn>& netIns
) = 0; ) const = 0;
virtual void set_callback(const std::any& callback) = 0;
[[nodiscard]] virtual std::vector<std::tuple<std::string, std::string>> describe_callback_context() const = 0;
protected: protected:
const EngineT& m_engine; ///< The engine used by this solver strategy. const EngineT& m_engine; ///< The engine used by this solver strategy.
const SingleZoneNetworkSolver<EngineT>& m_solver;
}; };
/** /**

View File

@@ -2,7 +2,7 @@
#include "gridfire/trigger/trigger_abstract.h" #include "gridfire/trigger/trigger_abstract.h"
#include "gridfire/trigger/trigger_result.h" #include "gridfire/trigger/trigger_result.h"
#include "gridfire/solver/strategies/CVODE_solver_strategy.h" #include "gridfire/solver/strategies/PointSolver.h"
#include "fourdst/logging/logging.h" #include "fourdst/logging/logging.h"
#include <string> #include <string>
@@ -47,7 +47,7 @@ namespace gridfire::trigger::solver::CVODE {
* *
* See also: engine_partitioning_trigger.cpp for the concrete logic and logging. * See also: engine_partitioning_trigger.cpp for the concrete logic and logging.
*/ */
class SimulationTimeTrigger final : public Trigger<gridfire::solver::CVODESolverStrategy::TimestepContext> { class SimulationTimeTrigger final : public Trigger<gridfire::solver::PointSolverTimestepContext> {
public: public:
/** /**
* @brief Construct with a positive time interval between firings. * @brief Construct with a positive time interval between firings.
@@ -62,7 +62,7 @@ namespace gridfire::trigger::solver::CVODE {
* *
* @post increments hit/miss counters and may emit trace logs. * @post increments hit/miss counters and may emit trace logs.
*/ */
bool check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override; bool check(const gridfire::solver::PointSolverTimestepContext &ctx) const override;
/** /**
* @brief Update internal state; if check(ctx) is true, advance last_trigger_time. * @brief Update internal state; if check(ctx) is true, advance last_trigger_time.
* @param ctx CVODE timestep context. * @param ctx CVODE timestep context.
@@ -70,9 +70,9 @@ namespace gridfire::trigger::solver::CVODE {
* @note update() calls check(ctx) and, on success, records the overshoot delta * @note update() calls check(ctx) and, on success, records the overshoot delta
* (ctx.t - last_trigger_time) - interval for diagnostics. * (ctx.t - last_trigger_time) - interval for diagnostics.
*/ */
void update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override; void update(const gridfire::solver::PointSolverTimestepContext &ctx) override;
void step(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override; void step(const gridfire::solver::PointSolverTimestepContext &ctx) override;
/** /**
* @brief Reset counters and last trigger bookkeeping (time and delta) to zero. * @brief Reset counters and last trigger bookkeeping (time and delta) to zero.
*/ */
@@ -85,7 +85,7 @@ namespace gridfire::trigger::solver::CVODE {
* @param ctx CVODE timestep context. * @param ctx CVODE timestep context.
* @return TriggerResult including name, value, and description. * @return TriggerResult including name, value, and description.
*/ */
TriggerResult why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override; TriggerResult why(const gridfire::solver::PointSolverTimestepContext &ctx) const override;
/** @brief Textual description including configured interval. */ /** @brief Textual description including configured interval. */
std::string describe() const override; std::string describe() const override;
/** @brief Number of true evaluations since last reset. */ /** @brief Number of true evaluations since last reset. */
@@ -130,7 +130,7 @@ namespace gridfire::trigger::solver::CVODE {
* @par See also * @par See also
* - engine_partitioning_trigger.cpp for concrete logic and trace logging. * - engine_partitioning_trigger.cpp for concrete logic and trace logging.
*/ */
class OffDiagonalTrigger final : public Trigger<gridfire::solver::CVODESolverStrategy::TimestepContext> { class OffDiagonalTrigger final : public Trigger<gridfire::solver::PointSolverTimestepContext> {
public: public:
/** /**
* @brief Construct with a non-negative magnitude threshold. * @brief Construct with a non-negative magnitude threshold.
@@ -145,13 +145,13 @@ namespace gridfire::trigger::solver::CVODE {
* *
* @post increments hit/miss counters and may emit trace logs. * @post increments hit/miss counters and may emit trace logs.
*/ */
bool check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override; bool check(const gridfire::solver::PointSolverTimestepContext &ctx) const override;
/** /**
* @brief Record an update; does not mutate any Jacobian-related state. * @brief Record an update; does not mutate any Jacobian-related state.
* @param ctx CVODE timestep context (unused except for symmetry with interface). * @param ctx CVODE timestep context (unused except for symmetry with interface).
*/ */
void update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override; void update(const gridfire::solver::PointSolverTimestepContext &ctx) override;
void step(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override; void step(const gridfire::solver::PointSolverTimestepContext &ctx) override;
/** @brief Reset counters to zero. */ /** @brief Reset counters to zero. */
void reset() override; void reset() override;
@@ -161,7 +161,7 @@ namespace gridfire::trigger::solver::CVODE {
* @brief Structured explanation of the evaluation outcome. * @brief Structured explanation of the evaluation outcome.
* @param ctx CVODE timestep context. * @param ctx CVODE timestep context.
*/ */
TriggerResult why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override; TriggerResult why(const gridfire::solver::PointSolverTimestepContext &ctx) const override;
/** @brief Textual description including configured threshold. */ /** @brief Textual description including configured threshold. */
std::string describe() const override; std::string describe() const override;
/** @brief Number of true evaluations since last reset. */ /** @brief Number of true evaluations since last reset. */
@@ -206,7 +206,7 @@ namespace gridfire::trigger::solver::CVODE {
* *
* See also: engine_partitioning_trigger.cpp for exact logic and logging. * See also: engine_partitioning_trigger.cpp for exact logic and logging.
*/ */
class TimestepCollapseTrigger final : public Trigger<gridfire::solver::CVODESolverStrategy::TimestepContext> { class TimestepCollapseTrigger final : public Trigger<gridfire::solver::PointSolverTimestepContext> {
public: public:
/** /**
* @brief Construct with threshold and relative/absolute mode; window size defaults to 1. * @brief Construct with threshold and relative/absolute mode; window size defaults to 1.
@@ -230,20 +230,20 @@ namespace gridfire::trigger::solver::CVODE {
* *
* @post increments hit/miss counters and may emit trace logs. * @post increments hit/miss counters and may emit trace logs.
*/ */
bool check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override; bool check(const gridfire::solver::PointSolverTimestepContext &ctx) const override;
/** /**
* @brief Update sliding window with the most recent dt and increment update counter. * @brief Update sliding window with the most recent dt and increment update counter.
* @param ctx CVODE timestep context. * @param ctx CVODE timestep context.
*/ */
void update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override; void update(const gridfire::solver::PointSolverTimestepContext &ctx) override;
void step(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override; void step(const gridfire::solver::PointSolverTimestepContext &ctx) override;
/** @brief Reset counters and clear the dt window. */ /** @brief Reset counters and clear the dt window. */
void reset() override; void reset() override;
/** @brief Stable human-readable name. */ /** @brief Stable human-readable name. */
std::string name() const override; std::string name() const override;
/** @brief Structured explanation of the evaluation outcome. */ /** @brief Structured explanation of the evaluation outcome. */
TriggerResult why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override; TriggerResult why(const gridfire::solver::PointSolverTimestepContext &ctx) const override;
/** @brief Textual description including threshold, mode, and window size. */ /** @brief Textual description including threshold, mode, and window size. */
std::string describe() const override; std::string describe() const override;
/** @brief Number of true evaluations since last reset. */ /** @brief Number of true evaluations since last reset. */
@@ -272,15 +272,15 @@ namespace gridfire::trigger::solver::CVODE {
std::deque<double> m_timestep_window; std::deque<double> m_timestep_window;
}; };
class ConvergenceFailureTrigger final : public Trigger<gridfire::solver::CVODESolverStrategy::TimestepContext> { class ConvergenceFailureTrigger final : public Trigger<gridfire::solver::PointSolverTimestepContext> {
public: public:
explicit ConvergenceFailureTrigger(size_t totalFailures, float relativeFailureRate, size_t windowSize); explicit ConvergenceFailureTrigger(size_t totalFailures, float relativeFailureRate, size_t windowSize);
bool check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override; bool check(const gridfire::solver::PointSolverTimestepContext &ctx) const override;
void update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override; void update(const gridfire::solver::PointSolverTimestepContext &ctx) override;
void step(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override; void step(const gridfire::solver::PointSolverTimestepContext &ctx) override;
void reset() override; void reset() override;
@@ -288,7 +288,7 @@ namespace gridfire::trigger::solver::CVODE {
[[nodiscard]] std::string describe() const override; [[nodiscard]] std::string describe() const override;
[[nodiscard]] TriggerResult why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override; [[nodiscard]] TriggerResult why(const gridfire::solver::PointSolverTimestepContext &ctx) const override;
[[nodiscard]] size_t numTriggers() const override; [[nodiscard]] size_t numTriggers() const override;
@@ -312,8 +312,8 @@ namespace gridfire::trigger::solver::CVODE {
private: private:
float current_mean() const; float current_mean() const;
bool abs_failure(const gridfire::solver::CVODESolverStrategy::TimestepContext& ctx) const; bool abs_failure(const gridfire::solver::PointSolverTimestepContext& ctx) const;
bool rel_failure(const gridfire::solver::CVODESolverStrategy::TimestepContext& ctx) const; bool rel_failure(const gridfire::solver::PointSolverTimestepContext& ctx) const;
}; };
/** /**
@@ -337,10 +337,10 @@ namespace gridfire::trigger::solver::CVODE {
* *
* @note The exact policy is subject to change; this function centralizes that decision. * @note The exact policy is subject to change; this function centralizes that decision.
*/ */
std::unique_ptr<Trigger<gridfire::solver::CVODESolverStrategy::TimestepContext>> makeEnginePartitioningTrigger( std::unique_ptr<Trigger<gridfire::solver::PointSolverTimestepContext>> makeEnginePartitioningTrigger(
const double simulationTimeInterval, double simulationTimeInterval,
const double offDiagonalThreshold, double offDiagonalThreshold,
const double timestepCollapseRatio, double timestepCollapseRatio,
const size_t maxConvergenceFailures size_t maxConvergenceFailures
); );
} }

View File

@@ -30,6 +30,7 @@ namespace gridfire::omp {
); );
CppAD::thread_alloc::hold_memory(true); CppAD::thread_alloc::hold_memory(true);
CppAD::CheckSimpleVector<double, std::vector<double>>(0, 1);
s_par_mode_initialized = true; s_par_mode_initialized = true;
} }
} }

View File

@@ -118,8 +118,7 @@ namespace gridfire::engine {
m_weakRateInterpolator(rates::weak::UNIFIED_WEAK_DATA), m_weakRateInterpolator(rates::weak::UNIFIED_WEAK_DATA),
m_reactions(build_nuclear_network(composition, m_weakRateInterpolator, buildDepth, reactionTypes)), m_reactions(build_nuclear_network(composition, m_weakRateInterpolator, buildDepth, reactionTypes)),
m_partitionFunction(partitionFunction.clone()), m_partitionFunction(partitionFunction.clone()),
m_depth(buildDepth), m_depth(buildDepth)
m_state_blob_offset(0) // For a base engine the offset is always 0
{ {
syncInternalMaps(); syncInternalMaps();
} }
@@ -128,8 +127,7 @@ namespace gridfire::engine {
const reaction::ReactionSet &reactions const reaction::ReactionSet &reactions
) : ) :
m_weakRateInterpolator(rates::weak::UNIFIED_WEAK_DATA), m_weakRateInterpolator(rates::weak::UNIFIED_WEAK_DATA),
m_reactions(reactions), m_reactions(reactions)
m_state_blob_offset(0)
{ {
syncInternalMaps(); syncInternalMaps();
} }

View File

@@ -2,7 +2,6 @@
#include "fourdst/atomic/species.h" #include "fourdst/atomic/species.h"
#include "fourdst/composition/utils.h" #include "fourdst/composition/utils.h"
#include "gridfire/engine/views/engine_priming.h"
#include "gridfire/solver/solver.h" #include "gridfire/solver/solver.h"
#include "gridfire/engine/engine_abstract.h" #include "gridfire/engine/engine_abstract.h"
@@ -13,7 +12,7 @@
#include "gridfire/engine/scratchpads/engine_graph_scratchpad.h" #include "gridfire/engine/scratchpads/engine_graph_scratchpad.h"
#include "fourdst/logging/logging.h" #include "fourdst/logging/logging.h"
#include "gridfire/solver/strategies/CVODE_solver_strategy.h" #include "gridfire/solver/strategies/PointSolver.h"
#include "quill/Logger.h" #include "quill/Logger.h"
#include "quill/LogMacros.h" #include "quill/LogMacros.h"
@@ -28,13 +27,12 @@ namespace gridfire::engine {
const GraphEngine& engine, const std::optional<std::vector<reaction::ReactionType>>& ignoredReactionTypes const GraphEngine& engine, const std::optional<std::vector<reaction::ReactionType>>& ignoredReactionTypes
) { ) {
const auto logger = LogManager::getInstance().getLogger("log"); const auto logger = LogManager::getInstance().getLogger("log");
solver::CVODESolverStrategy integrator(engine, ctx); solver::PointSolver integrator(engine);
solver::PointSolverContext solverCtx(ctx);
solverCtx.abs_tol = 1e-3;
solverCtx.rel_tol = 1e-3;
solverCtx.stdout_logging = false;
// Do not need high precision for priming
integrator.set_absTol(1e-3);
integrator.set_relTol(1e-3);
integrator.set_stdout_logging_enabled(false);
NetIn solverInput(netIn); NetIn solverInput(netIn);
solverInput.tMax = 1e-15; solverInput.tMax = 1e-15;
@@ -43,7 +41,7 @@ namespace gridfire::engine {
LOG_INFO(logger, "Short timescale ({}) network ignition started.", solverInput.tMax); LOG_INFO(logger, "Short timescale ({}) network ignition started.", solverInput.tMax);
PrimingReport report; PrimingReport report;
try { try {
const NetOut netOut = integrator.evaluate(solverInput, false); const NetOut netOut = integrator.evaluate(solverCtx, solverInput);
LOG_INFO(logger, "Network ignition completed."); LOG_INFO(logger, "Network ignition completed.");
LOG_TRACE_L2( LOG_TRACE_L2(
logger, logger,

View File

@@ -2005,7 +2005,32 @@ namespace gridfire::engine {
LOG_INFO(getLogger(), "KINSol failed to converge within the maximum number of iterations, but achieved acceptable accuracy with function norm {} < {}. Proceeding with solution.", LOG_INFO(getLogger(), "KINSol failed to converge within the maximum number of iterations, but achieved acceptable accuracy with function norm {} < {}. Proceeding with solution.",
fnorm, ACCEPTABLE_FTOL); fnorm, ACCEPTABLE_FTOL);
} else { } else {
LOG_WARNING(getLogger(), "KINSol failed to converge while solving QSE abundances with flag {}. Error {}", utils::kinsol_ret_code_map.at(flag), fnorm); LOG_CRITICAL(getLogger(), "KINSol failed to converge while solving QSE abundances with flag {}. Flag No.: {}, Error (fNorm): {}", utils::kinsol_ret_code_map.at(flag), flag, fnorm);
LOG_CRITICAL(getLogger(), "State prior to failure: {}",
[&comp, &data]() -> std::string {
std::ostringstream oss;
oss << "Solve species: <";
size_t count = 0;
for (const auto& species : data.qse_solve_species) {
oss << species.name();
if (count < data.qse_solve_species.size() - 1) {
oss << ", ";
}
count++;
}
oss << "> | Abundances and rates at failure: ";
count = 0;
for (const auto& [species, abundance] : comp) {
oss << species.name() << ": Y = " << abundance;
if (count < comp.size() - 1) {
oss << ", ";
}
count++;
}
oss << " | Temperature: " << data.T9 << ", Density: " << data.rho;
return oss.str();
}()
);
throw exceptions::InvalidQSESolutionError("KINSol failed to converge while solving QSE abundances. " + utils::kinsol_ret_code_map.at(flag)); throw exceptions::InvalidQSESolutionError("KINSol failed to converge while solving QSE abundances. " + utils::kinsol_ret_code_map.at(flag));
} }
} }

View File

@@ -0,0 +1,94 @@
#include "gridfire/solver/strategies/GridSolver.h"
#include "gridfire/exceptions/error_solver.h"
#include "gridfire/solver/strategies/PointSolver.h"
#include "gridfire/utils/macros.h"
#include "gridfire/utils/gf_omp.h"
#include <cstdio>
#include <print>
namespace gridfire::solver {
void GridSolverContext::init() {}
void GridSolverContext::reset() {
solver_workspaces.clear();
timestep_callbacks.clear();
}
void GridSolverContext::set_callback(const std::function<void(const TimestepContextBase &)> &callback) {
for (auto &cb : timestep_callbacks) {
cb = callback;
}
}
void GridSolverContext::set_callback(const std::function<void(const TimestepContextBase &)> &callback, const size_t zone_idx) {
if (zone_idx >= timestep_callbacks.size()) {
throw exceptions::SolverError("GridSolverContext::set_callback: zone_idx out of range.");
}
timestep_callbacks[zone_idx] = callback;
}
void GridSolverContext::set_stdout_logging(const bool enable) {
zone_stdout_logging = enable;
}
void GridSolverContext::set_detailed_logging(const bool enable) {
zone_detailed_logging = enable;
}
GridSolverContext::GridSolverContext(
const engine::scratch::StateBlob &ctx_template
) :
ctx_template(ctx_template) {}
GridSolver::GridSolver(
const engine::DynamicEngine &engine,
const SingleZoneDynamicNetworkSolver &solver
) :
MultiZoneNetworkSolver(engine, solver) {
GF_PAR_INIT();
}
std::vector<NetOut> GridSolver::evaluate(
SolverContextBase& ctx,
const std::vector<NetIn>& netIns
) const {
auto* sctx_p = dynamic_cast<GridSolverContext*>(&ctx);
if (!sctx_p) {
throw exceptions::SolverError("GridSolver::evaluate: SolverContextBase is not of type GridSolverContext.");
}
const size_t n_zones = netIns.size();
if (n_zones == 0) { return {}; }
std::vector<NetOut> results(n_zones);
sctx_p->solver_workspaces.resize(n_zones);
GF_OMP(
parallel for default(none) shared(sctx_p, n_zones),
for (size_t zone_idx = 0; zone_idx < n_zones; ++zone_idx)) {
sctx_p->solver_workspaces[zone_idx] = std::make_unique<PointSolverContext>(sctx_p->ctx_template);
sctx_p->solver_workspaces[zone_idx]->set_stdout_logging(sctx_p->zone_stdout_logging);
sctx_p->solver_workspaces[zone_idx]->set_detailed_logging(sctx_p->zone_detailed_logging);
}
GF_OMP(
parallel for default(none) shared(results, sctx_p, netIns, n_zones),
for (size_t zone_idx = 0; zone_idx < n_zones; ++zone_idx)) {
try {
results[zone_idx] = m_solver.evaluate(
*sctx_p->solver_workspaces[zone_idx],
netIns[zone_idx]
);
} catch (exceptions::GridFireError& e) {
std::println("CVODE Solver Failure in zone {}: {}", zone_idx, e.what());
}
if (sctx_p->zone_completion_logging) {
std::println("Thread {} completed zone {}", GF_OMP_THREAD_NUM, zone_idx);
}
}
return results;
}
}

View File

@@ -1,4 +1,4 @@
#include "gridfire/solver/strategies/CVODE_solver_strategy.h" #include "gridfire/solver/strategies/PointSolver.h"
#include "gridfire/types/types.h" #include "gridfire/types/types.h"
#include "gridfire/utils/table_format.h" #include "gridfire/utils/table_format.h"
@@ -28,7 +28,7 @@
namespace gridfire::solver { namespace gridfire::solver {
using namespace gridfire::engine; using namespace gridfire::engine;
CVODESolverStrategy::TimestepContext::TimestepContext( PointSolverTimestepContext::PointSolverTimestepContext(
const double t, const double t,
const N_Vector &state, const N_Vector &state,
const double dt, const double dt,
@@ -58,7 +58,7 @@ namespace gridfire::solver {
state_ctx(ctx) state_ctx(ctx)
{} {}
std::vector<std::tuple<std::string, std::string>> CVODESolverStrategy::TimestepContext::describe() const { std::vector<std::tuple<std::string, std::string>> PointSolverTimestepContext::describe() const {
std::vector<std::tuple<std::string, std::string>> description; std::vector<std::tuple<std::string, std::string>> description;
description.emplace_back("t", "Current Time"); description.emplace_back("t", "Current Time");
description.emplace_back("state", "Current State Vector (N_Vector)"); description.emplace_back("state", "Current State Vector (N_Vector)");
@@ -74,36 +74,112 @@ namespace gridfire::solver {
return description; return description;
} }
void PointSolverContext::init() {
reset_all();
init_context();
}
CVODESolverStrategy::CVODESolverStrategy( void PointSolverContext::set_stdout_logging(const bool enable) {
const DynamicEngine &engine, stdout_logging = enable;
const scratch::StateBlob& ctx }
): SingleZoneNetworkSolver<DynamicEngine>(engine, ctx) {
// PERF: In order to support MPI this function must be changed void PointSolverContext::set_detailed_logging(const bool enable) {
const int flag = SUNContext_Create(SUN_COMM_NULL, &m_sun_ctx); detailed_step_logging = enable;
if (flag < 0) { }
throw std::runtime_error("Failed to create SUNDIALS context (SUNDIALS Errno: " + std::to_string(flag) + ")");
void PointSolverContext::reset_all() {
reset_user();
reset_cvode();
}
void PointSolverContext::reset_user() {
callback.reset();
num_steps = 0;
stdout_logging = true;
abs_tol.reset();
rel_tol.reset();
detailed_step_logging = false;
last_size = 0;
last_composition_hash = 0ULL;
}
void PointSolverContext::reset_cvode() {
if (LS) {
SUNLinSolFree(LS);
LS = nullptr;
}
if (J) {
SUNMatDestroy(J);
J = nullptr;
}
if (Y) {
N_VDestroy(Y);
Y = nullptr;
}
if (YErr) {
N_VDestroy(YErr);
YErr = nullptr;
}
if (constraints) {
N_VDestroy(constraints);
constraints = nullptr;
}
if (cvode_mem) {
CVodeFree(&cvode_mem);
cvode_mem = nullptr;
} }
} }
CVODESolverStrategy::~CVODESolverStrategy() { void PointSolverContext::clear_context() {
LOG_TRACE_L1(m_logger, "Cleaning up CVODE resources..."); if (sun_ctx) {
cleanup_cvode_resources(true); SUNContext_Free(&sun_ctx);
sun_ctx = nullptr;
if (m_sun_ctx) {
SUNContext_Free(&m_sun_ctx);
} }
} }
NetOut CVODESolverStrategy::evaluate(const NetIn& netIn) { void PointSolverContext::init_context() {
return evaluate(netIn, false); if (!sun_ctx) {
utils::check_sundials_flag(SUNContext_Create(SUN_COMM_NULL, &sun_ctx), "SUNContext_Create", utils::SUNDIALS_RET_CODE_TYPES::CVODE);
}
} }
NetOut CVODESolverStrategy::evaluate( bool PointSolverContext::has_context() const {
return sun_ctx != nullptr;
}
PointSolverContext::PointSolverContext(
const scratch::StateBlob& engine_ctx
) :
engine_ctx(engine_ctx.clone_structure())
{
utils::check_sundials_flag(SUNContext_Create(SUN_COMM_NULL, &sun_ctx), "SUNContext_Create", utils::SUNDIALS_RET_CODE_TYPES::CVODE);
}
PointSolverContext::~PointSolverContext() {
reset_cvode();
clear_context();
}
PointSolver::PointSolver(
const DynamicEngine &engine
): SingleZoneNetworkSolver(engine) {}
NetOut PointSolver::evaluate(
SolverContextBase& solver_ctx,
const NetIn& netIn
) const {
return evaluate(solver_ctx, netIn, false);
}
NetOut PointSolver::evaluate(
SolverContextBase& solver_ctx,
const NetIn &netIn, const NetIn &netIn,
bool displayTrigger, bool displayTrigger,
bool forceReinitialize bool forceReinitialize
) { ) const {
auto* sctx_p = dynamic_cast<PointSolverContext*>(&solver_ctx);
LOG_TRACE_L1(m_logger, "Starting solver evaluation with T9: {} and rho: {}", netIn.temperature/1e9, netIn.density); LOG_TRACE_L1(m_logger, "Starting solver evaluation with T9: {} and rho: {}", netIn.temperature/1e9, netIn.density);
LOG_TRACE_L1(m_logger, "Building engine update trigger...."); LOG_TRACE_L1(m_logger, "Building engine update trigger....");
auto trigger = trigger::solver::CVODE::makeEnginePartitioningTrigger(1e12, 1e10, 0.5, 2); auto trigger = trigger::solver::CVODE::makeEnginePartitioningTrigger(1e12, 1e10, 0.5, 2);
@@ -117,23 +193,24 @@ namespace gridfire::solver {
// 2. If the user has set tolerances in code, those override the config // 2. If the user has set tolerances in code, those override the config
// 3. If the user has not set tolerances in code and the config does not have them, use hardcoded defaults // 3. If the user has not set tolerances in code and the config does not have them, use hardcoded defaults
auto absTol = m_config->solver.cvode.absTol; if (!sctx_p->abs_tol.has_value()) {
auto relTol = m_config->solver.cvode.relTol; sctx_p->abs_tol = m_config->solver.cvode.absTol;
if (m_absTol) {
absTol = *m_absTol;
} }
if (m_relTol) { if (!sctx_p->rel_tol.has_value()) {
relTol = *m_relTol; sctx_p->rel_tol = m_config->solver.cvode.relTol;
} }
bool resourcesExist = (m_cvode_mem != nullptr) && (m_Y != nullptr);
bool inconsistentComposition = netIn.composition.hash() != m_last_composition_hash; bool resourcesExist = (sctx_p->cvode_mem != nullptr) && (sctx_p->Y != nullptr);
bool inconsistentComposition = netIn.composition.hash() != sctx_p->last_composition_hash;
fourdst::composition::Composition equilibratedComposition; fourdst::composition::Composition equilibratedComposition;
if (forceReinitialize || !resourcesExist || inconsistentComposition) { if (forceReinitialize || !resourcesExist || inconsistentComposition) {
cleanup_cvode_resources(true); sctx_p->reset_cvode();
if (!sctx_p->has_context()) {
sctx_p->init_context();
}
LOG_INFO( LOG_INFO(
m_logger, m_logger,
"Preforming full CVODE initialization (Reason: {})", "Preforming full CVODE initialization (Reason: {})",
@@ -141,26 +218,24 @@ namespace gridfire::solver {
(!resourcesExist ? "CVODE resources do not exist" : (!resourcesExist ? "CVODE resources do not exist" :
"Input composition inconsistent with previous state")); "Input composition inconsistent with previous state"));
LOG_TRACE_L1(m_logger, "Starting engine update chain..."); LOG_TRACE_L1(m_logger, "Starting engine update chain...");
equilibratedComposition = m_engine.project(*m_scratch_blob, netIn); equilibratedComposition = m_engine.project(*sctx_p->engine_ctx, netIn);
LOG_TRACE_L1(m_logger, "Engine updated and equilibrated composition found!"); LOG_TRACE_L1(m_logger, "Engine updated and equilibrated composition found!");
size_t numSpecies = m_engine.getNetworkSpecies(*m_scratch_blob).size(); size_t numSpecies = m_engine.getNetworkSpecies(*sctx_p->engine_ctx).size();
uint64_t N = numSpecies + 1; uint64_t N = numSpecies + 1;
LOG_TRACE_L1(m_logger, "Number of species: {} ({} independent variables)", numSpecies, N); LOG_TRACE_L1(m_logger, "Number of species: {} ({} independent variables)", numSpecies, N);
LOG_TRACE_L1(m_logger, "Initializing CVODE resources"); LOG_TRACE_L1(m_logger, "Initializing CVODE resources");
m_cvode_mem = CVodeCreate(CV_BDF, m_sun_ctx);
utils::check_cvode_flag(m_cvode_mem == nullptr ? -1 : 0, "CVodeCreate");
initialize_cvode_integration_resources(N, numSpecies, 0.0, equilibratedComposition, absTol, relTol, 0.0); initialize_cvode_integration_resources(sctx_p, N, numSpecies, 0.0, equilibratedComposition, sctx_p->abs_tol.value(), sctx_p->rel_tol.value(), 0.0);
m_last_size = N; sctx_p->last_size = N;
} else { } else {
LOG_INFO(m_logger, "Reusing existing CVODE resources (size: {})", m_last_size); LOG_INFO(m_logger, "Reusing existing CVODE resources (size: {})", sctx_p->last_size);
const size_t numSpecies = m_engine.getNetworkSpecies(*m_scratch_blob).size(); const size_t numSpecies = m_engine.getNetworkSpecies(*sctx_p->engine_ctx).size();
sunrealtype *y_data = N_VGetArrayPointer(m_Y); sunrealtype *y_data = N_VGetArrayPointer(sctx_p->Y);
for (size_t i = 0; i < numSpecies; i++) { for (size_t i = 0; i < numSpecies; i++) {
const auto& species = m_engine.getNetworkSpecies(*m_scratch_blob)[i]; const auto& species = m_engine.getNetworkSpecies(*sctx_p->engine_ctx)[i];
if (netIn.composition.contains(species)) { if (netIn.composition.contains(species)) {
y_data[i] = netIn.composition.getMolarAbundance(species); y_data[i] = netIn.composition.getMolarAbundance(species);
} else { } else {
@@ -168,16 +243,17 @@ namespace gridfire::solver {
} }
} }
y_data[numSpecies] = 0.0; // Reset energy accumulator y_data[numSpecies] = 0.0; // Reset energy accumulator
utils::check_cvode_flag(CVodeSStolerances(m_cvode_mem, relTol, absTol), "CVodeSStolerances"); utils::check_cvode_flag(CVodeSStolerances(sctx_p->cvode_mem, sctx_p->rel_tol.value(), sctx_p->abs_tol.value()), "CVodeSStolerances");
utils::check_cvode_flag(CVodeReInit(m_cvode_mem, 0.0, m_Y), "CVodeReInit"); utils::check_cvode_flag(CVodeReInit(sctx_p->cvode_mem, 0.0, sctx_p->Y), "CVodeReInit");
equilibratedComposition = netIn.composition; // Use the provided composition as-is if we already have validated CVODE resources and that the composition is consistent with the previous state equilibratedComposition = netIn.composition; // Use the provided composition as-is if we already have validated CVODE resources and that the composition is consistent with the previous state
} }
size_t numSpecies = m_engine.getNetworkSpecies(*m_scratch_blob).size(); size_t numSpecies = m_engine.getNetworkSpecies(*sctx_p->engine_ctx).size();
CVODEUserData user_data { CVODEUserData user_data {
.solver_instance = this, .solver_instance = this,
.ctx = *m_scratch_blob, .sctx = sctx_p,
.ctx = *sctx_p->engine_ctx,
.engine = &m_engine, .engine = &m_engine,
}; };
LOG_TRACE_L1(m_logger, "CVODE resources successfully initialized!"); LOG_TRACE_L1(m_logger, "CVODE resources successfully initialized!");
@@ -185,7 +261,7 @@ namespace gridfire::solver {
double current_time = 0; double current_time = 0;
// ReSharper disable once CppTooWideScope // ReSharper disable once CppTooWideScope
[[maybe_unused]] double last_callback_time = 0; [[maybe_unused]] double last_callback_time = 0;
m_num_steps = 0; sctx_p->num_steps = 0;
double accumulated_energy = 0.0; double accumulated_energy = 0.0;
double accumulated_neutrino_energy_loss = 0.0; double accumulated_neutrino_energy_loss = 0.0;
@@ -205,13 +281,13 @@ namespace gridfire::solver {
while (current_time < netIn.tMax) { while (current_time < netIn.tMax) {
user_data.T9 = T9; user_data.T9 = T9;
user_data.rho = netIn.density; user_data.rho = netIn.density;
user_data.networkSpecies = &m_engine.getNetworkSpecies(*m_scratch_blob); user_data.networkSpecies = &m_engine.getNetworkSpecies(*sctx_p->engine_ctx);
user_data.captured_exception.reset(); user_data.captured_exception.reset();
utils::check_cvode_flag(CVodeSetUserData(m_cvode_mem, &user_data), "CVodeSetUserData"); utils::check_cvode_flag(CVodeSetUserData(sctx_p->cvode_mem, &user_data), "CVodeSetUserData");
LOG_TRACE_L2(m_logger, "Taking one CVODE step..."); LOG_TRACE_L2(m_logger, "Taking one CVODE step...");
int flag = CVode(m_cvode_mem, netIn.tMax, m_Y, &current_time, CV_ONE_STEP); int flag = CVode(sctx_p->cvode_mem, netIn.tMax, sctx_p->Y, &current_time, CV_ONE_STEP);
LOG_TRACE_L2(m_logger, "CVODE step complete. Current time: {}, step status: {}", current_time, utils::cvode_ret_code_map.at(flag)); LOG_TRACE_L2(m_logger, "CVODE step complete. Current time: {}, step status: {}", current_time, utils::cvode_ret_code_map.at(flag));
if (user_data.captured_exception){ if (user_data.captured_exception){
@@ -223,13 +299,13 @@ namespace gridfire::solver {
long int n_steps; long int n_steps;
double last_step_size; double last_step_size;
CVodeGetNumSteps(m_cvode_mem, &n_steps); CVodeGetNumSteps(sctx_p->cvode_mem, &n_steps);
CVodeGetLastStep(m_cvode_mem, &last_step_size); CVodeGetLastStep(sctx_p->cvode_mem, &last_step_size);
long int nliters, nlcfails; long int nliters, nlcfails;
CVodeGetNumNonlinSolvIters(m_cvode_mem, &nliters); CVodeGetNumNonlinSolvIters(sctx_p->cvode_mem, &nliters);
CVodeGetNumNonlinSolvConvFails(m_cvode_mem, &nlcfails); CVodeGetNumNonlinSolvConvFails(sctx_p->cvode_mem, &nlcfails);
sunrealtype* y_data = N_VGetArrayPointer(m_Y); sunrealtype* y_data = N_VGetArrayPointer(sctx_p->Y);
const double current_energy = y_data[numSpecies]; // Specific energy rate const double current_energy = y_data[numSpecies]; // Specific energy rate
// TODO: Accumulate neutrino loss through the state vector directly which will allow CVODE to properly integrate it // TODO: Accumulate neutrino loss through the state vector directly which will allow CVODE to properly integrate it
@@ -238,7 +314,7 @@ namespace gridfire::solver {
size_t iter_diff = (total_nonlinear_iterations + nliters) - prev_nonlinear_iterations; size_t iter_diff = (total_nonlinear_iterations + nliters) - prev_nonlinear_iterations;
size_t convFail_diff = (total_convergence_failures + nlcfails) - prev_convergence_failures; size_t convFail_diff = (total_convergence_failures + nlcfails) - prev_convergence_failures;
if (m_stdout_logging_enabled) { if (sctx_p->stdout_logging) {
std::println( std::println(
"Step: {:6} | Updates: {:3} | Epoch Steps: {:4} | t: {:.3e} [s] | dt: {:15.6E} [s] | Iterations: {:6} (+{:2}) | Total Convergence Failures: {:2} (+{:2})", "Step: {:6} | Updates: {:3} | Epoch Steps: {:4} | t: {:.3e} [s] | dt: {:15.6E} [s] | Iterations: {:6} (+{:2}) | Total Convergence Failures: {:2} (+{:2})",
total_steps + n_steps, total_steps + n_steps,
@@ -253,20 +329,16 @@ namespace gridfire::solver {
); );
} }
for (size_t i = 0; i < numSpecies; ++i) { for (size_t i = 0; i < numSpecies; ++i) {
const auto& species = m_engine.getNetworkSpecies(*m_scratch_blob)[i]; const auto& species = m_engine.getNetworkSpecies(*sctx_p->engine_ctx)[i];
if (y_data[i] > 0.0) { if (y_data[i] > 0.0) {
postStep.setMolarAbundance(species, y_data[i]); postStep.setMolarAbundance(species, y_data[i]);
} }
} }
// fourdst::composition::Composition collectedComposition = m_engine.collectComposition(postStep, netIn.temperature/1e9, netIn.density);
// for (size_t i = 0; i < numSpecies; ++i) {
// y_data[i] = collectedComposition.getMolarAbundance(m_engine.getNetworkSpecies()[i]);
// }
LOG_INFO(m_logger, "Completed {:5} steps to time {:10.4E} [s] (dt = {:15.6E} [s]). Current specific energy: {:15.6E} [erg/g]", total_steps + n_steps, current_time, last_step_size, current_energy); LOG_INFO(m_logger, "Completed {:5} steps to time {:10.4E} [s] (dt = {:15.6E} [s]). Current specific energy: {:15.6E} [erg/g]", total_steps + n_steps, current_time, last_step_size, current_energy);
LOG_DEBUG(m_logger, "Current composition (molar abundance): {}", [&]() -> std::string { LOG_DEBUG(m_logger, "Current composition (molar abundance): {}", [&]() -> std::string {
std::stringstream ss; std::stringstream ss;
for (size_t i = 0; i < numSpecies; ++i) { for (size_t i = 0; i < numSpecies; ++i) {
const auto& species = m_engine.getNetworkSpecies(*m_scratch_blob)[i]; const auto& species = m_engine.getNetworkSpecies(*sctx_p->engine_ctx)[i];
ss << species.name() << ": (y_data = " << y_data[i] << ", collected = " << postStep.getMolarAbundance(species) << ")"; ss << species.name() << ": (y_data = " << y_data[i] << ", collected = " << postStep.getMolarAbundance(species) << ")";
if (i < numSpecies - 1) { if (i < numSpecies - 1) {
ss << ", "; ss << ", ";
@@ -282,36 +354,44 @@ namespace gridfire::solver {
? user_data.reaction_contribution_map.value() ? user_data.reaction_contribution_map.value()
: kEmptyMap; : kEmptyMap;
auto ctx = TimestepContext( auto ctx = PointSolverTimestepContext(
current_time, current_time,
m_Y, sctx_p->Y,
last_step_size, last_step_size,
last_callback_time, last_callback_time,
T9, T9,
netIn.density, netIn.density,
n_steps, n_steps,
m_engine, m_engine,
m_engine.getNetworkSpecies(*m_scratch_blob), m_engine.getNetworkSpecies(*sctx_p->engine_ctx),
convFail_diff, convFail_diff,
iter_diff, iter_diff,
rcMap, rcMap,
*m_scratch_blob *sctx_p->engine_ctx
); );
prev_nonlinear_iterations = nliters + total_nonlinear_iterations; prev_nonlinear_iterations = nliters + total_nonlinear_iterations;
prev_convergence_failures = nlcfails + total_convergence_failures; prev_convergence_failures = nlcfails + total_convergence_failures;
if (m_callback.has_value()) { if (sctx_p->callback.has_value()) {
m_callback.value()(ctx); sctx_p->callback.value()(ctx);
} }
trigger->step(ctx); trigger->step(ctx);
if (m_detailed_step_logging) { if (sctx_p->detailed_step_logging) {
log_step_diagnostics(*m_scratch_blob, user_data, true, true, true, "step_" + std::to_string(total_steps + n_steps) + ".json"); log_step_diagnostics(
sctx_p,
*sctx_p->engine_ctx,
user_data,
true,
true,
true,
"step_" + std::to_string(total_steps + n_steps) + ".json"
);
} }
if (trigger->check(ctx)) { if (trigger->check(ctx)) {
if (m_stdout_logging_enabled && displayTrigger) { if (sctx_p->stdout_logging && displayTrigger) {
trigger::printWhy(trigger->why(ctx)); trigger::printWhy(trigger->why(ctx));
} }
trigger->update(ctx); trigger->update(ctx);
@@ -333,20 +413,20 @@ namespace gridfire::solver {
fourdst::composition::Composition temp_comp; fourdst::composition::Composition temp_comp;
std::vector<double> mass_fractions; std::vector<double> mass_fractions;
auto num_species_at_stop = static_cast<long int>(m_engine.getNetworkSpecies(*m_scratch_blob).size()); auto num_species_at_stop = static_cast<long int>(m_engine.getNetworkSpecies(*sctx_p->engine_ctx).size());
if (num_species_at_stop > m_Y->ops->nvgetlength(m_Y) - 1) { if (num_species_at_stop > sctx_p->Y->ops->nvgetlength(sctx_p->Y) - 1) {
LOG_ERROR( LOG_ERROR(
m_logger, m_logger,
"Number of species at engine update ({}) exceeds the number of species in the CVODE solver ({}). This should never happen.", "Number of species at engine update ({}) exceeds the number of species in the CVODE solver ({}). This should never happen.",
num_species_at_stop, num_species_at_stop,
m_Y->ops->nvgetlength(m_Y) - 1 // -1 due to energy in the last index sctx_p->Y->ops->nvgetlength(sctx_p->Y) - 1 // -1 due to energy in the last index
); );
throw std::runtime_error("Number of species at engine update exceeds the number of species in the CVODE solver. This should never happen."); throw std::runtime_error("Number of species at engine update exceeds the number of species in the CVODE solver. This should never happen.");
} }
for (const auto& species: m_engine.getNetworkSpecies(*m_scratch_blob)) { for (const auto& species: m_engine.getNetworkSpecies(*sctx_p->engine_ctx)) {
const size_t sid = m_engine.getSpeciesIndex(*m_scratch_blob, species); const size_t sid = m_engine.getSpeciesIndex(*sctx_p->engine_ctx, species);
temp_comp.registerSpecies(species); temp_comp.registerSpecies(species);
double y = end_of_step_abundances[sid]; double y = end_of_step_abundances[sid];
if (y > 0.0) { if (y > 0.0) {
@@ -356,7 +436,7 @@ namespace gridfire::solver {
#ifndef NDEBUG #ifndef NDEBUG
for (long int i = 0; i < num_species_at_stop; ++i) { for (long int i = 0; i < num_species_at_stop; ++i) {
const auto& species = m_engine.getNetworkSpecies(*m_scratch_blob)[i]; const auto& species = m_engine.getNetworkSpecies(*sctx_p->engine_ctx)[i];
if (std::abs(temp_comp.getMolarAbundance(species) - y_data[i]) > 1e-12) { if (std::abs(temp_comp.getMolarAbundance(species) - y_data[i]) > 1e-12) {
throw exceptions::UtilityError("Conversion from solver state to composition molar abundance failed verification."); throw exceptions::UtilityError("Conversion from solver state to composition molar abundance failed verification.");
} }
@@ -391,7 +471,7 @@ namespace gridfire::solver {
"Prior to Engine Update active reactions are: {}", "Prior to Engine Update active reactions are: {}",
[&]() -> std::string { [&]() -> std::string {
std::stringstream ss; std::stringstream ss;
const gridfire::reaction::ReactionSet& reactions = m_engine.getNetworkReactions(*m_scratch_blob); const gridfire::reaction::ReactionSet& reactions = m_engine.getNetworkReactions(*sctx_p->engine_ctx);
size_t count = 0; size_t count = 0;
for (const auto& reaction : reactions) { for (const auto& reaction : reactions) {
ss << reaction -> id(); ss << reaction -> id();
@@ -403,7 +483,7 @@ namespace gridfire::solver {
return ss.str(); return ss.str();
}() }()
); );
fourdst::composition::Composition currentComposition = m_engine.project(*m_scratch_blob, netInTemp); fourdst::composition::Composition currentComposition = m_engine.project(*sctx_p->engine_ctx, netInTemp);
LOG_DEBUG( LOG_DEBUG(
m_logger, m_logger,
"After to Engine update composition is (molar abundance) {}", "After to Engine update composition is (molar abundance) {}",
@@ -450,7 +530,7 @@ namespace gridfire::solver {
"After Engine Update active reactions are: {}", "After Engine Update active reactions are: {}",
[&]() -> std::string { [&]() -> std::string {
std::stringstream ss; std::stringstream ss;
const gridfire::reaction::ReactionSet& reactions = m_engine.getNetworkReactions(*m_scratch_blob); const gridfire::reaction::ReactionSet& reactions = m_engine.getNetworkReactions(*sctx_p->engine_ctx);
size_t count = 0; size_t count = 0;
for (const auto& reaction : reactions) { for (const auto& reaction : reactions) {
ss << reaction -> id(); ss << reaction -> id();
@@ -466,34 +546,29 @@ namespace gridfire::solver {
m_logger, m_logger,
"Due to a triggered engine update the composition was updated from size {} to {} species.", "Due to a triggered engine update the composition was updated from size {} to {} species.",
num_species_at_stop, num_species_at_stop,
m_engine.getNetworkSpecies(*m_scratch_blob).size() m_engine.getNetworkSpecies(*sctx_p->engine_ctx).size()
); );
numSpecies = m_engine.getNetworkSpecies(*m_scratch_blob).size(); numSpecies = m_engine.getNetworkSpecies(*sctx_p->engine_ctx).size();
size_t N = numSpecies + 1; size_t N = numSpecies + 1;
LOG_INFO(m_logger, "Starting CVODE reinitialization after engine update..."); LOG_INFO(m_logger, "Starting CVODE reinitialization after engine update...");
cleanup_cvode_resources(true); sctx_p->reset_cvode();
initialize_cvode_integration_resources(sctx_p, N, numSpecies, current_time, currentComposition, sctx_p->abs_tol.value(), sctx_p->rel_tol.value(), accumulated_energy);
m_cvode_mem = CVodeCreate(CV_BDF, m_sun_ctx); utils::check_cvode_flag(CVodeReInit(sctx_p->cvode_mem, current_time, sctx_p->Y), "CVodeReInit");
utils::check_cvode_flag(m_cvode_mem == nullptr ? -1 : 0, "CVodeCreate");
initialize_cvode_integration_resources(N, numSpecies, current_time, currentComposition, absTol, relTol, accumulated_energy);
utils::check_cvode_flag(CVodeReInit(m_cvode_mem, current_time, m_Y), "CVodeReInit");
// throw exceptions::DebugException("Debug");
LOG_INFO(m_logger, "Done reinitializing CVODE after engine update. The next log messages will be from the first step after reinitialization..."); LOG_INFO(m_logger, "Done reinitializing CVODE after engine update. The next log messages will be from the first step after reinitialization...");
} }
} }
if (m_stdout_logging_enabled) { // Flush the buffer if standard out logging is enabled if (sctx_p->stdout_logging) { // Flush the buffer if standard out logging is enabled
std::cout << std::flush; std::cout << std::flush;
} }
LOG_INFO(m_logger, "CVODE iteration complete"); LOG_INFO(m_logger, "CVODE iteration complete");
sunrealtype* y_data = N_VGetArrayPointer(m_Y); sunrealtype* y_data = N_VGetArrayPointer(sctx_p->Y);
accumulated_energy += y_data[numSpecies]; accumulated_energy += y_data[numSpecies];
std::vector<double> y_vec(y_data, y_data + numSpecies); std::vector<double> y_vec(y_data, y_data + numSpecies);
@@ -505,7 +580,7 @@ namespace gridfire::solver {
LOG_INFO(m_logger, "Constructing final composition= with {} species", numSpecies); LOG_INFO(m_logger, "Constructing final composition= with {} species", numSpecies);
fourdst::composition::Composition topLevelComposition(m_engine.getNetworkSpecies(*m_scratch_blob), y_vec); fourdst::composition::Composition topLevelComposition(m_engine.getNetworkSpecies(*sctx_p->engine_ctx), y_vec);
LOG_INFO(m_logger, "Final composition constructed from solver state successfully! ({})", [&topLevelComposition]() -> std::string { LOG_INFO(m_logger, "Final composition constructed from solver state successfully! ({})", [&topLevelComposition]() -> std::string {
std::ostringstream ss; std::ostringstream ss;
size_t i = 0; size_t i = 0;
@@ -520,7 +595,7 @@ namespace gridfire::solver {
}()); }());
LOG_INFO(m_logger, "Collecting final composition..."); LOG_INFO(m_logger, "Collecting final composition...");
fourdst::composition::Composition outputComposition = m_engine.collectComposition(*m_scratch_blob, topLevelComposition, netIn.temperature/1e9, netIn.density); fourdst::composition::Composition outputComposition = m_engine.collectComposition(*sctx_p->engine_ctx, topLevelComposition, netIn.temperature/1e9, netIn.density);
assert(outputComposition.getRegisteredSymbols().size() == equilibratedComposition.getRegisteredSymbols().size()); assert(outputComposition.getRegisteredSymbols().size() == equilibratedComposition.getRegisteredSymbols().size());
@@ -541,11 +616,11 @@ namespace gridfire::solver {
NetOut netOut; NetOut netOut;
netOut.composition = outputComposition; netOut.composition = outputComposition;
netOut.energy = accumulated_energy; netOut.energy = accumulated_energy;
utils::check_cvode_flag(CVodeGetNumSteps(m_cvode_mem, reinterpret_cast<long int *>(&netOut.num_steps)), "CVodeGetNumSteps"); utils::check_cvode_flag(CVodeGetNumSteps(sctx_p->cvode_mem, reinterpret_cast<long int *>(&netOut.num_steps)), "CVodeGetNumSteps");
LOG_TRACE_L2(m_logger, "generating final nuclear energy generation rate derivatives..."); LOG_TRACE_L2(m_logger, "generating final nuclear energy generation rate derivatives...");
auto [dEps_dT, dEps_dRho] = m_engine.calculateEpsDerivatives( auto [dEps_dT, dEps_dRho] = m_engine.calculateEpsDerivatives(
*m_scratch_blob, *sctx_p->engine_ctx,
outputComposition, outputComposition,
T9, T9,
netIn.density netIn.density
@@ -559,53 +634,13 @@ namespace gridfire::solver {
LOG_TRACE_L2(m_logger, "Output data built!"); LOG_TRACE_L2(m_logger, "Output data built!");
LOG_TRACE_L2(m_logger, "Solver evaluation complete!."); LOG_TRACE_L2(m_logger, "Solver evaluation complete!.");
m_last_composition_hash = netOut.composition.hash(); sctx_p->last_composition_hash = netOut.composition.hash();
m_last_size = netOut.composition.size() + 1; sctx_p->last_size = netOut.composition.size() + 1;
CVodeGetLastStep(m_cvode_mem, &m_last_good_time_step); CVodeGetLastStep(sctx_p->cvode_mem, &sctx_p->last_good_time_step);
return netOut; return netOut;
} }
void CVODESolverStrategy::set_callback(const std::any &callback) { int PointSolver::cvode_rhs_wrapper(
m_callback = std::any_cast<TimestepCallback>(callback);
}
bool CVODESolverStrategy::get_stdout_logging_enabled() const {
return m_stdout_logging_enabled;
}
void CVODESolverStrategy::set_stdout_logging_enabled(const bool logging_enabled) {
m_stdout_logging_enabled = logging_enabled;
}
void CVODESolverStrategy::set_absTol(double absTol) {
m_absTol = absTol;
}
void CVODESolverStrategy::set_relTol(double relTol) {
m_relTol = relTol;
}
double CVODESolverStrategy::get_absTol() const {
if (m_absTol.has_value()) {
return m_absTol.value();
} else {
return -1.0;
}
}
double CVODESolverStrategy::get_relTol() const {
if (m_relTol.has_value()) {
return m_relTol.value();
} else {
return -1.0;
}
}
std::vector<std::tuple<std::string, std::string>> CVODESolverStrategy::describe_callback_context() const {
return {};
}
int CVODESolverStrategy::cvode_rhs_wrapper(
const sunrealtype t, const sunrealtype t,
const N_Vector y, const N_Vector y,
const N_Vector ydot, const N_Vector ydot,
@@ -633,7 +668,7 @@ namespace gridfire::solver {
} }
} }
int CVODESolverStrategy::cvode_jac_wrapper( int PointSolver::cvode_jac_wrapper(
sunrealtype t, sunrealtype t,
N_Vector y, N_Vector y,
N_Vector ydot, N_Vector ydot,
@@ -754,7 +789,7 @@ namespace gridfire::solver {
return 0; return 0;
} }
CVODESolverStrategy::CVODERHSOutputData CVODESolverStrategy::calculate_rhs( PointSolver::CVODERHSOutputData PointSolver::calculate_rhs(
const sunrealtype t, const sunrealtype t,
N_Vector y, N_Vector y,
N_Vector ydot, N_Vector ydot,
@@ -772,10 +807,10 @@ namespace gridfire::solver {
} }
} }
std::vector<double> y_vec(y_data, y_data + numSpecies); std::vector<double> y_vec(y_data, y_data + numSpecies);
fourdst::composition::Composition composition(m_engine.getNetworkSpecies(*m_scratch_blob), y_vec); fourdst::composition::Composition composition(m_engine.getNetworkSpecies(data->ctx), y_vec);
LOG_TRACE_L2(m_logger, "Calculating RHS at time {} with {} species in composition", t, composition.size()); LOG_TRACE_L2(m_logger, "Calculating RHS at time {} with {} species in composition", t, composition.size());
const auto result = m_engine.calculateRHSAndEnergy(*m_scratch_blob, composition, data->T9, data->rho, false); const auto result = m_engine.calculateRHSAndEnergy(data->ctx, composition, data->T9, data->rho, false);
if (!result) { if (!result) {
LOG_CRITICAL(m_logger, "Failed to calculate RHS at time {}: {}", t, EngineStatus_to_string(result.error())); LOG_CRITICAL(m_logger, "Failed to calculate RHS at time {}: {}", t, EngineStatus_to_string(result.error()));
throw exceptions::BadRHSEngineError(std::format("Failed to calculate RHS at time {}: {}", t, EngineStatus_to_string(result.error()))); throw exceptions::BadRHSEngineError(std::format("Failed to calculate RHS at time {}: {}", t, EngineStatus_to_string(result.error())));
@@ -805,7 +840,7 @@ namespace gridfire::solver {
}()); }());
for (size_t i = 0; i < numSpecies; ++i) { for (size_t i = 0; i < numSpecies; ++i) {
fourdst::atomic::Species species = m_engine.getNetworkSpecies(*m_scratch_blob)[i]; fourdst::atomic::Species species = m_engine.getNetworkSpecies(data->ctx)[i];
ydot_data[i] = dydt.at(species); ydot_data[i] = dydt.at(species);
} }
ydot_data[numSpecies] = nuclearEnergyGenerationRate; // Set the last element to the specific energy rate ydot_data[numSpecies] = nuclearEnergyGenerationRate; // Set the last element to the specific energy rate
@@ -813,7 +848,8 @@ namespace gridfire::solver {
return {reactionContributions, result.value().neutrinoEnergyLossRate, result.value().totalNeutrinoFlux}; return {reactionContributions, result.value().neutrinoEnergyLossRate, result.value().totalNeutrinoFlux};
} }
void CVODESolverStrategy::initialize_cvode_integration_resources( void PointSolver::initialize_cvode_integration_resources(
PointSolverContext* sctx_p,
const uint64_t N, const uint64_t N,
const size_t numSpecies, const size_t numSpecies,
const double current_time, const double current_time,
@@ -821,16 +857,18 @@ namespace gridfire::solver {
const double absTol, const double absTol,
const double relTol, const double relTol,
const double accumulatedEnergy const double accumulatedEnergy
) { ) const {
LOG_TRACE_L2(m_logger, "Initializing CVODE integration resources with N: {}, current_time: {}, absTol: {}, relTol: {}", N, current_time, absTol, relTol); LOG_TRACE_L2(m_logger, "Initializing CVODE integration resources with N: {}, current_time: {}, absTol: {}, relTol: {}", N, current_time, absTol, relTol);
cleanup_cvode_resources(false); // Cleanup any existing resources before initializing new ones sctx_p->reset_cvode();
m_Y = utils::init_sun_vector(N, m_sun_ctx); sctx_p->cvode_mem = CVodeCreate(CV_BDF, sctx_p->sun_ctx);
m_YErr = N_VClone(m_Y); utils::check_cvode_flag(sctx_p->cvode_mem == nullptr ? -1 : 0, "CVodeCreate");
sctx_p->Y = utils::init_sun_vector(N, sctx_p->sun_ctx);
sctx_p->YErr = N_VClone(sctx_p->Y);
sunrealtype *y_data = N_VGetArrayPointer(m_Y); sunrealtype *y_data = N_VGetArrayPointer(sctx_p->Y);
for (size_t i = 0; i < numSpecies; i++) { for (size_t i = 0; i < numSpecies; i++) {
const auto& species = m_engine.getNetworkSpecies(*m_scratch_blob)[i]; const auto& species = m_engine.getNetworkSpecies(*sctx_p->engine_ctx)[i];
if (composition.contains(species)) { if (composition.contains(species)) {
y_data[i] = composition.getMolarAbundance(species); y_data[i] = composition.getMolarAbundance(species);
} else { } else {
@@ -840,8 +878,8 @@ namespace gridfire::solver {
y_data[numSpecies] = accumulatedEnergy; // Specific energy rate, initialized to zero y_data[numSpecies] = accumulatedEnergy; // Specific energy rate, initialized to zero
utils::check_cvode_flag(CVodeInit(m_cvode_mem, cvode_rhs_wrapper, current_time, m_Y), "CVodeInit"); utils::check_cvode_flag(CVodeInit(sctx_p->cvode_mem, cvode_rhs_wrapper, current_time, sctx_p->Y), "CVodeInit");
utils::check_cvode_flag(CVodeSStolerances(m_cvode_mem, relTol, absTol), "CVodeSStolerances"); utils::check_cvode_flag(CVodeSStolerances(sctx_p->cvode_mem, relTol, absTol), "CVodeSStolerances");
// Constraints // Constraints
// We constrain the solution vector using CVODE's built in constraint flags as outlines on page 53 of the CVODE manual // We constrain the solution vector using CVODE's built in constraint flags as outlines on page 53 of the CVODE manual
@@ -854,53 +892,30 @@ namespace gridfire::solver {
// -2.0: The corresponding component of y is constrained to be < 0 // -2.0: The corresponding component of y is constrained to be < 0
// Here we use 1.0 for all species to ensure they remain non-negative. // Here we use 1.0 for all species to ensure they remain non-negative.
m_constraints = N_VClone(m_Y); sctx_p->constraints = N_VClone(sctx_p->Y);
if (m_constraints == nullptr) { if (sctx_p->constraints == nullptr) {
LOG_ERROR(m_logger, "Failed to create constraints vector for CVODE"); LOG_ERROR(m_logger, "Failed to create constraints vector for CVODE");
throw std::runtime_error("Failed to create constraints vector for CVODE"); throw std::runtime_error("Failed to create constraints vector for CVODE");
} }
N_VConst(1.0, m_constraints); // Set all constraints to >= 0 (note this is where the flag values are set) N_VConst(1.0, sctx_p->constraints); // Set all constraints to >= 0 (note this is where the flag values are set)
utils::check_cvode_flag(CVodeSetConstraints(m_cvode_mem, m_constraints), "CVodeSetConstraints"); utils::check_cvode_flag(CVodeSetConstraints(sctx_p->cvode_mem, sctx_p->constraints), "CVodeSetConstraints");
utils::check_cvode_flag(CVodeSetMaxStep(m_cvode_mem, 1.0e20), "CVodeSetMaxStep"); utils::check_cvode_flag(CVodeSetMaxStep(sctx_p->cvode_mem, 1.0e20), "CVodeSetMaxStep");
m_J = SUNDenseMatrix(static_cast<sunindextype>(N), static_cast<sunindextype>(N), m_sun_ctx); sctx_p->J = SUNDenseMatrix(static_cast<sunindextype>(N), static_cast<sunindextype>(N), sctx_p->sun_ctx);
utils::check_cvode_flag(m_J == nullptr ? -1 : 0, "SUNDenseMatrix"); utils::check_cvode_flag(sctx_p->J == nullptr ? -1 : 0, "SUNDenseMatrix");
m_LS = SUNLinSol_Dense(m_Y, m_J, m_sun_ctx); sctx_p->LS = SUNLinSol_Dense(sctx_p->Y, sctx_p->J, sctx_p->sun_ctx);
utils::check_cvode_flag(m_LS == nullptr ? -1 : 0, "SUNLinSol_Dense"); utils::check_cvode_flag(sctx_p->LS == nullptr ? -1 : 0, "SUNLinSol_Dense");
utils::check_cvode_flag(CVodeSetLinearSolver(m_cvode_mem, m_LS, m_J), "CVodeSetLinearSolver"); utils::check_cvode_flag(CVodeSetLinearSolver(sctx_p->cvode_mem, sctx_p->LS, sctx_p->J), "CVodeSetLinearSolver");
utils::check_cvode_flag(CVodeSetJacFn(m_cvode_mem, cvode_jac_wrapper), "CVodeSetJacFn"); utils::check_cvode_flag(CVodeSetJacFn(sctx_p->cvode_mem, cvode_jac_wrapper), "CVodeSetJacFn");
LOG_TRACE_L2(m_logger, "CVODE solver initialized"); LOG_TRACE_L2(m_logger, "CVODE solver initialized");
} }
void CVODESolverStrategy::cleanup_cvode_resources(const bool memFree) {
LOG_TRACE_L2(m_logger, "Cleaning up cvode resources");
if (m_LS) SUNLinSolFree(m_LS);
if (m_J) SUNMatDestroy(m_J);
if (m_Y) N_VDestroy(m_Y);
if (m_YErr) N_VDestroy(m_YErr);
if (m_constraints) N_VDestroy(m_constraints);
m_LS = nullptr; void PointSolver::log_step_diagnostics(
m_J = nullptr; PointSolverContext* sctx_p,
m_Y = nullptr;
m_YErr = nullptr;
m_constraints = nullptr;
if (memFree) {
if (m_cvode_mem) CVodeFree(&m_cvode_mem);
m_cvode_mem = nullptr;
}
LOG_TRACE_L2(m_logger, "Done Cleaning up cvode resources");
}
void CVODESolverStrategy::set_detailed_step_logging(const bool enabled) {
m_detailed_step_logging = enabled;
}
void CVODESolverStrategy::log_step_diagnostics(
scratch::StateBlob &ctx, scratch::StateBlob &ctx,
const CVODEUserData &user_data, const CVODEUserData &user_data,
bool displayJacobianStiffness, bool displayJacobianStiffness,
@@ -916,10 +931,10 @@ namespace gridfire::solver {
sunrealtype hlast, hcur, tcur; sunrealtype hlast, hcur, tcur;
int qlast; int qlast;
utils::check_cvode_flag(CVodeGetLastStep(m_cvode_mem, &hlast), "CVodeGetLastStep"); utils::check_cvode_flag(CVodeGetLastStep(sctx_p->cvode_mem, &hlast), "CVodeGetLastStep");
utils::check_cvode_flag(CVodeGetCurrentStep(m_cvode_mem, &hcur), "CVodeGetCurrentStep"); utils::check_cvode_flag(CVodeGetCurrentStep(sctx_p->cvode_mem, &hcur), "CVodeGetCurrentStep");
utils::check_cvode_flag(CVodeGetLastOrder(m_cvode_mem, &qlast), "CVodeGetLastOrder"); utils::check_cvode_flag(CVodeGetLastOrder(sctx_p->cvode_mem, &qlast), "CVodeGetLastOrder");
utils::check_cvode_flag(CVodeGetCurrentTime(m_cvode_mem, &tcur), "CVodeGetCurrentTime"); utils::check_cvode_flag(CVodeGetCurrentTime(sctx_p->cvode_mem, &tcur), "CVodeGetCurrentTime");
nlohmann::json j; nlohmann::json j;
{ {
@@ -941,13 +956,13 @@ namespace gridfire::solver {
// These are the CRITICAL counters for diagnosing your problem // These are the CRITICAL counters for diagnosing your problem
long int nsteps, nfevals, nlinsetups, netfails, nniters, nconvfails, nsetfails; long int nsteps, nfevals, nlinsetups, netfails, nniters, nconvfails, nsetfails;
utils::check_cvode_flag(CVodeGetNumSteps(m_cvode_mem, &nsteps), "CVodeGetNumSteps"); utils::check_cvode_flag(CVodeGetNumSteps(sctx_p->cvode_mem, &nsteps), "CVodeGetNumSteps");
utils::check_cvode_flag(CVodeGetNumRhsEvals(m_cvode_mem, &nfevals), "CVodeGetNumRhsEvals"); utils::check_cvode_flag(CVodeGetNumRhsEvals(sctx_p->cvode_mem, &nfevals), "CVodeGetNumRhsEvals");
utils::check_cvode_flag(CVodeGetNumLinSolvSetups(m_cvode_mem, &nlinsetups), "CVodeGetNumLinSolvSetups"); utils::check_cvode_flag(CVodeGetNumLinSolvSetups(sctx_p->cvode_mem, &nlinsetups), "CVodeGetNumLinSolvSetups");
utils::check_cvode_flag(CVodeGetNumErrTestFails(m_cvode_mem, &netfails), "CVodeGetNumErrTestFails"); utils::check_cvode_flag(CVodeGetNumErrTestFails(sctx_p->cvode_mem, &netfails), "CVodeGetNumErrTestFails");
utils::check_cvode_flag(CVodeGetNumNonlinSolvIters(m_cvode_mem, &nniters), "CVodeGetNumNonlinSolvIters"); utils::check_cvode_flag(CVodeGetNumNonlinSolvIters(sctx_p->cvode_mem, &nniters), "CVodeGetNumNonlinSolvIters");
utils::check_cvode_flag(CVodeGetNumNonlinSolvConvFails(m_cvode_mem, &nconvfails), "CVodeGetNumNonlinSolvConvFails"); utils::check_cvode_flag(CVodeGetNumNonlinSolvConvFails(sctx_p->cvode_mem, &nconvfails), "CVodeGetNumNonlinSolvConvFails");
utils::check_cvode_flag(CVodeGetNumLinConvFails(m_cvode_mem, &nsetfails), "CVodeGetNumLinConvFails"); utils::check_cvode_flag(CVodeGetNumLinConvFails(sctx_p->cvode_mem, &nsetfails), "CVodeGetNumLinConvFails");
{ {
@@ -975,22 +990,26 @@ namespace gridfire::solver {
} }
// --- 3. Get Estimated Local Errors (Your Original Logic) --- // --- 3. Get Estimated Local Errors (Your Original Logic) ---
utils::check_cvode_flag(CVodeGetEstLocalErrors(m_cvode_mem, m_YErr), "CVodeGetEstLocalErrors"); utils::check_cvode_flag(CVodeGetEstLocalErrors(sctx_p->cvode_mem, sctx_p->YErr), "CVodeGetEstLocalErrors");
sunrealtype *y_data = N_VGetArrayPointer(m_Y); sunrealtype *y_data = N_VGetArrayPointer(sctx_p->Y);
sunrealtype *y_err_data = N_VGetArrayPointer(m_YErr); sunrealtype *y_err_data = N_VGetArrayPointer(sctx_p->YErr);
const auto absTol = m_config->solver.cvode.absTol;
const auto relTol = m_config->solver.cvode.relTol;
std::vector<double> err_ratios; std::vector<double> err_ratios;
const size_t num_components = N_VGetLength(m_Y); const size_t num_components = N_VGetLength(sctx_p->Y);
err_ratios.resize(num_components - 1); // Assuming -1 is for Energy or similar err_ratios.resize(num_components - 1); // Assuming -1 is for Energy or similar
std::vector<double> Y_full(y_data, y_data + num_components - 1); std::vector<double> Y_full(y_data, y_data + num_components - 1);
std::vector<double> E_full(y_err_data, y_err_data + num_components - 1); std::vector<double> E_full(y_err_data, y_err_data + num_components - 1);
auto result = diagnostics::report_limiting_species(ctx, *user_data.engine, Y_full, E_full, relTol, absTol, 10, to_file); if (!sctx_p->abs_tol.has_value()) {
sctx_p->abs_tol = m_config->solver.cvode.absTol;
}
if (!sctx_p->rel_tol.has_value()) {
sctx_p->rel_tol = m_config->solver.cvode.relTol;
}
auto result = diagnostics::report_limiting_species(ctx, *user_data.engine, Y_full, E_full, sctx_p->rel_tol.value(), sctx_p->abs_tol.value(), 10, to_file);
if (to_file && result.has_value()) { if (to_file && result.has_value()) {
j["Limiting_Species"] = result.value(); j["Limiting_Species"] = result.value();
} }
@@ -1003,8 +1022,9 @@ namespace gridfire::solver {
0.0 0.0
); );
for (size_t i = 0; i < num_components - 1; i++) { for (size_t i = 0; i < num_components - 1; i++) {
const double weight = relTol * std::abs(y_data[i]) + absTol; const double weight = sctx_p->rel_tol.value() * std::abs(y_data[i]) + sctx_p->abs_tol.value();
if (weight == 0.0) { if (weight == 0.0) {
err_ratios[i] = 0.0; // Avoid division by zero err_ratios[i] = 0.0; // Avoid division by zero
continue; continue;
@@ -1013,11 +1033,11 @@ namespace gridfire::solver {
err_ratios[i] = err_ratio; err_ratios[i] = err_ratio;
} }
fourdst::composition::Composition composition(user_data.engine->getNetworkSpecies(*m_scratch_blob), Y_full); fourdst::composition::Composition composition(user_data.engine->getNetworkSpecies(*sctx_p->engine_ctx), Y_full);
fourdst::composition::Composition collectedComposition = user_data.engine->collectComposition(*m_scratch_blob, composition, user_data.T9, user_data.rho); fourdst::composition::Composition collectedComposition = user_data.engine->collectComposition(*sctx_p->engine_ctx, composition, user_data.T9, user_data.rho);
auto destructionTimescales = user_data.engine->getSpeciesDestructionTimescales(*m_scratch_blob, collectedComposition, user_data.T9, user_data.rho); auto destructionTimescales = user_data.engine->getSpeciesDestructionTimescales(*sctx_p->engine_ctx, collectedComposition, user_data.T9, user_data.rho);
auto netTimescales = user_data.engine->getSpeciesTimescales(*m_scratch_blob, collectedComposition, user_data.T9, user_data.rho); auto netTimescales = user_data.engine->getSpeciesTimescales(*sctx_p->engine_ctx, collectedComposition, user_data.T9, user_data.rho);
bool timescaleOkay = false; bool timescaleOkay = false;
if (destructionTimescales && netTimescales) timescaleOkay = true; if (destructionTimescales && netTimescales) timescaleOkay = true;
@@ -1037,7 +1057,7 @@ namespace gridfire::solver {
if (destructionTimescales.value().contains(sp)) destructionTimescales_list.emplace_back(destructionTimescales.value().at(sp)); if (destructionTimescales.value().contains(sp)) destructionTimescales_list.emplace_back(destructionTimescales.value().at(sp));
else destructionTimescales_list.emplace_back(std::numeric_limits<double>::infinity()); else destructionTimescales_list.emplace_back(std::numeric_limits<double>::infinity());
speciesStatus_list.push_back(SpeciesStatus_to_string(user_data.engine->getSpeciesStatus(*m_scratch_blob, sp))); speciesStatus_list.push_back(SpeciesStatus_to_string(user_data.engine->getSpeciesStatus(*sctx_p->engine_ctx, sp)));
} }
utils::Column<fourdst::atomic::Species> speciesColumn("Species", species_list); utils::Column<fourdst::atomic::Species> speciesColumn("Species", species_list);

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,5 @@
#include "gridfire/solver/strategies/triggers/engine_partitioning_trigger.h" #include "gridfire/solver/strategies/triggers/engine_partitioning_trigger.h"
#include "gridfire/solver/strategies/CVODE_solver_strategy.h" #include "gridfire/solver/strategies/PointSolver.h"
#include "gridfire/trigger/trigger_logical.h" #include "gridfire/trigger/trigger_logical.h"
#include "gridfire/trigger/trigger_abstract.h" #include "gridfire/trigger/trigger_abstract.h"
@@ -28,7 +28,7 @@ namespace gridfire::trigger::solver::CVODE {
} }
} }
bool SimulationTimeTrigger::check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const { bool SimulationTimeTrigger::check(const gridfire::solver::PointSolverTimestepContext &ctx) const {
if (ctx.t - m_last_trigger_time >= m_interval) { if (ctx.t - m_last_trigger_time >= m_interval) {
m_hits++; m_hits++;
LOG_TRACE_L2(m_logger, "SimulationTimeTrigger triggered at t = {}, last trigger time was {}, delta = {}", ctx.t, m_last_trigger_time, m_last_trigger_time_delta); LOG_TRACE_L2(m_logger, "SimulationTimeTrigger triggered at t = {}, last trigger time was {}, delta = {}", ctx.t, m_last_trigger_time, m_last_trigger_time_delta);
@@ -38,7 +38,7 @@ namespace gridfire::trigger::solver::CVODE {
return false; return false;
} }
void SimulationTimeTrigger::update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) { void SimulationTimeTrigger::update(const gridfire::solver::PointSolverTimestepContext &ctx) {
if (check(ctx)) { if (check(ctx)) {
m_last_trigger_time_delta = (ctx.t - m_last_trigger_time) - m_interval; m_last_trigger_time_delta = (ctx.t - m_last_trigger_time) - m_interval;
m_last_trigger_time = ctx.t; m_last_trigger_time = ctx.t;
@@ -47,7 +47,7 @@ namespace gridfire::trigger::solver::CVODE {
} }
void SimulationTimeTrigger::step( void SimulationTimeTrigger::step(
const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx const gridfire::solver::PointSolverTimestepContext &ctx
) { ) {
// --- SimulationTimeTrigger::step does nothing and is intentionally left blank --- // // --- SimulationTimeTrigger::step does nothing and is intentionally left blank --- //
} }
@@ -65,7 +65,7 @@ namespace gridfire::trigger::solver::CVODE {
return "Simulation Time Trigger"; return "Simulation Time Trigger";
} }
TriggerResult SimulationTimeTrigger::why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const { TriggerResult SimulationTimeTrigger::why(const gridfire::solver::PointSolverTimestepContext &ctx) const {
TriggerResult result; TriggerResult result;
result.name = name(); result.name = name();
if (check(ctx)) { if (check(ctx)) {
@@ -99,18 +99,18 @@ namespace gridfire::trigger::solver::CVODE {
} }
} }
bool OffDiagonalTrigger::check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const { bool OffDiagonalTrigger::check(const gridfire::solver::PointSolverTimestepContext &ctx) const {
//TODO : This currently does nothing //TODO : This currently does nothing
return false; return false;
} }
void OffDiagonalTrigger::update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) { void OffDiagonalTrigger::update(const gridfire::solver::PointSolverTimestepContext &ctx) {
m_updates++; m_updates++;
} }
void OffDiagonalTrigger::step( void OffDiagonalTrigger::step(
const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx const gridfire::solver::PointSolverTimestepContext &ctx
) { ) {
// --- OffDiagonalTrigger::step does nothing and is intentionally left blank --- // // --- OffDiagonalTrigger::step does nothing and is intentionally left blank --- //
} }
@@ -126,7 +126,7 @@ namespace gridfire::trigger::solver::CVODE {
return "Off-Diagonal Trigger"; return "Off-Diagonal Trigger";
} }
TriggerResult OffDiagonalTrigger::why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const { TriggerResult OffDiagonalTrigger::why(const gridfire::solver::PointSolverTimestepContext &ctx) const {
TriggerResult result; TriggerResult result;
result.name = name(); result.name = name();
@@ -173,7 +173,7 @@ namespace gridfire::trigger::solver::CVODE {
} }
} }
bool TimestepCollapseTrigger::check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const { bool TimestepCollapseTrigger::check(const gridfire::solver::PointSolverTimestepContext &ctx) const {
if (m_timestep_window.size() < m_windowSize) { if (m_timestep_window.size() < m_windowSize) {
m_misses++; m_misses++;
return false; return false;
@@ -201,13 +201,13 @@ namespace gridfire::trigger::solver::CVODE {
return false; return false;
} }
void TimestepCollapseTrigger::update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) { void TimestepCollapseTrigger::update(const gridfire::solver::PointSolverTimestepContext &ctx) {
m_updates++; m_updates++;
m_timestep_window.clear(); m_timestep_window.clear();
} }
void TimestepCollapseTrigger::step( void TimestepCollapseTrigger::step(
const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx const gridfire::solver::PointSolverTimestepContext &ctx
) { ) {
push_to_fixed_deque(m_timestep_window, ctx.dt, m_windowSize); push_to_fixed_deque(m_timestep_window, ctx.dt, m_windowSize);
// --- TimestepCollapseTrigger::step does nothing and is intentionally left blank --- // // --- TimestepCollapseTrigger::step does nothing and is intentionally left blank --- //
@@ -226,7 +226,7 @@ namespace gridfire::trigger::solver::CVODE {
} }
TriggerResult TimestepCollapseTrigger::why( TriggerResult TimestepCollapseTrigger::why(
const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx const gridfire::solver::PointSolverTimestepContext &ctx
) const { ) const {
TriggerResult result; TriggerResult result;
result.name = name(); result.name = name();
@@ -263,7 +263,7 @@ namespace gridfire::trigger::solver::CVODE {
m_windowSize(windowSize) {} m_windowSize(windowSize) {}
bool ConvergenceFailureTrigger::check( bool ConvergenceFailureTrigger::check(
const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx const gridfire::solver::PointSolverTimestepContext &ctx
) const { ) const {
if (m_window.size() != m_windowSize) { if (m_window.size() != m_windowSize) {
m_misses++; m_misses++;
@@ -278,13 +278,13 @@ namespace gridfire::trigger::solver::CVODE {
} }
void ConvergenceFailureTrigger::update( void ConvergenceFailureTrigger::update(
const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx const gridfire::solver::PointSolverTimestepContext &ctx
) { ) {
m_window.clear(); m_window.clear();
} }
void ConvergenceFailureTrigger::step( void ConvergenceFailureTrigger::step(
const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx const gridfire::solver::PointSolverTimestepContext &ctx
) { ) {
push_to_fixed_deque(m_window, ctx.currentConvergenceFailures, m_windowSize); push_to_fixed_deque(m_window, ctx.currentConvergenceFailures, m_windowSize);
m_updates++; m_updates++;
@@ -306,7 +306,7 @@ namespace gridfire::trigger::solver::CVODE {
return "ConvergenceFailureTrigger(abs_failure_threshold=" + std::to_string(m_totalFailures) + ", rel_failure_threshold=" + std::to_string(m_relativeFailureRate) + ", windowSize=" + std::to_string(m_windowSize) + ")"; return "ConvergenceFailureTrigger(abs_failure_threshold=" + std::to_string(m_totalFailures) + ", rel_failure_threshold=" + std::to_string(m_relativeFailureRate) + ", windowSize=" + std::to_string(m_windowSize) + ")";
} }
TriggerResult ConvergenceFailureTrigger::why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const { TriggerResult ConvergenceFailureTrigger::why(const gridfire::solver::PointSolverTimestepContext &ctx) const {
TriggerResult result; TriggerResult result;
result.name = name(); result.name = name();
@@ -348,7 +348,7 @@ namespace gridfire::trigger::solver::CVODE {
} }
bool ConvergenceFailureTrigger::abs_failure( bool ConvergenceFailureTrigger::abs_failure(
const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx const gridfire::solver::PointSolverTimestepContext &ctx
) const { ) const {
if (ctx.currentConvergenceFailures > m_totalFailures) { if (ctx.currentConvergenceFailures > m_totalFailures) {
return true; return true;
@@ -357,7 +357,7 @@ namespace gridfire::trigger::solver::CVODE {
} }
bool ConvergenceFailureTrigger::rel_failure( bool ConvergenceFailureTrigger::rel_failure(
const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx const gridfire::solver::PointSolverTimestepContext &ctx
) const { ) const {
const float mean = current_mean(); const float mean = current_mean();
if (mean < 10) { if (mean < 10) {
@@ -369,13 +369,13 @@ namespace gridfire::trigger::solver::CVODE {
return false; return false;
} }
std::unique_ptr<Trigger<gridfire::solver::CVODESolverStrategy::TimestepContext>> makeEnginePartitioningTrigger( std::unique_ptr<Trigger<gridfire::solver::PointSolverTimestepContext>> makeEnginePartitioningTrigger(
const double simulationTimeInterval, const double simulationTimeInterval,
const double offDiagonalThreshold, const double offDiagonalThreshold,
const double timestepCollapseRatio, const double timestepCollapseRatio,
const size_t maxConvergenceFailures const size_t maxConvergenceFailures
) { ) {
using ctx_t = gridfire::solver::CVODESolverStrategy::TimestepContext; using ctx_t = gridfire::solver::PointSolverTimestepContext;
// 1. INSTABILITY TRIGGERS (High Priority) // 1. INSTABILITY TRIGGERS (High Priority)
auto convergenceFailureTrigger = std::make_unique<ConvergenceFailureTrigger>( auto convergenceFailureTrigger = std::make_unique<ConvergenceFailureTrigger>(

View File

@@ -15,8 +15,8 @@ gridfire_sources = files(
'lib/reaction/weak/weak_interpolator.cpp', 'lib/reaction/weak/weak_interpolator.cpp',
'lib/io/network_file.cpp', 'lib/io/network_file.cpp',
'lib/io/generative/python.cpp', 'lib/io/generative/python.cpp',
'lib/solver/strategies/CVODE_solver_strategy.cpp', 'lib/solver/strategies/PointSolver.cpp',
'lib/solver/strategies/SpectralSolverStrategy.cpp', 'lib/solver/strategies/GridSolver.cpp',
'lib/solver/strategies/triggers/engine_partitioning_trigger.cpp', 'lib/solver/strategies/triggers/engine_partitioning_trigger.cpp',
'lib/screening/screening_types.cpp', 'lib/screening/screening_types.cpp',
'lib/screening/screening_weak.cpp', 'lib/screening/screening_weak.cpp',

View File

@@ -19,7 +19,7 @@
#include <clocale> #include <clocale>
#include "gridfire/reaction/reaclib.h" #include "gridfire/utils/gf_omp.h"
static std::terminate_handler g_previousHandler = nullptr; static std::terminate_handler g_previousHandler = nullptr;
@@ -31,7 +31,7 @@ gridfire::NetIn init(const double temp, const double rho, const double tMax) {
std::setlocale(LC_ALL, ""); std::setlocale(LC_ALL, "");
g_previousHandler = std::set_terminate(quill_terminate_handler); g_previousHandler = std::set_terminate(quill_terminate_handler);
quill::Logger* logger = fourdst::logging::LogManager::getInstance().getLogger("log"); quill::Logger* logger = fourdst::logging::LogManager::getInstance().getLogger("log");
logger->set_log_level(quill::LogLevel::TraceL2); logger->set_log_level(quill::LogLevel::Info);
using namespace gridfire; using namespace gridfire;
const std::vector<double> X = {0.7081145999999999, 2.94e-5, 0.276, 0.003, 0.0011, 9.62e-3, 1.62e-3, 5.16e-4}; const std::vector<double> X = {0.7081145999999999, 2.94e-5, 0.276, 0.003, 0.0011, 9.62e-3, 1.62e-3, 5.16e-4};
@@ -143,7 +143,7 @@ void log_results(const gridfire::NetOut& netOut, const gridfire::NetIn& netIn) {
} }
void record_abundance_history_callback(const gridfire::solver::CVODESolverStrategy::TimestepContext& ctx) { void record_abundance_history_callback(const gridfire::solver::PointSolverTimestepContext& ctx) {
s_wrote_abundance_history = true; s_wrote_abundance_history = true;
const auto& engine = ctx.engine; const auto& engine = ctx.engine;
// std::unordered_map<std::string, std::pair<double, double>> abundances; // std::unordered_map<std::string, std::pair<double, double>> abundances;
@@ -224,11 +224,12 @@ void quill_terminate_handler()
std::abort(); std::abort();
} }
void callback_main(const gridfire::solver::CVODESolverStrategy::TimestepContext& ctx) { void callback_main(const gridfire::solver::PointSolverTimestepContext& ctx) {
record_abundance_history_callback(ctx); record_abundance_history_callback(ctx);
} }
int main() { int main() {
GF_PAR_INIT();
using namespace gridfire; using namespace gridfire;
constexpr size_t breaks = 1; constexpr size_t breaks = 1;
@@ -239,98 +240,20 @@ int main() {
const NetIn netIn = init(temp, rho, tMax); const NetIn netIn = init(temp, rho, tMax);
policy::MainSequencePolicy stellarPolicy(netIn.composition); policy::MainSequencePolicy stellarPolicy(netIn.composition);
policy::ConstructionResults construct = stellarPolicy.construct(); auto [engine, ctx_template] = stellarPolicy.construct();
std::println("Sandbox Engine Stack: {}", stellarPolicy); std::println("Sandbox Engine Stack: {}", stellarPolicy);
std::println("Scratch Blob State: {}", *construct.scratch_blob); std::println("Scratch Blob State: {}", *ctx_template);
constexpr size_t nZones = 100;
constexpr size_t runs = 1000; std::array<NetIn, nZones> netIns;
auto startTime = std::chrono::high_resolution_clock::now(); for (size_t zone = 0; zone < nZones; ++zone) {
netIns[zone] = netIn;
// arrays to store timings netIns[zone].temperature = 1.0e7;
std::array<std::chrono::duration<double>, runs> setup_times;
std::array<std::chrono::duration<double>, runs> eval_times;
std::array<NetOut, runs> serial_results;
for (size_t i = 0; i < runs; ++i) {
auto start_setup_time = std::chrono::high_resolution_clock::now();
std::print("Run {}/{}\r", i + 1, runs);
solver::CVODESolverStrategy solver(construct.engine, *construct.scratch_blob);
// solver.set_callback(solver::CVODESolverStrategy::TimestepCallback(callback_main));
solver.set_stdout_logging_enabled(false);
auto end_setup_time = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> setup_elapsed = end_setup_time - start_setup_time;
setup_times[i] = setup_elapsed;
auto start_eval_time = std::chrono::high_resolution_clock::now();
const NetOut netOut = solver.evaluate(netIn);
auto end_eval_time = std::chrono::high_resolution_clock::now();
serial_results[i] = netOut;
std::chrono::duration<double> eval_elapsed = end_eval_time - start_eval_time;
eval_times[i] = eval_elapsed;
// log_results(netOut, netIn);
}
auto endTime = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> elapsed = endTime - startTime;
std::println("");
// Summarize serial timings
double total_setup_time = 0.0;
double total_eval_time = 0.0;
for (size_t i = 0; i < runs; ++i) {
total_setup_time += setup_times[i].count();
total_eval_time += eval_times[i].count();
}
std::println("Average Setup Time over {} runs: {:.6f} seconds", runs, total_setup_time / runs);
std::println("Average Evaluation Time over {} runs: {:.6f} seconds", runs, total_eval_time / runs);
std::println("Total Time for {} runs: {:.6f} seconds", runs, elapsed.count());
std::println("Final H-1 Abundances Serial: {}", serial_results[0].composition.getMolarAbundance(fourdst::atomic::H_1));
// OPTIONAL: Prevent CppAD from returning memory to the system
// during execution to reduce overhead (can speed up tight loops)
CppAD::thread_alloc::hold_memory(true);
std::array<NetOut, runs> parallelResults;
std::array<std::chrono::duration<double>, runs> setupTimes;
std::array<std::chrono::duration<double>, runs> evalTimes;
std::array<std::unique_ptr<gridfire::engine::scratch::StateBlob>, runs> workspaces;
for (size_t i = 0; i < runs; ++i) {
workspaces[i] = construct.scratch_blob->clone_structure();
} }
const solver::PointSolver localSolver(engine);
solver::GridSolverContext solverCtx(*ctx_template);
const solver::GridSolver gridSolver(engine, localSolver);
// Parallel runs std::vector<NetOut> netOuts = gridSolver.evaluate(solverCtx, netIns | std::ranges::to<std::vector>());
startTime = std::chrono::high_resolution_clock::now();
for (size_t i = 0; i < runs; ++i) {
auto start_setup_time = std::chrono::high_resolution_clock::now();
solver::CVODESolverStrategy solver(construct.engine, *workspaces[i]);
solver.set_stdout_logging_enabled(false);
auto end_setup_time = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> setup_elapsed = end_setup_time - start_setup_time;
setupTimes[i] = setup_elapsed;
auto start_eval_time = std::chrono::high_resolution_clock::now();
parallelResults[i] = solver.evaluate(netIn);
auto end_eval_time = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> eval_elapsed = end_eval_time - start_eval_time;
evalTimes[i] = eval_elapsed;
}
endTime = std::chrono::high_resolution_clock::now();
elapsed = endTime - startTime;
std::println("");
// Summarize parallel timings
total_setup_time = 0.0;
total_eval_time = 0.0;
for (size_t i = 0; i < runs; ++i) {
total_setup_time += setupTimes[i].count();
total_eval_time += evalTimes[i].count();
}
std::println("Average Parallel Setup Time over {} runs: {:.6f} seconds", runs, total_setup_time / runs);
std::println("Average Parallel Evaluation Time over {} runs: {:.6f} seconds", runs, total_eval_time / runs);
std::println("Total Parallel Time for {} runs: {:.6f} seconds", runs, elapsed.count());
std::println("Final H-1 Abundances Parallel: {}", utils::iterable_to_delimited_string(parallelResults, ",", [](const auto& result) {
return result.composition.getMolarAbundance(fourdst::atomic::H_1);
}));
} }

View File

@@ -4,8 +4,8 @@ executable(
dependencies: [gridfire_dep, cli11_dep], dependencies: [gridfire_dep, cli11_dep],
) )
executable( #executable(
'spectral_sandbox', # 'spectral_sandbox',
'spectral_main.cpp', # 'spectral_main.cpp',
dependencies: [gridfire_dep, cli11_dep] # dependencies: [gridfire_dep, cli11_dep]
) #)

View File

@@ -1,108 +0,0 @@
#include <iostream>
#include <fstream>
#include <chrono>
#include <thread>
#include "gridfire/gridfire.h"
#include "fourdst/composition/composition.h"
#include "fourdst/logging/logging.h"
#include "fourdst/atomic/species.h"
#include "fourdst/composition/utils.h"
#include "quill/Logger.h"
#include "quill/Backend.h"
#include "CLI/CLI.hpp"
#include <clocale>
static std::terminate_handler g_previousHandler = nullptr;
static std::vector<std::pair<double, std::unordered_map<std::string, std::pair<double, double>>>> g_callbackHistory;
static bool s_wrote_abundance_history = false;
void quill_terminate_handler();
std::vector<double> linspace(const double start, const double end, const size_t num) {
std::vector<double> result;
if (num == 0) return result;
if (num == 1) {
result.push_back(start);
return result;
}
const double step = (end - start) / static_cast<double>(num - 1);
for (size_t i = 0; i < num; ++i) {
result.push_back(start + i * step);
}
return result;
}
std::vector<gridfire::NetIn> init(const double tMin, const double tMax, const double rhoMin, const double rhoMax, const double nShells, const double evolveTime) {
std::setlocale(LC_ALL, "");
g_previousHandler = std::set_terminate(quill_terminate_handler);
quill::Logger* logger = fourdst::logging::LogManager::getInstance().getLogger("log");
logger->set_log_level(quill::LogLevel::TraceL2);
LOG_INFO(logger, "Initializing GridFire Spectral Solver Sandbox...");
using namespace gridfire;
const std::vector<double> X = {0.7081145999999999, 2.94e-5, 0.276, 0.003, 0.0011, 9.62e-3, 1.62e-3, 5.16e-4};
const std::vector<std::string> symbols = {"H-1", "He-3", "He-4", "C-12", "N-14", "O-16", "Ne-20", "Mg-24"};
const fourdst::composition::Composition composition = fourdst::composition::buildCompositionFromMassFractions(symbols, X);
std::vector<NetIn> netIns;
for (const auto& [T, ρ]: std::views::zip(linspace(tMin, tMax, nShells), linspace(rhoMax, rhoMin, nShells))) {
NetIn netIn;
netIn.composition = composition;
netIn.temperature = T;
netIn.density = ρ;
netIn.energy = 0;
netIn.tMax = evolveTime;
netIn.dt0 = 1e-12;
netIns.push_back(netIn);
}
return netIns;
}
void quill_terminate_handler()
{
quill::Backend::stop();
if (g_previousHandler)
g_previousHandler();
else
std::abort();
}
int main(int argc, char** argv) {
using namespace gridfire;
CLI::App app{"GridFire Sandbox Application."};
double tMin = 1.5e7;
double tMax = 1.7e7;
double rhoMin = 1.5e2;
double rhoMax = 1.5e2;
double nShells = 15;
double evolveTime = 3.1536e+16;
app.add_option("--tMin", tMin, "Minimum time in seconds");
app.add_option("--tMax", tMax, "Maximum time in seconds");
app.add_option("--rhoMin", rhoMin, "Minimum density in g/cm^3");
app.add_option("--rhoMax", rhoMax, "Maximum density in g/cm^3");
app.add_option("--nShells", nShells, "Number of shells");
app.add_option("--evolveTime", evolveTime, "Maximum time in seconds");
CLI11_PARSE(app, argc, argv);
const std::vector<NetIn> netIns = init(tMin, tMax, rhoMin, rhoMax, nShells, evolveTime);
policy::MainSequencePolicy stellarPolicy(netIns[0].composition);
stellarPolicy.construct();
policy::ConstructionResults construct = stellarPolicy.construct();
solver::SpectralSolverStrategy solver(construct.engine);
std::vector<double> mass_coords = linspace(1e-5, 1.0, nShells);
std::vector<NetOut> results = solver.evaluate(netIns, mass_coords, *construct.scratch_blob);
}

373
tools/cli/gf_quick/main.cpp Normal file
View File

@@ -0,0 +1,373 @@
// ReSharper disable CppUnusedIncludeDirective
#include <iostream>
#include <fstream>
#include <chrono>
#include <thread>
#include <format>
#include "gridfire/gridfire.h"
#include <cppad/utility/thread_alloc.hpp> // Required for parallel_setup
#include "fourdst/composition/composition.h"
#include "fourdst/logging/logging.h"
#include "fourdst/atomic/species.h"
#include "fourdst/composition/utils.h"
#include "quill/Logger.h"
#include "quill/Backend.h"
#include "CLI/CLI.hpp"
#include <clocale>
#include "gridfire/utils/gf_omp.h"
static std::terminate_handler g_previousHandler = nullptr;
static std::vector<std::pair<double, std::unordered_map<std::string, std::pair<double, double>>>> g_callbackHistory;
static bool s_wrote_abundance_history = false;
void quill_terminate_handler();
using namespace fourdst::composition;
Composition rescale(const Composition& comp, double target_X, double target_Z) {
// 1. Validate inputs
if (target_X < 0.0 || target_Z < 0.0 || (target_X + target_Z) > 1.0 + 1e-14) {
throw std::invalid_argument("Target mass fractions X and Z must be non-negative and sum to <= 1.0");
}
// Force high precision for the target Y to ensure X+Y+Z = 1.0 exactly in our logic
long double ld_target_X = static_cast<long double>(target_X);
long double ld_target_Z = static_cast<long double>(target_Z);
long double ld_target_Y = 1.0L - ld_target_X - ld_target_Z;
// Clamp Y to 0 if it dipped slightly below due to precision (e.g. X+Z=1.0000000001)
if (ld_target_Y < 0.0L) ld_target_Y = 0.0L;
// 2. Manually calculate current Mass Totals (bypass getCanonicalComposition to avoid crashes)
long double total_mass_H = 0.0L;
long double total_mass_He = 0.0L;
long double total_mass_Z = 0.0L;
// We need to iterate and identify species types manually
// Standard definition: H (z=1), He (z=2), Metals (z>2)
// Note: We use long double accumulators to prevent summation drift
for (const auto& [spec, molar_abundance] : comp) {
// Retrieve atomic properties.
// Note: usage assumes fourdst::atomic::Species has .z() and .mass()
// consistent with the provided composition.cpp
int z = spec.z();
double a = spec.mass();
long double mass_contribution = static_cast<long double>(molar_abundance) * static_cast<long double>(a);
if (z == 1) {
total_mass_H += mass_contribution;
} else if (z == 2) {
total_mass_He += mass_contribution;
} else {
total_mass_Z += mass_contribution;
}
}
long double total_mass_current = total_mass_H + total_mass_He + total_mass_Z;
// Edge case: Empty composition
if (total_mass_current <= 0.0L) {
// Return empty or throw? If input was empty, return empty.
if (comp.size() == 0) return comp;
throw std::runtime_error("Input composition has zero total mass.");
}
// 3. Calculate Scaling Factors
// Factor = (Target_Mass_Fraction / Old_Mass_Fraction)
// = (Target_Mass_Fraction) / (Old_Group_Mass / Total_Mass)
// = (Target_Mass_Fraction * Total_Mass) / Old_Group_Mass
long double scale_H = 0.0L;
long double scale_He = 0.0L;
long double scale_Z = 0.0L;
if (ld_target_X > 1e-16L) {
if (total_mass_H <= 1e-19L) {
throw std::runtime_error("Cannot rescale Hydrogen to " + std::to_string(target_X) +
" because input has no Hydrogen.");
}
scale_H = (ld_target_X * total_mass_current) / total_mass_H;
}
if (ld_target_Y > 1e-16L) {
if (total_mass_He <= 1e-19L) {
throw std::runtime_error("Cannot rescale Helium to " + std::to_string((double)ld_target_Y) +
" because input has no Helium.");
}
scale_He = (ld_target_Y * total_mass_current) / total_mass_He;
}
if (ld_target_Z > 1e-16L) {
if (total_mass_Z <= 1e-19L) {
throw std::runtime_error("Cannot rescale Metals to " + std::to_string(target_Z) +
" because input has no Metals.");
}
scale_Z = (ld_target_Z * total_mass_current) / total_mass_Z;
}
// 4. Apply Scaling and Construct New Vectors
std::vector<fourdst::atomic::Species> new_species;
std::vector<double> new_abundances;
new_species.reserve(comp.size());
new_abundances.reserve(comp.size());
for (const auto& [spec, abundance] : comp) {
new_species.push_back(spec);
long double factor = 0.0L;
int z = spec.z();
if (z == 1) {
factor = scale_H;
} else if (z == 2) {
factor = scale_He;
} else {
factor = scale_Z;
}
// Calculate new abundance in long double then cast back
long double new_val_ld = static_cast<long double>(abundance) * factor;
new_abundances.push_back(static_cast<double>(new_val_ld));
}
return Composition(new_species, new_abundances);
}
gridfire::NetIn init(const double temp, const double rho, const double tMax) {
std::setlocale(LC_ALL, "");
g_previousHandler = std::set_terminate(quill_terminate_handler);
quill::Logger* logger = fourdst::logging::LogManager::getInstance().getLogger("log");
logger->set_log_level(quill::LogLevel::Info);
using namespace gridfire;
const std::vector<double> X = {0.7081145999999999, 2.94e-5, 0.276, 0.003, 0.0011, 9.62e-3, 1.62e-3, 5.16e-4};
const std::vector<std::string> symbols = {"H-1", "He-3", "He-4", "C-12", "N-14", "O-16", "Ne-20", "Mg-24"};
const fourdst::composition::Composition composition = fourdst::composition::buildCompositionFromMassFractions(symbols, X);
NetIn netIn;
netIn.composition = composition;
netIn.temperature = temp;
netIn.density = rho;
netIn.energy = 0;
netIn.tMax = tMax;
netIn.dt0 = 1e-12;
return netIn;
}
void log_results(const gridfire::NetOut& netOut, const gridfire::NetIn& netIn) {
std::vector<fourdst::atomic::Species> logSpecies = {
fourdst::atomic::H_1,
fourdst::atomic::He_3,
fourdst::atomic::He_4,
fourdst::atomic::C_12,
fourdst::atomic::N_14,
fourdst::atomic::O_16,
fourdst::atomic::Ne_20,
fourdst::atomic::Mg_24
};
std::vector<double> initial;
std::vector<double> final;
std::vector<double> delta;
std::vector<double> fractional;
for (const auto& species : logSpecies) {
double initial_X = netIn.composition.getMassFraction(species);
double final_X = netOut.composition.getMassFraction(species);
double delta_X = final_X - initial_X;
double fractionalChange = (delta_X) / initial_X * 100.0;
initial.push_back(initial_X);
final.push_back(final_X);
delta.push_back(delta_X);
fractional.push_back(fractionalChange);
}
initial.push_back(0.0); // Placeholder for energy
final.push_back(netOut.energy);
delta.push_back(netOut.energy);
fractional.push_back(0.0); // Placeholder for energy
initial.push_back(0.0);
final.push_back(netOut.dEps_dT);
delta.push_back(netOut.dEps_dT);
fractional.push_back(0.0);
initial.push_back(0.0);
final.push_back(netOut.dEps_dRho);
delta.push_back(netOut.dEps_dRho);
fractional.push_back(0.0);
initial.push_back(0.0);
final.push_back(netOut.specific_neutrino_energy_loss);
delta.push_back(netOut.specific_neutrino_energy_loss);
fractional.push_back(0.0);
initial.push_back(0.0);
final.push_back(netOut.specific_neutrino_flux);
delta.push_back(netOut.specific_neutrino_flux);
fractional.push_back(0.0);
initial.push_back(netIn.composition.getMeanParticleMass());
final.push_back(netOut.composition.getMeanParticleMass());
delta.push_back(final.back() - initial.back());
fractional.push_back((final.back() - initial.back()) / initial.back() * 100.0);
std::vector<std::string> rowLabels = [&]() -> std::vector<std::string> {
std::vector<std::string> labels;
for (const auto& species : logSpecies) {
labels.emplace_back(species.name());
}
labels.emplace_back("ε");
labels.emplace_back("dε/dT");
labels.emplace_back("dε/dρ");
labels.emplace_back("Eν");
labels.emplace_back("Fν");
labels.emplace_back("<μ>");
return labels;
}();
gridfire::utils::Column<std::string> paramCol("Parameter", rowLabels);
gridfire::utils::Column<double> initialCol("Initial", initial);
gridfire::utils::Column<double> finalCol ("Final", final);
gridfire::utils::Column<double> deltaCol ("δ", delta);
gridfire::utils::Column<double> percentCol("% Change", fractional);
std::vector<std::unique_ptr<gridfire::utils::ColumnBase>> columns;
columns.push_back(std::make_unique<gridfire::utils::Column<std::string>>(paramCol));
columns.push_back(std::make_unique<gridfire::utils::Column<double>>(initialCol));
columns.push_back(std::make_unique<gridfire::utils::Column<double>>(finalCol));
columns.push_back(std::make_unique<gridfire::utils::Column<double>>(deltaCol));
columns.push_back(std::make_unique<gridfire::utils::Column<double>>(percentCol));
gridfire::utils::print_table("Simulation Results", columns);
}
void record_abundance_history_callback(const gridfire::solver::PointSolverTimestepContext& ctx) {
s_wrote_abundance_history = true;
const auto& engine = ctx.engine;
// std::unordered_map<std::string, std::pair<double, double>> abundances;
std::vector<double> Y;
for (const auto& species : engine.getNetworkSpecies(ctx.state_ctx)) {
const size_t sid = engine.getSpeciesIndex(ctx.state_ctx, species);
double y = N_VGetArrayPointer(ctx.state)[sid];
Y.push_back(y > 0.0 ? y : 0.0); // Regularize tiny negative abundances to zero
}
fourdst::composition::Composition comp(engine.getNetworkSpecies(ctx.state_ctx), Y);
std::unordered_map<std::string, std::pair<double, double>> abundances;
for (const auto& sp : comp | std::views::keys) {
abundances.emplace(std::string(sp.name()), std::make_pair(sp.mass(), comp.getMolarAbundance(sp)));
}
g_callbackHistory.emplace_back(ctx.t, abundances);
}
void save_callback_data(const std::string_view filename) {
std::set<std::string> unique_species;
for (const auto &abundances: g_callbackHistory | std::views::values) {
for (const auto &species_name: abundances | std::views::keys) {
unique_species.insert(species_name);
}
}
std::ofstream csvFile(filename.data(), std::ios::out);
csvFile << "t,";
size_t i = 0;
for (const auto& species_name : unique_species) {
csvFile << species_name;
if (i < unique_species.size() - 1) {
csvFile << ",";
}
i++;
}
csvFile << "\n";
for (const auto& [time, data] : g_callbackHistory) {
csvFile << time << ",";
size_t j = 0;
for (const auto& species_name : unique_species) {
if (!data.contains(species_name)) {
csvFile << "0.0";
} else {
csvFile << data.at(species_name).second;
}
if (j < unique_species.size() - 1) {
csvFile << ",";
}
++j;
}
csvFile << "\n";
}
csvFile.close();
}
void log_callback_data(const double temp) {
if (s_wrote_abundance_history) {
std::cout << "Saving abundance history to abundance_history.csv" << std::endl;
save_callback_data("abundance_history_" + std::to_string(temp) + ".csv");
}
}
void quill_terminate_handler()
{
log_callback_data(1.5e7);
quill::Backend::stop();
if (g_previousHandler)
g_previousHandler();
else
std::abort();
}
void callback_main(const gridfire::solver::PointSolverTimestepContext& ctx) {
record_abundance_history_callback(ctx);
}
int main(int argc, char** argv) {
GF_PAR_INIT();
using namespace gridfire;
double temp = 1.5e7;
double rho = 1.5e2;
double tMax = 3.1536e+16;
double X = 0.7;
double Z = 0.02;
CLI::App app("GridFire Quick CLI Test");
// Add temp, rho, and tMax as options if desired
app.add_option("--temp", temp, "Initial Temperature")->default_val(std::format("{:5.2E}", temp));
app.add_option("--rho", rho, "Initial Density")->default_val(std::format("{:5.2E}", rho));
app.add_option("--tmax", tMax, "Maximum Time")->default_val(std::format("{:5.2E}", tMax));
// app.add_option("--X", X, "Target Hydrogen Mass Fraction")->default_val(std::format("{:5.2f}", X));
// app.add_option("--Z", Z, "Target Metal Mass Fraction")->default_val(std::format("{:5.2f}", Z));
CLI11_PARSE(app, argc, argv);
NetIn netIn = init(temp, rho, tMax);
// netIn.composition = rescale(netIn.composition, X, Z);
policy::MainSequencePolicy stellarPolicy(netIn.composition);
auto [engine, ctx_template] = stellarPolicy.construct();
solver::PointSolverContext solver_context(*ctx_template);
solver::PointSolver solver(engine);
NetOut result = solver.evaluate(solver_context, netIn);
log_results(result, netIn);
}

View File

@@ -0,0 +1 @@
executable('gf_quick', 'main.cpp', dependencies: [gridfire_dep, cli11_dep])

1
tools/cli/meson.build Normal file
View File

@@ -0,0 +1 @@
subdir('gf_quick')

View File

@@ -1,3 +1,4 @@
if get_option('build_tools') if get_option('build_tools')
subdir('config') subdir('config')
subdir('cli')
endif endif