feat(KINSOL): Switch from Eigen to KINSOL
Previously QSE solving was done using Eigen. While this worked we were limited in the ability to use previous iterations to speed up later steps. We have switched to KINSOL, from SUNDIALS, for linear solving. This has drastically speed up the process of solving for QSE abundances, primarily because the jacobian matrix does not need to be generated every single time time a QSE abundance is requested.
This commit is contained in:
@@ -22,65 +22,8 @@
|
||||
#include "gridfire/solver/strategies/triggers/engine_partitioning_trigger.h"
|
||||
#include "gridfire/trigger/procedures/trigger_pprint.h"
|
||||
#include "gridfire/exceptions/error_solver.h"
|
||||
#include "gridfire/utils/logging.h"
|
||||
#include "gridfire/utils/sundials.h"
|
||||
|
||||
namespace {
|
||||
|
||||
std::unordered_map<int, std::string> cvode_ret_code_map {
|
||||
{0, "CV_SUCCESS: The solver succeeded."},
|
||||
{1, "CV_TSTOP_RETURN: The solver reached the specified stopping time."},
|
||||
{2, "CV_ROOT_RETURN: A root was found."},
|
||||
{-99, "CV_WARNING: CVODE succeeded but in an unusual manner"},
|
||||
{-1, "CV_TOO_MUCH_WORK: The solver took too many internal steps."},
|
||||
{-2, "CV_TOO_MUCH_ACC: The solver could not satisfy the accuracy requested."},
|
||||
{-3, "CV_ERR_FAILURE: The solver encountered a non-recoverable error."},
|
||||
{-4, "CV_CONV_FAILURE: The solver failed to converge."},
|
||||
{-5, "CV_LINIT_FAIL: The linear solver's initialization function failed."},
|
||||
{-6, "CV_LSETUP_FAIL: The linear solver's setup function failed."},
|
||||
{-7, "CV_LSOLVE_FAIL: The linear solver's solve function failed."},
|
||||
{-8, "CV_RHSFUNC_FAIL: The right-hand side function failed in an unrecoverable manner."},
|
||||
{-9, "CV_FIRST_RHSFUNC_ERR: The right-hand side function failed at the first call."},
|
||||
{-10, "CV_REPTD_RHSFUNC_ERR: The right-hand side function repeatedly failed recoverable."},
|
||||
{-11, "CV_UNREC_RHSFUNC_ERR: The right-hand side function failed unrecoverably."},
|
||||
{-12, "CV_RTFUNC_FAIL: The rootfinding function failed in an unrecoverable manner."},
|
||||
{-13, "CV_NLS_INIT_FAIL: The nonlinear solver's initialization function failed."},
|
||||
{-14, "CV_NLS_SETUP_FAIL: The nonlinear solver's setup function failed."},
|
||||
{-15, "CV_CONSTR_FAIL : The inequality constraint was violated and the solver was unable to recover."},
|
||||
{-16, "CV_NLS_FAIL: The nonlinear solver's solve function failed."},
|
||||
{-20, "CV_MEM_FAIL: Memory allocation failed."},
|
||||
{-21, "CV_MEM_NULL: The CVODE memory structure is NULL."},
|
||||
{-22, "CV_ILL_INPUT: An illegal input was detected."},
|
||||
{-23, "CV_NO_MALLOC: The CVODE memory structure has not been allocated."},
|
||||
{-24, "CV_BAD_K: The value of k is invalid."},
|
||||
{-25, "CV_BAD_T: The value of t is invalid."},
|
||||
{-26, "CV_BAD_DKY: The value of dky is invalid."},
|
||||
{-27, "CV_TOO_CLOSE: The time points are too close together."},
|
||||
{-28, "CV_VECTOROP_ERR: A vector operation failed."},
|
||||
{-29, "CV_PROJ_MEM_NULL: The projection memory structure is NULL."},
|
||||
{-30, "CV_PROJFUNC_FAIL: The projection function failed in an unrecoverable manner."},
|
||||
{-31, "CV_REPTD_PROJFUNC_ERR: The projection function has repeated recoverable errors."}
|
||||
};
|
||||
void check_cvode_flag(const int flag, const std::string& func_name) {
|
||||
if (flag < 0) {
|
||||
if (!cvode_ret_code_map.contains(flag)) {
|
||||
throw gridfire::exceptions::CVODESolverFailureError("CVODE error in " + func_name + ": Unknown error code: " + std::to_string(flag));
|
||||
}
|
||||
throw gridfire::exceptions::CVODESolverFailureError("CVODE error in " + func_name + ": " + cvode_ret_code_map.at(flag));
|
||||
}
|
||||
}
|
||||
|
||||
N_Vector init_sun_vector(uint64_t size, SUNContext sun_ctx) {
|
||||
#ifdef SUNDIALS_HAVE_OPENMP
|
||||
const N_Vector vec = N_VNew_OpenMP(size, 0, sun_ctx);
|
||||
#elif SUNDIALS_HAVE_PTHREADS
|
||||
const N_Vector vec = N_VNew_Pthreads(size, sun_ctx);
|
||||
#else
|
||||
const N_Vector vec = N_VNew_Serial(static_cast<long long>(size), sun_ctx);
|
||||
#endif
|
||||
check_cvode_flag(vec == nullptr ? -1 : 0, "N_VNew");
|
||||
return vec;
|
||||
}
|
||||
}
|
||||
|
||||
namespace gridfire::solver {
|
||||
|
||||
@@ -175,7 +118,7 @@ namespace gridfire::solver {
|
||||
LOG_TRACE_L1(m_logger, "Number of species: {} ({} independent variables)", numSpecies, N);
|
||||
LOG_TRACE_L1(m_logger, "Initializing CVODE resources");
|
||||
m_cvode_mem = CVodeCreate(CV_BDF, m_sun_ctx);
|
||||
check_cvode_flag(m_cvode_mem == nullptr ? -1 : 0, "CVodeCreate");
|
||||
utils::check_cvode_flag(m_cvode_mem == nullptr ? -1 : 0, "CVodeCreate");
|
||||
|
||||
initialize_cvode_integration_resources(N, numSpecies, 0.0, equilibratedComposition, absTol, relTol, 0.0);
|
||||
|
||||
@@ -206,11 +149,11 @@ namespace gridfire::solver {
|
||||
user_data.networkSpecies = &m_engine.getNetworkSpecies();
|
||||
user_data.captured_exception.reset();
|
||||
|
||||
check_cvode_flag(CVodeSetUserData(m_cvode_mem, &user_data), "CVodeSetUserData");
|
||||
utils::check_cvode_flag(CVodeSetUserData(m_cvode_mem, &user_data), "CVodeSetUserData");
|
||||
|
||||
LOG_TRACE_L2(m_logger, "Taking one CVODE step...");
|
||||
int flag = CVode(m_cvode_mem, netIn.tMax, m_Y, ¤t_time, CV_ONE_STEP);
|
||||
LOG_TRACE_L2(m_logger, "CVODE step complete. Current time: {}, step status: {}", current_time, cvode_ret_code_map.at(flag));
|
||||
LOG_TRACE_L2(m_logger, "CVODE step complete. Current time: {}, step status: {}", current_time, utils::cvode_ret_code_map.at(flag));
|
||||
|
||||
if (user_data.captured_exception){
|
||||
std::rethrow_exception(std::make_exception_ptr(*user_data.captured_exception));
|
||||
@@ -220,7 +163,7 @@ namespace gridfire::solver {
|
||||
// TODO: Come up with some way to save these to a file rather than spamming stdout. JSON maybe? OPAT?
|
||||
// log_step_diagnostics(user_data, true, false, false);
|
||||
// exit(0);
|
||||
check_cvode_flag(flag, "CVode");
|
||||
utils::check_cvode_flag(flag, "CVode");
|
||||
|
||||
long int n_steps;
|
||||
double last_step_size;
|
||||
@@ -449,11 +392,11 @@ namespace gridfire::solver {
|
||||
cleanup_cvode_resources(true);
|
||||
|
||||
m_cvode_mem = CVodeCreate(CV_BDF, m_sun_ctx);
|
||||
check_cvode_flag(m_cvode_mem == nullptr ? -1 : 0, "CVodeCreate");
|
||||
utils::check_cvode_flag(m_cvode_mem == nullptr ? -1 : 0, "CVodeCreate");
|
||||
|
||||
initialize_cvode_integration_resources(N, numSpecies, current_time, currentComposition, absTol, relTol, accumulated_energy);
|
||||
|
||||
check_cvode_flag(CVodeReInit(m_cvode_mem, current_time, m_Y), "CVodeReInit");
|
||||
utils::check_cvode_flag(CVodeReInit(m_cvode_mem, current_time, m_Y), "CVodeReInit");
|
||||
}
|
||||
|
||||
}
|
||||
@@ -487,7 +430,7 @@ namespace gridfire::solver {
|
||||
NetOut netOut;
|
||||
netOut.composition = outputComposition;
|
||||
netOut.energy = accumulated_energy;
|
||||
check_cvode_flag(CVodeGetNumSteps(m_cvode_mem, reinterpret_cast<long int *>(&netOut.num_steps)), "CVodeGetNumSteps");
|
||||
utils::check_cvode_flag(CVodeGetNumSteps(m_cvode_mem, reinterpret_cast<long int *>(&netOut.num_steps)), "CVodeGetNumSteps");
|
||||
|
||||
LOG_TRACE_L2(m_logger, "generating final nuclear energy generation rate derivatives...");
|
||||
auto [dEps_dT, dEps_dRho] = m_engine.calculateEpsDerivatives(
|
||||
@@ -688,7 +631,7 @@ namespace gridfire::solver {
|
||||
std::vector<double> y_vec(y_data, y_data + numSpecies);
|
||||
fourdst::composition::Composition composition(m_engine.getNetworkSpecies(), y_vec);
|
||||
|
||||
LOG_TRACE_L2(m_logger, "Calculating RHS at time {} with {} species in composition (mean molecular mass: {})", t, composition.size(), composition.getMeanParticleMass());
|
||||
LOG_TRACE_L2(m_logger, "Calculating RHS at time {} with {} species in composition", t, composition.size());
|
||||
const auto result = m_engine.calculateRHSAndEnergy(composition, data->T9, data->rho);
|
||||
if (!result) {
|
||||
LOG_WARNING(m_logger, "StaleEngineTrigger thrown during RHS calculation at time {}", t);
|
||||
@@ -732,7 +675,7 @@ namespace gridfire::solver {
|
||||
LOG_TRACE_L2(m_logger, "Initializing CVODE integration resources with N: {}, current_time: {}, absTol: {}, relTol: {}", N, current_time, absTol, relTol);
|
||||
cleanup_cvode_resources(false); // Cleanup any existing resources before initializing new ones
|
||||
|
||||
m_Y = init_sun_vector(N, m_sun_ctx);
|
||||
m_Y = utils::init_sun_vector(N, m_sun_ctx);
|
||||
m_YErr = N_VClone(m_Y);
|
||||
|
||||
sunrealtype *y_data = N_VGetArrayPointer(m_Y);
|
||||
@@ -747,8 +690,8 @@ namespace gridfire::solver {
|
||||
y_data[numSpecies] = accumulatedEnergy; // Specific energy rate, initialized to zero
|
||||
|
||||
|
||||
check_cvode_flag(CVodeInit(m_cvode_mem, cvode_rhs_wrapper, current_time, m_Y), "CVodeInit");
|
||||
check_cvode_flag(CVodeSStolerances(m_cvode_mem, relTol, absTol), "CVodeSStolerances");
|
||||
utils::check_cvode_flag(CVodeInit(m_cvode_mem, cvode_rhs_wrapper, current_time, m_Y), "CVodeInit");
|
||||
utils::check_cvode_flag(CVodeSStolerances(m_cvode_mem, relTol, absTol), "CVodeSStolerances");
|
||||
|
||||
// Constraints
|
||||
// We constrain the solution vector using CVODE's built in constraint flags as outlines on page 53 of the CVODE manual
|
||||
@@ -768,17 +711,17 @@ namespace gridfire::solver {
|
||||
}
|
||||
N_VConst(1.0, m_constraints); // Set all constraints to >= 0 (note this is where the flag values are set)
|
||||
|
||||
check_cvode_flag(CVodeSetConstraints(m_cvode_mem, m_constraints), "CVodeSetConstraints");
|
||||
utils::check_cvode_flag(CVodeSetConstraints(m_cvode_mem, m_constraints), "CVodeSetConstraints");
|
||||
|
||||
check_cvode_flag(CVodeSetMaxStep(m_cvode_mem, 1.0e20), "CVodeSetMaxStep");
|
||||
utils::check_cvode_flag(CVodeSetMaxStep(m_cvode_mem, 1.0e20), "CVodeSetMaxStep");
|
||||
|
||||
m_J = SUNDenseMatrix(static_cast<sunindextype>(N), static_cast<sunindextype>(N), m_sun_ctx);
|
||||
check_cvode_flag(m_J == nullptr ? -1 : 0, "SUNDenseMatrix");
|
||||
utils::check_cvode_flag(m_J == nullptr ? -1 : 0, "SUNDenseMatrix");
|
||||
m_LS = SUNLinSol_Dense(m_Y, m_J, m_sun_ctx);
|
||||
check_cvode_flag(m_LS == nullptr ? -1 : 0, "SUNLinSol_Dense");
|
||||
utils::check_cvode_flag(m_LS == nullptr ? -1 : 0, "SUNLinSol_Dense");
|
||||
|
||||
check_cvode_flag(CVodeSetLinearSolver(m_cvode_mem, m_LS, m_J), "CVodeSetLinearSolver");
|
||||
check_cvode_flag(CVodeSetJacFn(m_cvode_mem, cvode_jac_wrapper), "CVodeSetJacFn");
|
||||
utils::check_cvode_flag(CVodeSetLinearSolver(m_cvode_mem, m_LS, m_J), "CVodeSetLinearSolver");
|
||||
utils::check_cvode_flag(CVodeSetJacFn(m_cvode_mem, cvode_jac_wrapper), "CVodeSetJacFn");
|
||||
LOG_TRACE_L2(m_logger, "CVODE solver initialized");
|
||||
}
|
||||
|
||||
@@ -814,10 +757,10 @@ namespace gridfire::solver {
|
||||
sunrealtype hlast, hcur, tcur;
|
||||
int qlast;
|
||||
|
||||
check_cvode_flag(CVodeGetLastStep(m_cvode_mem, &hlast), "CVodeGetLastStep");
|
||||
check_cvode_flag(CVodeGetCurrentStep(m_cvode_mem, &hcur), "CVodeGetCurrentStep");
|
||||
check_cvode_flag(CVodeGetLastOrder(m_cvode_mem, &qlast), "CVodeGetLastOrder");
|
||||
check_cvode_flag(CVodeGetCurrentTime(m_cvode_mem, &tcur), "CVodeGetCurrentTime");
|
||||
utils::check_cvode_flag(CVodeGetLastStep(m_cvode_mem, &hlast), "CVodeGetLastStep");
|
||||
utils::check_cvode_flag(CVodeGetCurrentStep(m_cvode_mem, &hcur), "CVodeGetCurrentStep");
|
||||
utils::check_cvode_flag(CVodeGetLastOrder(m_cvode_mem, &qlast), "CVodeGetLastOrder");
|
||||
utils::check_cvode_flag(CVodeGetCurrentTime(m_cvode_mem, &tcur), "CVodeGetCurrentTime");
|
||||
|
||||
{
|
||||
std::vector<std::string> labels = {"Current Time (tcur)", "Last Step (hlast)", "Current Step (hcur)", "Last Order (qlast)"};
|
||||
@@ -834,13 +777,13 @@ namespace gridfire::solver {
|
||||
// These are the CRITICAL counters for diagnosing your problem
|
||||
long int nsteps, nfevals, nlinsetups, netfails, nniters, nconvfails, nsetfails;
|
||||
|
||||
check_cvode_flag(CVodeGetNumSteps(m_cvode_mem, &nsteps), "CVodeGetNumSteps");
|
||||
check_cvode_flag(CVodeGetNumRhsEvals(m_cvode_mem, &nfevals), "CVodeGetNumRhsEvals");
|
||||
check_cvode_flag(CVodeGetNumLinSolvSetups(m_cvode_mem, &nlinsetups), "CVodeGetNumLinSolvSetups");
|
||||
check_cvode_flag(CVodeGetNumErrTestFails(m_cvode_mem, &netfails), "CVodeGetNumErrTestFails");
|
||||
check_cvode_flag(CVodeGetNumNonlinSolvIters(m_cvode_mem, &nniters), "CVodeGetNumNonlinSolvIters");
|
||||
check_cvode_flag(CVodeGetNumNonlinSolvConvFails(m_cvode_mem, &nconvfails), "CVodeGetNumNonlinSolvConvFails");
|
||||
check_cvode_flag(CVodeGetNumLinConvFails(m_cvode_mem, &nsetfails), "CVodeGetNumLinConvFails");
|
||||
utils::check_cvode_flag(CVodeGetNumSteps(m_cvode_mem, &nsteps), "CVodeGetNumSteps");
|
||||
utils::check_cvode_flag(CVodeGetNumRhsEvals(m_cvode_mem, &nfevals), "CVodeGetNumRhsEvals");
|
||||
utils::check_cvode_flag(CVodeGetNumLinSolvSetups(m_cvode_mem, &nlinsetups), "CVodeGetNumLinSolvSetups");
|
||||
utils::check_cvode_flag(CVodeGetNumErrTestFails(m_cvode_mem, &netfails), "CVodeGetNumErrTestFails");
|
||||
utils::check_cvode_flag(CVodeGetNumNonlinSolvIters(m_cvode_mem, &nniters), "CVodeGetNumNonlinSolvIters");
|
||||
utils::check_cvode_flag(CVodeGetNumNonlinSolvConvFails(m_cvode_mem, &nconvfails), "CVodeGetNumNonlinSolvConvFails");
|
||||
utils::check_cvode_flag(CVodeGetNumLinConvFails(m_cvode_mem, &nsetfails), "CVodeGetNumLinConvFails");
|
||||
|
||||
|
||||
{
|
||||
@@ -864,7 +807,7 @@ namespace gridfire::solver {
|
||||
}
|
||||
|
||||
// --- 3. Get Estimated Local Errors (Your Original Logic) ---
|
||||
check_cvode_flag(CVodeGetEstLocalErrors(m_cvode_mem, m_YErr), "CVodeGetEstLocalErrors");
|
||||
utils::check_cvode_flag(CVodeGetEstLocalErrors(m_cvode_mem, m_YErr), "CVodeGetEstLocalErrors");
|
||||
|
||||
sunrealtype *y_data = N_VGetArrayPointer(m_Y);
|
||||
sunrealtype *y_err_data = N_VGetArrayPointer(m_YErr);
|
||||
|
||||
Reference in New Issue
Block a user