feat(weak): major weak rate progress

Major weak rate progress which includes: A refactor of many of the public interfaces for GridFire Engines to use composition objects as opposed to raw abundance vectors. This helps prevent index mismatch errors. Further, the weak reaction class has been expanded with the majority of an implimentation, including an atomic_base derived class to allow for proper auto diff tracking of the interpolated table results. Some additional changes are that the version of fourdst and libcomposition have been bumped to versions with smarter caching of intermediate vectors and a few bug fixes.
This commit is contained in:
2025-10-07 15:16:03 -04:00
parent 4f1c260444
commit 8a0b5b2c36
53 changed files with 2310 additions and 1759 deletions

View File

@@ -1,14 +1,8 @@
#pragma once
#include "gridfire/engine/engine_graph.h"
#include "gridfire/engine/engine_abstract.h"
#include "gridfire/network.h"
#include "fourdst/logging/logging.h"
#include "fourdst/config/config.h"
#include "quill/Logger.h"
#include <functional>
#include <any>
#include <vector>
@@ -101,249 +95,4 @@ namespace gridfire::solver {
* @brief Type alias for a network solver strategy that uses a DynamicEngine.
*/
using DynamicNetworkSolverStrategy = NetworkSolverStrategy<DynamicEngine>;
/**
* @class DirectNetworkSolver
* @brief A network solver that directly integrates the reaction network ODEs.
*
* This solver uses a Runge-Kutta method to directly integrate the reaction network
* ODEs. It is simpler than the QSENetworkSolver, but it can be less efficient for
* stiff networks with disparate timescales.
*
* @implements DynamicNetworkSolverStrategy
*/
class DirectNetworkSolver final : public DynamicNetworkSolverStrategy {
public:
/**
* @brief Constructor for the DirectNetworkSolver.
* @param engine The dynamic engine to use for evaluating the network.
*/
using DynamicNetworkSolverStrategy::DynamicNetworkSolverStrategy;
/**
* @struct TimestepContext
* @brief Context for the timestep callback function for the DirectNetworkSolver.
*
* This struct contains the context that will be passed to the callback function at the end of each timestep.
* It includes the current time, state, timestep size, cached results, and other relevant information.
*
* This type should be used when defining a callback function
*
* **Example:**
* @code
* #include "gridfire/solver/solver.h"
*
* #include <ofstream>
* #include <ranges>
*
* static std::ofstream consumptionFile("consumption.txt");
* void callback(const gridfire::solver::DirectNetworkSolver::TimestepContext& context) {
* int H1Index = context.engine.getSpeciesIndex(fourdst::atomic::H_1);
* int He4Index = context.engine.getSpeciesIndex(fourdst::atomic::He_4);
*
* consumptionFile << context.t << "," << context.state(H1Index) << "," << context.state(He4Index) << "\n";
* }
*
* int main() {
* ... // Code to set up engine and solvers...
* solver.set_callback(callback);
* solver.evaluate(netIn);
* consumptionFile.close();
* }
* @endcode
*/
struct TimestepContext final : public SolverContextBase {
const double t; ///< Current time.
const boost::numeric::ublas::vector<double>& state; ///< Current state of the system.
const double dt; ///< Time step size.
const double cached_time; ///< Cached time for the last observed state.
const double last_observed_time; ///< Last time the state was observed.
const double last_step_time; ///< Last step time.
const double T9; ///< Temperature in units of 10^9 K.
const double rho; ///< Density in g/cm^3.
const std::optional<StepDerivatives<double>>& cached_result; ///< Cached result of the step derivatives.
const int num_steps; ///< Total number of steps taken.
const DynamicEngine& engine; ///< Reference to the dynamic engine.
const std::vector<fourdst::atomic::Species>& networkSpecies;
TimestepContext(
const double t,
const boost::numeric::ublas::vector<double> &state,
const double dt,
const double cached_time,
const double last_observed_time,
const double last_step_time,
const double t9,
const double rho,
const std::optional<StepDerivatives<double>> &cached_result,
const int num_steps,
const DynamicEngine &engine,
const std::vector<fourdst::atomic::Species>& networkSpecies
);
/**
* @brief Describe the context for callback functions.
* @return A vector of tuples, each containing a string for the parameter's name and a string for its type.
*
* This method provides a description of the context that will be passed to the callback function.
* The intent is that an end user can investigate the context and use this information to craft their own
* callback function.
*
* @implements SolverContextBase::describe
*/
[[nodiscard]] std::vector<std::tuple<std::string, std::string>> describe() const override;
};
/**
* @brief Type alias for a timestep callback function.
*
* @brief The type alias for the callback function that will be called at the end of each timestep.
*
*/
using TimestepCallback = std::function<void(const TimestepContext& context)>; ///< Type alias for a timestep callback function.
/**
* @brief Evaluates the network for a given timestep using direct integration.
* @param netIn The input conditions for the network.
* @return The output conditions after the timestep.
*/
NetOut evaluate(const NetIn& netIn) override;
/**
* @brief Sets the callback function to be called at the end of each timestep.
* @param callback The callback function to be called at the end of each timestep.
*
* This function allows the user to set a callback function that will be called at the end of each timestep.
* The callback function will receive a gridfire::solver::DirectNetworkSolver::TimestepContext object.
*/
void set_callback(const std::any &callback) override;
/**
* @brief Describe the context that will be passed to the callback function.
* @return A vector of tuples, each containing a string for the parameter's name and a string for its type.
*
* This method provides a description of the context that will be passed to the callback function.
* The intent is that an end user can investigate the context and use this information to craft their own
* callback function.
*
* @implements SolverContextBase::describe
*/
[[nodiscard]] std::vector<std::tuple<std::string, std::string>> describe_callback_context() const override;
private:
/**
* @struct RHSManager
* @brief Functor for calculating the right-hand side of the ODEs.
*
* This functor is used by the ODE solver to calculate the time derivatives of the
* species abundances. It takes the current abundances as input and returns the
* time derivatives.
*/
struct RHSManager {
DynamicEngine& m_engine; ///< The engine used to evaluate the network.
const double m_T9; ///< Temperature in units of 10^9 K.
const double m_rho; ///< Density in g/cm^3.
mutable double m_cached_time;
mutable std::optional<StepDerivatives<double>> m_cached_result;
mutable double m_last_observed_time = 0.0; ///< Last time the state was observed.
quill::Logger* m_logger = LogManager::getInstance().newFileLogger("integration.log", "GridFire"); ///< Logger instance.
mutable int m_num_steps = 0;
mutable double m_last_step_time = 1e-20;
TimestepCallback& m_callback;
const std::vector<fourdst::atomic::Species>& m_networkSpecies;
/**
* @brief Constructor for the RHSFunctor.
* @param engine The engine used to evaluate the network.
* @param T9 Temperature in units of 10^9 K.
* @param rho Density in g/cm^3.
* @param callback callback function to be called at the end of each timestep.
* @param networkSpecies vector of species in the network in the correct order.
*/
RHSManager(
DynamicEngine& engine,
const double T9,
const double rho,
TimestepCallback& callback,
const std::vector<fourdst::atomic::Species>& networkSpecies
) :
m_engine(engine),
m_T9(T9),
m_rho(rho),
m_cached_time(0),
m_callback(callback),
m_networkSpecies(networkSpecies){}
/**
* @brief Calculates the time derivatives of the species abundances.
* @param Y Vector of current abundances.
* @param dYdt Vector to store the time derivatives.
* @param t Current time.
*/
void operator()(
const boost::numeric::ublas::vector<double>& Y,
boost::numeric::ublas::vector<double>& dYdt,
double t
) const;
void observe(const boost::numeric::ublas::vector<double>& state, double t) const;
void compute_and_cache(const boost::numeric::ublas::vector<double>& state, double t) const;
};
/**
* @struct JacobianFunctor
* @brief Functor for calculating the Jacobian matrix.
*
* This functor is used by the ODE solver to calculate the Jacobian matrix of the
* ODEs. It takes the current abundances as input and returns the Jacobian matrix.
*/
struct JacobianFunctor {
DynamicEngine& m_engine; ///< The engine used to evaluate the network.
const double m_T9; ///< Temperature in units of 10^9 K.
const double m_rho; ///< Density in g/cm^3.
/**
* @brief Constructor for the JacobianFunctor.
* @param engine The engine used to evaluate the network.
* @param T9 Temperature in units of 10^9 K.
* @param rho Density in g/cm^3.
*/
JacobianFunctor(
DynamicEngine& engine,
const double T9,
const double rho
) :
m_engine(engine),
m_T9(T9),
m_rho(rho) {}
/**
* @brief Calculates the Jacobian matrix.
* @param Y Vector of current abundances.
* @param J Matrix to store the Jacobian matrix.
* @param t Current time.
* @param dfdt Vector to store the time derivatives (not used).
*/
void operator()(
const boost::numeric::ublas::vector<double>& Y,
boost::numeric::ublas::matrix<double>& J,
double t,
boost::numeric::ublas::vector<double>& dfdt
) const;
};
private:
quill::Logger* m_logger = LogManager::getInstance().getLogger("log"); ///< Logger instance.
Config& m_config = Config::getInstance(); ///< Configuration instance.
TimestepCallback m_callback;
};
}

