From 7fded59814e4f6985bd114b06ac28b0cb70fd027 Mon Sep 17 00:00:00 2001 From: Emily Boudreaux Date: Thu, 30 Oct 2025 15:05:08 -0400 Subject: [PATCH] fix(python-bindings): Updated python bindings to new interface The python bindings now work with the polymorphic reaction class and the CVODE solver --- src/include/gridfire/engine/engine_abstract.h | 19 +- src/include/gridfire/engine/engine_graph.h | 13 +- .../gridfire/engine/procedures/construction.h | 5 +- .../gridfire/engine/procedures/priming.h | 7 +- .../gridfire/engine/views/engine_adaptive.h | 3 + .../gridfire/engine/views/engine_defined.h | 2 + .../gridfire/engine/views/engine_multiscale.h | 14 +- .../gridfire/exceptions/error_engine.h | 8 + src/include/gridfire/solver/solver.h | 5 +- .../solver/strategies/CVODE_solver_strategy.h | 16 +- .../gridfire/utils/general_composition.h | 12 +- src/lib/engine/engine_graph.cpp | 22 +- src/lib/engine/views/engine_adaptive.cpp | 15 + src/lib/engine/views/engine_defined.cpp | 14 + src/lib/engine/views/engine_multiscale.cpp | 37 ++ .../strategies/CVODE_solver_strategy.cpp | 67 ++- src/python/engine/bindings.cpp | 442 +++++++++++++----- src/python/engine/bindings.h | 10 + src/python/engine/trampoline/py_engine.cpp | 90 +++- src/python/engine/trampoline/py_engine.h | 130 +++++- src/python/exceptions/bindings.cpp | 17 + src/python/reaction/bindings.cpp | 54 ++- src/python/solver/bindings.cpp | 120 +++-- src/python/solver/trampoline/py_solver.cpp | 9 + src/python/solver/trampoline/py_solver.h | 5 + src/python/types/bindings.cpp | 17 +- src/python/utils/bindings.cpp | 64 +++ 27 files changed, 962 insertions(+), 255 deletions(-) diff --git a/src/include/gridfire/engine/engine_abstract.h b/src/include/gridfire/engine/engine_abstract.h index 3af6aad9..03c37e9c 100644 --- a/src/include/gridfire/engine/engine_abstract.h +++ b/src/include/gridfire/engine/engine_abstract.h @@ -112,7 +112,8 @@ namespace gridfire { * @param comp Composition object containing current abundances. * @param T9 Temperature in units of 10^9 K. * @param rho Density in g/cm^3. - * @return StepDerivatives containing dY/dt and energy generation rate. + * @return expected> containing either dY/dt and energy generation rate or a stale engine + * error indicating that the engine must be updated * * This function must be implemented by derived classes to compute the * time derivatives of all species and the specific nuclear energy generation @@ -394,5 +395,21 @@ namespace gridfire { // ReSharper disable once CppDFAUnreachableCode } + /** + * @brief Recursively collect composition from current engine and any sub engines if they exist. + * @details If species i is defined in comp and in any sub engine or self composition then the molar abundance of + * species i in the returned composition will be that defined in comp. If there are species defined in sub engine + * compositions which are not defined in comp then their molar abundances will be based on the reported values + * from each sub engine. + * @note It is up to each engine to decide how to handle filling in the return composition. + * @note These methods return an unfinalized composition which must then be finalized by the caller + * @param comp Input composition to "normalize". + * @return An updated composition which is a superset of comp. This may contain species which were culled, for + * example, by either QSE partitioning or reaction flow rate culling + */ + virtual fourdst::composition::Composition collectComposition( + fourdst::composition::Composition& comp + ) const = 0; + }; } \ No newline at end of file diff --git a/src/include/gridfire/engine/engine_graph.h b/src/include/gridfire/engine/engine_graph.h index 692f7459..9be39e30 100644 --- a/src/include/gridfire/engine/engine_graph.h +++ b/src/include/gridfire/engine/engine_graph.h @@ -730,7 +730,18 @@ namespace gridfire { BuildDepthType depth ) override; - void lumpReactions(); + /** + * @brief This will return the input comp with the molar abundances of any species not registered in that but + * registered in the engine active species set to 0.0. + * @note Effectively this method does not change input composition; rather it ensures that all species which + * can be tracked by an instance of GraphEngine are registered in the composition object. + * @note If a species is in the input comp but not in the network + * @param comp Input Composition + * @return A new composition where all members of the active species set are registered. And any members not in comp + * have a molar abundance set to 0. + * @throws BadCollectionError If the input composition contains species not present in the network species set + */ + fourdst::composition::Composition collectComposition(fourdst::composition::Composition &comp) const override; private: diff --git a/src/include/gridfire/engine/procedures/construction.h b/src/include/gridfire/engine/procedures/construction.h index 654f1197..accd7675 100644 --- a/src/include/gridfire/engine/procedures/construction.h +++ b/src/include/gridfire/engine/procedures/construction.h @@ -23,7 +23,7 @@ namespace gridfire { * * @param composition Mapping of isotopic species to their mass fractions; species with positive * mass fraction seed the network. - * @param weakInterpolator + * @param weakInterpolator Interpolator to build weak rates from. Must be constructed and owned by the caller. * @param maxLayers Variant specifying either a predefined NetworkBuildDepth or a custom integer depth; * negative depth (Full) collects all reactions, zero is invalid. * @param reverse If true, collects reverse reactions (decays or back-reactions); if false, uses forward reactions. @@ -36,6 +36,7 @@ namespace gridfire { reaction::ReactionSet build_nuclear_network( const fourdst::composition::Composition &composition, const rates::weak::WeakRateInterpolator &weakInterpolator, - BuildDepthType maxLayers = NetworkBuildDepth::Full, bool reverse = false + BuildDepthType maxLayers = NetworkBuildDepth::Full, + bool reverse = false ); } diff --git a/src/include/gridfire/engine/procedures/priming.h b/src/include/gridfire/engine/procedures/priming.h index ac1858a8..f4cc37aa 100644 --- a/src/include/gridfire/engine/procedures/priming.h +++ b/src/include/gridfire/engine/procedures/priming.h @@ -53,8 +53,8 @@ namespace gridfire { const fourdst::atomic::Species& species, const fourdst::composition::Composition& composition, double T9, - double rho, const std::optional> & - reactionTypesToIgnore + double rho, + const std::optional> &reactionTypesToIgnore ); /** @@ -78,6 +78,7 @@ namespace gridfire { const fourdst::atomic::Species& species, const fourdst::composition::Composition& composition, double T9, - double rho, const std::optional> &reactionTypesToIgnore + double rho, + const std::optional> &reactionTypesToIgnore ); } \ No newline at end of file diff --git a/src/include/gridfire/engine/views/engine_adaptive.h b/src/include/gridfire/engine/views/engine_adaptive.h index b1645d69..2a501ca7 100644 --- a/src/include/gridfire/engine/views/engine_adaptive.h +++ b/src/include/gridfire/engine/views/engine_adaptive.h @@ -302,6 +302,8 @@ namespace gridfire { [[nodiscard]] std::vector mapNetInToMolarAbundanceVector(const NetIn &netIn) const override; [[nodiscard]] PrimingReport primeEngine(const NetIn &netIn) override; + + fourdst::composition::Composition collectComposition(fourdst::composition::Composition &comp) const override; private: using Config = fourdst::config::Config; using LogManager = fourdst::logging::LogManager; @@ -315,6 +317,7 @@ namespace gridfire { /** @brief The set of species that are currently active in the network. */ std::vector m_activeSpecies; + /** @brief The set of reactions that are currently active in the network. */ reaction::ReactionSet m_activeReactions; diff --git a/src/include/gridfire/engine/views/engine_defined.h b/src/include/gridfire/engine/views/engine_defined.h index 68210683..7a5dbd50 100644 --- a/src/include/gridfire/engine/views/engine_defined.h +++ b/src/include/gridfire/engine/views/engine_defined.h @@ -220,6 +220,8 @@ namespace gridfire{ [[nodiscard]] std::vector mapNetInToMolarAbundanceVector(const NetIn &netIn) const override; [[nodiscard]] PrimingReport primeEngine(const NetIn &netIn) override; + + fourdst::composition::Composition collectComposition(fourdst::composition::Composition &comp) const override; protected: bool m_isStale = true; GraphEngine& m_baseEngine; diff --git a/src/include/gridfire/engine/views/engine_multiscale.h b/src/include/gridfire/engine/views/engine_multiscale.h index 4ffb608c..0c62a7c5 100644 --- a/src/include/gridfire/engine/views/engine_multiscale.h +++ b/src/include/gridfire/engine/views/engine_multiscale.h @@ -650,7 +650,7 @@ namespace gridfire { * @brief Exports the network to a DOT file for visualization. * * @param filename The name of the DOT file to create. - * @param Y Vector of current molar abundances for the full network. + * @param comp Composition object * @param T9 Temperature in units of 10^9 K. * @param rho Density in g/cm^3. * @@ -663,7 +663,7 @@ namespace gridfire { */ void exportToDot( const std::string& filename, - const fourdst::composition::Composition &Y, + const fourdst::composition::Composition &comp, double T9, double rho ) const; @@ -784,6 +784,16 @@ namespace gridfire { bool involvesSpeciesInDynamic(const fourdst::atomic::Species &species) const; + /** + * @brief Collect the composition from this and sub engines. + * @details This method operates by injecting the current equilibrium abundances for algebraic species into + * the composition object so that they can be bubbled up to the caller. + * @param comp Input Composition + * @return New composition which is comp + any edits from lower levels + the equilibrium abundances of all algebraic species. + * @throws BadCollectionError: if there is a species in the algebraic species set which does not show up in the reported composition from the base engine.:w + */ + fourdst::composition::Composition collectComposition(fourdst::composition::Composition &comp) const override; + private: /** diff --git a/src/include/gridfire/exceptions/error_engine.h b/src/include/gridfire/exceptions/error_engine.h index 3937a760..d7875d92 100644 --- a/src/include/gridfire/exceptions/error_engine.h +++ b/src/include/gridfire/exceptions/error_engine.h @@ -143,4 +143,12 @@ namespace gridfire::exceptions { std::string m_message; }; + class BadCollectionError final : public EngineError { + public: + explicit BadCollectionError(std::string message): m_message(std::move(message)) {} + [[nodiscard]] const char* what() const noexcept override { return m_message.c_str(); } + private: + std::string m_message; + }; + } \ No newline at end of file diff --git a/src/include/gridfire/solver/solver.h b/src/include/gridfire/solver/solver.h index 060a0448..bee5a32a 100644 --- a/src/include/gridfire/solver/solver.h +++ b/src/include/gridfire/solver/solver.h @@ -14,11 +14,12 @@ namespace gridfire::solver { * @struct SolverContextBase * @brief Base class for solver callback contexts. * - * This struct serves as a base class for contexts that can be passed to solver callbacks, it enforces + * This struct serves as a base class for contexts that can be papubl;ssed to solver callbacks, it enforces * that derived classes implement a `describe` method that returns a vector of tuples describing * the context that a callback will receive when called. */ - struct SolverContextBase { + class SolverContextBase { + public: virtual ~SolverContextBase() = default; /** diff --git a/src/include/gridfire/solver/strategies/CVODE_solver_strategy.h b/src/include/gridfire/solver/strategies/CVODE_solver_strategy.h index 34c891bb..f7d302b1 100644 --- a/src/include/gridfire/solver/strategies/CVODE_solver_strategy.h +++ b/src/include/gridfire/solver/strategies/CVODE_solver_strategy.h @@ -121,6 +121,19 @@ namespace gridfire::solver { */ NetOut evaluate(const NetIn& netIn) override; + /** + * @brief Call to evaluate which will let the user control if the trigger reasoning is displayed + * @param netIn Inputs: temperature [K], density [g cm^-3], tMax [s], composition. + * @param displayTrigger Boolean flag to control if trigger reasoning is displayed + * @return NetOut containing final Composition, accumulated energy [erg/g], step count, + * and dEps/dT, dEps/dRho. + * @throws std::runtime_error If any CVODE or SUNDIALS call fails (negative return codes), + * or if internal consistency checks fail during engine updates. + * @throws exceptions::StaleEngineTrigger Propagated if the engine signals a stale state + * during RHS evaluation (captured in the wrapper then rethrown here). + */ + NetOut evaluate(const NetIn& netIn, bool displayTrigger); + /** * @brief Install a timestep callback. * @param callback std::any containing TimestepCallback (std::function). @@ -135,8 +148,9 @@ namespace gridfire::solver { /** * @brief Enable/disable per-step stdout logging. + * @param logging_enabled Flag to control if a timestep summary is written to standard output or not */ - void set_stdout_logging_enabled(const bool value); + void set_stdout_logging_enabled(bool logging_enabled); /** * @brief Schema of fields exposed to the timestep callback context. diff --git a/src/include/gridfire/utils/general_composition.h b/src/include/gridfire/utils/general_composition.h index e222d645..bc696313 100644 --- a/src/include/gridfire/utils/general_composition.h +++ b/src/include/gridfire/utils/general_composition.h @@ -2,6 +2,8 @@ #include "fourdst/composition/composition.h" #include "fourdst/composition/atomicSpecies.h" +#include + namespace gridfire::utils { inline double massFractionFromMolarAbundanceAndComposition ( const fourdst::composition::Composition& composition, @@ -65,10 +67,14 @@ namespace gridfire::utils { inline std::vector molarMassVectorFromComposition( const fourdst::composition::Composition& composition ) { - std::vector molarMassVector; - molarMassVector.reserve(composition.getRegisteredSymbols().size()); + std::map molarMassMap; for (const auto &entry: composition | std::views::values) { - molarMassVector.push_back(entry.isotope().mass()); + molarMassMap.emplace(entry.isotope(), entry.isotope().mass()); + } + std::vector molarMassVector; + molarMassVector.reserve(molarMassMap.size()); + for (const auto molarMass : molarMassMap | std::views::values) { + molarMassVector.push_back(molarMass); } return molarMassVector; } diff --git a/src/lib/engine/engine_graph.cpp b/src/lib/engine/engine_graph.cpp index 25ff7068..092a2724 100644 --- a/src/lib/engine/engine_graph.cpp +++ b/src/lib/engine/engine_graph.cpp @@ -176,7 +176,7 @@ namespace gridfire { recordADTape(); // Record the AD tape for the RHS of the ODE (dY/di and dEps/di) for all independent variables i - const size_t inputSize = m_rhsADFun.Domain(); + [[maybe_unused]] const size_t inputSize = m_rhsADFun.Domain(); const size_t outputSize = m_rhsADFun.Range(); // Create a range x range identity pattern @@ -584,6 +584,26 @@ namespace gridfire { } } + fourdst::composition::Composition GraphEngine::collectComposition( + fourdst::composition::Composition &comp + ) const { + for (const auto &speciesName: comp | std::views::keys) { + if (!m_networkSpeciesMap.contains(speciesName)) { + throw exceptions::BadCollectionError("Cannot collect composition from GraphEngine as " + speciesName + " present in input composition does not exist in the network species map"); + } + } + fourdst::composition::Composition result; + for (const auto& species : m_networkSpecies ) { + result.registerSpecies(species); + if (comp.hasSpecies(species)) { + result.setMassFraction(species, comp.getMassFraction(species)); + } else { + result.setMassFraction(species, 0.0); + } + } + return result; + } + StepDerivatives GraphEngine::calculateAllDerivativesUsingPrecomputation( const fourdst::composition::Composition& comp, const std::vector &bare_rates, diff --git a/src/lib/engine/views/engine_adaptive.cpp b/src/lib/engine/views/engine_adaptive.cpp index 5987e2cd..6d9d5ffb 100644 --- a/src/lib/engine/views/engine_adaptive.cpp +++ b/src/lib/engine/views/engine_adaptive.cpp @@ -318,6 +318,21 @@ namespace gridfire { return m_baseEngine.primeEngine(netIn); } + fourdst::composition::Composition AdaptiveEngineView::collectComposition( + fourdst::composition::Composition &comp + ) const { + fourdst::composition::Composition result = m_baseEngine.collectComposition(comp); // Step one is to bubble the results from lower levels of the engine chain up + + for (const auto& species : m_activeSpecies) { + if (!result.hasSpecies(species)) { + result.registerSpecies(species); + result.setMassFraction(species, 0.0); + } + } + + return result; + } + size_t AdaptiveEngineView::getSpeciesIndex(const fourdst::atomic::Species &species) const { const auto it = std::ranges::find(m_activeSpecies, species); if (it != m_activeSpecies.end()) { diff --git a/src/lib/engine/views/engine_defined.cpp b/src/lib/engine/views/engine_defined.cpp index 51369f80..17657caf 100644 --- a/src/lib/engine/views/engine_defined.cpp +++ b/src/lib/engine/views/engine_defined.cpp @@ -295,6 +295,20 @@ namespace gridfire { return m_baseEngine.primeEngine(netIn); } + fourdst::composition::Composition DefinedEngineView::collectComposition( + fourdst::composition::Composition &comp + ) const { + fourdst::composition::Composition result = m_baseEngine.collectComposition(comp); + + for (const auto& species : m_activeSpecies) { + if (!result.hasSpecies(species)) { + result.registerSpecies(species); + result.setMassFraction(species, 0.0); + } + } + return result; + } + std::vector DefinedEngineView::constructSpeciesIndexMap() const { LOG_TRACE_L3(m_logger, "Constructing species index map for DefinedEngineView..."); std::unordered_map fullSpeciesReverseMap; diff --git a/src/lib/engine/views/engine_multiscale.cpp b/src/lib/engine/views/engine_multiscale.cpp index 8f1ccd9d..3d17c48a 100644 --- a/src/lib/engine/views/engine_multiscale.cpp +++ b/src/lib/engine/views/engine_multiscale.cpp @@ -874,6 +874,43 @@ namespace gridfire { return std::ranges::find(m_dynamic_species, species) != m_dynamic_species.end(); } + fourdst::composition::Composition MultiscalePartitioningEngineView::collectComposition( + fourdst::composition::Composition &comp + ) const { + fourdst::composition::Composition result = m_baseEngine.collectComposition(comp); + bool didFinalize = result.finalize(false); + if (!didFinalize) { + std::string msg = "Failed to finalize collected composition from MultiscalePartitioningEngine view after calling base engines collectComposition method."; + LOG_ERROR(m_logger, "{}", msg); + throw exceptions::BadCollectionError(msg); + } + std::map Ym; // Use an ordered map here so that this is ordered by atomic mass (which is the comparator for Species) + for (const auto& [speciesName, entry] : result) { + Ym.emplace(entry.isotope(), result.getMolarAbundance(speciesName)); + } + for (const auto& [species, Yi] : m_algebraic_abundances) { + if (!Ym.contains(species)) { + throw exceptions::BadCollectionError("MuiltiscalePartitioningEngineView failed to collect composition for species " + std::string(species.name()) + " as the base engine did not report that species present in its composition!"); + } + Ym.at(species) = Yi; + } + std::vector M; + std::vector Y; + std::vector speciesNames; + M.reserve(Ym.size()); + Y.reserve(Ym.size()); + + for (const auto& [species, Yi] : Ym) { + M.emplace_back(species.mass()); + Y.emplace_back(Yi); + speciesNames.emplace_back(species.name()); + } + + std::vector X = utils::massFractionFromMolarAbundanceAndMolarMass(Y, M); + + return fourdst::composition::Composition(speciesNames, X); + } + size_t MultiscalePartitioningEngineView::getSpeciesIndex(const Species &species) const { return m_baseEngine.getSpeciesIndex(species); } diff --git a/src/lib/solver/strategies/CVODE_solver_strategy.cpp b/src/lib/solver/strategies/CVODE_solver_strategy.cpp index 7b8cd839..1c50858e 100644 --- a/src/lib/solver/strategies/CVODE_solver_strategy.cpp +++ b/src/lib/solver/strategies/CVODE_solver_strategy.cpp @@ -136,6 +136,13 @@ namespace gridfire::solver { } NetOut CVODESolverStrategy::evaluate(const NetIn& netIn) { + return evaluate(netIn, false); + } + + NetOut CVODESolverStrategy::evaluate( + const NetIn &netIn, + bool displayTrigger + ) { LOG_TRACE_L1(m_logger, "Starting solver evaluation with T9: {} and rho: {}", netIn.temperature/1e9, netIn.density); LOG_TRACE_L1(m_logger, "Building engine update trigger...."); auto trigger = trigger::solver::CVODE::makeEnginePartitioningTrigger(1e12, 1e10, 1, true, 10); @@ -182,12 +189,7 @@ namespace gridfire::solver { check_cvode_flag(CVodeSetUserData(m_cvode_mem, &user_data), "CVodeSetUserData"); - int flag{}; - if (m_stdout_logging_enabled) { - flag = CVode(m_cvode_mem, netIn.tMax, m_Y, ¤t_time, CV_ONE_STEP); - } else { - flag = CVode(m_cvode_mem, netIn.tMax, m_Y, ¤t_time, CV_NORMAL); - } + int flag = CVode(m_cvode_mem, netIn.tMax, m_Y, ¤t_time, CV_ONE_STEP); if (user_data.captured_exception){ std::rethrow_exception(std::make_exception_ptr(*user_data.captured_exception)); @@ -206,14 +208,17 @@ namespace gridfire::solver { sunrealtype* y_data = N_VGetArrayPointer(m_Y); const double current_energy = y_data[numSpecies]; // Specific energy rate - std::cout << std::scientific << std::setprecision(3) - << "Step: " << std::setw(6) << n_steps - << " | Time: " << current_time << " [s]" - << " | Last Step Size: " << last_step_size - << " | Accumulated Energy: " << current_energy << " [erg/g]" - << " | NonLinIters: " << std::setw(2) << nliters - << " | ConvFails: " << std::setw(2) << nlcfails - << std::endl; + if (m_stdout_logging_enabled) { + std::cout << std::scientific << std::setprecision(3) + << "Step: " << std::setw(6) << n_steps + << " | Time: " << current_time << " [s]" + << " | Last Step Size: " << last_step_size + << " | Current Lightest Molar Abundance: " << y_data[0] << " [mol/g]" + << " | Accumulated Energy: " << current_energy << " [erg/g]" + << " | Total Non Linear Iterations: " << std::setw(2) << nliters + << " | Total Convergence Failures: " << std::setw(2) << nlcfails + << "\n"; + } auto ctx = TimestepContext( current_time, @@ -227,7 +232,9 @@ namespace gridfire::solver { m_engine.getNetworkSpecies()); if (trigger->check(ctx)) { - trigger::printWhy(trigger->why(ctx)); + if (m_stdout_logging_enabled && displayTrigger) { + trigger::printWhy(trigger->why(ctx)); + } trigger->update(ctx); accumulated_energy += current_energy; // Add the specific energy rate to the accumulated energy LOG_INFO( @@ -299,8 +306,9 @@ namespace gridfire::solver { } - // TODO: Need a more reliable way to get the final composition out, probably some methods that bubble it or something - // aside from that this now seems to be working + if (m_stdout_logging_enabled) { // Flush the buffer if standard out logging is enabled + std::cout << std::flush; + } LOG_TRACE_L2(m_logger, "CVODE iteration complete"); @@ -323,13 +331,23 @@ namespace gridfire::solver { } LOG_TRACE_L2(m_logger, "Constructing final composition= with {} species", speciesNames.size()); - fourdst::composition::Composition outputComposition(speciesNames); - outputComposition.setMassFraction(speciesNames, finalMassFractions); - bool didFinalize = outputComposition.finalize(true); - if (!didFinalize) { - LOG_ERROR(m_logger, "Failed to finalize output composition after CVODE integration. Check output mass fractions for validity."); + fourdst::composition::Composition topLevelComposition(speciesNames); + topLevelComposition.setMassFraction(speciesNames, finalMassFractions); + bool didFinalizeTopLevel = topLevelComposition.finalize(true); + if (!didFinalizeTopLevel) { + LOG_ERROR(m_logger, "Failed to finalize top level reconstructed composition after CVODE integration. Check output mass fractions for validity."); throw std::runtime_error("Failed to finalize output composition after CVODE integration."); } + fourdst::composition::Composition outputComposition = m_engine.collectComposition(topLevelComposition); + + assert(outputComposition.getRegisteredSymbols().size() == equilibratedComposition.getRegisteredSymbols().size()); + + bool didFinalizeOutput = outputComposition.finalize(false); + if (!didFinalizeOutput) { + LOG_ERROR(m_logger, "Failed to finalize output composition after CVODE integration."); + throw std::runtime_error("Failed to finalize output composition after CVODE integration."); + } + LOG_TRACE_L2(m_logger, "Final composition constructed successfully!"); LOG_TRACE_L2(m_logger, "Constructing output data..."); @@ -351,6 +369,7 @@ namespace gridfire::solver { LOG_TRACE_L2(m_logger, "Output data built!"); LOG_TRACE_L2(m_logger, "Solver evaluation complete!."); + return netOut; } @@ -362,8 +381,8 @@ namespace gridfire::solver { return m_stdout_logging_enabled; } - void CVODESolverStrategy::set_stdout_logging_enabled(const bool value) { - m_stdout_logging_enabled = value; + void CVODESolverStrategy::set_stdout_logging_enabled(const bool logging_enabled) { + m_stdout_logging_enabled = logging_enabled; } std::vector> CVODESolverStrategy::describe_callback_context() const { diff --git a/src/python/engine/bindings.cpp b/src/python/engine/bindings.cpp index 3543330e..7988df0a 100644 --- a/src/python/engine/bindings.cpp +++ b/src/python/engine/bindings.cpp @@ -5,6 +5,8 @@ #include "bindings.h" #include "gridfire/engine/engine.h" +#include "gridfire/engine/diagnostics/dynamic_engine_diagnostics.h" +#include "gridfire/exceptions/exceptions.h" #include "trampoline/py_engine.h" @@ -17,23 +19,70 @@ namespace { template void registerDynamicEngineDefs(py::class_ pyClass) { - pyClass.def("calculateRHSAndEnergy", &T::calculateRHSAndEnergy, - py::arg("Y"), + pyClass.def( + "calculateRHSAndEnergy", + []( + const gridfire::DynamicEngine& self, + const fourdst::composition::Composition& comp, + const double T9, + const double rho + ) { + auto result = self.calculateRHSAndEnergy(comp, T9, rho); + if (!result.has_value()) { + throw gridfire::exceptions::StaleEngineError("Engine reports stale state, call update()."); + } + return result.value(); + }, + py::arg("comp"), py::arg("T9"), py::arg("rho"), "Calculate the right-hand side (dY/dt) and energy generation rate." ) - .def("generateJacobianMatrix", py::overload_cast&, double, double>(&T::generateJacobianMatrix, py::const_), - py::arg("Y_dynamic"), + .def("calculateEpsDerivatives", + &gridfire::DynamicEngine::calculateEpsDerivatives, + py::arg("comp"), + py::arg("T9"), + py::arg("rho"), + "Calculate deps/dT and deps/drho" + ) + .def("generateJacobianMatrix", + py::overload_cast(&T::generateJacobianMatrix, py::const_), + py::arg("comp"), py::arg("T9"), py::arg("rho"), "Generate the Jacobian matrix for the current state." ) - .def("generateStoichiometryMatrix", &T::generateStoichiometryMatrix) + .def("generateJacobianMatrix", + py::overload_cast&>(&T::generateJacobianMatrix, py::const_), + py::arg("comp"), + py::arg("T9"), + py::arg("rho"), + py::arg("activeSpecies"), + "Generate the jacobian matrix only for the subset of the matrix representing the active species." + ) + .def("generateJacobianMatrix", + py::overload_cast(&T::generateJacobianMatrix, py::const_), + py::arg("comp"), + py::arg("T9"), + py::arg("rho"), + py::arg("sparsityPattern"), + "Generate the jacobian matrix for the given sparsity pattern" + ) + .def("generateStoichiometryMatrix", + &T::generateStoichiometryMatrix + ) .def("calculateMolarReactionFlow", - static_cast&, const double, const double) const>(&T::calculateMolarReactionFlow), + []( + const gridfire::DynamicEngine& self, + const gridfire::reaction::Reaction& reaction, + const fourdst::composition::Composition& comp, + const double T9, + const double rho + ) -> double { + return self.calculateMolarReactionFlow(reaction, comp, T9, rho); + }, py::arg("reaction"), - py::arg("Y"), + py::arg("comp"), py::arg("T9"), py::arg("rho"), "Calculate the molar reaction flow for a given reaction." @@ -49,61 +98,99 @@ namespace { "Set the network reactions to a new set of reactions." ) .def("getJacobianMatrixEntry", &T::getJacobianMatrixEntry, - py::arg("i"), - py::arg("j"), + py::arg("rowSpecies"), + py::arg("colSpecies"), "Get an entry from the previously generated Jacobian matrix." ) .def("getStoichiometryMatrixEntry", &T::getStoichiometryMatrixEntry, - py::arg("speciesIndex"), - py::arg("reactionIndex"), + py::arg("species"), + py::arg("reaction"), "Get an entry from the stoichiometry matrix." ) - .def("getSpeciesTimescales", &T::getSpeciesTimescales, - py::arg("Y"), + .def("getSpeciesTimescales", + []( + const gridfire::DynamicEngine& self, + const fourdst::composition::Composition& comp, + const double T9, + const double rho + ) -> std::unordered_map { + const auto result = self.getSpeciesTimescales(comp, T9, rho); + if (!result.has_value()) { + throw gridfire::exceptions::StaleEngineError("Engine reports stale state, call update()."); + } + return result.value(); + }, + py::arg("comp"), py::arg("T9"), py::arg("rho"), "Get the timescales for each species in the network." ) - .def("getSpeciesDestructionTimescales", &T::getSpeciesDestructionTimescales, - py::arg("Y"), + .def("getSpeciesDestructionTimescales", + []( + const gridfire::DynamicEngine& self, + const fourdst::composition::Composition& comp, + const double T9, + const double rho + ) -> std::unordered_map { + const auto result = self.getSpeciesDestructionTimescales(comp, T9, rho); + if (!result.has_value()) { + throw gridfire::exceptions::StaleEngineError("Engine reports stale state, call update()."); + } + return result.value(); + }, + py::arg("comp"), py::arg("T9"), py::arg("rho"), "Get the destruction timescales for each species in the network." ) - .def("update", &T::update, + .def("update", + &T::update, py::arg("netIn"), "Update the engine state based on the provided NetIn object." ) - .def("setScreeningModel", &T::setScreeningModel, + .def("setScreeningModel", + &T::setScreeningModel, py::arg("screeningModel"), "Set the screening model for the engine." ) - .def("getScreeningModel", &T::getScreeningModel, + .def("getScreeningModel", + &T::getScreeningModel, "Get the current screening model of the engine." ) - .def("getSpeciesIndex", &T::getSpeciesIndex, + .def("getSpeciesIndex", + &T::getSpeciesIndex, py::arg("species"), "Get the index of a species in the network." ) - .def("mapNetInToMolarAbundanceVector", &T::mapNetInToMolarAbundanceVector, + .def("mapNetInToMolarAbundanceVector", + &T::mapNetInToMolarAbundanceVector, py::arg("netIn"), "Map a NetIn object to a vector of molar abundances." ) - .def("primeEngine", &T::primeEngine, + .def("primeEngine", + &T::primeEngine, py::arg("netIn"), "Prime the engine with a NetIn object to prepare for calculations." ) - .def("getDepth", &T::getDepth, + .def("getDepth", + &T::getDepth, "Get the current build depth of the engine." ) - .def("rebuild", &T::rebuild, + .def("rebuild", + &T::rebuild, py::arg("composition"), py::arg("depth") = gridfire::NetworkBuildDepth::Full, "Rebuild the engine with a new composition and build depth." ) - .def("isStale", &T::isStale, + .def("isStale", + &T::isStale, py::arg("netIn"), "Check if the engine is stale based on the provided NetIn object." + ) + .def("collectComposition", + &T::collectComposition, + py::arg("composition"), + "Recursively collect composition from current engine and any sub engines if they exist." ); } @@ -112,14 +199,123 @@ namespace { void register_engine_bindings(py::module &m) { register_base_engine_bindings(m); register_engine_view_bindings(m); + register_engine_diagnostic_bindings(m); + register_engine_procedural_bindings(m); + register_engine_type_bindings(m); +} - m.def("build_reaclib_nuclear_network", &gridfire::build_reaclib_nuclear_network, - py::arg("composition"), - py::arg("maxLayers") = gridfire::NetworkBuildDepth::Full, - py::arg("reverse") = false, - "Build a nuclear network from a composition using ReacLib data." +void register_base_engine_bindings(const pybind11::module &m) { + + py::class_>(m, "StepDerivatives") + .def_readonly("dYdt", &gridfire::StepDerivatives::dydt, "The right-hand side (dY/dt) of the ODE system.") + .def_readonly("energy", &gridfire::StepDerivatives::nuclearEnergyGenerationRate, "The energy generation rate."); + + py::class_ py_sparsity_pattern(m, "SparsityPattern"); + + abs_stype_register_engine_bindings(m); + abs_stype_register_dynamic_engine_bindings(m); + con_stype_register_graph_engine_bindings(m); +} + +void abs_stype_register_engine_bindings(const pybind11::module &m) { + py::class_(m, "Engine"); +} + +void abs_stype_register_dynamic_engine_bindings(const pybind11::module &m) { + const auto a = py::class_(m, "DynamicEngine"); +} + +void register_engine_procedural_bindings(pybind11::module &m) { + auto procedures = m.def_submodule("procedures", "Procedural functions associated with engine module"); + register_engine_construction_bindings(procedures); + register_engine_construction_bindings(procedures); +} + +void register_engine_diagnostic_bindings(pybind11::module &m) { + auto diagnostics = m.def_submodule("diagnostics", "A submodule for engine diagnostics"); + diagnostics.def("report_limiting_species", + &gridfire::diagnostics::report_limiting_species, + py::arg("engine"), + py::arg("Y_full"), + py::arg("E_full"), + py::arg("dydt_full"), + py::arg("relTol"), + py::arg("absTol"), + py::arg("top_n") = 10 ); + diagnostics.def("inspect_species_balance", + &gridfire::diagnostics::inspect_species_balance, + py::arg("engine"), + py::arg("species_name"), + py::arg("comp"), + py::arg("T9"), + py::arg("rho") + ); + + diagnostics.def("inspect_jacobian_stiffness", + &gridfire::diagnostics::inspect_jacobian_stiffness, + py::arg("engine"), + py::arg("comp"), + py::arg("T9"), + py::arg("rho") + ); +} + +void register_engine_construction_bindings(pybind11::module &m) { + m.def("build_nuclear_network", &gridfire::build_nuclear_network, + py::arg("composition"), + py::arg("weakInterpolator"), + py::arg("maxLayers") = gridfire::NetworkBuildDepth::Full, + py::arg("reverse") = false, + "Build a nuclear network from a composition using all archived reaction data." + ); +} + +void register_engine_priming_bindings(pybind11::module &m) { + + m.def("calculateDestructionRateConstant", + &gridfire::calculateDestructionRateConstant, + py::arg("engine"), + py::arg("species"), + py::arg("composition"), + py::arg("T9"), + py::arg("rho"), + py::arg("reactionTypesToIgnore") + ); + + m.def("calculateCreationRate", + &gridfire::calculateCreationRate, + py::arg("engine"), + py::arg("species"), + py::arg("composition"), + py::arg("T9"), + py::arg("rho"), + py::arg("reactionTypesToIgnore") + ); +} + +void register_engine_type_bindings(pybind11::module &m) { + auto types = m.def_submodule("types", "Types associated with engine module"); + register_engine_building_type_bindings(types); + register_engine_reporting_type_bindings(types); + +} + +void register_engine_building_type_bindings(pybind11::module &m) { + py::enum_(m, "NetworkBuildDepth") + .value("Full", gridfire::NetworkBuildDepth::Full, "Full network build depth") + .value("Shallow", gridfire::NetworkBuildDepth::Shallow, "Shallow network build depth") + .value("SecondOrder", gridfire::NetworkBuildDepth::SecondOrder, "Second order network build depth") + .value("ThirdOrder", gridfire::NetworkBuildDepth::ThirdOrder, "Third order network build depth") + .value("FourthOrder", gridfire::NetworkBuildDepth::FourthOrder, "Fourth order network build depth") + .value("FifthOrder", gridfire::NetworkBuildDepth::FifthOrder, "Fifth order network build depth") + .export_values(); + + py::class_ py_build_depth_type(m, "BuildDepthType"); +} + +void register_engine_reporting_type_bindings(pybind11::module &m) { py::enum_(m, "PrimingReportStatus") .value("FULL_SUCCESS", gridfire::PrimingReportStatus::FULL_SUCCESS, "Priming was full successful.") .value("NO_SPECIES_TO_PRIME", gridfire::PrimingReportStatus::NO_SPECIES_TO_PRIME, "No species to prime.") @@ -150,135 +346,128 @@ void register_engine_bindings(py::module &m) { ); } -void register_base_engine_bindings(const pybind11::module &m) { - - py::class_>(m, "StepDerivatives") - .def_readonly("dYdt", &gridfire::StepDerivatives::dydt, "The right-hand side (dY/dt) of the ODE system.") - .def_readonly("energy", &gridfire::StepDerivatives::nuclearEnergyGenerationRate, "The energy generation rate."); - - py::class_(m, "SparsityPattern"); - - abs_stype_register_engine_bindings(m); - abs_stype_register_dynamic_engine_bindings(m); - con_stype_register_graph_engine_bindings(m); -} - -void abs_stype_register_engine_bindings(const pybind11::module &m) { - py::class_(m, "Engine"); -} - -void abs_stype_register_dynamic_engine_bindings(const pybind11::module &m) { - const auto a = py::class_(m, "DynamicEngine"); -} - void con_stype_register_graph_engine_bindings(const pybind11::module &m) { - py::enum_(m, "NetworkBuildDepth") - .value("Full", gridfire::NetworkBuildDepth::Full, "Full network build depth") - .value("Shallow", gridfire::NetworkBuildDepth::Shallow, "Shallow network build depth") - .value("SecondOrder", gridfire::NetworkBuildDepth::SecondOrder, "Second order network build depth") - .value("ThirdOrder", gridfire::NetworkBuildDepth::ThirdOrder, "Third order network build depth") - .value("FourthOrder", gridfire::NetworkBuildDepth::FourthOrder, "Fourth order network build depth") - .value("FifthOrder", gridfire::NetworkBuildDepth::FifthOrder, "Fifth order network build depth") - .export_values(); - py::class_(m, "BuildDepthType"); - - auto py_dynamic_engine_bindings = py::class_(m, "GraphEngine"); + auto py_graph_engine_bindings = py::class_(m, "GraphEngine"); // Register the Graph Engine Specific Bindings - py_dynamic_engine_bindings.def(py::init(), + py_graph_engine_bindings.def(py::init(), py::arg("composition"), py::arg("depth") = gridfire::NetworkBuildDepth::Full, "Initialize GraphEngine with a composition and build depth." ); - py_dynamic_engine_bindings.def(py::init(), + py_graph_engine_bindings.def(py::init(), py::arg("composition"), py::arg("partitionFunction"), py::arg("depth") = gridfire::NetworkBuildDepth::Full, "Initialize GraphEngine with a composition, partition function and build depth." ); - py_dynamic_engine_bindings.def(py::init(), + py_graph_engine_bindings.def(py::init(), py::arg("reactions"), "Initialize GraphEngine with a set of reactions." ); - py_dynamic_engine_bindings.def("generateJacobianMatrix", py::overload_cast&, double, double, const gridfire::SparsityPattern&>(&gridfire::GraphEngine::generateJacobianMatrix, py::const_), - py::arg("Y_dynamic"), - py::arg("T9"), - py::arg("rho"), - py::arg("sparsityPattern"), - "Generate the Jacobian matrix for the current state with a specified sparsity pattern." - ); - py_dynamic_engine_bindings.def_static("getNetReactionStoichiometry", &gridfire::GraphEngine::getNetReactionStoichiometry, + py_graph_engine_bindings.def_static("getNetReactionStoichiometry", + &gridfire::GraphEngine::getNetReactionStoichiometry, py::arg("reaction"), "Get the net stoichiometry for a given reaction." ); - py_dynamic_engine_bindings.def("involvesSpecies", &gridfire::GraphEngine::involvesSpecies, + py_graph_engine_bindings.def("getSpeciesTimescales", + py::overload_cast(&gridfire::GraphEngine::getSpeciesTimescales, py::const_), + py::arg("composition"), + py::arg("T9"), + py::arg("rho"), + py::arg("activeReactions") + ); + py_graph_engine_bindings.def("getSpeciesDestructionTimescales", + py::overload_cast(&gridfire::GraphEngine::getSpeciesDestructionTimescales, py::const_), + py::arg("composition"), + py::arg("T9"), + py::arg("rho"), + py::arg("activeReactions") + ); + py_graph_engine_bindings.def("involvesSpecies", + &gridfire::GraphEngine::involvesSpecies, py::arg("species"), "Check if a given species is involved in the network." ); - py_dynamic_engine_bindings.def("exportToDot", &gridfire::GraphEngine::exportToDot, + py_graph_engine_bindings.def("exportToDot", + &gridfire::GraphEngine::exportToDot, py::arg("filename"), "Export the network to a DOT file for visualization." ); - py_dynamic_engine_bindings.def("exportToCSV", &gridfire::GraphEngine::exportToCSV, + py_graph_engine_bindings.def("exportToCSV", + &gridfire::GraphEngine::exportToCSV, py::arg("filename"), "Export the network to a CSV file for analysis." ); - py_dynamic_engine_bindings.def("setPrecomputation", &gridfire::GraphEngine::setPrecomputation, + py_graph_engine_bindings.def("setPrecomputation", + &gridfire::GraphEngine::setPrecomputation, py::arg("precompute"), "Enable or disable precomputation for the engine." ); - py_dynamic_engine_bindings.def("isPrecomputationEnabled", &gridfire::GraphEngine::isPrecomputationEnabled, + py_graph_engine_bindings.def("isPrecomputationEnabled", + &gridfire::GraphEngine::isPrecomputationEnabled, "Check if precomputation is enabled for the engine." ); - py_dynamic_engine_bindings.def("getPartitionFunction", &gridfire::GraphEngine::getPartitionFunction, + py_graph_engine_bindings.def("getPartitionFunction", + &gridfire::GraphEngine::getPartitionFunction, "Get the partition function used by the engine." ); - py_dynamic_engine_bindings.def("calculateReverseRate", &gridfire::GraphEngine::calculateReverseRate, + py_graph_engine_bindings.def("calculateReverseRate", + &gridfire::GraphEngine::calculateReverseRate, py::arg("reaction"), py::arg("T9"), - "Calculate the reverse rate for a given reaction at a specific temperature." + py::arg("rho"), + py::arg("composition"), + "Calculate the reverse rate for a given reaction at a specific temperature, density, and composition." ); - py_dynamic_engine_bindings.def("calculateReverseRateTwoBody", &gridfire::GraphEngine::calculateReverseRateTwoBody, + py_graph_engine_bindings.def("calculateReverseRateTwoBody", + &gridfire::GraphEngine::calculateReverseRateTwoBody, py::arg("reaction"), py::arg("T9"), py::arg("forwardRate"), py::arg("expFactor"), "Calculate the reverse rate for a two-body reaction at a specific temperature." ); - py_dynamic_engine_bindings.def("calculateReverseRateTwoBodyDerivative", &gridfire::GraphEngine::calculateReverseRateTwoBodyDerivative, + py_graph_engine_bindings.def("calculateReverseRateTwoBodyDerivative", + &gridfire::GraphEngine::calculateReverseRateTwoBodyDerivative, py::arg("reaction"), py::arg("T9"), + py::arg("rho"), + py::arg("composition"), py::arg("reverseRate"), "Calculate the derivative of the reverse rate for a two-body reaction at a specific temperature." ); - py_dynamic_engine_bindings.def("isUsingReverseReactions", &gridfire::GraphEngine::isUsingReverseReactions, + py_graph_engine_bindings.def("isUsingReverseReactions", + &gridfire::GraphEngine::isUsingReverseReactions, "Check if the engine is using reverse reactions." ); - py_dynamic_engine_bindings.def("setUseReverseReactions", &gridfire::GraphEngine::setUseReverseReactions, + py_graph_engine_bindings.def("setUseReverseReactions", + &gridfire::GraphEngine::setUseReverseReactions, py::arg("useReverse"), "Enable or disable the use of reverse reactions in the engine." ); - // Register the general dynamic engine bindings - registerDynamicEngineDefs(py_dynamic_engine_bindings); + registerDynamicEngineDefs(py_graph_engine_bindings); } void register_engine_view_bindings(const pybind11::module &m) { auto py_defined_engine_view_bindings = py::class_(m, "DefinedEngineView"); - py_defined_engine_view_bindings.def(py::init, gridfire::DynamicEngine&>(), + py_defined_engine_view_bindings.def(py::init, gridfire::GraphEngine&>(), py::arg("peNames"), py::arg("baseEngine"), - "Construct a defined engine view with a list of tracked reactions and a base engine."); + "Construct a defined engine view with a list of tracked reactions and a base engine." + ); py_defined_engine_view_bindings.def("getBaseEngine", &gridfire::DefinedEngineView::getBaseEngine, "Get the base engine associated with this defined engine view."); registerDynamicEngineDefs(py_defined_engine_view_bindings); auto py_file_defined_engine_view_bindings = py::class_(m, "FileDefinedEngineView"); - py_file_defined_engine_view_bindings.def(py::init(), + py_file_defined_engine_view_bindings.def( + py::init(), py::arg("baseEngine"), py::arg("fileName"), py::arg("parser"), @@ -296,11 +485,11 @@ void register_engine_view_bindings(const pybind11::module &m) { registerDynamicEngineDefs(py_file_defined_engine_view_bindings); auto py_priming_engine_view_bindings = py::class_(m, "NetworkPrimingEngineView"); - py_priming_engine_view_bindings.def(py::init(), + py_priming_engine_view_bindings.def(py::init(), py::arg("primingSymbol"), py::arg("baseEngine"), "Construct a priming engine view with a priming symbol and a base engine."); - py_priming_engine_view_bindings.def(py::init(), + py_priming_engine_view_bindings.def(py::init(), py::arg("primingSpecies"), py::arg("baseEngine"), "Construct a priming engine view with a priming species and a base engine."); @@ -313,8 +502,10 @@ void register_engine_view_bindings(const pybind11::module &m) { py_adaptive_engine_view_bindings.def(py::init(), py::arg("baseEngine"), "Construct an adaptive engine view with a base engine."); - py_adaptive_engine_view_bindings.def("getBaseEngine", &gridfire::AdaptiveEngineView::getBaseEngine, - "Get the base engine associated with this adaptive engine view."); + py_adaptive_engine_view_bindings.def("getBaseEngine", + &gridfire::AdaptiveEngineView::getBaseEngine, + "Get the base engine associated with this adaptive engine view." + ); registerDynamicEngineDefs(py_adaptive_engine_view_bindings); @@ -341,43 +532,63 @@ void register_engine_view_bindings(const pybind11::module &m) { auto py_multiscale_engine_view_bindings = py::class_(m, "MultiscalePartitioningEngineView"); py_multiscale_engine_view_bindings.def(py::init(), py::arg("baseEngine"), - "Construct a multiscale partitioning engine view with a base engine."); - py_multiscale_engine_view_bindings.def("getBaseEngine", &gridfire::MultiscalePartitioningEngineView::getBaseEngine, - "Get the base engine associated with this multiscale partitioning engine view."); - py_multiscale_engine_view_bindings.def("analyzeTimescalePoolConnectivity", &gridfire::MultiscalePartitioningEngineView::analyzeTimescalePoolConnectivity, + "Construct a multiscale partitioning engine view with a base engine." + ); + py_multiscale_engine_view_bindings.def("getBaseEngine", + &gridfire::MultiscalePartitioningEngineView::getBaseEngine, + "Get the base engine associated with this multiscale partitioning engine view." + ); + py_multiscale_engine_view_bindings.def("analyzeTimescalePoolConnectivity", + &gridfire::MultiscalePartitioningEngineView::analyzeTimescalePoolConnectivity, py::arg("timescale_pools"), - py::arg("Y"), + py::arg("comp"), py::arg("T9"), py::arg("rho"), - "Analyze the connectivity of timescale pools in the network."); - py_multiscale_engine_view_bindings.def("partitionNetwork", py::overload_cast&, double, double>(&gridfire::MultiscalePartitioningEngineView::partitionNetwork), - py::arg("Y"), + "Analyze the connectivity of timescale pools in the network." + ); + py_multiscale_engine_view_bindings.def("partitionNetwork", + py::overload_cast(&gridfire::MultiscalePartitioningEngineView::partitionNetwork), + py::arg("comp"), py::arg("T9"), py::arg("rho"), "Partition the network based on species timescales and connectivity."); - py_multiscale_engine_view_bindings.def("partitionNetwork", py::overload_cast(&gridfire::MultiscalePartitioningEngineView::partitionNetwork), + py_multiscale_engine_view_bindings.def("partitionNetwork", + py::overload_cast(&gridfire::MultiscalePartitioningEngineView::partitionNetwork), py::arg("netIn"), - "Partition the network based on a NetIn object."); - py_multiscale_engine_view_bindings.def("exportToDot", &gridfire::MultiscalePartitioningEngineView::exportToDot, + "Partition the network based on a NetIn object." + ); + py_multiscale_engine_view_bindings.def("exportToDot", + &gridfire::MultiscalePartitioningEngineView::exportToDot, py::arg("filename"), - py::arg("Y"), + py::arg("comp"), py::arg("T9"), py::arg("rho"), - "Export the network to a DOT file for visualization."); - py_multiscale_engine_view_bindings.def("getFastSpecies", &gridfire::MultiscalePartitioningEngineView::getFastSpecies, - "Get the list of fast species in the network."); - py_multiscale_engine_view_bindings.def("getDynamicSpecies", &gridfire::MultiscalePartitioningEngineView::getDynamicSpecies, - "Get the list of dynamic species in the network."); - py_multiscale_engine_view_bindings.def("equilibrateNetwork", py::overload_cast&, double, double>(&gridfire::MultiscalePartitioningEngineView::equilibrateNetwork), - py::arg("Y"), + "Export the network to a DOT file for visualization." + ); + py_multiscale_engine_view_bindings.def("getFastSpecies", + &gridfire::MultiscalePartitioningEngineView::getFastSpecies, + "Get the list of fast species in the network." + ); + py_multiscale_engine_view_bindings.def("getDynamicSpecies", + &gridfire::MultiscalePartitioningEngineView::getDynamicSpecies, + "Get the list of dynamic species in the network." + ); + py_multiscale_engine_view_bindings.def("equilibrateNetwork", + py::overload_cast(&gridfire::MultiscalePartitioningEngineView::equilibrateNetwork), + py::arg("comp"), py::arg("T9"), py::arg("rho"), "Equilibrate the network based on species abundances and conditions."); - py_multiscale_engine_view_bindings.def("equilibrateNetwork", py::overload_cast(&gridfire::MultiscalePartitioningEngineView::equilibrateNetwork), + py_multiscale_engine_view_bindings.def("equilibrateNetwork", + py::overload_cast(&gridfire::MultiscalePartitioningEngineView::equilibrateNetwork), py::arg("netIn"), - "Equilibrate the network based on a NetIn object."); + "Equilibrate the network based on a NetIn object." + ); + + registerDynamicEngineDefs( + py_multiscale_engine_view_bindings + ); - registerDynamicEngineDefs(py_multiscale_engine_view_bindings); } @@ -387,3 +598,4 @@ void register_engine_view_bindings(const pybind11::module &m) { + diff --git a/src/python/engine/bindings.h b/src/python/engine/bindings.h index 31c7160b..86e0aac6 100644 --- a/src/python/engine/bindings.h +++ b/src/python/engine/bindings.h @@ -13,4 +13,14 @@ void abs_stype_register_dynamic_engine_bindings(const pybind11::module &m); void con_stype_register_graph_engine_bindings(const pybind11::module &m); +void register_engine_diagnostic_bindings(pybind11::module &m); +void register_engine_procedural_bindings(pybind11::module &m); + +void register_engine_construction_bindings(pybind11::module &m); +void register_engine_priming_bindings(pybind11::module &m); + +void register_engine_type_bindings(pybind11::module &m); +void register_engine_building_type_bindings(pybind11::module &m); +void register_engine_reporting_type_bindings(pybind11::module &m); + diff --git a/src/python/engine/trampoline/py_engine.cpp b/src/python/engine/trampoline/py_engine.cpp index 91b496ad..abf25666 100644 --- a/src/python/engine/trampoline/py_engine.cpp +++ b/src/python/engine/trampoline/py_engine.cpp @@ -33,12 +33,12 @@ const std::vector& PyEngine::getNetworkSpecies() const py::pybind11_fail("Tried to call pure virtual function \"DynamicEngine::getNetworkSpecies\""); } -std::expected, gridfire::expectations::StaleEngineError> PyEngine::calculateRHSAndEnergy(const std::vector &Y, double T9, double rho) const { +std::expected, gridfire::expectations::StaleEngineError> PyEngine::calculateRHSAndEnergy(const fourdst::composition::Composition &comp, double T9, double rho) const { PYBIND11_OVERRIDE_PURE( PYBIND11_TYPE(std::expected, gridfire::expectations::StaleEngineError>), gridfire::Engine, calculateRHSAndEnergy, - Y, T9, rho + comp, T9, rho ); } @@ -65,39 +65,62 @@ const std::vector& PyDynamicEngine::getNetworkSpecies( py::pybind11_fail("Tried to call pure virtual function \"DynamicEngine::getNetworkSpecies\""); } -std::expected, gridfire::expectations::StaleEngineError> PyDynamicEngine::calculateRHSAndEnergy(const std::vector &Y, double T9, double rho) const { +std::expected, gridfire::expectations::StaleEngineError> PyDynamicEngine::calculateRHSAndEnergy(const fourdst::composition::Composition &comp, double T9, double rho) const { PYBIND11_OVERRIDE_PURE( PYBIND11_TYPE(std::expected, gridfire::expectations::StaleEngineError>), gridfire::Engine, calculateRHSAndEnergy, - Y, T9, rho + comp, T9, rho ); } -void PyDynamicEngine::generateJacobianMatrix(const std::vector &Y_dynamic, double T9, double rho) const { +void PyDynamicEngine::generateJacobianMatrix(const fourdst::composition::Composition& comp, double T9, double rho) const { PYBIND11_OVERRIDE_PURE( void, gridfire::DynamicEngine, generateJacobianMatrix, - Y_dynamic, T9, rho + comp, + T9, + rho ); } -void PyDynamicEngine::generateJacobianMatrix(const std::vector &Y_dynamic, double T9, double rho, const gridfire::SparsityPattern &sparsityPattern) const { +void PyDynamicEngine::generateJacobianMatrix( + const fourdst::composition::Composition &comp, + const double T9, + const double rho, + const std::vector &activeSpecies +) const { PYBIND11_OVERRIDE_PURE( void, gridfire::DynamicEngine, generateJacobianMatrix, - Y_dynamic, T9, rho, sparsityPattern + comp, + T9, + rho, + activeSpecies ); } -double PyDynamicEngine::getJacobianMatrixEntry(int i, int j) const { +void PyDynamicEngine::generateJacobianMatrix(const fourdst::composition::Composition &comp, double T9, double rho, const gridfire::SparsityPattern &sparsityPattern) const { + PYBIND11_OVERRIDE_PURE( + void, + gridfire::DynamicEngine, + generateJacobianMatrix, + comp, + T9, + rho, + sparsityPattern + ); +} + +double PyDynamicEngine::getJacobianMatrixEntry(const fourdst::atomic::Species& rowSpecies, const fourdst::atomic::Species& colSpecies) const { PYBIND11_OVERRIDE_PURE( double, gridfire::DynamicEngine, getJacobianMatrixEntry, - i, j + rowSpecies, + colSpecies ); } @@ -109,21 +132,25 @@ void PyDynamicEngine::generateStoichiometryMatrix() { ); } -int PyDynamicEngine::getStoichiometryMatrixEntry(int speciesIndex, int reactionIndex) const { +int PyDynamicEngine::getStoichiometryMatrixEntry(const fourdst::atomic::Species& species, const gridfire::reaction::Reaction& reaction) const { PYBIND11_OVERRIDE_PURE( int, gridfire::DynamicEngine, getStoichiometryMatrixEntry, - speciesIndex, reactionIndex + species, + reaction ); } -double PyDynamicEngine::calculateMolarReactionFlow(const gridfire::reaction::Reaction &reaction, const std::vector &Y, double T9, double rho) const { +double PyDynamicEngine::calculateMolarReactionFlow(const gridfire::reaction::Reaction &reaction, const fourdst::composition::Composition &comp, double T9, double rho) const { PYBIND11_OVERRIDE_PURE( double, gridfire::DynamicEngine, calculateMolarReactionFlow, - reaction, Y, T9, rho + reaction, + comp, + T9, + rho ); } @@ -144,21 +171,23 @@ void PyDynamicEngine::setNetworkReactions(const gridfire::reaction::ReactionSet& ); } -std::expected, gridfire::expectations::StaleEngineError> PyDynamicEngine::getSpeciesTimescales(const std::vector &Y, double T9, double rho) const { +std::expected, gridfire::expectations::StaleEngineError> PyDynamicEngine::getSpeciesTimescales(const fourdst::composition::Composition &comp, double T9, double rho) const { PYBIND11_OVERRIDE_PURE( PYBIND11_TYPE(std::expected, gridfire::expectations::StaleEngineError>), gridfire::DynamicEngine, getSpeciesTimescales, - Y, T9, rho + comp, + T9, + rho ); } -std::expected, gridfire::expectations::StaleEngineError> PyDynamicEngine::getSpeciesDestructionTimescales(const std::vector &Y, double T9, double rho) const { +std::expected, gridfire::expectations::StaleEngineError> PyDynamicEngine::getSpeciesDestructionTimescales(const fourdst::composition::Composition &comp, double T9, double rho) const { PYBIND11_OVERRIDE_PURE( PYBIND11_TYPE(std::expected, gridfire::expectations::StaleEngineError>), gridfire::DynamicEngine, getSpeciesDestructionTimescales, - Y, T9, rho + comp, T9, rho ); } @@ -224,6 +253,31 @@ gridfire::PrimingReport PyDynamicEngine::primeEngine(const gridfire::NetIn &netI ); } +gridfire::EnergyDerivatives PyDynamicEngine::calculateEpsDerivatives( + const fourdst::composition::Composition &comp, + const double T9, + const double rho) const { + PYBIND11_OVERRIDE_PURE( + gridfire::EnergyDerivatives, + gridfire::DynamicEngine, + calculateEpsDerivatives, + comp, + T9, + rho + ); +} + +fourdst::composition::Composition PyDynamicEngine::collectComposition( + fourdst::composition::Composition &comp +) const { + PYBIND11_OVERRIDE_PURE( + fourdst::composition::Composition, + gridfire::DynamicEngine, + collectComposition, + comp + ); +} + const gridfire::Engine& PyEngineView::getBaseEngine() const { PYBIND11_OVERRIDE_PURE( const gridfire::Engine&, diff --git a/src/python/engine/trampoline/py_engine.h b/src/python/engine/trampoline/py_engine.h index 3634033f..aedcc8e5 100644 --- a/src/python/engine/trampoline/py_engine.h +++ b/src/python/engine/trampoline/py_engine.h @@ -12,7 +12,12 @@ class PyEngine final : public gridfire::Engine { public: const std::vector& getNetworkSpecies() const override; - std::expected,gridfire::expectations::StaleEngineError> calculateRHSAndEnergy(const std::vector &Y, double T9, double rho) const override; + + std::expected,gridfire::expectations::StaleEngineError> calculateRHSAndEnergy( + const fourdst::composition::Composition& comp, + double T9, + double rho + ) const override; private: mutable std::vector m_species_cache; }; @@ -20,41 +25,124 @@ private: class PyDynamicEngine final : public gridfire::DynamicEngine { public: const std::vector& getNetworkSpecies() const override; - std::expected,gridfire::expectations::StaleEngineError> calculateRHSAndEnergy(const std::vector &Y, double T9, double rho) const override; - void generateJacobianMatrix(const std::vector &Y_dynamic, double T9, double rho) const override; - void generateJacobianMatrix(const std::vector &Y_dynamic, double T9, double rho, const gridfire::SparsityPattern &sparsityPattern) const override; - double getJacobianMatrixEntry(int i, int j) const override; + + std::expected, gridfire::expectations::StaleEngineError> calculateRHSAndEnergy( + const fourdst::composition::Composition& comp, + double T9, + double rho + ) const override; + + void generateJacobianMatrix( + const fourdst::composition::Composition& comp, + double T9, + double rho + ) const override; + + void generateJacobianMatrix( + const fourdst::composition::Composition &comp, + double T9, + double rho, + const std::vector &activeSpecies + ) const override; + + void generateJacobianMatrix( + const fourdst::composition::Composition& comp, + double T9, + double rho, + const gridfire::SparsityPattern &sparsityPattern + ) const override; + + double getJacobianMatrixEntry( + const fourdst::atomic::Species& rowSpecies, + const fourdst::atomic::Species& colSpecies + ) const override; + void generateStoichiometryMatrix() override; - int getStoichiometryMatrixEntry(int speciesIndex, int reactionIndex) const override; - double calculateMolarReactionFlow(const gridfire::reaction::Reaction &reaction, const std::vector &Y, double T9, double rho) const override; + + int getStoichiometryMatrixEntry( + const fourdst::atomic::Species& species, + const gridfire::reaction::Reaction& reaction + ) const override; + + double calculateMolarReactionFlow( + const gridfire::reaction::Reaction &reaction, + const fourdst::composition::Composition& comp, + double T9, + double rho + ) const override; + const gridfire::reaction::ReactionSet& getNetworkReactions() const override; - void setNetworkReactions(const gridfire::reaction::ReactionSet& reactions) override; - std::expected, gridfire::expectations::StaleEngineError> getSpeciesTimescales(const std::vector &Y, double T9, double rho) const override; - std::expected, gridfire::expectations::StaleEngineError> getSpeciesDestructionTimescales(const std::vector &Y, double T9, double rho) const override; - fourdst::composition::Composition update(const gridfire::NetIn &netIn) override; - bool isStale(const gridfire::NetIn &netIn) override; - void setScreeningModel(gridfire::screening::ScreeningType model) override; + + void setNetworkReactions( + const gridfire::reaction::ReactionSet& reactions + ) override; + + std::expected, gridfire::expectations::StaleEngineError> getSpeciesTimescales( + const fourdst::composition::Composition& comp, + double T9, + double rho + ) const override; + + std::expected, gridfire::expectations::StaleEngineError> getSpeciesDestructionTimescales( + const fourdst::composition::Composition &comp, + double T9, + double rho + ) const override; + + fourdst::composition::Composition update( + const gridfire::NetIn &netIn + ) override; + + bool isStale( + const gridfire::NetIn &netIn + ) override; + + void setScreeningModel( + gridfire::screening::ScreeningType model + ) override; + gridfire::screening::ScreeningType getScreeningModel() const override; - size_t getSpeciesIndex(const fourdst::atomic::Species &species) const override; - std::vector mapNetInToMolarAbundanceVector(const gridfire::NetIn &netIn) const override; - gridfire::PrimingReport primeEngine(const gridfire::NetIn &netIn) override; + size_t getSpeciesIndex( + const fourdst::atomic::Species &species + ) const override; + + std::vector mapNetInToMolarAbundanceVector( + const gridfire::NetIn &netIn + ) const override; + + gridfire::PrimingReport primeEngine( + const gridfire::NetIn &netIn + ) override; + gridfire::BuildDepthType getDepth() const override { throw std::logic_error("Network depth not supported by this engine."); } - void rebuild(const fourdst::composition::Composition& comp, gridfire::BuildDepthType depth) override { + void rebuild( + const fourdst::composition::Composition& comp, + gridfire::BuildDepthType depth + ) override { throw std::logic_error("Setting network depth not supported by this engine."); } + + [[nodiscard]] gridfire::EnergyDerivatives calculateEpsDerivatives( + const fourdst::composition::Composition &comp, + double T9, + double rho + ) const override; + + fourdst::composition::Composition collectComposition( + fourdst::composition::Composition &comp + ) const override; + private: mutable std::vector m_species_cache; - - }; class PyEngineView final : public gridfire::EngineView { - const gridfire::Engine& getBaseEngine() const override; + [[nodiscard]] const gridfire::Engine& getBaseEngine() const override; }; class PyDynamicEngineView final : public gridfire::EngineView { - const gridfire::DynamicEngine& getBaseEngine() const override; + [[nodiscard]] const gridfire::DynamicEngine& getBaseEngine() const override; }; \ No newline at end of file diff --git a/src/python/exceptions/bindings.cpp b/src/python/exceptions/bindings.cpp index 0b201851..5fe00648 100644 --- a/src/python/exceptions/bindings.cpp +++ b/src/python/exceptions/bindings.cpp @@ -38,4 +38,21 @@ void register_exception_bindings(const py::module &m) { return self.what(); }); + py::register_exception(m, "FailedToPartitionEngineError", m.attr("GridFireEngineError")); + py::register_exception(m, "NetworkResizedError", m.attr("GridFireEngineError")); + py::register_exception(m, "UnableToSetNetworkReactionsError", m.attr("GridFireEngineError")); + py::register_exception(m, "BadCollectionError", m.attr("GridFireEngineError")); + + py::register_exception(m, "JacobianError", m.attr("GridFireEngineError")); + + py::register_exception(m, "StaleJacobianError", m.attr("JacobianEngineError")); + py::register_exception(m, "UninitializedJacobianError", m.attr("JacobianEngineError")); + py::register_exception(m, "UnknownJacobianError", m.attr("JacobianEngineError")); + + py::register_exception(m, "UtilityError"); + py::register_exception(m, "HashingError", m.attr("UtilityError")); + + + + } diff --git a/src/python/reaction/bindings.cpp b/src/python/reaction/bindings.cpp index edecc3aa..c9e4e206 100644 --- a/src/python/reaction/bindings.cpp +++ b/src/python/reaction/bindings.cpp @@ -48,7 +48,7 @@ void register_reaction_bindings(py::module &m) { .def( "calculate_rate", [](const gridfire::reaction::ReaclibReaction& self, const double T9, const double rho, const std::vector& Y) -> double { - return self.calculate_rate(T9, rho, 0, TODO, Y, TODO); + return self.calculate_rate(T9, rho, 0, {}, Y, {}); }, py::arg("T9"), py::arg("rho"), @@ -183,9 +183,15 @@ void register_reaction_bindings(py::module &m) { py::class_(m, "LogicalReaclibReaction") .def( - py::init>(), + py::init>(), py::arg("reactions"), - "Construct a LogicalReaclibReaction from a vector of Reaction objects." + "Construct a LogicalReaclibReaction from a vector of ReaclibReaction objects." + ) + .def( + py::init, bool>(), + py::arg("reactions"), + py::arg("is_reverse"), + "Construct a LogicalReaclibReaction from a vector of ReaclibReaction objects." ) .def( "add_reaction", @@ -210,26 +216,56 @@ void register_reaction_bindings(py::module &m) { ) .def( "calculate_rate", - [](const gridfire::reaction::LogicalReaclibReaction& self, const double T9, const double rho, const std::vector& Y) -> double { - return self.calculate_rate(T9, rho, 0, TODO, Y, TODO); + []( + const gridfire::reaction::LogicalReaclibReaction& self, + const double T9, + const double rho, + const double Ye, + const double mue, + const std::vector& Y, + const std::unordered_map& index_to_species_map + ) -> double { + return self.calculate_rate(T9, rho, Ye, mue, Y, index_to_species_map); }, py::arg("T9"), - "Calculate the reaction rate at a given temperature T9 (in units of 10^9 K)." + py::arg("rho"), + py::arg("Ye"), + py::arg("mue"), + py::arg("Y"), + py::arg("index_to_species_map"), + "Calculate the reaction rate at a given temperature T9 (in units of 10^9 K). Note that for a reaclib reaction only T9 is actually used, all other parameters are there for interface compatibility." ) .def( "calculate_forward_rate_log_derivative", - &gridfire::reaction::LogicalReaclibReaction::calculate_forward_rate_log_derivative, + &gridfire::reaction::LogicalReaclibReaction::calculate_log_rate_partial_deriv_wrt_T9, py::arg("T9"), + py::arg("rho"), + py::arg("Ye"), + py::arg("mue"), + py::arg("Composition"), "Calculate the forward rate log derivative at a given temperature T9 (in units of 10^9 K)." ); py::class_(m, "ReactionSet") - // TODO: Fix the constructor to accept a vector of unique ptrs to Reaclib Reactions .def( - py::init>(), + py::init>(), py::arg("reactions"), + py::keep_alive<1, 2>(), // Keep arg 2 (reactions) alive as long as arg 1 (self) is alive. This helps mitigate use-after-free errors "Construct a LogicalReactionSet from a vector of LogicalReaclibReaction objects." ) + .def_static( + "from_clones", + [](const std::vector& py_reactions) { + std::vector> cpp_reactions; + cpp_reactions.reserve(py_reactions.size()); + for (const auto& reaction : py_reactions) { + cpp_reactions.emplace_back(reaction->clone()); + } + return std::make_unique(std::move(cpp_reactions)); + }, + py::arg("reactions"), + "Create a ReactionSet that takes ownership of the reactions by cloning the input reactions." + ) .def( py::init<>(), "Default constructor for an empty LogicalReactionSet." diff --git a/src/python/solver/bindings.cpp b/src/python/solver/bindings.cpp index 3108fa47..a4035801 100644 --- a/src/python/solver/bindings.cpp +++ b/src/python/solver/bindings.cpp @@ -2,71 +2,97 @@ #include // Needed for vectors, maps, sets, strings #include // Needed for binding std::vector, std::map etc. if needed directly #include +#include #include #include "bindings.h" -#include "gridfire/solver/solver.h" +#include "gridfire/solver/strategies/CVODE_solver_strategy.h" #include "trampoline/py_solver.h" namespace py = pybind11; void register_solver_bindings(const py::module &m) { - auto py_dynamic_network_solving_strategy = py::class_(m, "DynamicNetworkSolverStrategy"); - auto py_direct_network_solver = py::class_(m, "DirectNetworkSolver"); - - py_direct_network_solver.def(py::init(), - py::arg("engine"), - "Constructor for the DirectNetworkSolver. Takes a DynamicEngine instance to use for evaluating the network." - ); - - py_direct_network_solver.def("evaluate", - &gridfire::solver::DirectNetworkSolver::evaluate, + auto py_dynamic_network_solver_strategy = py::class_(m, "DynamicNetworkSolverStrategy"); + py_dynamic_network_solver_strategy.def( + "evaluate", + &gridfire::solver::DynamicNetworkSolverStrategy::evaluate, py::arg("netIn"), - "Evaluate the network for a given timestep. Returns the output conditions after the timestep." + "evaluate the dynamic engine using the dynamic engine class" ); - py_direct_network_solver.def("set_callback", - [](gridfire::solver::DirectNetworkSolver &self, const gridfire::solver::DirectNetworkSolver::TimestepCallback& cb) { + py_dynamic_network_solver_strategy.def( + "set_callback", + [](gridfire::solver::DynamicNetworkSolverStrategy& self, std::function cb) { self.set_callback(cb); }, - py::arg("callback"), - "Sets a callback function to be called at each timestep." + "Set a callback function which will run at the end of every successful timestep" ); - py::class_(py_direct_network_solver, "TimestepContext") - .def_readonly("t", &gridfire::solver::DirectNetworkSolver::TimestepContext::t, "Current time in the simulation.") - .def_property_readonly( - "state", [](const gridfire::solver::DirectNetworkSolver::TimestepContext& ctx) { - std::vector state(ctx.state.size()); - std::ranges::copy(ctx.state, state.begin()); - return py::array_t(static_cast(state.size()), state.data()); - }) - .def_readonly("dt", &gridfire::solver::DirectNetworkSolver::TimestepContext::dt, "Current timestep size.") - .def_readonly("cached_time", &gridfire::solver::DirectNetworkSolver::TimestepContext::cached_time, "Cached time for the last computed result.") - .def_readonly("last_observed_time", &gridfire::solver::DirectNetworkSolver::TimestepContext::last_observed_time, "Last time the state was observed.") - .def_readonly("last_step_time", &gridfire::solver::DirectNetworkSolver::TimestepContext::last_step_time, "Last step time taken for the integration.") - .def_readonly("T9", &gridfire::solver::DirectNetworkSolver::TimestepContext::T9, "Temperature in units of 10^9 K.") - .def_readonly("rho", &gridfire::solver::DirectNetworkSolver::TimestepContext::rho, "Temperature in units of 10^9 K.") - .def_property_readonly("cached_result", [](const gridfire::solver::DirectNetworkSolver::TimestepContext& ctx) -> py::object { - if (ctx.cached_result.has_value()) { - const auto&[dydt, nuclearEnergyGenerationRate] = ctx.cached_result.value(); - return py::make_tuple( - py::array_t(static_cast(dydt.size()), dydt.data()), - nuclearEnergyGenerationRate - ); - } - return py::none(); - }, "Cached result of the step derivatives.") - .def_readonly("num_steps", &gridfire::solver::DirectNetworkSolver::TimestepContext::num_steps, "Total number of steps taken in the simulation.") - .def_property_readonly("engine", [](const gridfire::solver::DirectNetworkSolver::TimestepContext &ctx) -> const gridfire::DynamicEngine & { - return ctx.engine; - }, py::return_value_policy::reference) + py_dynamic_network_solver_strategy.def( + "describe_callback_context", + &gridfire::solver::DynamicNetworkSolverStrategy::describe_callback_context, + "Get a structure representing what data is in the callback context in a human readable format" + ); + + auto py_cvode_solver_strategy = py::class_(m, "CVODESolverStrategy"); + + py_cvode_solver_strategy.def( + py::init(), + py::arg("engine"), + "Initialize the CVODESolverStrategy object." + ); + + py_cvode_solver_strategy.def( + "evaluate", + py::overload_cast(&gridfire::solver::CVODESolverStrategy::evaluate), + py::arg("netIn"), + py::arg("display_trigger"), + "evaluate the dynamic engine using the dynamic engine class" + ); + + py_cvode_solver_strategy.def( + "get_stdout_logging_enabled", + &gridfire::solver::CVODESolverStrategy::get_stdout_logging_enabled, + "Check if solver logging to standard output is enabled." + ); + + py_cvode_solver_strategy.def( + "set_stdout_logging_enabled", + &gridfire::solver::CVODESolverStrategy::set_stdout_logging_enabled, + py::arg("logging_enabled"), + "Enable logging to standard output." + ); + + auto py_cvode_timestep_context = py::class_(m, "CVODETimestepContext"); + py_cvode_timestep_context.def_readonly("t", &gridfire::solver::CVODESolverStrategy::TimestepContext::t); + py_cvode_timestep_context.def_property_readonly( + "state", + [](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> std::vector { + const sunrealtype* nvec_data = N_VGetArrayPointer(self.state); + const sunindextype length = N_VGetLength(self.state); + return std::vector(nvec_data, nvec_data + length); + } + ); + py_cvode_timestep_context.def_readonly("dt", &gridfire::solver::CVODESolverStrategy::TimestepContext::dt); + py_cvode_timestep_context.def_readonly("last_step_time", &gridfire::solver::CVODESolverStrategy::TimestepContext::last_step_time); + py_cvode_timestep_context.def_readonly("T9", &gridfire::solver::CVODESolverStrategy::TimestepContext::T9); + py_cvode_timestep_context.def_readonly("rho", &gridfire::solver::CVODESolverStrategy::TimestepContext::rho); + py_cvode_timestep_context.def_readonly("num_steps", &gridfire::solver::CVODESolverStrategy::TimestepContext::num_steps); + py_cvode_timestep_context.def_property_readonly( + "engine", + [](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> const gridfire::DynamicEngine& { + return self.engine; + } + ); + py_cvode_timestep_context.def_property_readonly( + "networkSpecies", + [](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> std::vector { + return self.networkSpecies; + } + ); - .def_property_readonly("network_species", [](const gridfire::solver::DirectNetworkSolver::TimestepContext &ctx) -> const std::vector & { - return ctx.networkSpecies; - }, py::return_value_policy::reference); } diff --git a/src/python/solver/trampoline/py_solver.cpp b/src/python/solver/trampoline/py_solver.cpp index 0e67f8b6..2b735571 100644 --- a/src/python/solver/trampoline/py_solver.cpp +++ b/src/python/solver/trampoline/py_solver.cpp @@ -39,3 +39,12 @@ std::vector> PyDynamicNetworkSolverStrategy describe_callback_context // Method name ); } + +std::vector> PySolverContextBase::describe() const { + using DescriptionVector = std::vector>; + PYBIND11_OVERRIDE_PURE( + DescriptionVector, + gridfire::solver::SolverContextBase, + describe + ); +} diff --git a/src/python/solver/trampoline/py_solver.h b/src/python/solver/trampoline/py_solver.h index 120971c9..25b19609 100644 --- a/src/python/solver/trampoline/py_solver.h +++ b/src/python/solver/trampoline/py_solver.h @@ -12,4 +12,9 @@ class PyDynamicNetworkSolverStrategy final : public gridfire::solver::DynamicNet gridfire::NetOut evaluate(const gridfire::NetIn &netIn) override; void set_callback(const std::any &callback) override; [[nodiscard]] std::vector> describe_callback_context() const override; +}; + +class PySolverContextBase final : public gridfire::solver::SolverContextBase { +public: + [[nodiscard]] std::vector> describe() const override; }; \ No newline at end of file diff --git a/src/python/types/bindings.cpp b/src/python/types/bindings.cpp index 2a9d088d..23596f33 100644 --- a/src/python/types/bindings.cpp +++ b/src/python/types/bindings.cpp @@ -1,6 +1,7 @@ #include #include // Needed for vectors, maps, sets, strings #include // Needed for binding std::vector, std::map etc. if needed directly +#include #include "bindings.h" @@ -32,12 +33,18 @@ void register_type_bindings(const pybind11::module &m) { .def_readonly("composition", &gridfire::NetOut::composition) .def_readonly("num_steps", &gridfire::NetOut::num_steps) .def_readonly("energy", &gridfire::NetOut::energy) + .def_readonly("dEps_dT", &gridfire::NetOut::dEps_dT) + .def_readonly("dEps_dRho", &gridfire::NetOut::dEps_dRho) .def("__repr__", [](const gridfire::NetOut &netOut) { - std::stringstream ss; - ss << "NetOut(composition=" << netOut.composition - << ", num_steps=" << netOut.num_steps - << ", energy=" << netOut.energy << ")"; - return ss.str(); + std::string repr = std::format( + "NetOut(<μ> = {} steps = {}, ε = {}, dε/dT = {}, dε/dρ = {})", + netOut.composition.getMeanParticleMass(), + netOut.num_steps, + netOut.energy, + netOut.dEps_dT, + netOut.dEps_dRho + ); + return repr; }); } diff --git a/src/python/utils/bindings.cpp b/src/python/utils/bindings.cpp index 5d9810db..97f08d68 100644 --- a/src/python/utils/bindings.cpp +++ b/src/python/utils/bindings.cpp @@ -3,6 +3,9 @@ #include "bindings.h" +#include "gridfire/utils/general_composition.h" +#include "gridfire/utils/hashing.h" + namespace py = pybind11; #include "gridfire/utils/logging.h" @@ -16,4 +19,65 @@ void register_utils_bindings(py::module &m) { py::arg("rho"), "Format a string for logging nuclear timescales based on temperature, density, and energy generation rate." ); + + m.def( + "massFractionFromMolarAbundanceAndComposition", + &gridfire::utils::massFractionFromMolarAbundanceAndComposition, + py::arg("composition"), + py::arg("species"), + py::arg("Yi"), + "Convert a specific species molar abundance into its mass fraction if it were present in a given composition." + ); + + m.def( + "massFractionFromMolarAbundanceAndMolarMass", + &gridfire::utils::massFractionFromMolarAbundanceAndMolarMass, + py::arg("molarAbundances"), + py::arg("molarMasses"), + "Convert a vector of molar abundances and a parallel vector of molar masses into a vector of mass fractions" + ); + + m.def( + "molarMassVectorFromComposition", + &gridfire::utils::molarMassVectorFromComposition, + py::arg("composition"), + "Extract vector of molar masses from a composition object, this will be sorted by species mass so that the lightest species are at the front of the list." + ); + + m.def( + "hash_atomic", + &gridfire::utils::hash_atomic, + py::arg("a"), + py::arg("z") + ); + + auto hashing_module = m.def_submodule("hashing", "module for gridfire hashing functions"); + auto reaction_hashing_module = hashing_module.def_submodule("reaction", "utility module for hashing gridfire reaction functions"); + + reaction_hashing_module.def( + "splitmix64", + &gridfire::utils::hashing::reaction::splitmix64, + py::arg("x") + ); + + reaction_hashing_module.def( + "mix_species", + &gridfire::utils::hashing::reaction::mix_species, + py::arg("a"), + py::arg("z") + ); + + reaction_hashing_module.def( + "multiset_combine", + &gridfire::utils::hashing::reaction::multiset_combine, + py::arg("acc"), + py::arg("x") + ); + + m.def( + "hash_reaction", + &gridfire::utils::hash_reaction, + py::arg("reaction") + ); + }