perf(multi): Simple parallel multi zone solver
Added a simple parallel multi-zone solver
This commit is contained in:
@@ -63,7 +63,7 @@ int main() {
|
|||||||
std::println("Scratch Blob State: {}", *construct.scratch_blob);
|
std::println("Scratch Blob State: {}", *construct.scratch_blob);
|
||||||
|
|
||||||
|
|
||||||
constexpr size_t runs = 1000;
|
constexpr size_t runs = 10;
|
||||||
auto startTime = std::chrono::high_resolution_clock::now();
|
auto startTime = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
// arrays to store timings
|
// arrays to store timings
|
||||||
@@ -72,14 +72,15 @@ int main() {
|
|||||||
std::array<NetOut, runs> serial_results;
|
std::array<NetOut, runs> serial_results;
|
||||||
for (size_t i = 0; i < runs; ++i) {
|
for (size_t i = 0; i < runs; ++i) {
|
||||||
auto start_setup_time = std::chrono::high_resolution_clock::now();
|
auto start_setup_time = std::chrono::high_resolution_clock::now();
|
||||||
solver::CVODESolverStrategy solver(construct.engine, *construct.scratch_blob);
|
solver::PointSolverContext solverCtx(*construct.scratch_blob);
|
||||||
solver.set_stdout_logging_enabled(false);
|
solverCtx.set_stdout_logging(false);
|
||||||
|
solver::PointSolver solver(construct.engine);
|
||||||
auto end_setup_time = std::chrono::high_resolution_clock::now();
|
auto end_setup_time = std::chrono::high_resolution_clock::now();
|
||||||
std::chrono::duration<double> setup_elapsed = end_setup_time - start_setup_time;
|
std::chrono::duration<double> setup_elapsed = end_setup_time - start_setup_time;
|
||||||
setup_times[i] = setup_elapsed;
|
setup_times[i] = setup_elapsed;
|
||||||
|
|
||||||
auto start_eval_time = std::chrono::high_resolution_clock::now();
|
auto start_eval_time = std::chrono::high_resolution_clock::now();
|
||||||
const NetOut netOut = solver.evaluate(netIn);
|
const NetOut netOut = solver.evaluate(solverCtx, netIn);
|
||||||
auto end_eval_time = std::chrono::high_resolution_clock::now();
|
auto end_eval_time = std::chrono::high_resolution_clock::now();
|
||||||
serial_results[i] = netOut;
|
serial_results[i] = netOut;
|
||||||
std::chrono::duration<double> eval_elapsed = end_eval_time - start_eval_time;
|
std::chrono::duration<double> eval_elapsed = end_eval_time - start_eval_time;
|
||||||
@@ -99,7 +100,6 @@ int main() {
|
|||||||
std::println("Average Setup Time over {} runs: {:.6f} seconds", runs, total_setup_time / runs);
|
std::println("Average Setup Time over {} runs: {:.6f} seconds", runs, total_setup_time / runs);
|
||||||
std::println("Average Evaluation Time over {} runs: {:.6f} seconds", runs, total_eval_time / runs);
|
std::println("Average Evaluation Time over {} runs: {:.6f} seconds", runs, total_eval_time / runs);
|
||||||
std::println("Total Time for {} runs: {:.6f} seconds", runs, elapsed.count());
|
std::println("Total Time for {} runs: {:.6f} seconds", runs, elapsed.count());
|
||||||
std::println("Final H-1 Abundances Serial: {}", serial_results[0].composition.getMolarAbundance(fourdst::atomic::H_1));
|
|
||||||
|
|
||||||
|
|
||||||
std::array<NetOut, runs> parallelResults;
|
std::array<NetOut, runs> parallelResults;
|
||||||
@@ -114,16 +114,16 @@ int main() {
|
|||||||
// Parallel runs
|
// Parallel runs
|
||||||
startTime = std::chrono::high_resolution_clock::now();
|
startTime = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
GF_OMP(parallel for,)
|
GF_OMP(parallel for, for (size_t i = 0; i < runs; ++i)) {
|
||||||
for (size_t i = 0; i < runs; ++i) {
|
|
||||||
auto start_setup_time = std::chrono::high_resolution_clock::now();
|
auto start_setup_time = std::chrono::high_resolution_clock::now();
|
||||||
solver::CVODESolverStrategy solver(construct.engine, *workspaces[i]);
|
solver::PointSolverContext solverCtx(*construct.scratch_blob);
|
||||||
solver.set_stdout_logging_enabled(false);
|
solverCtx.set_stdout_logging(false);
|
||||||
|
solver::PointSolver solver(construct.engine);
|
||||||
auto end_setup_time = std::chrono::high_resolution_clock::now();
|
auto end_setup_time = std::chrono::high_resolution_clock::now();
|
||||||
std::chrono::duration<double> setup_elapsed = end_setup_time - start_setup_time;
|
std::chrono::duration<double> setup_elapsed = end_setup_time - start_setup_time;
|
||||||
setupTimes[i] = setup_elapsed;
|
setupTimes[i] = setup_elapsed;
|
||||||
auto start_eval_time = std::chrono::high_resolution_clock::now();
|
auto start_eval_time = std::chrono::high_resolution_clock::now();
|
||||||
parallelResults[i] = solver.evaluate(netIn);
|
parallelResults[i] = solver.evaluate(solverCtx, netIn);
|
||||||
auto end_eval_time = std::chrono::high_resolution_clock::now();
|
auto end_eval_time = std::chrono::high_resolution_clock::now();
|
||||||
std::chrono::duration<double> eval_elapsed = end_eval_time - start_eval_time;
|
std::chrono::duration<double> eval_elapsed = end_eval_time - start_eval_time;
|
||||||
evalTimes[i] = eval_elapsed;
|
evalTimes[i] = eval_elapsed;
|
||||||
@@ -144,10 +144,6 @@ int main() {
|
|||||||
std::println("Average Parallel Evaluation Time over {} runs: {:.6f} seconds", runs, total_eval_time / runs);
|
std::println("Average Parallel Evaluation Time over {} runs: {:.6f} seconds", runs, total_eval_time / runs);
|
||||||
std::println("Total Parallel Time for {} runs: {:.6f} seconds", runs, elapsed.count());
|
std::println("Total Parallel Time for {} runs: {:.6f} seconds", runs, elapsed.count());
|
||||||
|
|
||||||
std::println("Final H-1 Abundances Parallel: {}", utils::iterable_to_delimited_string(parallelResults, ",", [](const auto& result) {
|
|
||||||
return result.composition.getMolarAbundance(fourdst::atomic::H_1);
|
|
||||||
}));
|
|
||||||
|
|
||||||
std::println("========== Summary ==========");
|
std::println("========== Summary ==========");
|
||||||
std::println("Serial Runs:");
|
std::println("Serial Runs:");
|
||||||
std::println(" Average Setup Time: {:.6f} seconds", total_setup_time / runs);
|
std::println(" Average Setup Time: {:.6f} seconds", total_setup_time / runs);
|
||||||
|
|||||||
@@ -35,3 +35,16 @@ endif
|
|||||||
if get_option('openmp_support')
|
if get_option('openmp_support')
|
||||||
add_project_arguments('-DGF_USE_OPENMP', language: 'cpp')
|
add_project_arguments('-DGF_USE_OPENMP', language: 'cpp')
|
||||||
endif
|
endif
|
||||||
|
|
||||||
|
if get_option('asan') and get_option('buildtype') != 'debug' and get_option('buildtype') != 'debugoptimized'
|
||||||
|
error('AddressSanitizer (ASan) can only be enabled for debug or debugoptimized builds')
|
||||||
|
endif
|
||||||
|
|
||||||
|
if get_option('asan') and (get_option('buildtype') == 'debugoptimized' or get_option('buildtype') == 'debug')
|
||||||
|
message('enabling AddressSanitizer (ASan) support')
|
||||||
|
add_project_arguments('-fsanitize=address,undefined', language: 'cpp')
|
||||||
|
add_project_arguments('-fno-omit-frame-pointer', language: 'cpp')
|
||||||
|
|
||||||
|
add_project_link_arguments('-fsanitize=address,undefined', language: 'cpp')
|
||||||
|
add_project_link_arguments('-fno-omit-frame-pointer', language: 'cpp')
|
||||||
|
endif
|
||||||
|
|||||||
@@ -11,4 +11,5 @@ option('build_c_api', type: 'boolean', value: true, description: 'compile the C
|
|||||||
option('build_tools', type: 'boolean', value: true, description: 'build the GridFire command line tools')
|
option('build_tools', type: 'boolean', value: true, description: 'build the GridFire command line tools')
|
||||||
option('openmp_support', type: 'boolean', value: false, description: 'Enable OpenMP support for parallelization')
|
option('openmp_support', type: 'boolean', value: false, description: 'Enable OpenMP support for parallelization')
|
||||||
option('use_mimalloc', type: 'boolean', value: true, description: 'Use mimalloc as the memory allocator for GridFire. Generally this is ~10% faster than the system allocator.')
|
option('use_mimalloc', type: 'boolean', value: true, description: 'Use mimalloc as the memory allocator for GridFire. Generally this is ~10% faster than the system allocator.')
|
||||||
option('build_benchmarks', type: 'boolean', value: false, description: 'build the benchmark suite')
|
option('build_benchmarks', type: 'boolean', value: false, description: 'build the benchmark suite')
|
||||||
|
option('asan', type: 'boolean', value: false, description: 'Enable AddressSanitizer (ASan) support for detecting memory errors')
|
||||||
@@ -11,10 +11,10 @@ namespace gridfire::config {
|
|||||||
|
|
||||||
struct SpectralSolverConfig {
|
struct SpectralSolverConfig {
|
||||||
struct Trigger {
|
struct Trigger {
|
||||||
double simulationTimeInterval = 1.0e12;
|
|
||||||
double offDiagonalThreshold = 1.0e10;
|
|
||||||
double timestepCollapseRatio = 0.5;
|
double timestepCollapseRatio = 0.5;
|
||||||
size_t maxConvergenceFailures = 2;
|
size_t maxConvergenceFailures = 2;
|
||||||
|
double relativeFailureRate = 0.5;
|
||||||
|
size_t windowSize = 10;
|
||||||
};
|
};
|
||||||
struct MonitorFunctionConfig {
|
struct MonitorFunctionConfig {
|
||||||
double structure_weight = 1.0;
|
double structure_weight = 1.0;
|
||||||
|
|||||||
@@ -807,8 +807,6 @@ namespace gridfire::engine {
|
|||||||
|
|
||||||
CppAD::ADFun<double> m_authoritativeADFun;
|
CppAD::ADFun<double> m_authoritativeADFun;
|
||||||
|
|
||||||
const size_t m_state_blob_offset;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/**
|
/**
|
||||||
* @brief Synchronizes the internal maps.
|
* @brief Synchronizes the internal maps.
|
||||||
|
|||||||
43
src/include/gridfire/solver/strategies/GridSolver.h
Normal file
43
src/include/gridfire/solver/strategies/GridSolver.h
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "gridfire/solver/strategies/strategy_abstract.h"
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
|
||||||
|
namespace gridfire::solver {
|
||||||
|
struct GridSolverContext final : SolverContextBase {
|
||||||
|
std::vector<std::unique_ptr<SolverContextBase>> solver_workspaces;
|
||||||
|
std::vector<std::function<void(const TimestepContextBase&)>> timestep_callbacks;
|
||||||
|
const engine::scratch::StateBlob& ctx_template;
|
||||||
|
|
||||||
|
bool zone_completion_logging = true;
|
||||||
|
bool zone_stdout_logging = false;
|
||||||
|
bool zone_detailed_logging = false;
|
||||||
|
|
||||||
|
void init() override;
|
||||||
|
void reset();
|
||||||
|
|
||||||
|
void set_callback(const std::function<void(const TimestepContextBase&)> &callback);
|
||||||
|
void set_callback(const std::function<void(const TimestepContextBase&)> &callback, size_t zone_idx);
|
||||||
|
|
||||||
|
void set_stdout_logging(bool enable) override;
|
||||||
|
void set_detailed_logging(bool enable) override;
|
||||||
|
|
||||||
|
explicit GridSolverContext(const engine::scratch::StateBlob& ctx_template);
|
||||||
|
};
|
||||||
|
|
||||||
|
class GridSolver final : public MultiZoneDynamicNetworkSolver {
|
||||||
|
public:
|
||||||
|
GridSolver(
|
||||||
|
const engine::DynamicEngine& engine,
|
||||||
|
const SingleZoneDynamicNetworkSolver& solver
|
||||||
|
);
|
||||||
|
|
||||||
|
~GridSolver() override = default;
|
||||||
|
|
||||||
|
std::vector<NetOut> evaluate(
|
||||||
|
SolverContextBase& ctx,
|
||||||
|
const std::vector<NetIn>& netIns
|
||||||
|
) const override;
|
||||||
|
};
|
||||||
|
}
|
||||||
@@ -44,8 +44,88 @@
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace gridfire::solver {
|
namespace gridfire::solver {
|
||||||
|
struct PointSolverTimestepContext final : TimestepContextBase {
|
||||||
|
const double t; ///< Current integration time [s].
|
||||||
|
const N_Vector& state; ///< Current CVODE state vector (N_Vector).
|
||||||
|
const double dt; ///< Last step size [s].
|
||||||
|
const double last_step_time; ///< Time at last callback [s].
|
||||||
|
const double T9; ///< Temperature in GK.
|
||||||
|
const double rho; ///< Density [g cm^-3].
|
||||||
|
const size_t num_steps; ///< Number of CVODE steps taken so far.
|
||||||
|
const engine::DynamicEngine& engine; ///< Reference to the engine.
|
||||||
|
const std::vector<fourdst::atomic::Species>& networkSpecies; ///< Species layout.
|
||||||
|
const size_t currentConvergenceFailures; ///< Total number of convergence failures
|
||||||
|
const size_t currentNonlinearIterations; ///< Total number of non-linear iterations
|
||||||
|
const std::map<fourdst::atomic::Species, std::unordered_map<std::string, double>>& reactionContributionMap; ///< Map of reaction contributions for the current step
|
||||||
|
engine::scratch::StateBlob& state_ctx; ///< Reference to the engine scratch state blob
|
||||||
|
|
||||||
|
PointSolverTimestepContext(
|
||||||
|
double t,
|
||||||
|
const N_Vector& state,
|
||||||
|
double dt,
|
||||||
|
double last_step_time,
|
||||||
|
double t9,
|
||||||
|
double rho,
|
||||||
|
size_t num_steps,
|
||||||
|
const engine::DynamicEngine& engine,
|
||||||
|
const std::vector<fourdst::atomic::Species>& networkSpecies,
|
||||||
|
size_t currentConvergenceFailure,
|
||||||
|
size_t currentNonlinearIterations,
|
||||||
|
const std::map<fourdst::atomic::Species, std::unordered_map<std::string, double>> &reactionContributionMap,
|
||||||
|
engine::scratch::StateBlob& state_ctx
|
||||||
|
);
|
||||||
|
|
||||||
|
[[nodiscard]] std::vector<std::tuple<std::string, std::string>> describe() const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
using TimestepCallback = std::function<void(const PointSolverTimestepContext& context)>; ///< Type alias for a timestep callback function.
|
||||||
|
|
||||||
|
struct PointSolverContext final : SolverContextBase {
|
||||||
|
SUNContext sun_ctx = nullptr; ///< SUNDIALS context (lifetime of the solver).
|
||||||
|
void* cvode_mem = nullptr; ///< CVODE memory block.
|
||||||
|
N_Vector Y = nullptr; ///< CVODE state vector (species + energy accumulator).
|
||||||
|
N_Vector YErr = nullptr; ///< Estimated local errors.
|
||||||
|
SUNMatrix J = nullptr; ///< Dense Jacobian matrix.
|
||||||
|
SUNLinearSolver LS = nullptr; ///< Dense linear solver.
|
||||||
|
|
||||||
|
std::unique_ptr<engine::scratch::StateBlob> engine_ctx;
|
||||||
|
|
||||||
|
|
||||||
|
std::optional<TimestepCallback> callback; ///< Optional per-step callback.
|
||||||
|
int num_steps = 0; ///< CVODE step counter (used for diagnostics and triggers).
|
||||||
|
|
||||||
|
bool stdout_logging = true; ///< If true, print per-step logs and use CV_ONE_STEP.
|
||||||
|
|
||||||
|
N_Vector constraints = nullptr; ///< CVODE constraints vector (>= 0 for species entries).
|
||||||
|
|
||||||
|
std::optional<double> abs_tol; ///< User-specified absolute tolerance.
|
||||||
|
std::optional<double> rel_tol; ///< User-specified relative tolerance.
|
||||||
|
|
||||||
|
bool detailed_step_logging = false; ///< If true, log detailed step diagnostics (error ratios, Jacobian, species balance).
|
||||||
|
|
||||||
|
size_t last_size = 0;
|
||||||
|
size_t last_composition_hash = 0ULL;
|
||||||
|
sunrealtype last_good_time_step = 0ULL;
|
||||||
|
|
||||||
|
void init() override;
|
||||||
|
void set_stdout_logging(bool enable) override;
|
||||||
|
void set_detailed_logging(bool enable) override;
|
||||||
|
|
||||||
|
void reset_all();
|
||||||
|
void reset_user();
|
||||||
|
void reset_cvode();
|
||||||
|
void clear_context();
|
||||||
|
void init_context();
|
||||||
|
|
||||||
|
[[nodiscard]] bool has_context() const;
|
||||||
|
|
||||||
|
explicit PointSolverContext(const engine::scratch::StateBlob& engine_ctx);
|
||||||
|
~PointSolverContext() override;
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @class CVODESolverStrategy
|
* @class PointSolver
|
||||||
* @brief Stiff ODE integrator backed by SUNDIALS CVODE (BDF) for network + energy.
|
* @brief Stiff ODE integrator backed by SUNDIALS CVODE (BDF) for network + energy.
|
||||||
*
|
*
|
||||||
* Integrates the nuclear network abundances along with a final accumulator entry for specific
|
* Integrates the nuclear network abundances along with a final accumulator entry for specific
|
||||||
@@ -78,27 +158,16 @@ namespace gridfire::solver {
|
|||||||
* std::cout << "Final energy: " << out.energy << " erg/g\n";
|
* std::cout << "Final energy: " << out.energy << " erg/g\n";
|
||||||
* @endcode
|
* @endcode
|
||||||
*/
|
*/
|
||||||
class CVODESolverStrategy final : public SingleZoneDynamicNetworkSolver {
|
class PointSolver final : public SingleZoneDynamicNetworkSolver {
|
||||||
public:
|
public:
|
||||||
/**
|
/**
|
||||||
* @brief Construct the CVODE strategy and create a SUNDIALS context.
|
* @brief Construct the CVODE strategy and create a SUNDIALS context.
|
||||||
* @param engine DynamicEngine used for RHS/Jacobian evaluation and network access.
|
* @param engine DynamicEngine used for RHS/Jacobian evaluation and network access.
|
||||||
* @throws std::runtime_error If SUNContext_Create fails.
|
* @throws std::runtime_error If SUNContext_Create fails.
|
||||||
*/
|
*/
|
||||||
explicit CVODESolverStrategy(
|
explicit PointSolver(
|
||||||
const engine::DynamicEngine& engine,
|
const engine::DynamicEngine& engine
|
||||||
const engine::scratch::StateBlob& ctx
|
|
||||||
);
|
);
|
||||||
/**
|
|
||||||
* @brief Destructor: cleans CVODE/SUNDIALS resources and frees SUNContext.
|
|
||||||
*/
|
|
||||||
~CVODESolverStrategy() override;
|
|
||||||
|
|
||||||
// Make the class non-copyable and non-movable to prevent shallow copies of CVODE pointers
|
|
||||||
CVODESolverStrategy(const CVODESolverStrategy&) = delete;
|
|
||||||
CVODESolverStrategy& operator=(const CVODESolverStrategy&) = delete;
|
|
||||||
CVODESolverStrategy(CVODESolverStrategy&&) = delete;
|
|
||||||
CVODESolverStrategy& operator=(CVODESolverStrategy&&) = delete;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Integrate from t=0 to netIn.tMax and return final composition and energy.
|
* @brief Integrate from t=0 to netIn.tMax and return final composition and energy.
|
||||||
@@ -114,6 +183,7 @@ namespace gridfire::solver {
|
|||||||
* - At the end, converts molar abundances to mass fractions and assembles NetOut,
|
* - At the end, converts molar abundances to mass fractions and assembles NetOut,
|
||||||
* including derivatives of energy w.r.t. T and rho from the engine.
|
* including derivatives of energy w.r.t. T and rho from the engine.
|
||||||
*
|
*
|
||||||
|
* @param solver_ctx
|
||||||
* @param netIn Inputs: temperature [K], density [g cm^-3], tMax [s], composition.
|
* @param netIn Inputs: temperature [K], density [g cm^-3], tMax [s], composition.
|
||||||
* @return NetOut containing final Composition, accumulated energy [erg/g], step count,
|
* @return NetOut containing final Composition, accumulated energy [erg/g], step count,
|
||||||
* and dEps/dT, dEps/dRho.
|
* and dEps/dT, dEps/dRho.
|
||||||
@@ -122,10 +192,14 @@ namespace gridfire::solver {
|
|||||||
* @throws exceptions::StaleEngineTrigger Propagated if the engine signals a stale state
|
* @throws exceptions::StaleEngineTrigger Propagated if the engine signals a stale state
|
||||||
* during RHS evaluation (captured in the wrapper then rethrown here).
|
* during RHS evaluation (captured in the wrapper then rethrown here).
|
||||||
*/
|
*/
|
||||||
NetOut evaluate(const NetIn& netIn) override;
|
NetOut evaluate(
|
||||||
|
SolverContextBase& solver_ctx,
|
||||||
|
const NetIn& netIn
|
||||||
|
) const override;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Call to evaluate which will let the user control if the trigger reasoning is displayed
|
* @brief Call to evaluate which will let the user control if the trigger reasoning is displayed
|
||||||
|
* @param solver_ctx
|
||||||
* @param netIn Inputs: temperature [K], density [g cm^-3], tMax [s], composition.
|
* @param netIn Inputs: temperature [K], density [g cm^-3], tMax [s], composition.
|
||||||
* @param displayTrigger Boolean flag to control if trigger reasoning is displayed
|
* @param displayTrigger Boolean flag to control if trigger reasoning is displayed
|
||||||
* @param forceReinitialize Boolean flag to force reinitialization of CVODE resources at the start
|
* @param forceReinitialize Boolean flag to force reinitialization of CVODE resources at the start
|
||||||
@@ -136,89 +210,13 @@ namespace gridfire::solver {
|
|||||||
* @throws exceptions::StaleEngineTrigger Propagated if the engine signals a stale state
|
* @throws exceptions::StaleEngineTrigger Propagated if the engine signals a stale state
|
||||||
* during RHS evaluation (captured in the wrapper then rethrown here).
|
* during RHS evaluation (captured in the wrapper then rethrown here).
|
||||||
*/
|
*/
|
||||||
NetOut evaluate(const NetIn& netIn, bool displayTrigger, bool forceReinitialize = false);
|
NetOut evaluate(
|
||||||
|
SolverContextBase& solver_ctx,
|
||||||
|
const NetIn& netIn,
|
||||||
|
bool displayTrigger,
|
||||||
|
bool forceReinitialize = false
|
||||||
|
) const;
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Install a timestep callback.
|
|
||||||
* @param callback std::any containing TimestepCallback (std::function<void(const TimestepContext&)>).
|
|
||||||
* @throws std::bad_any_cast If callback is not of the expected type.
|
|
||||||
*/
|
|
||||||
void set_callback(const std::any &callback) override;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Whether per-step logs are printed to stdout and CVode is stepped with CV_ONE_STEP.
|
|
||||||
*/
|
|
||||||
[[nodiscard]] bool get_stdout_logging_enabled() const;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Enable/disable per-step stdout logging.
|
|
||||||
* @param logging_enabled Flag to control if a timestep summary is written to standard output or not
|
|
||||||
*/
|
|
||||||
void set_stdout_logging_enabled(bool logging_enabled);
|
|
||||||
|
|
||||||
void set_absTol(double absTol);
|
|
||||||
void set_relTol(double relTol);
|
|
||||||
|
|
||||||
double get_absTol() const;
|
|
||||||
double get_relTol() const;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Schema of fields exposed to the timestep callback context.
|
|
||||||
*/
|
|
||||||
[[nodiscard]] std::vector<std::tuple<std::string, std::string>> describe_callback_context() const override;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @struct TimestepContext
|
|
||||||
* @brief Immutable view of the current integration state passed to callbacks.
|
|
||||||
*
|
|
||||||
* Fields capture CVODE time/state, step size, thermodynamic state, the engine reference,
|
|
||||||
* and the list of network species used to interpret the state vector layout.
|
|
||||||
*/
|
|
||||||
struct TimestepContext final : public SolverContextBase {
|
|
||||||
// This struct can be identical to the one in DirectNetworkSolver
|
|
||||||
const double t; ///< Current integration time [s].
|
|
||||||
const N_Vector& state; ///< Current CVODE state vector (N_Vector).
|
|
||||||
const double dt; ///< Last step size [s].
|
|
||||||
const double last_step_time; ///< Time at last callback [s].
|
|
||||||
const double T9; ///< Temperature in GK.
|
|
||||||
const double rho; ///< Density [g cm^-3].
|
|
||||||
const size_t num_steps; ///< Number of CVODE steps taken so far.
|
|
||||||
const engine::DynamicEngine& engine; ///< Reference to the engine.
|
|
||||||
const std::vector<fourdst::atomic::Species>& networkSpecies; ///< Species layout.
|
|
||||||
const size_t currentConvergenceFailures; ///< Total number of convergence failures
|
|
||||||
const size_t currentNonlinearIterations; ///< Total number of non-linear iterations
|
|
||||||
const std::map<fourdst::atomic::Species, std::unordered_map<std::string, double>>& reactionContributionMap; ///< Map of reaction contributions for the current step
|
|
||||||
engine::scratch::StateBlob& state_ctx; ///< Reference to the engine scratch state blob
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Construct a context snapshot.
|
|
||||||
*/
|
|
||||||
TimestepContext(
|
|
||||||
double t,
|
|
||||||
const N_Vector& state,
|
|
||||||
double dt,
|
|
||||||
double last_step_time,
|
|
||||||
double t9,
|
|
||||||
double rho,
|
|
||||||
size_t num_steps,
|
|
||||||
const engine::DynamicEngine& engine,
|
|
||||||
const std::vector<fourdst::atomic::Species>& networkSpecies,
|
|
||||||
size_t currentConvergenceFailure,
|
|
||||||
size_t currentNonlinearIterations,
|
|
||||||
const std::map<fourdst::atomic::Species, std::unordered_map<std::string, double>> &reactionContributionMap,
|
|
||||||
engine::scratch::StateBlob& state_ctx
|
|
||||||
);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Human-readable description of the context fields.
|
|
||||||
*/
|
|
||||||
[[nodiscard]] std::vector<std::tuple<std::string, std::string>> describe() const override;
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Type alias for a timestep callback.
|
|
||||||
*/
|
|
||||||
using TimestepCallback = std::function<void(const TimestepContext& context)>; ///< Type alias for a timestep callback function.
|
|
||||||
private:
|
private:
|
||||||
/**
|
/**
|
||||||
* @struct CVODEUserData
|
* @struct CVODEUserData
|
||||||
@@ -230,7 +228,8 @@ namespace gridfire::solver {
|
|||||||
* to CVODE, then the driver loop inspects and rethrows.
|
* to CVODE, then the driver loop inspects and rethrows.
|
||||||
*/
|
*/
|
||||||
struct CVODEUserData {
|
struct CVODEUserData {
|
||||||
CVODESolverStrategy* solver_instance{}; // Pointer back to the class instance
|
const PointSolver* solver_instance{}; // Pointer back to the class instance
|
||||||
|
PointSolverContext* sctx; // Pointer to the solver context
|
||||||
engine::scratch::StateBlob& ctx;
|
engine::scratch::StateBlob& ctx;
|
||||||
const engine::DynamicEngine* engine{};
|
const engine::DynamicEngine* engine{};
|
||||||
double T9{};
|
double T9{};
|
||||||
@@ -283,6 +282,7 @@ namespace gridfire::solver {
|
|||||||
* step size, creates a dense matrix and dense linear solver, and registers the Jacobian.
|
* step size, creates a dense matrix and dense linear solver, and registers the Jacobian.
|
||||||
*/
|
*/
|
||||||
void initialize_cvode_integration_resources(
|
void initialize_cvode_integration_resources(
|
||||||
|
PointSolverContext* ctx,
|
||||||
uint64_t N,
|
uint64_t N,
|
||||||
size_t numSpecies,
|
size_t numSpecies,
|
||||||
double current_time,
|
double current_time,
|
||||||
@@ -290,15 +290,7 @@ namespace gridfire::solver {
|
|||||||
double absTol,
|
double absTol,
|
||||||
double relTol,
|
double relTol,
|
||||||
double accumulatedEnergy
|
double accumulatedEnergy
|
||||||
);
|
) const;
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Destroy CVODE vectors/linear algebra and optionally the CVODE memory block.
|
|
||||||
* @param memFree If true, also calls CVodeFree on m_cvode_mem.
|
|
||||||
*/
|
|
||||||
void cleanup_cvode_resources(bool memFree);
|
|
||||||
|
|
||||||
void set_detailed_step_logging(bool enabled);
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -308,31 +300,13 @@ namespace gridfire::solver {
|
|||||||
* sorted table of species with the highest error ratios; then invokes diagnostic routines to
|
* sorted table of species with the highest error ratios; then invokes diagnostic routines to
|
||||||
* inspect Jacobian stiffness and species balance.
|
* inspect Jacobian stiffness and species balance.
|
||||||
*/
|
*/
|
||||||
void log_step_diagnostics(engine::scratch::StateBlob &ctx, const CVODEUserData& user_data, bool displayJacobianStiffness, bool
|
void log_step_diagnostics(
|
||||||
displaySpeciesBalance, bool to_file, std::optional<std::string> filename) const;
|
PointSolverContext* sctx_p,
|
||||||
private:
|
engine::scratch::StateBlob &ctx,
|
||||||
SUNContext m_sun_ctx = nullptr; ///< SUNDIALS context (lifetime of the solver).
|
const CVODEUserData& user_data,
|
||||||
void* m_cvode_mem = nullptr; ///< CVODE memory block.
|
bool displayJacobianStiffness,
|
||||||
N_Vector m_Y = nullptr; ///< CVODE state vector (species + energy accumulator).
|
bool displaySpeciesBalance,
|
||||||
N_Vector m_YErr = nullptr; ///< Estimated local errors.
|
bool to_file, std::optional<std::string> filename
|
||||||
SUNMatrix m_J = nullptr; ///< Dense Jacobian matrix.
|
) const;
|
||||||
SUNLinearSolver m_LS = nullptr; ///< Dense linear solver.
|
|
||||||
|
|
||||||
|
|
||||||
std::optional<TimestepCallback> m_callback; ///< Optional per-step callback.
|
|
||||||
int m_num_steps = 0; ///< CVODE step counter (used for diagnostics and triggers).
|
|
||||||
|
|
||||||
bool m_stdout_logging_enabled = true; ///< If true, print per-step logs and use CV_ONE_STEP.
|
|
||||||
|
|
||||||
N_Vector m_constraints = nullptr; ///< CVODE constraints vector (>= 0 for species entries).
|
|
||||||
|
|
||||||
std::optional<double> m_absTol; ///< User-specified absolute tolerance.
|
|
||||||
std::optional<double> m_relTol; ///< User-specified relative tolerance.
|
|
||||||
|
|
||||||
bool m_detailed_step_logging = false; ///< If true, log detailed step diagnostics (error ratios, Jacobian, species balance).
|
|
||||||
|
|
||||||
mutable size_t m_last_size = 0;
|
|
||||||
mutable size_t m_last_composition_hash = 0ULL;
|
|
||||||
mutable sunrealtype m_last_good_time_step = 0ULL;
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@@ -1,221 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "gridfire/solver/strategies/strategy_abstract.h"
|
|
||||||
#include "gridfire/engine/engine_abstract.h"
|
|
||||||
#include "gridfire/types/types.h"
|
|
||||||
#include "gridfire/config/config.h"
|
|
||||||
|
|
||||||
#include "fourdst/logging/logging.h"
|
|
||||||
#include "fourdst/constants/const.h"
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
#include <functional>
|
|
||||||
#include <cvode/cvode.h>
|
|
||||||
#include <sundials/sundials_types.h>
|
|
||||||
|
|
||||||
#include "gridfire/exceptions/error_engine.h"
|
|
||||||
|
|
||||||
#ifdef SUNDIALS_HAVE_OPENMP
|
|
||||||
#include <nvector/nvector_openmp.h>
|
|
||||||
#endif
|
|
||||||
#ifdef SUNDIALS_HAVE_PTHREADS
|
|
||||||
#include <nvector/nvector_pthreads.hh>
|
|
||||||
#endif
|
|
||||||
#ifndef SUNDIALS_HAVE_OPENMP
|
|
||||||
#ifndef SUNDIALS_HAVE_PTHREADS
|
|
||||||
#include <nvector/nvector_serial.h>
|
|
||||||
#endif
|
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace gridfire::solver {
|
|
||||||
class SpectralSolverStrategy final : public MultiZoneDynamicNetworkSolver {
|
|
||||||
public:
|
|
||||||
explicit SpectralSolverStrategy(const engine::DynamicEngine& engine);
|
|
||||||
~SpectralSolverStrategy() override;
|
|
||||||
|
|
||||||
std::vector<NetOut> evaluate(
|
|
||||||
const std::vector<NetIn> &netIns,
|
|
||||||
const std::vector<double>& mass_coords, const engine::scratch::StateBlob &ctx_template
|
|
||||||
) override;
|
|
||||||
|
|
||||||
void set_callback(const std::any &callback) override;
|
|
||||||
[[nodiscard]] std::vector<std::tuple<std::string, std::string>> describe_callback_context() const override;
|
|
||||||
|
|
||||||
[[nodiscard]] bool get_stdout_logging_enabled() const;
|
|
||||||
void set_stdout_logging_enabled(bool logging_enabled);
|
|
||||||
|
|
||||||
public:
|
|
||||||
struct TimestepContext final : public SolverContextBase {
|
|
||||||
TimestepContext(
|
|
||||||
const double t,
|
|
||||||
const N_Vector &state,
|
|
||||||
const double dt,
|
|
||||||
const double last_time_step,
|
|
||||||
const engine::DynamicEngine &engine
|
|
||||||
) :
|
|
||||||
t(t),
|
|
||||||
state(state),
|
|
||||||
dt(dt),
|
|
||||||
last_time_step(last_time_step),
|
|
||||||
engine(engine) {}
|
|
||||||
|
|
||||||
[[nodiscard]] std::vector<std::tuple<std::string, std::string>> describe() const override;
|
|
||||||
|
|
||||||
const double t;
|
|
||||||
const N_Vector& state;
|
|
||||||
const double dt;
|
|
||||||
const double last_time_step;
|
|
||||||
const engine::DynamicEngine& engine;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct BasisEval {
|
|
||||||
size_t start_idx;
|
|
||||||
std::vector<double> phi;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct SplineBasis {
|
|
||||||
std::vector<double> knots;
|
|
||||||
std::vector<double> quadrature_nodes;
|
|
||||||
std::vector<double> quadrature_weights;
|
|
||||||
int degree = 3;
|
|
||||||
|
|
||||||
|
|
||||||
std::vector<BasisEval> quad_evals;
|
|
||||||
};
|
|
||||||
public:
|
|
||||||
using TimestepCallback = std::function<void(const TimestepContext&)>;
|
|
||||||
private:
|
|
||||||
|
|
||||||
enum class ParallelInitializationResult : uint8_t {
|
|
||||||
SUCCESS,
|
|
||||||
FAILURE
|
|
||||||
};
|
|
||||||
|
|
||||||
struct SpectralCoefficients {
|
|
||||||
size_t num_sets;
|
|
||||||
size_t num_coefficients;
|
|
||||||
std::vector<double> coefficients;
|
|
||||||
|
|
||||||
double operator()(size_t i, size_t j) const;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct GridPoint {
|
|
||||||
double T9;
|
|
||||||
double rho;
|
|
||||||
fourdst::composition::Composition composition;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Constants {
|
|
||||||
const double c = fourdst::constant::Constants::getInstance().get("c").value;
|
|
||||||
const double N_a = fourdst::constant::Constants::getInstance().get("N_a").value;
|
|
||||||
const double c2 = c * c;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct DenseLinearSolver {
|
|
||||||
SUNMatrix A;
|
|
||||||
SUNLinearSolver LS;
|
|
||||||
N_Vector temp_vector;
|
|
||||||
SUNContext ctx;
|
|
||||||
|
|
||||||
DenseLinearSolver(size_t size, SUNContext sun_ctx);
|
|
||||||
~DenseLinearSolver();
|
|
||||||
|
|
||||||
DenseLinearSolver(const DenseLinearSolver&) = delete;
|
|
||||||
DenseLinearSolver& operator=(const DenseLinearSolver&) = delete;
|
|
||||||
|
|
||||||
void setup() const;
|
|
||||||
void zero() const;
|
|
||||||
|
|
||||||
void init_from_cache(size_t num_basis_funcs, const std::vector<BasisEval>& shell_cache) const;
|
|
||||||
void init_from_basis(size_t num_basis_funcs, const SplineBasis& basis) const;
|
|
||||||
|
|
||||||
void solve_inplace(N_Vector x, size_t num_vars, size_t basis_size) const;
|
|
||||||
void solve_inplace_ptr(sunrealtype* data_ptr, size_t num_vars, size_t basis_size) const;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct CVODEUserData {
|
|
||||||
SpectralSolverStrategy* solver_instance{};
|
|
||||||
std::vector<std::reference_wrapper<engine::scratch::StateBlob>> workspaces;
|
|
||||||
const engine::DynamicEngine* engine{};
|
|
||||||
std::unique_ptr<exceptions::EngineError> captured_exception{};
|
|
||||||
|
|
||||||
std::vector<double> T9{};
|
|
||||||
std::vector<double> rho{};
|
|
||||||
double energy{};
|
|
||||||
|
|
||||||
double neutrino_energy_loss_rate = 0.0;
|
|
||||||
double total_neutrino_flux = 0.0;
|
|
||||||
|
|
||||||
DenseLinearSolver* mass_matrix_solver_instance{};
|
|
||||||
const SplineBasis* basis{};
|
|
||||||
};
|
|
||||||
|
|
||||||
private:
|
|
||||||
fourdst::config::Config<config::GridFireConfig> m_config;
|
|
||||||
quill::Logger* m_logger = fourdst::logging::LogManager::getInstance().getLogger("log");
|
|
||||||
|
|
||||||
SUNContext m_sun_ctx = nullptr; ///< SUNDIALS context (lifetime of the solver).
|
|
||||||
void* m_cvode_mem = nullptr; ///< CVODE memory block.
|
|
||||||
N_Vector m_Y = nullptr; ///< CVODE state vector (species + energy accumulator).
|
|
||||||
SUNMatrix m_J = nullptr; ///< Dense Jacobian matrix.
|
|
||||||
SUNLinearSolver m_LS = nullptr; ///< Dense linear solver.
|
|
||||||
|
|
||||||
|
|
||||||
std::optional<TimestepCallback> m_callback; ///< Optional per-step callback.
|
|
||||||
int m_num_steps = 0; ///< CVODE step counter (used for diagnostics and triggers).
|
|
||||||
|
|
||||||
bool m_stdout_logging_enabled = true; ///< If true, print per-step logs and use CV_ONE_STEP.
|
|
||||||
|
|
||||||
N_Vector m_constraints = nullptr; ///< CVODE constraints vector (>= 0 for species entries).
|
|
||||||
|
|
||||||
std::optional<double> m_absTol; ///< User-specified absolute tolerance.
|
|
||||||
std::optional<double> m_relTol; ///< User-specified relative tolerance.
|
|
||||||
|
|
||||||
bool m_detailed_step_logging = false; ///< If true, log detailed step diagnostics (error ratios, Jacobian, species balance).
|
|
||||||
|
|
||||||
mutable size_t m_last_size = 0;
|
|
||||||
mutable size_t m_last_composition_hash = 0ULL;
|
|
||||||
mutable sunrealtype m_last_good_time_step = 0ULL;
|
|
||||||
|
|
||||||
SplineBasis m_current_basis;
|
|
||||||
|
|
||||||
Constants m_constants;
|
|
||||||
|
|
||||||
N_Vector m_T_coeffs = nullptr;
|
|
||||||
N_Vector m_rho_coeffs = nullptr;
|
|
||||||
|
|
||||||
std::vector<fourdst::atomic::Species> m_global_species_list;
|
|
||||||
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::vector<double> evaluate_monitor_function(const std::vector<NetIn>& current_shells) const;
|
|
||||||
|
|
||||||
static SplineBasis generate_basis_from_monitor(const std::vector<double>& monitor_values, const std::vector<double>& mass_coordinates, size_t actual_elements);
|
|
||||||
|
|
||||||
GridPoint reconstruct_at_quadrature(const N_Vector y_coeffs, size_t quad_index, const SplineBasis &basis) const;
|
|
||||||
|
|
||||||
std::vector<NetOut> reconstruct_solution(const std::vector<NetIn>& original_inputs, const std::vector<double>& mass_coordinates, const N_Vector final_coeffs, const SplineBasis& basis, double dt) const;
|
|
||||||
|
|
||||||
static int cvode_rhs_wrapper(sunrealtype t, N_Vector y, N_Vector, void* user_data);
|
|
||||||
static int cvode_jac_wrapper(sunrealtype t, N_Vector y, N_Vector ydot, SUNMatrix J, void* user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3);
|
|
||||||
|
|
||||||
int calculate_rhs(sunrealtype t, N_Vector y_coeffs, N_Vector ydot_coeffs, CVODEUserData* data) const;
|
|
||||||
int calculate_jacobian(sunrealtype t, N_Vector y_coeffs, N_Vector ydot_coeffs, SUNMatrix J, const CVODEUserData *data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3) const;
|
|
||||||
|
|
||||||
static size_t nyquist_elements(size_t requested_elements, size_t num_shells) ;
|
|
||||||
|
|
||||||
static void project_specific_variable(
|
|
||||||
const std::vector<NetIn>& current_shells,
|
|
||||||
const std::vector<double>& mass_coordinates,
|
|
||||||
const std::vector<BasisEval>& shell_cache,
|
|
||||||
const DenseLinearSolver& linear_solver,
|
|
||||||
N_Vector output_vec,
|
|
||||||
size_t output_offset,
|
|
||||||
const std::function<double(const NetIn&)> &getter,
|
|
||||||
bool use_log
|
|
||||||
);
|
|
||||||
|
|
||||||
void inspect_jacobian(SUNMatrix J, const std::string& context) const;
|
|
||||||
};
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -2,5 +2,5 @@
|
|||||||
|
|
||||||
#include "gridfire/solver/strategies/triggers/triggers.h"
|
#include "gridfire/solver/strategies/triggers/triggers.h"
|
||||||
#include "gridfire/solver/strategies/strategy_abstract.h"
|
#include "gridfire/solver/strategies/strategy_abstract.h"
|
||||||
#include "gridfire/solver/strategies/CVODE_solver_strategy.h"
|
#include "gridfire/solver/strategies/PointSolver.h"
|
||||||
#include "gridfire/solver/strategies/SpectralSolverStrategy.h"
|
#include "gridfire/solver/strategies/GridSolver.h"
|
||||||
@@ -13,17 +13,24 @@ namespace gridfire::solver {
|
|||||||
template <typename EngineT>
|
template <typename EngineT>
|
||||||
concept IsEngine = std::is_base_of_v<engine::Engine, EngineT>;
|
concept IsEngine = std::is_base_of_v<engine::Engine, EngineT>;
|
||||||
|
|
||||||
|
struct SolverContextBase {
|
||||||
|
virtual void init() = 0;
|
||||||
|
virtual void set_stdout_logging(bool enable) = 0;
|
||||||
|
virtual void set_detailed_logging(bool enable) = 0;
|
||||||
|
virtual ~SolverContextBase() = default;
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @struct SolverContextBase
|
* @struct TimestepContextBase
|
||||||
* @brief Base class for solver callback contexts.
|
* @brief Base class for solver callback contexts.
|
||||||
*
|
*
|
||||||
* This struct serves as a base class for contexts that can be passed to solver callbacks, it enforces
|
* This struct serves as a base class for contexts that can be passed to solver callbacks, it enforces
|
||||||
* that derived classes implement a `describe` method that returns a vector of tuples describing
|
* that derived classes implement a `describe` method that returns a vector of tuples describing
|
||||||
* the context that a callback will receive when called.
|
* the context that a callback will receive when called.
|
||||||
*/
|
*/
|
||||||
class SolverContextBase {
|
class TimestepContextBase {
|
||||||
public:
|
public:
|
||||||
virtual ~SolverContextBase() = default;
|
virtual ~TimestepContextBase() = default;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Describe the context for callback functions.
|
* @brief Describe the context for callback functions.
|
||||||
@@ -54,11 +61,9 @@ namespace gridfire::solver {
|
|||||||
* @param engine The engine to use for evaluating the network.
|
* @param engine The engine to use for evaluating the network.
|
||||||
*/
|
*/
|
||||||
explicit SingleZoneNetworkSolver(
|
explicit SingleZoneNetworkSolver(
|
||||||
const EngineT& engine,
|
const EngineT& engine
|
||||||
const engine::scratch::StateBlob& ctx
|
|
||||||
) :
|
) :
|
||||||
m_engine(engine),
|
m_engine(engine) {};
|
||||||
m_scratch_blob(ctx.clone_structure()) {};
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Virtual destructor.
|
* @brief Virtual destructor.
|
||||||
@@ -67,58 +72,39 @@ namespace gridfire::solver {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Evaluates the network for a given timestep.
|
* @brief Evaluates the network for a given timestep.
|
||||||
|
* @param solver_ctx
|
||||||
|
* @param engine_ctx
|
||||||
* @param netIn The input conditions for the network.
|
* @param netIn The input conditions for the network.
|
||||||
* @return The output conditions after the timestep.
|
* @return The output conditions after the timestep.
|
||||||
*/
|
*/
|
||||||
virtual NetOut evaluate(const NetIn& netIn) = 0;
|
virtual NetOut evaluate(
|
||||||
|
SolverContextBase& solver_ctx,
|
||||||
|
const NetIn& netIn
|
||||||
|
) const = 0;
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief set the callback function to be called at the end of each timestep.
|
|
||||||
*
|
|
||||||
* This function allows the user to set a callback function that will be called at the end of each timestep.
|
|
||||||
* The callback function will receive a gridfire::solver::<SOMESOLVER>::TimestepContext object. Note that
|
|
||||||
* depending on the solver, this context may contain different information. Further, the exact
|
|
||||||
* signature of the callback function is left up to each solver. Every solver should provide a type or type alias
|
|
||||||
* TimestepCallback that defines the signature of the callback function so that the user can easily
|
|
||||||
* get that type information.
|
|
||||||
*
|
|
||||||
* @param callback The callback function to be called at the end of each timestep.
|
|
||||||
*/
|
|
||||||
virtual void set_callback(const std::any& callback) = 0;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Describe the context that will be passed to the callback function.
|
|
||||||
* @return A vector of tuples, each containing a string for the parameter's name and a string for its type.
|
|
||||||
*
|
|
||||||
* This method should be overridden by derived classes to provide a description of the context
|
|
||||||
* that will be passed to the callback function. The intent of this method is that an end user can investigate
|
|
||||||
* the context that will be passed to the callback function, and use this information to craft their own
|
|
||||||
* callback function.
|
|
||||||
*/
|
|
||||||
[[nodiscard]] virtual std::vector<std::tuple<std::string, std::string>> describe_callback_context() const = 0;
|
|
||||||
protected:
|
protected:
|
||||||
const EngineT& m_engine; ///< The engine used by this solver strategy.
|
const EngineT& m_engine; ///< The engine used by this solver strategy.
|
||||||
std::unique_ptr<engine::scratch::StateBlob> m_scratch_blob;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template <IsEngine EngineT>
|
template <IsEngine EngineT>
|
||||||
class MultiZoneNetworkSolver {
|
class MultiZoneNetworkSolver {
|
||||||
public:
|
public:
|
||||||
explicit MultiZoneNetworkSolver(
|
explicit MultiZoneNetworkSolver(
|
||||||
const EngineT& engine
|
const EngineT& engine,
|
||||||
|
const SingleZoneNetworkSolver<EngineT>& solver
|
||||||
) :
|
) :
|
||||||
m_engine(engine) {};
|
m_engine(engine),
|
||||||
|
m_solver(solver) {};
|
||||||
|
|
||||||
virtual ~MultiZoneNetworkSolver() = default;
|
virtual ~MultiZoneNetworkSolver() = default;
|
||||||
|
|
||||||
virtual std::vector<NetOut> evaluate(
|
virtual std::vector<NetOut> evaluate(
|
||||||
const std::vector<NetIn>& netIns,
|
SolverContextBase& solver_ctx,
|
||||||
const std::vector<double>& mass_coords, const engine::scratch::StateBlob &ctx_template
|
const std::vector<NetIn>& netIns
|
||||||
) = 0;
|
) const = 0;
|
||||||
virtual void set_callback(const std::any& callback) = 0;
|
|
||||||
[[nodiscard]] virtual std::vector<std::tuple<std::string, std::string>> describe_callback_context() const = 0;
|
|
||||||
protected:
|
protected:
|
||||||
const EngineT& m_engine; ///< The engine used by this solver strategy.
|
const EngineT& m_engine; ///< The engine used by this solver strategy.
|
||||||
|
const SingleZoneNetworkSolver<EngineT>& m_solver;
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
#include "gridfire/trigger/trigger_abstract.h"
|
#include "gridfire/trigger/trigger_abstract.h"
|
||||||
#include "gridfire/trigger/trigger_result.h"
|
#include "gridfire/trigger/trigger_result.h"
|
||||||
#include "gridfire/solver/strategies/CVODE_solver_strategy.h"
|
#include "gridfire/solver/strategies/PointSolver.h"
|
||||||
#include "fourdst/logging/logging.h"
|
#include "fourdst/logging/logging.h"
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
@@ -47,7 +47,7 @@ namespace gridfire::trigger::solver::CVODE {
|
|||||||
*
|
*
|
||||||
* See also: engine_partitioning_trigger.cpp for the concrete logic and logging.
|
* See also: engine_partitioning_trigger.cpp for the concrete logic and logging.
|
||||||
*/
|
*/
|
||||||
class SimulationTimeTrigger final : public Trigger<gridfire::solver::CVODESolverStrategy::TimestepContext> {
|
class SimulationTimeTrigger final : public Trigger<gridfire::solver::PointSolverTimestepContext> {
|
||||||
public:
|
public:
|
||||||
/**
|
/**
|
||||||
* @brief Construct with a positive time interval between firings.
|
* @brief Construct with a positive time interval between firings.
|
||||||
@@ -62,7 +62,7 @@ namespace gridfire::trigger::solver::CVODE {
|
|||||||
*
|
*
|
||||||
* @post increments hit/miss counters and may emit trace logs.
|
* @post increments hit/miss counters and may emit trace logs.
|
||||||
*/
|
*/
|
||||||
bool check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override;
|
bool check(const gridfire::solver::PointSolverTimestepContext &ctx) const override;
|
||||||
/**
|
/**
|
||||||
* @brief Update internal state; if check(ctx) is true, advance last_trigger_time.
|
* @brief Update internal state; if check(ctx) is true, advance last_trigger_time.
|
||||||
* @param ctx CVODE timestep context.
|
* @param ctx CVODE timestep context.
|
||||||
@@ -70,9 +70,9 @@ namespace gridfire::trigger::solver::CVODE {
|
|||||||
* @note update() calls check(ctx) and, on success, records the overshoot delta
|
* @note update() calls check(ctx) and, on success, records the overshoot delta
|
||||||
* (ctx.t - last_trigger_time) - interval for diagnostics.
|
* (ctx.t - last_trigger_time) - interval for diagnostics.
|
||||||
*/
|
*/
|
||||||
void update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override;
|
void update(const gridfire::solver::PointSolverTimestepContext &ctx) override;
|
||||||
|
|
||||||
void step(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override;
|
void step(const gridfire::solver::PointSolverTimestepContext &ctx) override;
|
||||||
/**
|
/**
|
||||||
* @brief Reset counters and last trigger bookkeeping (time and delta) to zero.
|
* @brief Reset counters and last trigger bookkeeping (time and delta) to zero.
|
||||||
*/
|
*/
|
||||||
@@ -85,7 +85,7 @@ namespace gridfire::trigger::solver::CVODE {
|
|||||||
* @param ctx CVODE timestep context.
|
* @param ctx CVODE timestep context.
|
||||||
* @return TriggerResult including name, value, and description.
|
* @return TriggerResult including name, value, and description.
|
||||||
*/
|
*/
|
||||||
TriggerResult why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override;
|
TriggerResult why(const gridfire::solver::PointSolverTimestepContext &ctx) const override;
|
||||||
/** @brief Textual description including configured interval. */
|
/** @brief Textual description including configured interval. */
|
||||||
std::string describe() const override;
|
std::string describe() const override;
|
||||||
/** @brief Number of true evaluations since last reset. */
|
/** @brief Number of true evaluations since last reset. */
|
||||||
@@ -130,7 +130,7 @@ namespace gridfire::trigger::solver::CVODE {
|
|||||||
* @par See also
|
* @par See also
|
||||||
* - engine_partitioning_trigger.cpp for concrete logic and trace logging.
|
* - engine_partitioning_trigger.cpp for concrete logic and trace logging.
|
||||||
*/
|
*/
|
||||||
class OffDiagonalTrigger final : public Trigger<gridfire::solver::CVODESolverStrategy::TimestepContext> {
|
class OffDiagonalTrigger final : public Trigger<gridfire::solver::PointSolverTimestepContext> {
|
||||||
public:
|
public:
|
||||||
/**
|
/**
|
||||||
* @brief Construct with a non-negative magnitude threshold.
|
* @brief Construct with a non-negative magnitude threshold.
|
||||||
@@ -145,13 +145,13 @@ namespace gridfire::trigger::solver::CVODE {
|
|||||||
*
|
*
|
||||||
* @post increments hit/miss counters and may emit trace logs.
|
* @post increments hit/miss counters and may emit trace logs.
|
||||||
*/
|
*/
|
||||||
bool check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override;
|
bool check(const gridfire::solver::PointSolverTimestepContext &ctx) const override;
|
||||||
/**
|
/**
|
||||||
* @brief Record an update; does not mutate any Jacobian-related state.
|
* @brief Record an update; does not mutate any Jacobian-related state.
|
||||||
* @param ctx CVODE timestep context (unused except for symmetry with interface).
|
* @param ctx CVODE timestep context (unused except for symmetry with interface).
|
||||||
*/
|
*/
|
||||||
void update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override;
|
void update(const gridfire::solver::PointSolverTimestepContext &ctx) override;
|
||||||
void step(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override;
|
void step(const gridfire::solver::PointSolverTimestepContext &ctx) override;
|
||||||
/** @brief Reset counters to zero. */
|
/** @brief Reset counters to zero. */
|
||||||
void reset() override;
|
void reset() override;
|
||||||
|
|
||||||
@@ -161,7 +161,7 @@ namespace gridfire::trigger::solver::CVODE {
|
|||||||
* @brief Structured explanation of the evaluation outcome.
|
* @brief Structured explanation of the evaluation outcome.
|
||||||
* @param ctx CVODE timestep context.
|
* @param ctx CVODE timestep context.
|
||||||
*/
|
*/
|
||||||
TriggerResult why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override;
|
TriggerResult why(const gridfire::solver::PointSolverTimestepContext &ctx) const override;
|
||||||
/** @brief Textual description including configured threshold. */
|
/** @brief Textual description including configured threshold. */
|
||||||
std::string describe() const override;
|
std::string describe() const override;
|
||||||
/** @brief Number of true evaluations since last reset. */
|
/** @brief Number of true evaluations since last reset. */
|
||||||
@@ -206,7 +206,7 @@ namespace gridfire::trigger::solver::CVODE {
|
|||||||
*
|
*
|
||||||
* See also: engine_partitioning_trigger.cpp for exact logic and logging.
|
* See also: engine_partitioning_trigger.cpp for exact logic and logging.
|
||||||
*/
|
*/
|
||||||
class TimestepCollapseTrigger final : public Trigger<gridfire::solver::CVODESolverStrategy::TimestepContext> {
|
class TimestepCollapseTrigger final : public Trigger<gridfire::solver::PointSolverTimestepContext> {
|
||||||
public:
|
public:
|
||||||
/**
|
/**
|
||||||
* @brief Construct with threshold and relative/absolute mode; window size defaults to 1.
|
* @brief Construct with threshold and relative/absolute mode; window size defaults to 1.
|
||||||
@@ -230,20 +230,20 @@ namespace gridfire::trigger::solver::CVODE {
|
|||||||
*
|
*
|
||||||
* @post increments hit/miss counters and may emit trace logs.
|
* @post increments hit/miss counters and may emit trace logs.
|
||||||
*/
|
*/
|
||||||
bool check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override;
|
bool check(const gridfire::solver::PointSolverTimestepContext &ctx) const override;
|
||||||
/**
|
/**
|
||||||
* @brief Update sliding window with the most recent dt and increment update counter.
|
* @brief Update sliding window with the most recent dt and increment update counter.
|
||||||
* @param ctx CVODE timestep context.
|
* @param ctx CVODE timestep context.
|
||||||
*/
|
*/
|
||||||
void update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override;
|
void update(const gridfire::solver::PointSolverTimestepContext &ctx) override;
|
||||||
void step(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override;
|
void step(const gridfire::solver::PointSolverTimestepContext &ctx) override;
|
||||||
/** @brief Reset counters and clear the dt window. */
|
/** @brief Reset counters and clear the dt window. */
|
||||||
void reset() override;
|
void reset() override;
|
||||||
|
|
||||||
/** @brief Stable human-readable name. */
|
/** @brief Stable human-readable name. */
|
||||||
std::string name() const override;
|
std::string name() const override;
|
||||||
/** @brief Structured explanation of the evaluation outcome. */
|
/** @brief Structured explanation of the evaluation outcome. */
|
||||||
TriggerResult why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override;
|
TriggerResult why(const gridfire::solver::PointSolverTimestepContext &ctx) const override;
|
||||||
/** @brief Textual description including threshold, mode, and window size. */
|
/** @brief Textual description including threshold, mode, and window size. */
|
||||||
std::string describe() const override;
|
std::string describe() const override;
|
||||||
/** @brief Number of true evaluations since last reset. */
|
/** @brief Number of true evaluations since last reset. */
|
||||||
@@ -272,15 +272,15 @@ namespace gridfire::trigger::solver::CVODE {
|
|||||||
std::deque<double> m_timestep_window;
|
std::deque<double> m_timestep_window;
|
||||||
};
|
};
|
||||||
|
|
||||||
class ConvergenceFailureTrigger final : public Trigger<gridfire::solver::CVODESolverStrategy::TimestepContext> {
|
class ConvergenceFailureTrigger final : public Trigger<gridfire::solver::PointSolverTimestepContext> {
|
||||||
public:
|
public:
|
||||||
explicit ConvergenceFailureTrigger(size_t totalFailures, float relativeFailureRate, size_t windowSize);
|
explicit ConvergenceFailureTrigger(size_t totalFailures, float relativeFailureRate, size_t windowSize);
|
||||||
|
|
||||||
bool check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override;
|
bool check(const gridfire::solver::PointSolverTimestepContext &ctx) const override;
|
||||||
|
|
||||||
void update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override;
|
void update(const gridfire::solver::PointSolverTimestepContext &ctx) override;
|
||||||
|
|
||||||
void step(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override;
|
void step(const gridfire::solver::PointSolverTimestepContext &ctx) override;
|
||||||
|
|
||||||
void reset() override;
|
void reset() override;
|
||||||
|
|
||||||
@@ -288,7 +288,7 @@ namespace gridfire::trigger::solver::CVODE {
|
|||||||
|
|
||||||
[[nodiscard]] std::string describe() const override;
|
[[nodiscard]] std::string describe() const override;
|
||||||
|
|
||||||
[[nodiscard]] TriggerResult why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override;
|
[[nodiscard]] TriggerResult why(const gridfire::solver::PointSolverTimestepContext &ctx) const override;
|
||||||
|
|
||||||
[[nodiscard]] size_t numTriggers() const override;
|
[[nodiscard]] size_t numTriggers() const override;
|
||||||
|
|
||||||
@@ -312,8 +312,8 @@ namespace gridfire::trigger::solver::CVODE {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
float current_mean() const;
|
float current_mean() const;
|
||||||
bool abs_failure(const gridfire::solver::CVODESolverStrategy::TimestepContext& ctx) const;
|
bool abs_failure(const gridfire::solver::PointSolverTimestepContext& ctx) const;
|
||||||
bool rel_failure(const gridfire::solver::CVODESolverStrategy::TimestepContext& ctx) const;
|
bool rel_failure(const gridfire::solver::PointSolverTimestepContext& ctx) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -337,10 +337,10 @@ namespace gridfire::trigger::solver::CVODE {
|
|||||||
*
|
*
|
||||||
* @note The exact policy is subject to change; this function centralizes that decision.
|
* @note The exact policy is subject to change; this function centralizes that decision.
|
||||||
*/
|
*/
|
||||||
std::unique_ptr<Trigger<gridfire::solver::CVODESolverStrategy::TimestepContext>> makeEnginePartitioningTrigger(
|
std::unique_ptr<Trigger<gridfire::solver::PointSolverTimestepContext>> makeEnginePartitioningTrigger(
|
||||||
const double simulationTimeInterval,
|
double simulationTimeInterval,
|
||||||
const double offDiagonalThreshold,
|
double offDiagonalThreshold,
|
||||||
const double timestepCollapseRatio,
|
double timestepCollapseRatio,
|
||||||
const size_t maxConvergenceFailures
|
size_t maxConvergenceFailures
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ namespace gridfire::omp {
|
|||||||
);
|
);
|
||||||
|
|
||||||
CppAD::thread_alloc::hold_memory(true);
|
CppAD::thread_alloc::hold_memory(true);
|
||||||
|
CppAD::CheckSimpleVector<double, std::vector<double>>(0, 1);
|
||||||
s_par_mode_initialized = true;
|
s_par_mode_initialized = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -118,8 +118,7 @@ namespace gridfire::engine {
|
|||||||
m_weakRateInterpolator(rates::weak::UNIFIED_WEAK_DATA),
|
m_weakRateInterpolator(rates::weak::UNIFIED_WEAK_DATA),
|
||||||
m_reactions(build_nuclear_network(composition, m_weakRateInterpolator, buildDepth, reactionTypes)),
|
m_reactions(build_nuclear_network(composition, m_weakRateInterpolator, buildDepth, reactionTypes)),
|
||||||
m_partitionFunction(partitionFunction.clone()),
|
m_partitionFunction(partitionFunction.clone()),
|
||||||
m_depth(buildDepth),
|
m_depth(buildDepth)
|
||||||
m_state_blob_offset(0) // For a base engine the offset is always 0
|
|
||||||
{
|
{
|
||||||
syncInternalMaps();
|
syncInternalMaps();
|
||||||
}
|
}
|
||||||
@@ -128,8 +127,7 @@ namespace gridfire::engine {
|
|||||||
const reaction::ReactionSet &reactions
|
const reaction::ReactionSet &reactions
|
||||||
) :
|
) :
|
||||||
m_weakRateInterpolator(rates::weak::UNIFIED_WEAK_DATA),
|
m_weakRateInterpolator(rates::weak::UNIFIED_WEAK_DATA),
|
||||||
m_reactions(reactions),
|
m_reactions(reactions)
|
||||||
m_state_blob_offset(0)
|
|
||||||
{
|
{
|
||||||
syncInternalMaps();
|
syncInternalMaps();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
#include "fourdst/atomic/species.h"
|
#include "fourdst/atomic/species.h"
|
||||||
#include "fourdst/composition/utils.h"
|
#include "fourdst/composition/utils.h"
|
||||||
#include "gridfire/engine/views/engine_priming.h"
|
|
||||||
#include "gridfire/solver/solver.h"
|
#include "gridfire/solver/solver.h"
|
||||||
|
|
||||||
#include "gridfire/engine/engine_abstract.h"
|
#include "gridfire/engine/engine_abstract.h"
|
||||||
@@ -13,7 +12,7 @@
|
|||||||
#include "gridfire/engine/scratchpads/engine_graph_scratchpad.h"
|
#include "gridfire/engine/scratchpads/engine_graph_scratchpad.h"
|
||||||
|
|
||||||
#include "fourdst/logging/logging.h"
|
#include "fourdst/logging/logging.h"
|
||||||
#include "gridfire/solver/strategies/CVODE_solver_strategy.h"
|
#include "gridfire/solver/strategies/PointSolver.h"
|
||||||
#include "quill/Logger.h"
|
#include "quill/Logger.h"
|
||||||
#include "quill/LogMacros.h"
|
#include "quill/LogMacros.h"
|
||||||
|
|
||||||
@@ -28,13 +27,12 @@ namespace gridfire::engine {
|
|||||||
const GraphEngine& engine, const std::optional<std::vector<reaction::ReactionType>>& ignoredReactionTypes
|
const GraphEngine& engine, const std::optional<std::vector<reaction::ReactionType>>& ignoredReactionTypes
|
||||||
) {
|
) {
|
||||||
const auto logger = LogManager::getInstance().getLogger("log");
|
const auto logger = LogManager::getInstance().getLogger("log");
|
||||||
solver::CVODESolverStrategy integrator(engine, ctx);
|
solver::PointSolver integrator(engine);
|
||||||
|
solver::PointSolverContext solverCtx(ctx);
|
||||||
|
solverCtx.abs_tol = 1e-3;
|
||||||
|
solverCtx.rel_tol = 1e-3;
|
||||||
|
solverCtx.stdout_logging = false;
|
||||||
|
|
||||||
// Do not need high precision for priming
|
|
||||||
integrator.set_absTol(1e-3);
|
|
||||||
integrator.set_relTol(1e-3);
|
|
||||||
|
|
||||||
integrator.set_stdout_logging_enabled(false);
|
|
||||||
NetIn solverInput(netIn);
|
NetIn solverInput(netIn);
|
||||||
|
|
||||||
solverInput.tMax = 1e-15;
|
solverInput.tMax = 1e-15;
|
||||||
@@ -43,7 +41,7 @@ namespace gridfire::engine {
|
|||||||
LOG_INFO(logger, "Short timescale ({}) network ignition started.", solverInput.tMax);
|
LOG_INFO(logger, "Short timescale ({}) network ignition started.", solverInput.tMax);
|
||||||
PrimingReport report;
|
PrimingReport report;
|
||||||
try {
|
try {
|
||||||
const NetOut netOut = integrator.evaluate(solverInput, false);
|
const NetOut netOut = integrator.evaluate(solverCtx, solverInput);
|
||||||
LOG_INFO(logger, "Network ignition completed.");
|
LOG_INFO(logger, "Network ignition completed.");
|
||||||
LOG_TRACE_L2(
|
LOG_TRACE_L2(
|
||||||
logger,
|
logger,
|
||||||
|
|||||||
@@ -2005,7 +2005,32 @@ namespace gridfire::engine {
|
|||||||
LOG_INFO(getLogger(), "KINSol failed to converge within the maximum number of iterations, but achieved acceptable accuracy with function norm {} < {}. Proceeding with solution.",
|
LOG_INFO(getLogger(), "KINSol failed to converge within the maximum number of iterations, but achieved acceptable accuracy with function norm {} < {}. Proceeding with solution.",
|
||||||
fnorm, ACCEPTABLE_FTOL);
|
fnorm, ACCEPTABLE_FTOL);
|
||||||
} else {
|
} else {
|
||||||
LOG_WARNING(getLogger(), "KINSol failed to converge while solving QSE abundances with flag {}. Error {}", utils::kinsol_ret_code_map.at(flag), fnorm);
|
LOG_CRITICAL(getLogger(), "KINSol failed to converge while solving QSE abundances with flag {}. Flag No.: {}, Error (fNorm): {}", utils::kinsol_ret_code_map.at(flag), flag, fnorm);
|
||||||
|
LOG_CRITICAL(getLogger(), "State prior to failure: {}",
|
||||||
|
[&comp, &data]() -> std::string {
|
||||||
|
std::ostringstream oss;
|
||||||
|
oss << "Solve species: <";
|
||||||
|
size_t count = 0;
|
||||||
|
for (const auto& species : data.qse_solve_species) {
|
||||||
|
oss << species.name();
|
||||||
|
if (count < data.qse_solve_species.size() - 1) {
|
||||||
|
oss << ", ";
|
||||||
|
}
|
||||||
|
count++;
|
||||||
|
}
|
||||||
|
oss << "> | Abundances and rates at failure: ";
|
||||||
|
count = 0;
|
||||||
|
for (const auto& [species, abundance] : comp) {
|
||||||
|
oss << species.name() << ": Y = " << abundance;
|
||||||
|
if (count < comp.size() - 1) {
|
||||||
|
oss << ", ";
|
||||||
|
}
|
||||||
|
count++;
|
||||||
|
}
|
||||||
|
oss << " | Temperature: " << data.T9 << ", Density: " << data.rho;
|
||||||
|
return oss.str();
|
||||||
|
}()
|
||||||
|
);
|
||||||
throw exceptions::InvalidQSESolutionError("KINSol failed to converge while solving QSE abundances. " + utils::kinsol_ret_code_map.at(flag));
|
throw exceptions::InvalidQSESolutionError("KINSol failed to converge while solving QSE abundances. " + utils::kinsol_ret_code_map.at(flag));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
94
src/lib/solver/strategies/GridSolver.cpp
Normal file
94
src/lib/solver/strategies/GridSolver.cpp
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
#include "gridfire/solver/strategies/GridSolver.h"
|
||||||
|
|
||||||
|
#include "gridfire/exceptions/error_solver.h"
|
||||||
|
#include "gridfire/solver/strategies/PointSolver.h"
|
||||||
|
#include "gridfire/utils/macros.h"
|
||||||
|
#include "gridfire/utils/gf_omp.h"
|
||||||
|
|
||||||
|
#include <cstdio>
|
||||||
|
#include <print>
|
||||||
|
|
||||||
|
namespace gridfire::solver {
|
||||||
|
void GridSolverContext::init() {}
|
||||||
|
void GridSolverContext::reset() {
|
||||||
|
solver_workspaces.clear();
|
||||||
|
timestep_callbacks.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
void GridSolverContext::set_callback(const std::function<void(const TimestepContextBase &)> &callback) {
|
||||||
|
for (auto &cb : timestep_callbacks) {
|
||||||
|
cb = callback;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void GridSolverContext::set_callback(const std::function<void(const TimestepContextBase &)> &callback, const size_t zone_idx) {
|
||||||
|
if (zone_idx >= timestep_callbacks.size()) {
|
||||||
|
throw exceptions::SolverError("GridSolverContext::set_callback: zone_idx out of range.");
|
||||||
|
}
|
||||||
|
timestep_callbacks[zone_idx] = callback;
|
||||||
|
}
|
||||||
|
|
||||||
|
void GridSolverContext::set_stdout_logging(const bool enable) {
|
||||||
|
zone_stdout_logging = enable;
|
||||||
|
}
|
||||||
|
|
||||||
|
void GridSolverContext::set_detailed_logging(const bool enable) {
|
||||||
|
zone_detailed_logging = enable;
|
||||||
|
}
|
||||||
|
|
||||||
|
GridSolverContext::GridSolverContext(
|
||||||
|
const engine::scratch::StateBlob &ctx_template
|
||||||
|
) :
|
||||||
|
ctx_template(ctx_template) {}
|
||||||
|
|
||||||
|
|
||||||
|
GridSolver::GridSolver(
|
||||||
|
const engine::DynamicEngine &engine,
|
||||||
|
const SingleZoneDynamicNetworkSolver &solver
|
||||||
|
) :
|
||||||
|
MultiZoneNetworkSolver(engine, solver) {
|
||||||
|
GF_PAR_INIT();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<NetOut> GridSolver::evaluate(
|
||||||
|
SolverContextBase& ctx,
|
||||||
|
const std::vector<NetIn>& netIns
|
||||||
|
) const {
|
||||||
|
auto* sctx_p = dynamic_cast<GridSolverContext*>(&ctx);
|
||||||
|
if (!sctx_p) {
|
||||||
|
throw exceptions::SolverError("GridSolver::evaluate: SolverContextBase is not of type GridSolverContext.");
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t n_zones = netIns.size();
|
||||||
|
if (n_zones == 0) { return {}; }
|
||||||
|
|
||||||
|
std::vector<NetOut> results(n_zones);
|
||||||
|
|
||||||
|
sctx_p->solver_workspaces.resize(n_zones);
|
||||||
|
|
||||||
|
GF_OMP(
|
||||||
|
parallel for default(none) shared(sctx_p, n_zones),
|
||||||
|
for (size_t zone_idx = 0; zone_idx < n_zones; ++zone_idx)) {
|
||||||
|
sctx_p->solver_workspaces[zone_idx] = std::make_unique<PointSolverContext>(sctx_p->ctx_template);
|
||||||
|
sctx_p->solver_workspaces[zone_idx]->set_stdout_logging(sctx_p->zone_stdout_logging);
|
||||||
|
sctx_p->solver_workspaces[zone_idx]->set_detailed_logging(sctx_p->zone_detailed_logging);
|
||||||
|
}
|
||||||
|
|
||||||
|
GF_OMP(
|
||||||
|
parallel for default(none) shared(results, sctx_p, netIns, n_zones),
|
||||||
|
for (size_t zone_idx = 0; zone_idx < n_zones; ++zone_idx)) {
|
||||||
|
try {
|
||||||
|
results[zone_idx] = m_solver.evaluate(
|
||||||
|
*sctx_p->solver_workspaces[zone_idx],
|
||||||
|
netIns[zone_idx]
|
||||||
|
);
|
||||||
|
} catch (exceptions::GridFireError& e) {
|
||||||
|
std::println("CVODE Solver Failure in zone {}: {}", zone_idx, e.what());
|
||||||
|
}
|
||||||
|
if (sctx_p->zone_completion_logging) {
|
||||||
|
std::println("Thread {} completed zone {}", GF_OMP_THREAD_NUM, zone_idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return results;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
#include "gridfire/solver/strategies/CVODE_solver_strategy.h"
|
#include "gridfire/solver/strategies/PointSolver.h"
|
||||||
|
|
||||||
#include "gridfire/types/types.h"
|
#include "gridfire/types/types.h"
|
||||||
#include "gridfire/utils/table_format.h"
|
#include "gridfire/utils/table_format.h"
|
||||||
@@ -28,7 +28,7 @@
|
|||||||
namespace gridfire::solver {
|
namespace gridfire::solver {
|
||||||
using namespace gridfire::engine;
|
using namespace gridfire::engine;
|
||||||
|
|
||||||
CVODESolverStrategy::TimestepContext::TimestepContext(
|
PointSolverTimestepContext::PointSolverTimestepContext(
|
||||||
const double t,
|
const double t,
|
||||||
const N_Vector &state,
|
const N_Vector &state,
|
||||||
const double dt,
|
const double dt,
|
||||||
@@ -58,7 +58,7 @@ namespace gridfire::solver {
|
|||||||
state_ctx(ctx)
|
state_ctx(ctx)
|
||||||
{}
|
{}
|
||||||
|
|
||||||
std::vector<std::tuple<std::string, std::string>> CVODESolverStrategy::TimestepContext::describe() const {
|
std::vector<std::tuple<std::string, std::string>> PointSolverTimestepContext::describe() const {
|
||||||
std::vector<std::tuple<std::string, std::string>> description;
|
std::vector<std::tuple<std::string, std::string>> description;
|
||||||
description.emplace_back("t", "Current Time");
|
description.emplace_back("t", "Current Time");
|
||||||
description.emplace_back("state", "Current State Vector (N_Vector)");
|
description.emplace_back("state", "Current State Vector (N_Vector)");
|
||||||
@@ -74,36 +74,112 @@ namespace gridfire::solver {
|
|||||||
return description;
|
return description;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void PointSolverContext::init() {
|
||||||
|
reset_all();
|
||||||
|
init_context();
|
||||||
|
}
|
||||||
|
|
||||||
CVODESolverStrategy::CVODESolverStrategy(
|
void PointSolverContext::set_stdout_logging(const bool enable) {
|
||||||
const DynamicEngine &engine,
|
stdout_logging = enable;
|
||||||
const scratch::StateBlob& ctx
|
}
|
||||||
): SingleZoneNetworkSolver<DynamicEngine>(engine, ctx) {
|
|
||||||
// PERF: In order to support MPI this function must be changed
|
void PointSolverContext::set_detailed_logging(const bool enable) {
|
||||||
const int flag = SUNContext_Create(SUN_COMM_NULL, &m_sun_ctx);
|
detailed_step_logging = enable;
|
||||||
if (flag < 0) {
|
}
|
||||||
throw std::runtime_error("Failed to create SUNDIALS context (SUNDIALS Errno: " + std::to_string(flag) + ")");
|
|
||||||
|
void PointSolverContext::reset_all() {
|
||||||
|
reset_user();
|
||||||
|
reset_cvode();
|
||||||
|
}
|
||||||
|
|
||||||
|
void PointSolverContext::reset_user() {
|
||||||
|
callback.reset();
|
||||||
|
num_steps = 0;
|
||||||
|
stdout_logging = true;
|
||||||
|
abs_tol.reset();
|
||||||
|
rel_tol.reset();
|
||||||
|
detailed_step_logging = false;
|
||||||
|
last_size = 0;
|
||||||
|
last_composition_hash = 0ULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
void PointSolverContext::reset_cvode() {
|
||||||
|
if (LS) {
|
||||||
|
SUNLinSolFree(LS);
|
||||||
|
LS = nullptr;
|
||||||
|
}
|
||||||
|
if (J) {
|
||||||
|
SUNMatDestroy(J);
|
||||||
|
J = nullptr;
|
||||||
|
}
|
||||||
|
if (Y) {
|
||||||
|
N_VDestroy(Y);
|
||||||
|
Y = nullptr;
|
||||||
|
}
|
||||||
|
if (YErr) {
|
||||||
|
N_VDestroy(YErr);
|
||||||
|
YErr = nullptr;
|
||||||
|
}
|
||||||
|
if (constraints) {
|
||||||
|
N_VDestroy(constraints);
|
||||||
|
constraints = nullptr;
|
||||||
|
}
|
||||||
|
if (cvode_mem) {
|
||||||
|
CVodeFree(&cvode_mem);
|
||||||
|
cvode_mem = nullptr;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
CVODESolverStrategy::~CVODESolverStrategy() {
|
void PointSolverContext::clear_context() {
|
||||||
LOG_TRACE_L1(m_logger, "Cleaning up CVODE resources...");
|
if (sun_ctx) {
|
||||||
cleanup_cvode_resources(true);
|
SUNContext_Free(&sun_ctx);
|
||||||
|
sun_ctx = nullptr;
|
||||||
if (m_sun_ctx) {
|
|
||||||
SUNContext_Free(&m_sun_ctx);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
NetOut CVODESolverStrategy::evaluate(const NetIn& netIn) {
|
void PointSolverContext::init_context() {
|
||||||
return evaluate(netIn, false);
|
if (!sun_ctx) {
|
||||||
|
utils::check_sundials_flag(SUNContext_Create(SUN_COMM_NULL, &sun_ctx), "SUNContext_Create", utils::SUNDIALS_RET_CODE_TYPES::CVODE);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
NetOut CVODESolverStrategy::evaluate(
|
bool PointSolverContext::has_context() const {
|
||||||
|
return sun_ctx != nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
PointSolverContext::PointSolverContext(
|
||||||
|
const scratch::StateBlob& engine_ctx
|
||||||
|
) :
|
||||||
|
engine_ctx(engine_ctx.clone_structure())
|
||||||
|
{
|
||||||
|
utils::check_sundials_flag(SUNContext_Create(SUN_COMM_NULL, &sun_ctx), "SUNContext_Create", utils::SUNDIALS_RET_CODE_TYPES::CVODE);
|
||||||
|
}
|
||||||
|
|
||||||
|
PointSolverContext::~PointSolverContext() {
|
||||||
|
reset_cvode();
|
||||||
|
clear_context();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
PointSolver::PointSolver(
|
||||||
|
const DynamicEngine &engine
|
||||||
|
): SingleZoneNetworkSolver(engine) {}
|
||||||
|
|
||||||
|
NetOut PointSolver::evaluate(
|
||||||
|
SolverContextBase& solver_ctx,
|
||||||
|
const NetIn& netIn
|
||||||
|
) const {
|
||||||
|
return evaluate(solver_ctx, netIn, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
NetOut PointSolver::evaluate(
|
||||||
|
SolverContextBase& solver_ctx,
|
||||||
const NetIn &netIn,
|
const NetIn &netIn,
|
||||||
bool displayTrigger,
|
bool displayTrigger,
|
||||||
bool forceReinitialize
|
bool forceReinitialize
|
||||||
) {
|
) const {
|
||||||
|
auto* sctx_p = dynamic_cast<PointSolverContext*>(&solver_ctx);
|
||||||
|
|
||||||
LOG_TRACE_L1(m_logger, "Starting solver evaluation with T9: {} and rho: {}", netIn.temperature/1e9, netIn.density);
|
LOG_TRACE_L1(m_logger, "Starting solver evaluation with T9: {} and rho: {}", netIn.temperature/1e9, netIn.density);
|
||||||
LOG_TRACE_L1(m_logger, "Building engine update trigger....");
|
LOG_TRACE_L1(m_logger, "Building engine update trigger....");
|
||||||
auto trigger = trigger::solver::CVODE::makeEnginePartitioningTrigger(1e12, 1e10, 0.5, 2);
|
auto trigger = trigger::solver::CVODE::makeEnginePartitioningTrigger(1e12, 1e10, 0.5, 2);
|
||||||
@@ -117,23 +193,24 @@ namespace gridfire::solver {
|
|||||||
// 2. If the user has set tolerances in code, those override the config
|
// 2. If the user has set tolerances in code, those override the config
|
||||||
// 3. If the user has not set tolerances in code and the config does not have them, use hardcoded defaults
|
// 3. If the user has not set tolerances in code and the config does not have them, use hardcoded defaults
|
||||||
|
|
||||||
auto absTol = m_config->solver.cvode.absTol;
|
if (!sctx_p->abs_tol.has_value()) {
|
||||||
auto relTol = m_config->solver.cvode.relTol;
|
sctx_p->abs_tol = m_config->solver.cvode.absTol;
|
||||||
|
|
||||||
if (m_absTol) {
|
|
||||||
absTol = *m_absTol;
|
|
||||||
}
|
}
|
||||||
if (m_relTol) {
|
if (!sctx_p->rel_tol.has_value()) {
|
||||||
relTol = *m_relTol;
|
sctx_p->rel_tol = m_config->solver.cvode.relTol;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool resourcesExist = (m_cvode_mem != nullptr) && (m_Y != nullptr);
|
|
||||||
|
|
||||||
bool inconsistentComposition = netIn.composition.hash() != m_last_composition_hash;
|
bool resourcesExist = (sctx_p->cvode_mem != nullptr) && (sctx_p->Y != nullptr);
|
||||||
|
|
||||||
|
bool inconsistentComposition = netIn.composition.hash() != sctx_p->last_composition_hash;
|
||||||
fourdst::composition::Composition equilibratedComposition;
|
fourdst::composition::Composition equilibratedComposition;
|
||||||
|
|
||||||
if (forceReinitialize || !resourcesExist || inconsistentComposition) {
|
if (forceReinitialize || !resourcesExist || inconsistentComposition) {
|
||||||
cleanup_cvode_resources(true);
|
sctx_p->reset_cvode();
|
||||||
|
if (!sctx_p->has_context()) {
|
||||||
|
sctx_p->init_context();
|
||||||
|
}
|
||||||
LOG_INFO(
|
LOG_INFO(
|
||||||
m_logger,
|
m_logger,
|
||||||
"Preforming full CVODE initialization (Reason: {})",
|
"Preforming full CVODE initialization (Reason: {})",
|
||||||
@@ -141,26 +218,24 @@ namespace gridfire::solver {
|
|||||||
(!resourcesExist ? "CVODE resources do not exist" :
|
(!resourcesExist ? "CVODE resources do not exist" :
|
||||||
"Input composition inconsistent with previous state"));
|
"Input composition inconsistent with previous state"));
|
||||||
LOG_TRACE_L1(m_logger, "Starting engine update chain...");
|
LOG_TRACE_L1(m_logger, "Starting engine update chain...");
|
||||||
equilibratedComposition = m_engine.project(*m_scratch_blob, netIn);
|
equilibratedComposition = m_engine.project(*sctx_p->engine_ctx, netIn);
|
||||||
LOG_TRACE_L1(m_logger, "Engine updated and equilibrated composition found!");
|
LOG_TRACE_L1(m_logger, "Engine updated and equilibrated composition found!");
|
||||||
|
|
||||||
size_t numSpecies = m_engine.getNetworkSpecies(*m_scratch_blob).size();
|
size_t numSpecies = m_engine.getNetworkSpecies(*sctx_p->engine_ctx).size();
|
||||||
uint64_t N = numSpecies + 1;
|
uint64_t N = numSpecies + 1;
|
||||||
|
|
||||||
LOG_TRACE_L1(m_logger, "Number of species: {} ({} independent variables)", numSpecies, N);
|
LOG_TRACE_L1(m_logger, "Number of species: {} ({} independent variables)", numSpecies, N);
|
||||||
LOG_TRACE_L1(m_logger, "Initializing CVODE resources");
|
LOG_TRACE_L1(m_logger, "Initializing CVODE resources");
|
||||||
m_cvode_mem = CVodeCreate(CV_BDF, m_sun_ctx);
|
|
||||||
utils::check_cvode_flag(m_cvode_mem == nullptr ? -1 : 0, "CVodeCreate");
|
|
||||||
|
|
||||||
initialize_cvode_integration_resources(N, numSpecies, 0.0, equilibratedComposition, absTol, relTol, 0.0);
|
initialize_cvode_integration_resources(sctx_p, N, numSpecies, 0.0, equilibratedComposition, sctx_p->abs_tol.value(), sctx_p->rel_tol.value(), 0.0);
|
||||||
m_last_size = N;
|
sctx_p->last_size = N;
|
||||||
} else {
|
} else {
|
||||||
LOG_INFO(m_logger, "Reusing existing CVODE resources (size: {})", m_last_size);
|
LOG_INFO(m_logger, "Reusing existing CVODE resources (size: {})", sctx_p->last_size);
|
||||||
|
|
||||||
const size_t numSpecies = m_engine.getNetworkSpecies(*m_scratch_blob).size();
|
const size_t numSpecies = m_engine.getNetworkSpecies(*sctx_p->engine_ctx).size();
|
||||||
sunrealtype *y_data = N_VGetArrayPointer(m_Y);
|
sunrealtype *y_data = N_VGetArrayPointer(sctx_p->Y);
|
||||||
for (size_t i = 0; i < numSpecies; i++) {
|
for (size_t i = 0; i < numSpecies; i++) {
|
||||||
const auto& species = m_engine.getNetworkSpecies(*m_scratch_blob)[i];
|
const auto& species = m_engine.getNetworkSpecies(*sctx_p->engine_ctx)[i];
|
||||||
if (netIn.composition.contains(species)) {
|
if (netIn.composition.contains(species)) {
|
||||||
y_data[i] = netIn.composition.getMolarAbundance(species);
|
y_data[i] = netIn.composition.getMolarAbundance(species);
|
||||||
} else {
|
} else {
|
||||||
@@ -168,16 +243,17 @@ namespace gridfire::solver {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
y_data[numSpecies] = 0.0; // Reset energy accumulator
|
y_data[numSpecies] = 0.0; // Reset energy accumulator
|
||||||
utils::check_cvode_flag(CVodeSStolerances(m_cvode_mem, relTol, absTol), "CVodeSStolerances");
|
utils::check_cvode_flag(CVodeSStolerances(sctx_p->cvode_mem, sctx_p->rel_tol.value(), sctx_p->abs_tol.value()), "CVodeSStolerances");
|
||||||
utils::check_cvode_flag(CVodeReInit(m_cvode_mem, 0.0, m_Y), "CVodeReInit");
|
utils::check_cvode_flag(CVodeReInit(sctx_p->cvode_mem, 0.0, sctx_p->Y), "CVodeReInit");
|
||||||
|
|
||||||
equilibratedComposition = netIn.composition; // Use the provided composition as-is if we already have validated CVODE resources and that the composition is consistent with the previous state
|
equilibratedComposition = netIn.composition; // Use the provided composition as-is if we already have validated CVODE resources and that the composition is consistent with the previous state
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t numSpecies = m_engine.getNetworkSpecies(*m_scratch_blob).size();
|
size_t numSpecies = m_engine.getNetworkSpecies(*sctx_p->engine_ctx).size();
|
||||||
CVODEUserData user_data {
|
CVODEUserData user_data {
|
||||||
.solver_instance = this,
|
.solver_instance = this,
|
||||||
.ctx = *m_scratch_blob,
|
.sctx = sctx_p,
|
||||||
|
.ctx = *sctx_p->engine_ctx,
|
||||||
.engine = &m_engine,
|
.engine = &m_engine,
|
||||||
};
|
};
|
||||||
LOG_TRACE_L1(m_logger, "CVODE resources successfully initialized!");
|
LOG_TRACE_L1(m_logger, "CVODE resources successfully initialized!");
|
||||||
@@ -185,7 +261,7 @@ namespace gridfire::solver {
|
|||||||
double current_time = 0;
|
double current_time = 0;
|
||||||
// ReSharper disable once CppTooWideScope
|
// ReSharper disable once CppTooWideScope
|
||||||
[[maybe_unused]] double last_callback_time = 0;
|
[[maybe_unused]] double last_callback_time = 0;
|
||||||
m_num_steps = 0;
|
sctx_p->num_steps = 0;
|
||||||
double accumulated_energy = 0.0;
|
double accumulated_energy = 0.0;
|
||||||
|
|
||||||
double accumulated_neutrino_energy_loss = 0.0;
|
double accumulated_neutrino_energy_loss = 0.0;
|
||||||
@@ -205,13 +281,13 @@ namespace gridfire::solver {
|
|||||||
while (current_time < netIn.tMax) {
|
while (current_time < netIn.tMax) {
|
||||||
user_data.T9 = T9;
|
user_data.T9 = T9;
|
||||||
user_data.rho = netIn.density;
|
user_data.rho = netIn.density;
|
||||||
user_data.networkSpecies = &m_engine.getNetworkSpecies(*m_scratch_blob);
|
user_data.networkSpecies = &m_engine.getNetworkSpecies(*sctx_p->engine_ctx);
|
||||||
user_data.captured_exception.reset();
|
user_data.captured_exception.reset();
|
||||||
|
|
||||||
utils::check_cvode_flag(CVodeSetUserData(m_cvode_mem, &user_data), "CVodeSetUserData");
|
utils::check_cvode_flag(CVodeSetUserData(sctx_p->cvode_mem, &user_data), "CVodeSetUserData");
|
||||||
|
|
||||||
LOG_TRACE_L2(m_logger, "Taking one CVODE step...");
|
LOG_TRACE_L2(m_logger, "Taking one CVODE step...");
|
||||||
int flag = CVode(m_cvode_mem, netIn.tMax, m_Y, ¤t_time, CV_ONE_STEP);
|
int flag = CVode(sctx_p->cvode_mem, netIn.tMax, sctx_p->Y, ¤t_time, CV_ONE_STEP);
|
||||||
LOG_TRACE_L2(m_logger, "CVODE step complete. Current time: {}, step status: {}", current_time, utils::cvode_ret_code_map.at(flag));
|
LOG_TRACE_L2(m_logger, "CVODE step complete. Current time: {}, step status: {}", current_time, utils::cvode_ret_code_map.at(flag));
|
||||||
|
|
||||||
if (user_data.captured_exception){
|
if (user_data.captured_exception){
|
||||||
@@ -223,13 +299,13 @@ namespace gridfire::solver {
|
|||||||
|
|
||||||
long int n_steps;
|
long int n_steps;
|
||||||
double last_step_size;
|
double last_step_size;
|
||||||
CVodeGetNumSteps(m_cvode_mem, &n_steps);
|
CVodeGetNumSteps(sctx_p->cvode_mem, &n_steps);
|
||||||
CVodeGetLastStep(m_cvode_mem, &last_step_size);
|
CVodeGetLastStep(sctx_p->cvode_mem, &last_step_size);
|
||||||
long int nliters, nlcfails;
|
long int nliters, nlcfails;
|
||||||
CVodeGetNumNonlinSolvIters(m_cvode_mem, &nliters);
|
CVodeGetNumNonlinSolvIters(sctx_p->cvode_mem, &nliters);
|
||||||
CVodeGetNumNonlinSolvConvFails(m_cvode_mem, &nlcfails);
|
CVodeGetNumNonlinSolvConvFails(sctx_p->cvode_mem, &nlcfails);
|
||||||
|
|
||||||
sunrealtype* y_data = N_VGetArrayPointer(m_Y);
|
sunrealtype* y_data = N_VGetArrayPointer(sctx_p->Y);
|
||||||
const double current_energy = y_data[numSpecies]; // Specific energy rate
|
const double current_energy = y_data[numSpecies]; // Specific energy rate
|
||||||
|
|
||||||
// TODO: Accumulate neutrino loss through the state vector directly which will allow CVODE to properly integrate it
|
// TODO: Accumulate neutrino loss through the state vector directly which will allow CVODE to properly integrate it
|
||||||
@@ -238,7 +314,7 @@ namespace gridfire::solver {
|
|||||||
|
|
||||||
size_t iter_diff = (total_nonlinear_iterations + nliters) - prev_nonlinear_iterations;
|
size_t iter_diff = (total_nonlinear_iterations + nliters) - prev_nonlinear_iterations;
|
||||||
size_t convFail_diff = (total_convergence_failures + nlcfails) - prev_convergence_failures;
|
size_t convFail_diff = (total_convergence_failures + nlcfails) - prev_convergence_failures;
|
||||||
if (m_stdout_logging_enabled) {
|
if (sctx_p->stdout_logging) {
|
||||||
std::println(
|
std::println(
|
||||||
"Step: {:6} | Updates: {:3} | Epoch Steps: {:4} | t: {:.3e} [s] | dt: {:15.6E} [s] | Iterations: {:6} (+{:2}) | Total Convergence Failures: {:2} (+{:2})",
|
"Step: {:6} | Updates: {:3} | Epoch Steps: {:4} | t: {:.3e} [s] | dt: {:15.6E} [s] | Iterations: {:6} (+{:2}) | Total Convergence Failures: {:2} (+{:2})",
|
||||||
total_steps + n_steps,
|
total_steps + n_steps,
|
||||||
@@ -253,20 +329,16 @@ namespace gridfire::solver {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < numSpecies; ++i) {
|
for (size_t i = 0; i < numSpecies; ++i) {
|
||||||
const auto& species = m_engine.getNetworkSpecies(*m_scratch_blob)[i];
|
const auto& species = m_engine.getNetworkSpecies(*sctx_p->engine_ctx)[i];
|
||||||
if (y_data[i] > 0.0) {
|
if (y_data[i] > 0.0) {
|
||||||
postStep.setMolarAbundance(species, y_data[i]);
|
postStep.setMolarAbundance(species, y_data[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// fourdst::composition::Composition collectedComposition = m_engine.collectComposition(postStep, netIn.temperature/1e9, netIn.density);
|
|
||||||
// for (size_t i = 0; i < numSpecies; ++i) {
|
|
||||||
// y_data[i] = collectedComposition.getMolarAbundance(m_engine.getNetworkSpecies()[i]);
|
|
||||||
// }
|
|
||||||
LOG_INFO(m_logger, "Completed {:5} steps to time {:10.4E} [s] (dt = {:15.6E} [s]). Current specific energy: {:15.6E} [erg/g]", total_steps + n_steps, current_time, last_step_size, current_energy);
|
LOG_INFO(m_logger, "Completed {:5} steps to time {:10.4E} [s] (dt = {:15.6E} [s]). Current specific energy: {:15.6E} [erg/g]", total_steps + n_steps, current_time, last_step_size, current_energy);
|
||||||
LOG_DEBUG(m_logger, "Current composition (molar abundance): {}", [&]() -> std::string {
|
LOG_DEBUG(m_logger, "Current composition (molar abundance): {}", [&]() -> std::string {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
for (size_t i = 0; i < numSpecies; ++i) {
|
for (size_t i = 0; i < numSpecies; ++i) {
|
||||||
const auto& species = m_engine.getNetworkSpecies(*m_scratch_blob)[i];
|
const auto& species = m_engine.getNetworkSpecies(*sctx_p->engine_ctx)[i];
|
||||||
ss << species.name() << ": (y_data = " << y_data[i] << ", collected = " << postStep.getMolarAbundance(species) << ")";
|
ss << species.name() << ": (y_data = " << y_data[i] << ", collected = " << postStep.getMolarAbundance(species) << ")";
|
||||||
if (i < numSpecies - 1) {
|
if (i < numSpecies - 1) {
|
||||||
ss << ", ";
|
ss << ", ";
|
||||||
@@ -282,36 +354,44 @@ namespace gridfire::solver {
|
|||||||
? user_data.reaction_contribution_map.value()
|
? user_data.reaction_contribution_map.value()
|
||||||
: kEmptyMap;
|
: kEmptyMap;
|
||||||
|
|
||||||
auto ctx = TimestepContext(
|
auto ctx = PointSolverTimestepContext(
|
||||||
current_time,
|
current_time,
|
||||||
m_Y,
|
sctx_p->Y,
|
||||||
last_step_size,
|
last_step_size,
|
||||||
last_callback_time,
|
last_callback_time,
|
||||||
T9,
|
T9,
|
||||||
netIn.density,
|
netIn.density,
|
||||||
n_steps,
|
n_steps,
|
||||||
m_engine,
|
m_engine,
|
||||||
m_engine.getNetworkSpecies(*m_scratch_blob),
|
m_engine.getNetworkSpecies(*sctx_p->engine_ctx),
|
||||||
convFail_diff,
|
convFail_diff,
|
||||||
iter_diff,
|
iter_diff,
|
||||||
rcMap,
|
rcMap,
|
||||||
*m_scratch_blob
|
*sctx_p->engine_ctx
|
||||||
);
|
);
|
||||||
|
|
||||||
prev_nonlinear_iterations = nliters + total_nonlinear_iterations;
|
prev_nonlinear_iterations = nliters + total_nonlinear_iterations;
|
||||||
prev_convergence_failures = nlcfails + total_convergence_failures;
|
prev_convergence_failures = nlcfails + total_convergence_failures;
|
||||||
|
|
||||||
if (m_callback.has_value()) {
|
if (sctx_p->callback.has_value()) {
|
||||||
m_callback.value()(ctx);
|
sctx_p->callback.value()(ctx);
|
||||||
}
|
}
|
||||||
trigger->step(ctx);
|
trigger->step(ctx);
|
||||||
|
|
||||||
if (m_detailed_step_logging) {
|
if (sctx_p->detailed_step_logging) {
|
||||||
log_step_diagnostics(*m_scratch_blob, user_data, true, true, true, "step_" + std::to_string(total_steps + n_steps) + ".json");
|
log_step_diagnostics(
|
||||||
|
sctx_p,
|
||||||
|
*sctx_p->engine_ctx,
|
||||||
|
user_data,
|
||||||
|
true,
|
||||||
|
true,
|
||||||
|
true,
|
||||||
|
"step_" + std::to_string(total_steps + n_steps) + ".json"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (trigger->check(ctx)) {
|
if (trigger->check(ctx)) {
|
||||||
if (m_stdout_logging_enabled && displayTrigger) {
|
if (sctx_p->stdout_logging && displayTrigger) {
|
||||||
trigger::printWhy(trigger->why(ctx));
|
trigger::printWhy(trigger->why(ctx));
|
||||||
}
|
}
|
||||||
trigger->update(ctx);
|
trigger->update(ctx);
|
||||||
@@ -333,20 +413,20 @@ namespace gridfire::solver {
|
|||||||
|
|
||||||
fourdst::composition::Composition temp_comp;
|
fourdst::composition::Composition temp_comp;
|
||||||
std::vector<double> mass_fractions;
|
std::vector<double> mass_fractions;
|
||||||
auto num_species_at_stop = static_cast<long int>(m_engine.getNetworkSpecies(*m_scratch_blob).size());
|
auto num_species_at_stop = static_cast<long int>(m_engine.getNetworkSpecies(*sctx_p->engine_ctx).size());
|
||||||
|
|
||||||
if (num_species_at_stop > m_Y->ops->nvgetlength(m_Y) - 1) {
|
if (num_species_at_stop > sctx_p->Y->ops->nvgetlength(sctx_p->Y) - 1) {
|
||||||
LOG_ERROR(
|
LOG_ERROR(
|
||||||
m_logger,
|
m_logger,
|
||||||
"Number of species at engine update ({}) exceeds the number of species in the CVODE solver ({}). This should never happen.",
|
"Number of species at engine update ({}) exceeds the number of species in the CVODE solver ({}). This should never happen.",
|
||||||
num_species_at_stop,
|
num_species_at_stop,
|
||||||
m_Y->ops->nvgetlength(m_Y) - 1 // -1 due to energy in the last index
|
sctx_p->Y->ops->nvgetlength(sctx_p->Y) - 1 // -1 due to energy in the last index
|
||||||
);
|
);
|
||||||
throw std::runtime_error("Number of species at engine update exceeds the number of species in the CVODE solver. This should never happen.");
|
throw std::runtime_error("Number of species at engine update exceeds the number of species in the CVODE solver. This should never happen.");
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const auto& species: m_engine.getNetworkSpecies(*m_scratch_blob)) {
|
for (const auto& species: m_engine.getNetworkSpecies(*sctx_p->engine_ctx)) {
|
||||||
const size_t sid = m_engine.getSpeciesIndex(*m_scratch_blob, species);
|
const size_t sid = m_engine.getSpeciesIndex(*sctx_p->engine_ctx, species);
|
||||||
temp_comp.registerSpecies(species);
|
temp_comp.registerSpecies(species);
|
||||||
double y = end_of_step_abundances[sid];
|
double y = end_of_step_abundances[sid];
|
||||||
if (y > 0.0) {
|
if (y > 0.0) {
|
||||||
@@ -356,7 +436,7 @@ namespace gridfire::solver {
|
|||||||
|
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
for (long int i = 0; i < num_species_at_stop; ++i) {
|
for (long int i = 0; i < num_species_at_stop; ++i) {
|
||||||
const auto& species = m_engine.getNetworkSpecies(*m_scratch_blob)[i];
|
const auto& species = m_engine.getNetworkSpecies(*sctx_p->engine_ctx)[i];
|
||||||
if (std::abs(temp_comp.getMolarAbundance(species) - y_data[i]) > 1e-12) {
|
if (std::abs(temp_comp.getMolarAbundance(species) - y_data[i]) > 1e-12) {
|
||||||
throw exceptions::UtilityError("Conversion from solver state to composition molar abundance failed verification.");
|
throw exceptions::UtilityError("Conversion from solver state to composition molar abundance failed verification.");
|
||||||
}
|
}
|
||||||
@@ -391,7 +471,7 @@ namespace gridfire::solver {
|
|||||||
"Prior to Engine Update active reactions are: {}",
|
"Prior to Engine Update active reactions are: {}",
|
||||||
[&]() -> std::string {
|
[&]() -> std::string {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
const gridfire::reaction::ReactionSet& reactions = m_engine.getNetworkReactions(*m_scratch_blob);
|
const gridfire::reaction::ReactionSet& reactions = m_engine.getNetworkReactions(*sctx_p->engine_ctx);
|
||||||
size_t count = 0;
|
size_t count = 0;
|
||||||
for (const auto& reaction : reactions) {
|
for (const auto& reaction : reactions) {
|
||||||
ss << reaction -> id();
|
ss << reaction -> id();
|
||||||
@@ -403,7 +483,7 @@ namespace gridfire::solver {
|
|||||||
return ss.str();
|
return ss.str();
|
||||||
}()
|
}()
|
||||||
);
|
);
|
||||||
fourdst::composition::Composition currentComposition = m_engine.project(*m_scratch_blob, netInTemp);
|
fourdst::composition::Composition currentComposition = m_engine.project(*sctx_p->engine_ctx, netInTemp);
|
||||||
LOG_DEBUG(
|
LOG_DEBUG(
|
||||||
m_logger,
|
m_logger,
|
||||||
"After to Engine update composition is (molar abundance) {}",
|
"After to Engine update composition is (molar abundance) {}",
|
||||||
@@ -450,7 +530,7 @@ namespace gridfire::solver {
|
|||||||
"After Engine Update active reactions are: {}",
|
"After Engine Update active reactions are: {}",
|
||||||
[&]() -> std::string {
|
[&]() -> std::string {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
const gridfire::reaction::ReactionSet& reactions = m_engine.getNetworkReactions(*m_scratch_blob);
|
const gridfire::reaction::ReactionSet& reactions = m_engine.getNetworkReactions(*sctx_p->engine_ctx);
|
||||||
size_t count = 0;
|
size_t count = 0;
|
||||||
for (const auto& reaction : reactions) {
|
for (const auto& reaction : reactions) {
|
||||||
ss << reaction -> id();
|
ss << reaction -> id();
|
||||||
@@ -466,34 +546,29 @@ namespace gridfire::solver {
|
|||||||
m_logger,
|
m_logger,
|
||||||
"Due to a triggered engine update the composition was updated from size {} to {} species.",
|
"Due to a triggered engine update the composition was updated from size {} to {} species.",
|
||||||
num_species_at_stop,
|
num_species_at_stop,
|
||||||
m_engine.getNetworkSpecies(*m_scratch_blob).size()
|
m_engine.getNetworkSpecies(*sctx_p->engine_ctx).size()
|
||||||
);
|
);
|
||||||
|
|
||||||
numSpecies = m_engine.getNetworkSpecies(*m_scratch_blob).size();
|
numSpecies = m_engine.getNetworkSpecies(*sctx_p->engine_ctx).size();
|
||||||
size_t N = numSpecies + 1;
|
size_t N = numSpecies + 1;
|
||||||
|
|
||||||
LOG_INFO(m_logger, "Starting CVODE reinitialization after engine update...");
|
LOG_INFO(m_logger, "Starting CVODE reinitialization after engine update...");
|
||||||
cleanup_cvode_resources(true);
|
sctx_p->reset_cvode();
|
||||||
|
initialize_cvode_integration_resources(sctx_p, N, numSpecies, current_time, currentComposition, sctx_p->abs_tol.value(), sctx_p->rel_tol.value(), accumulated_energy);
|
||||||
|
|
||||||
m_cvode_mem = CVodeCreate(CV_BDF, m_sun_ctx);
|
utils::check_cvode_flag(CVodeReInit(sctx_p->cvode_mem, current_time, sctx_p->Y), "CVodeReInit");
|
||||||
utils::check_cvode_flag(m_cvode_mem == nullptr ? -1 : 0, "CVodeCreate");
|
|
||||||
|
|
||||||
initialize_cvode_integration_resources(N, numSpecies, current_time, currentComposition, absTol, relTol, accumulated_energy);
|
|
||||||
|
|
||||||
utils::check_cvode_flag(CVodeReInit(m_cvode_mem, current_time, m_Y), "CVodeReInit");
|
|
||||||
// throw exceptions::DebugException("Debug");
|
|
||||||
LOG_INFO(m_logger, "Done reinitializing CVODE after engine update. The next log messages will be from the first step after reinitialization...");
|
LOG_INFO(m_logger, "Done reinitializing CVODE after engine update. The next log messages will be from the first step after reinitialization...");
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (m_stdout_logging_enabled) { // Flush the buffer if standard out logging is enabled
|
if (sctx_p->stdout_logging) { // Flush the buffer if standard out logging is enabled
|
||||||
std::cout << std::flush;
|
std::cout << std::flush;
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG_INFO(m_logger, "CVODE iteration complete");
|
LOG_INFO(m_logger, "CVODE iteration complete");
|
||||||
|
|
||||||
sunrealtype* y_data = N_VGetArrayPointer(m_Y);
|
sunrealtype* y_data = N_VGetArrayPointer(sctx_p->Y);
|
||||||
accumulated_energy += y_data[numSpecies];
|
accumulated_energy += y_data[numSpecies];
|
||||||
std::vector<double> y_vec(y_data, y_data + numSpecies);
|
std::vector<double> y_vec(y_data, y_data + numSpecies);
|
||||||
|
|
||||||
@@ -505,7 +580,7 @@ namespace gridfire::solver {
|
|||||||
|
|
||||||
LOG_INFO(m_logger, "Constructing final composition= with {} species", numSpecies);
|
LOG_INFO(m_logger, "Constructing final composition= with {} species", numSpecies);
|
||||||
|
|
||||||
fourdst::composition::Composition topLevelComposition(m_engine.getNetworkSpecies(*m_scratch_blob), y_vec);
|
fourdst::composition::Composition topLevelComposition(m_engine.getNetworkSpecies(*sctx_p->engine_ctx), y_vec);
|
||||||
LOG_INFO(m_logger, "Final composition constructed from solver state successfully! ({})", [&topLevelComposition]() -> std::string {
|
LOG_INFO(m_logger, "Final composition constructed from solver state successfully! ({})", [&topLevelComposition]() -> std::string {
|
||||||
std::ostringstream ss;
|
std::ostringstream ss;
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
@@ -520,7 +595,7 @@ namespace gridfire::solver {
|
|||||||
}());
|
}());
|
||||||
|
|
||||||
LOG_INFO(m_logger, "Collecting final composition...");
|
LOG_INFO(m_logger, "Collecting final composition...");
|
||||||
fourdst::composition::Composition outputComposition = m_engine.collectComposition(*m_scratch_blob, topLevelComposition, netIn.temperature/1e9, netIn.density);
|
fourdst::composition::Composition outputComposition = m_engine.collectComposition(*sctx_p->engine_ctx, topLevelComposition, netIn.temperature/1e9, netIn.density);
|
||||||
|
|
||||||
assert(outputComposition.getRegisteredSymbols().size() == equilibratedComposition.getRegisteredSymbols().size());
|
assert(outputComposition.getRegisteredSymbols().size() == equilibratedComposition.getRegisteredSymbols().size());
|
||||||
|
|
||||||
@@ -541,11 +616,11 @@ namespace gridfire::solver {
|
|||||||
NetOut netOut;
|
NetOut netOut;
|
||||||
netOut.composition = outputComposition;
|
netOut.composition = outputComposition;
|
||||||
netOut.energy = accumulated_energy;
|
netOut.energy = accumulated_energy;
|
||||||
utils::check_cvode_flag(CVodeGetNumSteps(m_cvode_mem, reinterpret_cast<long int *>(&netOut.num_steps)), "CVodeGetNumSteps");
|
utils::check_cvode_flag(CVodeGetNumSteps(sctx_p->cvode_mem, reinterpret_cast<long int *>(&netOut.num_steps)), "CVodeGetNumSteps");
|
||||||
|
|
||||||
LOG_TRACE_L2(m_logger, "generating final nuclear energy generation rate derivatives...");
|
LOG_TRACE_L2(m_logger, "generating final nuclear energy generation rate derivatives...");
|
||||||
auto [dEps_dT, dEps_dRho] = m_engine.calculateEpsDerivatives(
|
auto [dEps_dT, dEps_dRho] = m_engine.calculateEpsDerivatives(
|
||||||
*m_scratch_blob,
|
*sctx_p->engine_ctx,
|
||||||
outputComposition,
|
outputComposition,
|
||||||
T9,
|
T9,
|
||||||
netIn.density
|
netIn.density
|
||||||
@@ -559,53 +634,13 @@ namespace gridfire::solver {
|
|||||||
LOG_TRACE_L2(m_logger, "Output data built!");
|
LOG_TRACE_L2(m_logger, "Output data built!");
|
||||||
LOG_TRACE_L2(m_logger, "Solver evaluation complete!.");
|
LOG_TRACE_L2(m_logger, "Solver evaluation complete!.");
|
||||||
|
|
||||||
m_last_composition_hash = netOut.composition.hash();
|
sctx_p->last_composition_hash = netOut.composition.hash();
|
||||||
m_last_size = netOut.composition.size() + 1;
|
sctx_p->last_size = netOut.composition.size() + 1;
|
||||||
CVodeGetLastStep(m_cvode_mem, &m_last_good_time_step);
|
CVodeGetLastStep(sctx_p->cvode_mem, &sctx_p->last_good_time_step);
|
||||||
return netOut;
|
return netOut;
|
||||||
}
|
}
|
||||||
|
|
||||||
void CVODESolverStrategy::set_callback(const std::any &callback) {
|
int PointSolver::cvode_rhs_wrapper(
|
||||||
m_callback = std::any_cast<TimestepCallback>(callback);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool CVODESolverStrategy::get_stdout_logging_enabled() const {
|
|
||||||
return m_stdout_logging_enabled;
|
|
||||||
}
|
|
||||||
|
|
||||||
void CVODESolverStrategy::set_stdout_logging_enabled(const bool logging_enabled) {
|
|
||||||
m_stdout_logging_enabled = logging_enabled;
|
|
||||||
}
|
|
||||||
|
|
||||||
void CVODESolverStrategy::set_absTol(double absTol) {
|
|
||||||
m_absTol = absTol;
|
|
||||||
}
|
|
||||||
|
|
||||||
void CVODESolverStrategy::set_relTol(double relTol) {
|
|
||||||
m_relTol = relTol;
|
|
||||||
}
|
|
||||||
|
|
||||||
double CVODESolverStrategy::get_absTol() const {
|
|
||||||
if (m_absTol.has_value()) {
|
|
||||||
return m_absTol.value();
|
|
||||||
} else {
|
|
||||||
return -1.0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
double CVODESolverStrategy::get_relTol() const {
|
|
||||||
if (m_relTol.has_value()) {
|
|
||||||
return m_relTol.value();
|
|
||||||
} else {
|
|
||||||
return -1.0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::tuple<std::string, std::string>> CVODESolverStrategy::describe_callback_context() const {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
int CVODESolverStrategy::cvode_rhs_wrapper(
|
|
||||||
const sunrealtype t,
|
const sunrealtype t,
|
||||||
const N_Vector y,
|
const N_Vector y,
|
||||||
const N_Vector ydot,
|
const N_Vector ydot,
|
||||||
@@ -633,7 +668,7 @@ namespace gridfire::solver {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int CVODESolverStrategy::cvode_jac_wrapper(
|
int PointSolver::cvode_jac_wrapper(
|
||||||
sunrealtype t,
|
sunrealtype t,
|
||||||
N_Vector y,
|
N_Vector y,
|
||||||
N_Vector ydot,
|
N_Vector ydot,
|
||||||
@@ -754,7 +789,7 @@ namespace gridfire::solver {
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
CVODESolverStrategy::CVODERHSOutputData CVODESolverStrategy::calculate_rhs(
|
PointSolver::CVODERHSOutputData PointSolver::calculate_rhs(
|
||||||
const sunrealtype t,
|
const sunrealtype t,
|
||||||
N_Vector y,
|
N_Vector y,
|
||||||
N_Vector ydot,
|
N_Vector ydot,
|
||||||
@@ -772,10 +807,10 @@ namespace gridfire::solver {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
std::vector<double> y_vec(y_data, y_data + numSpecies);
|
std::vector<double> y_vec(y_data, y_data + numSpecies);
|
||||||
fourdst::composition::Composition composition(m_engine.getNetworkSpecies(*m_scratch_blob), y_vec);
|
fourdst::composition::Composition composition(m_engine.getNetworkSpecies(data->ctx), y_vec);
|
||||||
|
|
||||||
LOG_TRACE_L2(m_logger, "Calculating RHS at time {} with {} species in composition", t, composition.size());
|
LOG_TRACE_L2(m_logger, "Calculating RHS at time {} with {} species in composition", t, composition.size());
|
||||||
const auto result = m_engine.calculateRHSAndEnergy(*m_scratch_blob, composition, data->T9, data->rho, false);
|
const auto result = m_engine.calculateRHSAndEnergy(data->ctx, composition, data->T9, data->rho, false);
|
||||||
if (!result) {
|
if (!result) {
|
||||||
LOG_CRITICAL(m_logger, "Failed to calculate RHS at time {}: {}", t, EngineStatus_to_string(result.error()));
|
LOG_CRITICAL(m_logger, "Failed to calculate RHS at time {}: {}", t, EngineStatus_to_string(result.error()));
|
||||||
throw exceptions::BadRHSEngineError(std::format("Failed to calculate RHS at time {}: {}", t, EngineStatus_to_string(result.error())));
|
throw exceptions::BadRHSEngineError(std::format("Failed to calculate RHS at time {}: {}", t, EngineStatus_to_string(result.error())));
|
||||||
@@ -805,7 +840,7 @@ namespace gridfire::solver {
|
|||||||
}());
|
}());
|
||||||
|
|
||||||
for (size_t i = 0; i < numSpecies; ++i) {
|
for (size_t i = 0; i < numSpecies; ++i) {
|
||||||
fourdst::atomic::Species species = m_engine.getNetworkSpecies(*m_scratch_blob)[i];
|
fourdst::atomic::Species species = m_engine.getNetworkSpecies(data->ctx)[i];
|
||||||
ydot_data[i] = dydt.at(species);
|
ydot_data[i] = dydt.at(species);
|
||||||
}
|
}
|
||||||
ydot_data[numSpecies] = nuclearEnergyGenerationRate; // Set the last element to the specific energy rate
|
ydot_data[numSpecies] = nuclearEnergyGenerationRate; // Set the last element to the specific energy rate
|
||||||
@@ -813,7 +848,8 @@ namespace gridfire::solver {
|
|||||||
return {reactionContributions, result.value().neutrinoEnergyLossRate, result.value().totalNeutrinoFlux};
|
return {reactionContributions, result.value().neutrinoEnergyLossRate, result.value().totalNeutrinoFlux};
|
||||||
}
|
}
|
||||||
|
|
||||||
void CVODESolverStrategy::initialize_cvode_integration_resources(
|
void PointSolver::initialize_cvode_integration_resources(
|
||||||
|
PointSolverContext* sctx_p,
|
||||||
const uint64_t N,
|
const uint64_t N,
|
||||||
const size_t numSpecies,
|
const size_t numSpecies,
|
||||||
const double current_time,
|
const double current_time,
|
||||||
@@ -821,16 +857,18 @@ namespace gridfire::solver {
|
|||||||
const double absTol,
|
const double absTol,
|
||||||
const double relTol,
|
const double relTol,
|
||||||
const double accumulatedEnergy
|
const double accumulatedEnergy
|
||||||
) {
|
) const {
|
||||||
LOG_TRACE_L2(m_logger, "Initializing CVODE integration resources with N: {}, current_time: {}, absTol: {}, relTol: {}", N, current_time, absTol, relTol);
|
LOG_TRACE_L2(m_logger, "Initializing CVODE integration resources with N: {}, current_time: {}, absTol: {}, relTol: {}", N, current_time, absTol, relTol);
|
||||||
cleanup_cvode_resources(false); // Cleanup any existing resources before initializing new ones
|
sctx_p->reset_cvode();
|
||||||
|
|
||||||
m_Y = utils::init_sun_vector(N, m_sun_ctx);
|
sctx_p->cvode_mem = CVodeCreate(CV_BDF, sctx_p->sun_ctx);
|
||||||
m_YErr = N_VClone(m_Y);
|
utils::check_cvode_flag(sctx_p->cvode_mem == nullptr ? -1 : 0, "CVodeCreate");
|
||||||
|
sctx_p->Y = utils::init_sun_vector(N, sctx_p->sun_ctx);
|
||||||
|
sctx_p->YErr = N_VClone(sctx_p->Y);
|
||||||
|
|
||||||
sunrealtype *y_data = N_VGetArrayPointer(m_Y);
|
sunrealtype *y_data = N_VGetArrayPointer(sctx_p->Y);
|
||||||
for (size_t i = 0; i < numSpecies; i++) {
|
for (size_t i = 0; i < numSpecies; i++) {
|
||||||
const auto& species = m_engine.getNetworkSpecies(*m_scratch_blob)[i];
|
const auto& species = m_engine.getNetworkSpecies(*sctx_p->engine_ctx)[i];
|
||||||
if (composition.contains(species)) {
|
if (composition.contains(species)) {
|
||||||
y_data[i] = composition.getMolarAbundance(species);
|
y_data[i] = composition.getMolarAbundance(species);
|
||||||
} else {
|
} else {
|
||||||
@@ -840,8 +878,8 @@ namespace gridfire::solver {
|
|||||||
y_data[numSpecies] = accumulatedEnergy; // Specific energy rate, initialized to zero
|
y_data[numSpecies] = accumulatedEnergy; // Specific energy rate, initialized to zero
|
||||||
|
|
||||||
|
|
||||||
utils::check_cvode_flag(CVodeInit(m_cvode_mem, cvode_rhs_wrapper, current_time, m_Y), "CVodeInit");
|
utils::check_cvode_flag(CVodeInit(sctx_p->cvode_mem, cvode_rhs_wrapper, current_time, sctx_p->Y), "CVodeInit");
|
||||||
utils::check_cvode_flag(CVodeSStolerances(m_cvode_mem, relTol, absTol), "CVodeSStolerances");
|
utils::check_cvode_flag(CVodeSStolerances(sctx_p->cvode_mem, relTol, absTol), "CVodeSStolerances");
|
||||||
|
|
||||||
// Constraints
|
// Constraints
|
||||||
// We constrain the solution vector using CVODE's built in constraint flags as outlines on page 53 of the CVODE manual
|
// We constrain the solution vector using CVODE's built in constraint flags as outlines on page 53 of the CVODE manual
|
||||||
@@ -854,53 +892,30 @@ namespace gridfire::solver {
|
|||||||
// -2.0: The corresponding component of y is constrained to be < 0
|
// -2.0: The corresponding component of y is constrained to be < 0
|
||||||
// Here we use 1.0 for all species to ensure they remain non-negative.
|
// Here we use 1.0 for all species to ensure they remain non-negative.
|
||||||
|
|
||||||
m_constraints = N_VClone(m_Y);
|
sctx_p->constraints = N_VClone(sctx_p->Y);
|
||||||
if (m_constraints == nullptr) {
|
if (sctx_p->constraints == nullptr) {
|
||||||
LOG_ERROR(m_logger, "Failed to create constraints vector for CVODE");
|
LOG_ERROR(m_logger, "Failed to create constraints vector for CVODE");
|
||||||
throw std::runtime_error("Failed to create constraints vector for CVODE");
|
throw std::runtime_error("Failed to create constraints vector for CVODE");
|
||||||
}
|
}
|
||||||
N_VConst(1.0, m_constraints); // Set all constraints to >= 0 (note this is where the flag values are set)
|
N_VConst(1.0, sctx_p->constraints); // Set all constraints to >= 0 (note this is where the flag values are set)
|
||||||
|
|
||||||
utils::check_cvode_flag(CVodeSetConstraints(m_cvode_mem, m_constraints), "CVodeSetConstraints");
|
utils::check_cvode_flag(CVodeSetConstraints(sctx_p->cvode_mem, sctx_p->constraints), "CVodeSetConstraints");
|
||||||
|
|
||||||
utils::check_cvode_flag(CVodeSetMaxStep(m_cvode_mem, 1.0e20), "CVodeSetMaxStep");
|
utils::check_cvode_flag(CVodeSetMaxStep(sctx_p->cvode_mem, 1.0e20), "CVodeSetMaxStep");
|
||||||
|
|
||||||
m_J = SUNDenseMatrix(static_cast<sunindextype>(N), static_cast<sunindextype>(N), m_sun_ctx);
|
sctx_p->J = SUNDenseMatrix(static_cast<sunindextype>(N), static_cast<sunindextype>(N), sctx_p->sun_ctx);
|
||||||
utils::check_cvode_flag(m_J == nullptr ? -1 : 0, "SUNDenseMatrix");
|
utils::check_cvode_flag(sctx_p->J == nullptr ? -1 : 0, "SUNDenseMatrix");
|
||||||
m_LS = SUNLinSol_Dense(m_Y, m_J, m_sun_ctx);
|
sctx_p->LS = SUNLinSol_Dense(sctx_p->Y, sctx_p->J, sctx_p->sun_ctx);
|
||||||
utils::check_cvode_flag(m_LS == nullptr ? -1 : 0, "SUNLinSol_Dense");
|
utils::check_cvode_flag(sctx_p->LS == nullptr ? -1 : 0, "SUNLinSol_Dense");
|
||||||
|
|
||||||
utils::check_cvode_flag(CVodeSetLinearSolver(m_cvode_mem, m_LS, m_J), "CVodeSetLinearSolver");
|
utils::check_cvode_flag(CVodeSetLinearSolver(sctx_p->cvode_mem, sctx_p->LS, sctx_p->J), "CVodeSetLinearSolver");
|
||||||
utils::check_cvode_flag(CVodeSetJacFn(m_cvode_mem, cvode_jac_wrapper), "CVodeSetJacFn");
|
utils::check_cvode_flag(CVodeSetJacFn(sctx_p->cvode_mem, cvode_jac_wrapper), "CVodeSetJacFn");
|
||||||
LOG_TRACE_L2(m_logger, "CVODE solver initialized");
|
LOG_TRACE_L2(m_logger, "CVODE solver initialized");
|
||||||
}
|
}
|
||||||
|
|
||||||
void CVODESolverStrategy::cleanup_cvode_resources(const bool memFree) {
|
|
||||||
LOG_TRACE_L2(m_logger, "Cleaning up cvode resources");
|
|
||||||
if (m_LS) SUNLinSolFree(m_LS);
|
|
||||||
if (m_J) SUNMatDestroy(m_J);
|
|
||||||
if (m_Y) N_VDestroy(m_Y);
|
|
||||||
if (m_YErr) N_VDestroy(m_YErr);
|
|
||||||
if (m_constraints) N_VDestroy(m_constraints);
|
|
||||||
|
|
||||||
m_LS = nullptr;
|
void PointSolver::log_step_diagnostics(
|
||||||
m_J = nullptr;
|
PointSolverContext* sctx_p,
|
||||||
m_Y = nullptr;
|
|
||||||
m_YErr = nullptr;
|
|
||||||
m_constraints = nullptr;
|
|
||||||
|
|
||||||
if (memFree) {
|
|
||||||
if (m_cvode_mem) CVodeFree(&m_cvode_mem);
|
|
||||||
m_cvode_mem = nullptr;
|
|
||||||
}
|
|
||||||
LOG_TRACE_L2(m_logger, "Done Cleaning up cvode resources");
|
|
||||||
}
|
|
||||||
|
|
||||||
void CVODESolverStrategy::set_detailed_step_logging(const bool enabled) {
|
|
||||||
m_detailed_step_logging = enabled;
|
|
||||||
}
|
|
||||||
|
|
||||||
void CVODESolverStrategy::log_step_diagnostics(
|
|
||||||
scratch::StateBlob &ctx,
|
scratch::StateBlob &ctx,
|
||||||
const CVODEUserData &user_data,
|
const CVODEUserData &user_data,
|
||||||
bool displayJacobianStiffness,
|
bool displayJacobianStiffness,
|
||||||
@@ -916,10 +931,10 @@ namespace gridfire::solver {
|
|||||||
sunrealtype hlast, hcur, tcur;
|
sunrealtype hlast, hcur, tcur;
|
||||||
int qlast;
|
int qlast;
|
||||||
|
|
||||||
utils::check_cvode_flag(CVodeGetLastStep(m_cvode_mem, &hlast), "CVodeGetLastStep");
|
utils::check_cvode_flag(CVodeGetLastStep(sctx_p->cvode_mem, &hlast), "CVodeGetLastStep");
|
||||||
utils::check_cvode_flag(CVodeGetCurrentStep(m_cvode_mem, &hcur), "CVodeGetCurrentStep");
|
utils::check_cvode_flag(CVodeGetCurrentStep(sctx_p->cvode_mem, &hcur), "CVodeGetCurrentStep");
|
||||||
utils::check_cvode_flag(CVodeGetLastOrder(m_cvode_mem, &qlast), "CVodeGetLastOrder");
|
utils::check_cvode_flag(CVodeGetLastOrder(sctx_p->cvode_mem, &qlast), "CVodeGetLastOrder");
|
||||||
utils::check_cvode_flag(CVodeGetCurrentTime(m_cvode_mem, &tcur), "CVodeGetCurrentTime");
|
utils::check_cvode_flag(CVodeGetCurrentTime(sctx_p->cvode_mem, &tcur), "CVodeGetCurrentTime");
|
||||||
|
|
||||||
nlohmann::json j;
|
nlohmann::json j;
|
||||||
{
|
{
|
||||||
@@ -941,13 +956,13 @@ namespace gridfire::solver {
|
|||||||
// These are the CRITICAL counters for diagnosing your problem
|
// These are the CRITICAL counters for diagnosing your problem
|
||||||
long int nsteps, nfevals, nlinsetups, netfails, nniters, nconvfails, nsetfails;
|
long int nsteps, nfevals, nlinsetups, netfails, nniters, nconvfails, nsetfails;
|
||||||
|
|
||||||
utils::check_cvode_flag(CVodeGetNumSteps(m_cvode_mem, &nsteps), "CVodeGetNumSteps");
|
utils::check_cvode_flag(CVodeGetNumSteps(sctx_p->cvode_mem, &nsteps), "CVodeGetNumSteps");
|
||||||
utils::check_cvode_flag(CVodeGetNumRhsEvals(m_cvode_mem, &nfevals), "CVodeGetNumRhsEvals");
|
utils::check_cvode_flag(CVodeGetNumRhsEvals(sctx_p->cvode_mem, &nfevals), "CVodeGetNumRhsEvals");
|
||||||
utils::check_cvode_flag(CVodeGetNumLinSolvSetups(m_cvode_mem, &nlinsetups), "CVodeGetNumLinSolvSetups");
|
utils::check_cvode_flag(CVodeGetNumLinSolvSetups(sctx_p->cvode_mem, &nlinsetups), "CVodeGetNumLinSolvSetups");
|
||||||
utils::check_cvode_flag(CVodeGetNumErrTestFails(m_cvode_mem, &netfails), "CVodeGetNumErrTestFails");
|
utils::check_cvode_flag(CVodeGetNumErrTestFails(sctx_p->cvode_mem, &netfails), "CVodeGetNumErrTestFails");
|
||||||
utils::check_cvode_flag(CVodeGetNumNonlinSolvIters(m_cvode_mem, &nniters), "CVodeGetNumNonlinSolvIters");
|
utils::check_cvode_flag(CVodeGetNumNonlinSolvIters(sctx_p->cvode_mem, &nniters), "CVodeGetNumNonlinSolvIters");
|
||||||
utils::check_cvode_flag(CVodeGetNumNonlinSolvConvFails(m_cvode_mem, &nconvfails), "CVodeGetNumNonlinSolvConvFails");
|
utils::check_cvode_flag(CVodeGetNumNonlinSolvConvFails(sctx_p->cvode_mem, &nconvfails), "CVodeGetNumNonlinSolvConvFails");
|
||||||
utils::check_cvode_flag(CVodeGetNumLinConvFails(m_cvode_mem, &nsetfails), "CVodeGetNumLinConvFails");
|
utils::check_cvode_flag(CVodeGetNumLinConvFails(sctx_p->cvode_mem, &nsetfails), "CVodeGetNumLinConvFails");
|
||||||
|
|
||||||
|
|
||||||
{
|
{
|
||||||
@@ -975,22 +990,26 @@ namespace gridfire::solver {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// --- 3. Get Estimated Local Errors (Your Original Logic) ---
|
// --- 3. Get Estimated Local Errors (Your Original Logic) ---
|
||||||
utils::check_cvode_flag(CVodeGetEstLocalErrors(m_cvode_mem, m_YErr), "CVodeGetEstLocalErrors");
|
utils::check_cvode_flag(CVodeGetEstLocalErrors(sctx_p->cvode_mem, sctx_p->YErr), "CVodeGetEstLocalErrors");
|
||||||
|
|
||||||
sunrealtype *y_data = N_VGetArrayPointer(m_Y);
|
sunrealtype *y_data = N_VGetArrayPointer(sctx_p->Y);
|
||||||
sunrealtype *y_err_data = N_VGetArrayPointer(m_YErr);
|
sunrealtype *y_err_data = N_VGetArrayPointer(sctx_p->YErr);
|
||||||
|
|
||||||
const auto absTol = m_config->solver.cvode.absTol;
|
|
||||||
const auto relTol = m_config->solver.cvode.relTol;
|
|
||||||
|
|
||||||
std::vector<double> err_ratios;
|
std::vector<double> err_ratios;
|
||||||
const size_t num_components = N_VGetLength(m_Y);
|
const size_t num_components = N_VGetLength(sctx_p->Y);
|
||||||
err_ratios.resize(num_components - 1); // Assuming -1 is for Energy or similar
|
err_ratios.resize(num_components - 1); // Assuming -1 is for Energy or similar
|
||||||
|
|
||||||
std::vector<double> Y_full(y_data, y_data + num_components - 1);
|
std::vector<double> Y_full(y_data, y_data + num_components - 1);
|
||||||
std::vector<double> E_full(y_err_data, y_err_data + num_components - 1);
|
std::vector<double> E_full(y_err_data, y_err_data + num_components - 1);
|
||||||
|
|
||||||
auto result = diagnostics::report_limiting_species(ctx, *user_data.engine, Y_full, E_full, relTol, absTol, 10, to_file);
|
if (!sctx_p->abs_tol.has_value()) {
|
||||||
|
sctx_p->abs_tol = m_config->solver.cvode.absTol;
|
||||||
|
}
|
||||||
|
if (!sctx_p->rel_tol.has_value()) {
|
||||||
|
sctx_p->rel_tol = m_config->solver.cvode.relTol;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto result = diagnostics::report_limiting_species(ctx, *user_data.engine, Y_full, E_full, sctx_p->rel_tol.value(), sctx_p->abs_tol.value(), 10, to_file);
|
||||||
if (to_file && result.has_value()) {
|
if (to_file && result.has_value()) {
|
||||||
j["Limiting_Species"] = result.value();
|
j["Limiting_Species"] = result.value();
|
||||||
}
|
}
|
||||||
@@ -1003,8 +1022,9 @@ namespace gridfire::solver {
|
|||||||
0.0
|
0.0
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
||||||
for (size_t i = 0; i < num_components - 1; i++) {
|
for (size_t i = 0; i < num_components - 1; i++) {
|
||||||
const double weight = relTol * std::abs(y_data[i]) + absTol;
|
const double weight = sctx_p->rel_tol.value() * std::abs(y_data[i]) + sctx_p->abs_tol.value();
|
||||||
if (weight == 0.0) {
|
if (weight == 0.0) {
|
||||||
err_ratios[i] = 0.0; // Avoid division by zero
|
err_ratios[i] = 0.0; // Avoid division by zero
|
||||||
continue;
|
continue;
|
||||||
@@ -1013,11 +1033,11 @@ namespace gridfire::solver {
|
|||||||
err_ratios[i] = err_ratio;
|
err_ratios[i] = err_ratio;
|
||||||
}
|
}
|
||||||
|
|
||||||
fourdst::composition::Composition composition(user_data.engine->getNetworkSpecies(*m_scratch_blob), Y_full);
|
fourdst::composition::Composition composition(user_data.engine->getNetworkSpecies(*sctx_p->engine_ctx), Y_full);
|
||||||
fourdst::composition::Composition collectedComposition = user_data.engine->collectComposition(*m_scratch_blob, composition, user_data.T9, user_data.rho);
|
fourdst::composition::Composition collectedComposition = user_data.engine->collectComposition(*sctx_p->engine_ctx, composition, user_data.T9, user_data.rho);
|
||||||
|
|
||||||
auto destructionTimescales = user_data.engine->getSpeciesDestructionTimescales(*m_scratch_blob, collectedComposition, user_data.T9, user_data.rho);
|
auto destructionTimescales = user_data.engine->getSpeciesDestructionTimescales(*sctx_p->engine_ctx, collectedComposition, user_data.T9, user_data.rho);
|
||||||
auto netTimescales = user_data.engine->getSpeciesTimescales(*m_scratch_blob, collectedComposition, user_data.T9, user_data.rho);
|
auto netTimescales = user_data.engine->getSpeciesTimescales(*sctx_p->engine_ctx, collectedComposition, user_data.T9, user_data.rho);
|
||||||
|
|
||||||
bool timescaleOkay = false;
|
bool timescaleOkay = false;
|
||||||
if (destructionTimescales && netTimescales) timescaleOkay = true;
|
if (destructionTimescales && netTimescales) timescaleOkay = true;
|
||||||
@@ -1037,7 +1057,7 @@ namespace gridfire::solver {
|
|||||||
if (destructionTimescales.value().contains(sp)) destructionTimescales_list.emplace_back(destructionTimescales.value().at(sp));
|
if (destructionTimescales.value().contains(sp)) destructionTimescales_list.emplace_back(destructionTimescales.value().at(sp));
|
||||||
else destructionTimescales_list.emplace_back(std::numeric_limits<double>::infinity());
|
else destructionTimescales_list.emplace_back(std::numeric_limits<double>::infinity());
|
||||||
|
|
||||||
speciesStatus_list.push_back(SpeciesStatus_to_string(user_data.engine->getSpeciesStatus(*m_scratch_blob, sp)));
|
speciesStatus_list.push_back(SpeciesStatus_to_string(user_data.engine->getSpeciesStatus(*sctx_p->engine_ctx, sp)));
|
||||||
}
|
}
|
||||||
|
|
||||||
utils::Column<fourdst::atomic::Species> speciesColumn("Species", species_list);
|
utils::Column<fourdst::atomic::Species> speciesColumn("Species", species_list);
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,5 @@
|
|||||||
#include "gridfire/solver/strategies/triggers/engine_partitioning_trigger.h"
|
#include "gridfire/solver/strategies/triggers/engine_partitioning_trigger.h"
|
||||||
#include "gridfire/solver/strategies/CVODE_solver_strategy.h"
|
#include "gridfire/solver/strategies/PointSolver.h"
|
||||||
|
|
||||||
#include "gridfire/trigger/trigger_logical.h"
|
#include "gridfire/trigger/trigger_logical.h"
|
||||||
#include "gridfire/trigger/trigger_abstract.h"
|
#include "gridfire/trigger/trigger_abstract.h"
|
||||||
@@ -28,7 +28,7 @@ namespace gridfire::trigger::solver::CVODE {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool SimulationTimeTrigger::check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const {
|
bool SimulationTimeTrigger::check(const gridfire::solver::PointSolverTimestepContext &ctx) const {
|
||||||
if (ctx.t - m_last_trigger_time >= m_interval) {
|
if (ctx.t - m_last_trigger_time >= m_interval) {
|
||||||
m_hits++;
|
m_hits++;
|
||||||
LOG_TRACE_L2(m_logger, "SimulationTimeTrigger triggered at t = {}, last trigger time was {}, delta = {}", ctx.t, m_last_trigger_time, m_last_trigger_time_delta);
|
LOG_TRACE_L2(m_logger, "SimulationTimeTrigger triggered at t = {}, last trigger time was {}, delta = {}", ctx.t, m_last_trigger_time, m_last_trigger_time_delta);
|
||||||
@@ -38,7 +38,7 @@ namespace gridfire::trigger::solver::CVODE {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
void SimulationTimeTrigger::update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) {
|
void SimulationTimeTrigger::update(const gridfire::solver::PointSolverTimestepContext &ctx) {
|
||||||
if (check(ctx)) {
|
if (check(ctx)) {
|
||||||
m_last_trigger_time_delta = (ctx.t - m_last_trigger_time) - m_interval;
|
m_last_trigger_time_delta = (ctx.t - m_last_trigger_time) - m_interval;
|
||||||
m_last_trigger_time = ctx.t;
|
m_last_trigger_time = ctx.t;
|
||||||
@@ -47,7 +47,7 @@ namespace gridfire::trigger::solver::CVODE {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void SimulationTimeTrigger::step(
|
void SimulationTimeTrigger::step(
|
||||||
const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx
|
const gridfire::solver::PointSolverTimestepContext &ctx
|
||||||
) {
|
) {
|
||||||
// --- SimulationTimeTrigger::step does nothing and is intentionally left blank --- //
|
// --- SimulationTimeTrigger::step does nothing and is intentionally left blank --- //
|
||||||
}
|
}
|
||||||
@@ -65,7 +65,7 @@ namespace gridfire::trigger::solver::CVODE {
|
|||||||
return "Simulation Time Trigger";
|
return "Simulation Time Trigger";
|
||||||
}
|
}
|
||||||
|
|
||||||
TriggerResult SimulationTimeTrigger::why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const {
|
TriggerResult SimulationTimeTrigger::why(const gridfire::solver::PointSolverTimestepContext &ctx) const {
|
||||||
TriggerResult result;
|
TriggerResult result;
|
||||||
result.name = name();
|
result.name = name();
|
||||||
if (check(ctx)) {
|
if (check(ctx)) {
|
||||||
@@ -99,18 +99,18 @@ namespace gridfire::trigger::solver::CVODE {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool OffDiagonalTrigger::check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const {
|
bool OffDiagonalTrigger::check(const gridfire::solver::PointSolverTimestepContext &ctx) const {
|
||||||
//TODO : This currently does nothing
|
//TODO : This currently does nothing
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void OffDiagonalTrigger::update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) {
|
void OffDiagonalTrigger::update(const gridfire::solver::PointSolverTimestepContext &ctx) {
|
||||||
m_updates++;
|
m_updates++;
|
||||||
}
|
}
|
||||||
|
|
||||||
void OffDiagonalTrigger::step(
|
void OffDiagonalTrigger::step(
|
||||||
const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx
|
const gridfire::solver::PointSolverTimestepContext &ctx
|
||||||
) {
|
) {
|
||||||
// --- OffDiagonalTrigger::step does nothing and is intentionally left blank --- //
|
// --- OffDiagonalTrigger::step does nothing and is intentionally left blank --- //
|
||||||
}
|
}
|
||||||
@@ -126,7 +126,7 @@ namespace gridfire::trigger::solver::CVODE {
|
|||||||
return "Off-Diagonal Trigger";
|
return "Off-Diagonal Trigger";
|
||||||
}
|
}
|
||||||
|
|
||||||
TriggerResult OffDiagonalTrigger::why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const {
|
TriggerResult OffDiagonalTrigger::why(const gridfire::solver::PointSolverTimestepContext &ctx) const {
|
||||||
TriggerResult result;
|
TriggerResult result;
|
||||||
result.name = name();
|
result.name = name();
|
||||||
|
|
||||||
@@ -173,7 +173,7 @@ namespace gridfire::trigger::solver::CVODE {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool TimestepCollapseTrigger::check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const {
|
bool TimestepCollapseTrigger::check(const gridfire::solver::PointSolverTimestepContext &ctx) const {
|
||||||
if (m_timestep_window.size() < m_windowSize) {
|
if (m_timestep_window.size() < m_windowSize) {
|
||||||
m_misses++;
|
m_misses++;
|
||||||
return false;
|
return false;
|
||||||
@@ -201,13 +201,13 @@ namespace gridfire::trigger::solver::CVODE {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
void TimestepCollapseTrigger::update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) {
|
void TimestepCollapseTrigger::update(const gridfire::solver::PointSolverTimestepContext &ctx) {
|
||||||
m_updates++;
|
m_updates++;
|
||||||
m_timestep_window.clear();
|
m_timestep_window.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
void TimestepCollapseTrigger::step(
|
void TimestepCollapseTrigger::step(
|
||||||
const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx
|
const gridfire::solver::PointSolverTimestepContext &ctx
|
||||||
) {
|
) {
|
||||||
push_to_fixed_deque(m_timestep_window, ctx.dt, m_windowSize);
|
push_to_fixed_deque(m_timestep_window, ctx.dt, m_windowSize);
|
||||||
// --- TimestepCollapseTrigger::step does nothing and is intentionally left blank --- //
|
// --- TimestepCollapseTrigger::step does nothing and is intentionally left blank --- //
|
||||||
@@ -226,7 +226,7 @@ namespace gridfire::trigger::solver::CVODE {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TriggerResult TimestepCollapseTrigger::why(
|
TriggerResult TimestepCollapseTrigger::why(
|
||||||
const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx
|
const gridfire::solver::PointSolverTimestepContext &ctx
|
||||||
) const {
|
) const {
|
||||||
TriggerResult result;
|
TriggerResult result;
|
||||||
result.name = name();
|
result.name = name();
|
||||||
@@ -263,7 +263,7 @@ namespace gridfire::trigger::solver::CVODE {
|
|||||||
m_windowSize(windowSize) {}
|
m_windowSize(windowSize) {}
|
||||||
|
|
||||||
bool ConvergenceFailureTrigger::check(
|
bool ConvergenceFailureTrigger::check(
|
||||||
const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx
|
const gridfire::solver::PointSolverTimestepContext &ctx
|
||||||
) const {
|
) const {
|
||||||
if (m_window.size() != m_windowSize) {
|
if (m_window.size() != m_windowSize) {
|
||||||
m_misses++;
|
m_misses++;
|
||||||
@@ -278,13 +278,13 @@ namespace gridfire::trigger::solver::CVODE {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void ConvergenceFailureTrigger::update(
|
void ConvergenceFailureTrigger::update(
|
||||||
const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx
|
const gridfire::solver::PointSolverTimestepContext &ctx
|
||||||
) {
|
) {
|
||||||
m_window.clear();
|
m_window.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
void ConvergenceFailureTrigger::step(
|
void ConvergenceFailureTrigger::step(
|
||||||
const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx
|
const gridfire::solver::PointSolverTimestepContext &ctx
|
||||||
) {
|
) {
|
||||||
push_to_fixed_deque(m_window, ctx.currentConvergenceFailures, m_windowSize);
|
push_to_fixed_deque(m_window, ctx.currentConvergenceFailures, m_windowSize);
|
||||||
m_updates++;
|
m_updates++;
|
||||||
@@ -306,7 +306,7 @@ namespace gridfire::trigger::solver::CVODE {
|
|||||||
return "ConvergenceFailureTrigger(abs_failure_threshold=" + std::to_string(m_totalFailures) + ", rel_failure_threshold=" + std::to_string(m_relativeFailureRate) + ", windowSize=" + std::to_string(m_windowSize) + ")";
|
return "ConvergenceFailureTrigger(abs_failure_threshold=" + std::to_string(m_totalFailures) + ", rel_failure_threshold=" + std::to_string(m_relativeFailureRate) + ", windowSize=" + std::to_string(m_windowSize) + ")";
|
||||||
}
|
}
|
||||||
|
|
||||||
TriggerResult ConvergenceFailureTrigger::why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const {
|
TriggerResult ConvergenceFailureTrigger::why(const gridfire::solver::PointSolverTimestepContext &ctx) const {
|
||||||
TriggerResult result;
|
TriggerResult result;
|
||||||
result.name = name();
|
result.name = name();
|
||||||
|
|
||||||
@@ -348,7 +348,7 @@ namespace gridfire::trigger::solver::CVODE {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool ConvergenceFailureTrigger::abs_failure(
|
bool ConvergenceFailureTrigger::abs_failure(
|
||||||
const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx
|
const gridfire::solver::PointSolverTimestepContext &ctx
|
||||||
) const {
|
) const {
|
||||||
if (ctx.currentConvergenceFailures > m_totalFailures) {
|
if (ctx.currentConvergenceFailures > m_totalFailures) {
|
||||||
return true;
|
return true;
|
||||||
@@ -357,7 +357,7 @@ namespace gridfire::trigger::solver::CVODE {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool ConvergenceFailureTrigger::rel_failure(
|
bool ConvergenceFailureTrigger::rel_failure(
|
||||||
const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx
|
const gridfire::solver::PointSolverTimestepContext &ctx
|
||||||
) const {
|
) const {
|
||||||
const float mean = current_mean();
|
const float mean = current_mean();
|
||||||
if (mean < 10) {
|
if (mean < 10) {
|
||||||
@@ -369,13 +369,13 @@ namespace gridfire::trigger::solver::CVODE {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<Trigger<gridfire::solver::CVODESolverStrategy::TimestepContext>> makeEnginePartitioningTrigger(
|
std::unique_ptr<Trigger<gridfire::solver::PointSolverTimestepContext>> makeEnginePartitioningTrigger(
|
||||||
const double simulationTimeInterval,
|
const double simulationTimeInterval,
|
||||||
const double offDiagonalThreshold,
|
const double offDiagonalThreshold,
|
||||||
const double timestepCollapseRatio,
|
const double timestepCollapseRatio,
|
||||||
const size_t maxConvergenceFailures
|
const size_t maxConvergenceFailures
|
||||||
) {
|
) {
|
||||||
using ctx_t = gridfire::solver::CVODESolverStrategy::TimestepContext;
|
using ctx_t = gridfire::solver::PointSolverTimestepContext;
|
||||||
|
|
||||||
// 1. INSTABILITY TRIGGERS (High Priority)
|
// 1. INSTABILITY TRIGGERS (High Priority)
|
||||||
auto convergenceFailureTrigger = std::make_unique<ConvergenceFailureTrigger>(
|
auto convergenceFailureTrigger = std::make_unique<ConvergenceFailureTrigger>(
|
||||||
|
|||||||
@@ -15,8 +15,8 @@ gridfire_sources = files(
|
|||||||
'lib/reaction/weak/weak_interpolator.cpp',
|
'lib/reaction/weak/weak_interpolator.cpp',
|
||||||
'lib/io/network_file.cpp',
|
'lib/io/network_file.cpp',
|
||||||
'lib/io/generative/python.cpp',
|
'lib/io/generative/python.cpp',
|
||||||
'lib/solver/strategies/CVODE_solver_strategy.cpp',
|
'lib/solver/strategies/PointSolver.cpp',
|
||||||
'lib/solver/strategies/SpectralSolverStrategy.cpp',
|
'lib/solver/strategies/GridSolver.cpp',
|
||||||
'lib/solver/strategies/triggers/engine_partitioning_trigger.cpp',
|
'lib/solver/strategies/triggers/engine_partitioning_trigger.cpp',
|
||||||
'lib/screening/screening_types.cpp',
|
'lib/screening/screening_types.cpp',
|
||||||
'lib/screening/screening_weak.cpp',
|
'lib/screening/screening_weak.cpp',
|
||||||
|
|||||||
@@ -19,7 +19,7 @@
|
|||||||
|
|
||||||
#include <clocale>
|
#include <clocale>
|
||||||
|
|
||||||
#include "gridfire/reaction/reaclib.h"
|
#include "gridfire/utils/gf_omp.h"
|
||||||
|
|
||||||
|
|
||||||
static std::terminate_handler g_previousHandler = nullptr;
|
static std::terminate_handler g_previousHandler = nullptr;
|
||||||
@@ -31,7 +31,7 @@ gridfire::NetIn init(const double temp, const double rho, const double tMax) {
|
|||||||
std::setlocale(LC_ALL, "");
|
std::setlocale(LC_ALL, "");
|
||||||
g_previousHandler = std::set_terminate(quill_terminate_handler);
|
g_previousHandler = std::set_terminate(quill_terminate_handler);
|
||||||
quill::Logger* logger = fourdst::logging::LogManager::getInstance().getLogger("log");
|
quill::Logger* logger = fourdst::logging::LogManager::getInstance().getLogger("log");
|
||||||
logger->set_log_level(quill::LogLevel::TraceL2);
|
logger->set_log_level(quill::LogLevel::Info);
|
||||||
|
|
||||||
using namespace gridfire;
|
using namespace gridfire;
|
||||||
const std::vector<double> X = {0.7081145999999999, 2.94e-5, 0.276, 0.003, 0.0011, 9.62e-3, 1.62e-3, 5.16e-4};
|
const std::vector<double> X = {0.7081145999999999, 2.94e-5, 0.276, 0.003, 0.0011, 9.62e-3, 1.62e-3, 5.16e-4};
|
||||||
@@ -143,7 +143,7 @@ void log_results(const gridfire::NetOut& netOut, const gridfire::NetIn& netIn) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void record_abundance_history_callback(const gridfire::solver::CVODESolverStrategy::TimestepContext& ctx) {
|
void record_abundance_history_callback(const gridfire::solver::PointSolverTimestepContext& ctx) {
|
||||||
s_wrote_abundance_history = true;
|
s_wrote_abundance_history = true;
|
||||||
const auto& engine = ctx.engine;
|
const auto& engine = ctx.engine;
|
||||||
// std::unordered_map<std::string, std::pair<double, double>> abundances;
|
// std::unordered_map<std::string, std::pair<double, double>> abundances;
|
||||||
@@ -224,11 +224,12 @@ void quill_terminate_handler()
|
|||||||
std::abort();
|
std::abort();
|
||||||
}
|
}
|
||||||
|
|
||||||
void callback_main(const gridfire::solver::CVODESolverStrategy::TimestepContext& ctx) {
|
void callback_main(const gridfire::solver::PointSolverTimestepContext& ctx) {
|
||||||
record_abundance_history_callback(ctx);
|
record_abundance_history_callback(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
|
GF_PAR_INIT();
|
||||||
using namespace gridfire;
|
using namespace gridfire;
|
||||||
|
|
||||||
constexpr size_t breaks = 1;
|
constexpr size_t breaks = 1;
|
||||||
@@ -239,98 +240,20 @@ int main() {
|
|||||||
const NetIn netIn = init(temp, rho, tMax);
|
const NetIn netIn = init(temp, rho, tMax);
|
||||||
|
|
||||||
policy::MainSequencePolicy stellarPolicy(netIn.composition);
|
policy::MainSequencePolicy stellarPolicy(netIn.composition);
|
||||||
policy::ConstructionResults construct = stellarPolicy.construct();
|
auto [engine, ctx_template] = stellarPolicy.construct();
|
||||||
std::println("Sandbox Engine Stack: {}", stellarPolicy);
|
std::println("Sandbox Engine Stack: {}", stellarPolicy);
|
||||||
std::println("Scratch Blob State: {}", *construct.scratch_blob);
|
std::println("Scratch Blob State: {}", *ctx_template);
|
||||||
|
|
||||||
|
constexpr size_t nZones = 100;
|
||||||
constexpr size_t runs = 1000;
|
std::array<NetIn, nZones> netIns;
|
||||||
auto startTime = std::chrono::high_resolution_clock::now();
|
for (size_t zone = 0; zone < nZones; ++zone) {
|
||||||
|
netIns[zone] = netIn;
|
||||||
// arrays to store timings
|
netIns[zone].temperature = 1.0e7;
|
||||||
std::array<std::chrono::duration<double>, runs> setup_times;
|
|
||||||
std::array<std::chrono::duration<double>, runs> eval_times;
|
|
||||||
std::array<NetOut, runs> serial_results;
|
|
||||||
for (size_t i = 0; i < runs; ++i) {
|
|
||||||
auto start_setup_time = std::chrono::high_resolution_clock::now();
|
|
||||||
std::print("Run {}/{}\r", i + 1, runs);
|
|
||||||
solver::CVODESolverStrategy solver(construct.engine, *construct.scratch_blob);
|
|
||||||
// solver.set_callback(solver::CVODESolverStrategy::TimestepCallback(callback_main));
|
|
||||||
solver.set_stdout_logging_enabled(false);
|
|
||||||
auto end_setup_time = std::chrono::high_resolution_clock::now();
|
|
||||||
std::chrono::duration<double> setup_elapsed = end_setup_time - start_setup_time;
|
|
||||||
setup_times[i] = setup_elapsed;
|
|
||||||
|
|
||||||
auto start_eval_time = std::chrono::high_resolution_clock::now();
|
|
||||||
const NetOut netOut = solver.evaluate(netIn);
|
|
||||||
auto end_eval_time = std::chrono::high_resolution_clock::now();
|
|
||||||
serial_results[i] = netOut;
|
|
||||||
std::chrono::duration<double> eval_elapsed = end_eval_time - start_eval_time;
|
|
||||||
eval_times[i] = eval_elapsed;
|
|
||||||
|
|
||||||
// log_results(netOut, netIn);
|
|
||||||
}
|
|
||||||
auto endTime = std::chrono::high_resolution_clock::now();
|
|
||||||
std::chrono::duration<double> elapsed = endTime - startTime;
|
|
||||||
std::println("");
|
|
||||||
|
|
||||||
// Summarize serial timings
|
|
||||||
double total_setup_time = 0.0;
|
|
||||||
double total_eval_time = 0.0;
|
|
||||||
for (size_t i = 0; i < runs; ++i) {
|
|
||||||
total_setup_time += setup_times[i].count();
|
|
||||||
total_eval_time += eval_times[i].count();
|
|
||||||
}
|
|
||||||
std::println("Average Setup Time over {} runs: {:.6f} seconds", runs, total_setup_time / runs);
|
|
||||||
std::println("Average Evaluation Time over {} runs: {:.6f} seconds", runs, total_eval_time / runs);
|
|
||||||
std::println("Total Time for {} runs: {:.6f} seconds", runs, elapsed.count());
|
|
||||||
std::println("Final H-1 Abundances Serial: {}", serial_results[0].composition.getMolarAbundance(fourdst::atomic::H_1));
|
|
||||||
|
|
||||||
// OPTIONAL: Prevent CppAD from returning memory to the system
|
|
||||||
// during execution to reduce overhead (can speed up tight loops)
|
|
||||||
CppAD::thread_alloc::hold_memory(true);
|
|
||||||
|
|
||||||
std::array<NetOut, runs> parallelResults;
|
|
||||||
std::array<std::chrono::duration<double>, runs> setupTimes;
|
|
||||||
std::array<std::chrono::duration<double>, runs> evalTimes;
|
|
||||||
std::array<std::unique_ptr<gridfire::engine::scratch::StateBlob>, runs> workspaces;
|
|
||||||
for (size_t i = 0; i < runs; ++i) {
|
|
||||||
workspaces[i] = construct.scratch_blob->clone_structure();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const solver::PointSolver localSolver(engine);
|
||||||
|
solver::GridSolverContext solverCtx(*ctx_template);
|
||||||
|
const solver::GridSolver gridSolver(engine, localSolver);
|
||||||
|
|
||||||
// Parallel runs
|
std::vector<NetOut> netOuts = gridSolver.evaluate(solverCtx, netIns | std::ranges::to<std::vector>());
|
||||||
startTime = std::chrono::high_resolution_clock::now();
|
|
||||||
for (size_t i = 0; i < runs; ++i) {
|
|
||||||
auto start_setup_time = std::chrono::high_resolution_clock::now();
|
|
||||||
solver::CVODESolverStrategy solver(construct.engine, *workspaces[i]);
|
|
||||||
solver.set_stdout_logging_enabled(false);
|
|
||||||
auto end_setup_time = std::chrono::high_resolution_clock::now();
|
|
||||||
std::chrono::duration<double> setup_elapsed = end_setup_time - start_setup_time;
|
|
||||||
setupTimes[i] = setup_elapsed;
|
|
||||||
auto start_eval_time = std::chrono::high_resolution_clock::now();
|
|
||||||
parallelResults[i] = solver.evaluate(netIn);
|
|
||||||
auto end_eval_time = std::chrono::high_resolution_clock::now();
|
|
||||||
std::chrono::duration<double> eval_elapsed = end_eval_time - start_eval_time;
|
|
||||||
evalTimes[i] = eval_elapsed;
|
|
||||||
}
|
|
||||||
endTime = std::chrono::high_resolution_clock::now();
|
|
||||||
elapsed = endTime - startTime;
|
|
||||||
std::println("");
|
|
||||||
|
|
||||||
// Summarize parallel timings
|
|
||||||
total_setup_time = 0.0;
|
|
||||||
total_eval_time = 0.0;
|
|
||||||
for (size_t i = 0; i < runs; ++i) {
|
|
||||||
total_setup_time += setupTimes[i].count();
|
|
||||||
total_eval_time += evalTimes[i].count();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::println("Average Parallel Setup Time over {} runs: {:.6f} seconds", runs, total_setup_time / runs);
|
|
||||||
std::println("Average Parallel Evaluation Time over {} runs: {:.6f} seconds", runs, total_eval_time / runs);
|
|
||||||
std::println("Total Parallel Time for {} runs: {:.6f} seconds", runs, elapsed.count());
|
|
||||||
|
|
||||||
std::println("Final H-1 Abundances Parallel: {}", utils::iterable_to_delimited_string(parallelResults, ",", [](const auto& result) {
|
|
||||||
return result.composition.getMolarAbundance(fourdst::atomic::H_1);
|
|
||||||
}));
|
|
||||||
}
|
}
|
||||||
@@ -4,8 +4,8 @@ executable(
|
|||||||
dependencies: [gridfire_dep, cli11_dep],
|
dependencies: [gridfire_dep, cli11_dep],
|
||||||
)
|
)
|
||||||
|
|
||||||
executable(
|
#executable(
|
||||||
'spectral_sandbox',
|
# 'spectral_sandbox',
|
||||||
'spectral_main.cpp',
|
# 'spectral_main.cpp',
|
||||||
dependencies: [gridfire_dep, cli11_dep]
|
# dependencies: [gridfire_dep, cli11_dep]
|
||||||
)
|
#)
|
||||||
|
|||||||
@@ -1,108 +0,0 @@
|
|||||||
#include <iostream>
|
|
||||||
#include <fstream>
|
|
||||||
#include <chrono>
|
|
||||||
#include <thread>
|
|
||||||
|
|
||||||
#include "gridfire/gridfire.h"
|
|
||||||
|
|
||||||
#include "fourdst/composition/composition.h"
|
|
||||||
#include "fourdst/logging/logging.h"
|
|
||||||
#include "fourdst/atomic/species.h"
|
|
||||||
#include "fourdst/composition/utils.h"
|
|
||||||
|
|
||||||
#include "quill/Logger.h"
|
|
||||||
#include "quill/Backend.h"
|
|
||||||
#include "CLI/CLI.hpp"
|
|
||||||
|
|
||||||
#include <clocale>
|
|
||||||
|
|
||||||
|
|
||||||
static std::terminate_handler g_previousHandler = nullptr;
|
|
||||||
static std::vector<std::pair<double, std::unordered_map<std::string, std::pair<double, double>>>> g_callbackHistory;
|
|
||||||
static bool s_wrote_abundance_history = false;
|
|
||||||
void quill_terminate_handler();
|
|
||||||
|
|
||||||
std::vector<double> linspace(const double start, const double end, const size_t num) {
|
|
||||||
std::vector<double> result;
|
|
||||||
if (num == 0) return result;
|
|
||||||
if (num == 1) {
|
|
||||||
result.push_back(start);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
const double step = (end - start) / static_cast<double>(num - 1);
|
|
||||||
for (size_t i = 0; i < num; ++i) {
|
|
||||||
result.push_back(start + i * step);
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<gridfire::NetIn> init(const double tMin, const double tMax, const double rhoMin, const double rhoMax, const double nShells, const double evolveTime) {
|
|
||||||
std::setlocale(LC_ALL, "");
|
|
||||||
g_previousHandler = std::set_terminate(quill_terminate_handler);
|
|
||||||
quill::Logger* logger = fourdst::logging::LogManager::getInstance().getLogger("log");
|
|
||||||
logger->set_log_level(quill::LogLevel::TraceL2);
|
|
||||||
LOG_INFO(logger, "Initializing GridFire Spectral Solver Sandbox...");
|
|
||||||
|
|
||||||
using namespace gridfire;
|
|
||||||
const std::vector<double> X = {0.7081145999999999, 2.94e-5, 0.276, 0.003, 0.0011, 9.62e-3, 1.62e-3, 5.16e-4};
|
|
||||||
const std::vector<std::string> symbols = {"H-1", "He-3", "He-4", "C-12", "N-14", "O-16", "Ne-20", "Mg-24"};
|
|
||||||
|
|
||||||
|
|
||||||
const fourdst::composition::Composition composition = fourdst::composition::buildCompositionFromMassFractions(symbols, X);
|
|
||||||
|
|
||||||
std::vector<NetIn> netIns;
|
|
||||||
for (const auto& [T, ρ]: std::views::zip(linspace(tMin, tMax, nShells), linspace(rhoMax, rhoMin, nShells))) {
|
|
||||||
NetIn netIn;
|
|
||||||
netIn.composition = composition;
|
|
||||||
netIn.temperature = T;
|
|
||||||
netIn.density = ρ;
|
|
||||||
netIn.energy = 0;
|
|
||||||
|
|
||||||
netIn.tMax = evolveTime;
|
|
||||||
netIn.dt0 = 1e-12;
|
|
||||||
netIns.push_back(netIn);
|
|
||||||
}
|
|
||||||
|
|
||||||
return netIns;
|
|
||||||
}
|
|
||||||
void quill_terminate_handler()
|
|
||||||
{
|
|
||||||
quill::Backend::stop();
|
|
||||||
if (g_previousHandler)
|
|
||||||
g_previousHandler();
|
|
||||||
else
|
|
||||||
std::abort();
|
|
||||||
}
|
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
|
||||||
using namespace gridfire;
|
|
||||||
|
|
||||||
CLI::App app{"GridFire Sandbox Application."};
|
|
||||||
|
|
||||||
double tMin = 1.5e7;
|
|
||||||
double tMax = 1.7e7;
|
|
||||||
double rhoMin = 1.5e2;
|
|
||||||
double rhoMax = 1.5e2;
|
|
||||||
double nShells = 15;
|
|
||||||
double evolveTime = 3.1536e+16;
|
|
||||||
|
|
||||||
app.add_option("--tMin", tMin, "Minimum time in seconds");
|
|
||||||
app.add_option("--tMax", tMax, "Maximum time in seconds");
|
|
||||||
app.add_option("--rhoMin", rhoMin, "Minimum density in g/cm^3");
|
|
||||||
app.add_option("--rhoMax", rhoMax, "Maximum density in g/cm^3");
|
|
||||||
app.add_option("--nShells", nShells, "Number of shells");
|
|
||||||
app.add_option("--evolveTime", evolveTime, "Maximum time in seconds");
|
|
||||||
|
|
||||||
CLI11_PARSE(app, argc, argv);
|
|
||||||
|
|
||||||
const std::vector<NetIn> netIns = init(tMin, tMax, rhoMin, rhoMax, nShells, evolveTime);
|
|
||||||
|
|
||||||
policy::MainSequencePolicy stellarPolicy(netIns[0].composition);
|
|
||||||
stellarPolicy.construct();
|
|
||||||
policy::ConstructionResults construct = stellarPolicy.construct();
|
|
||||||
|
|
||||||
solver::SpectralSolverStrategy solver(construct.engine);
|
|
||||||
std::vector<double> mass_coords = linspace(1e-5, 1.0, nShells);
|
|
||||||
|
|
||||||
std::vector<NetOut> results = solver.evaluate(netIns, mass_coords, *construct.scratch_blob);
|
|
||||||
}
|
|
||||||
373
tools/cli/gf_quick/main.cpp
Normal file
373
tools/cli/gf_quick/main.cpp
Normal file
@@ -0,0 +1,373 @@
|
|||||||
|
// ReSharper disable CppUnusedIncludeDirective
|
||||||
|
#include <iostream>
|
||||||
|
#include <fstream>
|
||||||
|
#include <chrono>
|
||||||
|
#include <thread>
|
||||||
|
#include <format>
|
||||||
|
|
||||||
|
#include "gridfire/gridfire.h"
|
||||||
|
#include <cppad/utility/thread_alloc.hpp> // Required for parallel_setup
|
||||||
|
|
||||||
|
#include "fourdst/composition/composition.h"
|
||||||
|
#include "fourdst/logging/logging.h"
|
||||||
|
#include "fourdst/atomic/species.h"
|
||||||
|
#include "fourdst/composition/utils.h"
|
||||||
|
|
||||||
|
#include "quill/Logger.h"
|
||||||
|
#include "quill/Backend.h"
|
||||||
|
#include "CLI/CLI.hpp"
|
||||||
|
|
||||||
|
#include <clocale>
|
||||||
|
|
||||||
|
#include "gridfire/utils/gf_omp.h"
|
||||||
|
|
||||||
|
|
||||||
|
static std::terminate_handler g_previousHandler = nullptr;
|
||||||
|
static std::vector<std::pair<double, std::unordered_map<std::string, std::pair<double, double>>>> g_callbackHistory;
|
||||||
|
static bool s_wrote_abundance_history = false;
|
||||||
|
void quill_terminate_handler();
|
||||||
|
|
||||||
|
using namespace fourdst::composition;
|
||||||
|
Composition rescale(const Composition& comp, double target_X, double target_Z) {
|
||||||
|
// 1. Validate inputs
|
||||||
|
if (target_X < 0.0 || target_Z < 0.0 || (target_X + target_Z) > 1.0 + 1e-14) {
|
||||||
|
throw std::invalid_argument("Target mass fractions X and Z must be non-negative and sum to <= 1.0");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Force high precision for the target Y to ensure X+Y+Z = 1.0 exactly in our logic
|
||||||
|
long double ld_target_X = static_cast<long double>(target_X);
|
||||||
|
long double ld_target_Z = static_cast<long double>(target_Z);
|
||||||
|
long double ld_target_Y = 1.0L - ld_target_X - ld_target_Z;
|
||||||
|
|
||||||
|
// Clamp Y to 0 if it dipped slightly below due to precision (e.g. X+Z=1.0000000001)
|
||||||
|
if (ld_target_Y < 0.0L) ld_target_Y = 0.0L;
|
||||||
|
|
||||||
|
// 2. Manually calculate current Mass Totals (bypass getCanonicalComposition to avoid crashes)
|
||||||
|
long double total_mass_H = 0.0L;
|
||||||
|
long double total_mass_He = 0.0L;
|
||||||
|
long double total_mass_Z = 0.0L;
|
||||||
|
|
||||||
|
// We need to iterate and identify species types manually
|
||||||
|
// Standard definition: H (z=1), He (z=2), Metals (z>2)
|
||||||
|
// Note: We use long double accumulators to prevent summation drift
|
||||||
|
for (const auto& [spec, molar_abundance] : comp) {
|
||||||
|
// Retrieve atomic properties.
|
||||||
|
// Note: usage assumes fourdst::atomic::Species has .z() and .mass()
|
||||||
|
// consistent with the provided composition.cpp
|
||||||
|
int z = spec.z();
|
||||||
|
double a = spec.mass();
|
||||||
|
|
||||||
|
long double mass_contribution = static_cast<long double>(molar_abundance) * static_cast<long double>(a);
|
||||||
|
|
||||||
|
if (z == 1) {
|
||||||
|
total_mass_H += mass_contribution;
|
||||||
|
} else if (z == 2) {
|
||||||
|
total_mass_He += mass_contribution;
|
||||||
|
} else {
|
||||||
|
total_mass_Z += mass_contribution;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
long double total_mass_current = total_mass_H + total_mass_He + total_mass_Z;
|
||||||
|
|
||||||
|
// Edge case: Empty composition
|
||||||
|
if (total_mass_current <= 0.0L) {
|
||||||
|
// Return empty or throw? If input was empty, return empty.
|
||||||
|
if (comp.size() == 0) return comp;
|
||||||
|
throw std::runtime_error("Input composition has zero total mass.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Calculate Scaling Factors
|
||||||
|
// Factor = (Target_Mass_Fraction / Old_Mass_Fraction)
|
||||||
|
// = (Target_Mass_Fraction) / (Old_Group_Mass / Total_Mass)
|
||||||
|
// = (Target_Mass_Fraction * Total_Mass) / Old_Group_Mass
|
||||||
|
|
||||||
|
long double scale_H = 0.0L;
|
||||||
|
long double scale_He = 0.0L;
|
||||||
|
long double scale_Z = 0.0L;
|
||||||
|
|
||||||
|
if (ld_target_X > 1e-16L) {
|
||||||
|
if (total_mass_H <= 1e-19L) {
|
||||||
|
throw std::runtime_error("Cannot rescale Hydrogen to " + std::to_string(target_X) +
|
||||||
|
" because input has no Hydrogen.");
|
||||||
|
}
|
||||||
|
scale_H = (ld_target_X * total_mass_current) / total_mass_H;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ld_target_Y > 1e-16L) {
|
||||||
|
if (total_mass_He <= 1e-19L) {
|
||||||
|
throw std::runtime_error("Cannot rescale Helium to " + std::to_string((double)ld_target_Y) +
|
||||||
|
" because input has no Helium.");
|
||||||
|
}
|
||||||
|
scale_He = (ld_target_Y * total_mass_current) / total_mass_He;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ld_target_Z > 1e-16L) {
|
||||||
|
if (total_mass_Z <= 1e-19L) {
|
||||||
|
throw std::runtime_error("Cannot rescale Metals to " + std::to_string(target_Z) +
|
||||||
|
" because input has no Metals.");
|
||||||
|
}
|
||||||
|
scale_Z = (ld_target_Z * total_mass_current) / total_mass_Z;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Apply Scaling and Construct New Vectors
|
||||||
|
std::vector<fourdst::atomic::Species> new_species;
|
||||||
|
std::vector<double> new_abundances;
|
||||||
|
new_species.reserve(comp.size());
|
||||||
|
new_abundances.reserve(comp.size());
|
||||||
|
|
||||||
|
for (const auto& [spec, abundance] : comp) {
|
||||||
|
new_species.push_back(spec);
|
||||||
|
|
||||||
|
long double factor = 0.0L;
|
||||||
|
int z = spec.z();
|
||||||
|
|
||||||
|
if (z == 1) {
|
||||||
|
factor = scale_H;
|
||||||
|
} else if (z == 2) {
|
||||||
|
factor = scale_He;
|
||||||
|
} else {
|
||||||
|
factor = scale_Z;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate new abundance in long double then cast back
|
||||||
|
long double new_val_ld = static_cast<long double>(abundance) * factor;
|
||||||
|
new_abundances.push_back(static_cast<double>(new_val_ld));
|
||||||
|
}
|
||||||
|
|
||||||
|
return Composition(new_species, new_abundances);
|
||||||
|
}
|
||||||
|
|
||||||
|
gridfire::NetIn init(const double temp, const double rho, const double tMax) {
|
||||||
|
std::setlocale(LC_ALL, "");
|
||||||
|
g_previousHandler = std::set_terminate(quill_terminate_handler);
|
||||||
|
quill::Logger* logger = fourdst::logging::LogManager::getInstance().getLogger("log");
|
||||||
|
logger->set_log_level(quill::LogLevel::Info);
|
||||||
|
|
||||||
|
using namespace gridfire;
|
||||||
|
const std::vector<double> X = {0.7081145999999999, 2.94e-5, 0.276, 0.003, 0.0011, 9.62e-3, 1.62e-3, 5.16e-4};
|
||||||
|
const std::vector<std::string> symbols = {"H-1", "He-3", "He-4", "C-12", "N-14", "O-16", "Ne-20", "Mg-24"};
|
||||||
|
|
||||||
|
|
||||||
|
const fourdst::composition::Composition composition = fourdst::composition::buildCompositionFromMassFractions(symbols, X);
|
||||||
|
|
||||||
|
NetIn netIn;
|
||||||
|
netIn.composition = composition;
|
||||||
|
netIn.temperature = temp;
|
||||||
|
netIn.density = rho;
|
||||||
|
netIn.energy = 0;
|
||||||
|
|
||||||
|
netIn.tMax = tMax;
|
||||||
|
netIn.dt0 = 1e-12;
|
||||||
|
|
||||||
|
return netIn;
|
||||||
|
}
|
||||||
|
|
||||||
|
void log_results(const gridfire::NetOut& netOut, const gridfire::NetIn& netIn) {
|
||||||
|
std::vector<fourdst::atomic::Species> logSpecies = {
|
||||||
|
fourdst::atomic::H_1,
|
||||||
|
fourdst::atomic::He_3,
|
||||||
|
fourdst::atomic::He_4,
|
||||||
|
fourdst::atomic::C_12,
|
||||||
|
fourdst::atomic::N_14,
|
||||||
|
fourdst::atomic::O_16,
|
||||||
|
fourdst::atomic::Ne_20,
|
||||||
|
fourdst::atomic::Mg_24
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<double> initial;
|
||||||
|
std::vector<double> final;
|
||||||
|
std::vector<double> delta;
|
||||||
|
std::vector<double> fractional;
|
||||||
|
for (const auto& species : logSpecies) {
|
||||||
|
double initial_X = netIn.composition.getMassFraction(species);
|
||||||
|
double final_X = netOut.composition.getMassFraction(species);
|
||||||
|
double delta_X = final_X - initial_X;
|
||||||
|
double fractionalChange = (delta_X) / initial_X * 100.0;
|
||||||
|
|
||||||
|
initial.push_back(initial_X);
|
||||||
|
final.push_back(final_X);
|
||||||
|
delta.push_back(delta_X);
|
||||||
|
fractional.push_back(fractionalChange);
|
||||||
|
}
|
||||||
|
|
||||||
|
initial.push_back(0.0); // Placeholder for energy
|
||||||
|
final.push_back(netOut.energy);
|
||||||
|
delta.push_back(netOut.energy);
|
||||||
|
fractional.push_back(0.0); // Placeholder for energy
|
||||||
|
|
||||||
|
initial.push_back(0.0);
|
||||||
|
final.push_back(netOut.dEps_dT);
|
||||||
|
delta.push_back(netOut.dEps_dT);
|
||||||
|
fractional.push_back(0.0);
|
||||||
|
|
||||||
|
initial.push_back(0.0);
|
||||||
|
final.push_back(netOut.dEps_dRho);
|
||||||
|
delta.push_back(netOut.dEps_dRho);
|
||||||
|
fractional.push_back(0.0);
|
||||||
|
|
||||||
|
initial.push_back(0.0);
|
||||||
|
final.push_back(netOut.specific_neutrino_energy_loss);
|
||||||
|
delta.push_back(netOut.specific_neutrino_energy_loss);
|
||||||
|
fractional.push_back(0.0);
|
||||||
|
|
||||||
|
initial.push_back(0.0);
|
||||||
|
final.push_back(netOut.specific_neutrino_flux);
|
||||||
|
delta.push_back(netOut.specific_neutrino_flux);
|
||||||
|
fractional.push_back(0.0);
|
||||||
|
|
||||||
|
initial.push_back(netIn.composition.getMeanParticleMass());
|
||||||
|
final.push_back(netOut.composition.getMeanParticleMass());
|
||||||
|
delta.push_back(final.back() - initial.back());
|
||||||
|
fractional.push_back((final.back() - initial.back()) / initial.back() * 100.0);
|
||||||
|
|
||||||
|
std::vector<std::string> rowLabels = [&]() -> std::vector<std::string> {
|
||||||
|
std::vector<std::string> labels;
|
||||||
|
for (const auto& species : logSpecies) {
|
||||||
|
labels.emplace_back(species.name());
|
||||||
|
}
|
||||||
|
labels.emplace_back("ε");
|
||||||
|
labels.emplace_back("dε/dT");
|
||||||
|
labels.emplace_back("dε/dρ");
|
||||||
|
labels.emplace_back("Eν");
|
||||||
|
labels.emplace_back("Fν");
|
||||||
|
labels.emplace_back("<μ>");
|
||||||
|
return labels;
|
||||||
|
}();
|
||||||
|
|
||||||
|
|
||||||
|
gridfire::utils::Column<std::string> paramCol("Parameter", rowLabels);
|
||||||
|
gridfire::utils::Column<double> initialCol("Initial", initial);
|
||||||
|
gridfire::utils::Column<double> finalCol ("Final", final);
|
||||||
|
gridfire::utils::Column<double> deltaCol ("δ", delta);
|
||||||
|
gridfire::utils::Column<double> percentCol("% Change", fractional);
|
||||||
|
|
||||||
|
std::vector<std::unique_ptr<gridfire::utils::ColumnBase>> columns;
|
||||||
|
columns.push_back(std::make_unique<gridfire::utils::Column<std::string>>(paramCol));
|
||||||
|
columns.push_back(std::make_unique<gridfire::utils::Column<double>>(initialCol));
|
||||||
|
columns.push_back(std::make_unique<gridfire::utils::Column<double>>(finalCol));
|
||||||
|
columns.push_back(std::make_unique<gridfire::utils::Column<double>>(deltaCol));
|
||||||
|
columns.push_back(std::make_unique<gridfire::utils::Column<double>>(percentCol));
|
||||||
|
|
||||||
|
|
||||||
|
gridfire::utils::print_table("Simulation Results", columns);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void record_abundance_history_callback(const gridfire::solver::PointSolverTimestepContext& ctx) {
|
||||||
|
s_wrote_abundance_history = true;
|
||||||
|
const auto& engine = ctx.engine;
|
||||||
|
// std::unordered_map<std::string, std::pair<double, double>> abundances;
|
||||||
|
std::vector<double> Y;
|
||||||
|
for (const auto& species : engine.getNetworkSpecies(ctx.state_ctx)) {
|
||||||
|
const size_t sid = engine.getSpeciesIndex(ctx.state_ctx, species);
|
||||||
|
double y = N_VGetArrayPointer(ctx.state)[sid];
|
||||||
|
Y.push_back(y > 0.0 ? y : 0.0); // Regularize tiny negative abundances to zero
|
||||||
|
}
|
||||||
|
|
||||||
|
fourdst::composition::Composition comp(engine.getNetworkSpecies(ctx.state_ctx), Y);
|
||||||
|
|
||||||
|
|
||||||
|
std::unordered_map<std::string, std::pair<double, double>> abundances;
|
||||||
|
for (const auto& sp : comp | std::views::keys) {
|
||||||
|
abundances.emplace(std::string(sp.name()), std::make_pair(sp.mass(), comp.getMolarAbundance(sp)));
|
||||||
|
}
|
||||||
|
g_callbackHistory.emplace_back(ctx.t, abundances);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void save_callback_data(const std::string_view filename) {
|
||||||
|
std::set<std::string> unique_species;
|
||||||
|
for (const auto &abundances: g_callbackHistory | std::views::values) {
|
||||||
|
for (const auto &species_name: abundances | std::views::keys) {
|
||||||
|
unique_species.insert(species_name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::ofstream csvFile(filename.data(), std::ios::out);
|
||||||
|
csvFile << "t,";
|
||||||
|
|
||||||
|
size_t i = 0;
|
||||||
|
for (const auto& species_name : unique_species) {
|
||||||
|
csvFile << species_name;
|
||||||
|
if (i < unique_species.size() - 1) {
|
||||||
|
csvFile << ",";
|
||||||
|
}
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
|
||||||
|
csvFile << "\n";
|
||||||
|
|
||||||
|
for (const auto& [time, data] : g_callbackHistory) {
|
||||||
|
csvFile << time << ",";
|
||||||
|
size_t j = 0;
|
||||||
|
for (const auto& species_name : unique_species) {
|
||||||
|
if (!data.contains(species_name)) {
|
||||||
|
csvFile << "0.0";
|
||||||
|
} else {
|
||||||
|
csvFile << data.at(species_name).second;
|
||||||
|
}
|
||||||
|
if (j < unique_species.size() - 1) {
|
||||||
|
csvFile << ",";
|
||||||
|
}
|
||||||
|
++j;
|
||||||
|
}
|
||||||
|
csvFile << "\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
csvFile.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
void log_callback_data(const double temp) {
|
||||||
|
if (s_wrote_abundance_history) {
|
||||||
|
std::cout << "Saving abundance history to abundance_history.csv" << std::endl;
|
||||||
|
save_callback_data("abundance_history_" + std::to_string(temp) + ".csv");
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
void quill_terminate_handler()
|
||||||
|
{
|
||||||
|
log_callback_data(1.5e7);
|
||||||
|
quill::Backend::stop();
|
||||||
|
if (g_previousHandler)
|
||||||
|
g_previousHandler();
|
||||||
|
else
|
||||||
|
std::abort();
|
||||||
|
}
|
||||||
|
|
||||||
|
void callback_main(const gridfire::solver::PointSolverTimestepContext& ctx) {
|
||||||
|
record_abundance_history_callback(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char** argv) {
|
||||||
|
GF_PAR_INIT();
|
||||||
|
using namespace gridfire;
|
||||||
|
|
||||||
|
double temp = 1.5e7;
|
||||||
|
double rho = 1.5e2;
|
||||||
|
double tMax = 3.1536e+16;
|
||||||
|
double X = 0.7;
|
||||||
|
double Z = 0.02;
|
||||||
|
|
||||||
|
|
||||||
|
CLI::App app("GridFire Quick CLI Test");
|
||||||
|
// Add temp, rho, and tMax as options if desired
|
||||||
|
app.add_option("--temp", temp, "Initial Temperature")->default_val(std::format("{:5.2E}", temp));
|
||||||
|
app.add_option("--rho", rho, "Initial Density")->default_val(std::format("{:5.2E}", rho));
|
||||||
|
app.add_option("--tmax", tMax, "Maximum Time")->default_val(std::format("{:5.2E}", tMax));
|
||||||
|
// app.add_option("--X", X, "Target Hydrogen Mass Fraction")->default_val(std::format("{:5.2f}", X));
|
||||||
|
// app.add_option("--Z", Z, "Target Metal Mass Fraction")->default_val(std::format("{:5.2f}", Z));
|
||||||
|
|
||||||
|
CLI11_PARSE(app, argc, argv);
|
||||||
|
NetIn netIn = init(temp, rho, tMax);
|
||||||
|
// netIn.composition = rescale(netIn.composition, X, Z);
|
||||||
|
|
||||||
|
policy::MainSequencePolicy stellarPolicy(netIn.composition);
|
||||||
|
auto [engine, ctx_template] = stellarPolicy.construct();
|
||||||
|
|
||||||
|
solver::PointSolverContext solver_context(*ctx_template);
|
||||||
|
solver::PointSolver solver(engine);
|
||||||
|
|
||||||
|
NetOut result = solver.evaluate(solver_context, netIn);
|
||||||
|
log_results(result, netIn);
|
||||||
|
}
|
||||||
1
tools/cli/gf_quick/meson.build
Normal file
1
tools/cli/gf_quick/meson.build
Normal file
@@ -0,0 +1 @@
|
|||||||
|
executable('gf_quick', 'main.cpp', dependencies: [gridfire_dep, cli11_dep])
|
||||||
1
tools/cli/meson.build
Normal file
1
tools/cli/meson.build
Normal file
@@ -0,0 +1 @@
|
|||||||
|
subdir('gf_quick')
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
if get_option('build_tools')
|
if get_option('build_tools')
|
||||||
subdir('config')
|
subdir('config')
|
||||||
|
subdir('cli')
|
||||||
endif
|
endif
|
||||||
Reference in New Issue
Block a user