View File

@@ -55,7 +55,7 @@ namespace gridfire::solver {
void set_callback(const std::any &callback) override;
bool get_stdout_logging_enabled() const;
[[nodiscard]] bool get_stdout_logging_enabled() const;
void set_stdout_logging_enabled(const bool value);
@@ -69,14 +69,14 @@ namespace gridfire::solver {
const double last_step_time;
const double T9;
const double rho;
const int num_steps;
const size_t num_steps;
const DynamicEngine& engine;
const std::vector<fourdst::atomic::Species>& networkSpecies;
// Constructor
TimestepContext(
double t, const N_Vector& state, double dt, double last_step_time,
double t9, double rho, int num_steps, const DynamicEngine& engine,
double t9, double rho, size_t num_steps, const DynamicEngine& engine,
const std::vector<fourdst::atomic::Species>& networkSpecies
);
@@ -104,8 +104,8 @@ namespace gridfire::solver {
};
private:
Config& m_config = Config::getInstance();
quill::Logger* m_logger = LogManager::getInstance().getLogger("log");
fourdst::config::Config& m_config = fourdst::config::Config::getInstance();
quill::Logger* m_logger = fourdst::logging::LogManager::getInstance().getLogger("log");
static int cvode_rhs_wrapper(sunrealtype t, N_Vector y, N_Vector ydot, void *user_data);
static int cvode_jac_wrapper(sunrealtype t, N_Vector y, N_Vector ydot, SUNMatrix J, void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3);

View File

@@ -23,7 +23,7 @@ namespace gridfire::trigger::solver::CVODE {
size_t numTriggers() const override;
size_t numMisses() const override;
private:
quill::Logger* m_logger = LogManager::getInstance().getLogger("log");
quill::Logger* m_logger = fourdst::logging::LogManager::getInstance().getLogger("log");
mutable size_t m_hits = 0;
mutable size_t m_misses = 0;
mutable size_t m_updates = 0;
@@ -48,7 +48,7 @@ namespace gridfire::trigger::solver::CVODE {
size_t numTriggers() const override;
size_t numMisses() const override;
private:
quill::Logger* m_logger = LogManager::getInstance().getLogger("log");
quill::Logger* m_logger = fourdst::logging::LogManager::getInstance().getLogger("log");
mutable size_t m_hits = 0;
mutable size_t m_misses = 0;
mutable size_t m_updates = 0;
@@ -71,7 +71,7 @@ namespace gridfire::trigger::solver::CVODE {
size_t numTriggers() const override;
size_t numMisses() const override;
private:
quill::Logger* m_logger = LogManager::getInstance().getLogger("log");
quill::Logger* m_logger = fourdst::logging::LogManager::getInstance().getLogger("log");
mutable size_t m_hits = 0;
mutable size_t m_misses = 0;
mutable size_t m_updates = 0;