feat(dynamic-engine): added derivitves for energy generation rate. dε/dT and dε/dρ have been added to NetOut and computed with auto diff
This commit is contained in:
@@ -1,6 +1,10 @@
|
||||
#include "gridfire/solver/strategies/CVODE_solver_strategy.h"
|
||||
|
||||
#include "gridfire/network.h"
|
||||
#include "gridfire/utils/table_format.h"
|
||||
#include "gridfire/engine/diagnostics/dynamic_engine_diagnostics.h"
|
||||
|
||||
#include "quill/LogMacros.h"
|
||||
|
||||
#include "fourdst/composition/composition.h"
|
||||
|
||||
@@ -10,6 +14,9 @@
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <stdexcept>
|
||||
#include <algorithm>
|
||||
|
||||
#include "fourdst/composition/exceptions/exceptions_composition.h"
|
||||
|
||||
|
||||
namespace {
|
||||
@@ -75,11 +82,12 @@ namespace gridfire::solver {
|
||||
// TODO: In order to support MPI this function must be changed
|
||||
const int flag = SUNContext_Create(SUN_COMM_NULL, &m_sun_ctx);
|
||||
if (flag < 0) {
|
||||
throw std::runtime_error("Failed to create SUNDIALS context (Errno: " + std::to_string(flag) + ")");
|
||||
throw std::runtime_error("Failed to create SUNDIALS context (SUNDIALS Errno: " + std::to_string(flag) + ")");
|
||||
}
|
||||
}
|
||||
|
||||
CVODESolverStrategy::~CVODESolverStrategy() {
|
||||
std::cout << "Cleaning up CVODE resources..." << std::endl;
|
||||
cleanup_cvode_resources(true);
|
||||
|
||||
if (m_sun_ctx) {
|
||||
@@ -111,7 +119,7 @@ namespace gridfire::solver {
|
||||
[[maybe_unused]] double last_callback_time = 0;
|
||||
m_num_steps = 0;
|
||||
double accumulated_energy = 0.0;
|
||||
|
||||
int total_update_stages_triggered = 0;
|
||||
while (current_time < netIn.tMax) {
|
||||
try {
|
||||
user_data.T9 = T9;
|
||||
@@ -138,22 +146,47 @@ namespace gridfire::solver {
|
||||
double last_step_size;
|
||||
CVodeGetNumSteps(m_cvode_mem, &n_steps);
|
||||
CVodeGetLastStep(m_cvode_mem, &last_step_size);
|
||||
std::cout << std::scientific << std::setprecision(3) << "Step: " << std::setw(6) << n_steps << " | Time: " << current_time << " | Last Step Size: " << last_step_size << std::endl;
|
||||
long int nliters, nlcfails;
|
||||
CVodeGetNumNonlinSolvIters(m_cvode_mem, &nliters);
|
||||
CVodeGetNumNonlinSolvConvFails(m_cvode_mem, &nlcfails);
|
||||
|
||||
sunrealtype* y_data = N_VGetArrayPointer(m_Y);
|
||||
const double current_energy = y_data[numSpecies]; // Specific energy rate
|
||||
std::cout << std::scientific << std::setprecision(3)
|
||||
<< "Step: " << std::setw(6) << n_steps
|
||||
<< " | Time: " << current_time << " [s]"
|
||||
<< " | Last Step Size: " << last_step_size
|
||||
<< " | Accumulated Energy: " << current_energy << " [erg/g]"
|
||||
<< " | NonlinIters: " << std::setw(2) << nliters
|
||||
<< " | ConvFails: " << std::setw(2) << nlcfails
|
||||
<< std::endl;
|
||||
|
||||
// if (n_steps % 50 == 0) {
|
||||
// std::cout << "Logging step diagnostics at step " << n_steps << "..." << std::endl;
|
||||
// log_step_diagnostics(user_data);
|
||||
// }
|
||||
// if (n_steps == 300) {
|
||||
// log_step_diagnostics(user_data);
|
||||
// exit(0);
|
||||
// }
|
||||
// log_step_diagnostics(user_data);
|
||||
|
||||
} catch (const exceptions::StaleEngineTrigger& e) {
|
||||
exceptions::StaleEngineTrigger::state staleState = e.getState();
|
||||
accumulated_energy += e.energy(); // Add the specific energy rate to the accumulated energy
|
||||
// total_update_stages_triggered++;
|
||||
LOG_INFO(
|
||||
m_logger,
|
||||
"Engine Update Triggered due to StaleEngineTrigger exception at time {} ({} update{} triggered). Current total specific energy {} [erg/g]",
|
||||
current_time,
|
||||
total_update_stages_triggered,
|
||||
total_update_stages_triggered == 1 ? "" : "s",
|
||||
accumulated_energy);
|
||||
total_update_stages_triggered++;
|
||||
|
||||
fourdst::composition::Composition temp_comp;
|
||||
std::vector<double> mass_fractions;
|
||||
size_t num_species_at_stop = e.numSpecies();
|
||||
|
||||
if (num_species_at_stop != m_engine.getNetworkSpecies().size()) {
|
||||
throw std::runtime_error(
|
||||
"StaleEngineError state has a different number of species than the engine. This should not happen."
|
||||
);
|
||||
}
|
||||
mass_fractions.reserve(num_species_at_stop);
|
||||
|
||||
for (size_t i = 0; i < num_species_at_stop; ++i) {
|
||||
@@ -170,13 +203,22 @@ namespace gridfire::solver {
|
||||
netInTemp.composition = temp_comp;
|
||||
|
||||
fourdst::composition::Composition currentComposition = m_engine.update(netInTemp);
|
||||
LOG_INFO(
|
||||
m_logger,
|
||||
"Due to a triggered stale engine the composition was updated from size {} to {} species.",
|
||||
num_species_at_stop,
|
||||
m_engine.getNetworkSpecies().size()
|
||||
);
|
||||
|
||||
numSpecies = m_engine.getNetworkSpecies().size();
|
||||
N = numSpecies + 1;
|
||||
|
||||
initialize_cvode_integration_resources(N, numSpecies, current_time, temp_comp, absTol, relTol, accumulated_energy);
|
||||
initialize_cvode_integration_resources(N, numSpecies, current_time, currentComposition, absTol, relTol, accumulated_energy);
|
||||
|
||||
check_cvode_flag(CVodeReInit(m_cvode_mem, current_time, m_Y), "CVodeReInit");
|
||||
} catch (fourdst::composition::exceptions::InvalidCompositionError& e) {
|
||||
log_step_diagnostics(user_data);
|
||||
std::rethrow_exception(std::make_exception_ptr(e));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -195,7 +237,7 @@ namespace gridfire::solver {
|
||||
std::vector<std::string> speciesNames;
|
||||
speciesNames.reserve(numSpecies);
|
||||
for (const auto& species : m_engine.getNetworkSpecies()) {
|
||||
speciesNames.push_back(std::string(species.name()));
|
||||
speciesNames.emplace_back(species.name());
|
||||
}
|
||||
|
||||
fourdst::composition::Composition outputComposition(speciesNames);
|
||||
@@ -203,9 +245,21 @@ namespace gridfire::solver {
|
||||
outputComposition.finalize(true);
|
||||
|
||||
NetOut netOut;
|
||||
netOut.composition = std::move(outputComposition);
|
||||
netOut.composition = outputComposition;
|
||||
netOut.energy = accumulated_energy;
|
||||
check_cvode_flag(CVodeGetNumSteps(m_cvode_mem, reinterpret_cast<long int *>(&netOut.num_steps)), "CVodeGetNumSteps");
|
||||
|
||||
outputComposition.setCompositionMode(false); // set to number fraction mode
|
||||
std::vector<double> Y = outputComposition.getNumberFractionVector(); // TODO need to ensure that the canonical vector representation is used throughout the code to make sure tracking does not get messed up
|
||||
auto [dEps_dT, dEps_dRho] = m_engine.calculateEpsDerivatives(
|
||||
std::vector<double>(Y.begin(), Y.begin() + numSpecies), // TODO: This narrowing should probably be solved. Its possible unforeseen bugs will arise from this
|
||||
T9,
|
||||
netIn.density
|
||||
);
|
||||
|
||||
netOut.dEps_dT = dEps_dT;
|
||||
netOut.dEps_dRho = dEps_dRho;
|
||||
|
||||
return netOut;
|
||||
}
|
||||
|
||||
@@ -232,10 +286,11 @@ namespace gridfire::solver {
|
||||
void *user_data
|
||||
) {
|
||||
auto* data = static_cast<CVODEUserData*>(user_data);
|
||||
auto* instance = data->solver_instance;
|
||||
const auto* instance = data->solver_instance;
|
||||
|
||||
try {
|
||||
return instance->calculate_rhs(t, y, ydot, data);
|
||||
instance->calculate_rhs(t, y, ydot, data);
|
||||
return 0;
|
||||
} catch (const exceptions::StaleEngineTrigger& e) {
|
||||
data->captured_exception = std::make_unique<exceptions::StaleEngineTrigger>(e);
|
||||
return 1; // 1 Indicates a recoverable error, CVODE will retry the step
|
||||
@@ -279,7 +334,7 @@ namespace gridfire::solver {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int CVODESolverStrategy::calculate_rhs(
|
||||
void CVODESolverStrategy::calculate_rhs(
|
||||
const sunrealtype t,
|
||||
const N_Vector y,
|
||||
const N_Vector ydot,
|
||||
@@ -304,7 +359,6 @@ namespace gridfire::solver {
|
||||
ydot_data[i] = dydt[i];
|
||||
}
|
||||
ydot_data[numSpecies] = nuclearEnergyGenerationRate; // Set the last element to the specific energy rate
|
||||
return 0;
|
||||
}
|
||||
|
||||
void CVODESolverStrategy::initialize_cvode_integration_resources(
|
||||
@@ -319,6 +373,7 @@ namespace gridfire::solver {
|
||||
cleanup_cvode_resources(false); // Cleanup any existing resources before initializing new ones
|
||||
|
||||
m_Y = init_sun_vector(N, m_sun_ctx);
|
||||
m_YErr = N_VClone(m_Y);
|
||||
|
||||
sunrealtype *y_data = N_VGetArrayPointer(m_Y);
|
||||
for (size_t i = 0; i < numSpecies; i++) {
|
||||
@@ -335,6 +390,8 @@ namespace gridfire::solver {
|
||||
check_cvode_flag(CVodeInit(m_cvode_mem, cvode_rhs_wrapper, current_time, m_Y), "CVodeInit");
|
||||
check_cvode_flag(CVodeSStolerances(m_cvode_mem, relTol, absTol), "CVodeSStolerances");
|
||||
|
||||
check_cvode_flag(CVodeSetMaxStep(m_cvode_mem, 1.0e20), "CVodeSetMaxStep");
|
||||
|
||||
m_J = SUNDenseMatrix(static_cast<sunindextype>(N), static_cast<sunindextype>(N), m_sun_ctx);
|
||||
check_cvode_flag(m_J == nullptr ? -1 : 0, "SUNDenseMatrix");
|
||||
m_LS = SUNLinSol_Dense(m_Y, m_J, m_sun_ctx);
|
||||
@@ -348,10 +405,12 @@ namespace gridfire::solver {
|
||||
if (m_LS) SUNLinSolFree(m_LS);
|
||||
if (m_J) SUNMatDestroy(m_J);
|
||||
if (m_Y) N_VDestroy(m_Y);
|
||||
if (m_YErr) N_VDestroy(m_YErr);
|
||||
|
||||
m_LS = nullptr;
|
||||
m_J = nullptr;
|
||||
m_Y = nullptr;
|
||||
m_YErr = nullptr;
|
||||
|
||||
if (memFree) {
|
||||
if (m_cvode_mem) CVodeFree(&m_cvode_mem);
|
||||
@@ -359,4 +418,79 @@ namespace gridfire::solver {
|
||||
}
|
||||
}
|
||||
|
||||
void CVODESolverStrategy::log_step_diagnostics(const CVODEUserData &user_data) const {
|
||||
check_cvode_flag(CVodeGetEstLocalErrors(m_cvode_mem, m_YErr), "CVodeGetEstLocalErrors");
|
||||
|
||||
sunrealtype *y_data = N_VGetArrayPointer(m_Y);
|
||||
sunrealtype *y_err_data = N_VGetArrayPointer(m_YErr);
|
||||
|
||||
|
||||
std::vector<double> err_ratios;
|
||||
std::vector<std::string> speciesNames;
|
||||
|
||||
const auto absTol = m_config.get<double>("gridfire:solver:CVODESolverStrategy:absTol", 1.0e-8);
|
||||
const auto relTol = m_config.get<double>("gridfire:solver:CVODESolverStrategy:relTol", 1.0e-8);
|
||||
|
||||
const size_t num_components = N_VGetLength(m_Y);
|
||||
err_ratios.resize(num_components - 1);
|
||||
|
||||
std::vector<double> Y_full(y_data, y_data + num_components - 1);
|
||||
|
||||
std::ranges::replace_if(
|
||||
Y_full,
|
||||
[](const double val) {
|
||||
return val < 0.0 && val > -1.0e-16;
|
||||
},
|
||||
0.0
|
||||
);
|
||||
|
||||
for (size_t i = 0; i < num_components - 1; i++) {
|
||||
const double weight = relTol * std::abs(y_data[i]) + absTol;
|
||||
if (weight == 0.0) continue; // Skip components with zero weight
|
||||
|
||||
const double err_ratio = std::abs(y_err_data[i]) / weight;
|
||||
|
||||
err_ratios[i] = err_ratio;
|
||||
speciesNames.push_back(std::string(user_data.networkSpecies->at(i).name()));
|
||||
}
|
||||
|
||||
if (err_ratios.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<size_t> indices(speciesNames.size());
|
||||
for (size_t i = 0; i < indices.size(); ++i) {
|
||||
indices[i] = i;
|
||||
}
|
||||
|
||||
std::ranges::sort(
|
||||
indices,
|
||||
[&err_ratios](const size_t i1, const size_t i2) {
|
||||
return err_ratios[i1] > err_ratios[i2];
|
||||
}
|
||||
);
|
||||
|
||||
std::vector<std::string> sorted_speciesNames;
|
||||
std::vector<double> sorted_err_ratios;
|
||||
|
||||
sorted_speciesNames.reserve(indices.size());
|
||||
sorted_err_ratios.reserve(indices.size());
|
||||
|
||||
for (const auto idx: indices) {
|
||||
sorted_speciesNames.push_back(speciesNames[idx]);
|
||||
sorted_err_ratios.push_back(err_ratios[idx]);
|
||||
}
|
||||
|
||||
|
||||
|
||||
std::vector<std::unique_ptr<utils::ColumnBase>> columns;
|
||||
columns.push_back(std::make_unique<utils::Column<std::string>>("Species", sorted_speciesNames));
|
||||
columns.push_back(std::make_unique<utils::Column<double>>("Error Ratio", sorted_err_ratios));
|
||||
|
||||
std::cout << utils::format_table("Species Error Ratios", columns) << std::endl;
|
||||
diagnostics::inspect_jacobian_stiffness(*user_data.engine, Y_full, user_data.T9, user_data.rho);
|
||||
diagnostics::inspect_species_balance(*user_data.engine, "N-14", Y_full, user_data.T9, user_data.rho);
|
||||
diagnostics::inspect_species_balance(*user_data.engine, "n-1", Y_full, user_data.T9, user_data.rho);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user