feat(KINSOL): Switch from Eigen to KINSOL

Previously QSE solving was done using Eigen. While this worked we were
limited in the ability to use previous iterations to speed up later
steps. We have switched to KINSOL, from SUNDIALS, for linear solving.
This has drastically speed up the process of solving for QSE abundances,
primarily because the jacobian matrix does not need to be generated
every single time time a QSE abundance is requested.
This commit is contained in:
2025-11-19 12:06:21 -05:00
parent f7fbc6c1da
commit 442d4ed86c
12 changed files with 506 additions and 386 deletions

View File

@@ -22,65 +22,8 @@
#include "gridfire/solver/strategies/triggers/engine_partitioning_trigger.h"
#include "gridfire/trigger/procedures/trigger_pprint.h"
#include "gridfire/exceptions/error_solver.h"
#include "gridfire/utils/logging.h"
#include "gridfire/utils/sundials.h"
namespace {
std::unordered_map<int, std::string> cvode_ret_code_map {
{0, "CV_SUCCESS: The solver succeeded."},
{1, "CV_TSTOP_RETURN: The solver reached the specified stopping time."},
{2, "CV_ROOT_RETURN: A root was found."},
{-99, "CV_WARNING: CVODE succeeded but in an unusual manner"},
{-1, "CV_TOO_MUCH_WORK: The solver took too many internal steps."},
{-2, "CV_TOO_MUCH_ACC: The solver could not satisfy the accuracy requested."},
{-3, "CV_ERR_FAILURE: The solver encountered a non-recoverable error."},
{-4, "CV_CONV_FAILURE: The solver failed to converge."},
{-5, "CV_LINIT_FAIL: The linear solver's initialization function failed."},
{-6, "CV_LSETUP_FAIL: The linear solver's setup function failed."},
{-7, "CV_LSOLVE_FAIL: The linear solver's solve function failed."},
{-8, "CV_RHSFUNC_FAIL: The right-hand side function failed in an unrecoverable manner."},
{-9, "CV_FIRST_RHSFUNC_ERR: The right-hand side function failed at the first call."},
{-10, "CV_REPTD_RHSFUNC_ERR: The right-hand side function repeatedly failed recoverable."},
{-11, "CV_UNREC_RHSFUNC_ERR: The right-hand side function failed unrecoverably."},
{-12, "CV_RTFUNC_FAIL: The rootfinding function failed in an unrecoverable manner."},
{-13, "CV_NLS_INIT_FAIL: The nonlinear solver's initialization function failed."},
{-14, "CV_NLS_SETUP_FAIL: The nonlinear solver's setup function failed."},
{-15, "CV_CONSTR_FAIL : The inequality constraint was violated and the solver was unable to recover."},
{-16, "CV_NLS_FAIL: The nonlinear solver's solve function failed."},
{-20, "CV_MEM_FAIL: Memory allocation failed."},
{-21, "CV_MEM_NULL: The CVODE memory structure is NULL."},
{-22, "CV_ILL_INPUT: An illegal input was detected."},
{-23, "CV_NO_MALLOC: The CVODE memory structure has not been allocated."},
{-24, "CV_BAD_K: The value of k is invalid."},
{-25, "CV_BAD_T: The value of t is invalid."},
{-26, "CV_BAD_DKY: The value of dky is invalid."},
{-27, "CV_TOO_CLOSE: The time points are too close together."},
{-28, "CV_VECTOROP_ERR: A vector operation failed."},
{-29, "CV_PROJ_MEM_NULL: The projection memory structure is NULL."},
{-30, "CV_PROJFUNC_FAIL: The projection function failed in an unrecoverable manner."},
{-31, "CV_REPTD_PROJFUNC_ERR: The projection function has repeated recoverable errors."}
};
void check_cvode_flag(const int flag, const std::string& func_name) {
if (flag < 0) {
if (!cvode_ret_code_map.contains(flag)) {
throw gridfire::exceptions::CVODESolverFailureError("CVODE error in " + func_name + ": Unknown error code: " + std::to_string(flag));
}
throw gridfire::exceptions::CVODESolverFailureError("CVODE error in " + func_name + ": " + cvode_ret_code_map.at(flag));
}
}
N_Vector init_sun_vector(uint64_t size, SUNContext sun_ctx) {
#ifdef SUNDIALS_HAVE_OPENMP
const N_Vector vec = N_VNew_OpenMP(size, 0, sun_ctx);
#elif SUNDIALS_HAVE_PTHREADS
const N_Vector vec = N_VNew_Pthreads(size, sun_ctx);
#else
const N_Vector vec = N_VNew_Serial(static_cast<long long>(size), sun_ctx);
#endif
check_cvode_flag(vec == nullptr ? -1 : 0, "N_VNew");
return vec;
}
}
namespace gridfire::solver {
@@ -175,7 +118,7 @@ namespace gridfire::solver {
LOG_TRACE_L1(m_logger, "Number of species: {} ({} independent variables)", numSpecies, N);
LOG_TRACE_L1(m_logger, "Initializing CVODE resources");
m_cvode_mem = CVodeCreate(CV_BDF, m_sun_ctx);
check_cvode_flag(m_cvode_mem == nullptr ? -1 : 0, "CVodeCreate");
utils::check_cvode_flag(m_cvode_mem == nullptr ? -1 : 0, "CVodeCreate");
initialize_cvode_integration_resources(N, numSpecies, 0.0, equilibratedComposition, absTol, relTol, 0.0);
@@ -206,11 +149,11 @@ namespace gridfire::solver {
user_data.networkSpecies = &m_engine.getNetworkSpecies();
user_data.captured_exception.reset();
check_cvode_flag(CVodeSetUserData(m_cvode_mem, &user_data), "CVodeSetUserData");
utils::check_cvode_flag(CVodeSetUserData(m_cvode_mem, &user_data), "CVodeSetUserData");
LOG_TRACE_L2(m_logger, "Taking one CVODE step...");
int flag = CVode(m_cvode_mem, netIn.tMax, m_Y, &current_time, CV_ONE_STEP);
LOG_TRACE_L2(m_logger, "CVODE step complete. Current time: {}, step status: {}", current_time, cvode_ret_code_map.at(flag));
LOG_TRACE_L2(m_logger, "CVODE step complete. Current time: {}, step status: {}", current_time, utils::cvode_ret_code_map.at(flag));
if (user_data.captured_exception){
std::rethrow_exception(std::make_exception_ptr(*user_data.captured_exception));
@@ -220,7 +163,7 @@ namespace gridfire::solver {
// TODO: Come up with some way to save these to a file rather than spamming stdout. JSON maybe? OPAT?
// log_step_diagnostics(user_data, true, false, false);
// exit(0);
check_cvode_flag(flag, "CVode");
utils::check_cvode_flag(flag, "CVode");
long int n_steps;
double last_step_size;
@@ -449,11 +392,11 @@ namespace gridfire::solver {
cleanup_cvode_resources(true);
m_cvode_mem = CVodeCreate(CV_BDF, m_sun_ctx);
check_cvode_flag(m_cvode_mem == nullptr ? -1 : 0, "CVodeCreate");
utils::check_cvode_flag(m_cvode_mem == nullptr ? -1 : 0, "CVodeCreate");
initialize_cvode_integration_resources(N, numSpecies, current_time, currentComposition, absTol, relTol, accumulated_energy);
check_cvode_flag(CVodeReInit(m_cvode_mem, current_time, m_Y), "CVodeReInit");
utils::check_cvode_flag(CVodeReInit(m_cvode_mem, current_time, m_Y), "CVodeReInit");
}
}
@@ -487,7 +430,7 @@ namespace gridfire::solver {
NetOut netOut;
netOut.composition = outputComposition;
netOut.energy = accumulated_energy;
check_cvode_flag(CVodeGetNumSteps(m_cvode_mem, reinterpret_cast<long int *>(&netOut.num_steps)), "CVodeGetNumSteps");
utils::check_cvode_flag(CVodeGetNumSteps(m_cvode_mem, reinterpret_cast<long int *>(&netOut.num_steps)), "CVodeGetNumSteps");
LOG_TRACE_L2(m_logger, "generating final nuclear energy generation rate derivatives...");
auto [dEps_dT, dEps_dRho] = m_engine.calculateEpsDerivatives(
@@ -688,7 +631,7 @@ namespace gridfire::solver {
std::vector<double> y_vec(y_data, y_data + numSpecies);
fourdst::composition::Composition composition(m_engine.getNetworkSpecies(), y_vec);
LOG_TRACE_L2(m_logger, "Calculating RHS at time {} with {} species in composition (mean molecular mass: {})", t, composition.size(), composition.getMeanParticleMass());
LOG_TRACE_L2(m_logger, "Calculating RHS at time {} with {} species in composition", t, composition.size());
const auto result = m_engine.calculateRHSAndEnergy(composition, data->T9, data->rho);
if (!result) {
LOG_WARNING(m_logger, "StaleEngineTrigger thrown during RHS calculation at time {}", t);
@@ -732,7 +675,7 @@ namespace gridfire::solver {
LOG_TRACE_L2(m_logger, "Initializing CVODE integration resources with N: {}, current_time: {}, absTol: {}, relTol: {}", N, current_time, absTol, relTol);
cleanup_cvode_resources(false); // Cleanup any existing resources before initializing new ones
m_Y = init_sun_vector(N, m_sun_ctx);
m_Y = utils::init_sun_vector(N, m_sun_ctx);
m_YErr = N_VClone(m_Y);
sunrealtype *y_data = N_VGetArrayPointer(m_Y);
@@ -747,8 +690,8 @@ namespace gridfire::solver {
y_data[numSpecies] = accumulatedEnergy; // Specific energy rate, initialized to zero
check_cvode_flag(CVodeInit(m_cvode_mem, cvode_rhs_wrapper, current_time, m_Y), "CVodeInit");
check_cvode_flag(CVodeSStolerances(m_cvode_mem, relTol, absTol), "CVodeSStolerances");
utils::check_cvode_flag(CVodeInit(m_cvode_mem, cvode_rhs_wrapper, current_time, m_Y), "CVodeInit");
utils::check_cvode_flag(CVodeSStolerances(m_cvode_mem, relTol, absTol), "CVodeSStolerances");
// Constraints
// We constrain the solution vector using CVODE's built in constraint flags as outlines on page 53 of the CVODE manual
@@ -768,17 +711,17 @@ namespace gridfire::solver {
}
N_VConst(1.0, m_constraints); // Set all constraints to >= 0 (note this is where the flag values are set)
check_cvode_flag(CVodeSetConstraints(m_cvode_mem, m_constraints), "CVodeSetConstraints");
utils::check_cvode_flag(CVodeSetConstraints(m_cvode_mem, m_constraints), "CVodeSetConstraints");
check_cvode_flag(CVodeSetMaxStep(m_cvode_mem, 1.0e20), "CVodeSetMaxStep");
utils::check_cvode_flag(CVodeSetMaxStep(m_cvode_mem, 1.0e20), "CVodeSetMaxStep");
m_J = SUNDenseMatrix(static_cast<sunindextype>(N), static_cast<sunindextype>(N), m_sun_ctx);
check_cvode_flag(m_J == nullptr ? -1 : 0, "SUNDenseMatrix");
utils::check_cvode_flag(m_J == nullptr ? -1 : 0, "SUNDenseMatrix");
m_LS = SUNLinSol_Dense(m_Y, m_J, m_sun_ctx);
check_cvode_flag(m_LS == nullptr ? -1 : 0, "SUNLinSol_Dense");
utils::check_cvode_flag(m_LS == nullptr ? -1 : 0, "SUNLinSol_Dense");
check_cvode_flag(CVodeSetLinearSolver(m_cvode_mem, m_LS, m_J), "CVodeSetLinearSolver");
check_cvode_flag(CVodeSetJacFn(m_cvode_mem, cvode_jac_wrapper), "CVodeSetJacFn");
utils::check_cvode_flag(CVodeSetLinearSolver(m_cvode_mem, m_LS, m_J), "CVodeSetLinearSolver");
utils::check_cvode_flag(CVodeSetJacFn(m_cvode_mem, cvode_jac_wrapper), "CVodeSetJacFn");
LOG_TRACE_L2(m_logger, "CVODE solver initialized");
}
@@ -814,10 +757,10 @@ namespace gridfire::solver {
sunrealtype hlast, hcur, tcur;
int qlast;
check_cvode_flag(CVodeGetLastStep(m_cvode_mem, &hlast), "CVodeGetLastStep");
check_cvode_flag(CVodeGetCurrentStep(m_cvode_mem, &hcur), "CVodeGetCurrentStep");
check_cvode_flag(CVodeGetLastOrder(m_cvode_mem, &qlast), "CVodeGetLastOrder");
check_cvode_flag(CVodeGetCurrentTime(m_cvode_mem, &tcur), "CVodeGetCurrentTime");
utils::check_cvode_flag(CVodeGetLastStep(m_cvode_mem, &hlast), "CVodeGetLastStep");
utils::check_cvode_flag(CVodeGetCurrentStep(m_cvode_mem, &hcur), "CVodeGetCurrentStep");
utils::check_cvode_flag(CVodeGetLastOrder(m_cvode_mem, &qlast), "CVodeGetLastOrder");
utils::check_cvode_flag(CVodeGetCurrentTime(m_cvode_mem, &tcur), "CVodeGetCurrentTime");
{
std::vector<std::string> labels = {"Current Time (tcur)", "Last Step (hlast)", "Current Step (hcur)", "Last Order (qlast)"};
@@ -834,13 +777,13 @@ namespace gridfire::solver {
// These are the CRITICAL counters for diagnosing your problem
long int nsteps, nfevals, nlinsetups, netfails, nniters, nconvfails, nsetfails;
check_cvode_flag(CVodeGetNumSteps(m_cvode_mem, &nsteps), "CVodeGetNumSteps");
check_cvode_flag(CVodeGetNumRhsEvals(m_cvode_mem, &nfevals), "CVodeGetNumRhsEvals");
check_cvode_flag(CVodeGetNumLinSolvSetups(m_cvode_mem, &nlinsetups), "CVodeGetNumLinSolvSetups");
check_cvode_flag(CVodeGetNumErrTestFails(m_cvode_mem, &netfails), "CVodeGetNumErrTestFails");
check_cvode_flag(CVodeGetNumNonlinSolvIters(m_cvode_mem, &nniters), "CVodeGetNumNonlinSolvIters");
check_cvode_flag(CVodeGetNumNonlinSolvConvFails(m_cvode_mem, &nconvfails), "CVodeGetNumNonlinSolvConvFails");
check_cvode_flag(CVodeGetNumLinConvFails(m_cvode_mem, &nsetfails), "CVodeGetNumLinConvFails");
utils::check_cvode_flag(CVodeGetNumSteps(m_cvode_mem, &nsteps), "CVodeGetNumSteps");
utils::check_cvode_flag(CVodeGetNumRhsEvals(m_cvode_mem, &nfevals), "CVodeGetNumRhsEvals");
utils::check_cvode_flag(CVodeGetNumLinSolvSetups(m_cvode_mem, &nlinsetups), "CVodeGetNumLinSolvSetups");
utils::check_cvode_flag(CVodeGetNumErrTestFails(m_cvode_mem, &netfails), "CVodeGetNumErrTestFails");
utils::check_cvode_flag(CVodeGetNumNonlinSolvIters(m_cvode_mem, &nniters), "CVodeGetNumNonlinSolvIters");
utils::check_cvode_flag(CVodeGetNumNonlinSolvConvFails(m_cvode_mem, &nconvfails), "CVodeGetNumNonlinSolvConvFails");
utils::check_cvode_flag(CVodeGetNumLinConvFails(m_cvode_mem, &nsetfails), "CVodeGetNumLinConvFails");
{
@@ -864,7 +807,7 @@ namespace gridfire::solver {
}
// --- 3. Get Estimated Local Errors (Your Original Logic) ---
check_cvode_flag(CVodeGetEstLocalErrors(m_cvode_mem, m_YErr), "CVodeGetEstLocalErrors");
utils::check_cvode_flag(CVodeGetEstLocalErrors(m_cvode_mem, m_YErr), "CVodeGetEstLocalErrors");
sunrealtype *y_data = N_VGetArrayPointer(m_Y);
sunrealtype *y_err_data = N_VGetArrayPointer(m_YErr);