16 Commits

Author SHA1 Message Date
087926728a docs(version): v0.7.6rc3 -> v0.7.6rc3.1
this version uses the target workaround which we use for the SUNDIALS suite to get meson to recognize it as position indepenednt code to CppAD
2025-12-22 08:23:20 -05:00
df09564c9a fix(cppad): build static lib for cppad
this is needed for gcc build and does not break clang build
2025-12-22 08:22:07 -05:00
0bf3ae625b docs(meson): version v0.7.5rc3 -> v0.7.6rc3 2025-12-22 08:18:02 -05:00
5c1714410a fix(engine_defined): fixed gcc build warnings 2025-12-22 08:17:23 -05:00
e6a9d8c5bb fix(wheels): fixed macos wheel generation to pin meson==1.9.1 2025-12-22 07:09:51 -05:00
e73daf88b3 docs(version): v0.7.4rc3 -> v0.7.5rc4 2025-12-20 16:41:27 -05:00
e197227908 fix(meson): address #13
Address regresion in meson 1.10.1 by pinning version to 1.9.1
2025-12-20 16:40:28 -05:00
f1f793f775 Merge pull request #12 from tboudreaux/perf/coloring
Perf/coloring
2025-12-20 16:17:08 -05:00
e98c9a4050 feat(python): Python multi-test
Python now works with mulit-threading and zones
2025-12-20 16:08:33 -05:00
11a596b75b feat(python): Python Bindings
Python Bindings are working again
2025-12-20 16:02:52 -05:00
d65c237b26 feat(fortran): Fortran interface can now use multi-zone
Fortran interface uses the new C api ability to call the naieve
multi-zone solver. This allows fortran calling code to make use of in
build parellaism for solving multiple zones
2025-12-19 09:58:47 -05:00
2a9649a72e feat(C-API): C API now can use multi-zone solver
C api has been brought back up to support and can use paraellization along with multi zone solver
2025-12-18 15:14:47 -05:00
cd5e42b69a feat(omp): useful omp macros
A few macros which make turning on and off omp features cleaner without
#defines everywherer
2025-12-18 12:48:10 -05:00
dcfd7b60aa perf(multi): Simple parallel multi zone solver
Added a simple parallel multi-zone solver
2025-12-18 12:47:39 -05:00
4e1edfc142 feat(Spectral): Working on Spectral Solver 2025-12-15 12:14:00 -05:00
0b09ed1cb3 feat(SpectralSolver): Spectral Solver now works in a limited fashion
Major work on spectral solver, can now evolve up to about a year. At
that point we likely need to impliment repartitioning logic to stabalize
the network or some other scheme based on the jacobian structure
2025-12-12 17:24:53 -05:00
128 changed files with 6957 additions and 3097 deletions

View File

@@ -48,7 +48,7 @@ PROJECT_NAME = GridFire
# could be handy for archiving the generated documentation or if some version # could be handy for archiving the generated documentation or if some version
# control system is used. # control system is used.
PROJECT_NUMBER = v0.7.4_rc2 PROJECT_NUMBER = v0.7.5rc3
# Using the PROJECT_BRIEF tag one can provide an optional one line description # Using the PROJECT_BRIEF tag one can provide an optional one line description
# for a project that appears at the top of each page and should give viewers a # for a project that appears at the top of each page and should give viewers a

View File

@@ -19,15 +19,7 @@
#include <clocale> #include <clocale>
#include "gridfire/reaction/reaclib.h" #include "gridfire/reaction/reaclib.h"
#include <omp.h> #include "gridfire/utils/gf_omp.h"
unsigned long get_thread_id() {
return static_cast<unsigned long>(omp_get_thread_num());
}
bool in_parallel() {
return omp_in_parallel() != 0;
}
gridfire::NetIn init(const double temp, const double rho, const double tMax) { gridfire::NetIn init(const double temp, const double rho, const double tMax) {
std::setlocale(LC_ALL, ""); std::setlocale(LC_ALL, "");
@@ -55,6 +47,7 @@ gridfire::NetIn init(const double temp, const double rho, const double tMax) {
int main() { int main() {
GF_PAR_INIT()
using namespace gridfire; using namespace gridfire;
constexpr size_t breaks = 1; constexpr size_t breaks = 1;
@@ -70,7 +63,7 @@ int main() {
std::println("Scratch Blob State: {}", *construct.scratch_blob); std::println("Scratch Blob State: {}", *construct.scratch_blob);
constexpr size_t runs = 1000; constexpr size_t runs = 10;
auto startTime = std::chrono::high_resolution_clock::now(); auto startTime = std::chrono::high_resolution_clock::now();
// arrays to store timings // arrays to store timings
@@ -79,14 +72,15 @@ int main() {
std::array<NetOut, runs> serial_results; std::array<NetOut, runs> serial_results;
for (size_t i = 0; i < runs; ++i) { for (size_t i = 0; i < runs; ++i) {
auto start_setup_time = std::chrono::high_resolution_clock::now(); auto start_setup_time = std::chrono::high_resolution_clock::now();
solver::CVODESolverStrategy solver(construct.engine, *construct.scratch_blob); solver::PointSolverContext solverCtx(*construct.scratch_blob);
solver.set_stdout_logging_enabled(false); solverCtx.set_stdout_logging(false);
solver::PointSolver solver(construct.engine);
auto end_setup_time = std::chrono::high_resolution_clock::now(); auto end_setup_time = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> setup_elapsed = end_setup_time - start_setup_time; std::chrono::duration<double> setup_elapsed = end_setup_time - start_setup_time;
setup_times[i] = setup_elapsed; setup_times[i] = setup_elapsed;
auto start_eval_time = std::chrono::high_resolution_clock::now(); auto start_eval_time = std::chrono::high_resolution_clock::now();
const NetOut netOut = solver.evaluate(netIn); const NetOut netOut = solver.evaluate(solverCtx, netIn);
auto end_eval_time = std::chrono::high_resolution_clock::now(); auto end_eval_time = std::chrono::high_resolution_clock::now();
serial_results[i] = netOut; serial_results[i] = netOut;
std::chrono::duration<double> eval_elapsed = end_eval_time - start_eval_time; std::chrono::duration<double> eval_elapsed = end_eval_time - start_eval_time;
@@ -106,17 +100,7 @@ int main() {
std::println("Average Setup Time over {} runs: {:.6f} seconds", runs, total_setup_time / runs); std::println("Average Setup Time over {} runs: {:.6f} seconds", runs, total_setup_time / runs);
std::println("Average Evaluation Time over {} runs: {:.6f} seconds", runs, total_eval_time / runs); std::println("Average Evaluation Time over {} runs: {:.6f} seconds", runs, total_eval_time / runs);
std::println("Total Time for {} runs: {:.6f} seconds", runs, elapsed.count()); std::println("Total Time for {} runs: {:.6f} seconds", runs, elapsed.count());
std::println("Final H-1 Abundances Serial: {}", serial_results[0].composition.getMolarAbundance(fourdst::atomic::H_1));
CppAD::thread_alloc::parallel_setup(
static_cast<size_t>(omp_get_max_threads()), // Max threads
[]() -> bool { return in_parallel(); }, // Function to get thread ID
[]() -> size_t { return get_thread_id(); } // Function to check parallel state
);
// OPTIONAL: Prevent CppAD from returning memory to the system
// during execution to reduce overhead (can speed up tight loops)
CppAD::thread_alloc::hold_memory(true);
std::array<NetOut, runs> parallelResults; std::array<NetOut, runs> parallelResults;
std::array<std::chrono::duration<double>, runs> setupTimes; std::array<std::chrono::duration<double>, runs> setupTimes;
@@ -129,16 +113,17 @@ int main() {
// Parallel runs // Parallel runs
startTime = std::chrono::high_resolution_clock::now(); startTime = std::chrono::high_resolution_clock::now();
#pragma omp parallel for
for (size_t i = 0; i < runs; ++i) { GF_OMP(parallel for, for (size_t i = 0; i < runs; ++i)) {
auto start_setup_time = std::chrono::high_resolution_clock::now(); auto start_setup_time = std::chrono::high_resolution_clock::now();
solver::CVODESolverStrategy solver(construct.engine, *workspaces[i]); solver::PointSolverContext solverCtx(*construct.scratch_blob);
solver.set_stdout_logging_enabled(false); solverCtx.set_stdout_logging(false);
solver::PointSolver solver(construct.engine);
auto end_setup_time = std::chrono::high_resolution_clock::now(); auto end_setup_time = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> setup_elapsed = end_setup_time - start_setup_time; std::chrono::duration<double> setup_elapsed = end_setup_time - start_setup_time;
setupTimes[i] = setup_elapsed; setupTimes[i] = setup_elapsed;
auto start_eval_time = std::chrono::high_resolution_clock::now(); auto start_eval_time = std::chrono::high_resolution_clock::now();
parallelResults[i] = solver.evaluate(netIn); parallelResults[i] = solver.evaluate(solverCtx, netIn);
auto end_eval_time = std::chrono::high_resolution_clock::now(); auto end_eval_time = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> eval_elapsed = end_eval_time - start_eval_time; std::chrono::duration<double> eval_elapsed = end_eval_time - start_eval_time;
evalTimes[i] = eval_elapsed; evalTimes[i] = eval_elapsed;
@@ -159,10 +144,6 @@ int main() {
std::println("Average Parallel Evaluation Time over {} runs: {:.6f} seconds", runs, total_eval_time / runs); std::println("Average Parallel Evaluation Time over {} runs: {:.6f} seconds", runs, total_eval_time / runs);
std::println("Total Parallel Time for {} runs: {:.6f} seconds", runs, elapsed.count()); std::println("Total Parallel Time for {} runs: {:.6f} seconds", runs, elapsed.count());
std::println("Final H-1 Abundances Parallel: {}", utils::iterable_to_delimited_string(parallelResults, ",", [](const auto& result) {
return result.composition.getMolarAbundance(fourdst::atomic::H_1);
}));
std::println("========== Summary =========="); std::println("========== Summary ==========");
std::println("Serial Runs:"); std::println("Serial Runs:");
std::println(" Average Setup Time: {:.6f} seconds", total_setup_time / runs); std::println(" Average Setup Time: {:.6f} seconds", total_setup_time / runs);

View File

@@ -0,0 +1,3 @@
if get_option('build_benchmarks')
subdir('SingleZoneSolver')
endif

View File

@@ -33,5 +33,18 @@ else
endif endif
if get_option('openmp_support') if get_option('openmp_support')
add_project_arguments('-DGRIDFIRE_USE_OPENMP', language: 'cpp') add_project_arguments('-DGF_USE_OPENMP', language: 'cpp')
endif
if get_option('asan') and get_option('buildtype') != 'debug' and get_option('buildtype') != 'debugoptimized'
error('AddressSanitizer (ASan) can only be enabled for debug or debugoptimized builds')
endif
if get_option('asan') and (get_option('buildtype') == 'debugoptimized' or get_option('buildtype') == 'debug')
message('enabling AddressSanitizer (ASan) support')
add_project_arguments('-fsanitize=address,undefined', language: 'cpp')
add_project_arguments('-fno-omit-frame-pointer', language: 'cpp')
add_project_link_arguments('-fsanitize=address,undefined', language: 'cpp')
add_project_link_arguments('-fno-omit-frame-pointer', language: 'cpp')
endif endif

View File

@@ -1,5 +1,10 @@
if get_option('build_fortran') if get_option('build_fortran')
add_languages('fortran', native: true) found_fortran = add_languages('fortran')
if not found_fortran
error('Fortran compiler not found, but build_fortran option is enabled.')
else
message('Fortran compiler found.')
endif
message('Found FORTRAN compiler: ' + meson.get_compiler('fortran').get_id()) message('Found FORTRAN compiler: ' + meson.get_compiler('fortran').get_id())
message('Fortran standard set to: ' + get_option('fortran_std')) message('Fortran standard set to: ' + get_option('fortran_std'))
message('Building fortran module (gridfire_mod.mod)') message('Building fortran module (gridfire_mod.mod)')

View File

@@ -1,26 +1,36 @@
#cppad_inc = include_directories('include', is_system: true)
#cppad_dep = declare_dependency(
# include_directories: cppad_inc,
#)
#
#message('Registering CppAD headers for installation...')
#install_subdir('include/cppad', install_dir: get_option('includedir'))
#message('Done registering CppAD headers for installation!')
#
cppad_cmake_options = cmake.subproject_options() cppad_cmake_options = cmake.subproject_options()
cppad_cmake_options.add_cmake_defines({ cppad_cmake_options.add_cmake_defines({
'cppad_static_lib': 'true', 'cppad_static_lib': 'true',
'cpp_mas_num_threads': '10', 'cpp_mas_num_threads': '10',
'cppad_debug_and_release': 'false', 'cppad_debug_and_release': 'false',
'include_doc': 'false' 'include_doc': 'false',
'CMAKE_POSITION_INDEPENDENT_CODE': true
}) })
cppad_cmake_options.set_install(false)
cppad_sp = cmake.subproject( cppad_sp = cmake.subproject(
'cppad', 'cppad',
options: cppad_cmake_options, options: cppad_cmake_options,
) )
cppad_dep = cppad_sp.dependency('cppad_lib').as_system() cppad_target = cppad_sp.target('cppad_lib')
cppad_objs = [cppad_target.extract_all_objects(recursive: true)]
cppad_incs = cppad_sp.include_directories('cppad_lib')
empty_cppad_file = configure_file(output: 'cppad_dummy_ar.cpp', command: ['echo'], capture: true)
libcppad_static = static_library(
'cppad-static',
empty_cppad_file,
objects: cppad_objs,
include_directories: cppad_incs,
pic: true,
install: false
)
cppad_dep = declare_dependency(
link_with: libcppad_static,
include_directories: cppad_incs
)

View File

@@ -7,10 +7,10 @@ if meson.is_cross_build() and host_machine.system() == 'darwin'
py_inc_dir = include_directories('../../cross/python_includes/python-' + py_ver + '/include/python' + py_ver) py_inc_dir = include_directories('../../cross/python_includes/python-' + py_ver + '/include/python' + py_ver)
py_dep = declare_dependency(include_directories: py_inc_dir) py_dep = declare_dependency(include_directories: py_inc_dir)
py_module_prefix = '' py_module_prefix = ''
py_module_suffic = 'so' py_module_suffix = 'so'
meson.override_dependency('python3', py_dep) meson.override_dependency('python3', py_dep)
else else
py_dep = py_installation.dependency() py_dep = py_installation.dependency()
py_module_prefix = '' py_module_prefix = ''
py_module_suffic = 'so' py_module_suffix = 'so'
endif endif

View File

@@ -7,7 +7,8 @@ cvode_cmake_options.add_cmake_defines({
'BUILD_SHARED_LIBS' : 'OFF', 'BUILD_SHARED_LIBS' : 'OFF',
'BUILD_STATIC_LIBS' : 'ON', 'BUILD_STATIC_LIBS' : 'ON',
'EXAMPLES_ENABLE_C' : 'OFF', 'EXAMPLES_ENABLE_C' : 'OFF',
'CMAKE_POSITION_INDEPENDENT_CODE': true 'CMAKE_POSITION_INDEPENDENT_CODE': true,
'CMAKE_PLATFORM_NO_VERSIONED_SONAME': 'ON'
}) })
@@ -16,6 +17,8 @@ cvode_cmake_options.add_cmake_defines({
'CMAKE_INSTALL_INCLUDEDIR': get_option('includedir') 'CMAKE_INSTALL_INCLUDEDIR': get_option('includedir')
}) })
cvode_cmake_options.set_install(false)
if meson.is_cross_build() and host_machine.system() == 'emscripten' if meson.is_cross_build() and host_machine.system() == 'emscripten'
cvode_cmake_options.add_cmake_defines({ cvode_cmake_options.add_cmake_defines({
'CMAKE_C_FLAGS': '-s MEMORY64=1 -s ALLOW_MEMORY_GROWTH=1', 'CMAKE_C_FLAGS': '-s MEMORY64=1 -s ALLOW_MEMORY_GROWTH=1',

View File

@@ -8,7 +8,8 @@ kinsol_cmake_options.add_cmake_defines({
'BUILD_SHARED_LIBS' : 'OFF', 'BUILD_SHARED_LIBS' : 'OFF',
'BUILD_STATIC_LIBS' : 'ON', 'BUILD_STATIC_LIBS' : 'ON',
'EXAMPLES_ENABLE_C' : 'OFF', 'EXAMPLES_ENABLE_C' : 'OFF',
'CMAKE_POSITION_INDEPENDENT_CODE': true 'CMAKE_POSITION_INDEPENDENT_CODE': true,
'CMAKE_PLATFORM_NO_VERSIONED_SONAME': 'ON'
}) })
kinsol_cmake_options.add_cmake_defines({ kinsol_cmake_options.add_cmake_defines({
@@ -16,6 +17,8 @@ kinsol_cmake_options.add_cmake_defines({
'CMAKE_INSTALL_INCLUDEDIR': get_option('includedir') 'CMAKE_INSTALL_INCLUDEDIR': get_option('includedir')
}) })
kinsol_cmake_options.set_install(false)
kinsol_sp = cmake.subproject( kinsol_sp = cmake.subproject(
'kinsol', 'kinsol',
options: kinsol_cmake_options, options: kinsol_cmake_options,

View File

@@ -28,6 +28,8 @@ if get_option('build_python')
meson.project_source_root() + '/src/python/policy/bindings.cpp', meson.project_source_root() + '/src/python/policy/bindings.cpp',
meson.project_source_root() + '/src/python/policy/trampoline/py_policy.cpp', meson.project_source_root() + '/src/python/policy/trampoline/py_policy.cpp',
meson.project_source_root() + '/src/python/utils/bindings.cpp', meson.project_source_root() + '/src/python/utils/bindings.cpp',
meson.project_source_root() + '/src/python/config/bindings.cpp',
meson.project_source_root() + '/src/python/engine/scratchpads/bindings.cpp',
] ]
@@ -56,6 +58,7 @@ if get_option('build_python')
files( files(
meson.project_source_root() + '/src/python/gridfire/__init__.py', meson.project_source_root() + '/src/python/gridfire/__init__.py',
meson.project_source_root() + '/stubs/gridfire/_gridfire/__init__.pyi', meson.project_source_root() + '/stubs/gridfire/_gridfire/__init__.pyi',
meson.project_source_root() + '/stubs/gridfire/_gridfire/config.pyi',
meson.project_source_root() + '/stubs/gridfire/_gridfire/exceptions.pyi', meson.project_source_root() + '/stubs/gridfire/_gridfire/exceptions.pyi',
meson.project_source_root() + '/stubs/gridfire/_gridfire/partition.pyi', meson.project_source_root() + '/stubs/gridfire/_gridfire/partition.pyi',
meson.project_source_root() + '/stubs/gridfire/_gridfire/reaction.pyi', meson.project_source_root() + '/stubs/gridfire/_gridfire/reaction.pyi',
@@ -72,6 +75,7 @@ if get_option('build_python')
files( files(
meson.project_source_root() + '/stubs/gridfire/_gridfire/engine/__init__.pyi', meson.project_source_root() + '/stubs/gridfire/_gridfire/engine/__init__.pyi',
meson.project_source_root() + '/stubs/gridfire/_gridfire/engine/diagnostics.pyi', meson.project_source_root() + '/stubs/gridfire/_gridfire/engine/diagnostics.pyi',
meson.project_source_root() + '/stubs/gridfire/_gridfire/engine/scratchpads.pyi'
), ),
subdir: 'gridfire/engine', subdir: 'gridfire/engine',
) )

View File

@@ -18,7 +18,7 @@
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
# #
# *********************************************************************** # # *********************************************************************** #
project('GridFire', ['c', 'cpp'], version: 'v0.7.4_rc2', default_options: ['cpp_std=c++23'], meson_version: '>=1.5.0') project('GridFire', ['c', 'cpp'], version: 'v0.7.6rc3.1', default_options: ['cpp_std=c++23'], meson_version: '>=1.5.0')
# Start by running the code which validates the build environment # Start by running the code which validates the build environment
subdir('build-check') subdir('build-check')

View File

@@ -12,3 +12,4 @@ option('build_tools', type: 'boolean', value: true, description: 'build the Grid
option('openmp_support', type: 'boolean', value: false, description: 'Enable OpenMP support for parallelization') option('openmp_support', type: 'boolean', value: false, description: 'Enable OpenMP support for parallelization')
option('use_mimalloc', type: 'boolean', value: true, description: 'Use mimalloc as the memory allocator for GridFire. Generally this is ~10% faster than the system allocator.') option('use_mimalloc', type: 'boolean', value: true, description: 'Use mimalloc as the memory allocator for GridFire. Generally this is ~10% faster than the system allocator.')
option('build_benchmarks', type: 'boolean', value: false, description: 'build the benchmark suite') option('build_benchmarks', type: 'boolean', value: false, description: 'build the benchmark suite')
option('asan', type: 'boolean', value: false, description: 'Enable AddressSanitizer (ASan) support for detecting memory errors')

View File

@@ -56,7 +56,7 @@ echo "Site packages: $SITE_PACKAGES"
echo "" echo ""
echo -e "${GREEN}Step 2: Installing fourdst with pip...${NC}" echo -e "${GREEN}Step 2: Installing fourdst with pip...${NC}"
$PYTHON_BIN -m pip install . -v $PYTHON_BIN -m pip install . -v --no-build-isolation
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo -e "${RED}Error: pip install failed${NC}" echo -e "${RED}Error: pip install failed${NC}"

View File

@@ -1,14 +1,14 @@
[build-system] [build-system]
requires = [ requires = [
"meson-python>=0.15.0", # Use a recent version "meson-python>=0.15.0", # Use a recent version
"meson>=1.6.0", # Specify your Meson version requirement "meson==1.9.1", # Specify your Meson version requirement
"pybind11>=2.10" # pybind11 headers needed at build time "pybind11>=2.10" # pybind11 headers needed at build time
] ]
build-backend = "mesonpy" build-backend = "mesonpy"
[project] [project]
name = "gridfire" # Choose your Python package name name = "gridfire" # Choose your Python package name
version = "0.7.4_rc2" # Your project's version version = "v0.7.5rc3" # Your project's version
description = "Python interface to the GridFire nuclear network code" description = "Python interface to the GridFire nuclear network code"
readme = "README.md" readme = "README.md"
license = { file = "LICENSE.txt" } # Reference your license file [cite: 2] license = { file = "LICENSE.txt" } # Reference your license file [cite: 2]
@@ -22,4 +22,4 @@ maintainers = [
] ]
[tool.meson-python.args] [tool.meson-python.args]
setup = ['-Dpkg-config=false'] setup = ['-Dpkg_config=false', '-Dbuildtype=release', '-Dopenmp_support=true', '-Dasan=false', '-Dlog_level=error', '-Dbuild_tests=false', '-Dbuild_c_api=false', '-Dbuild_examples=false', '-Dbuild_benchmarks=false', '-Dbuild_tools=false', '-Dplugin_support=false', '-Duse_mimalloc=false', '-Dbuild_python=true']

View File

@@ -2,6 +2,14 @@ module gridfire_mod
use iso_c_binding use iso_c_binding
implicit none implicit none
type, public :: GF_TYPE
integer(c_int) :: value
end type GF_TYPE
type(GF_TYPE), parameter, public :: &
SINGLE_ZONE = GF_TYPE(1001), &
MULTI_ZONE = GF_TYPE(1002)
enum, bind (C) enum, bind (C)
enumerator :: FDSSE_NON_4DSTAR_ERROR = -102 enumerator :: FDSSE_NON_4DSTAR_ERROR = -102
enumerator :: FDSSE_UNKNOWN_ERROR = -101 enumerator :: FDSSE_UNKNOWN_ERROR = -101
@@ -50,24 +58,46 @@ module gridfire_mod
enumerator :: GF_DEBUG_ERRROR = 30 enumerator :: GF_DEBUG_ERRROR = 30
enumerator :: GF_GRIDFIRE_ERROR = 31 enumerator :: GF_GRIDFIRE_ERROR = 31
enumerator :: GF_UNINITIALIZED_INPUT_MEMORY_ERROR = 32
enumerator :: GF_UNINITIALIZED_OUTPUT_MEMORY_ERROR = 33
enumerator :: GF_INVALD_NUM_SPECIES = 34
enumerator :: GF_INVALID_TIMESTEPS = 35
enumerator :: GF_UNKNONWN_FREE_TYPE = 36
enumerator :: GF_INVALID_TYPE = 37
enumerator :: GF_SINGLE_ZONE = 1001
enumerator :: GF_MULTI_ZONE = 1002
end enum end enum
interface interface
! void* gf_init() ! void* gf_init()
function gf_init() bind(C, name="gf_init") function gf_init(ctx_type) bind(C, name="gf_init")
import :: c_ptr import :: c_ptr, c_int
type(c_ptr) :: gf_init type(c_ptr) :: gf_init
integer(c_int), value :: ctx_type
end function gf_init end function gf_init
! void gf_free(void* gf) ! int gf_free(void* gf)
subroutine gf_free(gf) bind(C, name="gf_free") function gf_free(ctx_type, ptr) result(c_res) bind(C, name="gf_free")
import :: c_ptr import :: c_ptr, c_int
type(c_ptr), value :: gf type(c_ptr), value :: ptr
end subroutine gf_free integer(c_int), value :: ctx_type
integer(c_int) :: c_res
end function gf_free
function gf_set_num_zones(ctx_type, ptr, num_zones) result(c_res) bind(C, name="gf_set_num_zones")
import :: c_ptr, c_int, c_size_t
type(c_ptr), value :: ptr
integer(c_int), value :: ctx_type
integer(c_size_t), value :: num_zones
integer(c_int) :: c_res
end function gf_set_num_zones
! char* gf_get_last_error_message(void* ptr); ! char* gf_get_last_error_message(void* ptr);
function gf_get_last_error_message(ptr) result(c_msg) bind(C, name="gf_get_last_error_message") function gf_get_last_error_message(ptr) result(c_msg) bind(C, name="gf_get_last_error_message")
import import :: c_ptr, c_int
type(c_ptr), value :: ptr type(c_ptr), value :: ptr
type(c_ptr) :: c_msg type(c_ptr) :: c_msg
end function end function
@@ -102,49 +132,116 @@ module gridfire_mod
end function end function
! int gf_evolve(...) ! int gf_evolve(...)
function gf_evolve(ptr, Y_in, num_species, T, rho, dt, Y_out, energy_out, dEps_dT, dEps_dRho, specific_neutrino_loss, specific_neutrino_flux, mass_lost) result(ierr) & function gf_evolve_c_scalar(ctx_type, ptr, Y_in, num_species, T, rho, tMax, dt0, &
Y_out, energy, dedt, dedrho, &
nue_loss, nu_flux, mass_lost) result(ierr) &
bind(C, name="gf_evolve") bind(C, name="gf_evolve")
import import :: c_ptr, c_int, c_double, c_size_t
type(c_ptr), value :: ptr type(c_ptr), value :: ptr
real(c_double), dimension(*), intent(in) :: Y_in integer(c_int), value :: ctx_type
integer(c_size_t), value :: num_species integer(c_size_t), value :: num_species
real(c_double), value :: T, rho, dt
! Arrays
real(c_double), dimension(*), intent(in) :: Y_in
real(c_double), dimension(*), intent(out) :: Y_out real(c_double), dimension(*), intent(out) :: Y_out
real(c_double), intent(out) :: energy_out, dEps_dT, dEps_dRho, specific_neutrino_loss, specific_neutrino_flux, mass_lost
! Scalars (Passed by Reference -> matches void*)
real(c_double), intent(in) :: T, rho
real(c_double), intent(out) :: energy, dedt, dedrho, nue_loss, nu_flux, mass_lost
! Scalars (Passed by Value)
real(c_double), value :: tMax, dt0
integer(c_int) :: ierr
end function
! 2. Interface for Multi Zone (Arrays)
function gf_evolve_c_array(ctx_type, ptr, Y_in, num_species, T, rho, tMax, dt0, &
Y_out, energy, dedt, dedrho, &
nue_loss, nu_flux, mass_lost) result(ierr) &
bind(C, name="gf_evolve")
import :: c_ptr, c_int, c_double, c_size_t
type(c_ptr), value :: ptr
integer(c_int), value :: ctx_type
integer(c_size_t), value :: num_species
! All Arrays (dimension(*))
real(c_double), dimension(*), intent(in) :: Y_in
real(c_double), dimension(*), intent(in) :: T, rho
real(c_double), dimension(*), intent(out) :: Y_out
real(c_double), dimension(*), intent(out) :: energy, dedt, dedrho, nue_loss, nu_flux, mass_lost
! Scalars (Passed by Value)
real(c_double), value :: tMax, dt0
integer(c_int) :: ierr integer(c_int) :: ierr
end function end function
end interface end interface
type :: GridFire type :: GridFire
type(c_ptr) :: ctx = c_null_ptr type(c_ptr) :: ctx = c_null_ptr
integer(c_int) :: ctx_type = SINGLE_ZONE%value
integer(c_size_t) :: num_species = 0 integer(c_size_t) :: num_species = 0
integer(c_size_t) :: num_zones = 1
contains contains
procedure :: gff_init procedure :: gff_init
procedure :: gff_free procedure :: gff_free
procedure :: register_species procedure :: gff_register_species
procedure :: setup_policy procedure :: gff_setup_policy
procedure :: setup_solver procedure :: gff_setup_solver
procedure :: evolve procedure :: gff_get_last_error
procedure :: get_last_error
procedure :: gff_evolve_single
procedure :: gff_evolve_multi
generic :: gff_evolve => gff_evolve_single, gff_evolve_multi
end type GridFire end type GridFire
contains contains
subroutine gff_init(self) subroutine gff_init(self, type, zones)
class(GridFire), intent(out) :: self class(GridFire), intent(out) :: self
type(GF_TYPE), intent(in) :: type
integer(c_size_t), intent(in), optional :: zones
integer(c_int) :: ierr
self%ctx = gf_init() if (type%value==1002) then
if (.not. present(zones)) then
print *, "GridFire Error: Multi-zone type requires number of zones to be specficied in the GridFire init method (i.e. GridFire(MULTI_ZONE, 10) for 10 zones)."
error stop
end if
self%num_zones = zones
end if
self%ctx_type = type%value
self%ctx = gf_init(self%ctx_type)
if (type%value==1002) then
ierr = gf_set_num_zones(self%ctx_type, self%ctx, self%num_zones)
if (ierr /= GF_SUCCESS .AND. ierr /= FDSSE_SUCCESS) then
print *, "GridFire Multi-Zone Error: ", self%gff_get_last_error()
error stop
end if
end if
end subroutine gff_init end subroutine gff_init
subroutine gff_free(self) subroutine gff_free(self)
class(GridFire), intent(inout) :: self class(GridFire), intent(inout) :: self
integer(c_int) :: ierr
if (c_associated(self%ctx)) then if (c_associated(self%ctx)) then
call gf_free(self%ctx) ierr = gf_free(self%ctx_type, self%ctx)
if (ierr /= GF_SUCCESS .AND. ierr /= FDSSE_SUCCESS) then
print *, "GridFire Free Error: ", self%gff_get_last_error()
error stop
end if
self%ctx = c_null_ptr self%ctx = c_null_ptr
end if end if
end subroutine gff_free end subroutine gff_free
function get_last_error(self) result(msg) function gff_get_last_error(self) result(msg)
class(GridFire), intent(in) :: self class(GridFire), intent(in) :: self
character(len=:), allocatable :: msg character(len=:), allocatable :: msg
type(c_ptr) :: c_msg_ptr type(c_ptr) :: c_msg_ptr
@@ -169,9 +266,9 @@ module gridfire_mod
do i = 1, len_str do i = 1, len_str
msg(i+10:i+10) = char_ptr(i) msg(i+10:i+10) = char_ptr(i)
end do end do
end function get_last_error end function gff_get_last_error
subroutine register_species(self, species_list) subroutine gff_register_species(self, species_list)
class(GridFire), intent(inout) :: self class(GridFire), intent(inout) :: self
character(len=*), dimension(:), intent(in) :: species_list character(len=*), dimension(:), intent(in) :: species_list
@@ -179,7 +276,6 @@ module gridfire_mod
character(kind=c_char, len=:), allocatable, target :: temp_strs(:) character(kind=c_char, len=:), allocatable, target :: temp_strs(:)
integer :: i, n, ierr integer :: i, n, ierr
print *, "Registering ", size(species_list), " species."
n = size(species_list) n = size(species_list)
self%num_species = int(n, c_size_t) self%num_species = int(n, c_size_t)
@@ -191,17 +287,14 @@ module gridfire_mod
c_ptrs(i) = c_loc(temp_strs(i)) c_ptrs(i) = c_loc(temp_strs(i))
end do end do
print *, "Calling gf_register_species..."
ierr = gf_register_species(self%ctx, int(n, c_int), c_ptrs) ierr = gf_register_species(self%ctx, int(n, c_int), c_ptrs)
print *, "gf_register_species returned with code: ", ierr
if (ierr /= GF_SUCCESS .AND. ierr /= FDSSE_SUCCESS) then if (ierr /= GF_SUCCESS .AND. ierr /= FDSSE_SUCCESS) then
print *, "GridFire: ", self%get_last_error() print *, "GridFire: ", self%gff_get_last_error()
error stop error stop
end if end if
end subroutine register_species end subroutine gff_register_species
subroutine setup_policy(self, policy_name, abundances) subroutine gff_setup_policy(self, policy_name, abundances)
class(GridFire), intent(in) :: self class(GridFire), intent(in) :: self
character(len=*), intent(in) :: policy_name character(len=*), intent(in) :: policy_name
real(c_double), dimension(:), intent(in) :: abundances real(c_double), dimension(:), intent(in) :: abundances
@@ -218,41 +311,59 @@ module gridfire_mod
self%num_species) self%num_species)
if (ierr /= GF_SUCCESS .AND. ierr /= FDSSE_SUCCESS) then if (ierr /= GF_SUCCESS .AND. ierr /= FDSSE_SUCCESS) then
print *, "GridFire Policy Error: ", self%get_last_error() print *, "GridFire Policy Error: ", self%gff_get_last_error()
error stop error stop
end if end if
end subroutine setup_policy end subroutine gff_setup_policy
subroutine setup_solver(self, solver_name) subroutine gff_setup_solver(self, solver_name)
class(GridFire), intent(in) :: self class(GridFire), intent(in) :: self
character(len=*), intent(in) :: solver_name character(len=*), intent(in) :: solver_name
integer(c_int) :: ierr integer(c_int) :: ierr
ierr = gf_construct_solver_from_engine(self%ctx, trim(solver_name) // c_null_char) ierr = gf_construct_solver_from_engine(self%ctx, trim(solver_name) // c_null_char)
if (ierr /= GF_SUCCESS .AND. ierr /= FDSSE_SUCCESS) then if (ierr /= GF_SUCCESS .AND. ierr /= FDSSE_SUCCESS) then
print *, "GridFire Solver Error: ", self%get_last_error() print *, "GridFire Solver Error: ", self%gff_get_last_error()
error stop error stop
end if end if
end subroutine setup_solver end subroutine gff_setup_solver
subroutine evolve(self, Y_in, T, rho, dt, Y_out, energy, dedt, dedrho, nu_e_loss, nu_flux, mass_lost, ierr) subroutine gff_evolve_single(self, Y_in, T, rho, tMax, dt0, Y_out, energy, dedt, dedrho, nu_e_loss, nu_flux, mass_lost, ierr)
class(GridFire), intent(in) :: self class(GridFire), intent(in) :: self
real(c_double), dimension(:), intent(in) :: Y_in real(c_double), dimension(:), intent(in) :: Y_in
real(c_double), value :: T, rho, dt real(c_double), intent(in) :: T, rho
real(c_double), value :: tMax, dt0
real(c_double), dimension(:), intent(out) :: Y_out real(c_double), dimension(:), intent(out) :: Y_out
real(c_double), intent(out) :: energy, dedt, dedrho, nu_e_loss, nu_flux, mass_lost real(c_double), intent(out) :: energy, dedt, dedrho, nu_e_loss, nu_flux, mass_lost
integer, intent(out) :: ierr integer, intent(out) :: ierr
integer(c_int) :: c_ierr integer(c_int) :: c_ierr
c_ierr = gf_evolve(self%ctx, & c_ierr = gf_evolve_c_scalar(self%ctx_type, self%ctx, &
Y_in, self%num_species, & Y_in, self%num_species, &
T, rho, dt, & T, rho, tMax, dt0, &
Y_out, & Y_out, &
energy, dedt, dedrho, nu_e_loss, nu_flux, mass_lost) energy, dedt, dedrho, nu_e_loss, nu_flux, mass_lost)
ierr = int(c_ierr) ierr = int(c_ierr)
if (ierr /= GF_SUCCESS .AND. ierr /= FDSSE_SUCCESS) then end subroutine gff_evolve_single
print *, "GridFire Evolve Error: ", self%get_last_error()
end if subroutine gff_evolve_multi(self, Y_in, T, rho, tMax, dt0, Y_out, energy, dedt, dedrho, nu_e_loss, nu_flux, mass_lost, ierr)
end subroutine evolve class(GridFire), intent(in) :: self
real(c_double), dimension(:,:), intent(in) :: Y_in
real(c_double), dimension(:), intent(in) :: T, rho
real(c_double), value :: tMax, dt0
real(c_double), dimension(:,:), intent(out) :: Y_out
real(c_double), dimension(:), intent(out) :: energy, dedt, dedrho, nu_e_loss, nu_flux, mass_lost
integer, intent(out) :: ierr
integer(c_int) :: c_ierr
c_ierr = gf_evolve_c_array(self%ctx_type, self%ctx, &
Y_in, self%num_species, &
T, rho, tMax, dt0, &
Y_out, &
energy, dedt, dedrho, nu_e_loss, nu_flux, mass_lost)
ierr = int(c_ierr)
end subroutine gff_evolve_multi
end module gridfire_mod end module gridfire_mod

View File

@@ -7,19 +7,32 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
struct GridFireContext { enum class GFContextType {
POINT,
GRID
};
struct GFContext {
virtual ~GFContext() = default;
std::unique_ptr<gridfire::policy::NetworkPolicy> policy; std::unique_ptr<gridfire::policy::NetworkPolicy> policy;
gridfire::engine::DynamicEngine* engine; const gridfire::engine::DynamicEngine* engine;
std::unique_ptr<gridfire::solver::DynamicNetworkSolverStrategy> solver; std::unique_ptr<gridfire::engine::scratch::StateBlob> engine_ctx;
std::vector<fourdst::atomic::Species> speciesList; std::vector<fourdst::atomic::Species> speciesList;
fourdst::composition::Composition working_comp;
void init_species_map(const std::vector<std::string>& species_names); virtual void init_species_map(const std::vector<std::string>& species_names);
void init_engine_from_policy(const std::string& policy_name, const double *abundances, size_t num_species); virtual void init_engine_from_policy(const std::string& policy_name, const double *abundances, size_t num_species);
void init_solver_from_engine(const std::string& solver_name); virtual void init_solver_from_engine() = 0;
void init_composition_from_abundance_vector(const double* abundances, size_t num_species); fourdst::composition::Composition init_composition_from_abundance_vector(const std::vector<double> &abundances, size_t num_species) const;
std::string last_error;
};
struct GFPointContext final: GFContext{
std::unique_ptr<gridfire::solver::SingleZoneDynamicNetworkSolver> solver;
std::unique_ptr<gridfire::solver::SolverContextBase> solver_ctx;
void init_solver_from_engine() override;
int evolve( int evolve(
const double* Y_in, const double* Y_in,
@@ -35,9 +48,45 @@ struct GridFireContext {
double& specific_neutrino_energy_loss, double& specific_neutrino_energy_loss,
double& specific_neutrino_flux, double& specific_neutrino_flux,
double& mass_lost double& mass_lost
); ) const;
std::string last_error;
}; };
struct GFGridContext final : GFContext {
std::unique_ptr<gridfire::solver::SingleZoneDynamicNetworkSolver> local_solver;
std::unique_ptr<gridfire::solver::MultiZoneDynamicNetworkSolver> solver;
std::unique_ptr<gridfire::solver::SolverContextBase> solver_ctx;
void init_solver_from_engine() override;
size_t zones;
void set_zones(const size_t num_zones) {
zones = num_zones;
}
[[nodiscard]] size_t get_zones() const {
return zones;
}
int evolve(
const double* Y_in,
size_t num_species,
const double* T,
const double* rho,
double tMax,
double dt0,
double* Y_out,
double* energy_out,
double* dEps_dT,
double* dEps_dRho,
double* specific_neutrino_energy_loss,
double* specific_neutrino_flux,
double* mass_lost
) const;
};
std::unique_ptr<GFContext> make_gf_context(const GFContextType& type);
#endif #endif

View File

@@ -6,6 +6,12 @@
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
enum GF_TYPE {
SINGLE_ZONE = 1001,
MULTI_ZONE = 1002
};
enum FDSSE_ERROR_CODES { enum FDSSE_ERROR_CODES {
FDSSE_NON_4DSTAR_ERROR = -102, FDSSE_NON_4DSTAR_ERROR = -102,
FDSSE_UNKNOWN_ERROR = -101, FDSSE_UNKNOWN_ERROR = -101,
@@ -54,37 +60,49 @@ extern "C" {
GF_DEBUG_ERROR = 30, GF_DEBUG_ERROR = 30,
GF_GRIDFIRE_ERROR = 31, GF_GRIDFIRE_ERROR = 31,
GF_UNINITIALIZED_INPUT_MEMORY_ERROR = 32,
GF_UNINITIALIZED_OUTPUT_MEMORY_ERROR = 33,
GF_INVALID_NUM_SPECIES = 34,
GF_INVALID_TIMESTEPS = 35,
GF_UNKNOWN_FREE_TYPE = 36,
GF_INVALID_TYPE = 37,
}; };
char* gf_get_last_error_message(void* ptr); char* gf_get_last_error_message(void* ptr);
char* gf_error_code_to_string(int error_code); char* gf_error_code_to_string(int error_code);
void* gf_init(); void* gf_init(const enum GF_TYPE type);
void gf_free(void* ctx); int gf_free(const enum GF_TYPE type, void *ctx);
int gf_set_num_zones(const enum GF_TYPE type, void* ptr, const size_t num_zones);
int gf_register_species(void* ptr, const int num_species, const char** species_names); int gf_register_species(void* ptr, const int num_species, const char** species_names);
int gf_construct_engine_from_policy(void* ptr, const char* policy_name, const double *abundances, size_t num_species); int gf_construct_engine_from_policy(void* ptr, const char* policy_name, const double *abundances, size_t num_species);
int gf_construct_solver_from_engine(void* ptr, const char* solver_name); int gf_construct_solver_from_engine(void* ptr);
int gf_evolve( int gf_evolve(
enum GF_TYPE type,
void* ptr, void* ptr,
const double* Y_in, const void* Y_in,
size_t num_species, size_t num_species,
double T, const void* T,
double rho, const void* rho,
double tMax, double tMax,
double dt0, double dt0,
double* Y_out, void* Y_out,
double* energy_out, void* energy_out,
double* dEps_dT, void* dEps_dT,
double* dEps_dRho, void* dEps_dRho,
double* specific_neutrino_energy_loss, void* specific_neutrino_energy_loss,
double* specific_neutrino_flux, void* specific_neutrino_flux,
double* mass_lost void* mass_lost
); );
#ifdef __cplusplus #ifdef __cplusplus

View File

@@ -3,11 +3,10 @@
#include "fourdst/atomic/species.h" #include "fourdst/atomic/species.h"
#include "fourdst/composition/exceptions/exceptions_composition.h" #include "fourdst/composition/exceptions/exceptions_composition.h"
#include "gridfire/exceptions/error_policy.h"
#include "gridfire/utils/logging.h"
void GridFireContext::init_species_map(const std::vector<std::string> &species_names) { void GFContext::init_species_map(const std::vector<std::string> &species_names) {
for (const auto& name: species_names) {
working_comp.registerSymbol(name);
}
this->speciesList.clear(); this->speciesList.clear();
this->speciesList.reserve(species_names.size()); this->speciesList.reserve(species_names.size());
@@ -24,8 +23,9 @@ void GridFireContext::init_species_map(const std::vector<std::string> &species_n
} }
void GridFireContext::init_engine_from_policy(const std::string &policy_name, const double *abundances, const size_t num_species) { void GFContext::init_engine_from_policy(const std::string &policy_name, const double *abundances, const size_t num_species) {
init_composition_from_abundance_vector(abundances, num_species); const std::vector<double> Y_scratch(abundances, abundances + num_species);
fourdst::composition::Composition comp = init_composition_from_abundance_vector(Y_scratch, num_species);
enum class EnginePolicy { enum class EnginePolicy {
MAIN_SEQUENCE_POLICY MAIN_SEQUENCE_POLICY
@@ -47,71 +47,53 @@ void GridFireContext::init_engine_from_policy(const std::string &policy_name, co
switch (engine_map.at(policy_name)) { switch (engine_map.at(policy_name)) {
case EnginePolicy::MAIN_SEQUENCE_POLICY: { case EnginePolicy::MAIN_SEQUENCE_POLICY: {
this->policy = std::make_unique<gridfire::policy::MainSequencePolicy>(this->working_comp); this->policy = std::make_unique<gridfire::policy::MainSequencePolicy>(comp);
this->engine = &policy->construct(); const auto& [e, ctx] = policy->construct();
this->engine = &e;
this->engine_ctx = ctx->clone_structure();
break; break;
} }
default: default:
throw gridfire::exceptions::PolicyError( throw gridfire::exceptions::PolicyError(
"Unhandled engine policy in GridFireContext::init_engine_from_policy" "Unhandled engine policy in GFPointContext::init_engine_from_policy"
); );
} }
} }
void GridFireContext::init_solver_from_engine(const std::string &solver_name) { fourdst::composition::Composition GFContext::init_composition_from_abundance_vector(const std::vector<double> &abundances, size_t num_species) const {
enum class SolverType {
CVODE
};
static const std::unordered_map<std::string, SolverType> solver_map = {
{"CVODE", SolverType::CVODE}
};
if (!solver_map.contains(solver_name)) {
throw gridfire::exceptions::SolverError(
std::format(
"Solver {} is not recognized. Valid solvers are: {}",
solver_name,
gridfire::utils::iterable_to_delimited_string(solver_map, ", ", [](const auto& pair){ return pair.first; })
)
);
}
switch (solver_map.at(solver_name)) {
case SolverType::CVODE: {
this->solver = std::make_unique<gridfire::solver::CVODESolverStrategy>(*this->engine);
break;
}
default:
throw gridfire::exceptions::SolverError(
"Unhandled solver type in GridFireContext::init_solver_from_engine"
);
}
}
void GridFireContext::init_composition_from_abundance_vector(const double *abundances, size_t num_species) {
if (num_species == 0) { if (num_species == 0) {
throw fourdst::composition::exceptions::InvalidCompositionError("Cannot initialize composition with zero species."); throw fourdst::composition::exceptions::InvalidCompositionError("Cannot initialize composition with zero species.");
} }
if (num_species != working_comp.size()) { if (num_species != speciesList.size()) {
throw fourdst::composition::exceptions::InvalidCompositionError( throw fourdst::composition::exceptions::InvalidCompositionError(
std::format( std::format(
"Number of species provided ({}) does not match the registered species count ({}).", "Number of species provided ({}) does not match the registered species count ({}).",
num_species, num_species,
working_comp.size() speciesList.size()
) )
); );
} }
fourdst::composition::Composition comp;
for (size_t i = 0; i < num_species; i++) { for (size_t i = 0; i < num_species; i++) {
this->working_comp.setMolarAbundance(this->speciesList[i], abundances[i]); comp.registerSpecies(this->speciesList[i]);
comp.setMolarAbundance(this->speciesList[i], abundances[i]);
} }
return comp;
} }
int GridFireContext::evolve(
void GFPointContext::init_solver_from_engine() {
this->solver = std::make_unique<gridfire::solver::PointSolver>(*this->engine);
this->solver_ctx = std::make_unique<gridfire::solver::PointSolverContext>(*engine_ctx);
}
int GFPointContext::evolve(
const double* Y_in, const double* Y_in,
const size_t num_species, const size_t num_species,
const double T, const double T,
@@ -125,17 +107,19 @@ int GridFireContext::evolve(
double& specific_neutrino_energy_loss, double& specific_neutrino_energy_loss,
double& specific_neutrino_flux, double& specific_neutrino_flux,
double& mass_lost double& mass_lost
) { ) const {
init_composition_from_abundance_vector(Y_in, num_species);
const std::vector<double> Y_scratch(Y_in, Y_in + num_species);
const fourdst::composition::Composition comp = init_composition_from_abundance_vector(Y_scratch, num_species);
gridfire::NetIn netIn; gridfire::NetIn netIn;
netIn.temperature = T; netIn.temperature = T;
netIn.density = rho; netIn.density = rho;
netIn.dt0 = dt0; netIn.dt0 = dt0;
netIn.tMax = tMax; netIn.tMax = tMax;
netIn.composition = this->working_comp; netIn.composition = comp;
const gridfire::NetOut result = this->solver->evaluate(netIn); const gridfire::NetOut result = this->solver->evaluate(*solver_ctx, netIn);
energy_out = result.energy; energy_out = result.energy;
dEps_dT = result.dEps_dT; dEps_dT = result.dEps_dT;
@@ -159,3 +143,99 @@ int GridFireContext::evolve(
return 0; return 0;
} }
void GFGridContext::init_solver_from_engine() {
this->local_solver = std::make_unique<gridfire::solver::PointSolver>(*this->engine);
this->solver = std::make_unique<gridfire::solver::GridSolver>(*this->engine, *this->local_solver);
this->solver_ctx = std::make_unique<gridfire::solver::GridSolverContext>(*engine_ctx);
}
int GFGridContext::evolve(
const double* Y_in,
size_t num_species,
const double *T,
const double *rho,
double tMax,
double dt0,
double *Y_out,
double *energy_out,
double *dEps_dT,
double *dEps_dRho,
double *specific_neutrino_energy_loss,
double *specific_neutrino_flux,
double *mass_lost
) const {
if (this->get_zones() == 0) {
throw gridfire::exceptions::GridFireError("GFGridContext has zero zones configured for evolution.");
}
if (Y_in == nullptr || T == nullptr || rho == nullptr) {
throw gridfire::exceptions::GridFireError("Input abundance, temperature, or density pointers are null.");
}
if (Y_out == nullptr) {
throw gridfire::exceptions::GridFireError("Output abundance pointer is null.");
}
std::vector<fourdst::composition::Composition> zone_compositions;
zone_compositions.reserve(this->get_zones());
std::vector<double> Y_scratch;
Y_scratch.resize(num_species);
for (size_t i = 0; i < this->get_zones(); ++i) {
for (size_t j = 0; j < num_species; ++j) {
Y_scratch[j] = Y_in[i * num_species + j];
}
zone_compositions.push_back(init_composition_from_abundance_vector(Y_scratch, num_species));
}
std::vector<gridfire::NetIn> netIns;
netIns.reserve(this->get_zones());
for (size_t i = 0; i < this->get_zones(); ++i) {
gridfire::NetIn netIn;
netIn.temperature = T[i];
netIn.density = rho[i];
netIn.dt0 = dt0;
netIn.tMax = tMax;
netIn.composition = zone_compositions[i];
netIns.push_back(netIn);
}
std::vector<gridfire::NetOut> results = this->solver->evaluate(*this->solver_ctx, netIns);
for (size_t zone_idx = 0; zone_idx < this->get_zones(); ++zone_idx) {
const gridfire::NetOut& netOut = results[zone_idx];
energy_out[zone_idx] = netOut.energy;
dEps_dT[zone_idx] = netOut.dEps_dT;;
dEps_dRho[zone_idx] = netOut.dEps_dRho;
specific_neutrino_energy_loss[zone_idx] = netOut.specific_neutrino_energy_loss;
specific_neutrino_flux[zone_idx] = netOut.specific_neutrino_flux;
std::set<fourdst::atomic::Species> seen_species;
for (size_t i = 0; i < num_species; ++i) {
fourdst::atomic::Species species = this->speciesList[i];
Y_out[zone_idx * num_species + i] = netOut.composition.getMolarAbundance(species);
seen_species.insert(species);
}
mass_lost[zone_idx] = 0.0;
for (const auto& species : netOut.composition.getRegisteredSpecies()) {
if (!seen_species.contains(species)) {
mass_lost[zone_idx] += species.mass() * netOut.composition.getMolarAbundance(species);
}
}
}
return 0;
}
std::unique_ptr<GFContext> make_gf_context(const GFContextType &type) {
switch (type) {
case GFContextType::POINT:
return std::make_unique<GFPointContext>();
case GFContextType::GRID:
return std::make_unique<GFGridContext>();
default:
throw gridfire::exceptions::GridFireError("Unhandled GFContextType in make_gf_context");
}
}

View File

@@ -3,120 +3,18 @@
#include "gridfire/extern/gridfire_context.h" #include "gridfire/extern/gridfire_context.h"
#include "gridfire/extern/gridfire_extern.h" #include "gridfire/extern/gridfire_extern.h"
extern "C" { namespace {
void* gf_init() {
return new GridFireContext();
}
void gf_free(void* ctx) { template<typename T>
delete static_cast<GridFireContext*>(ctx); concept ErrorTrackable = requires(T a) {
} { a.last_error } -> std::convertible_to<std::string>;
};
int gf_register_species(void* ptr, const int num_species, const char** species_names) { template <ErrorTrackable Context, typename Func>
auto* ctx = static_cast<GridFireContext*>(ptr); int execute_guarded(Context* ctx, Func&& action) {
try { try {
std::vector<std::string> names; const int result = action();
for(int i=0; i<num_species; ++i) {
names.emplace_back(species_names[i]);
}
ctx->init_species_map(names);
return FDSSE_SUCCESS;
} catch (const fourdst::composition::exceptions::UnknownSymbolError& e) {
ctx->last_error = e.what();
return FDSSE_UNKNOWN_SYMBOL_ERROR;
} catch (const fourdst::composition::exceptions::SpeciesError& e) {
ctx->last_error = e.what();
return FDSSE_SPECIES_ERROR;
} catch (const std::exception& e) {
ctx->last_error = e.what();
return FDSSE_NON_4DSTAR_ERROR;
} catch (...) {
ctx->last_error = "Unknown error occurred during species registration.";
return FDSSE_UNKNOWN_ERROR;
}
}
int gf_construct_engine_from_policy(
void* ptr,
const char* policy_name,
const double *abundances,
const size_t num_species
) {
auto* ctx = static_cast<GridFireContext*>(ptr);
try {
ctx->init_engine_from_policy(std::string(policy_name), abundances, num_species);
return GF_SUCCESS;
} catch (const gridfire::exceptions::MissingBaseReactionError& e) {
ctx->last_error = e.what();
return GF_MISSING_BASE_REACTION_ERROR;
} catch (const gridfire::exceptions::MissingSeedSpeciesError& e) {
ctx->last_error = e.what();
return GF_MISSING_SEED_SPECIES_ERROR;
} catch (const gridfire::exceptions::MissingKeyReactionError& e) {
ctx->last_error = e.what();
return GF_MISSING_KEY_REACTION_ERROR;
} catch (const gridfire::exceptions::PolicyError& e) {
ctx->last_error = e.what();
return GF_POLICY_ERROR;
} catch (std::exception& e) {
ctx->last_error = e.what();
return GF_NON_GRIDFIRE_ERROR;
} catch (...) {
ctx->last_error = "Unknown error occurred during engine construction.";
return GF_UNKNOWN_ERROR;
}
}
int gf_construct_solver_from_engine(
void* ptr,
const char* solver_name
) {
auto* ctx = static_cast<GridFireContext*>(ptr);
try {
ctx->init_solver_from_engine(std::string(solver_name));
return GF_SUCCESS;
} catch (std::exception& e) {
ctx->last_error = e.what();
return GF_NON_GRIDFIRE_ERROR;
} catch (...) {
ctx->last_error = "Unknown error occurred during solver construction.";
return GF_UNKNOWN_ERROR;
}
}
int gf_evolve(
void* ptr,
const double* Y_in,
const size_t num_species,
const double T,
const double rho,
const double tMax,
const double dt0,
double* Y_out,
double* energy_out,
double* dEps_dT,
double* dEps_dRho,
double* specific_neutrino_energy_loss,
double* specific_neutrino_flux,
double* mass_lost
) {
auto* ctx = static_cast<GridFireContext*>(ptr);
try {
const int result = ctx->evolve(
Y_in,
num_species,
T,
rho,
tMax,
dt0,
Y_out,
*energy_out,
*dEps_dT,
*dEps_dRho,
*specific_neutrino_energy_loss,
*specific_neutrino_flux,
*mass_lost
);
if (result != 0) { if (result != 0) {
return result; return result;
} }
@@ -211,8 +109,230 @@ extern "C" {
} }
} }
}
extern "C" {
void* gf_init(const enum GF_TYPE type) {
if (type == MULTI_ZONE) {
return new GFGridContext();
}
if (type == SINGLE_ZONE) {
return new GFPointContext();
}
return nullptr;
}
int gf_free(const enum GF_TYPE type, void *ctx) {
if (!ctx) {
return GF_UNINITIALIZED_INPUT_MEMORY_ERROR;
}
if (type == MULTI_ZONE) {
delete static_cast<GFGridContext*>(ctx);
return GF_SUCCESS;
}
if (type == SINGLE_ZONE) {
delete static_cast<GFPointContext*>(ctx);
return GF_SUCCESS;
}
return GF_UNKNOWN_FREE_TYPE;
}
int gf_set_num_zones(const enum GF_TYPE type, void* ptr, const size_t num_zones) {
if (type != MULTI_ZONE) {
return GF_INVALID_TYPE;
}
if (!ptr) {
return GF_UNINITIALIZED_INPUT_MEMORY_ERROR;
}
auto* ctx = static_cast<GFGridContext*>(ptr);
return execute_guarded(ctx, [&]() {
ctx->set_zones(num_zones);
return GF_SUCCESS;
});
}
int gf_register_species(void* ptr, const int num_species, const char** species_names) {
if (num_species < 0) return GF_INVALID_NUM_SPECIES;
if (num_species == 0) return GF_SUCCESS;
if (!ptr || !species_names) {
return GF_UNINITIALIZED_INPUT_MEMORY_ERROR;
}
for (int i=0; i < num_species; ++i) {
if (!species_names[i]) {
return GF_UNINITIALIZED_INPUT_MEMORY_ERROR;
}
}
auto* ctx = static_cast<GFContext*>(ptr);
return execute_guarded(ctx, [&]() {
std::vector<std::string> names;
for (int i=0; i<num_species; ++i) {
names.emplace_back(species_names[i]);
}
ctx->init_species_map(names);
return FDSSE_SUCCESS;
});
}
int gf_construct_engine_from_policy(
void* ptr,
const char* policy_name,
const double *abundances,
const size_t num_species
) {
auto* ctx = static_cast<GFContext*>(ptr);
return execute_guarded(ctx, [&]() {
ctx->init_engine_from_policy(std::string(policy_name), abundances, num_species);
return GF_SUCCESS;
});
}
int gf_construct_solver_from_engine(
void* ptr
) {
auto* ctx = static_cast<GFContext*>(ptr);
return execute_guarded(ctx, [&]() {
ctx->init_solver_from_engine();
return GF_SUCCESS;
});
}
int gf_evolve(
const enum GF_TYPE type,
void* ptr,
const void* Y_in,
const size_t num_species,
const void* T,
const void* rho,
const double tMax,
const double dt0,
void* Y_out,
void* energy_out,
void* dEps_dT,
void* dEps_dRho,
void* specific_neutrino_energy_loss,
void* specific_neutrino_flux,
void* mass_lost
) {
printf("In C Starting gf_evolve with type %d\n", type);
printf("In C num_species: %zu, tMax: %e, dt0: %e\n", num_species, tMax, dt0);
printf("In C Y_in ptr: %p, T ptr: %p, rho ptr: %p\n", Y_in, T, rho);
// values
printf("In C Y_in first 5 values: ");
const auto* Y_in_ptr = static_cast<const double*>(Y_in);
for (size_t i = 0; i < std::min(num_species, size_t(5)); ++i) {
printf("%e ", Y_in_ptr[i]);
}
printf("\n");
printf("In C T value: %e\n", *(static_cast<const double*>(T)));
printf("In C rho value: %e\n", *(static_cast<const double*>(rho)));
printf("In C tMax value: %e\n", tMax);
printf("In C dt0 value: %e\n", dt0);
if (!ptr || !Y_in || !T || !rho) {
return GF_UNINITIALIZED_INPUT_MEMORY_ERROR;
}
if (!Y_out || !energy_out || !dEps_dT || !dEps_dRho || !specific_neutrino_energy_loss || !specific_neutrino_flux || !mass_lost) {
return GF_UNINITIALIZED_OUTPUT_MEMORY_ERROR;
}
if (tMax <= 0 || dt0 <= 0) {
return GF_INVALID_TIMESTEPS;
}
if (num_species <= 0) {
return GF_INVALID_NUM_SPECIES;
}
switch (type) {
case SINGLE_ZONE : {
auto* ctx = static_cast<GFPointContext*>(ptr);
const auto T_ptr = static_cast<const double*>(T);
const auto *rho_ptr = static_cast<const double*>(rho);
auto* Y_out_local = static_cast<double*>(Y_out);
auto* energy_out_local = static_cast<double*>(energy_out);
auto* dEps_dT_local = static_cast<double*>(dEps_dT);
auto* dEps_dRho_local = static_cast<double*>(dEps_dRho);
auto* specific_neutrino_energy_loss_local = static_cast<double*>(specific_neutrino_energy_loss);
auto* specific_neutrino_flux_local = static_cast<double*>(specific_neutrino_flux);
auto* mass_lost_local = static_cast<double*>(mass_lost);
printf("Evolving single zone with T = %e, rho = %e for tMax = %e and dt0 = %e\n", *T_ptr, *rho_ptr, tMax, dt0);
return execute_guarded(ctx, [&]() {
return ctx->evolve(
static_cast<const double*>(Y_in),
num_species,
*T_ptr,
*rho_ptr,
tMax,
dt0,
Y_out_local,
*energy_out_local,
*dEps_dT_local,
*dEps_dRho_local,
*specific_neutrino_energy_loss_local,
*specific_neutrino_flux_local,
*mass_lost_local
);
});
}
case MULTI_ZONE : {
auto* ctx = static_cast<GFGridContext*>(ptr);
const auto *T_ptr = static_cast<const double*>(T);
const auto *rho_ptr = static_cast<const double*>(rho);
auto* Y_out_local = static_cast<double*>(Y_out);
auto* energy_out_local = static_cast<double*>(energy_out);
auto* dEps_dT_local = static_cast<double*>(dEps_dT);
auto* dEps_dRho_local = static_cast<double*>(dEps_dRho);
auto* specific_neutrino_energy_loss_local = static_cast<double*>(specific_neutrino_energy_loss);
auto* specific_neutrino_flux_local = static_cast<double*>(specific_neutrino_flux);
auto* mass_lost_local = static_cast<double*>(mass_lost);
printf("Evolving multi zone for tMax = %e and dt0 = %e\n", tMax, dt0);
return execute_guarded(ctx, [&]() {
return ctx->evolve(
static_cast<const double*>(Y_in),
num_species,
T_ptr, // T pointer
rho_ptr, // rho pointer
tMax,
dt0,
Y_out_local,
energy_out_local,
dEps_dT_local,
dEps_dRho_local,
specific_neutrino_energy_loss_local,
specific_neutrino_flux_local,
mass_lost_local
);
});
}
default :
return GF_UNKNOWN_ERROR;
}
}
char* gf_get_last_error_message(void* ptr) { char* gf_get_last_error_message(void* ptr) {
const auto* ctx = static_cast<GridFireContext*>(ptr); if (!ptr) {
return const_cast<char*>("GF_UNINITIALIZED_INPUT_MEMORY_ERROR");
}
const auto* ctx = static_cast<GFContext*>(ptr);
return const_cast<char*>(ctx->last_error.c_str()); return const_cast<char*>(ctx->last_error.c_str());
} }
@@ -278,6 +398,18 @@ extern "C" {
return const_cast<char*>("GF_DEBUG_ERROR"); return const_cast<char*>("GF_DEBUG_ERROR");
case GF_GRIDFIRE_ERROR: case GF_GRIDFIRE_ERROR:
return const_cast<char*>("GF_GRIDFIRE_ERROR"); return const_cast<char*>("GF_GRIDFIRE_ERROR");
case GF_UNINITIALIZED_INPUT_MEMORY_ERROR:
return const_cast<char*>("GF_UNINITIALIZED_INPUT_MEMORY_ERROR");
case GF_UNINITIALIZED_OUTPUT_MEMORY_ERROR:
return const_cast<char*>("GF_UNINITIALIZED_OUTPUT_MEMORY_ERROR");
case GF_INVALID_NUM_SPECIES:
return const_cast<char*>("GF_INVALID_NUM_SPECIES");
case GF_INVALID_TIMESTEPS:
return const_cast<char*>("GF_INVALID_TIMESTEPS");
case GF_UNKNOWN_FREE_TYPE:
return const_cast<char*>("GF_UNKNOWN_FREE_TYPE");
case GF_INVALID_TYPE:
return const_cast<char*>("GF_INVALID_TYPE");
case FDSSE_NON_4DSTAR_ERROR: case FDSSE_NON_4DSTAR_ERROR:
return const_cast<char*>("FDSSE_NON_4DSTAR_ERROR"); return const_cast<char*>("FDSSE_NON_4DSTAR_ERROR");
case FDSSE_UNKNOWN_ERROR: case FDSSE_UNKNOWN_ERROR:

View File

@@ -8,27 +8,8 @@ namespace gridfire::config {
double relTol = 1.0e-5; double relTol = 1.0e-5;
}; };
struct SpectralSolverConfig {
struct MonitorFunctionConfig {
double structure_weight = 1.0;
double abundance_weight = 10.0;
double alpha = 0.2;
double beta = 0.8;
};
struct BasisConfig {
size_t num_elements = 50;
};
double absTol = 1.0e-8;
double relTol = 1.0e-5;
size_t degree = 3;
MonitorFunctionConfig monitorFunction;
BasisConfig basis;
};
struct SolverConfig { struct SolverConfig {
CVODESolverConfig cvode; CVODESolverConfig cvode;
SpectralSolverConfig spectral;
}; };
struct AdaptiveEngineViewConfig { struct AdaptiveEngineViewConfig {
@@ -50,5 +31,4 @@ namespace gridfire::config {
}; };
} }

View File

@@ -807,8 +807,6 @@ namespace gridfire::engine {
CppAD::ADFun<double> m_authoritativeADFun; CppAD::ADFun<double> m_authoritativeADFun;
const size_t m_state_blob_offset;
private: private:
/** /**
* @brief Synchronizes the internal maps. * @brief Synchronizes the internal maps.

View File

@@ -359,7 +359,6 @@ namespace gridfire::engine {
private: private:
using LogManager = LogManager;
Config<config::GridFireConfig> m_config; Config<config::GridFireConfig> m_config;
quill::Logger* m_logger = LogManager::getInstance().getLogger("log"); quill::Logger* m_logger = LogManager::getInstance().getLogger("log");

View File

@@ -25,6 +25,7 @@
#include <set> #include <set>
#include "gridfire/engine/types/engine_types.h" #include "gridfire/engine/types/engine_types.h"
#include "gridfire/engine/scratchpads/blob.h"
namespace gridfire::policy { namespace gridfire::policy {

View File

@@ -0,0 +1,46 @@
#pragma once
#include "gridfire/solver/strategies/strategy_abstract.h"
#include <functional>
namespace gridfire::solver {
struct GridSolverContext final : SolverContextBase {
std::vector<std::unique_ptr<SolverContextBase>> solver_workspaces;
std::vector<std::function<void(const TimestepContextBase&)>> timestep_callbacks;
const engine::scratch::StateBlob& ctx_template;
bool zone_completion_logging = true;
bool zone_stdout_logging = false;
bool zone_detailed_logging = false;
void init() override;
void reset();
void set_callback(const std::function<void(const TimestepContextBase&)> &callback);
void set_callback(const std::function<void(const TimestepContextBase&)> &callback, size_t zone_idx);
void clear_callback();
void clear_callback(size_t zone_idx);
void set_stdout_logging(bool enable) override;
void set_detailed_logging(bool enable) override;
explicit GridSolverContext(const engine::scratch::StateBlob& ctx_template);
};
class GridSolver final : public MultiZoneDynamicNetworkSolver {
public:
GridSolver(
const engine::DynamicEngine& engine,
const SingleZoneDynamicNetworkSolver& solver
);
~GridSolver() override = default;
std::vector<NetOut> evaluate(
SolverContextBase& ctx,
const std::vector<NetIn>& netIns
) const override;
};
}

View File

@@ -44,8 +44,88 @@
#endif #endif
namespace gridfire::solver { namespace gridfire::solver {
struct PointSolverTimestepContext final : TimestepContextBase {
const double t; ///< Current integration time [s].
const N_Vector& state; ///< Current CVODE state vector (N_Vector).
const double dt; ///< Last step size [s].
const double last_step_time; ///< Time at last callback [s].
const double T9; ///< Temperature in GK.
const double rho; ///< Density [g cm^-3].
const size_t num_steps; ///< Number of CVODE steps taken so far.
const engine::DynamicEngine& engine; ///< Reference to the engine.
const std::vector<fourdst::atomic::Species>& networkSpecies; ///< Species layout.
const size_t currentConvergenceFailures; ///< Total number of convergence failures
const size_t currentNonlinearIterations; ///< Total number of non-linear iterations
const std::map<fourdst::atomic::Species, std::unordered_map<std::string, double>>& reactionContributionMap; ///< Map of reaction contributions for the current step
engine::scratch::StateBlob& state_ctx; ///< Reference to the engine scratch state blob
PointSolverTimestepContext(
double t,
const N_Vector& state,
double dt,
double last_step_time,
double t9,
double rho,
size_t num_steps,
const engine::DynamicEngine& engine,
const std::vector<fourdst::atomic::Species>& networkSpecies,
size_t currentConvergenceFailure,
size_t currentNonlinearIterations,
const std::map<fourdst::atomic::Species, std::unordered_map<std::string, double>> &reactionContributionMap,
engine::scratch::StateBlob& state_ctx
);
[[nodiscard]] std::vector<std::tuple<std::string, std::string>> describe() const override;
};
using TimestepCallback = std::function<void(const PointSolverTimestepContext& context)>; ///< Type alias for a timestep callback function.
struct PointSolverContext final : SolverContextBase {
SUNContext sun_ctx = nullptr; ///< SUNDIALS context (lifetime of the solver).
void* cvode_mem = nullptr; ///< CVODE memory block.
N_Vector Y = nullptr; ///< CVODE state vector (species + energy accumulator).
N_Vector YErr = nullptr; ///< Estimated local errors.
SUNMatrix J = nullptr; ///< Dense Jacobian matrix.
SUNLinearSolver LS = nullptr; ///< Dense linear solver.
std::unique_ptr<engine::scratch::StateBlob> engine_ctx;
std::optional<TimestepCallback> callback; ///< Optional per-step callback.
int num_steps = 0; ///< CVODE step counter (used for diagnostics and triggers).
bool stdout_logging = true; ///< If true, print per-step logs and use CV_ONE_STEP.
N_Vector constraints = nullptr; ///< CVODE constraints vector (>= 0 for species entries).
std::optional<double> abs_tol; ///< User-specified absolute tolerance.
std::optional<double> rel_tol; ///< User-specified relative tolerance.
bool detailed_step_logging = false; ///< If true, log detailed step diagnostics (error ratios, Jacobian, species balance).
size_t last_size = 0;
size_t last_composition_hash = 0ULL;
sunrealtype last_good_time_step = 0ULL;
void init() override;
void set_stdout_logging(bool enable) override;
void set_detailed_logging(bool enable) override;
void reset_all();
void reset_user();
void reset_cvode();
void clear_context();
void init_context();
[[nodiscard]] bool has_context() const;
explicit PointSolverContext(const engine::scratch::StateBlob& engine_ctx);
~PointSolverContext() override;
};
/** /**
* @class CVODESolverStrategy * @class PointSolver
* @brief Stiff ODE integrator backed by SUNDIALS CVODE (BDF) for network + energy. * @brief Stiff ODE integrator backed by SUNDIALS CVODE (BDF) for network + energy.
* *
* Integrates the nuclear network abundances along with a final accumulator entry for specific * Integrates the nuclear network abundances along with a final accumulator entry for specific
@@ -78,27 +158,16 @@ namespace gridfire::solver {
* std::cout << "Final energy: " << out.energy << " erg/g\n"; * std::cout << "Final energy: " << out.energy << " erg/g\n";
* @endcode * @endcode
*/ */
class CVODESolverStrategy final : public SingleZoneDynamicNetworkSolver { class PointSolver final : public SingleZoneDynamicNetworkSolver {
public: public:
/** /**
* @brief Construct the CVODE strategy and create a SUNDIALS context. * @brief Construct the CVODE strategy and create a SUNDIALS context.
* @param engine DynamicEngine used for RHS/Jacobian evaluation and network access. * @param engine DynamicEngine used for RHS/Jacobian evaluation and network access.
* @throws std::runtime_error If SUNContext_Create fails. * @throws std::runtime_error If SUNContext_Create fails.
*/ */
explicit CVODESolverStrategy( explicit PointSolver(
const engine::DynamicEngine& engine, const engine::DynamicEngine& engine
const engine::scratch::StateBlob& ctx
); );
/**
* @brief Destructor: cleans CVODE/SUNDIALS resources and frees SUNContext.
*/
~CVODESolverStrategy() override;
// Make the class non-copyable and non-movable to prevent shallow copies of CVODE pointers
CVODESolverStrategy(const CVODESolverStrategy&) = delete;
CVODESolverStrategy& operator=(const CVODESolverStrategy&) = delete;
CVODESolverStrategy(CVODESolverStrategy&&) = delete;
CVODESolverStrategy& operator=(CVODESolverStrategy&&) = delete;
/** /**
* @brief Integrate from t=0 to netIn.tMax and return final composition and energy. * @brief Integrate from t=0 to netIn.tMax and return final composition and energy.
@@ -114,6 +183,7 @@ namespace gridfire::solver {
* - At the end, converts molar abundances to mass fractions and assembles NetOut, * - At the end, converts molar abundances to mass fractions and assembles NetOut,
* including derivatives of energy w.r.t. T and rho from the engine. * including derivatives of energy w.r.t. T and rho from the engine.
* *
* @param solver_ctx
* @param netIn Inputs: temperature [K], density [g cm^-3], tMax [s], composition. * @param netIn Inputs: temperature [K], density [g cm^-3], tMax [s], composition.
* @return NetOut containing final Composition, accumulated energy [erg/g], step count, * @return NetOut containing final Composition, accumulated energy [erg/g], step count,
* and dEps/dT, dEps/dRho. * and dEps/dT, dEps/dRho.
@@ -122,10 +192,14 @@ namespace gridfire::solver {
* @throws exceptions::StaleEngineTrigger Propagated if the engine signals a stale state * @throws exceptions::StaleEngineTrigger Propagated if the engine signals a stale state
* during RHS evaluation (captured in the wrapper then rethrown here). * during RHS evaluation (captured in the wrapper then rethrown here).
*/ */
NetOut evaluate(const NetIn& netIn) override; NetOut evaluate(
SolverContextBase& solver_ctx,
const NetIn& netIn
) const override;
/** /**
* @brief Call to evaluate which will let the user control if the trigger reasoning is displayed * @brief Call to evaluate which will let the user control if the trigger reasoning is displayed
* @param solver_ctx
* @param netIn Inputs: temperature [K], density [g cm^-3], tMax [s], composition. * @param netIn Inputs: temperature [K], density [g cm^-3], tMax [s], composition.
* @param displayTrigger Boolean flag to control if trigger reasoning is displayed * @param displayTrigger Boolean flag to control if trigger reasoning is displayed
* @param forceReinitialize Boolean flag to force reinitialization of CVODE resources at the start * @param forceReinitialize Boolean flag to force reinitialization of CVODE resources at the start
@@ -136,89 +210,13 @@ namespace gridfire::solver {
* @throws exceptions::StaleEngineTrigger Propagated if the engine signals a stale state * @throws exceptions::StaleEngineTrigger Propagated if the engine signals a stale state
* during RHS evaluation (captured in the wrapper then rethrown here). * during RHS evaluation (captured in the wrapper then rethrown here).
*/ */
NetOut evaluate(const NetIn& netIn, bool displayTrigger, bool forceReinitialize = false); NetOut evaluate(
SolverContextBase& solver_ctx,
const NetIn& netIn,
bool displayTrigger,
bool forceReinitialize = false
) const;
/**
* @brief Install a timestep callback.
* @param callback std::any containing TimestepCallback (std::function<void(const TimestepContext&)>).
* @throws std::bad_any_cast If callback is not of the expected type.
*/
void set_callback(const std::any &callback) override;
/**
* @brief Whether per-step logs are printed to stdout and CVode is stepped with CV_ONE_STEP.
*/
[[nodiscard]] bool get_stdout_logging_enabled() const;
/**
* @brief Enable/disable per-step stdout logging.
* @param logging_enabled Flag to control if a timestep summary is written to standard output or not
*/
void set_stdout_logging_enabled(bool logging_enabled);
void set_absTol(double absTol);
void set_relTol(double relTol);
double get_absTol() const;
double get_relTol() const;
/**
* @brief Schema of fields exposed to the timestep callback context.
*/
[[nodiscard]] std::vector<std::tuple<std::string, std::string>> describe_callback_context() const override;
/**
* @struct TimestepContext
* @brief Immutable view of the current integration state passed to callbacks.
*
* Fields capture CVODE time/state, step size, thermodynamic state, the engine reference,
* and the list of network species used to interpret the state vector layout.
*/
struct TimestepContext final : public SolverContextBase {
// This struct can be identical to the one in DirectNetworkSolver
const double t; ///< Current integration time [s].
const N_Vector& state; ///< Current CVODE state vector (N_Vector).
const double dt; ///< Last step size [s].
const double last_step_time; ///< Time at last callback [s].
const double T9; ///< Temperature in GK.
const double rho; ///< Density [g cm^-3].
const size_t num_steps; ///< Number of CVODE steps taken so far.
const engine::DynamicEngine& engine; ///< Reference to the engine.
const std::vector<fourdst::atomic::Species>& networkSpecies; ///< Species layout.
const size_t currentConvergenceFailures; ///< Total number of convergence failures
const size_t currentNonlinearIterations; ///< Total number of non-linear iterations
const std::map<fourdst::atomic::Species, std::unordered_map<std::string, double>>& reactionContributionMap; ///< Map of reaction contributions for the current step
engine::scratch::StateBlob& state_ctx; ///< Reference to the engine scratch state blob
/**
* @brief Construct a context snapshot.
*/
TimestepContext(
double t,
const N_Vector& state,
double dt,
double last_step_time,
double t9,
double rho,
size_t num_steps,
const engine::DynamicEngine& engine,
const std::vector<fourdst::atomic::Species>& networkSpecies,
size_t currentConvergenceFailure,
size_t currentNonlinearIterations,
const std::map<fourdst::atomic::Species, std::unordered_map<std::string, double>> &reactionContributionMap,
engine::scratch::StateBlob& state_ctx
);
/**
* @brief Human-readable description of the context fields.
*/
[[nodiscard]] std::vector<std::tuple<std::string, std::string>> describe() const override;
};
/**
* @brief Type alias for a timestep callback.
*/
using TimestepCallback = std::function<void(const TimestepContext& context)>; ///< Type alias for a timestep callback function.
private: private:
/** /**
* @struct CVODEUserData * @struct CVODEUserData
@@ -230,7 +228,8 @@ namespace gridfire::solver {
* to CVODE, then the driver loop inspects and rethrows. * to CVODE, then the driver loop inspects and rethrows.
*/ */
struct CVODEUserData { struct CVODEUserData {
CVODESolverStrategy* solver_instance{}; // Pointer back to the class instance const PointSolver* solver_instance{}; // Pointer back to the class instance
PointSolverContext* sctx; // Pointer to the solver context
engine::scratch::StateBlob& ctx; engine::scratch::StateBlob& ctx;
const engine::DynamicEngine* engine{}; const engine::DynamicEngine* engine{};
double T9{}; double T9{};
@@ -283,6 +282,7 @@ namespace gridfire::solver {
* step size, creates a dense matrix and dense linear solver, and registers the Jacobian. * step size, creates a dense matrix and dense linear solver, and registers the Jacobian.
*/ */
void initialize_cvode_integration_resources( void initialize_cvode_integration_resources(
PointSolverContext* ctx,
uint64_t N, uint64_t N,
size_t numSpecies, size_t numSpecies,
double current_time, double current_time,
@@ -290,15 +290,7 @@ namespace gridfire::solver {
double absTol, double absTol,
double relTol, double relTol,
double accumulatedEnergy double accumulatedEnergy
); ) const;
/**
* @brief Destroy CVODE vectors/linear algebra and optionally the CVODE memory block.
* @param memFree If true, also calls CVodeFree on m_cvode_mem.
*/
void cleanup_cvode_resources(bool memFree);
void set_detailed_step_logging(bool enabled);
/** /**
@@ -308,31 +300,13 @@ namespace gridfire::solver {
* sorted table of species with the highest error ratios; then invokes diagnostic routines to * sorted table of species with the highest error ratios; then invokes diagnostic routines to
* inspect Jacobian stiffness and species balance. * inspect Jacobian stiffness and species balance.
*/ */
void log_step_diagnostics(engine::scratch::StateBlob &ctx, const CVODEUserData& user_data, bool displayJacobianStiffness, bool void log_step_diagnostics(
displaySpeciesBalance, bool to_file, std::optional<std::string> filename) const; PointSolverContext* sctx_p,
private: engine::scratch::StateBlob &ctx,
SUNContext m_sun_ctx = nullptr; ///< SUNDIALS context (lifetime of the solver). const CVODEUserData& user_data,
void* m_cvode_mem = nullptr; ///< CVODE memory block. bool displayJacobianStiffness,
N_Vector m_Y = nullptr; ///< CVODE state vector (species + energy accumulator). bool displaySpeciesBalance,
N_Vector m_YErr = nullptr; ///< Estimated local errors. bool to_file, std::optional<std::string> filename
SUNMatrix m_J = nullptr; ///< Dense Jacobian matrix. ) const;
SUNLinearSolver m_LS = nullptr; ///< Dense linear solver.
std::optional<TimestepCallback> m_callback; ///< Optional per-step callback.
int m_num_steps = 0; ///< CVODE step counter (used for diagnostics and triggers).
bool m_stdout_logging_enabled = true; ///< If true, print per-step logs and use CV_ONE_STEP.
N_Vector m_constraints = nullptr; ///< CVODE constraints vector (>= 0 for species entries).
std::optional<double> m_absTol; ///< User-specified absolute tolerance.
std::optional<double> m_relTol; ///< User-specified relative tolerance.
bool m_detailed_step_logging = false; ///< If true, log detailed step diagnostics (error ratios, Jacobian, species balance).
mutable size_t m_last_size = 0;
mutable size_t m_last_composition_hash = 0ULL;
mutable sunrealtype m_last_good_time_step = 0ULL;
}; };
} }

View File

@@ -1,196 +0,0 @@
#pragma once
#include "gridfire/solver/strategies/strategy_abstract.h"
#include "gridfire/engine/engine_abstract.h"
#include "gridfire/types/types.h"
#include "gridfire/config/config.h"
#include "fourdst/logging/logging.h"
#include "fourdst/constants/const.h"
#include <vector>
#include <cvode/cvode.h>
#include <sundials/sundials_types.h>
#ifdef SUNDIALS_HAVE_OPENMP
#include <nvector/nvector_openmp.h>
#endif
#ifdef SUNDIALS_HAVE_PTHREADS
#include <nvector/nvector_pthreads.hh>
#endif
#ifndef SUNDIALS_HAVE_OPENMP
#ifndef SUNDIALS_HAVE_PTHREADS
#include <nvector/nvector_serial.h>
#endif
#endif
namespace gridfire::solver {
class SpectralSolverStrategy final : public MultiZoneDynamicNetworkSolverStrategy {
public:
explicit SpectralSolverStrategy(engine::DynamicEngine& engine);
~SpectralSolverStrategy() override;
std::vector<NetOut> evaluate(
const std::vector<NetIn> &netIns,
const std::vector<double>& mass_coords
) override;
void set_callback(const std::any &callback) override;
[[nodiscard]] std::vector<std::tuple<std::string, std::string>> describe_callback_context() const override;
[[nodiscard]] bool get_stdout_logging_enabled() const;
void set_stdout_logging_enabled(bool logging_enabled);
public:
struct TimestepContext final : public SolverContextBase {
TimestepContext(
const double t,
const N_Vector &state,
const double dt,
const double last_time_step,
const engine::DynamicEngine &engine
) :
t(t),
state(state),
dt(dt),
last_time_step(last_time_step),
engine(engine) {}
[[nodiscard]] std::vector<std::tuple<std::string, std::string>> describe() const override;
const double t;
const N_Vector& state;
const double dt;
const double last_time_step;
const engine::DynamicEngine& engine;
};
struct BasisEval {
size_t start_idx;
std::vector<double> phi;
};
struct SplineBasis {
std::vector<double> knots;
std::vector<double> quadrature_nodes;
std::vector<double> quadrature_weights;
int degree = 3;
std::vector<BasisEval> quad_evals;
};
public:
using TimestepCallback = std::function<void(const TimestepContext&)>;
private:
struct SpectralCoefficients {
size_t num_sets;
size_t num_coefficients;
std::vector<double> coefficients;
double operator()(size_t i, size_t j) const;
};
struct GridPoint {
double T9;
double rho;
fourdst::composition::Composition composition;
};
struct Constants {
const double c = fourdst::constant::Constants::getInstance().get("c").value;
const double N_a = fourdst::constant::Constants::getInstance().get("N_a").value;
const double c2 = c * c;
};
struct DenseLinearSolver {
SUNMatrix A;
SUNLinearSolver LS;
N_Vector temp_vector;
SUNContext ctx;
DenseLinearSolver(size_t size, SUNContext sun_ctx);
~DenseLinearSolver();
DenseLinearSolver(const DenseLinearSolver&) = delete;
DenseLinearSolver& operator=(const DenseLinearSolver&) = delete;
void setup() const;
void zero() const;
void init_from_cache(size_t num_basis_funcs, const std::vector<BasisEval>& shell_cache) const;
void init_from_basis(size_t num_basis_funcs, const SplineBasis& basis) const;
void solve_inplace(N_Vector x, size_t num_vars, size_t basis_size) const;
};
struct CVODEUserData {
SpectralSolverStrategy* solver_instance{};
engine::DynamicEngine* engine;
DenseLinearSolver* mass_matrix_solver_instance;
const SplineBasis* basis;
};
private:
fourdst::config::Config<config::GridFireConfig> m_config;
quill::Logger* m_logger = fourdst::logging::LogManager::getInstance().getLogger("log");
SUNContext m_sun_ctx = nullptr; ///< SUNDIALS context (lifetime of the solver).
void* m_cvode_mem = nullptr; ///< CVODE memory block.
N_Vector m_Y = nullptr; ///< CVODE state vector (species + energy accumulator).
SUNMatrix m_J = nullptr; ///< Dense Jacobian matrix.
SUNLinearSolver m_LS = nullptr; ///< Dense linear solver.
std::optional<TimestepCallback> m_callback; ///< Optional per-step callback.
int m_num_steps = 0; ///< CVODE step counter (used for diagnostics and triggers).
bool m_stdout_logging_enabled = true; ///< If true, print per-step logs and use CV_ONE_STEP.
N_Vector m_constraints = nullptr; ///< CVODE constraints vector (>= 0 for species entries).
std::optional<double> m_absTol; ///< User-specified absolute tolerance.
std::optional<double> m_relTol; ///< User-specified relative tolerance.
bool m_detailed_step_logging = false; ///< If true, log detailed step diagnostics (error ratios, Jacobian, species balance).
mutable size_t m_last_size = 0;
mutable size_t m_last_composition_hash = 0ULL;
mutable sunrealtype m_last_good_time_step = 0ULL;
SplineBasis m_current_basis;
Constants m_constants;
N_Vector m_T_coeffs = nullptr;
N_Vector m_rho_coeffs = nullptr;
private:
std::vector<double> evaluate_monitor_function(const std::vector<NetIn>& current_shells) const;
SplineBasis generate_basis_from_monitor(const std::vector<double>& monitor_values, const std::vector<double>& mass_coordinates) const;
GridPoint reconstruct_at_quadrature(const N_Vector y_coeffs, size_t quad_index, const SplineBasis &basis) const;
std::vector<NetOut> reconstruct_solution(const std::vector<NetIn>& original_inputs, const std::vector<double>& mass_coordinates, const N_Vector final_coeffs, const SplineBasis& basis, double dt) const;
static int cvode_rhs_wrapper(sunrealtype t, N_Vector y, N_Vector, void* user_data);
static int cvode_jac_wrapper(sunrealtype t, N_Vector y, N_Vector ydot, SUNMatrix J, void* user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3);
int calculate_rhs(sunrealtype t, N_Vector y_coeffs, N_Vector ydot_coeffs, CVODEUserData* data) const;
static void project_specific_variable(
const std::vector<NetIn>& current_shells,
const std::vector<double>& mass_coordinates,
const std::vector<BasisEval>& shell_cache,
const DenseLinearSolver& linear_solver,
N_Vector output_vec,
size_t output_offset,
const std::function<double(const NetIn&)> &getter,
bool use_log
);
};
}

View File

@@ -2,5 +2,5 @@
#include "gridfire/solver/strategies/triggers/triggers.h" #include "gridfire/solver/strategies/triggers/triggers.h"
#include "gridfire/solver/strategies/strategy_abstract.h" #include "gridfire/solver/strategies/strategy_abstract.h"
#include "gridfire/solver/strategies/CVODE_solver_strategy.h" #include "gridfire/solver/strategies/PointSolver.h"
#include "gridfire/solver/strategies/SpectralSolverStrategy.h" #include "gridfire/solver/strategies/GridSolver.h"

View File

@@ -13,17 +13,24 @@ namespace gridfire::solver {
template <typename EngineT> template <typename EngineT>
concept IsEngine = std::is_base_of_v<engine::Engine, EngineT>; concept IsEngine = std::is_base_of_v<engine::Engine, EngineT>;
struct SolverContextBase {
virtual void init() = 0;
virtual void set_stdout_logging(bool enable) = 0;
virtual void set_detailed_logging(bool enable) = 0;
virtual ~SolverContextBase() = default;
};
/** /**
* @struct SolverContextBase * @struct TimestepContextBase
* @brief Base class for solver callback contexts. * @brief Base class for solver callback contexts.
* *
* This struct serves as a base class for contexts that can be passed to solver callbacks, it enforces * This struct serves as a base class for contexts that can be passed to solver callbacks, it enforces
* that derived classes implement a `describe` method that returns a vector of tuples describing * that derived classes implement a `describe` method that returns a vector of tuples describing
* the context that a callback will receive when called. * the context that a callback will receive when called.
*/ */
class SolverContextBase { class TimestepContextBase {
public: public:
virtual ~SolverContextBase() = default; virtual ~TimestepContextBase() = default;
/** /**
* @brief Describe the context for callback functions. * @brief Describe the context for callback functions.
@@ -54,11 +61,9 @@ namespace gridfire::solver {
* @param engine The engine to use for evaluating the network. * @param engine The engine to use for evaluating the network.
*/ */
explicit SingleZoneNetworkSolver( explicit SingleZoneNetworkSolver(
const EngineT& engine, const EngineT& engine
const engine::scratch::StateBlob& ctx
) : ) :
m_engine(engine), m_engine(engine) {};
m_scratch_blob(ctx.clone_structure()) {};
/** /**
* @brief Virtual destructor. * @brief Virtual destructor.
@@ -67,38 +72,18 @@ namespace gridfire::solver {
/** /**
* @brief Evaluates the network for a given timestep. * @brief Evaluates the network for a given timestep.
* @param solver_ctx
* @param engine_ctx
* @param netIn The input conditions for the network. * @param netIn The input conditions for the network.
* @return The output conditions after the timestep. * @return The output conditions after the timestep.
*/ */
virtual NetOut evaluate(const NetIn& netIn) = 0; virtual NetOut evaluate(
SolverContextBase& solver_ctx,
const NetIn& netIn
) const = 0;
/**
* @brief set the callback function to be called at the end of each timestep.
*
* This function allows the user to set a callback function that will be called at the end of each timestep.
* The callback function will receive a gridfire::solver::<SOMESOLVER>::TimestepContext object. Note that
* depending on the solver, this context may contain different information. Further, the exact
* signature of the callback function is left up to each solver. Every solver should provide a type or type alias
* TimestepCallback that defines the signature of the callback function so that the user can easily
* get that type information.
*
* @param callback The callback function to be called at the end of each timestep.
*/
virtual void set_callback(const std::any& callback) = 0;
/**
* @brief Describe the context that will be passed to the callback function.
* @return A vector of tuples, each containing a string for the parameter's name and a string for its type.
*
* This method should be overridden by derived classes to provide a description of the context
* that will be passed to the callback function. The intent of this method is that an end user can investigate
* the context that will be passed to the callback function, and use this information to craft their own
* callback function.
*/
[[nodiscard]] virtual std::vector<std::tuple<std::string, std::string>> describe_callback_context() const = 0;
protected: protected:
const EngineT& m_engine; ///< The engine used by this solver strategy. const EngineT& m_engine; ///< The engine used by this solver strategy.
std::unique_ptr<engine::scratch::StateBlob> m_scratch_blob;
}; };
template <IsEngine EngineT> template <IsEngine EngineT>
@@ -106,22 +91,20 @@ namespace gridfire::solver {
public: public:
explicit MultiZoneNetworkSolver( explicit MultiZoneNetworkSolver(
const EngineT& engine, const EngineT& engine,
const engine::scratch::StateBlob& ctx const SingleZoneNetworkSolver<EngineT>& solver
) : ) :
m_engine(engine), m_engine(engine),
m_scratch_blob_structure(ctx.clone_structure()){}; m_solver(solver) {};
virtual ~MultiZoneNetworkSolver() = default; virtual ~MultiZoneNetworkSolver() = default;
virtual std::vector<NetOut> evaluate( virtual std::vector<NetOut> evaluate(
const std::vector<NetIn>& netIns, SolverContextBase& solver_ctx,
const std::vector<double>& mass_coords const std::vector<NetIn>& netIns
) = 0; ) const = 0;
virtual void set_callback(const std::any& callback) = 0;
[[nodiscard]] virtual std::vector<std::tuple<std::string, std::string>> describe_callback_context() const = 0;
protected: protected:
const EngineT& m_engine; ///< The engine used by this solver strategy. const EngineT& m_engine; ///< The engine used by this solver strategy.
std::unique_ptr<engine::scratch::StateBlob> m_scratch_blob_structure; const SingleZoneNetworkSolver<EngineT>& m_solver;
}; };
/** /**

View File

@@ -2,7 +2,7 @@
#include "gridfire/trigger/trigger_abstract.h" #include "gridfire/trigger/trigger_abstract.h"
#include "gridfire/trigger/trigger_result.h" #include "gridfire/trigger/trigger_result.h"
#include "gridfire/solver/strategies/CVODE_solver_strategy.h" #include "gridfire/solver/strategies/PointSolver.h"
#include "fourdst/logging/logging.h" #include "fourdst/logging/logging.h"
#include <string> #include <string>
@@ -47,7 +47,7 @@ namespace gridfire::trigger::solver::CVODE {
* *
* See also: engine_partitioning_trigger.cpp for the concrete logic and logging. * See also: engine_partitioning_trigger.cpp for the concrete logic and logging.
*/ */
class SimulationTimeTrigger final : public Trigger<gridfire::solver::CVODESolverStrategy::TimestepContext> { class SimulationTimeTrigger final : public Trigger<gridfire::solver::PointSolverTimestepContext> {
public: public:
/** /**
* @brief Construct with a positive time interval between firings. * @brief Construct with a positive time interval between firings.
@@ -62,7 +62,7 @@ namespace gridfire::trigger::solver::CVODE {
* *
* @post increments hit/miss counters and may emit trace logs. * @post increments hit/miss counters and may emit trace logs.
*/ */
bool check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override; bool check(const gridfire::solver::PointSolverTimestepContext &ctx) const override;
/** /**
* @brief Update internal state; if check(ctx) is true, advance last_trigger_time. * @brief Update internal state; if check(ctx) is true, advance last_trigger_time.
* @param ctx CVODE timestep context. * @param ctx CVODE timestep context.
@@ -70,9 +70,9 @@ namespace gridfire::trigger::solver::CVODE {
* @note update() calls check(ctx) and, on success, records the overshoot delta * @note update() calls check(ctx) and, on success, records the overshoot delta
* (ctx.t - last_trigger_time) - interval for diagnostics. * (ctx.t - last_trigger_time) - interval for diagnostics.
*/ */
void update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override; void update(const gridfire::solver::PointSolverTimestepContext &ctx) override;
void step(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override; void step(const gridfire::solver::PointSolverTimestepContext &ctx) override;
/** /**
* @brief Reset counters and last trigger bookkeeping (time and delta) to zero. * @brief Reset counters and last trigger bookkeeping (time and delta) to zero.
*/ */
@@ -85,7 +85,7 @@ namespace gridfire::trigger::solver::CVODE {
* @param ctx CVODE timestep context. * @param ctx CVODE timestep context.
* @return TriggerResult including name, value, and description. * @return TriggerResult including name, value, and description.
*/ */
TriggerResult why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override; TriggerResult why(const gridfire::solver::PointSolverTimestepContext &ctx) const override;
/** @brief Textual description including configured interval. */ /** @brief Textual description including configured interval. */
std::string describe() const override; std::string describe() const override;
/** @brief Number of true evaluations since last reset. */ /** @brief Number of true evaluations since last reset. */
@@ -130,7 +130,7 @@ namespace gridfire::trigger::solver::CVODE {
* @par See also * @par See also
* - engine_partitioning_trigger.cpp for concrete logic and trace logging. * - engine_partitioning_trigger.cpp for concrete logic and trace logging.
*/ */
class OffDiagonalTrigger final : public Trigger<gridfire::solver::CVODESolverStrategy::TimestepContext> { class OffDiagonalTrigger final : public Trigger<gridfire::solver::PointSolverTimestepContext> {
public: public:
/** /**
* @brief Construct with a non-negative magnitude threshold. * @brief Construct with a non-negative magnitude threshold.
@@ -145,13 +145,13 @@ namespace gridfire::trigger::solver::CVODE {
* *
* @post increments hit/miss counters and may emit trace logs. * @post increments hit/miss counters and may emit trace logs.
*/ */
bool check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override; bool check(const gridfire::solver::PointSolverTimestepContext &ctx) const override;
/** /**
* @brief Record an update; does not mutate any Jacobian-related state. * @brief Record an update; does not mutate any Jacobian-related state.
* @param ctx CVODE timestep context (unused except for symmetry with interface). * @param ctx CVODE timestep context (unused except for symmetry with interface).
*/ */
void update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override; void update(const gridfire::solver::PointSolverTimestepContext &ctx) override;
void step(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override; void step(const gridfire::solver::PointSolverTimestepContext &ctx) override;
/** @brief Reset counters to zero. */ /** @brief Reset counters to zero. */
void reset() override; void reset() override;
@@ -161,7 +161,7 @@ namespace gridfire::trigger::solver::CVODE {
* @brief Structured explanation of the evaluation outcome. * @brief Structured explanation of the evaluation outcome.
* @param ctx CVODE timestep context. * @param ctx CVODE timestep context.
*/ */
TriggerResult why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override; TriggerResult why(const gridfire::solver::PointSolverTimestepContext &ctx) const override;
/** @brief Textual description including configured threshold. */ /** @brief Textual description including configured threshold. */
std::string describe() const override; std::string describe() const override;
/** @brief Number of true evaluations since last reset. */ /** @brief Number of true evaluations since last reset. */
@@ -206,7 +206,7 @@ namespace gridfire::trigger::solver::CVODE {
* *
* See also: engine_partitioning_trigger.cpp for exact logic and logging. * See also: engine_partitioning_trigger.cpp for exact logic and logging.
*/ */
class TimestepCollapseTrigger final : public Trigger<gridfire::solver::CVODESolverStrategy::TimestepContext> { class TimestepCollapseTrigger final : public Trigger<gridfire::solver::PointSolverTimestepContext> {
public: public:
/** /**
* @brief Construct with threshold and relative/absolute mode; window size defaults to 1. * @brief Construct with threshold and relative/absolute mode; window size defaults to 1.
@@ -230,20 +230,20 @@ namespace gridfire::trigger::solver::CVODE {
* *
* @post increments hit/miss counters and may emit trace logs. * @post increments hit/miss counters and may emit trace logs.
*/ */
bool check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override; bool check(const gridfire::solver::PointSolverTimestepContext &ctx) const override;
/** /**
* @brief Update sliding window with the most recent dt and increment update counter. * @brief Update sliding window with the most recent dt and increment update counter.
* @param ctx CVODE timestep context. * @param ctx CVODE timestep context.
*/ */
void update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override; void update(const gridfire::solver::PointSolverTimestepContext &ctx) override;
void step(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override; void step(const gridfire::solver::PointSolverTimestepContext &ctx) override;
/** @brief Reset counters and clear the dt window. */ /** @brief Reset counters and clear the dt window. */
void reset() override; void reset() override;
/** @brief Stable human-readable name. */ /** @brief Stable human-readable name. */
std::string name() const override; std::string name() const override;
/** @brief Structured explanation of the evaluation outcome. */ /** @brief Structured explanation of the evaluation outcome. */
TriggerResult why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override; TriggerResult why(const gridfire::solver::PointSolverTimestepContext &ctx) const override;
/** @brief Textual description including threshold, mode, and window size. */ /** @brief Textual description including threshold, mode, and window size. */
std::string describe() const override; std::string describe() const override;
/** @brief Number of true evaluations since last reset. */ /** @brief Number of true evaluations since last reset. */
@@ -272,15 +272,15 @@ namespace gridfire::trigger::solver::CVODE {
std::deque<double> m_timestep_window; std::deque<double> m_timestep_window;
}; };
class ConvergenceFailureTrigger final : public Trigger<gridfire::solver::CVODESolverStrategy::TimestepContext> { class ConvergenceFailureTrigger final : public Trigger<gridfire::solver::PointSolverTimestepContext> {
public: public:
explicit ConvergenceFailureTrigger(size_t totalFailures, float relativeFailureRate, size_t windowSize); explicit ConvergenceFailureTrigger(size_t totalFailures, float relativeFailureRate, size_t windowSize);
bool check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override; bool check(const gridfire::solver::PointSolverTimestepContext &ctx) const override;
void update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override; void update(const gridfire::solver::PointSolverTimestepContext &ctx) override;
void step(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) override; void step(const gridfire::solver::PointSolverTimestepContext &ctx) override;
void reset() override; void reset() override;
@@ -288,7 +288,7 @@ namespace gridfire::trigger::solver::CVODE {
[[nodiscard]] std::string describe() const override; [[nodiscard]] std::string describe() const override;
[[nodiscard]] TriggerResult why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const override; [[nodiscard]] TriggerResult why(const gridfire::solver::PointSolverTimestepContext &ctx) const override;
[[nodiscard]] size_t numTriggers() const override; [[nodiscard]] size_t numTriggers() const override;
@@ -312,8 +312,8 @@ namespace gridfire::trigger::solver::CVODE {
private: private:
float current_mean() const; float current_mean() const;
bool abs_failure(const gridfire::solver::CVODESolverStrategy::TimestepContext& ctx) const; bool abs_failure(const gridfire::solver::PointSolverTimestepContext& ctx) const;
bool rel_failure(const gridfire::solver::CVODESolverStrategy::TimestepContext& ctx) const; bool rel_failure(const gridfire::solver::PointSolverTimestepContext& ctx) const;
}; };
/** /**
@@ -337,10 +337,10 @@ namespace gridfire::trigger::solver::CVODE {
* *
* @note The exact policy is subject to change; this function centralizes that decision. * @note The exact policy is subject to change; this function centralizes that decision.
*/ */
std::unique_ptr<Trigger<gridfire::solver::CVODESolverStrategy::TimestepContext>> makeEnginePartitioningTrigger( std::unique_ptr<Trigger<gridfire::solver::PointSolverTimestepContext>> makeEnginePartitioningTrigger(
const double simulationTimeInterval, double simulationTimeInterval,
const double offDiagonalThreshold, double offDiagonalThreshold,
const double timestepCollapseRatio, double timestepCollapseRatio,
const size_t maxConvergenceFailures size_t maxConvergenceFailures
); );
} }

View File

@@ -59,3 +59,26 @@ namespace gridfire {
concept IsArithmeticOrAD = std::is_same_v<T, double> || std::is_same_v<T, CppAD::AD<double>>; concept IsArithmeticOrAD = std::is_same_v<T, double> || std::is_same_v<T, CppAD::AD<double>>;
} // namespace nuclearNetwork } // namespace nuclearNetwork
template<>
struct std::formatter<gridfire::NetIn> : std::formatter<std::string> {
auto format(const gridfire::NetIn& netIn, auto& ctx) {
std::string output = "NetIn(, tMax=" + std::to_string(netIn.tMax) +
", dt0=" + std::to_string(netIn.dt0) +
", temperature=" + std::to_string(netIn.temperature) +
", density=" + std::to_string(netIn.density) +
", energy=" + std::to_string(netIn.energy) + ")";
return std::formatter<std::string>::format(output, ctx);
}
};
template <>
struct std::formatter<gridfire::NetOut> : std::formatter<std::string> {
auto format(const gridfire::NetOut& netOut, auto& ctx) {
std::string output = "NetOut(, num_steps=" + std::to_string(netOut.num_steps) +
", energy=" + std::to_string(netOut.energy) +
", dEps_dT=" + std::to_string(netOut.dEps_dT) +
", dEps_dRho=" + std::to_string(netOut.dEps_dRho) + ")";
return std::formatter<std::string>::format(output, ctx);
}
};

View File

@@ -0,0 +1,51 @@
#pragma once
#include "fourdst/logging/logging.h"
#include "quill/LogMacros.h"
#if defined(GF_USE_OPENMP)
#include <omp.h>
namespace gridfire::omp {
static bool s_par_mode_initialized = false;
inline unsigned long get_thread_id() {
return static_cast<unsigned long>(omp_get_thread_num());
}
inline bool in_parallel() {
return omp_in_parallel() != 0;
}
inline void init_parallel_mode() {
if (s_par_mode_initialized) {
return; // Only initialize once
}
quill::Logger* logger = fourdst::logging::LogManager::getInstance().getLogger("log");
LOG_INFO(logger, "Initializing OpenMP parallel mode with {} threads", static_cast<unsigned long>(omp_get_max_threads()));
CppAD::thread_alloc::parallel_setup(
static_cast<size_t>(omp_get_max_threads()), // Max threads
[]() -> bool { return in_parallel(); }, // Function to get thread ID
[]() -> size_t { return get_thread_id(); } // Function to check parallel state
);
CppAD::thread_alloc::hold_memory(true);
CppAD::CheckSimpleVector<double, std::vector<double>>(0, 1);
s_par_mode_initialized = true;
}
}
#define GF_PAR_INIT() gridfire::omp::init_parallel_mode();
#else
namespace gridfire::omp {
inline void log_not_in_parallel_mode() {
quill::Logger* logger = fourdst::logging::LogManager::getInstance().getLogger("log");
LOG_INFO(logger, "This is not an error! Note: OpenMP parallel mode is not enabled. GF_USE_OPENMP is not defined. Pass -DGF_USE_OPENMP when compiling to enable OpenMP support. When using meson use the option -Dopenmp_support=true");
}
}
#define GF_PAR_INIT() gridfire::omp::log_not_in_parallel_mode();
#endif

View File

@@ -0,0 +1,13 @@
#pragma once
#if defined(GF_USE_OPENMP)
#define GF_OMP_PRAGMA(x) _Pragma(#x)
#define GF_OMP(omp_args, extra) GF_OMP_PRAGMA(omp omp_args) extra
#define GF_OMP_MAX_THREADS omp_get_max_threads()
#define GF_OMP_THREAD_NUM omp_get_thread_num()
#else
#define GF_OMP(_,fallback_args) fallback_args
#define GF_OMP_MAX_THREADS 1
#define GF_OMP_THREAD_NUM 0
#endif

View File

@@ -5,3 +5,4 @@
#include "gridfire/utils/logging.h" #include "gridfire/utils/logging.h"
#include "gridfire/utils/sundials.h" #include "gridfire/utils/sundials.h"
#include "gridfire/utils/table_format.h" #include "gridfire/utils/table_format.h"
#include "gridfire/utils/macros.h"

View File

@@ -118,8 +118,7 @@ namespace gridfire::engine {
m_weakRateInterpolator(rates::weak::UNIFIED_WEAK_DATA), m_weakRateInterpolator(rates::weak::UNIFIED_WEAK_DATA),
m_reactions(build_nuclear_network(composition, m_weakRateInterpolator, buildDepth, reactionTypes)), m_reactions(build_nuclear_network(composition, m_weakRateInterpolator, buildDepth, reactionTypes)),
m_partitionFunction(partitionFunction.clone()), m_partitionFunction(partitionFunction.clone()),
m_depth(buildDepth), m_depth(buildDepth)
m_state_blob_offset(0) // For a base engine the offset is always 0
{ {
syncInternalMaps(); syncInternalMaps();
} }
@@ -128,8 +127,7 @@ namespace gridfire::engine {
const reaction::ReactionSet &reactions const reaction::ReactionSet &reactions
) : ) :
m_weakRateInterpolator(rates::weak::UNIFIED_WEAK_DATA), m_weakRateInterpolator(rates::weak::UNIFIED_WEAK_DATA),
m_reactions(reactions), m_reactions(reactions)
m_state_blob_offset(0)
{ {
syncInternalMaps(); syncInternalMaps();
} }
@@ -398,8 +396,10 @@ namespace gridfire::engine {
else { factor = std::pow(abundance, static_cast<double>(power)); } else { factor = std::pow(abundance, static_cast<double>(power)); }
if (!std::isfinite(factor)) { if (!std::isfinite(factor)) {
LOG_CRITICAL(m_logger, "Non-finite factor encountered in forward abundance product for reaction '{}'. Check input abundances for validity.", reaction.id()); const auto& sp = m_indexToSpeciesMap.at(reactantIndex);
throw exceptions::BadRHSEngineError("Non-finite factor encountered in forward abundance product."); std::string error_msg = std::format("Non-finite factor encountered in forward abundance product in reaction {} for species {} (Abundance: {}). Check input abundances for validity.", reaction.id(), sp.name(), abundance);
LOG_CRITICAL(m_logger, "{}", error_msg);
throw exceptions::BadRHSEngineError(error_msg);
} }
forwardAbundanceProduct *= factor; forwardAbundanceProduct *= factor;
@@ -949,9 +949,6 @@ namespace gridfire::engine {
const double rho, const double rho,
const std::vector<fourdst::atomic::Species> &activeSpecies const std::vector<fourdst::atomic::Species> &activeSpecies
) const { ) const {
// PERF: For small k it may make sense to implement a purley forward mode AD computation,
// some heuristic could be used to switch between the two methods based on k and
// total network species
const size_t k_active = activeSpecies.size(); const size_t k_active = activeSpecies.size();
// --- 1. Get the list of global indices --- // --- 1. Get the list of global indices ---
@@ -1443,7 +1440,7 @@ namespace gridfire::engine {
m_precomputed_reactions.push_back(std::move(precomp)); m_precomputed_reactions.push_back(std::move(precomp));
} }
LOG_TRACE_L1(m_logger, "Pre-computation complete. Precomputed data for {} reactions.", m_precomputedReactions.size()); LOG_TRACE_L1(m_logger, "Pre-computation complete. Precomputed data for {} reactions.", m_precomputed_reactions.size());
} }
bool GraphEngine::AtomicReverseRate::forward( bool GraphEngine::AtomicReverseRate::forward(

View File

@@ -2,7 +2,6 @@
#include "fourdst/atomic/species.h" #include "fourdst/atomic/species.h"
#include "fourdst/composition/utils.h" #include "fourdst/composition/utils.h"
#include "gridfire/engine/views/engine_priming.h"
#include "gridfire/solver/solver.h" #include "gridfire/solver/solver.h"
#include "gridfire/engine/engine_abstract.h" #include "gridfire/engine/engine_abstract.h"
@@ -13,7 +12,7 @@
#include "gridfire/engine/scratchpads/engine_graph_scratchpad.h" #include "gridfire/engine/scratchpads/engine_graph_scratchpad.h"
#include "fourdst/logging/logging.h" #include "fourdst/logging/logging.h"
#include "gridfire/solver/strategies/CVODE_solver_strategy.h" #include "gridfire/solver/strategies/PointSolver.h"
#include "quill/Logger.h" #include "quill/Logger.h"
#include "quill/LogMacros.h" #include "quill/LogMacros.h"
@@ -28,13 +27,12 @@ namespace gridfire::engine {
const GraphEngine& engine, const std::optional<std::vector<reaction::ReactionType>>& ignoredReactionTypes const GraphEngine& engine, const std::optional<std::vector<reaction::ReactionType>>& ignoredReactionTypes
) { ) {
const auto logger = LogManager::getInstance().getLogger("log"); const auto logger = LogManager::getInstance().getLogger("log");
solver::CVODESolverStrategy integrator(engine, ctx); solver::PointSolver integrator(engine);
solver::PointSolverContext solverCtx(ctx);
solverCtx.abs_tol = 1e-3;
solverCtx.rel_tol = 1e-3;
solverCtx.stdout_logging = false;
// Do not need high precision for priming
integrator.set_absTol(1e-3);
integrator.set_relTol(1e-3);
integrator.set_stdout_logging_enabled(false);
NetIn solverInput(netIn); NetIn solverInput(netIn);
solverInput.tMax = 1e-15; solverInput.tMax = 1e-15;
@@ -43,7 +41,7 @@ namespace gridfire::engine {
LOG_INFO(logger, "Short timescale ({}) network ignition started.", solverInput.tMax); LOG_INFO(logger, "Short timescale ({}) network ignition started.", solverInput.tMax);
PrimingReport report; PrimingReport report;
try { try {
const NetOut netOut = integrator.evaluate(solverInput, false); const NetOut netOut = integrator.evaluate(solverCtx, solverInput);
LOG_INFO(logger, "Network ignition completed."); LOG_INFO(logger, "Network ignition completed.");
LOG_TRACE_L2( LOG_TRACE_L2(
logger, logger,

View File

@@ -129,8 +129,7 @@ namespace gridfire::engine {
const double rho, const double rho,
const std::vector<Species> &activeSpecies const std::vector<Species> &activeSpecies
) const { ) const {
const auto *state = scratch::get_state<scratch::AdaptiveEngineViewScratchPad, true>(ctx); return m_baseEngine.generateJacobianMatrix(ctx, comp, T9, rho, activeSpecies);
return m_baseEngine.generateJacobianMatrix(ctx, comp, T9, rho, state->active_species);
} }

View File

@@ -748,7 +748,7 @@ namespace gridfire::engine {
} }
} }
} }
LOG_TRACE_L1(m_logger, "Algebraic species identified: {}", utils::iterable_to_delimited_string(m_algebraic_species)); LOG_TRACE_L1(m_logger, "Algebraic species identified: {}", utils::iterable_to_delimited_string(state->algebraic_species));
LOG_INFO( LOG_INFO(
m_logger, m_logger,
@@ -773,7 +773,7 @@ namespace gridfire::engine {
state->dynamic_species.push_back(species); state->dynamic_species.push_back(species);
} }
} }
LOG_TRACE_L1(m_logger, "Final dynamic species set: {}", utils::iterable_to_delimited_string(m_dynamic_species)); LOG_TRACE_L1(m_logger, "Final dynamic species set: {}", utils::iterable_to_delimited_string(state->dynamic_species));
LOG_TRACE_L1(m_logger, "Creating QSE solvers for each identified QSE group..."); LOG_TRACE_L1(m_logger, "Creating QSE solvers for each identified QSE group...");
for (const auto& group : state->qse_groups) { for (const auto& group : state->qse_groups) {
@@ -783,7 +783,7 @@ namespace gridfire::engine {
} }
state->qse_solvers.push_back(std::make_unique<QSESolver>(groupAlgebraicSpecies, m_baseEngine, state->sun_ctx)); state->qse_solvers.push_back(std::make_unique<QSESolver>(groupAlgebraicSpecies, m_baseEngine, state->sun_ctx));
} }
LOG_TRACE_L1(m_logger, "{} QSE solvers created.", m_qse_solvers.size()); LOG_TRACE_L1(m_logger, "{} QSE solvers created.", state->qse_solvers.size());
LOG_TRACE_L1(m_logger, "Calculating final equilibrated composition..."); LOG_TRACE_L1(m_logger, "Calculating final equilibrated composition...");
fourdst::composition::Composition result = getNormalizedEquilibratedComposition(ctx, comp, T9, rho, false); fourdst::composition::Composition result = getNormalizedEquilibratedComposition(ctx, comp, T9, rho, false);
@@ -1949,7 +1949,7 @@ namespace gridfire::engine {
LOG_TRACE_L2( LOG_TRACE_L2(
getLogger(), getLogger(),
"Starting KINSol QSE solver with initial state: {}", "Starting KINSol QSE solver with initial state: {}",
[&comp, &initial_rhs, &data]() -> std::string { [&comp, &rhsGuess, &data]() -> std::string {
std::ostringstream oss; std::ostringstream oss;
oss << "Solve species: <"; oss << "Solve species: <";
size_t count = 0; size_t count = 0;
@@ -1963,7 +1963,7 @@ namespace gridfire::engine {
oss << "> | Initial abundances and rates: "; oss << "> | Initial abundances and rates: ";
count = 0; count = 0;
for (const auto& [species, abundance] : comp) { for (const auto& [species, abundance] : comp) {
oss << species.name() << ": Y = " << abundance << ", dY/dt = " << initial_rhs.value().dydt.at(species); oss << species.name() << ": Y = " << abundance << ", dY/dt = " << rhsGuess.dydt.at(species);
if (count < comp.size() - 1) { if (count < comp.size() - 1) {
oss << ", "; oss << ", ";
} }
@@ -2005,7 +2005,32 @@ namespace gridfire::engine {
LOG_INFO(getLogger(), "KINSol failed to converge within the maximum number of iterations, but achieved acceptable accuracy with function norm {} < {}. Proceeding with solution.", LOG_INFO(getLogger(), "KINSol failed to converge within the maximum number of iterations, but achieved acceptable accuracy with function norm {} < {}. Proceeding with solution.",
fnorm, ACCEPTABLE_FTOL); fnorm, ACCEPTABLE_FTOL);
} else { } else {
LOG_WARNING(getLogger(), "KINSol failed to converge while solving QSE abundances with flag {}. Error {}", utils::kinsol_ret_code_map.at(flag), fnorm); LOG_CRITICAL(getLogger(), "KINSol failed to converge while solving QSE abundances with flag {}. Flag No.: {}, Error (fNorm): {}", utils::kinsol_ret_code_map.at(flag), flag, fnorm);
LOG_CRITICAL(getLogger(), "State prior to failure: {}",
[&comp, &data]() -> std::string {
std::ostringstream oss;
oss << "Solve species: <";
size_t count = 0;
for (const auto& species : data.qse_solve_species) {
oss << species.name();
if (count < data.qse_solve_species.size() - 1) {
oss << ", ";
}
count++;
}
oss << "> | Abundances and rates at failure: ";
count = 0;
for (const auto& [species, abundance] : comp) {
oss << species.name() << ": Y = " << abundance;
if (count < comp.size() - 1) {
oss << ", ";
}
count++;
}
oss << " | Temperature: " << data.T9 << ", Density: " << data.rho;
return oss.str();
}()
);
throw exceptions::InvalidQSESolutionError("KINSol failed to converge while solving QSE abundances. " + utils::kinsol_ret_code_map.at(flag)); throw exceptions::InvalidQSESolutionError("KINSol failed to converge while solving QSE abundances. " + utils::kinsol_ret_code_map.at(flag));
} }
} }

View File

@@ -282,11 +282,11 @@ namespace gridfire::reaction {
double Ye, double Ye,
double mue, const std::vector<double> &Y, const std::unordered_map<size_t, Species>& index_to_species_map double mue, const std::vector<double> &Y, const std::unordered_map<size_t, Species>& index_to_species_map
) const { ) const {
if (m_cached_rates.contains(T9)) { // if (m_cached_rates.contains(T9)) {
return m_cached_rates.at(T9); // return m_cached_rates.at(T9);
} // }
const double rate = calculate_rate<double>(T9); const double rate = calculate_rate<double>(T9);
m_cached_rates[T9] = rate; // m_cached_rates[T9] = rate;
return rate; return rate;
} }

View File

@@ -0,0 +1,107 @@
#include "gridfire/solver/strategies/GridSolver.h"
#include "gridfire/exceptions/error_solver.h"
#include "gridfire/solver/strategies/PointSolver.h"
#include "gridfire/utils/macros.h"
#include "gridfire/utils/gf_omp.h"
#include <cstdio>
#include <print>
namespace gridfire::solver {
void GridSolverContext::init() {}
void GridSolverContext::reset() {
solver_workspaces.clear();
timestep_callbacks.clear();
}
void GridSolverContext::set_callback(const std::function<void(const TimestepContextBase &)> &callback) {
for (auto &cb : timestep_callbacks) {
cb = callback;
}
}
void GridSolverContext::set_callback(const std::function<void(const TimestepContextBase &)> &callback, const size_t zone_idx) {
if (zone_idx >= timestep_callbacks.size()) {
throw exceptions::SolverError("GridSolverContext::set_callback: zone_idx out of range.");
}
timestep_callbacks[zone_idx] = callback;
}
void GridSolverContext::clear_callback() {
for (auto &cb : timestep_callbacks) {
cb = nullptr;
}
}
void GridSolverContext::clear_callback(const size_t zone_idx) {
if (zone_idx >= timestep_callbacks.size()) {
throw exceptions::SolverError("GridSolverContext::clear_callback: zone_idx out of range.");
}
timestep_callbacks[zone_idx] = nullptr;
}
void GridSolverContext::set_stdout_logging(const bool enable) {
zone_stdout_logging = enable;
}
void GridSolverContext::set_detailed_logging(const bool enable) {
zone_detailed_logging = enable;
}
GridSolverContext::GridSolverContext(
const engine::scratch::StateBlob &ctx_template
) :
ctx_template(ctx_template) {}
GridSolver::GridSolver(
const engine::DynamicEngine &engine,
const SingleZoneDynamicNetworkSolver &solver
) :
MultiZoneNetworkSolver(engine, solver) {
GF_PAR_INIT();
}
std::vector<NetOut> GridSolver::evaluate(
SolverContextBase& ctx,
const std::vector<NetIn>& netIns
) const {
auto* sctx_p = dynamic_cast<GridSolverContext*>(&ctx);
if (!sctx_p) {
throw exceptions::SolverError("GridSolver::evaluate: SolverContextBase is not of type GridSolverContext.");
}
const size_t n_zones = netIns.size();
if (n_zones == 0) { return {}; }
std::vector<NetOut> results(n_zones);
sctx_p->solver_workspaces.resize(n_zones);
GF_OMP(
parallel for default(none) shared(sctx_p, n_zones),
for (size_t zone_idx = 0; zone_idx < n_zones; ++zone_idx)) {
sctx_p->solver_workspaces[zone_idx] = std::make_unique<PointSolverContext>(sctx_p->ctx_template);
sctx_p->solver_workspaces[zone_idx]->set_stdout_logging(sctx_p->zone_stdout_logging);
sctx_p->solver_workspaces[zone_idx]->set_detailed_logging(sctx_p->zone_detailed_logging);
}
GF_OMP(
parallel for default(none) shared(results, sctx_p, netIns, n_zones),
for (size_t zone_idx = 0; zone_idx < n_zones; ++zone_idx)) {
try {
results[zone_idx] = m_solver.evaluate(
*sctx_p->solver_workspaces[zone_idx],
netIns[zone_idx]
);
} catch (exceptions::GridFireError& e) {
std::println("CVODE Solver Failure in zone {}: {}", zone_idx, e.what());
}
if (sctx_p->zone_completion_logging) {
std::println("Thread {} completed zone {}", GF_OMP_THREAD_NUM, zone_idx);
}
}
return results;
}
}

View File

@@ -1,4 +1,4 @@
#include "gridfire/solver/strategies/CVODE_solver_strategy.h" #include "gridfire/solver/strategies/PointSolver.h"
#include "gridfire/types/types.h" #include "gridfire/types/types.h"
#include "gridfire/utils/table_format.h" #include "gridfire/utils/table_format.h"
@@ -28,7 +28,7 @@
namespace gridfire::solver { namespace gridfire::solver {
using namespace gridfire::engine; using namespace gridfire::engine;
CVODESolverStrategy::TimestepContext::TimestepContext( PointSolverTimestepContext::PointSolverTimestepContext(
const double t, const double t,
const N_Vector &state, const N_Vector &state,
const double dt, const double dt,
@@ -58,7 +58,7 @@ namespace gridfire::solver {
state_ctx(ctx) state_ctx(ctx)
{} {}
std::vector<std::tuple<std::string, std::string>> CVODESolverStrategy::TimestepContext::describe() const { std::vector<std::tuple<std::string, std::string>> PointSolverTimestepContext::describe() const {
std::vector<std::tuple<std::string, std::string>> description; std::vector<std::tuple<std::string, std::string>> description;
description.emplace_back("t", "Current Time"); description.emplace_back("t", "Current Time");
description.emplace_back("state", "Current State Vector (N_Vector)"); description.emplace_back("state", "Current State Vector (N_Vector)");
@@ -74,36 +74,112 @@ namespace gridfire::solver {
return description; return description;
} }
void PointSolverContext::init() {
reset_all();
init_context();
}
CVODESolverStrategy::CVODESolverStrategy( void PointSolverContext::set_stdout_logging(const bool enable) {
const DynamicEngine &engine, stdout_logging = enable;
const scratch::StateBlob& ctx }
): SingleZoneNetworkSolver<DynamicEngine>(engine, ctx) {
// PERF: In order to support MPI this function must be changed void PointSolverContext::set_detailed_logging(const bool enable) {
const int flag = SUNContext_Create(SUN_COMM_NULL, &m_sun_ctx); detailed_step_logging = enable;
if (flag < 0) { }
throw std::runtime_error("Failed to create SUNDIALS context (SUNDIALS Errno: " + std::to_string(flag) + ")");
void PointSolverContext::reset_all() {
reset_user();
reset_cvode();
}
void PointSolverContext::reset_user() {
callback.reset();
num_steps = 0;
stdout_logging = true;
abs_tol.reset();
rel_tol.reset();
detailed_step_logging = false;
last_size = 0;
last_composition_hash = 0ULL;
}
void PointSolverContext::reset_cvode() {
if (LS) {
SUNLinSolFree(LS);
LS = nullptr;
}
if (J) {
SUNMatDestroy(J);
J = nullptr;
}
if (Y) {
N_VDestroy(Y);
Y = nullptr;
}
if (YErr) {
N_VDestroy(YErr);
YErr = nullptr;
}
if (constraints) {
N_VDestroy(constraints);
constraints = nullptr;
}
if (cvode_mem) {
CVodeFree(&cvode_mem);
cvode_mem = nullptr;
} }
} }
CVODESolverStrategy::~CVODESolverStrategy() { void PointSolverContext::clear_context() {
LOG_TRACE_L1(m_logger, "Cleaning up CVODE resources..."); if (sun_ctx) {
cleanup_cvode_resources(true); SUNContext_Free(&sun_ctx);
sun_ctx = nullptr;
if (m_sun_ctx) {
SUNContext_Free(&m_sun_ctx);
} }
} }
NetOut CVODESolverStrategy::evaluate(const NetIn& netIn) { void PointSolverContext::init_context() {
return evaluate(netIn, false); if (!sun_ctx) {
utils::check_sundials_flag(SUNContext_Create(SUN_COMM_NULL, &sun_ctx), "SUNContext_Create", utils::SUNDIALS_RET_CODE_TYPES::CVODE);
}
} }
NetOut CVODESolverStrategy::evaluate( bool PointSolverContext::has_context() const {
return sun_ctx != nullptr;
}
PointSolverContext::PointSolverContext(
const scratch::StateBlob& engine_ctx
) :
engine_ctx(engine_ctx.clone_structure())
{
utils::check_sundials_flag(SUNContext_Create(SUN_COMM_NULL, &sun_ctx), "SUNContext_Create", utils::SUNDIALS_RET_CODE_TYPES::CVODE);
}
PointSolverContext::~PointSolverContext() {
reset_cvode();
clear_context();
}
PointSolver::PointSolver(
const DynamicEngine &engine
): SingleZoneNetworkSolver(engine) {}
NetOut PointSolver::evaluate(
SolverContextBase& solver_ctx,
const NetIn& netIn
) const {
return evaluate(solver_ctx, netIn, false);
}
NetOut PointSolver::evaluate(
SolverContextBase& solver_ctx,
const NetIn &netIn, const NetIn &netIn,
bool displayTrigger, bool displayTrigger,
bool forceReinitialize bool forceReinitialize
) { ) const {
auto* sctx_p = dynamic_cast<PointSolverContext*>(&solver_ctx);
LOG_TRACE_L1(m_logger, "Starting solver evaluation with T9: {} and rho: {}", netIn.temperature/1e9, netIn.density); LOG_TRACE_L1(m_logger, "Starting solver evaluation with T9: {} and rho: {}", netIn.temperature/1e9, netIn.density);
LOG_TRACE_L1(m_logger, "Building engine update trigger...."); LOG_TRACE_L1(m_logger, "Building engine update trigger....");
auto trigger = trigger::solver::CVODE::makeEnginePartitioningTrigger(1e12, 1e10, 0.5, 2); auto trigger = trigger::solver::CVODE::makeEnginePartitioningTrigger(1e12, 1e10, 0.5, 2);
@@ -117,23 +193,24 @@ namespace gridfire::solver {
// 2. If the user has set tolerances in code, those override the config // 2. If the user has set tolerances in code, those override the config
// 3. If the user has not set tolerances in code and the config does not have them, use hardcoded defaults // 3. If the user has not set tolerances in code and the config does not have them, use hardcoded defaults
auto absTol = m_config->solver.cvode.absTol; if (!sctx_p->abs_tol.has_value()) {
auto relTol = m_config->solver.cvode.relTol; sctx_p->abs_tol = m_config->solver.cvode.absTol;
if (m_absTol) {
absTol = *m_absTol;
} }
if (m_relTol) { if (!sctx_p->rel_tol.has_value()) {
relTol = *m_relTol; sctx_p->rel_tol = m_config->solver.cvode.relTol;
} }
bool resourcesExist = (m_cvode_mem != nullptr) && (m_Y != nullptr);
bool inconsistentComposition = netIn.composition.hash() != m_last_composition_hash; bool resourcesExist = (sctx_p->cvode_mem != nullptr) && (sctx_p->Y != nullptr);
bool inconsistentComposition = netIn.composition.hash() != sctx_p->last_composition_hash;
fourdst::composition::Composition equilibratedComposition; fourdst::composition::Composition equilibratedComposition;
if (forceReinitialize || !resourcesExist || inconsistentComposition) { if (forceReinitialize || !resourcesExist || inconsistentComposition) {
cleanup_cvode_resources(true); sctx_p->reset_cvode();
if (!sctx_p->has_context()) {
sctx_p->init_context();
}
LOG_INFO( LOG_INFO(
m_logger, m_logger,
"Preforming full CVODE initialization (Reason: {})", "Preforming full CVODE initialization (Reason: {})",
@@ -141,26 +218,24 @@ namespace gridfire::solver {
(!resourcesExist ? "CVODE resources do not exist" : (!resourcesExist ? "CVODE resources do not exist" :
"Input composition inconsistent with previous state")); "Input composition inconsistent with previous state"));
LOG_TRACE_L1(m_logger, "Starting engine update chain..."); LOG_TRACE_L1(m_logger, "Starting engine update chain...");
equilibratedComposition = m_engine.project(*m_scratch_blob, netIn); equilibratedComposition = m_engine.project(*sctx_p->engine_ctx, netIn);
LOG_TRACE_L1(m_logger, "Engine updated and equilibrated composition found!"); LOG_TRACE_L1(m_logger, "Engine updated and equilibrated composition found!");
size_t numSpecies = m_engine.getNetworkSpecies(*m_scratch_blob).size(); size_t numSpecies = m_engine.getNetworkSpecies(*sctx_p->engine_ctx).size();
uint64_t N = numSpecies + 1; uint64_t N = numSpecies + 1;
LOG_TRACE_L1(m_logger, "Number of species: {} ({} independent variables)", numSpecies, N); LOG_TRACE_L1(m_logger, "Number of species: {} ({} independent variables)", numSpecies, N);
LOG_TRACE_L1(m_logger, "Initializing CVODE resources"); LOG_TRACE_L1(m_logger, "Initializing CVODE resources");
m_cvode_mem = CVodeCreate(CV_BDF, m_sun_ctx);
utils::check_cvode_flag(m_cvode_mem == nullptr ? -1 : 0, "CVodeCreate");
initialize_cvode_integration_resources(N, numSpecies, 0.0, equilibratedComposition, absTol, relTol, 0.0); initialize_cvode_integration_resources(sctx_p, N, numSpecies, 0.0, equilibratedComposition, sctx_p->abs_tol.value(), sctx_p->rel_tol.value(), 0.0);
m_last_size = N; sctx_p->last_size = N;
} else { } else {
LOG_INFO(m_logger, "Reusing existing CVODE resources (size: {})", m_last_size); LOG_INFO(m_logger, "Reusing existing CVODE resources (size: {})", sctx_p->last_size);
const size_t numSpecies = m_engine.getNetworkSpecies(*m_scratch_blob).size(); const size_t numSpecies = m_engine.getNetworkSpecies(*sctx_p->engine_ctx).size();
sunrealtype *y_data = N_VGetArrayPointer(m_Y); sunrealtype *y_data = N_VGetArrayPointer(sctx_p->Y);
for (size_t i = 0; i < numSpecies; i++) { for (size_t i = 0; i < numSpecies; i++) {
const auto& species = m_engine.getNetworkSpecies(*m_scratch_blob)[i]; const auto& species = m_engine.getNetworkSpecies(*sctx_p->engine_ctx)[i];
if (netIn.composition.contains(species)) { if (netIn.composition.contains(species)) {
y_data[i] = netIn.composition.getMolarAbundance(species); y_data[i] = netIn.composition.getMolarAbundance(species);
} else { } else {
@@ -168,16 +243,17 @@ namespace gridfire::solver {
} }
} }
y_data[numSpecies] = 0.0; // Reset energy accumulator y_data[numSpecies] = 0.0; // Reset energy accumulator
utils::check_cvode_flag(CVodeSStolerances(m_cvode_mem, relTol, absTol), "CVodeSStolerances"); utils::check_cvode_flag(CVodeSStolerances(sctx_p->cvode_mem, sctx_p->rel_tol.value(), sctx_p->abs_tol.value()), "CVodeSStolerances");
utils::check_cvode_flag(CVodeReInit(m_cvode_mem, 0.0, m_Y), "CVodeReInit"); utils::check_cvode_flag(CVodeReInit(sctx_p->cvode_mem, 0.0, sctx_p->Y), "CVodeReInit");
equilibratedComposition = netIn.composition; // Use the provided composition as-is if we already have validated CVODE resources and that the composition is consistent with the previous state equilibratedComposition = netIn.composition; // Use the provided composition as-is if we already have validated CVODE resources and that the composition is consistent with the previous state
} }
size_t numSpecies = m_engine.getNetworkSpecies(*m_scratch_blob).size(); size_t numSpecies = m_engine.getNetworkSpecies(*sctx_p->engine_ctx).size();
CVODEUserData user_data { CVODEUserData user_data {
.solver_instance = this, .solver_instance = this,
.ctx = *m_scratch_blob, .sctx = sctx_p,
.ctx = *sctx_p->engine_ctx,
.engine = &m_engine, .engine = &m_engine,
}; };
LOG_TRACE_L1(m_logger, "CVODE resources successfully initialized!"); LOG_TRACE_L1(m_logger, "CVODE resources successfully initialized!");
@@ -185,7 +261,7 @@ namespace gridfire::solver {
double current_time = 0; double current_time = 0;
// ReSharper disable once CppTooWideScope // ReSharper disable once CppTooWideScope
[[maybe_unused]] double last_callback_time = 0; [[maybe_unused]] double last_callback_time = 0;
m_num_steps = 0; sctx_p->num_steps = 0;
double accumulated_energy = 0.0; double accumulated_energy = 0.0;
double accumulated_neutrino_energy_loss = 0.0; double accumulated_neutrino_energy_loss = 0.0;
@@ -200,18 +276,18 @@ namespace gridfire::solver {
size_t total_steps = 0; size_t total_steps = 0;
LOG_TRACE_L1(m_logger, "Starting CVODE iteration"); LOG_TRACE_L1(m_logger, "Starting CVODE iteration...");
fourdst::composition::Composition postStep = equilibratedComposition; fourdst::composition::Composition postStep = equilibratedComposition;
while (current_time < netIn.tMax) { while (current_time < netIn.tMax) {
user_data.T9 = T9; user_data.T9 = T9;
user_data.rho = netIn.density; user_data.rho = netIn.density;
user_data.networkSpecies = &m_engine.getNetworkSpecies(*m_scratch_blob); user_data.networkSpecies = &m_engine.getNetworkSpecies(*sctx_p->engine_ctx);
user_data.captured_exception.reset(); user_data.captured_exception.reset();
utils::check_cvode_flag(CVodeSetUserData(m_cvode_mem, &user_data), "CVodeSetUserData"); utils::check_cvode_flag(CVodeSetUserData(sctx_p->cvode_mem, &user_data), "CVodeSetUserData");
LOG_TRACE_L2(m_logger, "Taking one CVODE step..."); LOG_TRACE_L2(m_logger, "Taking one CVODE step...");
int flag = CVode(m_cvode_mem, netIn.tMax, m_Y, &current_time, CV_ONE_STEP); int flag = CVode(sctx_p->cvode_mem, netIn.tMax, sctx_p->Y, &current_time, CV_ONE_STEP);
LOG_TRACE_L2(m_logger, "CVODE step complete. Current time: {}, step status: {}", current_time, utils::cvode_ret_code_map.at(flag)); LOG_TRACE_L2(m_logger, "CVODE step complete. Current time: {}, step status: {}", current_time, utils::cvode_ret_code_map.at(flag));
if (user_data.captured_exception){ if (user_data.captured_exception){
@@ -223,13 +299,13 @@ namespace gridfire::solver {
long int n_steps; long int n_steps;
double last_step_size; double last_step_size;
CVodeGetNumSteps(m_cvode_mem, &n_steps); CVodeGetNumSteps(sctx_p->cvode_mem, &n_steps);
CVodeGetLastStep(m_cvode_mem, &last_step_size); CVodeGetLastStep(sctx_p->cvode_mem, &last_step_size);
long int nliters, nlcfails; long int nliters, nlcfails;
CVodeGetNumNonlinSolvIters(m_cvode_mem, &nliters); CVodeGetNumNonlinSolvIters(sctx_p->cvode_mem, &nliters);
CVodeGetNumNonlinSolvConvFails(m_cvode_mem, &nlcfails); CVodeGetNumNonlinSolvConvFails(sctx_p->cvode_mem, &nlcfails);
sunrealtype* y_data = N_VGetArrayPointer(m_Y); sunrealtype* y_data = N_VGetArrayPointer(sctx_p->Y);
const double current_energy = y_data[numSpecies]; // Specific energy rate const double current_energy = y_data[numSpecies]; // Specific energy rate
// TODO: Accumulate neutrino loss through the state vector directly which will allow CVODE to properly integrate it // TODO: Accumulate neutrino loss through the state vector directly which will allow CVODE to properly integrate it
@@ -238,7 +314,7 @@ namespace gridfire::solver {
size_t iter_diff = (total_nonlinear_iterations + nliters) - prev_nonlinear_iterations; size_t iter_diff = (total_nonlinear_iterations + nliters) - prev_nonlinear_iterations;
size_t convFail_diff = (total_convergence_failures + nlcfails) - prev_convergence_failures; size_t convFail_diff = (total_convergence_failures + nlcfails) - prev_convergence_failures;
if (m_stdout_logging_enabled) { if (sctx_p->stdout_logging) {
std::println( std::println(
"Step: {:6} | Updates: {:3} | Epoch Steps: {:4} | t: {:.3e} [s] | dt: {:15.6E} [s] | Iterations: {:6} (+{:2}) | Total Convergence Failures: {:2} (+{:2})", "Step: {:6} | Updates: {:3} | Epoch Steps: {:4} | t: {:.3e} [s] | dt: {:15.6E} [s] | Iterations: {:6} (+{:2}) | Total Convergence Failures: {:2} (+{:2})",
total_steps + n_steps, total_steps + n_steps,
@@ -253,20 +329,16 @@ namespace gridfire::solver {
); );
} }
for (size_t i = 0; i < numSpecies; ++i) { for (size_t i = 0; i < numSpecies; ++i) {
const auto& species = m_engine.getNetworkSpecies(*m_scratch_blob)[i]; const auto& species = m_engine.getNetworkSpecies(*sctx_p->engine_ctx)[i];
if (y_data[i] > 0.0) { if (y_data[i] > 0.0) {
postStep.setMolarAbundance(species, y_data[i]); postStep.setMolarAbundance(species, y_data[i]);
} }
} }
// fourdst::composition::Composition collectedComposition = m_engine.collectComposition(postStep, netIn.temperature/1e9, netIn.density);
// for (size_t i = 0; i < numSpecies; ++i) {
// y_data[i] = collectedComposition.getMolarAbundance(m_engine.getNetworkSpecies()[i]);
// }
LOG_INFO(m_logger, "Completed {:5} steps to time {:10.4E} [s] (dt = {:15.6E} [s]). Current specific energy: {:15.6E} [erg/g]", total_steps + n_steps, current_time, last_step_size, current_energy); LOG_INFO(m_logger, "Completed {:5} steps to time {:10.4E} [s] (dt = {:15.6E} [s]). Current specific energy: {:15.6E} [erg/g]", total_steps + n_steps, current_time, last_step_size, current_energy);
LOG_DEBUG(m_logger, "Current composition (molar abundance): {}", [&]() -> std::string { LOG_DEBUG(m_logger, "Current composition (molar abundance): {}", [&]() -> std::string {
std::stringstream ss; std::stringstream ss;
for (size_t i = 0; i < numSpecies; ++i) { for (size_t i = 0; i < numSpecies; ++i) {
const auto& species = m_engine.getNetworkSpecies(*m_scratch_blob)[i]; const auto& species = m_engine.getNetworkSpecies(*sctx_p->engine_ctx)[i];
ss << species.name() << ": (y_data = " << y_data[i] << ", collected = " << postStep.getMolarAbundance(species) << ")"; ss << species.name() << ": (y_data = " << y_data[i] << ", collected = " << postStep.getMolarAbundance(species) << ")";
if (i < numSpecies - 1) { if (i < numSpecies - 1) {
ss << ", "; ss << ", ";
@@ -282,36 +354,44 @@ namespace gridfire::solver {
? user_data.reaction_contribution_map.value() ? user_data.reaction_contribution_map.value()
: kEmptyMap; : kEmptyMap;
auto ctx = TimestepContext( auto ctx = PointSolverTimestepContext(
current_time, current_time,
m_Y, sctx_p->Y,
last_step_size, last_step_size,
last_callback_time, last_callback_time,
T9, T9,
netIn.density, netIn.density,
n_steps, n_steps,
m_engine, m_engine,
m_engine.getNetworkSpecies(*m_scratch_blob), m_engine.getNetworkSpecies(*sctx_p->engine_ctx),
convFail_diff, convFail_diff,
iter_diff, iter_diff,
rcMap, rcMap,
*m_scratch_blob *sctx_p->engine_ctx
); );
prev_nonlinear_iterations = nliters + total_nonlinear_iterations; prev_nonlinear_iterations = nliters + total_nonlinear_iterations;
prev_convergence_failures = nlcfails + total_convergence_failures; prev_convergence_failures = nlcfails + total_convergence_failures;
if (m_callback.has_value()) { if (sctx_p->callback.has_value()) {
m_callback.value()(ctx); sctx_p->callback.value()(ctx);
} }
trigger->step(ctx); trigger->step(ctx);
if (m_detailed_step_logging) { if (sctx_p->detailed_step_logging) {
log_step_diagnostics(*m_scratch_blob, user_data, true, true, true, "step_" + std::to_string(total_steps + n_steps) + ".json"); log_step_diagnostics(
sctx_p,
*sctx_p->engine_ctx,
user_data,
true,
true,
true,
"step_" + std::to_string(total_steps + n_steps) + ".json"
);
} }
if (trigger->check(ctx)) { if (trigger->check(ctx)) {
if (m_stdout_logging_enabled && displayTrigger) { if (sctx_p->stdout_logging && displayTrigger) {
trigger::printWhy(trigger->why(ctx)); trigger::printWhy(trigger->why(ctx));
} }
trigger->update(ctx); trigger->update(ctx);
@@ -333,20 +413,20 @@ namespace gridfire::solver {
fourdst::composition::Composition temp_comp; fourdst::composition::Composition temp_comp;
std::vector<double> mass_fractions; std::vector<double> mass_fractions;
auto num_species_at_stop = static_cast<long int>(m_engine.getNetworkSpecies(*m_scratch_blob).size()); auto num_species_at_stop = static_cast<long int>(m_engine.getNetworkSpecies(*sctx_p->engine_ctx).size());
if (num_species_at_stop > m_Y->ops->nvgetlength(m_Y) - 1) { if (num_species_at_stop > sctx_p->Y->ops->nvgetlength(sctx_p->Y) - 1) {
LOG_ERROR( LOG_ERROR(
m_logger, m_logger,
"Number of species at engine update ({}) exceeds the number of species in the CVODE solver ({}). This should never happen.", "Number of species at engine update ({}) exceeds the number of species in the CVODE solver ({}). This should never happen.",
num_species_at_stop, num_species_at_stop,
m_Y->ops->nvgetlength(m_Y) - 1 // -1 due to energy in the last index sctx_p->Y->ops->nvgetlength(sctx_p->Y) - 1 // -1 due to energy in the last index
); );
throw std::runtime_error("Number of species at engine update exceeds the number of species in the CVODE solver. This should never happen."); throw std::runtime_error("Number of species at engine update exceeds the number of species in the CVODE solver. This should never happen.");
} }
for (const auto& species: m_engine.getNetworkSpecies(*m_scratch_blob)) { for (const auto& species: m_engine.getNetworkSpecies(*sctx_p->engine_ctx)) {
const size_t sid = m_engine.getSpeciesIndex(*m_scratch_blob, species); const size_t sid = m_engine.getSpeciesIndex(*sctx_p->engine_ctx, species);
temp_comp.registerSpecies(species); temp_comp.registerSpecies(species);
double y = end_of_step_abundances[sid]; double y = end_of_step_abundances[sid];
if (y > 0.0) { if (y > 0.0) {
@@ -356,7 +436,7 @@ namespace gridfire::solver {
#ifndef NDEBUG #ifndef NDEBUG
for (long int i = 0; i < num_species_at_stop; ++i) { for (long int i = 0; i < num_species_at_stop; ++i) {
const auto& species = m_engine.getNetworkSpecies(*m_scratch_blob)[i]; const auto& species = m_engine.getNetworkSpecies(*sctx_p->engine_ctx)[i];
if (std::abs(temp_comp.getMolarAbundance(species) - y_data[i]) > 1e-12) { if (std::abs(temp_comp.getMolarAbundance(species) - y_data[i]) > 1e-12) {
throw exceptions::UtilityError("Conversion from solver state to composition molar abundance failed verification."); throw exceptions::UtilityError("Conversion from solver state to composition molar abundance failed verification.");
} }
@@ -391,7 +471,7 @@ namespace gridfire::solver {
"Prior to Engine Update active reactions are: {}", "Prior to Engine Update active reactions are: {}",
[&]() -> std::string { [&]() -> std::string {
std::stringstream ss; std::stringstream ss;
const gridfire::reaction::ReactionSet& reactions = m_engine.getNetworkReactions(*m_scratch_blob); const gridfire::reaction::ReactionSet& reactions = m_engine.getNetworkReactions(*sctx_p->engine_ctx);
size_t count = 0; size_t count = 0;
for (const auto& reaction : reactions) { for (const auto& reaction : reactions) {
ss << reaction -> id(); ss << reaction -> id();
@@ -403,7 +483,7 @@ namespace gridfire::solver {
return ss.str(); return ss.str();
}() }()
); );
fourdst::composition::Composition currentComposition = m_engine.project(*m_scratch_blob, netInTemp); fourdst::composition::Composition currentComposition = m_engine.project(*sctx_p->engine_ctx, netInTemp);
LOG_DEBUG( LOG_DEBUG(
m_logger, m_logger,
"After to Engine update composition is (molar abundance) {}", "After to Engine update composition is (molar abundance) {}",
@@ -450,7 +530,7 @@ namespace gridfire::solver {
"After Engine Update active reactions are: {}", "After Engine Update active reactions are: {}",
[&]() -> std::string { [&]() -> std::string {
std::stringstream ss; std::stringstream ss;
const gridfire::reaction::ReactionSet& reactions = m_engine.getNetworkReactions(*m_scratch_blob); const gridfire::reaction::ReactionSet& reactions = m_engine.getNetworkReactions(*sctx_p->engine_ctx);
size_t count = 0; size_t count = 0;
for (const auto& reaction : reactions) { for (const auto& reaction : reactions) {
ss << reaction -> id(); ss << reaction -> id();
@@ -466,34 +546,29 @@ namespace gridfire::solver {
m_logger, m_logger,
"Due to a triggered engine update the composition was updated from size {} to {} species.", "Due to a triggered engine update the composition was updated from size {} to {} species.",
num_species_at_stop, num_species_at_stop,
m_engine.getNetworkSpecies(*m_scratch_blob).size() m_engine.getNetworkSpecies(*sctx_p->engine_ctx).size()
); );
numSpecies = m_engine.getNetworkSpecies(*m_scratch_blob).size(); numSpecies = m_engine.getNetworkSpecies(*sctx_p->engine_ctx).size();
size_t N = numSpecies + 1; size_t N = numSpecies + 1;
LOG_INFO(m_logger, "Starting CVODE reinitialization after engine update..."); LOG_INFO(m_logger, "Starting CVODE reinitialization after engine update...");
cleanup_cvode_resources(true); sctx_p->reset_cvode();
initialize_cvode_integration_resources(sctx_p, N, numSpecies, current_time, currentComposition, sctx_p->abs_tol.value(), sctx_p->rel_tol.value(), accumulated_energy);
m_cvode_mem = CVodeCreate(CV_BDF, m_sun_ctx); utils::check_cvode_flag(CVodeReInit(sctx_p->cvode_mem, current_time, sctx_p->Y), "CVodeReInit");
utils::check_cvode_flag(m_cvode_mem == nullptr ? -1 : 0, "CVodeCreate");
initialize_cvode_integration_resources(N, numSpecies, current_time, currentComposition, absTol, relTol, accumulated_energy);
utils::check_cvode_flag(CVodeReInit(m_cvode_mem, current_time, m_Y), "CVodeReInit");
// throw exceptions::DebugException("Debug");
LOG_INFO(m_logger, "Done reinitializing CVODE after engine update. The next log messages will be from the first step after reinitialization..."); LOG_INFO(m_logger, "Done reinitializing CVODE after engine update. The next log messages will be from the first step after reinitialization...");
} }
} }
if (m_stdout_logging_enabled) { // Flush the buffer if standard out logging is enabled if (sctx_p->stdout_logging) { // Flush the buffer if standard out logging is enabled
std::cout << std::flush; std::cout << std::flush;
} }
LOG_INFO(m_logger, "CVODE iteration complete"); LOG_INFO(m_logger, "CVODE iteration complete");
sunrealtype* y_data = N_VGetArrayPointer(m_Y); sunrealtype* y_data = N_VGetArrayPointer(sctx_p->Y);
accumulated_energy += y_data[numSpecies]; accumulated_energy += y_data[numSpecies];
std::vector<double> y_vec(y_data, y_data + numSpecies); std::vector<double> y_vec(y_data, y_data + numSpecies);
@@ -505,7 +580,7 @@ namespace gridfire::solver {
LOG_INFO(m_logger, "Constructing final composition= with {} species", numSpecies); LOG_INFO(m_logger, "Constructing final composition= with {} species", numSpecies);
fourdst::composition::Composition topLevelComposition(m_engine.getNetworkSpecies(*m_scratch_blob), y_vec); fourdst::composition::Composition topLevelComposition(m_engine.getNetworkSpecies(*sctx_p->engine_ctx), y_vec);
LOG_INFO(m_logger, "Final composition constructed from solver state successfully! ({})", [&topLevelComposition]() -> std::string { LOG_INFO(m_logger, "Final composition constructed from solver state successfully! ({})", [&topLevelComposition]() -> std::string {
std::ostringstream ss; std::ostringstream ss;
size_t i = 0; size_t i = 0;
@@ -520,7 +595,7 @@ namespace gridfire::solver {
}()); }());
LOG_INFO(m_logger, "Collecting final composition..."); LOG_INFO(m_logger, "Collecting final composition...");
fourdst::composition::Composition outputComposition = m_engine.collectComposition(*m_scratch_blob, topLevelComposition, netIn.temperature/1e9, netIn.density); fourdst::composition::Composition outputComposition = m_engine.collectComposition(*sctx_p->engine_ctx, topLevelComposition, netIn.temperature/1e9, netIn.density);
assert(outputComposition.getRegisteredSymbols().size() == equilibratedComposition.getRegisteredSymbols().size()); assert(outputComposition.getRegisteredSymbols().size() == equilibratedComposition.getRegisteredSymbols().size());
@@ -541,11 +616,11 @@ namespace gridfire::solver {
NetOut netOut; NetOut netOut;
netOut.composition = outputComposition; netOut.composition = outputComposition;
netOut.energy = accumulated_energy; netOut.energy = accumulated_energy;
utils::check_cvode_flag(CVodeGetNumSteps(m_cvode_mem, reinterpret_cast<long int *>(&netOut.num_steps)), "CVodeGetNumSteps"); utils::check_cvode_flag(CVodeGetNumSteps(sctx_p->cvode_mem, reinterpret_cast<long int *>(&netOut.num_steps)), "CVodeGetNumSteps");
LOG_TRACE_L2(m_logger, "generating final nuclear energy generation rate derivatives..."); LOG_TRACE_L2(m_logger, "generating final nuclear energy generation rate derivatives...");
auto [dEps_dT, dEps_dRho] = m_engine.calculateEpsDerivatives( auto [dEps_dT, dEps_dRho] = m_engine.calculateEpsDerivatives(
*m_scratch_blob, *sctx_p->engine_ctx,
outputComposition, outputComposition,
T9, T9,
netIn.density netIn.density
@@ -559,53 +634,13 @@ namespace gridfire::solver {
LOG_TRACE_L2(m_logger, "Output data built!"); LOG_TRACE_L2(m_logger, "Output data built!");
LOG_TRACE_L2(m_logger, "Solver evaluation complete!."); LOG_TRACE_L2(m_logger, "Solver evaluation complete!.");
m_last_composition_hash = netOut.composition.hash(); sctx_p->last_composition_hash = netOut.composition.hash();
m_last_size = netOut.composition.size() + 1; sctx_p->last_size = netOut.composition.size() + 1;
CVodeGetLastStep(m_cvode_mem, &m_last_good_time_step); CVodeGetLastStep(sctx_p->cvode_mem, &sctx_p->last_good_time_step);
return netOut; return netOut;
} }
void CVODESolverStrategy::set_callback(const std::any &callback) { int PointSolver::cvode_rhs_wrapper(
m_callback = std::any_cast<TimestepCallback>(callback);
}
bool CVODESolverStrategy::get_stdout_logging_enabled() const {
return m_stdout_logging_enabled;
}
void CVODESolverStrategy::set_stdout_logging_enabled(const bool logging_enabled) {
m_stdout_logging_enabled = logging_enabled;
}
void CVODESolverStrategy::set_absTol(double absTol) {
m_absTol = absTol;
}
void CVODESolverStrategy::set_relTol(double relTol) {
m_relTol = relTol;
}
double CVODESolverStrategy::get_absTol() const {
if (m_absTol.has_value()) {
return m_absTol.value();
} else {
return -1.0;
}
}
double CVODESolverStrategy::get_relTol() const {
if (m_relTol.has_value()) {
return m_relTol.value();
} else {
return -1.0;
}
}
std::vector<std::tuple<std::string, std::string>> CVODESolverStrategy::describe_callback_context() const {
return {};
}
int CVODESolverStrategy::cvode_rhs_wrapper(
const sunrealtype t, const sunrealtype t,
const N_Vector y, const N_Vector y,
const N_Vector ydot, const N_Vector ydot,
@@ -633,7 +668,7 @@ namespace gridfire::solver {
} }
} }
int CVODESolverStrategy::cvode_jac_wrapper( int PointSolver::cvode_jac_wrapper(
sunrealtype t, sunrealtype t,
N_Vector y, N_Vector y,
N_Vector ydot, N_Vector ydot,
@@ -754,7 +789,7 @@ namespace gridfire::solver {
return 0; return 0;
} }
CVODESolverStrategy::CVODERHSOutputData CVODESolverStrategy::calculate_rhs( PointSolver::CVODERHSOutputData PointSolver::calculate_rhs(
const sunrealtype t, const sunrealtype t,
N_Vector y, N_Vector y,
N_Vector ydot, N_Vector ydot,
@@ -772,10 +807,10 @@ namespace gridfire::solver {
} }
} }
std::vector<double> y_vec(y_data, y_data + numSpecies); std::vector<double> y_vec(y_data, y_data + numSpecies);
fourdst::composition::Composition composition(m_engine.getNetworkSpecies(*m_scratch_blob), y_vec); fourdst::composition::Composition composition(m_engine.getNetworkSpecies(data->ctx), y_vec);
LOG_TRACE_L2(m_logger, "Calculating RHS at time {} with {} species in composition", t, composition.size()); LOG_TRACE_L2(m_logger, "Calculating RHS at time {} with {} species in composition", t, composition.size());
const auto result = m_engine.calculateRHSAndEnergy(*m_scratch_blob, composition, data->T9, data->rho, false); const auto result = m_engine.calculateRHSAndEnergy(data->ctx, composition, data->T9, data->rho, false);
if (!result) { if (!result) {
LOG_CRITICAL(m_logger, "Failed to calculate RHS at time {}: {}", t, EngineStatus_to_string(result.error())); LOG_CRITICAL(m_logger, "Failed to calculate RHS at time {}: {}", t, EngineStatus_to_string(result.error()));
throw exceptions::BadRHSEngineError(std::format("Failed to calculate RHS at time {}: {}", t, EngineStatus_to_string(result.error()))); throw exceptions::BadRHSEngineError(std::format("Failed to calculate RHS at time {}: {}", t, EngineStatus_to_string(result.error())));
@@ -805,7 +840,7 @@ namespace gridfire::solver {
}()); }());
for (size_t i = 0; i < numSpecies; ++i) { for (size_t i = 0; i < numSpecies; ++i) {
fourdst::atomic::Species species = m_engine.getNetworkSpecies(*m_scratch_blob)[i]; fourdst::atomic::Species species = m_engine.getNetworkSpecies(data->ctx)[i];
ydot_data[i] = dydt.at(species); ydot_data[i] = dydt.at(species);
} }
ydot_data[numSpecies] = nuclearEnergyGenerationRate; // Set the last element to the specific energy rate ydot_data[numSpecies] = nuclearEnergyGenerationRate; // Set the last element to the specific energy rate
@@ -813,7 +848,8 @@ namespace gridfire::solver {
return {reactionContributions, result.value().neutrinoEnergyLossRate, result.value().totalNeutrinoFlux}; return {reactionContributions, result.value().neutrinoEnergyLossRate, result.value().totalNeutrinoFlux};
} }
void CVODESolverStrategy::initialize_cvode_integration_resources( void PointSolver::initialize_cvode_integration_resources(
PointSolverContext* sctx_p,
const uint64_t N, const uint64_t N,
const size_t numSpecies, const size_t numSpecies,
const double current_time, const double current_time,
@@ -821,16 +857,18 @@ namespace gridfire::solver {
const double absTol, const double absTol,
const double relTol, const double relTol,
const double accumulatedEnergy const double accumulatedEnergy
) { ) const {
LOG_TRACE_L2(m_logger, "Initializing CVODE integration resources with N: {}, current_time: {}, absTol: {}, relTol: {}", N, current_time, absTol, relTol); LOG_TRACE_L2(m_logger, "Initializing CVODE integration resources with N: {}, current_time: {}, absTol: {}, relTol: {}", N, current_time, absTol, relTol);
cleanup_cvode_resources(false); // Cleanup any existing resources before initializing new ones sctx_p->reset_cvode();
m_Y = utils::init_sun_vector(N, m_sun_ctx); sctx_p->cvode_mem = CVodeCreate(CV_BDF, sctx_p->sun_ctx);
m_YErr = N_VClone(m_Y); utils::check_cvode_flag(sctx_p->cvode_mem == nullptr ? -1 : 0, "CVodeCreate");
sctx_p->Y = utils::init_sun_vector(N, sctx_p->sun_ctx);
sctx_p->YErr = N_VClone(sctx_p->Y);
sunrealtype *y_data = N_VGetArrayPointer(m_Y); sunrealtype *y_data = N_VGetArrayPointer(sctx_p->Y);
for (size_t i = 0; i < numSpecies; i++) { for (size_t i = 0; i < numSpecies; i++) {
const auto& species = m_engine.getNetworkSpecies(*m_scratch_blob)[i]; const auto& species = m_engine.getNetworkSpecies(*sctx_p->engine_ctx)[i];
if (composition.contains(species)) { if (composition.contains(species)) {
y_data[i] = composition.getMolarAbundance(species); y_data[i] = composition.getMolarAbundance(species);
} else { } else {
@@ -840,8 +878,8 @@ namespace gridfire::solver {
y_data[numSpecies] = accumulatedEnergy; // Specific energy rate, initialized to zero y_data[numSpecies] = accumulatedEnergy; // Specific energy rate, initialized to zero
utils::check_cvode_flag(CVodeInit(m_cvode_mem, cvode_rhs_wrapper, current_time, m_Y), "CVodeInit"); utils::check_cvode_flag(CVodeInit(sctx_p->cvode_mem, cvode_rhs_wrapper, current_time, sctx_p->Y), "CVodeInit");
utils::check_cvode_flag(CVodeSStolerances(m_cvode_mem, relTol, absTol), "CVodeSStolerances"); utils::check_cvode_flag(CVodeSStolerances(sctx_p->cvode_mem, relTol, absTol), "CVodeSStolerances");
// Constraints // Constraints
// We constrain the solution vector using CVODE's built in constraint flags as outlines on page 53 of the CVODE manual // We constrain the solution vector using CVODE's built in constraint flags as outlines on page 53 of the CVODE manual
@@ -854,53 +892,30 @@ namespace gridfire::solver {
// -2.0: The corresponding component of y is constrained to be < 0 // -2.0: The corresponding component of y is constrained to be < 0
// Here we use 1.0 for all species to ensure they remain non-negative. // Here we use 1.0 for all species to ensure they remain non-negative.
m_constraints = N_VClone(m_Y); sctx_p->constraints = N_VClone(sctx_p->Y);
if (m_constraints == nullptr) { if (sctx_p->constraints == nullptr) {
LOG_ERROR(m_logger, "Failed to create constraints vector for CVODE"); LOG_ERROR(m_logger, "Failed to create constraints vector for CVODE");
throw std::runtime_error("Failed to create constraints vector for CVODE"); throw std::runtime_error("Failed to create constraints vector for CVODE");
} }
N_VConst(1.0, m_constraints); // Set all constraints to >= 0 (note this is where the flag values are set) N_VConst(1.0, sctx_p->constraints); // Set all constraints to >= 0 (note this is where the flag values are set)
utils::check_cvode_flag(CVodeSetConstraints(m_cvode_mem, m_constraints), "CVodeSetConstraints"); utils::check_cvode_flag(CVodeSetConstraints(sctx_p->cvode_mem, sctx_p->constraints), "CVodeSetConstraints");
utils::check_cvode_flag(CVodeSetMaxStep(m_cvode_mem, 1.0e20), "CVodeSetMaxStep"); utils::check_cvode_flag(CVodeSetMaxStep(sctx_p->cvode_mem, 1.0e20), "CVodeSetMaxStep");
m_J = SUNDenseMatrix(static_cast<sunindextype>(N), static_cast<sunindextype>(N), m_sun_ctx); sctx_p->J = SUNDenseMatrix(static_cast<sunindextype>(N), static_cast<sunindextype>(N), sctx_p->sun_ctx);
utils::check_cvode_flag(m_J == nullptr ? -1 : 0, "SUNDenseMatrix"); utils::check_cvode_flag(sctx_p->J == nullptr ? -1 : 0, "SUNDenseMatrix");
m_LS = SUNLinSol_Dense(m_Y, m_J, m_sun_ctx); sctx_p->LS = SUNLinSol_Dense(sctx_p->Y, sctx_p->J, sctx_p->sun_ctx);
utils::check_cvode_flag(m_LS == nullptr ? -1 : 0, "SUNLinSol_Dense"); utils::check_cvode_flag(sctx_p->LS == nullptr ? -1 : 0, "SUNLinSol_Dense");
utils::check_cvode_flag(CVodeSetLinearSolver(m_cvode_mem, m_LS, m_J), "CVodeSetLinearSolver"); utils::check_cvode_flag(CVodeSetLinearSolver(sctx_p->cvode_mem, sctx_p->LS, sctx_p->J), "CVodeSetLinearSolver");
utils::check_cvode_flag(CVodeSetJacFn(m_cvode_mem, cvode_jac_wrapper), "CVodeSetJacFn"); utils::check_cvode_flag(CVodeSetJacFn(sctx_p->cvode_mem, cvode_jac_wrapper), "CVodeSetJacFn");
LOG_TRACE_L2(m_logger, "CVODE solver initialized"); LOG_TRACE_L2(m_logger, "CVODE solver initialized");
} }
void CVODESolverStrategy::cleanup_cvode_resources(const bool memFree) {
LOG_TRACE_L2(m_logger, "Cleaning up cvode resources");
if (m_LS) SUNLinSolFree(m_LS);
if (m_J) SUNMatDestroy(m_J);
if (m_Y) N_VDestroy(m_Y);
if (m_YErr) N_VDestroy(m_YErr);
if (m_constraints) N_VDestroy(m_constraints);
m_LS = nullptr; void PointSolver::log_step_diagnostics(
m_J = nullptr; PointSolverContext* sctx_p,
m_Y = nullptr;
m_YErr = nullptr;
m_constraints = nullptr;
if (memFree) {
if (m_cvode_mem) CVodeFree(&m_cvode_mem);
m_cvode_mem = nullptr;
}
LOG_TRACE_L2(m_logger, "Done Cleaning up cvode resources");
}
void CVODESolverStrategy::set_detailed_step_logging(const bool enabled) {
m_detailed_step_logging = enabled;
}
void CVODESolverStrategy::log_step_diagnostics(
scratch::StateBlob &ctx, scratch::StateBlob &ctx,
const CVODEUserData &user_data, const CVODEUserData &user_data,
bool displayJacobianStiffness, bool displayJacobianStiffness,
@@ -916,10 +931,10 @@ namespace gridfire::solver {
sunrealtype hlast, hcur, tcur; sunrealtype hlast, hcur, tcur;
int qlast; int qlast;
utils::check_cvode_flag(CVodeGetLastStep(m_cvode_mem, &hlast), "CVodeGetLastStep"); utils::check_cvode_flag(CVodeGetLastStep(sctx_p->cvode_mem, &hlast), "CVodeGetLastStep");
utils::check_cvode_flag(CVodeGetCurrentStep(m_cvode_mem, &hcur), "CVodeGetCurrentStep"); utils::check_cvode_flag(CVodeGetCurrentStep(sctx_p->cvode_mem, &hcur), "CVodeGetCurrentStep");
utils::check_cvode_flag(CVodeGetLastOrder(m_cvode_mem, &qlast), "CVodeGetLastOrder"); utils::check_cvode_flag(CVodeGetLastOrder(sctx_p->cvode_mem, &qlast), "CVodeGetLastOrder");
utils::check_cvode_flag(CVodeGetCurrentTime(m_cvode_mem, &tcur), "CVodeGetCurrentTime"); utils::check_cvode_flag(CVodeGetCurrentTime(sctx_p->cvode_mem, &tcur), "CVodeGetCurrentTime");
nlohmann::json j; nlohmann::json j;
{ {
@@ -941,13 +956,13 @@ namespace gridfire::solver {
// These are the CRITICAL counters for diagnosing your problem // These are the CRITICAL counters for diagnosing your problem
long int nsteps, nfevals, nlinsetups, netfails, nniters, nconvfails, nsetfails; long int nsteps, nfevals, nlinsetups, netfails, nniters, nconvfails, nsetfails;
utils::check_cvode_flag(CVodeGetNumSteps(m_cvode_mem, &nsteps), "CVodeGetNumSteps"); utils::check_cvode_flag(CVodeGetNumSteps(sctx_p->cvode_mem, &nsteps), "CVodeGetNumSteps");
utils::check_cvode_flag(CVodeGetNumRhsEvals(m_cvode_mem, &nfevals), "CVodeGetNumRhsEvals"); utils::check_cvode_flag(CVodeGetNumRhsEvals(sctx_p->cvode_mem, &nfevals), "CVodeGetNumRhsEvals");
utils::check_cvode_flag(CVodeGetNumLinSolvSetups(m_cvode_mem, &nlinsetups), "CVodeGetNumLinSolvSetups"); utils::check_cvode_flag(CVodeGetNumLinSolvSetups(sctx_p->cvode_mem, &nlinsetups), "CVodeGetNumLinSolvSetups");
utils::check_cvode_flag(CVodeGetNumErrTestFails(m_cvode_mem, &netfails), "CVodeGetNumErrTestFails"); utils::check_cvode_flag(CVodeGetNumErrTestFails(sctx_p->cvode_mem, &netfails), "CVodeGetNumErrTestFails");
utils::check_cvode_flag(CVodeGetNumNonlinSolvIters(m_cvode_mem, &nniters), "CVodeGetNumNonlinSolvIters"); utils::check_cvode_flag(CVodeGetNumNonlinSolvIters(sctx_p->cvode_mem, &nniters), "CVodeGetNumNonlinSolvIters");
utils::check_cvode_flag(CVodeGetNumNonlinSolvConvFails(m_cvode_mem, &nconvfails), "CVodeGetNumNonlinSolvConvFails"); utils::check_cvode_flag(CVodeGetNumNonlinSolvConvFails(sctx_p->cvode_mem, &nconvfails), "CVodeGetNumNonlinSolvConvFails");
utils::check_cvode_flag(CVodeGetNumLinConvFails(m_cvode_mem, &nsetfails), "CVodeGetNumLinConvFails"); utils::check_cvode_flag(CVodeGetNumLinConvFails(sctx_p->cvode_mem, &nsetfails), "CVodeGetNumLinConvFails");
{ {
@@ -975,22 +990,26 @@ namespace gridfire::solver {
} }
// --- 3. Get Estimated Local Errors (Your Original Logic) --- // --- 3. Get Estimated Local Errors (Your Original Logic) ---
utils::check_cvode_flag(CVodeGetEstLocalErrors(m_cvode_mem, m_YErr), "CVodeGetEstLocalErrors"); utils::check_cvode_flag(CVodeGetEstLocalErrors(sctx_p->cvode_mem, sctx_p->YErr), "CVodeGetEstLocalErrors");
sunrealtype *y_data = N_VGetArrayPointer(m_Y); sunrealtype *y_data = N_VGetArrayPointer(sctx_p->Y);
sunrealtype *y_err_data = N_VGetArrayPointer(m_YErr); sunrealtype *y_err_data = N_VGetArrayPointer(sctx_p->YErr);
const auto absTol = m_config->solver.cvode.absTol;
const auto relTol = m_config->solver.cvode.relTol;
std::vector<double> err_ratios; std::vector<double> err_ratios;
const size_t num_components = N_VGetLength(m_Y); const size_t num_components = N_VGetLength(sctx_p->Y);
err_ratios.resize(num_components - 1); // Assuming -1 is for Energy or similar err_ratios.resize(num_components - 1); // Assuming -1 is for Energy or similar
std::vector<double> Y_full(y_data, y_data + num_components - 1); std::vector<double> Y_full(y_data, y_data + num_components - 1);
std::vector<double> E_full(y_err_data, y_err_data + num_components - 1); std::vector<double> E_full(y_err_data, y_err_data + num_components - 1);
auto result = diagnostics::report_limiting_species(ctx, *user_data.engine, Y_full, E_full, relTol, absTol, 10, to_file); if (!sctx_p->abs_tol.has_value()) {
sctx_p->abs_tol = m_config->solver.cvode.absTol;
}
if (!sctx_p->rel_tol.has_value()) {
sctx_p->rel_tol = m_config->solver.cvode.relTol;
}
auto result = diagnostics::report_limiting_species(ctx, *user_data.engine, Y_full, E_full, sctx_p->rel_tol.value(), sctx_p->abs_tol.value(), 10, to_file);
if (to_file && result.has_value()) { if (to_file && result.has_value()) {
j["Limiting_Species"] = result.value(); j["Limiting_Species"] = result.value();
} }
@@ -1003,8 +1022,9 @@ namespace gridfire::solver {
0.0 0.0
); );
for (size_t i = 0; i < num_components - 1; i++) { for (size_t i = 0; i < num_components - 1; i++) {
const double weight = relTol * std::abs(y_data[i]) + absTol; const double weight = sctx_p->rel_tol.value() * std::abs(y_data[i]) + sctx_p->abs_tol.value();
if (weight == 0.0) { if (weight == 0.0) {
err_ratios[i] = 0.0; // Avoid division by zero err_ratios[i] = 0.0; // Avoid division by zero
continue; continue;
@@ -1013,11 +1033,11 @@ namespace gridfire::solver {
err_ratios[i] = err_ratio; err_ratios[i] = err_ratio;
} }
fourdst::composition::Composition composition(user_data.engine->getNetworkSpecies(*m_scratch_blob), Y_full); fourdst::composition::Composition composition(user_data.engine->getNetworkSpecies(*sctx_p->engine_ctx), Y_full);
fourdst::composition::Composition collectedComposition = user_data.engine->collectComposition(*m_scratch_blob, composition, user_data.T9, user_data.rho); fourdst::composition::Composition collectedComposition = user_data.engine->collectComposition(*sctx_p->engine_ctx, composition, user_data.T9, user_data.rho);
auto destructionTimescales = user_data.engine->getSpeciesDestructionTimescales(*m_scratch_blob, collectedComposition, user_data.T9, user_data.rho); auto destructionTimescales = user_data.engine->getSpeciesDestructionTimescales(*sctx_p->engine_ctx, collectedComposition, user_data.T9, user_data.rho);
auto netTimescales = user_data.engine->getSpeciesTimescales(*m_scratch_blob, collectedComposition, user_data.T9, user_data.rho); auto netTimescales = user_data.engine->getSpeciesTimescales(*sctx_p->engine_ctx, collectedComposition, user_data.T9, user_data.rho);
bool timescaleOkay = false; bool timescaleOkay = false;
if (destructionTimescales && netTimescales) timescaleOkay = true; if (destructionTimescales && netTimescales) timescaleOkay = true;
@@ -1037,7 +1057,7 @@ namespace gridfire::solver {
if (destructionTimescales.value().contains(sp)) destructionTimescales_list.emplace_back(destructionTimescales.value().at(sp)); if (destructionTimescales.value().contains(sp)) destructionTimescales_list.emplace_back(destructionTimescales.value().at(sp));
else destructionTimescales_list.emplace_back(std::numeric_limits<double>::infinity()); else destructionTimescales_list.emplace_back(std::numeric_limits<double>::infinity());
speciesStatus_list.push_back(SpeciesStatus_to_string(user_data.engine->getSpeciesStatus(*m_scratch_blob, sp))); speciesStatus_list.push_back(SpeciesStatus_to_string(user_data.engine->getSpeciesStatus(*sctx_p->engine_ctx, sp)));
} }
utils::Column<fourdst::atomic::Species> speciesColumn("Species", species_list); utils::Column<fourdst::atomic::Species> speciesColumn("Species", species_list);

View File

@@ -1,770 +0,0 @@
#include "gridfire/solver/strategies/SpectralSolverStrategy.h"
#include <sunlinsol/sunlinsol_dense.h>
#include "gridfire/utils/sundials.h"
#include "quill/LogMacros.h"
#include "sunmatrix/sunmatrix_dense.h"
namespace {
std::pair<size_t, std::vector<double>> evaluate_bspline(
double x,
const gridfire::solver::SpectralSolverStrategy::SplineBasis& basis
) {
const int p = basis.degree;
const std::vector<double>& t = basis.knots;
auto it = std::ranges::upper_bound(t, x);
size_t i = std::distance(t.begin(), it) - 1;
if (i < static_cast<size_t>(p)) i = p;
if (i >= t.size() - 1 - p) i = t.size() - 2 - p;
if (x >= t.back()) {
i = t.size() - p - 2;
}
// Cox-de Boor algorithm
std::vector<double> N(p + 1);
std::vector<double> left(p + 1);
std::vector<double> right(p + 1);
N[0] = 1.0;
for (int j = 1; j <= p; ++j) {
left[j] = x - t[i + 1 - j];
right[j] = t[i + j] - x;
double saved = 0.0;
for (int r = 0; r < j; ++r) {
double temp = N[r] / (right[r + 1] + left[j - r]);
N[r] = saved + right[r + 1] * temp;
saved = left[j - r] * temp;
}
N[j] = saved;
}
return {i - p, N};
}
}
namespace gridfire::solver {
SpectralSolverStrategy::SpectralSolverStrategy(engine::DynamicEngine& engine) : MultiZoneNetworkSolverStrategy<engine::DynamicEngine> (engine) {
LOG_INFO(m_logger, "Initializing SpectralSolverStrategy");
utils::check_sundials_flag(SUNContext_Create(SUN_COMM_NULL, &m_sun_ctx), "SUNContext_Create", utils::SUNDIALS_RET_CODE_TYPES::CVODE);
m_absTol = m_config->solver.spectral.absTol;
m_relTol = m_config->solver.spectral.relTol;
LOG_INFO(m_logger, "SpectralSolverStrategy initialized successfully");
}
SpectralSolverStrategy::~SpectralSolverStrategy() {
LOG_INFO(m_logger, "Destroying SpectralSolverStrategy");
if (m_cvode_mem) {
CVodeFree(&m_cvode_mem);
m_cvode_mem = nullptr;
}
if (m_LS) SUNLinSolFree(m_LS);
if (m_J) SUNMatDestroy(m_J);
if (m_Y) N_VDestroy(m_Y);
if (m_constraints) N_VDestroy(m_constraints);
if (m_sun_ctx) {
SUNContext_Free(&m_sun_ctx);
m_sun_ctx = nullptr;
}
if (m_T_coeffs) N_VDestroy(m_T_coeffs);
if (m_rho_coeffs) N_VDestroy(m_rho_coeffs);
LOG_INFO(m_logger, "SpectralSolverStrategy destroyed successfully");
}
////////////////////////////////////////////////////////////////////////////////
/// Main Evaluation Loop
/////////////////////////////////////////////////////////////////////////////////
std::vector<NetOut> SpectralSolverStrategy::evaluate(const std::vector<NetIn>& netIns, const std::vector<double>& mass_coords) {
LOG_INFO(m_logger, "Starting spectral solver evaluation for {} zones", netIns.size());
assert(std::ranges::all_of(netIns, [&netIns](const NetIn& in) { return in.tMax == netIns[0].tMax; }) && "All NetIn entries must have the same tMax for spectral solver evaluation.");
std::vector<NetIn> updatedNetIns = netIns;
for (auto& netIn : updatedNetIns) {
netIn.composition = m_engine.update(netIn);
}
/////////////////////////////////////
/// Evaluate the monitor function ///
/////////////////////////////////////
const std::vector<double> monitor_function = evaluate_monitor_function(updatedNetIns);
m_current_basis = generate_basis_from_monitor(monitor_function, mass_coords);
size_t num_basis_funcs = m_current_basis.knots.size() - m_current_basis.degree - 1;
std::vector<BasisEval> shell_cache(updatedNetIns.size());
for (size_t shellID = 0; shellID < shell_cache.size(); ++shellID) {
auto [start, phi] = evaluate_bspline(mass_coords[shellID], m_current_basis);
shell_cache[shellID] = {.start_idx=start, .phi=phi};
}
DenseLinearSolver proj_solver(num_basis_funcs, m_sun_ctx);
proj_solver.init_from_cache(num_basis_funcs, shell_cache);
if (m_T_coeffs) N_VDestroy(m_T_coeffs);
m_T_coeffs = N_VNew_Serial(static_cast<sunindextype>(num_basis_funcs), m_sun_ctx);
project_specific_variable(updatedNetIns, mass_coords, shell_cache, proj_solver, m_T_coeffs, 0, [](const NetIn& s) { return s.temperature; }, true);
if (m_rho_coeffs) N_VDestroy(m_rho_coeffs);
m_rho_coeffs = N_VNew_Serial(static_cast<sunindextype>(num_basis_funcs), m_sun_ctx);
project_specific_variable(updatedNetIns, mass_coords, shell_cache, proj_solver, m_rho_coeffs, 0, [](const NetIn& s) { return s.density; }, true);
size_t num_species = m_engine.getNetworkSpecies().size();
size_t current_offset = 0;
size_t total_coefficients = num_basis_funcs * (num_species + 1);
if (m_Y) N_VDestroy(m_Y);
if (m_constraints) N_VDestroy(m_constraints);
m_Y = N_VNew_Serial(static_cast<sunindextype>(total_coefficients), m_sun_ctx);
m_constraints = N_VClone(m_Y);
N_VConst(0.0, m_constraints); // For now no constraints on coefficients
for (const auto& sp : m_engine.getNetworkSpecies()) {
project_specific_variable(
updatedNetIns,
mass_coords,
shell_cache,
proj_solver,
m_Y,
current_offset,
[&sp](const NetIn& s) { return s.composition.getMolarAbundance(sp); },
false
);
current_offset += num_basis_funcs;
}
sunrealtype* y_data = N_VGetArrayPointer(m_Y);
const size_t energy_offset = num_species * num_basis_funcs;
assert(energy_offset == current_offset && "Energy offset calculation mismatch in spectral solver initialization.");
for (size_t i = 0; i < num_basis_funcs; ++i) {
y_data[energy_offset + i] = 0.0;
}
DenseLinearSolver mass_solver(num_basis_funcs, m_sun_ctx);
mass_solver.init_from_basis(num_basis_funcs, m_current_basis);
/////////////////////////////////////
/// CVODE Initialization ///
/////////////////////////////////////
CVODEUserData data;
data.solver_instance = this;
data.engine = &m_engine;
data.mass_matrix_solver_instance = &mass_solver;
data.basis = &m_current_basis;
const double absTol = m_absTol.value_or(1e-10);
const double relTol = m_relTol.value_or(1e-6);
const bool size_changed = m_last_size != total_coefficients;
m_last_size = total_coefficients;
if (m_cvode_mem == nullptr || size_changed) {
if (m_cvode_mem) {
CVodeFree(&m_cvode_mem);
m_cvode_mem = nullptr;
}
if (m_LS) {
SUNLinSolFree(m_LS);
m_LS = nullptr;
}
if (m_J) {
SUNMatDestroy(m_J);
m_J = nullptr;
}
m_cvode_mem = CVodeCreate(CV_BDF, m_sun_ctx);
utils::check_sundials_flag(m_cvode_mem == nullptr ? -1 : 0, "CVodeCreate", utils::SUNDIALS_RET_CODE_TYPES::CVODE);
utils::check_sundials_flag(CVodeInit(m_cvode_mem, cvode_rhs_wrapper, 0.0, m_Y), "CVodeInit", utils::SUNDIALS_RET_CODE_TYPES::CVODE);
m_J = SUNDenseMatrix(static_cast<sunindextype>(total_coefficients), static_cast<sunindextype>(total_coefficients), m_sun_ctx);
m_LS = SUNLinSol_Dense(m_Y, m_J, m_sun_ctx);
utils::check_sundials_flag(CVodeSetLinearSolver(m_cvode_mem, m_LS, m_J), "CVodeSetLinearSolver", utils::SUNDIALS_RET_CODE_TYPES::CVODE);
// For now, we will not attach a Jacobian function, using finite differences
} else {
utils::check_sundials_flag(CVodeReInit(m_cvode_mem, 0.0, m_Y), "CVodeReInit", utils::SUNDIALS_RET_CODE_TYPES::CVODE);
}
utils::check_sundials_flag(CVodeSStolerances(m_cvode_mem, relTol, absTol), "CVodeSStolerances", utils::SUNDIALS_RET_CODE_TYPES::CVODE);
utils::check_sundials_flag(CVodeSetUserData(m_cvode_mem, &data), "CVodeSetUserData", utils::SUNDIALS_RET_CODE_TYPES::CVODE);
/////////////////////////////////////
/// Time Integration Loop ///
/////////////////////////////////////
const double target_time = updatedNetIns[0].tMax;
double current_time = 0.0;
while (current_time < target_time) {
int flag = CVode(m_cvode_mem, target_time, m_Y, &current_time, CV_ONE_STEP);
utils::check_sundials_flag(flag, "CVode", utils::SUNDIALS_RET_CODE_TYPES::CVODE);
std::println("Advanced to time: {:10.4e} / {:10.4e}", current_time, target_time);
}
std::vector<NetOut> results = reconstruct_solution(updatedNetIns, mass_coords, m_Y, m_current_basis, target_time);
return results;
}
void SpectralSolverStrategy::set_callback(const std::any &callback) {
m_callback = std::any_cast<TimestepCallback>(callback);
}
std::vector<std::tuple<std::string, std::string>> SpectralSolverStrategy::describe_callback_context() const {
throw std::runtime_error("SpectralSolverStrategy does not yet implement describe_callback_context.");
}
bool SpectralSolverStrategy::get_stdout_logging_enabled() const {
return m_stdout_logging_enabled;
}
void SpectralSolverStrategy::set_stdout_logging_enabled(bool logging_enabled) {
m_stdout_logging_enabled = logging_enabled;
}
////////////////////////////////////////////////////////////////////////////////
/// Static Wrappers for SUNDIALS Callbacks
////////////////////////////////////////////////////////////////////////////////
int SpectralSolverStrategy::cvode_rhs_wrapper(
const sunrealtype t,
const N_Vector y_coeffs,
const N_Vector ydot_coeffs,
void *user_data
) {
auto *data = static_cast<CVODEUserData*>(user_data);
const auto *instance = data->solver_instance;
try {
return instance -> calculate_rhs(t, y_coeffs, ydot_coeffs, data);
} catch (const std::exception& e) {
LOG_CRITICAL(instance->m_logger, "Uncaught exception in Spectral Solver RHS wrapper at time {}: {}", t, e.what());
return -1;
} catch (...) {
LOG_CRITICAL(instance->m_logger, "Unknown uncaught exception in Spectral Solver RHS wrapper at time {}", t);
return -1;
}
}
int SpectralSolverStrategy::cvode_jac_wrapper(
const sunrealtype t,
const N_Vector y,
const N_Vector ydot,
const SUNMatrix J,
void *user_data,
const N_Vector tmp1,
const N_Vector tmp2,
const N_Vector tmp3
) {
const auto *data = static_cast<CVODEUserData*>(user_data);
const auto *instance = data->solver_instance;
try {
LOG_WARNING_LIMIT_EVERY_N(1000, instance->m_logger, "Analytic Jacobian Generation not yet implemented, using finite difference approximation");
return 0;
} catch (const std::exception& e) {
LOG_CRITICAL(instance->m_logger, "Uncaught exception in Spectral Solver Jacobian wrapper at time {}: {}", t, e.what());
return -1;
} catch (...) {
LOG_CRITICAL(instance->m_logger, "Unknown uncaught exception in Spectral Solver Jacobian wrapper at time {}", t);
return -1;
}
}
////////////////////////////////////////////////////////////////////////////////
/// RHS implementation
////////////////////////////////////////////////////////////////////////////////
int SpectralSolverStrategy::calculate_rhs(
sunrealtype t,
N_Vector y_coeffs,
N_Vector ydot_coeffs,
CVODEUserData* data
) const {
const auto& basis = m_current_basis;
DenseLinearSolver* mass_solver = data->mass_matrix_solver_instance;
const auto& species_list = m_engine.getNetworkSpecies();
const size_t num_basis_funcs = basis.knots.size() - basis.degree - 1;
const size_t num_species = species_list.size();
sunrealtype* rhs_data = N_VGetArrayPointer(ydot_coeffs);
N_VConst(0.0, ydot_coeffs);
// PERF: In future we can use openMP to parallelize over these basis functions once we make the engines thread safe
for (size_t q = 0; q < basis.quadrature_nodes.size(); ++q) {
double w_q = basis.quadrature_weights[q];
const auto& [start_idx, phi] = basis.quad_evals[q];
GridPoint gp = reconstruct_at_quadrature(y_coeffs, q, basis);
std::expected<engine::StepDerivatives<double>, engine::EngineStatus> results = m_engine.calculateRHSAndEnergy(gp.composition, gp.T9, gp.rho, false);
// PERF: When switching to parallel execution, we will need to protect this section with a mutex or use atomic operations since we cannot throw safely from multiple threads
if (!results) {
LOG_CRITICAL(m_logger, "Engine failed to calculate RHS at time {}: {}", t, EngineStatus_to_string(results.error()));
return -1;
}
const auto& [dydt, eps_nuc, contributions, nu_loss, nu_flux] = results.value();
for (size_t s = 0; s < num_species; ++s) {
double rate = dydt.at(species_list[s]);
size_t species_offset = s * num_basis_funcs;
for (size_t k = 0; k < phi.size(); ++k) {
size_t global_idx = species_offset + start_idx + k;
rhs_data[global_idx] += w_q * phi[k] * rate;
}
}
size_t energy_offset = num_species * num_basis_funcs;
for (size_t k = 0; k < phi.size(); ++k) {
size_t global_idx = energy_offset + start_idx + k;
rhs_data[global_idx] += eps_nuc * w_q * phi[k];
}
}
size_t total_vars = num_species + 1;
mass_solver->solve_inplace(ydot_coeffs, total_vars, num_basis_funcs);
return 0;
}
////////////////////////////////////////////////////////////////////////////////
/// Spectral Utilities
/// These include basis generation, monitor function evaluation
/// projection and reconstruction routines.
////////////////////////////////////////////////////////////////////////////////
std::vector<double> SpectralSolverStrategy::evaluate_monitor_function(const std::vector<NetIn>& current_shells) const {
const size_t n_shells = current_shells.size();
if (n_shells < 3) {
return std::vector<double>(n_shells, 1.0); // NOLINT(*-return-braced-init-list)
}
std::vector<double> M(n_shells, 1.0);
auto accumulate_variable = [&](auto getter, double weight, bool use_log) {
std::vector<double> data(n_shells);
double min_val = std::numeric_limits<double>::max();
double max_val = std::numeric_limits<double>::lowest();
for (size_t i = 0 ; i < n_shells; ++i) {
double val = getter(current_shells[i]);
if (use_log) {
val = std::log10(std::max(val, 1e-100));
}
data[i] = val;
if (val < min_val) min_val = val;
if (val > max_val) max_val = val;
}
const double scale = max_val - min_val;
if (scale < 1e-10) return;
for (size_t i = 1; i < n_shells - 1; ++i) {
const double v_prev = data[i-1];
const double v_curr = data[i];
const double v_next = data[i+1];
// Finite difference estimates for first and second derivatives
double d1 = std::abs(v_next - v_prev) / 2.0;
double d2 = std::abs(v_next - 2.0 * v_curr + v_prev);
d1 /= scale;
d2 /= scale;
const double alpha = m_config->solver.spectral.monitorFunction.alpha;
const double beta = m_config->solver.spectral.monitorFunction.beta;
M[i] += weight * (alpha * d1 + beta * d2);
}
};
const double structure_weight = m_config->solver.spectral.monitorFunction.structure_weight;
double abundance_weight = m_config->solver.spectral.monitorFunction.abundance_weight;
accumulate_variable([](const NetIn& s) { return s.temperature; }, structure_weight, true);
accumulate_variable([](const NetIn& s) { return s.density; }, structure_weight, true);
for (const auto& sp : m_engine.getNetworkSpecies()) {
accumulate_variable([&sp](const NetIn& s) { return s.composition.getMolarAbundance(sp); }, abundance_weight, false);
}
//////////////////////////////
/// Smoothing the Monitor ///
//////////////////////////////
std::vector<double> M_smooth = M;
for (size_t i = 1; i < n_shells - 1; ++i) {
M_smooth[i] = (M[i-1] + 2.0 * M[i] + M[i+1]) / 4.0;
}
M_smooth[0] = M_smooth[1];
M_smooth[n_shells-1] = M_smooth[n_shells-2];
return M_smooth;
}
SpectralSolverStrategy::SplineBasis SpectralSolverStrategy::generate_basis_from_monitor(
const std::vector<double>& monitor_values,
const std::vector<double>& mass_coordinates
) const {
SplineBasis basis;
basis.degree = 3; // Cubic Spline
const size_t n_shells = monitor_values.size();
std::vector<double> I(n_shells, 0.0);
double current_integral = 0.0;
for (size_t i = 1; i < n_shells; ++i) {
const double dx = mass_coordinates[i] - mass_coordinates[i-1];
double dI = 0.5 * (monitor_values[i] + monitor_values[i-1]) * dx;
dI = std::max(dI, 1e-30);
current_integral += dI;
I[i] = current_integral;
}
const double total_integral = I.back();
for (size_t i = 0; i < n_shells; ++i) {
I[i] /= total_integral;
}
const size_t num_elements = m_config->solver.spectral.basis.num_elements;
basis.knots.reserve(num_elements + 1 + 2 * basis.degree);
// Note that these imply that mass_coordinates must be sorted in increasing order
double min_mass = mass_coordinates.front();
double max_mass = mass_coordinates.back();
for (int i = 0; i < basis.degree; ++i) {
basis.knots.push_back(min_mass);
}
for (size_t k = 1; k < num_elements; ++k) {
double target_I = static_cast<double>(k) / static_cast<double>(num_elements);
auto it = std::ranges::lower_bound(I, target_I);
size_t idx = std::distance(I.begin(), it);
if (idx == 0) idx = 1;
if (idx >= n_shells) idx = n_shells - 1;
double I0 = I[idx-1];
double I1 = I[idx];
double m0 = mass_coordinates[idx-1];
double m1 = mass_coordinates[idx];
double fraction = (target_I - I0) / (I1 - I0);
double knot_location = m0 + fraction * (m1 - m0);
basis.knots.push_back(knot_location);
}
for (int i = 0; i < basis.degree; ++i) {
basis.knots.push_back(max_mass);
}
constexpr double sqrt_3_over_5 = 0.77459666924;
constexpr double five_over_nine = 5.0 / 9.0;
constexpr double eight_over_nine = 8.0 / 9.0;
static constexpr std::array<double, 3> gl_nodes = {-sqrt_3_over_5, 0.0, sqrt_3_over_5};
static constexpr std::array<double, 3> gl_weights = {five_over_nine, eight_over_nine, five_over_nine};
basis.quadrature_nodes.clear();
basis.quadrature_weights.clear();
for (size_t i = basis.degree; i < basis.knots.size() - basis.degree - 1; ++i) {
double a = basis.knots[i];
double b = basis.knots[i+1];
if ( b - a < 1e-14) continue;
double mid = 0.5 * (a + b);
double half_width = 0.5 * (b - a);
for (size_t j = 0; j < gl_nodes.size(); ++j) {
double phys_node = mid + gl_nodes[j] * half_width;
double phys_weight = gl_weights[j] * half_width;
basis.quadrature_nodes.push_back(phys_node);
basis.quadrature_weights.push_back(phys_weight);
auto [start, phi] = evaluate_bspline(phys_node, basis);
basis.quad_evals.push_back({start, phi});
}
}
return basis;
}
SpectralSolverStrategy::GridPoint SpectralSolverStrategy::reconstruct_at_quadrature(
const N_Vector y_coeffs,
const size_t quad_index,
const SplineBasis &basis
) const {
auto [start_idx, vals] = basis.quad_evals[quad_index];
const sunrealtype* T_ptr = N_VGetArrayPointer(m_T_coeffs);
const sunrealtype* rho_ptr = N_VGetArrayPointer(m_rho_coeffs);
const sunrealtype* y_data = N_VGetArrayPointer(y_coeffs);
const size_t num_basis_funcs = basis.knots.size() - basis.degree - 1;
const std::vector<fourdst::atomic::Species>& species_list = m_engine.getNetworkSpecies();
const size_t num_species = species_list.size();
double logT = 0.0;
double logRho = 0.0;
for (size_t k = 0; k < vals.size(); ++k) {
size_t idx = start_idx + k;
logT += T_ptr[idx] * vals[k];
logRho += rho_ptr[idx] * vals[k];
}
GridPoint result;
result.T9 = std::pow(10.0, logT) / 1e9;
result.rho = std::pow(10.0, logRho);
for (size_t s = 0; s < num_species; ++s) {
const fourdst::atomic::Species& species = species_list[s];
double abundance = 0.0;
const size_t offset = s * num_basis_funcs;
for (size_t k = 0; k < vals.size(); ++k) {
abundance += y_data[offset + start_idx + k] * vals[k];
}
// Note: It is possible this will lead to a loss of mass conservation. In future we may want to implement a better way to handle this.
if (abundance < 0.0) abundance = 0.0;
result.composition.registerSpecies(species);
result.composition.setMolarAbundance(species, abundance);
}
return result;
}
std::vector<NetOut> SpectralSolverStrategy::reconstruct_solution(
const std::vector<NetIn>& original_inputs,
const std::vector<double>& mass_coordinates,
const N_Vector final_coeffs,
const SplineBasis& basis,
const double dt
) const {
const size_t n_shells = original_inputs.size();
const size_t num_basis_funcs = basis.knots.size() - basis.degree - 1;
std::vector<NetOut> outputs;
outputs.reserve(n_shells);
const sunrealtype* c_data = N_VGetArrayPointer(final_coeffs);
const auto& species_list = m_engine.getNetworkSpecies();
for (size_t shellID = 0; shellID < n_shells; ++shellID) {
const double x = mass_coordinates[shellID];
auto [start_idx, vals] = evaluate_bspline(x, basis);
auto reconstruct_var = [&](const size_t coeff_offset) -> double {
double result = 0.0;
for (size_t i = 0; i < vals.size(); ++i) {
result += c_data[coeff_offset + start_idx + i] * vals[i];
}
return result;
};
fourdst::composition::Composition comp_new;
for (size_t s_idx = 0; s_idx < species_list.size(); ++s_idx) {
const fourdst::atomic::Species& sp = species_list[s_idx];
comp_new.registerSpecies(sp);
const size_t current_offset = s_idx * num_basis_funcs;
double Y_val = reconstruct_var(current_offset);
if (Y_val < 0.0 && Y_val > -1.0e-16) {
Y_val = 0.0;
}
if (Y_val < 0.0 && Y_val > -1e-16) Y_val = 0.0;
if (Y_val >= 0.0) {
comp_new.setMolarAbundance(sp, Y_val);
}
}
const double energy = reconstruct_var(species_list.size() * num_basis_funcs);
NetOut netOut;
netOut.composition = comp_new;
netOut.energy = energy;
netOut.num_steps = -1; // Not tracked in spectral solver
outputs.push_back(std::move(netOut));
}
return outputs;
}
void SpectralSolverStrategy::project_specific_variable(
const std::vector<NetIn> &current_shells,
const std::vector<double> &mass_coordinates,
const std::vector<BasisEval> &shell_cache,
const DenseLinearSolver &linear_solver,
N_Vector output_vec,
size_t output_offset,
const std::function<double(const NetIn &)> &getter,
bool use_log
) {
const size_t n_shells = current_shells.size();
sunrealtype* out_ptr = N_VGetArrayPointer(output_vec);
size_t basis_size = N_VGetLength(linear_solver.temp_vector);
for (size_t i = 0; i < basis_size; ++i ) {
out_ptr[output_offset + i] = 0.0;
}
for (size_t shellID = 0; shellID < n_shells; ++shellID) {
double val = getter(current_shells[shellID]);
if (use_log) val = std::log10(std::max(val, 1e-100));
const auto& eval = shell_cache[shellID];
for (size_t i = 0; i < eval.phi.size(); ++i) {
out_ptr[output_offset + eval.start_idx + i] += val * eval.phi[i];
}
}
sunrealtype* tmp_data = N_VGetArrayPointer(linear_solver.temp_vector);
for (size_t i = 0; i < basis_size; ++i) tmp_data[i] = out_ptr[output_offset + i];
SUNLinSolSolve(linear_solver.LS, linear_solver.A, linear_solver.temp_vector, linear_solver.temp_vector, 0.0);
for (size_t i = 0; i < basis_size; ++i) out_ptr[output_offset + i] = tmp_data[i];
}
///////////////////////////////////////////////////////////////////////////////
/// SpectralSolverStrategy::MassMatrixSolver Implementation
///////////////////////////////////////////////////////////////////////////////
SpectralSolverStrategy::DenseLinearSolver::DenseLinearSolver(
size_t size,
SUNContext sun_ctx
) : ctx(sun_ctx) {
A = SUNDenseMatrix(size, size, sun_ctx);
temp_vector = N_VNew_Serial(size, sun_ctx);
LS = SUNLinSol_Dense(temp_vector, A, sun_ctx);
if (!A || !temp_vector || !LS) {
throw std::runtime_error("Failed to create MassMatrixSolver components.");
}
zero();
}
SpectralSolverStrategy::DenseLinearSolver::~DenseLinearSolver() {
if (LS) SUNLinSolFree(LS);
if (A) SUNMatDestroy(A);
if (temp_vector) N_VDestroy(temp_vector);
}
void SpectralSolverStrategy::DenseLinearSolver::zero() const {
SUNMatZero(A);
}
void SpectralSolverStrategy::DenseLinearSolver::init_from_cache(
const size_t num_basis_funcs,
const std::vector<BasisEval> &shell_cache
) const {
sunrealtype* a_data = SUNDenseMatrix_Data(A);
for (const auto&[start_idx, phi] : shell_cache) {
for (size_t i = 0; i < phi.size(); ++i) {
const size_t row = start_idx + i;
for (size_t j = 0; j < phi.size(); ++j) {
const size_t col = start_idx + j;
a_data[col * num_basis_funcs + row] += phi[i] * phi[j];
}
}
}
setup();
}
void SpectralSolverStrategy::DenseLinearSolver::init_from_basis(
const size_t num_basis_funcs,
const SplineBasis &basis
) const {
sunrealtype* m_data = SUNDenseMatrix_Data(A);
for (size_t q = 0; q < basis.quadrature_nodes.size(); ++q) {
double w_q = basis.quadrature_weights[q];
const auto& eval = basis.quad_evals[q];
for (size_t i = 0; i < eval.phi.size(); ++i) {
size_t row = eval.start_idx + i;
for (size_t j = 0; j < eval.phi.size(); ++j) {
size_t col = eval.start_idx + j;
m_data[col * num_basis_funcs + row] += w_q * eval.phi[j] * eval.phi[i];
}
}
}
setup();
}
void SpectralSolverStrategy::DenseLinearSolver::setup() const {
utils::check_sundials_flag(SUNLinSolSetup(LS, A), "SUNLinSolSetup - Mass Matrix Solver", utils::SUNDIALS_RET_CODE_TYPES::CVODE);
}
// ReSharper disable once CppMemberFunctionMayBeConst
void SpectralSolverStrategy::DenseLinearSolver::solve_inplace(const N_Vector x, const size_t num_vars, const size_t basis_size) const {
sunrealtype* x_data = N_VGetArrayPointer(x);
sunrealtype* tmp_data = N_VGetArrayPointer(temp_vector);
for (size_t v = 0; v < num_vars; ++v) {
const size_t offset = v * basis_size;
for (size_t i = 0; i < basis_size; ++i) {
tmp_data[i] = x_data[offset + i];
}
SUNLinSolSolve(LS, A, temp_vector, temp_vector, 0.0);
for (size_t i = 0; i < basis_size; ++i) {
x_data[offset + i] = tmp_data[i];
}
}
}
}

View File

@@ -1,5 +1,5 @@
#include "gridfire/solver/strategies/triggers/engine_partitioning_trigger.h" #include "gridfire/solver/strategies/triggers/engine_partitioning_trigger.h"
#include "gridfire/solver/strategies/CVODE_solver_strategy.h" #include "gridfire/solver/strategies/PointSolver.h"
#include "gridfire/trigger/trigger_logical.h" #include "gridfire/trigger/trigger_logical.h"
#include "gridfire/trigger/trigger_abstract.h" #include "gridfire/trigger/trigger_abstract.h"
@@ -28,7 +28,7 @@ namespace gridfire::trigger::solver::CVODE {
} }
} }
bool SimulationTimeTrigger::check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const { bool SimulationTimeTrigger::check(const gridfire::solver::PointSolverTimestepContext &ctx) const {
if (ctx.t - m_last_trigger_time >= m_interval) { if (ctx.t - m_last_trigger_time >= m_interval) {
m_hits++; m_hits++;
LOG_TRACE_L2(m_logger, "SimulationTimeTrigger triggered at t = {}, last trigger time was {}, delta = {}", ctx.t, m_last_trigger_time, m_last_trigger_time_delta); LOG_TRACE_L2(m_logger, "SimulationTimeTrigger triggered at t = {}, last trigger time was {}, delta = {}", ctx.t, m_last_trigger_time, m_last_trigger_time_delta);
@@ -38,7 +38,7 @@ namespace gridfire::trigger::solver::CVODE {
return false; return false;
} }
void SimulationTimeTrigger::update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) { void SimulationTimeTrigger::update(const gridfire::solver::PointSolverTimestepContext &ctx) {
if (check(ctx)) { if (check(ctx)) {
m_last_trigger_time_delta = (ctx.t - m_last_trigger_time) - m_interval; m_last_trigger_time_delta = (ctx.t - m_last_trigger_time) - m_interval;
m_last_trigger_time = ctx.t; m_last_trigger_time = ctx.t;
@@ -47,7 +47,7 @@ namespace gridfire::trigger::solver::CVODE {
} }
void SimulationTimeTrigger::step( void SimulationTimeTrigger::step(
const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx const gridfire::solver::PointSolverTimestepContext &ctx
) { ) {
// --- SimulationTimeTrigger::step does nothing and is intentionally left blank --- // // --- SimulationTimeTrigger::step does nothing and is intentionally left blank --- //
} }
@@ -65,7 +65,7 @@ namespace gridfire::trigger::solver::CVODE {
return "Simulation Time Trigger"; return "Simulation Time Trigger";
} }
TriggerResult SimulationTimeTrigger::why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const { TriggerResult SimulationTimeTrigger::why(const gridfire::solver::PointSolverTimestepContext &ctx) const {
TriggerResult result; TriggerResult result;
result.name = name(); result.name = name();
if (check(ctx)) { if (check(ctx)) {
@@ -99,18 +99,18 @@ namespace gridfire::trigger::solver::CVODE {
} }
} }
bool OffDiagonalTrigger::check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const { bool OffDiagonalTrigger::check(const gridfire::solver::PointSolverTimestepContext &ctx) const {
//TODO : This currently does nothing //TODO : This currently does nothing
return false; return false;
} }
void OffDiagonalTrigger::update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) { void OffDiagonalTrigger::update(const gridfire::solver::PointSolverTimestepContext &ctx) {
m_updates++; m_updates++;
} }
void OffDiagonalTrigger::step( void OffDiagonalTrigger::step(
const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx const gridfire::solver::PointSolverTimestepContext &ctx
) { ) {
// --- OffDiagonalTrigger::step does nothing and is intentionally left blank --- // // --- OffDiagonalTrigger::step does nothing and is intentionally left blank --- //
} }
@@ -126,7 +126,7 @@ namespace gridfire::trigger::solver::CVODE {
return "Off-Diagonal Trigger"; return "Off-Diagonal Trigger";
} }
TriggerResult OffDiagonalTrigger::why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const { TriggerResult OffDiagonalTrigger::why(const gridfire::solver::PointSolverTimestepContext &ctx) const {
TriggerResult result; TriggerResult result;
result.name = name(); result.name = name();
@@ -173,7 +173,7 @@ namespace gridfire::trigger::solver::CVODE {
} }
} }
bool TimestepCollapseTrigger::check(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const { bool TimestepCollapseTrigger::check(const gridfire::solver::PointSolverTimestepContext &ctx) const {
if (m_timestep_window.size() < m_windowSize) { if (m_timestep_window.size() < m_windowSize) {
m_misses++; m_misses++;
return false; return false;
@@ -201,13 +201,13 @@ namespace gridfire::trigger::solver::CVODE {
return false; return false;
} }
void TimestepCollapseTrigger::update(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) { void TimestepCollapseTrigger::update(const gridfire::solver::PointSolverTimestepContext &ctx) {
m_updates++; m_updates++;
m_timestep_window.clear(); m_timestep_window.clear();
} }
void TimestepCollapseTrigger::step( void TimestepCollapseTrigger::step(
const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx const gridfire::solver::PointSolverTimestepContext &ctx
) { ) {
push_to_fixed_deque(m_timestep_window, ctx.dt, m_windowSize); push_to_fixed_deque(m_timestep_window, ctx.dt, m_windowSize);
// --- TimestepCollapseTrigger::step does nothing and is intentionally left blank --- // // --- TimestepCollapseTrigger::step does nothing and is intentionally left blank --- //
@@ -226,7 +226,7 @@ namespace gridfire::trigger::solver::CVODE {
} }
TriggerResult TimestepCollapseTrigger::why( TriggerResult TimestepCollapseTrigger::why(
const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx const gridfire::solver::PointSolverTimestepContext &ctx
) const { ) const {
TriggerResult result; TriggerResult result;
result.name = name(); result.name = name();
@@ -263,7 +263,7 @@ namespace gridfire::trigger::solver::CVODE {
m_windowSize(windowSize) {} m_windowSize(windowSize) {}
bool ConvergenceFailureTrigger::check( bool ConvergenceFailureTrigger::check(
const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx const gridfire::solver::PointSolverTimestepContext &ctx
) const { ) const {
if (m_window.size() != m_windowSize) { if (m_window.size() != m_windowSize) {
m_misses++; m_misses++;
@@ -278,13 +278,13 @@ namespace gridfire::trigger::solver::CVODE {
} }
void ConvergenceFailureTrigger::update( void ConvergenceFailureTrigger::update(
const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx const gridfire::solver::PointSolverTimestepContext &ctx
) { ) {
m_window.clear(); m_window.clear();
} }
void ConvergenceFailureTrigger::step( void ConvergenceFailureTrigger::step(
const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx const gridfire::solver::PointSolverTimestepContext &ctx
) { ) {
push_to_fixed_deque(m_window, ctx.currentConvergenceFailures, m_windowSize); push_to_fixed_deque(m_window, ctx.currentConvergenceFailures, m_windowSize);
m_updates++; m_updates++;
@@ -306,7 +306,7 @@ namespace gridfire::trigger::solver::CVODE {
return "ConvergenceFailureTrigger(abs_failure_threshold=" + std::to_string(m_totalFailures) + ", rel_failure_threshold=" + std::to_string(m_relativeFailureRate) + ", windowSize=" + std::to_string(m_windowSize) + ")"; return "ConvergenceFailureTrigger(abs_failure_threshold=" + std::to_string(m_totalFailures) + ", rel_failure_threshold=" + std::to_string(m_relativeFailureRate) + ", windowSize=" + std::to_string(m_windowSize) + ")";
} }
TriggerResult ConvergenceFailureTrigger::why(const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx) const { TriggerResult ConvergenceFailureTrigger::why(const gridfire::solver::PointSolverTimestepContext &ctx) const {
TriggerResult result; TriggerResult result;
result.name = name(); result.name = name();
@@ -348,7 +348,7 @@ namespace gridfire::trigger::solver::CVODE {
} }
bool ConvergenceFailureTrigger::abs_failure( bool ConvergenceFailureTrigger::abs_failure(
const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx const gridfire::solver::PointSolverTimestepContext &ctx
) const { ) const {
if (ctx.currentConvergenceFailures > m_totalFailures) { if (ctx.currentConvergenceFailures > m_totalFailures) {
return true; return true;
@@ -357,7 +357,7 @@ namespace gridfire::trigger::solver::CVODE {
} }
bool ConvergenceFailureTrigger::rel_failure( bool ConvergenceFailureTrigger::rel_failure(
const gridfire::solver::CVODESolverStrategy::TimestepContext &ctx const gridfire::solver::PointSolverTimestepContext &ctx
) const { ) const {
const float mean = current_mean(); const float mean = current_mean();
if (mean < 10) { if (mean < 10) {
@@ -369,13 +369,13 @@ namespace gridfire::trigger::solver::CVODE {
return false; return false;
} }
std::unique_ptr<Trigger<gridfire::solver::CVODESolverStrategy::TimestepContext>> makeEnginePartitioningTrigger( std::unique_ptr<Trigger<gridfire::solver::PointSolverTimestepContext>> makeEnginePartitioningTrigger(
const double simulationTimeInterval, const double simulationTimeInterval,
const double offDiagonalThreshold, const double offDiagonalThreshold,
const double timestepCollapseRatio, const double timestepCollapseRatio,
const size_t maxConvergenceFailures const size_t maxConvergenceFailures
) { ) {
using ctx_t = gridfire::solver::CVODESolverStrategy::TimestepContext; using ctx_t = gridfire::solver::PointSolverTimestepContext;
// 1. INSTABILITY TRIGGERS (High Priority) // 1. INSTABILITY TRIGGERS (High Priority)
auto convergenceFailureTrigger = std::make_unique<ConvergenceFailureTrigger>( auto convergenceFailureTrigger = std::make_unique<ConvergenceFailureTrigger>(

View File

@@ -15,8 +15,8 @@ gridfire_sources = files(
'lib/reaction/weak/weak_interpolator.cpp', 'lib/reaction/weak/weak_interpolator.cpp',
'lib/io/network_file.cpp', 'lib/io/network_file.cpp',
'lib/io/generative/python.cpp', 'lib/io/generative/python.cpp',
'lib/solver/strategies/CVODE_solver_strategy.cpp', 'lib/solver/strategies/PointSolver.cpp',
# 'lib/solver/strategies/SpectralSolverStrategy.cpp', 'lib/solver/strategies/GridSolver.cpp',
'lib/solver/strategies/triggers/engine_partitioning_trigger.cpp', 'lib/solver/strategies/triggers/engine_partitioning_trigger.cpp',
'lib/screening/screening_types.cpp', 'lib/screening/screening_types.cpp',
'lib/screening/screening_weak.cpp', 'lib/screening/screening_weak.cpp',
@@ -58,12 +58,21 @@ if get_option('openmp_support')
endif endif
# Define the libnetwork library so it can be linked against by other parts of the build system # Define the libnetwork library so it can be linked against by other parts of the build system
libgridfire = library('gridfire', if get_option('build_python')
gridfire_sources, libgridfire = static_library('gridfire',
include_directories: include_directories('include'), gridfire_sources,
dependencies: gridfire_build_dependencies, include_directories: include_directories('include'),
objects: [cvode_objs, kinsol_objs], dependencies: gridfire_build_dependencies,
install : true) objects: [cvode_objs, kinsol_objs],
install : false)
else
libgridfire = library('gridfire',
gridfire_sources,
include_directories: include_directories('include'),
dependencies: gridfire_build_dependencies,
objects: [cvode_objs, kinsol_objs],
install : true)
endif
gridfire_dep = declare_dependency( gridfire_dep = declare_dependency(
include_directories: include_directories('include'), include_directories: include_directories('include'),

View File

@@ -4,6 +4,7 @@
#include "types/bindings.h" #include "types/bindings.h"
#include "partition/bindings.h" #include "partition/bindings.h"
#include "engine/bindings.h" #include "engine/bindings.h"
#include "engine/scratchpads/bindings.h"
#include "exceptions/bindings.h" #include "exceptions/bindings.h"
#include "io/bindings.h" #include "io/bindings.h"
#include "reaction/bindings.h" #include "reaction/bindings.h"
@@ -11,6 +12,7 @@
#include "solver/bindings.h" #include "solver/bindings.h"
#include "utils/bindings.h" #include "utils/bindings.h"
#include "policy/bindings.h" #include "policy/bindings.h"
#include "config/bindings.h"
PYBIND11_MODULE(_gridfire, m) { PYBIND11_MODULE(_gridfire, m) {
m.doc() = "Python bindings for the fourdst utility modules which are a part of the 4D-STAR project."; m.doc() = "Python bindings for the fourdst utility modules which are a part of the 4D-STAR project.";
@@ -20,6 +22,9 @@ PYBIND11_MODULE(_gridfire, m) {
pybind11::module::import("fourdst.config"); pybind11::module::import("fourdst.config");
pybind11::module::import("fourdst.atomic"); pybind11::module::import("fourdst.atomic");
auto configMod = m.def_submodule("config", "GridFire configuration bindings");
register_config_bindings(configMod);
auto typeMod = m.def_submodule("type", "GridFire type bindings"); auto typeMod = m.def_submodule("type", "GridFire type bindings");
register_type_bindings(typeMod); register_type_bindings(typeMod);
@@ -39,6 +44,12 @@ PYBIND11_MODULE(_gridfire, m) {
register_exception_bindings(exceptionMod); register_exception_bindings(exceptionMod);
auto engineMod = m.def_submodule("engine", "Engine and Engine View bindings"); auto engineMod = m.def_submodule("engine", "Engine and Engine View bindings");
auto scratchpadMod = engineMod.def_submodule("scratchpads", "Engine ScratchPad bindings");
register_scratchpad_types_bindings(scratchpadMod);
register_scratchpad_bindings(scratchpadMod);
register_state_blob_bindings(scratchpadMod);
register_engine_bindings(engineMod); register_engine_bindings(engineMod);
auto solverMod = m.def_submodule("solver", "GridFire numerical solver bindings"); auto solverMod = m.def_submodule("solver", "GridFire numerical solver bindings");

View File

@@ -0,0 +1,34 @@
#include "bindings.h"
#include "gridfire/config/config.h"
#include <pybind11/pybind11.h>
namespace py = pybind11;
void register_config_bindings(pybind11::module &m) {
py::class_<gridfire::config::CVODESolverConfig>(m, "CVODESolverConfig")
.def(py::init<>())
.def_readwrite("absTol", &gridfire::config::CVODESolverConfig::absTol)
.def_readwrite("relTol", &gridfire::config::CVODESolverConfig::relTol);
py::class_<gridfire::config::SolverConfig>(m, "SolverConfig")
.def(py::init<>())
.def_readwrite("cvode", &gridfire::config::SolverConfig::cvode);
py::class_<gridfire::config::AdaptiveEngineViewConfig>(m, "AdaptiveEngineViewConfig")
.def(py::init<>())
.def_readwrite("relativeCullingThreshold", &gridfire::config::AdaptiveEngineViewConfig::relativeCullingThreshold);
py::class_<gridfire::config::EngineViewConfig>(m, "EngineViewConfig")
.def(py::init<>())
.def_readwrite("adaptiveEngineView", &gridfire::config::EngineViewConfig::adaptiveEngineView);
py::class_<gridfire::config::EngineConfig>(m, "EngineConfig")
.def(py::init<>())
.def_readwrite("views", &gridfire::config::EngineConfig::views);
py::class_<gridfire::config::GridFireConfig>(m, "GridFireConfig")
.def(py::init<>())
.def_readwrite("solver", &gridfire::config::GridFireConfig::solver)
.def_readwrite("engine", &gridfire::config::GridFireConfig::engine);
}

View File

@@ -0,0 +1,5 @@
#pragma once
#include <pybind11/pybind11.h>
void register_config_bindings(pybind11::module &m);

View File

@@ -12,6 +12,7 @@
namespace py = pybind11; namespace py = pybind11;
namespace sp = gridfire::engine::scratch;
namespace { namespace {
template <typename T> template <typename T>
@@ -23,16 +24,18 @@ namespace {
"calculateRHSAndEnergy", "calculateRHSAndEnergy",
[]( [](
const gridfire::engine::DynamicEngine& self, const gridfire::engine::DynamicEngine& self,
sp::StateBlob& ctx,
const fourdst::composition::Composition& comp, const fourdst::composition::Composition& comp,
const double T9, const double T9,
const double rho const double rho
) { ) {
auto result = self.calculateRHSAndEnergy(comp, T9, rho); auto result = self.calculateRHSAndEnergy(ctx, comp, T9, rho, false);
if (!result.has_value()) { if (!result.has_value()) {
throw gridfire::exceptions::EngineError(std::format("calculateRHSAndEnergy returned a potentially recoverable error {}", gridfire::engine::EngineStatus_to_string(result.error()))); throw gridfire::exceptions::EngineError(std::format("calculateRHSAndEnergy returned a potentially recoverable error {}", gridfire::engine::EngineStatus_to_string(result.error())));
} }
return result.value(); return result.value();
}, },
py::arg("ctx"),
py::arg("comp"), py::arg("comp"),
py::arg("T9"), py::arg("T9"),
py::arg("rho"), py::arg("rho"),
@@ -40,6 +43,7 @@ namespace {
) )
.def("calculateEpsDerivatives", .def("calculateEpsDerivatives",
&gridfire::engine::DynamicEngine::calculateEpsDerivatives, &gridfire::engine::DynamicEngine::calculateEpsDerivatives,
py::arg("ctx"),
py::arg("comp"), py::arg("comp"),
py::arg("T9"), py::arg("T9"),
py::arg("rho"), py::arg("rho"),
@@ -47,11 +51,13 @@ namespace {
) )
.def("generateJacobianMatrix", .def("generateJacobianMatrix",
[](const gridfire::engine::DynamicEngine& self, [](const gridfire::engine::DynamicEngine& self,
sp::StateBlob& ctx,
const fourdst::composition::Composition& comp, const fourdst::composition::Composition& comp,
const double T9, const double T9,
const double rho) -> gridfire::engine::NetworkJacobian { const double rho) -> gridfire::engine::NetworkJacobian {
return self.generateJacobianMatrix(comp, T9, rho); return self.generateJacobianMatrix(ctx, comp, T9, rho);
}, },
py::arg("ctx"),
py::arg("comp"), py::arg("comp"),
py::arg("T9"), py::arg("T9"),
py::arg("rho"), py::arg("rho"),
@@ -59,12 +65,14 @@ namespace {
) )
.def("generateJacobianMatrix", .def("generateJacobianMatrix",
[](const gridfire::engine::DynamicEngine& self, [](const gridfire::engine::DynamicEngine& self,
sp::StateBlob& ctx,
const fourdst::composition::Composition& comp, const fourdst::composition::Composition& comp,
const double T9, const double T9,
const double rho, const double rho,
const std::vector<fourdst::atomic::Species>& activeSpecies) -> gridfire::engine::NetworkJacobian { const std::vector<fourdst::atomic::Species>& activeSpecies) -> gridfire::engine::NetworkJacobian {
return self.generateJacobianMatrix(comp, T9, rho, activeSpecies); return self.generateJacobianMatrix(ctx, comp, T9, rho, activeSpecies);
}, },
py::arg("ctx"),
py::arg("comp"), py::arg("comp"),
py::arg("T9"), py::arg("T9"),
py::arg("rho"), py::arg("rho"),
@@ -73,31 +81,32 @@ namespace {
) )
.def("generateJacobianMatrix", .def("generateJacobianMatrix",
[](const gridfire::engine::DynamicEngine& self, [](const gridfire::engine::DynamicEngine& self,
sp::StateBlob& ctx,
const fourdst::composition::Composition& comp, const fourdst::composition::Composition& comp,
const double T9, const double T9,
const double rho, const double rho,
const gridfire::engine::SparsityPattern& sparsityPattern) -> gridfire::engine::NetworkJacobian { const gridfire::engine::SparsityPattern& sparsityPattern) -> gridfire::engine::NetworkJacobian {
return self.generateJacobianMatrix(comp, T9, rho, sparsityPattern); return self.generateJacobianMatrix(ctx, comp, T9, rho, sparsityPattern);
}, },
py::arg("ctx"),
py::arg("comp"), py::arg("comp"),
py::arg("T9"), py::arg("T9"),
py::arg("rho"), py::arg("rho"),
py::arg("sparsityPattern"), py::arg("sparsityPattern"),
"Generate the jacobian matrix for the given sparsity pattern" "Generate the jacobian matrix for the given sparsity pattern"
) )
.def("generateStoichiometryMatrix",
&T::generateStoichiometryMatrix
)
.def("calculateMolarReactionFlow", .def("calculateMolarReactionFlow",
[]( [](
const gridfire::engine::DynamicEngine& self, const gridfire::engine::DynamicEngine& self,
sp::StateBlob& ctx,
const gridfire::reaction::Reaction& reaction, const gridfire::reaction::Reaction& reaction,
const fourdst::composition::Composition& comp, const fourdst::composition::Composition& comp,
const double T9, const double T9,
const double rho const double rho
) -> double { ) -> double {
return self.calculateMolarReactionFlow(reaction, comp, T9, rho); return self.calculateMolarReactionFlow(ctx, reaction, comp, T9, rho);
}, },
py::arg("ctx"),
py::arg("reaction"), py::arg("reaction"),
py::arg("comp"), py::arg("comp"),
py::arg("T9"), py::arg("T9"),
@@ -110,28 +119,21 @@ namespace {
.def("getNetworkReactions", &T::getNetworkReactions, .def("getNetworkReactions", &T::getNetworkReactions,
"Get the set of logical reactions in the network." "Get the set of logical reactions in the network."
) )
.def ("setNetworkReactions", &T::setNetworkReactions,
py::arg("reactions"),
"Set the network reactions to a new set of reactions."
)
.def("getStoichiometryMatrixEntry", &T::getStoichiometryMatrixEntry,
py::arg("species"),
py::arg("reaction"),
"Get an entry from the stoichiometry matrix."
)
.def("getSpeciesTimescales", .def("getSpeciesTimescales",
[]( [](
const gridfire::engine::DynamicEngine& self, const gridfire::engine::DynamicEngine& self,
sp::StateBlob& ctx,
const fourdst::composition::Composition& comp, const fourdst::composition::Composition& comp,
const double T9, const double T9,
const double rho const double rho
) -> std::unordered_map<fourdst::atomic::Species, double> { ) -> std::unordered_map<fourdst::atomic::Species, double> {
const auto result = self.getSpeciesTimescales(comp, T9, rho); const auto result = self.getSpeciesTimescales(ctx, comp, T9, rho);
if (!result.has_value()) { if (!result.has_value()) {
throw gridfire::exceptions::EngineError(std::format("getSpeciesTimescales has returned a potentially recoverable error {}", gridfire::engine::EngineStatus_to_string(result.error()))); throw gridfire::exceptions::EngineError(std::format("getSpeciesTimescales has returned a potentially recoverable error {}", gridfire::engine::EngineStatus_to_string(result.error())));
} }
return result.value(); return result.value();
}, },
py::arg("ctx"),
py::arg("comp"), py::arg("comp"),
py::arg("T9"), py::arg("T9"),
py::arg("rho"), py::arg("rho"),
@@ -140,67 +142,48 @@ namespace {
.def("getSpeciesDestructionTimescales", .def("getSpeciesDestructionTimescales",
[]( [](
const gridfire::engine::DynamicEngine& self, const gridfire::engine::DynamicEngine& self,
sp::StateBlob& ctx,
const fourdst::composition::Composition& comp, const fourdst::composition::Composition& comp,
const double T9, const double T9,
const double rho const double rho
) -> std::unordered_map<fourdst::atomic::Species, double> { ) -> std::unordered_map<fourdst::atomic::Species, double> {
const auto result = self.getSpeciesDestructionTimescales(comp, T9, rho); const auto result = self.getSpeciesDestructionTimescales(ctx, comp, T9, rho);
if (!result.has_value()) { if (!result.has_value()) {
throw gridfire::exceptions::EngineError(std::format("getSpeciesDestructionTimescales has returned a potentially recoverable error {}", gridfire::engine::EngineStatus_to_string(result.error()))); throw gridfire::exceptions::EngineError(std::format("getSpeciesDestructionTimescales has returned a potentially recoverable error {}", gridfire::engine::EngineStatus_to_string(result.error())));
} }
return result.value(); return result.value();
}, },
py::arg("ctx"),
py::arg("comp"), py::arg("comp"),
py::arg("T9"), py::arg("T9"),
py::arg("rho"), py::arg("rho"),
"Get the destruction timescales for each species in the network." "Get the destruction timescales for each species in the network."
) )
.def("update", .def("project",
&T::update, &T::project,
py::arg("ctx"),
py::arg("netIn"), py::arg("netIn"),
"Update the engine state based on the provided NetIn object." "Update the engine state based on the provided NetIn object."
) )
.def("setScreeningModel",
&T::setScreeningModel,
py::arg("screeningModel"),
"Set the screening model for the engine."
)
.def("getScreeningModel", .def("getScreeningModel",
&T::getScreeningModel, &T::getScreeningModel,
"Get the current screening model of the engine." "Get the current screening model of the engine."
) )
.def("getSpeciesIndex", .def("getSpeciesIndex",
&T::getSpeciesIndex, &T::getSpeciesIndex,
py::arg("ctx"),
py::arg("species"), py::arg("species"),
"Get the index of a species in the network." "Get the index of a species in the network."
) )
.def("mapNetInToMolarAbundanceVector",
&T::mapNetInToMolarAbundanceVector,
py::arg("netIn"),
"Map a NetIn object to a vector of molar abundances."
)
.def("primeEngine", .def("primeEngine",
&T::primeEngine, &T::primeEngine,
py::arg("ctx"),
py::arg("netIn"), py::arg("netIn"),
"Prime the engine with a NetIn object to prepare for calculations." "Prime the engine with a NetIn object to prepare for calculations."
) )
.def("getDepth",
&T::getDepth,
"Get the current build depth of the engine."
)
.def("rebuild",
&T::rebuild,
py::arg("composition"),
py::arg("depth") = gridfire::engine::NetworkBuildDepth::Full,
"Rebuild the engine with a new composition and build depth."
)
.def("isStale",
&T::isStale,
py::arg("netIn"),
"Check if the engine is stale based on the provided NetIn object."
)
.def("collectComposition", .def("collectComposition",
&T::collectComposition, &T::collectComposition,
py::arg("ctx"),
py::arg("composition"), py::arg("composition"),
py::arg("T9"), py::arg("T9"),
py::arg("rho"), py::arg("rho"),
@@ -208,6 +191,7 @@ namespace {
) )
.def("getSpeciesStatus", .def("getSpeciesStatus",
&T::getSpeciesStatus, &T::getSpeciesStatus,
py::arg("ctx"),
py::arg("species"), py::arg("species"),
"Get the status of a species in the network." "Get the status of a species in the network."
); );
@@ -253,6 +237,7 @@ void register_engine_diagnostic_bindings(pybind11::module &m) {
auto diagnostics = m.def_submodule("diagnostics", "A submodule for engine diagnostics"); auto diagnostics = m.def_submodule("diagnostics", "A submodule for engine diagnostics");
diagnostics.def("report_limiting_species", diagnostics.def("report_limiting_species",
&gridfire::engine::diagnostics::report_limiting_species, &gridfire::engine::diagnostics::report_limiting_species,
py::arg("ctx"),
py::arg("engine"), py::arg("engine"),
py::arg("Y_full"), py::arg("Y_full"),
py::arg("E_full"), py::arg("E_full"),
@@ -264,6 +249,7 @@ void register_engine_diagnostic_bindings(pybind11::module &m) {
diagnostics.def("inspect_species_balance", diagnostics.def("inspect_species_balance",
&gridfire::engine::diagnostics::inspect_species_balance, &gridfire::engine::diagnostics::inspect_species_balance,
py::arg("ctx"),
py::arg("engine"), py::arg("engine"),
py::arg("species_name"), py::arg("species_name"),
py::arg("comp"), py::arg("comp"),
@@ -274,6 +260,7 @@ void register_engine_diagnostic_bindings(pybind11::module &m) {
diagnostics.def("inspect_jacobian_stiffness", diagnostics.def("inspect_jacobian_stiffness",
&gridfire::engine::diagnostics::inspect_jacobian_stiffness, &gridfire::engine::diagnostics::inspect_jacobian_stiffness,
py::arg("ctx"),
py::arg("engine"), py::arg("engine"),
py::arg("comp"), py::arg("comp"),
py::arg("T9"), py::arg("T9"),
@@ -311,6 +298,7 @@ void register_engine_construction_bindings(pybind11::module &m) {
void register_engine_priming_bindings(pybind11::module &m) { void register_engine_priming_bindings(pybind11::module &m) {
m.def("primeNetwork", m.def("primeNetwork",
&gridfire::engine::primeNetwork, &gridfire::engine::primeNetwork,
py::arg("ctx"),
py::arg("netIn"), py::arg("netIn"),
py::arg("engine"), py::arg("engine"),
py::arg("ignoredReactionTypes") = std::nullopt, py::arg("ignoredReactionTypes") = std::nullopt,
@@ -456,19 +444,16 @@ void con_stype_register_graph_engine_bindings(const pybind11::module &m) {
py::arg("reactions"), py::arg("reactions"),
"Initialize GraphEngine with a set of reactions." "Initialize GraphEngine with a set of reactions."
); );
py_graph_engine_bindings.def_static("getNetReactionStoichiometry",
&gridfire::engine::GraphEngine::getNetReactionStoichiometry,
py::arg("reaction"),
"Get the net stoichiometry for a given reaction."
);
py_graph_engine_bindings.def("getSpeciesTimescales", py_graph_engine_bindings.def("getSpeciesTimescales",
[](const gridfire::engine::GraphEngine& self, [](const gridfire::engine::GraphEngine& self,
sp::StateBlob& ctx,
const fourdst::composition::Composition& composition, const fourdst::composition::Composition& composition,
const double T9, const double T9,
const double rho, const double rho,
const gridfire::reaction::ReactionSet& activeReactions) { const gridfire::reaction::ReactionSet& activeReactions) {
return self.getSpeciesTimescales(composition, T9, rho, activeReactions); return self.getSpeciesTimescales(ctx, composition, T9, rho, activeReactions);
}, },
py::arg("ctx"),
py::arg("composition"), py::arg("composition"),
py::arg("T9"), py::arg("T9"),
py::arg("rho"), py::arg("rho"),
@@ -476,12 +461,14 @@ void con_stype_register_graph_engine_bindings(const pybind11::module &m) {
); );
py_graph_engine_bindings.def("getSpeciesDestructionTimescales", py_graph_engine_bindings.def("getSpeciesDestructionTimescales",
[](const gridfire::engine::GraphEngine& self, [](const gridfire::engine::GraphEngine& self,
sp::StateBlob& ctx,
const fourdst::composition::Composition& composition, const fourdst::composition::Composition& composition,
const double T9, const double T9,
const double rho, const double rho,
const gridfire::reaction::ReactionSet& activeReactions) { const gridfire::reaction::ReactionSet& activeReactions) {
return self.getSpeciesDestructionTimescales(composition, T9, rho, activeReactions); return self.getSpeciesDestructionTimescales(ctx, composition, T9, rho, activeReactions);
}, },
py::arg("ctx"),
py::arg("composition"), py::arg("composition"),
py::arg("T9"), py::arg("T9"),
py::arg("rho"), py::arg("rho"),
@@ -489,24 +476,22 @@ void con_stype_register_graph_engine_bindings(const pybind11::module &m) {
); );
py_graph_engine_bindings.def("involvesSpecies", py_graph_engine_bindings.def("involvesSpecies",
&gridfire::engine::GraphEngine::involvesSpecies, &gridfire::engine::GraphEngine::involvesSpecies,
py::arg("ctx"),
py::arg("species"), py::arg("species"),
"Check if a given species is involved in the network." "Check if a given species is involved in the network."
); );
py_graph_engine_bindings.def("exportToDot", py_graph_engine_bindings.def("exportToDot",
&gridfire::engine::GraphEngine::exportToDot, &gridfire::engine::GraphEngine::exportToDot,
py::arg("ctx"),
py::arg("filename"), py::arg("filename"),
"Export the network to a DOT file for visualization." "Export the network to a DOT file for visualization."
); );
py_graph_engine_bindings.def("exportToCSV", py_graph_engine_bindings.def("exportToCSV",
&gridfire::engine::GraphEngine::exportToCSV, &gridfire::engine::GraphEngine::exportToCSV,
py::arg("ctx"),
py::arg("filename"), py::arg("filename"),
"Export the network to a CSV file for analysis." "Export the network to a CSV file for analysis."
); );
py_graph_engine_bindings.def("setPrecomputation",
&gridfire::engine::GraphEngine::setPrecomputation,
py::arg("precompute"),
"Enable or disable precomputation for the engine."
);
py_graph_engine_bindings.def("isPrecomputationEnabled", py_graph_engine_bindings.def("isPrecomputationEnabled",
&gridfire::engine::GraphEngine::isPrecomputationEnabled, &gridfire::engine::GraphEngine::isPrecomputationEnabled,
"Check if precomputation is enabled for the engine." "Check if precomputation is enabled for the engine."
@@ -544,11 +529,6 @@ void con_stype_register_graph_engine_bindings(const pybind11::module &m) {
&gridfire::engine::GraphEngine::isUsingReverseReactions, &gridfire::engine::GraphEngine::isUsingReverseReactions,
"Check if the engine is using reverse reactions." "Check if the engine is using reverse reactions."
); );
py_graph_engine_bindings.def("setUseReverseReactions",
&gridfire::engine::GraphEngine::setUseReverseReactions,
py::arg("useReverse"),
"Enable or disable the use of reverse reactions in the engine."
);
// Register the general dynamic engine bindings // Register the general dynamic engine bindings
registerDynamicEngineDefs<gridfire::engine::GraphEngine, gridfire::engine::DynamicEngine>(py_graph_engine_bindings); registerDynamicEngineDefs<gridfire::engine::GraphEngine, gridfire::engine::DynamicEngine>(py_graph_engine_bindings);
@@ -587,11 +567,13 @@ void register_engine_view_bindings(const pybind11::module &m) {
registerDynamicEngineDefs<gridfire::engine::FileDefinedEngineView, gridfire::engine::DefinedEngineView>(py_file_defined_engine_view_bindings); registerDynamicEngineDefs<gridfire::engine::FileDefinedEngineView, gridfire::engine::DefinedEngineView>(py_file_defined_engine_view_bindings);
auto py_priming_engine_view_bindings = py::class_<gridfire::engine::NetworkPrimingEngineView, gridfire::engine::DefinedEngineView>(m, "NetworkPrimingEngineView"); auto py_priming_engine_view_bindings = py::class_<gridfire::engine::NetworkPrimingEngineView, gridfire::engine::DefinedEngineView>(m, "NetworkPrimingEngineView");
py_priming_engine_view_bindings.def(py::init<const std::string&, gridfire::engine::GraphEngine&>(), py_priming_engine_view_bindings.def(py::init<sp::StateBlob&, const std::string&, gridfire::engine::GraphEngine&>(),
py::arg("ctx"),
py::arg("primingSymbol"), py::arg("primingSymbol"),
py::arg("baseEngine"), py::arg("baseEngine"),
"Construct a priming engine view with a priming symbol and a base engine."); "Construct a priming engine view with a priming symbol and a base engine.");
py_priming_engine_view_bindings.def(py::init<const fourdst::atomic::Species&, gridfire::engine::GraphEngine&>(), py_priming_engine_view_bindings.def(py::init<sp::StateBlob&, const fourdst::atomic::Species&, gridfire::engine::GraphEngine&>(),
py::arg("ctx"),
py::arg("primingSpecies"), py::arg("primingSpecies"),
py::arg("baseEngine"), py::arg("baseEngine"),
"Construct a priming engine view with a priming species and a base engine."); "Construct a priming engine view with a priming species and a base engine.");
@@ -622,15 +604,12 @@ void register_engine_view_bindings(const pybind11::module &m) {
); );
py_multiscale_engine_view_bindings.def("partitionNetwork", py_multiscale_engine_view_bindings.def("partitionNetwork",
&gridfire::engine::MultiscalePartitioningEngineView::partitionNetwork, &gridfire::engine::MultiscalePartitioningEngineView::partitionNetwork,
py::arg("ctx"),
py::arg("netIn"), py::arg("netIn"),
"Partition the network based on species timescales and connectivity."); "Partition the network based on species timescales and connectivity.");
py_multiscale_engine_view_bindings.def("partitionNetwork",
py::overload_cast<const gridfire::NetIn&>(&gridfire::engine::MultiscalePartitioningEngineView::partitionNetwork),
py::arg("netIn"),
"Partition the network based on a NetIn object."
);
py_multiscale_engine_view_bindings.def("exportToDot", py_multiscale_engine_view_bindings.def("exportToDot",
&gridfire::engine::MultiscalePartitioningEngineView::exportToDot, &gridfire::engine::MultiscalePartitioningEngineView::exportToDot,
py::arg("ctx"),
py::arg("filename"), py::arg("filename"),
py::arg("comp"), py::arg("comp"),
py::arg("T9"), py::arg("T9"),
@@ -661,7 +640,16 @@ void register_engine_view_bindings(const pybind11::module &m) {
"Check if a given species is involved in the network's dynamic set." "Check if a given species is involved in the network's dynamic set."
); );
py_multiscale_engine_view_bindings.def("getNormalizedEquilibratedComposition", py_multiscale_engine_view_bindings.def("getNormalizedEquilibratedComposition",
&gridfire::engine::MultiscalePartitioningEngineView::getNormalizedEquilibratedComposition, [](
const gridfire::engine::MultiscalePartitioningEngineView& self,
sp::StateBlob& ctx,
const fourdst::composition::Composition& comp,
const double T9,
const double rho
) {
return self.getNormalizedEquilibratedComposition(ctx, comp, T9, rho, false);
},
py::arg("ctx"),
py::arg("comp"), py::arg("comp"),
py::arg("T9"), py::arg("T9"),
py::arg("rho"), py::arg("rho"),

View File

@@ -1,21 +0,0 @@
subdir('trampoline')
# Define the library
bindings_sources = files('bindings.cpp')
bindings_headers = files('bindings.h')
dependencies = [
gridfire_dep,
python3_dep,
pybind11_dep,
]
message('⏳ Python bindings for GridFire Engine are being registered...')
shared_module('py_gf_engine',
bindings_sources,
cpp_args: ['-fvisibility=default'],
install : true,
dependencies: dependencies,
include_directories: include_directories('.')
)
message('✅ Python bindings for GridFire Engine registered successfully!')

View File

@@ -0,0 +1,152 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc. if needed directly
#include "gridfire/engine/scratchpads/scratchpads.h"
#include "bindings.h"
namespace py = pybind11;
namespace sp = gridfire::engine::scratch;
template<typename... ScratchPadTypes>
void build_state_getter(py::module& m) {
}
void register_scratchpad_types_bindings(pybind11::module &m) {
py::enum_<sp::ScratchPadType>(m, "ScratchPadType")
.value("GRAPH_ENGINE_SCRATCHPAD", sp::ScratchPadType::GRAPH_ENGINE_SCRATCHPAD)
.value("MULTISCALE_PARTITIONING_ENGINE_VIEW_SCRATCHPAD", sp::ScratchPadType::MULTISCALE_PARTITIONING_ENGINE_VIEW_SCRATCHPAD)
.value("ADAPTIVE_ENGINE_VIEW_SCRATCHPAD", sp::ScratchPadType::ADAPTIVE_ENGINE_VIEW_SCRATCHPAD)
.value("DEFINED_ENGINE_VIEW_SCRATCHPAD", sp::ScratchPadType::DEFINED_ENGINE_VIEW_SCRATCHPAD)
.export_values();
}
void register_scratchpad_bindings(pybind11::module_ &m) {
py::enum_<sp::GraphEngineScratchPad::ADFunRegistrationResult>(m, "ADFunRegistrationResult")
.value("SUCCESS", sp::GraphEngineScratchPad::ADFunRegistrationResult::SUCCESS)
.value("ALREADY_REGISTERED", sp::GraphEngineScratchPad::ADFunRegistrationResult::ALREADY_REGISTERED)
.export_values();
py::class_<sp::GraphEngineScratchPad>(m, "GraphEngineScratchPad")
.def(py::init<>())
.def("initialize", &sp::GraphEngineScratchPad::initialize, py::arg("engine"))
.def("clone", &sp::GraphEngineScratchPad::clone)
.def("is_initialized", &sp::GraphEngineScratchPad::is_initialized)
.def_readonly("most_recent_rhs_calculation", &sp::GraphEngineScratchPad::most_recent_rhs_calculation)
.def_readonly("local_abundance_cache", &sp::GraphEngineScratchPad::local_abundance_cache)
.def_readonly("has_initialized", &sp::GraphEngineScratchPad::has_initialized)
.def_readonly("stepDerivativesCache", &sp::GraphEngineScratchPad::stepDerivativesCache)
.def_readonly_static("ID", &sp::GraphEngineScratchPad::ID)
.def("__repr__", [](const sp::GraphEngineScratchPad &self) {
return std::format("{}", self);
});
py::class_<sp::MultiscalePartitioningEngineViewScratchPad>(m, "MultiscalePartitioningEngineViewScratchPad")
.def(py::init<>())
.def("initialize", &sp::MultiscalePartitioningEngineViewScratchPad::initialize)
.def("clone", &sp::MultiscalePartitioningEngineViewScratchPad::clone)
.def("is_initialized", &sp::MultiscalePartitioningEngineViewScratchPad::is_initialized)
.def_readonly("qse_groups", &sp::MultiscalePartitioningEngineViewScratchPad::qse_groups)
.def_readonly("dynamic_species", &sp::MultiscalePartitioningEngineViewScratchPad::dynamic_species)
.def_readonly("algebraic_species", &sp::MultiscalePartitioningEngineViewScratchPad::algebraic_species)
.def_readonly("composition_cache", &sp::MultiscalePartitioningEngineViewScratchPad::composition_cache)
.def_readonly("has_initialized", &sp::MultiscalePartitioningEngineViewScratchPad::has_initialized)
.def_readonly_static("ID", &sp::MultiscalePartitioningEngineViewScratchPad::ID)
.def("__repr__", [](const sp::MultiscalePartitioningEngineViewScratchPad &self) {
return std::format("{}", self);
});
py::class_<sp::AdaptiveEngineViewScratchPad>(m, "AdaptiveEngineViewScratchPad")
.def(py::init<>())
.def("initialize", &sp::AdaptiveEngineViewScratchPad::initialize)
.def("clone", &sp::AdaptiveEngineViewScratchPad::clone)
.def("is_initialized", &sp::AdaptiveEngineViewScratchPad::is_initialized)
.def_readonly("active_species", &sp::AdaptiveEngineViewScratchPad::active_species)
.def_readonly("active_reactions", &sp::AdaptiveEngineViewScratchPad::active_reactions)
.def_readonly("has_initialized", &sp::AdaptiveEngineViewScratchPad::has_initialized)
.def_readonly_static("ID", &sp::AdaptiveEngineViewScratchPad::ID)
.def("__repr__", [](const sp::AdaptiveEngineViewScratchPad &self) {
return std::format("{}", self);
});
py::class_<sp::DefinedEngineViewScratchPad>(m, "DefinedEngineViewScratchPad")
.def(py::init<>())
.def("clone", &sp::DefinedEngineViewScratchPad::clone)
.def("is_initialized", &sp::DefinedEngineViewScratchPad::is_initialized)
.def_readonly("active_species", &sp::DefinedEngineViewScratchPad::active_species)
.def_readonly("active_reactions", &sp::DefinedEngineViewScratchPad::active_reactions)
.def_readonly("species_index_map", &sp::DefinedEngineViewScratchPad::species_index_map)
.def_readonly("reaction_index_map", &sp::DefinedEngineViewScratchPad::reaction_index_map)
.def_readonly("has_initialized", &sp::DefinedEngineViewScratchPad::has_initialized)
.def_readonly_static("ID", &sp::DefinedEngineViewScratchPad::ID)
.def("__repr__", [](const sp::DefinedEngineViewScratchPad &self) {
return std::format("{}", self);
});
}
void register_state_blob_bindings(pybind11::module_ &m) {
py::enum_<sp::StateBlob::Error>(m, "StateBlobError")
.value("SCRATCHPAD_OUT_OF_BOUNDS", sp::StateBlob::Error::SCRATCHPAD_OUT_OF_BOUNDS)
.value("SCRATCHPAD_NOT_FOUND", sp::StateBlob::Error::SCRATCHPAD_NOT_FOUND)
.value("SCRATCHPAD_BAD_CAST", sp::StateBlob::Error::SCRATCHPAD_BAD_CAST)
.value("SCRATCHPAD_NOT_INITIALIZED", sp::StateBlob::Error::SCRATCHPAD_NOT_INITIALIZED)
.value("SCRATCHPAD_TYPE_COLLISION", sp::StateBlob::Error::SCRATCHPAD_TYPE_COLLISION)
.value("SCRATCHPAD_UNKNOWN_ERROR", sp::StateBlob::Error::SCRATCHPAD_UNKNOWN_ERROR)
.export_values();
py::class_<sp::StateBlob>(m, "StateBlob")
.def(py::init<>())
.def("enroll", [](sp::StateBlob &self, const sp::ScratchPadType type) {
switch (type) {
case sp::ScratchPadType::GRAPH_ENGINE_SCRATCHPAD:
self.enroll<sp::GraphEngineScratchPad>();
break;
case sp::ScratchPadType::MULTISCALE_PARTITIONING_ENGINE_VIEW_SCRATCHPAD:
self.enroll<sp::MultiscalePartitioningEngineViewScratchPad>();
break;
case sp::ScratchPadType::ADAPTIVE_ENGINE_VIEW_SCRATCHPAD:
self.enroll<sp::AdaptiveEngineViewScratchPad>();
break;
case sp::ScratchPadType::DEFINED_ENGINE_VIEW_SCRATCHPAD:
self.enroll<sp::DefinedEngineViewScratchPad>();
break;
default:
throw std::invalid_argument("Unknown ScratchPadType for enrollment.");
}
})
.def("get", [](const sp::StateBlob &self, const sp::ScratchPadType type) {
auto result = self.get(type);
if (!result.has_value()) {
throw std::runtime_error("Error retrieving scratchpad: " + sp::StateBlob::error_to_string(result.error()));
}
return result.value();
},
pybind11::return_value_policy::reference_internal
)
.def("clone_structure", &sp::StateBlob::clone_structure)
.def("get_registered_scratchpads", &sp::StateBlob::get_registered_scratchpads)
.def("get_status", [](const sp::StateBlob &self, const sp::ScratchPadType type) -> sp::StateBlob::ScratchPadStatus {
switch (type) {
case sp::ScratchPadType::GRAPH_ENGINE_SCRATCHPAD:
return self.get_status<sp::GraphEngineScratchPad>();
case sp::ScratchPadType::MULTISCALE_PARTITIONING_ENGINE_VIEW_SCRATCHPAD:
return self.get_status<sp::MultiscalePartitioningEngineViewScratchPad>();
case sp::ScratchPadType::ADAPTIVE_ENGINE_VIEW_SCRATCHPAD:
return self.get_status<sp::AdaptiveEngineViewScratchPad>();
case sp::ScratchPadType::DEFINED_ENGINE_VIEW_SCRATCHPAD:
return self.get_status<sp::DefinedEngineViewScratchPad>();
default:
throw std::invalid_argument("Unknown ScratchPadType for status retrieval.");
}
})
.def("get_status_map", &sp::StateBlob::get_status_map)
.def_static("error_to_string", &sp::StateBlob::error_to_string)
.def("__repr__", [](const sp::StateBlob &self) {
return std::format("{}", self);
});
}

View File

@@ -0,0 +1,7 @@
#pragma once
#include <pybind11/pybind11.h>
void register_scratchpad_types_bindings(pybind11::module_& m);
void register_scratchpad_bindings(pybind11::module_& m);
void register_state_blob_bindings(pybind11::module_& m);

View File

@@ -1,21 +0,0 @@
gf_engine_trampoline_sources = files('py_engine.cpp')
gf_engine_trapoline_dependencies = [
gridfire_dep,
pybind11_dep,
python3_dep,
]
gf_engine_trampoline_lib = static_library(
'engine_trampolines',
gf_engine_trampoline_sources,
include_directories: include_directories('.'),
dependencies: gf_engine_trapoline_dependencies,
install: false,
)
gr_engine_trampoline_dep = declare_dependency(
link_with: gf_engine_trampoline_lib,
include_directories: ('.'),
dependencies: gf_engine_trapoline_dependencies,
)

View File

@@ -13,36 +13,29 @@
namespace py = pybind11; namespace py = pybind11;
const std::vector<fourdst::atomic::Species>& PyEngine::getNetworkSpecies() const { const std::vector<fourdst::atomic::Species>& PyEngine::getNetworkSpecies(
/* gridfire::engine::scratch::StateBlob& ctx
* Acquire the GIL (Global Interpreter Lock) for thread safety ) const {
* with the Python interpreter. PYBIND11_OVERRIDE_PURE(
*/ const std::vector<fourdst::atomic::Species>&,
py::gil_scoped_acquire gil; gridfire::engine::Engine,
getNetworkSpecies,
/* ctx
* get_override() looks for a Python method that overrides this C++ one. );
*/
if (const py::function override = py::get_override(this, "getNetworkSpecies")) {
const py::object result = override();
m_species_cache = result.cast<std::vector<fourdst::atomic::Species>>();
return m_species_cache;
}
py::pybind11_fail("Tried to call pure virtual function \"DynamicEngine::getNetworkSpecies\"");
} }
std::expected<gridfire::engine::StepDerivatives<double>, gridfire::engine::EngineStatus> PyEngine::calculateRHSAndEnergy( std::expected<gridfire::engine::StepDerivatives<double>, gridfire::engine::EngineStatus> PyEngine::calculateRHSAndEnergy(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp, const fourdst::composition::CompositionAbstract &comp,
double T9, double T9,
double rho double rho,
bool trust
) const { ) const {
PYBIND11_OVERRIDE_PURE( PYBIND11_OVERRIDE_PURE(
PYBIND11_TYPE(std::expected<gridfire::engine::StepDerivatives<double>, gridfire::engine::EngineStatus>), PYBIND11_TYPE(std::expected<gridfire::engine::StepDerivatives<double>, gridfire::engine::EngineStatus>),
gridfire::engine::Engine, gridfire::engine::Engine,
calculateRHSAndEnergy, calculateRHSAndEnergy,
comp, T9, rho ctx, comp, T9, rho, trust
); );
} }
@@ -50,41 +43,35 @@ std::expected<gridfire::engine::StepDerivatives<double>, gridfire::engine::Engin
/// PyDynamicEngine Implementation /// /// PyDynamicEngine Implementation ///
///////////////////////////////////// /////////////////////////////////////
const std::vector<fourdst::atomic::Species>& PyDynamicEngine::getNetworkSpecies() const { const std::vector<fourdst::atomic::Species>& PyDynamicEngine::getNetworkSpecies(
/* gridfire::engine::scratch::StateBlob& ctx
* Acquire the GIL (Global Interpreter Lock) for thread safety ) const {
* with the Python interpreter. PYBIND11_OVERRIDE_PURE(
*/ const std::vector<fourdst::atomic::Species>&,
py::gil_scoped_acquire gil; gridfire::engine::DynamicEngine,
getNetworkSpecies,
/* ctx
* get_override() looks for a Python method that overrides this C++ one. );
*/
if (const py::function override = py::get_override(this, "getNetworkSpecies")) {
const py::object result = override();
m_species_cache = result.cast<std::vector<fourdst::atomic::Species>>();
return m_species_cache;
}
py::pybind11_fail("Tried to call pure virtual function \"DynamicEngine::getNetworkSpecies\"");
} }
std::expected<gridfire::engine::StepDerivatives<double>, gridfire::engine::EngineStatus> PyDynamicEngine::calculateRHSAndEnergy( std::expected<gridfire::engine::StepDerivatives<double>, gridfire::engine::EngineStatus> PyDynamicEngine::calculateRHSAndEnergy(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp, const fourdst::composition::CompositionAbstract &comp,
double T9, double T9,
double rho double rho,
bool trust
) const { ) const {
PYBIND11_OVERRIDE_PURE( PYBIND11_OVERRIDE_PURE(
PYBIND11_TYPE(std::expected<gridfire::engine::StepDerivatives<double>, gridfire::engine::EngineStatus>), PYBIND11_TYPE(std::expected<gridfire::engine::StepDerivatives<double>, gridfire::engine::EngineStatus>),
gridfire::engine::DynamicEngine, gridfire::engine::DynamicEngine,
calculateRHSAndEnergy, calculateRHSAndEnergy,
comp, T9, rho ctx, comp, T9, rho, trust
); );
} }
gridfire::engine::NetworkJacobian PyDynamicEngine::generateJacobianMatrix( gridfire::engine::NetworkJacobian PyDynamicEngine::generateJacobianMatrix(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract& comp, const fourdst::composition::CompositionAbstract& comp,
double T9, double T9,
double rho double rho
@@ -100,6 +87,7 @@ gridfire::engine::NetworkJacobian PyDynamicEngine::generateJacobianMatrix(
} }
gridfire::engine::NetworkJacobian PyDynamicEngine::generateJacobianMatrix( gridfire::engine::NetworkJacobian PyDynamicEngine::generateJacobianMatrix(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp, const fourdst::composition::CompositionAbstract &comp,
const double T9, const double T9,
const double rho, const double rho,
@@ -109,6 +97,7 @@ gridfire::engine::NetworkJacobian PyDynamicEngine::generateJacobianMatrix(
gridfire::engine::NetworkJacobian, gridfire::engine::NetworkJacobian,
gridfire::engine::DynamicEngine, gridfire::engine::DynamicEngine,
generateJacobianMatrix, generateJacobianMatrix,
ctx,
comp, comp,
T9, T9,
rho, rho,
@@ -117,6 +106,7 @@ gridfire::engine::NetworkJacobian PyDynamicEngine::generateJacobianMatrix(
} }
gridfire::engine::NetworkJacobian PyDynamicEngine::generateJacobianMatrix( gridfire::engine::NetworkJacobian PyDynamicEngine::generateJacobianMatrix(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp, const fourdst::composition::CompositionAbstract &comp,
double T9, double T9,
double rho, double rho,
@@ -126,6 +116,7 @@ gridfire::engine::NetworkJacobian PyDynamicEngine::generateJacobianMatrix(
gridfire::engine::NetworkJacobian, gridfire::engine::NetworkJacobian,
gridfire::engine::DynamicEngine, gridfire::engine::DynamicEngine,
generateJacobianMatrix, generateJacobianMatrix,
ctx,
comp, comp,
T9, T9,
rho, rho,
@@ -133,28 +124,8 @@ gridfire::engine::NetworkJacobian PyDynamicEngine::generateJacobianMatrix(
); );
} }
void PyDynamicEngine::generateStoichiometryMatrix() {
PYBIND11_OVERRIDE_PURE(
void,
gridfire::engine::DynamicEngine,
generateStoichiometryMatrix
);
}
int PyDynamicEngine::getStoichiometryMatrixEntry(
const fourdst::atomic::Species& species,
const gridfire::reaction::Reaction& reaction
) const {
PYBIND11_OVERRIDE_PURE(
int,
gridfire::engine::DynamicEngine,
getStoichiometryMatrixEntry,
species,
reaction
);
}
double PyDynamicEngine::calculateMolarReactionFlow( double PyDynamicEngine::calculateMolarReactionFlow(
gridfire::engine::scratch::StateBlob& ctx,
const gridfire::reaction::Reaction &reaction, const gridfire::reaction::Reaction &reaction,
const fourdst::composition::CompositionAbstract &comp, const fourdst::composition::CompositionAbstract &comp,
double T9, double T9,
@@ -164,6 +135,7 @@ double PyDynamicEngine::calculateMolarReactionFlow(
double, double,
gridfire::engine::DynamicEngine, gridfire::engine::DynamicEngine,
calculateMolarReactionFlow, calculateMolarReactionFlow,
ctx,
reaction, reaction,
comp, comp,
T9, T9,
@@ -171,24 +143,19 @@ double PyDynamicEngine::calculateMolarReactionFlow(
); );
} }
const gridfire::reaction::ReactionSet& PyDynamicEngine::getNetworkReactions() const { const gridfire::reaction::ReactionSet& PyDynamicEngine::getNetworkReactions(
gridfire::engine::scratch::StateBlob& ctx
) const {
PYBIND11_OVERRIDE_PURE( PYBIND11_OVERRIDE_PURE(
const gridfire::reaction::ReactionSet&, const gridfire::reaction::ReactionSet&,
gridfire::engine::DynamicEngine, gridfire::engine::DynamicEngine,
getNetworkReactions getNetworkReactions,
); ctx
}
void PyDynamicEngine::setNetworkReactions(const gridfire::reaction::ReactionSet& reactions) {
PYBIND11_OVERRIDE_PURE(
void,
gridfire::engine::DynamicEngine,
setNetworkReactions,
reactions
); );
} }
std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::engine::EngineStatus> PyDynamicEngine::getSpeciesTimescales( std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::engine::EngineStatus> PyDynamicEngine::getSpeciesTimescales(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp, const fourdst::composition::CompositionAbstract &comp,
double T9, double T9,
double rho double rho
@@ -197,6 +164,7 @@ std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::en
PYBIND11_TYPE(std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::engine::EngineStatus>), PYBIND11_TYPE(std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::engine::EngineStatus>),
gridfire::engine::DynamicEngine, gridfire::engine::DynamicEngine,
getSpeciesTimescales, getSpeciesTimescales,
ctx,
comp, comp,
T9, T9,
rho rho
@@ -204,6 +172,7 @@ std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::en
} }
std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::engine::EngineStatus> PyDynamicEngine::getSpeciesDestructionTimescales( std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::engine::EngineStatus> PyDynamicEngine::getSpeciesDestructionTimescales(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp, const fourdst::composition::CompositionAbstract &comp,
double T9, double T9,
double rho double rho
@@ -212,80 +181,71 @@ std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::en
PYBIND11_TYPE(std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::engine::EngineStatus>), PYBIND11_TYPE(std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::engine::EngineStatus>),
gridfire::engine::DynamicEngine, gridfire::engine::DynamicEngine,
getSpeciesDestructionTimescales, getSpeciesDestructionTimescales,
comp, T9, rho ctx, comp, T9, rho
); );
} }
fourdst::composition::Composition PyDynamicEngine::update(const gridfire::NetIn &netIn) { fourdst::composition::Composition PyDynamicEngine::project(
gridfire::engine::scratch::StateBlob& ctx,
const gridfire::NetIn &netIn
) const {
PYBIND11_OVERRIDE_PURE( PYBIND11_OVERRIDE_PURE(
fourdst::composition::Composition, fourdst::composition::Composition,
gridfire::engine::DynamicEngine, gridfire::engine::DynamicEngine,
update, project,
ctx,
netIn netIn
); );
} }
bool PyDynamicEngine::isStale(const gridfire::NetIn &netIn) { gridfire::screening::ScreeningType PyDynamicEngine::getScreeningModel(
PYBIND11_OVERRIDE_PURE( gridfire::engine::scratch::StateBlob& ctx
bool, ) const {
gridfire::engine::DynamicEngine,
isStale,
netIn
);
}
void PyDynamicEngine::setScreeningModel(gridfire::screening::ScreeningType model) {
PYBIND11_OVERRIDE_PURE(
void,
gridfire::engine::DynamicEngine,
setScreeningModel,
model
);
}
gridfire::screening::ScreeningType PyDynamicEngine::getScreeningModel() const {
PYBIND11_OVERRIDE_PURE( PYBIND11_OVERRIDE_PURE(
gridfire::screening::ScreeningType, gridfire::screening::ScreeningType,
gridfire::engine::DynamicEngine, gridfire::engine::DynamicEngine,
getScreeningModel getScreeningModel,
ctx
); );
} }
size_t PyDynamicEngine::getSpeciesIndex(const fourdst::atomic::Species &species) const { size_t PyDynamicEngine::getSpeciesIndex(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::atomic::Species &species
) const {
PYBIND11_OVERRIDE_PURE( PYBIND11_OVERRIDE_PURE(
int, int,
gridfire::engine::DynamicEngine, gridfire::engine::DynamicEngine,
getSpeciesIndex, getSpeciesIndex,
ctx,
species species
); );
} }
std::vector<double> PyDynamicEngine::mapNetInToMolarAbundanceVector(const gridfire::NetIn &netIn) const { gridfire::engine::PrimingReport PyDynamicEngine::primeEngine(
PYBIND11_OVERRIDE_PURE( gridfire::engine::scratch::StateBlob& ctx,
std::vector<double>, const gridfire::NetIn &netIn
gridfire::engine::DynamicEngine, ) const {
mapNetInToMolarAbundanceVector,
netIn
);
}
gridfire::engine::PrimingReport PyDynamicEngine::primeEngine(const gridfire::NetIn &netIn) {
PYBIND11_OVERRIDE_PURE( PYBIND11_OVERRIDE_PURE(
gridfire::engine::PrimingReport, gridfire::engine::PrimingReport,
gridfire::engine::DynamicEngine, gridfire::engine::DynamicEngine,
primeEngine, primeEngine,
ctx,
netIn netIn
); );
} }
gridfire::engine::EnergyDerivatives PyDynamicEngine::calculateEpsDerivatives( gridfire::engine::EnergyDerivatives PyDynamicEngine::calculateEpsDerivatives(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp, const fourdst::composition::CompositionAbstract &comp,
const double T9, const double T9,
const double rho) const { const double rho
) const {
PYBIND11_OVERRIDE_PURE( PYBIND11_OVERRIDE_PURE(
gridfire::engine::EnergyDerivatives, gridfire::engine::EnergyDerivatives,
gridfire::engine::DynamicEngine, gridfire::engine::DynamicEngine,
calculateEpsDerivatives, calculateEpsDerivatives,
ctx,
comp, comp,
T9, T9,
rho rho
@@ -293,6 +253,7 @@ gridfire::engine::EnergyDerivatives PyDynamicEngine::calculateEpsDerivatives(
} }
fourdst::composition::Composition PyDynamicEngine::collectComposition( fourdst::composition::Composition PyDynamicEngine::collectComposition(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp, const fourdst::composition::CompositionAbstract &comp,
const double T9, const double T9,
const double rho const double rho
@@ -301,21 +262,37 @@ fourdst::composition::Composition PyDynamicEngine::collectComposition(
fourdst::composition::Composition, fourdst::composition::Composition,
gridfire::engine::DynamicEngine, gridfire::engine::DynamicEngine,
collectComposition, collectComposition,
ctx,
comp, comp,
T9, T9,
rho rho
); );
} }
gridfire::engine::SpeciesStatus PyDynamicEngine::getSpeciesStatus(const fourdst::atomic::Species &species) const { gridfire::engine::SpeciesStatus PyDynamicEngine::getSpeciesStatus(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::atomic::Species &species
) const {
PYBIND11_OVERRIDE_PURE( PYBIND11_OVERRIDE_PURE(
gridfire::engine::SpeciesStatus, gridfire::engine::SpeciesStatus,
gridfire::engine::DynamicEngine, gridfire::engine::DynamicEngine,
getSpeciesStatus, getSpeciesStatus,
ctx,
species species
); );
} }
std::optional<gridfire::engine::StepDerivatives<double>> PyDynamicEngine::getMostRecentRHSCalculation(
gridfire::engine::scratch::StateBlob &ctx
) const {
PYBIND11_OVERRIDE_PURE(
PYBIND11_TYPE(std::optional<gridfire::engine::StepDerivatives<double>>),
gridfire::engine::DynamicEngine,
getMostRecentRHSCalculation,
ctx
);
}
const gridfire::engine::Engine& PyEngineView::getBaseEngine() const { const gridfire::engine::Engine& PyEngineView::getBaseEngine() const {
PYBIND11_OVERRIDE_PURE( PYBIND11_OVERRIDE_PURE(
const gridfire::engine::Engine&, const gridfire::engine::Engine&,

View File

@@ -10,12 +10,16 @@
class PyEngine final : public gridfire::engine::Engine { class PyEngine final : public gridfire::engine::Engine {
public: public:
const std::vector<fourdst::atomic::Species>& getNetworkSpecies() const override; const std::vector<fourdst::atomic::Species>& getNetworkSpecies(
gridfire::engine::scratch::StateBlob& ctx
) const override;
std::expected<gridfire::engine::StepDerivatives<double>, gridfire::engine::EngineStatus> calculateRHSAndEnergy( std::expected<gridfire::engine::StepDerivatives<double>, gridfire::engine::EngineStatus> calculateRHSAndEnergy(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp, const fourdst::composition::CompositionAbstract &comp,
double T9, double T9,
double rho double rho,
bool trust
) const override; ) const override;
private: private:
mutable std::vector<fourdst::atomic::Species> m_species_cache; mutable std::vector<fourdst::atomic::Species> m_species_cache;
@@ -23,21 +27,27 @@ private:
class PyDynamicEngine final : public gridfire::engine::DynamicEngine { class PyDynamicEngine final : public gridfire::engine::DynamicEngine {
public: public:
const std::vector<fourdst::atomic::Species>& getNetworkSpecies() const override; const std::vector<fourdst::atomic::Species>& getNetworkSpecies(
gridfire::engine::scratch::StateBlob& ctx
) const override;
std::expected<gridfire::engine::StepDerivatives<double>, gridfire::engine::EngineStatus> calculateRHSAndEnergy( std::expected<gridfire::engine::StepDerivatives<double>, gridfire::engine::EngineStatus> calculateRHSAndEnergy(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp, const fourdst::composition::CompositionAbstract &comp,
double T9, double T9,
double rho double rho,
bool trust
) const override; ) const override;
gridfire::engine::NetworkJacobian generateJacobianMatrix( gridfire::engine::NetworkJacobian generateJacobianMatrix(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract& comp, const fourdst::composition::CompositionAbstract& comp,
double T9, double T9,
double rho double rho
) const override; ) const override;
gridfire::engine::NetworkJacobian generateJacobianMatrix( gridfire::engine::NetworkJacobian generateJacobianMatrix(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp, const fourdst::composition::CompositionAbstract &comp,
double T9, double T9,
double rho, double rho,
@@ -45,96 +55,81 @@ public:
) const override; ) const override;
gridfire::engine::NetworkJacobian generateJacobianMatrix( gridfire::engine::NetworkJacobian generateJacobianMatrix(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract& comp, const fourdst::composition::CompositionAbstract& comp,
double T9, double T9,
double rho, double rho,
const gridfire::engine::SparsityPattern &sparsityPattern const gridfire::engine::SparsityPattern &sparsityPattern
) const override; ) const override;
void generateStoichiometryMatrix() override;
int getStoichiometryMatrixEntry(
const fourdst::atomic::Species& species,
const gridfire::reaction::Reaction& reaction
) const override;
double calculateMolarReactionFlow( double calculateMolarReactionFlow(
gridfire::engine::scratch::StateBlob& ctx,
const gridfire::reaction::Reaction &reaction, const gridfire::reaction::Reaction &reaction,
const fourdst::composition::CompositionAbstract &comp, const fourdst::composition::CompositionAbstract &comp,
double T9, double T9,
double rho double rho
) const override; ) const override;
const gridfire::reaction::ReactionSet& getNetworkReactions() const override; const gridfire::reaction::ReactionSet& getNetworkReactions(
gridfire::engine::scratch::StateBlob& ctx
void setNetworkReactions( ) const override;
const gridfire::reaction::ReactionSet& reactions
) override;
std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::engine::EngineStatus> getSpeciesTimescales( std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::engine::EngineStatus> getSpeciesTimescales(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp, const fourdst::composition::CompositionAbstract &comp,
double T9, double T9,
double rho double rho
) const override; ) const override;
std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::engine::EngineStatus> getSpeciesDestructionTimescales( std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::engine::EngineStatus> getSpeciesDestructionTimescales(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp, const fourdst::composition::CompositionAbstract &comp,
double T9, double T9,
double rho double rho
) const override; ) const override;
fourdst::composition::Composition update( fourdst::composition::Composition project(
gridfire::engine::scratch::StateBlob& ctx,
const gridfire::NetIn &netIn const gridfire::NetIn &netIn
) override; ) const override;
bool isStale( gridfire::screening::ScreeningType getScreeningModel(
const gridfire::NetIn &netIn gridfire::engine::scratch::StateBlob& ctx
) override; ) const override;
void setScreeningModel(
gridfire::screening::ScreeningType model
) override;
gridfire::screening::ScreeningType getScreeningModel() const override;
size_t getSpeciesIndex( size_t getSpeciesIndex(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::atomic::Species &species const fourdst::atomic::Species &species
) const override; ) const override;
std::vector<double> mapNetInToMolarAbundanceVector( gridfire::engine::PrimingReport primeEngine(
gridfire::engine::scratch::StateBlob& ctx,
const gridfire::NetIn &netIn const gridfire::NetIn &netIn
) const override; ) const override;
gridfire::engine::PrimingReport primeEngine(
const gridfire::NetIn &netIn
) override;
gridfire::engine::BuildDepthType getDepth() const override {
throw std::logic_error("Network depth not supported by this engine.");
}
void rebuild(
const fourdst::composition::CompositionAbstract &comp,
gridfire::engine::BuildDepthType depth
) override {
throw std::logic_error("Setting network depth not supported by this engine.");
}
[[nodiscard]] gridfire::engine::EnergyDerivatives calculateEpsDerivatives( [[nodiscard]] gridfire::engine::EnergyDerivatives calculateEpsDerivatives(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp, const fourdst::composition::CompositionAbstract &comp,
double T9, double T9,
double rho double rho
) const override; ) const override;
fourdst::composition::Composition collectComposition( fourdst::composition::Composition collectComposition(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::composition::CompositionAbstract &comp, const fourdst::composition::CompositionAbstract &comp,
double T9, double T9,
double rho double rho
) const override; ) const override;
gridfire::engine::SpeciesStatus getSpeciesStatus( gridfire::engine::SpeciesStatus getSpeciesStatus(
gridfire::engine::scratch::StateBlob& ctx,
const fourdst::atomic::Species &species const fourdst::atomic::Species &species
) const override; ) const override;
std::optional<gridfire::engine::StepDerivatives<double>> getMostRecentRHSCalculation(
gridfire::engine::scratch::StateBlob &ctx
) const override;
private: private:
mutable std::vector<fourdst::atomic::Species> m_species_cache; mutable std::vector<fourdst::atomic::Species> m_species_cache;
}; };

View File

@@ -2,6 +2,8 @@
#include "bindings.h" #include "bindings.h"
#include "gridfire/exceptions/error_scratchpad.h"
namespace py = pybind11; namespace py = pybind11;
#include "gridfire/exceptions/exceptions.h" #include "gridfire/exceptions/exceptions.h"
@@ -44,4 +46,6 @@ void register_exception_bindings(const py::module &m) {
py::register_exception<gridfire::exceptions::CVODESolverFailureError>(m, "CVODESolverFailureError", m.attr("SUNDIALSError")); py::register_exception<gridfire::exceptions::CVODESolverFailureError>(m, "CVODESolverFailureError", m.attr("SUNDIALSError"));
py::register_exception<gridfire::exceptions::KINSolSolverFailureError>(m, "KINSolSolverFailureError", m.attr("SUNDIALSError")); py::register_exception<gridfire::exceptions::KINSolSolverFailureError>(m, "KINSolSolverFailureError", m.attr("SUNDIALSError"));
py::register_exception<gridfire::exceptions::ScratchPadError>(m, "ScratchPadError", m.attr("GridFireError"));
} }

View File

@@ -1,17 +0,0 @@
# Define the library
bindings_sources = files('bindings.cpp')
bindings_headers = files('bindings.h')
dependencies = [
gridfire_dep,
python3_dep,
pybind11_dep,
]
shared_module('py_gf_exceptions',
bindings_sources,
cpp_args: ['-fvisibility=default'],
install : true,
dependencies: dependencies,
include_directories: include_directories('.')
)

View File

@@ -1,7 +1,7 @@
from ._gridfire import * from ._gridfire import *
import sys import sys
from ._gridfire import type, utils, engine, solver, exceptions, partition, reaction, screening, io, policy from ._gridfire import type, utils, engine, solver, exceptions, partition, reaction, screening, io, policy, config
sys.modules['gridfire.type'] = type sys.modules['gridfire.type'] = type
sys.modules['gridfire.utils'] = utils sys.modules['gridfire.utils'] = utils
@@ -13,8 +13,61 @@ sys.modules['gridfire.reaction'] = reaction
sys.modules['gridfire.screening'] = screening sys.modules['gridfire.screening'] = screening
sys.modules['gridfire.policy'] = policy sys.modules['gridfire.policy'] = policy
sys.modules['gridfire.io'] = io sys.modules['gridfire.io'] = io
sys.modules['gridfire.config'] = config
__all__ = ['type', 'utils', 'engine', 'solver', 'exceptions', 'partition', 'reaction', 'screening', 'io', 'policy'] __all__ = ['type', 'utils', 'engine', 'solver', 'exceptions', 'partition', 'reaction', 'screening', 'io', 'policy', 'config']
__version__ = "v0.7.4_rc2" import importlib.metadata
try:
_meta = importlib.metadata.metadata('gridfire')
__version__ = _meta['Version']
__author__ = _meta['Author']
__license__ = _meta['License']
__email__ = _meta['Author-email']
__url__ = _meta['Home-page'] or _meta.get('Project-URL', '').split(',')[0].split(' ')[-1].strip()
__description__ = _meta['Summary']
except importlib.metadata.PackageNotFoundError :
__version__ = 'unknown - Package not installed'
__author__ = 'Emily M. Boudreaux'
__license__ = 'GNU General Public License v3.0'
__email__ = 'emily.boudreaux@dartmouth.edu'
__url__ = 'https://github.com/4D-STAR/GridFire'
def gf_metadata():
return {
'version': __version__,
'author': __author__,
'license': __license__,
'email': __email__,
'url': __url__,
'description': __description__
}
def gf_version():
return __version__
def gf_author():
return __author__
def gf_license():
return __license__
def gf_email():
return __email__
def gf_url():
return __url__
def gf_description():
return __description__
def gf_collaboration():
return "4D-STAR Collaboration"
def gf_credits():
return [
"Emily M. Boudreaux - Lead Developer",
"Aaron Dotter - Co-Developer",
"4D-STAR Collaboration - Contributors"
]

View File

@@ -1,17 +0,0 @@
# Define the library
bindings_sources = files('bindings.cpp')
bindings_headers = files('bindings.h')
dependencies = [
gridfire_dep,
python3_dep,
pybind11_dep,
]
shared_module('py_gf_io',
bindings_sources,
cpp_args: ['-fvisibility=default'],
install : true,
dependencies: dependencies,
include_directories: include_directories('.')
)

View File

@@ -1,21 +0,0 @@
gf_io_trampoline_sources = files('py_io.cpp')
gf_io_trapoline_dependencies = [
gridfire_dep,
pybind11_dep,
python3_dep,
]
gf_io_trampoline_lib = static_library(
'io_trampolines',
gf_io_trampoline_sources,
include_directories: include_directories('.'),
dependencies: gf_io_trapoline_dependencies,
install: false,
)
gr_io_trampoline_dep = declare_dependency(
link_with: gf_io_trampoline_lib,
include_directories: ('.'),
dependencies: gf_io_trapoline_dependencies,
)

View File

@@ -1,10 +0,0 @@
subdir('types')
subdir('utils')
subdir('exceptions')
subdir('io')
subdir('partition')
subdir('reaction')
subdir('screening')
subdir('engine')
subdir('policy')
subdir('solver')

View File

@@ -1,19 +0,0 @@
subdir('trampoline')
# Define the library
bindings_sources = files('bindings.cpp')
bindings_headers = files('bindings.h')
dependencies = [
gridfire_dep,
python3_dep,
pybind11_dep,
]
shared_module('py_gf_partition',
bindings_sources,
cpp_args: ['-fvisibility=default'],
install : true,
dependencies: dependencies,
include_directories: include_directories('.')
)

View File

@@ -1,21 +0,0 @@
gf_partition_trampoline_sources = files('py_partition.cpp')
gf_partition_trapoline_dependencies = [
gridfire_dep,
pybind11_dep,
python3_dep,
]
gf_partition_trampoline_lib = static_library(
'partition_trampolines',
gf_partition_trampoline_sources,
include_directories: include_directories('.'),
dependencies: gf_partition_trapoline_dependencies,
install: false,
)
gr_partition_trampoline_dep = declare_dependency(
link_with: gf_partition_trampoline_lib,
include_directories: ('.'),
dependencies: gf_partition_trapoline_dependencies,
)

View File

@@ -8,7 +8,6 @@
#include "gridfire/policy/policy.h" #include "gridfire/policy/policy.h"
PYBIND11_DECLARE_HOLDER_TYPE(T, std::unique_ptr<T>, true) // Declare unique_ptr as a holder type for pybind11
namespace py = pybind11; namespace py = pybind11;
@@ -103,9 +102,30 @@ namespace {
.def( .def(
"construct", "construct",
&T::construct, &T::construct,
py::return_value_policy::reference,
"Construct the network according to the policy." "Construct the network according to the policy."
)
.def(
"get_engine_stack",
[](const T &self) {
const auto& stack = self.get_engine_stack();
std::vector<gridfire::engine::DynamicEngine*> engine_ptrs;
engine_ptrs.reserve(stack.size());
for (const auto& engine_uptr : stack) {
engine_ptrs.push_back(engine_uptr.get());
}
return engine_ptrs;
},
py::return_value_policy::reference_internal
)
.def(
"get_stack_scratch_blob",
&T::get_stack_scratch_blob
)
.def(
"get_partition_function",
&T::get_partition_function
); );
} }
} }
@@ -215,6 +235,26 @@ void register_network_policy_bindings(pybind11::module &m) {
.value("INITIALIZED_VERIFIED", gridfire::policy::NetworkPolicyStatus::INITIALIZED_VERIFIED) .value("INITIALIZED_VERIFIED", gridfire::policy::NetworkPolicyStatus::INITIALIZED_VERIFIED)
.export_values(); .export_values();
m.def("network_policy_status_to_string",
&gridfire::policy::NetworkPolicyStatusToString,
py::arg("status"),
"Convert a NetworkPolicyStatus enum value to its string representation."
);
py::class_<gridfire::policy::ConstructionResults>(m, "ConstructionResults")
.def_property_readonly("engine",
[](const gridfire::policy::ConstructionResults &self) -> const gridfire::engine::DynamicEngine& {
return self.engine;
},
py::return_value_policy::reference
)
.def_property_readonly("scratch_blob",
[](const gridfire::policy::ConstructionResults &self) {
return self.scratch_blob.get();
},
py::return_value_policy::reference_internal
);
py::class_<gridfire::policy::NetworkPolicy, PyNetworkPolicy> py_networkPolicy(m, "NetworkPolicy"); py::class_<gridfire::policy::NetworkPolicy, PyNetworkPolicy> py_networkPolicy(m, "NetworkPolicy");
py::class_<gridfire::policy::MainSequencePolicy, gridfire::policy::NetworkPolicy> py_mainSeqPolicy(m, "MainSequencePolicy"); py::class_<gridfire::policy::MainSequencePolicy, gridfire::policy::NetworkPolicy> py_mainSeqPolicy(m, "MainSequencePolicy");
py_mainSeqPolicy.def( py_mainSeqPolicy.def(

View File

@@ -1,19 +0,0 @@
# Define the library
subdir('trampoline')
bindings_sources = files('bindings.cpp')
bindings_headers = files('bindings.h')
dependencies = [
gridfire_dep,
python3_dep,
pybind11_dep,
]
shared_module('py_gf_policy',
bindings_sources,
cpp_args: ['-fvisibility=default'],
install : true,
dependencies: dependencies,
include_directories: include_directories('.')
)

View File

@@ -1,21 +0,0 @@
gf_policy_trampoline_sources = files('py_policy.cpp')
gf_policy_trapoline_dependencies = [
gridfire_dep,
pybind11_dep,
python3_dep,
]
gf_policy_trampoline_lib = static_library(
'policy_trampolines',
gf_policy_trampoline_sources,
include_directories: include_directories('.'),
dependencies: gf_policy_trapoline_dependencies,
install: false,
)
gr_policy_trampoline_dep = declare_dependency(
link_with: gf_policy_trampoline_lib,
include_directories: ('.'),
dependencies: gf_policy_trapoline_dependencies,
)

View File

@@ -39,9 +39,9 @@ const gridfire::reaction::ReactionSet& PyNetworkPolicy::get_seed_reactions() con
); );
} }
gridfire::engine::DynamicEngine& PyNetworkPolicy::construct() { gridfire::policy::ConstructionResults PyNetworkPolicy::construct() {
PYBIND11_OVERRIDE_PURE( PYBIND11_OVERRIDE_PURE(
gridfire::engine::DynamicEngine&, gridfire::policy::ConstructionResults,
gridfire::policy::NetworkPolicy, gridfire::policy::NetworkPolicy,
construct construct
); );
@@ -79,6 +79,14 @@ const std::unique_ptr<gridfire::partition::PartitionFunction>& PyNetworkPolicy::
); );
} }
std::unique_ptr<gridfire::engine::scratch::StateBlob> PyNetworkPolicy::get_stack_scratch_blob() const {
PYBIND11_OVERRIDE_PURE(
std::unique_ptr<gridfire::engine::scratch::StateBlob>,
gridfire::policy::NetworkPolicy,
get_stack_scratch_blob
);
}
const gridfire::reaction::ReactionSet &PyReactionChainPolicy::get_reactions() const { const gridfire::reaction::ReactionSet &PyReactionChainPolicy::get_reactions() const {
PYBIND11_OVERRIDE_PURE( PYBIND11_OVERRIDE_PURE(
const gridfire::reaction::ReactionSet &, const gridfire::reaction::ReactionSet &,

View File

@@ -13,7 +13,7 @@ public:
[[nodiscard]] const gridfire::reaction::ReactionSet& get_seed_reactions() const override; [[nodiscard]] const gridfire::reaction::ReactionSet& get_seed_reactions() const override;
[[nodiscard]] gridfire::engine::DynamicEngine& construct() override; [[nodiscard]] gridfire::policy::ConstructionResults construct() override;
[[nodiscard]] gridfire::policy::NetworkPolicyStatus get_status() const override; [[nodiscard]] gridfire::policy::NetworkPolicyStatus get_status() const override;
@@ -22,6 +22,8 @@ public:
[[nodiscard]] std::vector<gridfire::engine::EngineTypes> get_engine_types_stack() const override; [[nodiscard]] std::vector<gridfire::engine::EngineTypes> get_engine_types_stack() const override;
[[nodiscard]] const std::unique_ptr<gridfire::partition::PartitionFunction>& get_partition_function() const override; [[nodiscard]] const std::unique_ptr<gridfire::partition::PartitionFunction>& get_partition_function() const override;
[[nodiscard]] std::unique_ptr<gridfire::engine::scratch::StateBlob> get_stack_scratch_blob() const override;
}; };
class PyReactionChainPolicy final : public gridfire::policy::ReactionChainPolicy { class PyReactionChainPolicy final : public gridfire::policy::ReactionChainPolicy {

View File

@@ -1,17 +0,0 @@
# Define the library
bindings_sources = files('bindings.cpp')
bindings_headers = files('bindings.h')
dependencies = [
gridfire_dep,
python3_dep,
pybind11_dep,
]
shared_module('py_gf_reaction',
bindings_sources,
cpp_args: ['-fvisibility=default'],
install : true,
dependencies: dependencies,
include_directories: include_directories('.')
)

View File

@@ -1,19 +0,0 @@
subdir('trampoline')
# Define the library
bindings_sources = files('bindings.cpp')
bindings_headers = files('bindings.h')
dependencies = [
gridfire_dep,
python3_dep,
pybind11_dep,
]
shared_module('py_gf_screening',
bindings_sources,
cpp_args: ['-fvisibility=default'],
install : true,
dependencies: dependencies,
include_directories: include_directories('.')
)

View File

@@ -1,21 +0,0 @@
gf_screening_trampoline_sources = files('py_screening.cpp')
gf_screening_trapoline_dependencies = [
gridfire_dep,
pybind11_dep,
python3_dep,
]
gf_screening_trampoline_lib = static_library(
'screening_trampolines',
gf_screening_trampoline_sources,
include_directories: include_directories('.'),
dependencies: gf_screening_trapoline_dependencies,
install: false,
)
gr_screening_trampoline_dep = declare_dependency(
link_with: gf_screening_trampoline_lib,
include_directories: ('.'),
dependencies: gf_screening_trapoline_dependencies,
)

View File

@@ -7,125 +7,226 @@
#include "bindings.h" #include "bindings.h"
#include "gridfire/solver/strategies/CVODE_solver_strategy.h" #include "gridfire/solver/strategies/PointSolver.h"
#include "gridfire/engine/scratchpads/blob.h"
#include "trampoline/py_solver.h" #include "trampoline/py_solver.h"
namespace py = pybind11; namespace py = pybind11;
void register_solver_bindings(const py::module &m) { void register_solver_bindings(const py::module &m) {
auto py_solver_context_base = py::class_<gridfire::solver::SolverContextBase>(m, "SolverContextBase"); auto py_cvode_timestep_context = py::class_<gridfire::solver::PointSolverTimestepContext>(m, "PointSolverTimestepContext");
py_cvode_timestep_context.def_readonly("t", &gridfire::solver::PointSolverTimestepContext::t);
auto py_cvode_timestep_context = py::class_<gridfire::solver::CVODESolverStrategy::TimestepContext, gridfire::solver::SolverContextBase>(m, "CVODETimestepContext");
py_cvode_timestep_context.def_readonly("t", &gridfire::solver::CVODESolverStrategy::TimestepContext::t);
py_cvode_timestep_context.def_property_readonly( py_cvode_timestep_context.def_property_readonly(
"state", "state",
[](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> std::vector<double> { [](const gridfire::solver::PointSolverTimestepContext& self) -> std::vector<double> {
const sunrealtype* nvec_data = N_VGetArrayPointer(self.state); const sunrealtype* nvec_data = N_VGetArrayPointer(self.state);
const sunindextype length = N_VGetLength(self.state); const sunindextype length = N_VGetLength(self.state);
return std::vector<double>(nvec_data, nvec_data + length); return {nvec_data, nvec_data + length};
} }
); );
py_cvode_timestep_context.def_readonly("dt", &gridfire::solver::CVODESolverStrategy::TimestepContext::dt); py_cvode_timestep_context.def_readonly("dt", &gridfire::solver::PointSolverTimestepContext::dt);
py_cvode_timestep_context.def_readonly("last_step_time", &gridfire::solver::CVODESolverStrategy::TimestepContext::last_step_time); py_cvode_timestep_context.def_readonly("last_step_time", &gridfire::solver::PointSolverTimestepContext::last_step_time);
py_cvode_timestep_context.def_readonly("T9", &gridfire::solver::CVODESolverStrategy::TimestepContext::T9); py_cvode_timestep_context.def_readonly("T9", &gridfire::solver::PointSolverTimestepContext::T9);
py_cvode_timestep_context.def_readonly("rho", &gridfire::solver::CVODESolverStrategy::TimestepContext::rho); py_cvode_timestep_context.def_readonly("rho", &gridfire::solver::PointSolverTimestepContext::rho);
py_cvode_timestep_context.def_readonly("num_steps", &gridfire::solver::CVODESolverStrategy::TimestepContext::num_steps); py_cvode_timestep_context.def_readonly("num_steps", &gridfire::solver::PointSolverTimestepContext::num_steps);
py_cvode_timestep_context.def_readonly("currentConvergenceFailures", &gridfire::solver::CVODESolverStrategy::TimestepContext::currentConvergenceFailures); py_cvode_timestep_context.def_readonly("currentConvergenceFailures", &gridfire::solver::PointSolverTimestepContext::currentConvergenceFailures);
py_cvode_timestep_context.def_readonly("currentNonlinearIterations", &gridfire::solver::CVODESolverStrategy::TimestepContext::currentNonlinearIterations); py_cvode_timestep_context.def_readonly("currentNonlinearIterations", &gridfire::solver::PointSolverTimestepContext::currentNonlinearIterations);
py_cvode_timestep_context.def_property_readonly( py_cvode_timestep_context.def_property_readonly(
"engine", "engine",
[](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> const gridfire::engine::DynamicEngine& { [](const gridfire::solver::PointSolverTimestepContext& self) -> const gridfire::engine::DynamicEngine& {
return self.engine; return self.engine;
} }
); );
py_cvode_timestep_context.def_property_readonly( py_cvode_timestep_context.def_property_readonly(
"networkSpecies", "networkSpecies",
[](const gridfire::solver::CVODESolverStrategy::TimestepContext& self) -> std::vector<fourdst::atomic::Species> { [](const gridfire::solver::PointSolverTimestepContext& self) -> std::vector<fourdst::atomic::Species> {
return self.networkSpecies; return self.networkSpecies;
} }
); );
py_cvode_timestep_context.def_property_readonly(
"state_ctx",
[](const gridfire::solver::PointSolverTimestepContext& self) {
return &(self.state_ctx);
},
py::return_value_policy::reference_internal
);
auto py_dynamic_network_solver_strategy = py::class_<gridfire::solver::DynamicNetworkSolverStrategy, PyDynamicNetworkSolverStrategy>(m, "DynamicNetworkSolverStrategy");
py_dynamic_network_solver_strategy.def( auto py_solver_context_base = py::class_<gridfire::solver::SolverContextBase>(m, "SolverContextBase");
auto py_point_solver_context = py::class_<gridfire::solver::PointSolverContext, gridfire::solver::SolverContextBase>(m, "PointSolverContext");
py_point_solver_context
.def_readonly(
"sun_ctx", &gridfire::solver::PointSolverContext::sun_ctx
)
.def_readonly(
"cvode_mem", &gridfire::solver::PointSolverContext::cvode_mem
)
.def_readonly(
"Y", &gridfire::solver::PointSolverContext::Y
)
.def_readonly(
"YErr", &gridfire::solver::PointSolverContext::YErr
)
.def_readonly(
"J", &gridfire::solver::PointSolverContext::J
)
.def_readonly(
"LS", &gridfire::solver::PointSolverContext::LS
)
.def_property_readonly(
"engine_ctx",
[](const gridfire::solver::PointSolverContext& self) -> gridfire::engine::scratch::StateBlob& {
return *(self.engine_ctx);
},
py::return_value_policy::reference
)
.def_readonly(
"num_steps", &gridfire::solver::PointSolverContext::num_steps
)
.def_property(
"abs_tol",
[](const gridfire::solver::PointSolverContext& self) -> double {
return self.abs_tol.value();
},
[](gridfire::solver::PointSolverContext& self, double abs_tol) -> void {
self.abs_tol = abs_tol;
}
)
.def_property(
"rel_tol",
[](const gridfire::solver::PointSolverContext& self) -> double {
return self.rel_tol.value();
},
[](gridfire::solver::PointSolverContext& self, double rel_tol) -> void {
self.rel_tol = rel_tol;
}
)
.def_property(
"stdout_logging",
[](const gridfire::solver::PointSolverContext& self) -> bool {
return self.stdout_logging;
},
[](gridfire::solver::PointSolverContext& self, const bool enable) -> void {
self.stdout_logging = enable;
}
)
.def_property(
"detailed_logging",
[](const gridfire::solver::PointSolverContext& self) -> bool {
return self.detailed_step_logging;
},
[](gridfire::solver::PointSolverContext& self, const bool enable) -> void {
self.detailed_step_logging = enable;
}
)
.def_property(
"callback",
[](const gridfire::solver::PointSolverContext& self) -> std::optional<std::function<void(const gridfire::solver::PointSolverTimestepContext&)>> {
return self.callback;
},
[](gridfire::solver::PointSolverContext& self, const std::optional<std::function<void(const gridfire::solver::PointSolverTimestepContext&)>>& cb) {
self.callback = cb;
}
)
.def("reset_all", &gridfire::solver::PointSolverContext::reset_all)
.def("reset_user", &gridfire::solver::PointSolverContext::reset_user)
.def("reset_cvode", &gridfire::solver::PointSolverContext::reset_cvode)
.def("clear_context", &gridfire::solver::PointSolverContext::clear_context)
.def("init_context", &gridfire::solver::PointSolverContext::init_context)
.def("has_context", &gridfire::solver::PointSolverContext::has_context)
.def("init", &gridfire::solver::PointSolverContext::init)
.def(py::init<const gridfire::engine::scratch::StateBlob&>(), py::arg("engine_ctx"));
auto py_single_zone_dynamic_network_solver = py::class_<gridfire::solver::SingleZoneDynamicNetworkSolver, PySingleZoneDynamicNetworkSolver>(m, "SingleZoneDynamicNetworkSolver");
py_single_zone_dynamic_network_solver.def(
"evaluate", "evaluate",
&gridfire::solver::DynamicNetworkSolverStrategy::evaluate, &gridfire::solver::SingleZoneDynamicNetworkSolver::evaluate,
py::arg("solver_ctx"),
py::arg("netIn"), py::arg("netIn"),
"evaluate the dynamic engine using the dynamic engine class" "evaluate the dynamic engine using the dynamic engine class for a single zone"
);
auto py_multi_zone_dynamic_network_solver = py::class_<gridfire::solver::MultiZoneDynamicNetworkSolver, PyMultiZoneDynamicNetworkSolver>(m, "MultiZoneDynamicNetworkSolver");
py_multi_zone_dynamic_network_solver.def(
"evaluate",
&gridfire::solver::MultiZoneDynamicNetworkSolver::evaluate,
py::arg("solver_ctx"),
py::arg("netIns"),
"evaluate the dynamic engine using the dynamic engine class for multiple zones (using openmp if available)"
); );
auto py_point_solver = py::class_<gridfire::solver::PointSolver, gridfire::solver::SingleZoneDynamicNetworkSolver>(m, "PointSolver");
py_dynamic_network_solver_strategy.def( py_point_solver.def(
"describe_callback_context",
&gridfire::solver::DynamicNetworkSolverStrategy::describe_callback_context,
"Get a structure representing what data is in the callback context in a human readable format"
);
auto py_cvode_solver_strategy = py::class_<gridfire::solver::CVODESolverStrategy, gridfire::solver::DynamicNetworkSolverStrategy>(m, "CVODESolverStrategy");
py_cvode_solver_strategy.def(
py::init<gridfire::engine::DynamicEngine&>(), py::init<gridfire::engine::DynamicEngine&>(),
py::arg("engine"), py::arg("engine"),
"Initialize the CVODESolverStrategy object." "Initialize the PointSolver object."
); );
py_cvode_solver_strategy.def( py_point_solver.def(
"evaluate", "evaluate",
py::overload_cast<const gridfire::NetIn&, bool>(&gridfire::solver::CVODESolverStrategy::evaluate), py::overload_cast<gridfire::solver::SolverContextBase&, const gridfire::NetIn&, bool, bool>(&gridfire::solver::PointSolver::evaluate, py::const_),
py::arg("solver_ctx"),
py::arg("netIn"), py::arg("netIn"),
py::arg("display_trigger") = false, py::arg("display_trigger") = false,
py::arg("force_reinitialization") = false,
"evaluate the dynamic engine using the dynamic engine class" "evaluate the dynamic engine using the dynamic engine class"
); );
py_cvode_solver_strategy.def( auto py_grid_solver_context = py::class_<gridfire::solver::GridSolverContext, gridfire::solver::SolverContextBase>(m, "GridSolverContext");
"get_stdout_logging_enabled", py_grid_solver_context.def(py::init<const gridfire::engine::scratch::StateBlob&>(), py::arg("ctx_template"));
&gridfire::solver::CVODESolverStrategy::get_stdout_logging_enabled, py_grid_solver_context.def("init", &gridfire::solver::GridSolverContext::init);
"Check if solver logging to standard output is enabled." py_grid_solver_context.def("reset", &gridfire::solver::GridSolverContext::reset);
py_grid_solver_context.def("set_callback", py::overload_cast<const std::function<void(const gridfire::solver::TimestepContextBase&)>&>(&gridfire::solver::GridSolverContext::set_callback) , py::arg("callback"));
py_grid_solver_context.def("set_callback", py::overload_cast<const std::function<void(const gridfire::solver::TimestepContextBase&)>&, size_t>(&gridfire::solver::GridSolverContext::set_callback) , py::arg("callback"), py::arg("zone_idx"));
py_grid_solver_context.def("clear_callback", py::overload_cast<>(&gridfire::solver::GridSolverContext::clear_callback));
py_grid_solver_context.def("clear_callback", py::overload_cast<size_t>(&gridfire::solver::GridSolverContext::clear_callback), py::arg("zone_idx"));
py_grid_solver_context.def_property(
"stdout_logging",
[](const gridfire::solver::GridSolverContext& self) -> bool {
return self.zone_stdout_logging;
},
[](gridfire::solver::GridSolverContext& self, const bool enable) -> void {
self.zone_stdout_logging = enable;
}
)
.def_property(
"detailed_logging",
[](const gridfire::solver::GridSolverContext& self) -> bool {
return self.zone_detailed_logging;
},
[](gridfire::solver::GridSolverContext& self, const bool enable) -> void {
self.zone_detailed_logging = enable;
}
)
.def_property(
"zone_completion_logging",
[](const gridfire::solver::GridSolverContext& self) -> bool {
return self.zone_completion_logging;
},
[](gridfire::solver::GridSolverContext& self, const bool enable) -> void {
self.zone_completion_logging = enable;
}
); );
py_cvode_solver_strategy.def( auto py_grid_solver = py::class_<gridfire::solver::GridSolver, gridfire::solver::MultiZoneDynamicNetworkSolver>(m, "GridSolver");
"set_stdout_logging_enabled", py_grid_solver.def(
&gridfire::solver::CVODESolverStrategy::set_stdout_logging_enabled, py::init<const gridfire::engine::DynamicEngine&, const gridfire::solver::SingleZoneDynamicNetworkSolver&>(),
py::arg("logging_enabled"), py::arg("engine"),
"Enable logging to standard output." py::arg("solver"),
"Initialize the GridSolver object."
); );
py_cvode_solver_strategy.def( py_grid_solver.def(
"set_absTol", "evaluate",
&gridfire::solver::CVODESolverStrategy::set_absTol, &gridfire::solver::GridSolver::evaluate,
py::arg("absTol"), py::arg("solver_ctx"),
"Set the absolute tolerance for the CVODE solver." py::arg("netIns"),
"evaluate the dynamic engine using the dynamic engine class"
); );
py_cvode_solver_strategy.def(
"set_relTol",
&gridfire::solver::CVODESolverStrategy::set_relTol,
py::arg("relTol"),
"Set the relative tolerance for the CVODE solver."
);
py_cvode_solver_strategy.def(
"get_absTol",
&gridfire::solver::CVODESolverStrategy::get_absTol,
"Get the absolute tolerance for the CVODE solver."
);
py_cvode_solver_strategy.def(
"get_relTol",
&gridfire::solver::CVODESolverStrategy::get_relTol,
"Get the relative tolerance for the CVODE solver."
);
py_cvode_solver_strategy.def(
"set_callback",
[](
gridfire::solver::CVODESolverStrategy& self,
std::function<void(const gridfire::solver::CVODESolverStrategy::TimestepContext&)> cb
) {
self.set_callback(std::any(cb));
},
py::arg("cb"),
"Set a callback function which will run at the end of every successful timestep"
);
} }

View File

@@ -1,17 +0,0 @@
# Define the library
bindings_sources = files('bindings.cpp')
bindings_headers = files('bindings.h')
dependencies = [
gridfire_dep,
python3_dep,
pybind11_dep,
]
shared_module('py_gf_solver',
bindings_sources,
cpp_args: ['-fvisibility=default'],
install : true,
dependencies: dependencies,
include_directories: include_directories('.')
)

View File

@@ -1,21 +0,0 @@
gf_solver_trampoline_sources = files('py_solver.cpp')
gf_solver_trapoline_dependencies = [
gridfire_dep,
pybind11_dep,
python3_dep,
]
gf_solver_trampoline_lib = static_library(
'solver_trampolines',
gf_solver_trampoline_sources,
include_directories: include_directories('.'),
dependencies: gf_solver_trapoline_dependencies,
install: false,
)
gr_solver_trampoline_dep = declare_dependency(
link_with: gf_solver_trampoline_lib,
include_directories: ('.'),
dependencies: gf_solver_trapoline_dependencies,
)

View File

@@ -13,38 +13,63 @@
namespace py = pybind11; namespace py = pybind11;
gridfire::NetOut PyDynamicNetworkSolverStrategy::evaluate(const gridfire::NetIn &netIn) { gridfire::NetOut PySingleZoneDynamicNetworkSolver::evaluate(
gridfire::solver::SolverContextBase &solver_ctx,
const gridfire::NetIn &netIn
) const {
PYBIND11_OVERRIDE_PURE( PYBIND11_OVERRIDE_PURE(
gridfire::NetOut, // Return type gridfire::NetOut,
gridfire::solver::DynamicNetworkSolverStrategy, // Base class gridfire::solver::SingleZoneDynamicNetworkSolver,
evaluate, // Method name evaluate,
netIn // Arguments solver_ctx,
netIn
); );
} }
void PyDynamicNetworkSolverStrategy::set_callback(const std::any &callback) { std::vector<gridfire::NetOut> PyMultiZoneDynamicNetworkSolver::evaluate(
gridfire::solver::SolverContextBase &solver_ctx,
const std::vector<gridfire::NetIn> &netIns
) const {
PYBIND11_OVERRIDE_PURE( PYBIND11_OVERRIDE_PURE(
void, std::vector<gridfire::NetOut>,
gridfire::solver::DynamicNetworkSolverStrategy, // Base class gridfire::solver::MultiZoneDynamicNetworkSolver,
set_callback, // Method name evaluate,
callback // Arguments solver_ctx,
netIns
); );
} }
std::vector<std::tuple<std::string, std::string>> PyDynamicNetworkSolverStrategy::describe_callback_context() const { std::vector<std::tuple<std::string, std::string>> PyTimestepContextBase::describe() const {
using DescriptionVector = std::vector<std::tuple<std::string, std::string>>; using ReturnType = std::vector<std::tuple<std::string, std::string>>;
PYBIND11_OVERRIDE_PURE( PYBIND11_OVERRIDE_PURE(
DescriptionVector, // Return type ReturnType,
gridfire::solver::DynamicNetworkSolverStrategy, // Base class gridfire::solver::TimestepContextBase,
describe_callback_context // Method name
);
}
std::vector<std::tuple<std::string, std::string>> PySolverContextBase::describe() const {
using DescriptionVector = std::vector<std::tuple<std::string, std::string>>;
PYBIND11_OVERRIDE_PURE(
DescriptionVector,
gridfire::solver::SolverContextBase,
describe describe
); );
} }
void PySolverContextBase::init() {
PYBIND11_OVERRIDE_PURE(
void,
gridfire::solver::SolverContextBase,
init
);
}
void PySolverContextBase::set_stdout_logging(bool enable) {
PYBIND11_OVERRIDE_PURE(
void,
gridfire::solver::SolverContextBase,
set_stdout_logging,
enable
);
}
void PySolverContextBase::set_detailed_logging(bool enable) {
PYBIND11_OVERRIDE_PURE(
void,
gridfire::solver::SolverContextBase,
set_detailed_logging,
enable
);
}

View File

@@ -7,14 +7,37 @@
#include <string> #include <string>
#include <any> #include <any>
class PyDynamicNetworkSolverStrategy final : public gridfire::solver::DynamicNetworkSolverStrategy { class PySingleZoneDynamicNetworkSolver final : public gridfire::solver::SingleZoneDynamicNetworkSolver {
explicit PyDynamicNetworkSolverStrategy(gridfire::engine::DynamicEngine &engine) : gridfire::solver::DynamicNetworkSolverStrategy(engine) {} public:
gridfire::NetOut evaluate(const gridfire::NetIn &netIn) override; explicit PySingleZoneDynamicNetworkSolver(const gridfire::engine::DynamicEngine &engine) : gridfire::solver::SingleZoneDynamicNetworkSolver(engine) {}
void set_callback(const std::any &callback) override;
[[nodiscard]] std::vector<std::tuple<std::string, std::string>> describe_callback_context() const override; gridfire::NetOut evaluate(
gridfire::solver::SolverContextBase &solver_ctx,
const gridfire::NetIn &netIn
) const override;
};
class PyMultiZoneDynamicNetworkSolver final : public gridfire::solver::MultiZoneDynamicNetworkSolver {
public:
explicit PyMultiZoneDynamicNetworkSolver(
const gridfire::engine::DynamicEngine &engine,
const gridfire::solver::SingleZoneDynamicNetworkSolver &local_solver
) : gridfire::solver::MultiZoneDynamicNetworkSolver(engine, local_solver) {}
std::vector<gridfire::NetOut> evaluate(
gridfire::solver::SolverContextBase &solver_ctx,
const std::vector<gridfire::NetIn> &netIns
) const override;
};
class PyTimestepContextBase final : public gridfire::solver::TimestepContextBase {
public:
[[nodiscard]] std::vector<std::tuple<std::string, std::string>> describe() const override;
}; };
class PySolverContextBase final : public gridfire::solver::SolverContextBase { class PySolverContextBase final : public gridfire::solver::SolverContextBase {
public: public:
[[nodiscard]] std::vector<std::tuple<std::string, std::string>> describe() const override; void init() override;
void set_stdout_logging(bool enable) override;
void set_detailed_logging(bool enable) override;
}; };

View File

@@ -1,17 +0,0 @@
# Define the library
bindings_sources = files('bindings.cpp')
bindings_headers = files('bindings.h')
dependencies = [
gridfire_dep,
python3_dep,
pybind11_dep,
]
shared_module('py_gf_types',
bindings_sources,
cpp_args: ['-fvisibility=default'],
install : true,
dependencies: dependencies,
include_directories: include_directories('.')
)

View File

@@ -12,6 +12,7 @@ namespace py = pybind11;
void register_utils_bindings(py::module &m) { void register_utils_bindings(py::module &m) {
m.def("formatNuclearTimescaleLogString", m.def("formatNuclearTimescaleLogString",
&gridfire::utils::formatNuclearTimescaleLogString, &gridfire::utils::formatNuclearTimescaleLogString,
py::arg("ctx"),
py::arg("engine"), py::arg("engine"),
py::arg("Y"), py::arg("Y"),
py::arg("T9"), py::arg("T9"),

View File

@@ -1,17 +0,0 @@
# Define the library
bindings_sources = files('bindings.cpp')
bindings_headers = files('bindings.h')
dependencies = [
gridfire_dep,
python3_dep,
pybind11_dep,
]
shared_module('py_gf_utils',
bindings_sources,
cpp_args: ['-fvisibility=default'],
install : true,
dependencies: dependencies,
include_directories: include_directories('.')
)

File diff suppressed because one or more lines are too long

View File

@@ -2,6 +2,7 @@
Python bindings for the fourdst utility modules which are a part of the 4D-STAR project. Python bindings for the fourdst utility modules which are a part of the 4D-STAR project.
""" """
from __future__ import annotations from __future__ import annotations
from . import config
from . import engine from . import engine
from . import exceptions from . import exceptions
from . import io from . import io
@@ -12,4 +13,4 @@ from . import screening
from . import solver from . import solver
from . import type from . import type
from . import utils from . import utils
__all__: list[str] = ['engine', 'exceptions', 'io', 'partition', 'policy', 'reaction', 'screening', 'solver', 'type', 'utils'] __all__: list[str] = ['config', 'engine', 'exceptions', 'io', 'partition', 'policy', 'reaction', 'screening', 'solver', 'type', 'utils']

View File

@@ -0,0 +1,47 @@
"""
GridFire configuration bindings
"""
from __future__ import annotations
import typing
__all__: list[str] = ['AdaptiveEngineViewConfig', 'CVODESolverConfig', 'EngineConfig', 'EngineViewConfig', 'GridFireConfig', 'SolverConfig']
class AdaptiveEngineViewConfig:
def __init__(self) -> None:
...
@property
def relativeCullingThreshold(self) -> float:
...
@relativeCullingThreshold.setter
def relativeCullingThreshold(self, arg0: typing.SupportsFloat) -> None:
...
class CVODESolverConfig:
def __init__(self) -> None:
...
@property
def absTol(self) -> float:
...
@absTol.setter
def absTol(self, arg0: typing.SupportsFloat) -> None:
...
@property
def relTol(self) -> float:
...
@relTol.setter
def relTol(self, arg0: typing.SupportsFloat) -> None:
...
class EngineConfig:
views: EngineViewConfig
def __init__(self) -> None:
...
class EngineViewConfig:
adaptiveEngineView: AdaptiveEngineViewConfig
def __init__(self) -> None:
...
class GridFireConfig:
engine: EngineConfig
solver: SolverConfig
def __init__(self) -> None:
...
class SolverConfig:
cvode: CVODESolverConfig
def __init__(self) -> None:
...

View File

@@ -14,110 +14,81 @@ import numpy
import numpy.typing import numpy.typing
import typing import typing
from . import diagnostics from . import diagnostics
__all__: list[str] = ['ACTIVE', 'ADAPTIVE_ENGINE_VIEW', 'AdaptiveEngineView', 'BuildDepthType', 'DEFAULT', 'DEFINED_ENGINE_VIEW', 'DefinedEngineView', 'DynamicEngine', 'EQUILIBRIUM', 'Engine', 'EngineTypes', 'FILE_DEFINED_ENGINE_VIEW', 'FULL_SUCCESS', 'FifthOrder', 'FileDefinedEngineView', 'FourthOrder', 'Full', 'GRAPH_ENGINE', 'GraphEngine', 'INACTIVE_FLOW', 'MAX_ITERATIONS_REACHED', 'MULTISCALE_PARTITIONING_ENGINE_VIEW', 'MultiscalePartitioningEngineView', 'NONE', 'NOT_PRESENT', 'NO_SPECIES_TO_PRIME', 'NetworkBuildDepth', 'NetworkConstructionFlags', 'NetworkJacobian', 'NetworkPrimingEngineView', 'PRIMING_ENGINE_VIEW', 'PrimingReport', 'PrimingReportStatus', 'REACLIB', 'REACLIB_STRONG', 'REACLIB_WEAK', 'SecondOrder', 'Shallow', 'SparsityPattern', 'SpeciesStatus', 'StepDerivatives', 'ThirdOrder', 'WRL_BETA_MINUS', 'WRL_BETA_PLUS', 'WRL_ELECTRON_CAPTURE', 'WRL_POSITRON_CAPTURE', 'WRL_WEAK', 'build_nuclear_network', 'diagnostics', 'primeNetwork', 'regularize_jacobian'] from . import scratchpads
__all__: list[str] = ['ACTIVE', 'ADAPTIVE_ENGINE_VIEW', 'AdaptiveEngineView', 'BuildDepthType', 'DEFAULT', 'DEFINED_ENGINE_VIEW', 'DefinedEngineView', 'DynamicEngine', 'EQUILIBRIUM', 'Engine', 'EngineTypes', 'FILE_DEFINED_ENGINE_VIEW', 'FULL_SUCCESS', 'FifthOrder', 'FileDefinedEngineView', 'FourthOrder', 'Full', 'GRAPH_ENGINE', 'GraphEngine', 'INACTIVE_FLOW', 'MAX_ITERATIONS_REACHED', 'MULTISCALE_PARTITIONING_ENGINE_VIEW', 'MultiscalePartitioningEngineView', 'NONE', 'NOT_PRESENT', 'NO_SPECIES_TO_PRIME', 'NetworkBuildDepth', 'NetworkConstructionFlags', 'NetworkJacobian', 'NetworkPrimingEngineView', 'PRIMING_ENGINE_VIEW', 'PrimingReport', 'PrimingReportStatus', 'REACLIB', 'REACLIB_STRONG', 'REACLIB_WEAK', 'SecondOrder', 'Shallow', 'SparsityPattern', 'SpeciesStatus', 'StepDerivatives', 'ThirdOrder', 'WRL_BETA_MINUS', 'WRL_BETA_PLUS', 'WRL_ELECTRON_CAPTURE', 'WRL_POSITRON_CAPTURE', 'WRL_WEAK', 'build_nuclear_network', 'diagnostics', 'primeNetwork', 'regularize_jacobian', 'scratchpads']
class AdaptiveEngineView(DynamicEngine): class AdaptiveEngineView(DynamicEngine):
def __init__(self, baseEngine: DynamicEngine) -> None: def __init__(self, baseEngine: DynamicEngine) -> None:
""" """
Construct an adaptive engine view with a base engine. Construct an adaptive engine view with a base engine.
""" """
def calculateEpsDerivatives(self, comp: ..., T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> ...: def calculateEpsDerivatives(self, ctx: scratchpads.StateBlob, comp: ..., T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> ...:
""" """
Calculate deps/dT and deps/drho Calculate deps/dT and deps/drho
""" """
def calculateMolarReactionFlow(self: DynamicEngine, reaction: ..., comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> float: def calculateMolarReactionFlow(self: DynamicEngine, ctx: scratchpads.StateBlob, reaction: ..., comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> float:
""" """
Calculate the molar reaction flow for a given reaction. Calculate the molar reaction flow for a given reaction.
""" """
def calculateRHSAndEnergy(self: DynamicEngine, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> StepDerivatives: def calculateRHSAndEnergy(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> StepDerivatives:
""" """
Calculate the right-hand side (dY/dt) and energy generation rate. Calculate the right-hand side (dY/dt) and energy generation rate.
""" """
def collectComposition(self, composition: ..., T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> fourdst._phys.composition.Composition: def collectComposition(self, ctx: scratchpads.StateBlob, composition: ..., T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> fourdst._phys.composition.Composition:
""" """
Recursively collect composition from current engine and any sub engines if they exist. Recursively collect composition from current engine and any sub engines if they exist.
""" """
@typing.overload @typing.overload
def generateJacobianMatrix(self: DynamicEngine, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> NetworkJacobian: def generateJacobianMatrix(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> NetworkJacobian:
""" """
Generate the Jacobian matrix for the current state. Generate the Jacobian matrix for the current state.
""" """
@typing.overload @typing.overload
def generateJacobianMatrix(self: DynamicEngine, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, activeSpecies: collections.abc.Sequence[fourdst._phys.atomic.Species]) -> NetworkJacobian: def generateJacobianMatrix(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, activeSpecies: collections.abc.Sequence[fourdst._phys.atomic.Species]) -> NetworkJacobian:
""" """
Generate the jacobian matrix only for the subset of the matrix representing the active species. Generate the jacobian matrix only for the subset of the matrix representing the active species.
""" """
@typing.overload @typing.overload
def generateJacobianMatrix(self: DynamicEngine, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, sparsityPattern: collections.abc.Sequence[tuple[typing.SupportsInt, typing.SupportsInt]]) -> NetworkJacobian: def generateJacobianMatrix(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, sparsityPattern: collections.abc.Sequence[tuple[typing.SupportsInt, typing.SupportsInt]]) -> NetworkJacobian:
""" """
Generate the jacobian matrix for the given sparsity pattern Generate the jacobian matrix for the given sparsity pattern
""" """
def generateStoichiometryMatrix(self) -> None:
...
def getBaseEngine(self) -> DynamicEngine: def getBaseEngine(self) -> DynamicEngine:
""" """
Get the base engine associated with this adaptive engine view. Get the base engine associated with this adaptive engine view.
""" """
def getDepth(self) -> gridfire._gridfire.engine.NetworkBuildDepth | int: def getNetworkReactions(self, arg0: scratchpads.StateBlob) -> gridfire._gridfire.reaction.ReactionSet:
"""
Get the current build depth of the engine.
"""
def getNetworkReactions(self) -> gridfire._gridfire.reaction.ReactionSet:
""" """
Get the set of logical reactions in the network. Get the set of logical reactions in the network.
""" """
def getNetworkSpecies(self) -> list[fourdst._phys.atomic.Species]: def getNetworkSpecies(self, arg0: scratchpads.StateBlob) -> list[fourdst._phys.atomic.Species]:
""" """
Get the list of species in the network. Get the list of species in the network.
""" """
def getScreeningModel(self) -> gridfire._gridfire.screening.ScreeningType: def getScreeningModel(self, arg0: scratchpads.StateBlob) -> gridfire._gridfire.screening.ScreeningType:
""" """
Get the current screening model of the engine. Get the current screening model of the engine.
""" """
def getSpeciesDestructionTimescales(self: DynamicEngine, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> dict[fourdst._phys.atomic.Species, float]: def getSpeciesDestructionTimescales(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> dict[fourdst._phys.atomic.Species, float]:
""" """
Get the destruction timescales for each species in the network. Get the destruction timescales for each species in the network.
""" """
def getSpeciesIndex(self, species: fourdst._phys.atomic.Species) -> int: def getSpeciesIndex(self, ctx: scratchpads.StateBlob, species: fourdst._phys.atomic.Species) -> int:
""" """
Get the index of a species in the network. Get the index of a species in the network.
""" """
def getSpeciesStatus(self, species: fourdst._phys.atomic.Species) -> SpeciesStatus: def getSpeciesStatus(self, ctx: scratchpads.StateBlob, species: fourdst._phys.atomic.Species) -> SpeciesStatus:
""" """
Get the status of a species in the network. Get the status of a species in the network.
""" """
def getSpeciesTimescales(self: DynamicEngine, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> dict[fourdst._phys.atomic.Species, float]: def getSpeciesTimescales(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> dict[fourdst._phys.atomic.Species, float]:
""" """
Get the timescales for each species in the network. Get the timescales for each species in the network.
""" """
def getStoichiometryMatrixEntry(self, species: fourdst._phys.atomic.Species, reaction: ...) -> int: def primeEngine(self, ctx: scratchpads.StateBlob, netIn: gridfire._gridfire.type.NetIn) -> PrimingReport:
"""
Get an entry from the stoichiometry matrix.
"""
def isStale(self, netIn: gridfire._gridfire.type.NetIn) -> bool:
"""
Check if the engine is stale based on the provided NetIn object.
"""
def mapNetInToMolarAbundanceVector(self, netIn: gridfire._gridfire.type.NetIn) -> list[float]:
"""
Map a NetIn object to a vector of molar abundances.
"""
def primeEngine(self, netIn: gridfire._gridfire.type.NetIn) -> PrimingReport:
""" """
Prime the engine with a NetIn object to prepare for calculations. Prime the engine with a NetIn object to prepare for calculations.
""" """
def rebuild(self, composition: ..., depth: gridfire._gridfire.engine.NetworkBuildDepth | typing.SupportsInt = ...) -> None: def project(self, ctx: scratchpads.StateBlob, netIn: gridfire._gridfire.type.NetIn) -> fourdst._phys.composition.Composition:
"""
Rebuild the engine with a new composition and build depth.
"""
def setNetworkReactions(self, reactions: gridfire._gridfire.reaction.ReactionSet) -> None:
"""
Set the network reactions to a new set of reactions.
"""
def setScreeningModel(self, screeningModel: gridfire._gridfire.screening.ScreeningType) -> None:
"""
Set the screening model for the engine.
"""
def update(self, netIn: gridfire._gridfire.type.NetIn) -> fourdst._phys.composition.Composition:
""" """
Update the engine state based on the provided NetIn object. Update the engine state based on the provided NetIn object.
""" """
@@ -128,104 +99,74 @@ class DefinedEngineView(DynamicEngine):
""" """
Construct a defined engine view with a list of tracked reactions and a base engine. Construct a defined engine view with a list of tracked reactions and a base engine.
""" """
def calculateEpsDerivatives(self, comp: ..., T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> ...: def calculateEpsDerivatives(self, ctx: scratchpads.StateBlob, comp: ..., T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> ...:
""" """
Calculate deps/dT and deps/drho Calculate deps/dT and deps/drho
""" """
def calculateMolarReactionFlow(self: DynamicEngine, reaction: ..., comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> float: def calculateMolarReactionFlow(self: DynamicEngine, ctx: scratchpads.StateBlob, reaction: ..., comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> float:
""" """
Calculate the molar reaction flow for a given reaction. Calculate the molar reaction flow for a given reaction.
""" """
def calculateRHSAndEnergy(self: DynamicEngine, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> StepDerivatives: def calculateRHSAndEnergy(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> StepDerivatives:
""" """
Calculate the right-hand side (dY/dt) and energy generation rate. Calculate the right-hand side (dY/dt) and energy generation rate.
""" """
def collectComposition(self, composition: ..., T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> fourdst._phys.composition.Composition: def collectComposition(self, ctx: scratchpads.StateBlob, composition: ..., T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> fourdst._phys.composition.Composition:
""" """
Recursively collect composition from current engine and any sub engines if they exist. Recursively collect composition from current engine and any sub engines if they exist.
""" """
@typing.overload @typing.overload
def generateJacobianMatrix(self: DynamicEngine, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> NetworkJacobian: def generateJacobianMatrix(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> NetworkJacobian:
""" """
Generate the Jacobian matrix for the current state. Generate the Jacobian matrix for the current state.
""" """
@typing.overload @typing.overload
def generateJacobianMatrix(self: DynamicEngine, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, activeSpecies: collections.abc.Sequence[fourdst._phys.atomic.Species]) -> NetworkJacobian: def generateJacobianMatrix(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, activeSpecies: collections.abc.Sequence[fourdst._phys.atomic.Species]) -> NetworkJacobian:
""" """
Generate the jacobian matrix only for the subset of the matrix representing the active species. Generate the jacobian matrix only for the subset of the matrix representing the active species.
""" """
@typing.overload @typing.overload
def generateJacobianMatrix(self: DynamicEngine, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, sparsityPattern: collections.abc.Sequence[tuple[typing.SupportsInt, typing.SupportsInt]]) -> NetworkJacobian: def generateJacobianMatrix(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, sparsityPattern: collections.abc.Sequence[tuple[typing.SupportsInt, typing.SupportsInt]]) -> NetworkJacobian:
""" """
Generate the jacobian matrix for the given sparsity pattern Generate the jacobian matrix for the given sparsity pattern
""" """
def generateStoichiometryMatrix(self) -> None:
...
def getBaseEngine(self) -> DynamicEngine: def getBaseEngine(self) -> DynamicEngine:
""" """
Get the base engine associated with this defined engine view. Get the base engine associated with this defined engine view.
""" """
def getDepth(self) -> gridfire._gridfire.engine.NetworkBuildDepth | int: def getNetworkReactions(self, arg0: scratchpads.StateBlob) -> gridfire._gridfire.reaction.ReactionSet:
"""
Get the current build depth of the engine.
"""
def getNetworkReactions(self) -> gridfire._gridfire.reaction.ReactionSet:
""" """
Get the set of logical reactions in the network. Get the set of logical reactions in the network.
""" """
def getNetworkSpecies(self) -> list[fourdst._phys.atomic.Species]: def getNetworkSpecies(self, arg0: scratchpads.StateBlob) -> list[fourdst._phys.atomic.Species]:
""" """
Get the list of species in the network. Get the list of species in the network.
""" """
def getScreeningModel(self) -> gridfire._gridfire.screening.ScreeningType: def getScreeningModel(self, arg0: scratchpads.StateBlob) -> gridfire._gridfire.screening.ScreeningType:
""" """
Get the current screening model of the engine. Get the current screening model of the engine.
""" """
def getSpeciesDestructionTimescales(self: DynamicEngine, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> dict[fourdst._phys.atomic.Species, float]: def getSpeciesDestructionTimescales(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> dict[fourdst._phys.atomic.Species, float]:
""" """
Get the destruction timescales for each species in the network. Get the destruction timescales for each species in the network.
""" """
def getSpeciesIndex(self, species: fourdst._phys.atomic.Species) -> int: def getSpeciesIndex(self, ctx: scratchpads.StateBlob, species: fourdst._phys.atomic.Species) -> int:
""" """
Get the index of a species in the network. Get the index of a species in the network.
""" """
def getSpeciesStatus(self, species: fourdst._phys.atomic.Species) -> SpeciesStatus: def getSpeciesStatus(self, ctx: scratchpads.StateBlob, species: fourdst._phys.atomic.Species) -> SpeciesStatus:
""" """
Get the status of a species in the network. Get the status of a species in the network.
""" """
def getSpeciesTimescales(self: DynamicEngine, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> dict[fourdst._phys.atomic.Species, float]: def getSpeciesTimescales(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> dict[fourdst._phys.atomic.Species, float]:
""" """
Get the timescales for each species in the network. Get the timescales for each species in the network.
""" """
def getStoichiometryMatrixEntry(self, species: fourdst._phys.atomic.Species, reaction: ...) -> int: def primeEngine(self, ctx: scratchpads.StateBlob, netIn: gridfire._gridfire.type.NetIn) -> PrimingReport:
"""
Get an entry from the stoichiometry matrix.
"""
def isStale(self, netIn: gridfire._gridfire.type.NetIn) -> bool:
"""
Check if the engine is stale based on the provided NetIn object.
"""
def mapNetInToMolarAbundanceVector(self, netIn: gridfire._gridfire.type.NetIn) -> list[float]:
"""
Map a NetIn object to a vector of molar abundances.
"""
def primeEngine(self, netIn: gridfire._gridfire.type.NetIn) -> PrimingReport:
""" """
Prime the engine with a NetIn object to prepare for calculations. Prime the engine with a NetIn object to prepare for calculations.
""" """
def rebuild(self, composition: ..., depth: gridfire._gridfire.engine.NetworkBuildDepth | typing.SupportsInt = ...) -> None: def project(self, ctx: scratchpads.StateBlob, netIn: gridfire._gridfire.type.NetIn) -> fourdst._phys.composition.Composition:
"""
Rebuild the engine with a new composition and build depth.
"""
def setNetworkReactions(self, reactions: gridfire._gridfire.reaction.ReactionSet) -> None:
"""
Set the network reactions to a new set of reactions.
"""
def setScreeningModel(self, screeningModel: gridfire._gridfire.screening.ScreeningType) -> None:
"""
Set the screening model for the engine.
"""
def update(self, netIn: gridfire._gridfire.type.NetIn) -> fourdst._phys.composition.Composition:
""" """
Update the engine state based on the provided NetIn object. Update the engine state based on the provided NetIn object.
""" """
@@ -293,56 +234,50 @@ class FileDefinedEngineView(DefinedEngineView):
""" """
Construct a defined engine view from a file and a base engine. Construct a defined engine view from a file and a base engine.
""" """
def calculateEpsDerivatives(self, comp: ..., T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> ...: def calculateEpsDerivatives(self, ctx: scratchpads.StateBlob, comp: ..., T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> ...:
""" """
Calculate deps/dT and deps/drho Calculate deps/dT and deps/drho
""" """
def calculateMolarReactionFlow(self: DynamicEngine, reaction: ..., comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> float: def calculateMolarReactionFlow(self: DynamicEngine, ctx: scratchpads.StateBlob, reaction: ..., comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> float:
""" """
Calculate the molar reaction flow for a given reaction. Calculate the molar reaction flow for a given reaction.
""" """
def calculateRHSAndEnergy(self: DynamicEngine, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> StepDerivatives: def calculateRHSAndEnergy(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> StepDerivatives:
""" """
Calculate the right-hand side (dY/dt) and energy generation rate. Calculate the right-hand side (dY/dt) and energy generation rate.
""" """
def collectComposition(self, composition: ..., T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> fourdst._phys.composition.Composition: def collectComposition(self, ctx: scratchpads.StateBlob, composition: ..., T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> fourdst._phys.composition.Composition:
""" """
Recursively collect composition from current engine and any sub engines if they exist. Recursively collect composition from current engine and any sub engines if they exist.
""" """
@typing.overload @typing.overload
def generateJacobianMatrix(self: DynamicEngine, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> NetworkJacobian: def generateJacobianMatrix(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> NetworkJacobian:
""" """
Generate the Jacobian matrix for the current state. Generate the Jacobian matrix for the current state.
""" """
@typing.overload @typing.overload
def generateJacobianMatrix(self: DynamicEngine, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, activeSpecies: collections.abc.Sequence[fourdst._phys.atomic.Species]) -> NetworkJacobian: def generateJacobianMatrix(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, activeSpecies: collections.abc.Sequence[fourdst._phys.atomic.Species]) -> NetworkJacobian:
""" """
Generate the jacobian matrix only for the subset of the matrix representing the active species. Generate the jacobian matrix only for the subset of the matrix representing the active species.
""" """
@typing.overload @typing.overload
def generateJacobianMatrix(self: DynamicEngine, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, sparsityPattern: collections.abc.Sequence[tuple[typing.SupportsInt, typing.SupportsInt]]) -> NetworkJacobian: def generateJacobianMatrix(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, sparsityPattern: collections.abc.Sequence[tuple[typing.SupportsInt, typing.SupportsInt]]) -> NetworkJacobian:
""" """
Generate the jacobian matrix for the given sparsity pattern Generate the jacobian matrix for the given sparsity pattern
""" """
def generateStoichiometryMatrix(self) -> None:
...
def getBaseEngine(self) -> DynamicEngine: def getBaseEngine(self) -> DynamicEngine:
""" """
Get the base engine associated with this file defined engine view. Get the base engine associated with this file defined engine view.
""" """
def getDepth(self) -> gridfire._gridfire.engine.NetworkBuildDepth | int:
"""
Get the current build depth of the engine.
"""
def getNetworkFile(self) -> str: def getNetworkFile(self) -> str:
""" """
Get the network file associated with this defined engine view. Get the network file associated with this defined engine view.
""" """
def getNetworkReactions(self) -> gridfire._gridfire.reaction.ReactionSet: def getNetworkReactions(self, arg0: scratchpads.StateBlob) -> gridfire._gridfire.reaction.ReactionSet:
""" """
Get the set of logical reactions in the network. Get the set of logical reactions in the network.
""" """
def getNetworkSpecies(self) -> list[fourdst._phys.atomic.Species]: def getNetworkSpecies(self, arg0: scratchpads.StateBlob) -> list[fourdst._phys.atomic.Species]:
""" """
Get the list of species in the network. Get the list of species in the network.
""" """
@@ -350,64 +285,35 @@ class FileDefinedEngineView(DefinedEngineView):
""" """
Get the parser used for this defined engine view. Get the parser used for this defined engine view.
""" """
def getScreeningModel(self) -> gridfire._gridfire.screening.ScreeningType: def getScreeningModel(self, arg0: scratchpads.StateBlob) -> gridfire._gridfire.screening.ScreeningType:
""" """
Get the current screening model of the engine. Get the current screening model of the engine.
""" """
def getSpeciesDestructionTimescales(self: DynamicEngine, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> dict[fourdst._phys.atomic.Species, float]: def getSpeciesDestructionTimescales(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> dict[fourdst._phys.atomic.Species, float]:
""" """
Get the destruction timescales for each species in the network. Get the destruction timescales for each species in the network.
""" """
def getSpeciesIndex(self, species: fourdst._phys.atomic.Species) -> int: def getSpeciesIndex(self, ctx: scratchpads.StateBlob, species: fourdst._phys.atomic.Species) -> int:
""" """
Get the index of a species in the network. Get the index of a species in the network.
""" """
def getSpeciesStatus(self, species: fourdst._phys.atomic.Species) -> SpeciesStatus: def getSpeciesStatus(self, ctx: scratchpads.StateBlob, species: fourdst._phys.atomic.Species) -> SpeciesStatus:
""" """
Get the status of a species in the network. Get the status of a species in the network.
""" """
def getSpeciesTimescales(self: DynamicEngine, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> dict[fourdst._phys.atomic.Species, float]: def getSpeciesTimescales(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> dict[fourdst._phys.atomic.Species, float]:
""" """
Get the timescales for each species in the network. Get the timescales for each species in the network.
""" """
def getStoichiometryMatrixEntry(self, species: fourdst._phys.atomic.Species, reaction: ...) -> int: def primeEngine(self, ctx: scratchpads.StateBlob, netIn: gridfire._gridfire.type.NetIn) -> PrimingReport:
"""
Get an entry from the stoichiometry matrix.
"""
def isStale(self, netIn: gridfire._gridfire.type.NetIn) -> bool:
"""
Check if the engine is stale based on the provided NetIn object.
"""
def mapNetInToMolarAbundanceVector(self, netIn: gridfire._gridfire.type.NetIn) -> list[float]:
"""
Map a NetIn object to a vector of molar abundances.
"""
def primeEngine(self, netIn: gridfire._gridfire.type.NetIn) -> PrimingReport:
""" """
Prime the engine with a NetIn object to prepare for calculations. Prime the engine with a NetIn object to prepare for calculations.
""" """
def rebuild(self, composition: ..., depth: gridfire._gridfire.engine.NetworkBuildDepth | typing.SupportsInt = ...) -> None: def project(self, ctx: scratchpads.StateBlob, netIn: gridfire._gridfire.type.NetIn) -> fourdst._phys.composition.Composition:
"""
Rebuild the engine with a new composition and build depth.
"""
def setNetworkReactions(self, reactions: gridfire._gridfire.reaction.ReactionSet) -> None:
"""
Set the network reactions to a new set of reactions.
"""
def setScreeningModel(self, screeningModel: gridfire._gridfire.screening.ScreeningType) -> None:
"""
Set the screening model for the engine.
"""
def update(self, netIn: gridfire._gridfire.type.NetIn) -> fourdst._phys.composition.Composition:
""" """
Update the engine state based on the provided NetIn object. Update the engine state based on the provided NetIn object.
""" """
class GraphEngine(DynamicEngine): class GraphEngine(DynamicEngine):
@staticmethod
def getNetReactionStoichiometry(reaction: ...) -> dict[fourdst._phys.atomic.Species, int]:
"""
Get the net stoichiometry for a given reaction.
"""
@typing.overload @typing.overload
def __init__(self, composition: fourdst._phys.composition.Composition, depth: gridfire._gridfire.engine.NetworkBuildDepth | typing.SupportsInt = ...) -> None: def __init__(self, composition: fourdst._phys.composition.Composition, depth: gridfire._gridfire.engine.NetworkBuildDepth | typing.SupportsInt = ...) -> None:
""" """
@@ -423,15 +329,15 @@ class GraphEngine(DynamicEngine):
""" """
Initialize GraphEngine with a set of reactions. Initialize GraphEngine with a set of reactions.
""" """
def calculateEpsDerivatives(self, comp: ..., T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> ...: def calculateEpsDerivatives(self, ctx: scratchpads.StateBlob, comp: ..., T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> ...:
""" """
Calculate deps/dT and deps/drho Calculate deps/dT and deps/drho
""" """
def calculateMolarReactionFlow(self: DynamicEngine, reaction: ..., comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> float: def calculateMolarReactionFlow(self: DynamicEngine, ctx: scratchpads.StateBlob, reaction: ..., comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> float:
""" """
Calculate the molar reaction flow for a given reaction. Calculate the molar reaction flow for a given reaction.
""" """
def calculateRHSAndEnergy(self: DynamicEngine, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> StepDerivatives: def calculateRHSAndEnergy(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> StepDerivatives:
""" """
Calculate the right-hand side (dY/dt) and energy generation rate. Calculate the right-hand side (dY/dt) and energy generation rate.
""" """
@@ -447,128 +353,90 @@ class GraphEngine(DynamicEngine):
""" """
Calculate the derivative of the reverse rate for a two-body reaction at a specific temperature. Calculate the derivative of the reverse rate for a two-body reaction at a specific temperature.
""" """
def collectComposition(self, composition: ..., T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> fourdst._phys.composition.Composition: def collectComposition(self, ctx: scratchpads.StateBlob, composition: ..., T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> fourdst._phys.composition.Composition:
""" """
Recursively collect composition from current engine and any sub engines if they exist. Recursively collect composition from current engine and any sub engines if they exist.
""" """
def exportToCSV(self, filename: str) -> None: def exportToCSV(self, ctx: scratchpads.StateBlob, filename: str) -> None:
""" """
Export the network to a CSV file for analysis. Export the network to a CSV file for analysis.
""" """
def exportToDot(self, filename: str) -> None: def exportToDot(self, ctx: scratchpads.StateBlob, filename: str) -> None:
""" """
Export the network to a DOT file for visualization. Export the network to a DOT file for visualization.
""" """
@typing.overload @typing.overload
def generateJacobianMatrix(self: DynamicEngine, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> NetworkJacobian: def generateJacobianMatrix(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> NetworkJacobian:
""" """
Generate the Jacobian matrix for the current state. Generate the Jacobian matrix for the current state.
""" """
@typing.overload @typing.overload
def generateJacobianMatrix(self: DynamicEngine, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, activeSpecies: collections.abc.Sequence[fourdst._phys.atomic.Species]) -> NetworkJacobian: def generateJacobianMatrix(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, activeSpecies: collections.abc.Sequence[fourdst._phys.atomic.Species]) -> NetworkJacobian:
""" """
Generate the jacobian matrix only for the subset of the matrix representing the active species. Generate the jacobian matrix only for the subset of the matrix representing the active species.
""" """
@typing.overload @typing.overload
def generateJacobianMatrix(self: DynamicEngine, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, sparsityPattern: collections.abc.Sequence[tuple[typing.SupportsInt, typing.SupportsInt]]) -> NetworkJacobian: def generateJacobianMatrix(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, sparsityPattern: collections.abc.Sequence[tuple[typing.SupportsInt, typing.SupportsInt]]) -> NetworkJacobian:
""" """
Generate the jacobian matrix for the given sparsity pattern Generate the jacobian matrix for the given sparsity pattern
""" """
def generateStoichiometryMatrix(self) -> None: def getNetworkReactions(self, arg0: scratchpads.StateBlob) -> gridfire._gridfire.reaction.ReactionSet:
...
def getDepth(self) -> gridfire._gridfire.engine.NetworkBuildDepth | int:
"""
Get the current build depth of the engine.
"""
def getNetworkReactions(self) -> gridfire._gridfire.reaction.ReactionSet:
""" """
Get the set of logical reactions in the network. Get the set of logical reactions in the network.
""" """
def getNetworkSpecies(self) -> list[fourdst._phys.atomic.Species]: def getNetworkSpecies(self, arg0: scratchpads.StateBlob) -> list[fourdst._phys.atomic.Species]:
""" """
Get the list of species in the network. Get the list of species in the network.
""" """
def getPartitionFunction(self) -> gridfire._gridfire.partition.PartitionFunction: def getPartitionFunction(self, arg0: scratchpads.StateBlob) -> gridfire._gridfire.partition.PartitionFunction:
""" """
Get the partition function used by the engine. Get the partition function used by the engine.
""" """
def getScreeningModel(self) -> gridfire._gridfire.screening.ScreeningType: def getScreeningModel(self, arg0: scratchpads.StateBlob) -> gridfire._gridfire.screening.ScreeningType:
""" """
Get the current screening model of the engine. Get the current screening model of the engine.
""" """
@typing.overload @typing.overload
def getSpeciesDestructionTimescales(self, composition: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, activeReactions: gridfire._gridfire.reaction.ReactionSet) -> ...: def getSpeciesDestructionTimescales(self, ctx: scratchpads.StateBlob, composition: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, activeReactions: gridfire._gridfire.reaction.ReactionSet) -> ...:
... ...
@typing.overload @typing.overload
def getSpeciesDestructionTimescales(self: DynamicEngine, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> dict[fourdst._phys.atomic.Species, float]: def getSpeciesDestructionTimescales(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> dict[fourdst._phys.atomic.Species, float]:
""" """
Get the destruction timescales for each species in the network. Get the destruction timescales for each species in the network.
""" """
def getSpeciesIndex(self, species: fourdst._phys.atomic.Species) -> int: def getSpeciesIndex(self, ctx: scratchpads.StateBlob, species: fourdst._phys.atomic.Species) -> int:
""" """
Get the index of a species in the network. Get the index of a species in the network.
""" """
def getSpeciesStatus(self, species: fourdst._phys.atomic.Species) -> SpeciesStatus: def getSpeciesStatus(self, ctx: scratchpads.StateBlob, species: fourdst._phys.atomic.Species) -> SpeciesStatus:
""" """
Get the status of a species in the network. Get the status of a species in the network.
""" """
@typing.overload @typing.overload
def getSpeciesTimescales(self, composition: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, activeReactions: gridfire._gridfire.reaction.ReactionSet) -> ...: def getSpeciesTimescales(self, ctx: scratchpads.StateBlob, composition: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, activeReactions: gridfire._gridfire.reaction.ReactionSet) -> ...:
... ...
@typing.overload @typing.overload
def getSpeciesTimescales(self: DynamicEngine, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> dict[fourdst._phys.atomic.Species, float]: def getSpeciesTimescales(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> dict[fourdst._phys.atomic.Species, float]:
""" """
Get the timescales for each species in the network. Get the timescales for each species in the network.
""" """
def getStoichiometryMatrixEntry(self, species: fourdst._phys.atomic.Species, reaction: ...) -> int: def involvesSpecies(self, ctx: scratchpads.StateBlob, species: fourdst._phys.atomic.Species) -> bool:
"""
Get an entry from the stoichiometry matrix.
"""
def involvesSpecies(self, species: fourdst._phys.atomic.Species) -> bool:
""" """
Check if a given species is involved in the network. Check if a given species is involved in the network.
""" """
def isPrecomputationEnabled(self) -> bool: def isPrecomputationEnabled(self, arg0: scratchpads.StateBlob) -> bool:
""" """
Check if precomputation is enabled for the engine. Check if precomputation is enabled for the engine.
""" """
def isStale(self, netIn: gridfire._gridfire.type.NetIn) -> bool: def isUsingReverseReactions(self, arg0: scratchpads.StateBlob) -> bool:
"""
Check if the engine is stale based on the provided NetIn object.
"""
def isUsingReverseReactions(self) -> bool:
""" """
Check if the engine is using reverse reactions. Check if the engine is using reverse reactions.
""" """
def mapNetInToMolarAbundanceVector(self, netIn: gridfire._gridfire.type.NetIn) -> list[float]: def primeEngine(self, ctx: scratchpads.StateBlob, netIn: gridfire._gridfire.type.NetIn) -> PrimingReport:
"""
Map a NetIn object to a vector of molar abundances.
"""
def primeEngine(self, netIn: gridfire._gridfire.type.NetIn) -> PrimingReport:
""" """
Prime the engine with a NetIn object to prepare for calculations. Prime the engine with a NetIn object to prepare for calculations.
""" """
def rebuild(self, composition: ..., depth: gridfire._gridfire.engine.NetworkBuildDepth | typing.SupportsInt = ...) -> None: def project(self, ctx: scratchpads.StateBlob, netIn: gridfire._gridfire.type.NetIn) -> fourdst._phys.composition.Composition:
"""
Rebuild the engine with a new composition and build depth.
"""
def setNetworkReactions(self, reactions: gridfire._gridfire.reaction.ReactionSet) -> None:
"""
Set the network reactions to a new set of reactions.
"""
def setPrecomputation(self, precompute: bool) -> None:
"""
Enable or disable precomputation for the engine.
"""
def setScreeningModel(self, screeningModel: gridfire._gridfire.screening.ScreeningType) -> None:
"""
Set the screening model for the engine.
"""
def setUseReverseReactions(self, useReverse: bool) -> None:
"""
Enable or disable the use of reverse reactions in the engine.
"""
def update(self, netIn: gridfire._gridfire.type.NetIn) -> fourdst._phys.composition.Composition:
""" """
Update the engine state based on the provided NetIn object. Update the engine state based on the provided NetIn object.
""" """
@@ -577,142 +445,106 @@ class MultiscalePartitioningEngineView(DynamicEngine):
""" """
Construct a multiscale partitioning engine view with a base engine. Construct a multiscale partitioning engine view with a base engine.
""" """
def calculateEpsDerivatives(self, comp: ..., T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> ...: def calculateEpsDerivatives(self, ctx: scratchpads.StateBlob, comp: ..., T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> ...:
""" """
Calculate deps/dT and deps/drho Calculate deps/dT and deps/drho
""" """
def calculateMolarReactionFlow(self: DynamicEngine, reaction: ..., comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> float: def calculateMolarReactionFlow(self: DynamicEngine, ctx: scratchpads.StateBlob, reaction: ..., comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> float:
""" """
Calculate the molar reaction flow for a given reaction. Calculate the molar reaction flow for a given reaction.
""" """
def calculateRHSAndEnergy(self: DynamicEngine, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> StepDerivatives: def calculateRHSAndEnergy(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> StepDerivatives:
""" """
Calculate the right-hand side (dY/dt) and energy generation rate. Calculate the right-hand side (dY/dt) and energy generation rate.
""" """
def collectComposition(self, composition: ..., T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> fourdst._phys.composition.Composition: def collectComposition(self, ctx: scratchpads.StateBlob, composition: ..., T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> fourdst._phys.composition.Composition:
""" """
Recursively collect composition from current engine and any sub engines if they exist. Recursively collect composition from current engine and any sub engines if they exist.
""" """
def exportToDot(self, filename: str, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> None: def exportToDot(self, ctx: scratchpads.StateBlob, filename: str, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> None:
""" """
Export the network to a DOT file for visualization. Export the network to a DOT file for visualization.
""" """
@typing.overload @typing.overload
def generateJacobianMatrix(self: DynamicEngine, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> NetworkJacobian: def generateJacobianMatrix(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> NetworkJacobian:
""" """
Generate the Jacobian matrix for the current state. Generate the Jacobian matrix for the current state.
""" """
@typing.overload @typing.overload
def generateJacobianMatrix(self: DynamicEngine, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, activeSpecies: collections.abc.Sequence[fourdst._phys.atomic.Species]) -> NetworkJacobian: def generateJacobianMatrix(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, activeSpecies: collections.abc.Sequence[fourdst._phys.atomic.Species]) -> NetworkJacobian:
""" """
Generate the jacobian matrix only for the subset of the matrix representing the active species. Generate the jacobian matrix only for the subset of the matrix representing the active species.
""" """
@typing.overload @typing.overload
def generateJacobianMatrix(self: DynamicEngine, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, sparsityPattern: collections.abc.Sequence[tuple[typing.SupportsInt, typing.SupportsInt]]) -> NetworkJacobian: def generateJacobianMatrix(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, sparsityPattern: collections.abc.Sequence[tuple[typing.SupportsInt, typing.SupportsInt]]) -> NetworkJacobian:
""" """
Generate the jacobian matrix for the given sparsity pattern Generate the jacobian matrix for the given sparsity pattern
""" """
def generateStoichiometryMatrix(self) -> None:
...
def getBaseEngine(self) -> DynamicEngine: def getBaseEngine(self) -> DynamicEngine:
""" """
Get the base engine associated with this multiscale partitioning engine view. Get the base engine associated with this multiscale partitioning engine view.
""" """
def getDepth(self) -> gridfire._gridfire.engine.NetworkBuildDepth | int: def getDynamicSpecies(self: scratchpads.StateBlob) -> list[fourdst._phys.atomic.Species]:
"""
Get the current build depth of the engine.
"""
def getDynamicSpecies(self) -> list[fourdst._phys.atomic.Species]:
""" """
Get the list of dynamic species in the network. Get the list of dynamic species in the network.
""" """
def getFastSpecies(self) -> list[fourdst._phys.atomic.Species]: def getFastSpecies(self, arg0: scratchpads.StateBlob) -> list[fourdst._phys.atomic.Species]:
""" """
Get the list of fast species in the network. Get the list of fast species in the network.
""" """
def getNetworkReactions(self) -> gridfire._gridfire.reaction.ReactionSet: def getNetworkReactions(self, arg0: scratchpads.StateBlob) -> gridfire._gridfire.reaction.ReactionSet:
""" """
Get the set of logical reactions in the network. Get the set of logical reactions in the network.
""" """
def getNetworkSpecies(self) -> list[fourdst._phys.atomic.Species]: def getNetworkSpecies(self, arg0: scratchpads.StateBlob) -> list[fourdst._phys.atomic.Species]:
""" """
Get the list of species in the network. Get the list of species in the network.
""" """
def getNormalizedEquilibratedComposition(self, comp: ..., T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> fourdst._phys.composition.Composition: def getNormalizedEquilibratedComposition(self, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> fourdst._phys.composition.Composition:
""" """
Get the normalized equilibrated composition for the algebraic species. Get the normalized equilibrated composition for the algebraic species.
""" """
def getScreeningModel(self) -> gridfire._gridfire.screening.ScreeningType: def getScreeningModel(self, arg0: scratchpads.StateBlob) -> gridfire._gridfire.screening.ScreeningType:
""" """
Get the current screening model of the engine. Get the current screening model of the engine.
""" """
def getSpeciesDestructionTimescales(self: DynamicEngine, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> dict[fourdst._phys.atomic.Species, float]: def getSpeciesDestructionTimescales(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> dict[fourdst._phys.atomic.Species, float]:
""" """
Get the destruction timescales for each species in the network. Get the destruction timescales for each species in the network.
""" """
def getSpeciesIndex(self, species: fourdst._phys.atomic.Species) -> int: def getSpeciesIndex(self, ctx: scratchpads.StateBlob, species: fourdst._phys.atomic.Species) -> int:
""" """
Get the index of a species in the network. Get the index of a species in the network.
""" """
def getSpeciesStatus(self, species: fourdst._phys.atomic.Species) -> SpeciesStatus: def getSpeciesStatus(self, ctx: scratchpads.StateBlob, species: fourdst._phys.atomic.Species) -> SpeciesStatus:
""" """
Get the status of a species in the network. Get the status of a species in the network.
""" """
def getSpeciesTimescales(self: DynamicEngine, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> dict[fourdst._phys.atomic.Species, float]: def getSpeciesTimescales(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> dict[fourdst._phys.atomic.Species, float]:
""" """
Get the timescales for each species in the network. Get the timescales for each species in the network.
""" """
def getStoichiometryMatrixEntry(self, species: fourdst._phys.atomic.Species, reaction: ...) -> int: def involvesSpecies(self: scratchpads.StateBlob, species: fourdst._phys.atomic.Species) -> bool:
"""
Get an entry from the stoichiometry matrix.
"""
def involvesSpecies(self, species: fourdst._phys.atomic.Species) -> bool:
""" """
Check if a given species is involved in the network (in either the algebraic or dynamic set). Check if a given species is involved in the network (in either the algebraic or dynamic set).
""" """
def involvesSpeciesInDynamic(self, species: fourdst._phys.atomic.Species) -> bool: def involvesSpeciesInDynamic(self: scratchpads.StateBlob, species: fourdst._phys.atomic.Species) -> bool:
""" """
Check if a given species is involved in the network's dynamic set. Check if a given species is involved in the network's dynamic set.
""" """
def involvesSpeciesInQSE(self, species: fourdst._phys.atomic.Species) -> bool: def involvesSpeciesInQSE(self: scratchpads.StateBlob, species: fourdst._phys.atomic.Species) -> bool:
""" """
Check if a given species is involved in the network's algebraic set. Check if a given species is involved in the network's algebraic set.
""" """
def isStale(self, netIn: gridfire._gridfire.type.NetIn) -> bool: def partitionNetwork(self, ctx: scratchpads.StateBlob, netIn: gridfire._gridfire.type.NetIn) -> fourdst._phys.composition.Composition:
"""
Check if the engine is stale based on the provided NetIn object.
"""
def mapNetInToMolarAbundanceVector(self, netIn: gridfire._gridfire.type.NetIn) -> list[float]:
"""
Map a NetIn object to a vector of molar abundances.
"""
@typing.overload
def partitionNetwork(self, netIn: gridfire._gridfire.type.NetIn) -> fourdst._phys.composition.Composition:
""" """
Partition the network based on species timescales and connectivity. Partition the network based on species timescales and connectivity.
""" """
@typing.overload def primeEngine(self, ctx: scratchpads.StateBlob, netIn: gridfire._gridfire.type.NetIn) -> PrimingReport:
def partitionNetwork(self, netIn: gridfire._gridfire.type.NetIn) -> fourdst._phys.composition.Composition:
"""
Partition the network based on a NetIn object.
"""
def primeEngine(self, netIn: gridfire._gridfire.type.NetIn) -> PrimingReport:
""" """
Prime the engine with a NetIn object to prepare for calculations. Prime the engine with a NetIn object to prepare for calculations.
""" """
def rebuild(self, composition: ..., depth: gridfire._gridfire.engine.NetworkBuildDepth | typing.SupportsInt = ...) -> None: def project(self, ctx: scratchpads.StateBlob, netIn: gridfire._gridfire.type.NetIn) -> fourdst._phys.composition.Composition:
"""
Rebuild the engine with a new composition and build depth.
"""
def setNetworkReactions(self, reactions: gridfire._gridfire.reaction.ReactionSet) -> None:
"""
Set the network reactions to a new set of reactions.
"""
def setScreeningModel(self, screeningModel: gridfire._gridfire.screening.ScreeningType) -> None:
"""
Set the screening model for the engine.
"""
def update(self, netIn: gridfire._gridfire.type.NetIn) -> fourdst._phys.composition.Composition:
""" """
Update the engine state based on the provided NetIn object. Update the engine state based on the provided NetIn object.
""" """
@@ -893,113 +725,83 @@ class NetworkJacobian:
""" """
class NetworkPrimingEngineView(DefinedEngineView): class NetworkPrimingEngineView(DefinedEngineView):
@typing.overload @typing.overload
def __init__(self, primingSymbol: str, baseEngine: GraphEngine) -> None: def __init__(self, ctx: scratchpads.StateBlob, primingSymbol: str, baseEngine: GraphEngine) -> None:
""" """
Construct a priming engine view with a priming symbol and a base engine. Construct a priming engine view with a priming symbol and a base engine.
""" """
@typing.overload @typing.overload
def __init__(self, primingSpecies: fourdst._phys.atomic.Species, baseEngine: GraphEngine) -> None: def __init__(self, ctx: scratchpads.StateBlob, primingSpecies: fourdst._phys.atomic.Species, baseEngine: GraphEngine) -> None:
""" """
Construct a priming engine view with a priming species and a base engine. Construct a priming engine view with a priming species and a base engine.
""" """
def calculateEpsDerivatives(self, comp: ..., T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> ...: def calculateEpsDerivatives(self, ctx: scratchpads.StateBlob, comp: ..., T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> ...:
""" """
Calculate deps/dT and deps/drho Calculate deps/dT and deps/drho
""" """
def calculateMolarReactionFlow(self: DynamicEngine, reaction: ..., comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> float: def calculateMolarReactionFlow(self: DynamicEngine, ctx: scratchpads.StateBlob, reaction: ..., comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> float:
""" """
Calculate the molar reaction flow for a given reaction. Calculate the molar reaction flow for a given reaction.
""" """
def calculateRHSAndEnergy(self: DynamicEngine, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> StepDerivatives: def calculateRHSAndEnergy(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> StepDerivatives:
""" """
Calculate the right-hand side (dY/dt) and energy generation rate. Calculate the right-hand side (dY/dt) and energy generation rate.
""" """
def collectComposition(self, composition: ..., T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> fourdst._phys.composition.Composition: def collectComposition(self, ctx: scratchpads.StateBlob, composition: ..., T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> fourdst._phys.composition.Composition:
""" """
Recursively collect composition from current engine and any sub engines if they exist. Recursively collect composition from current engine and any sub engines if they exist.
""" """
@typing.overload @typing.overload
def generateJacobianMatrix(self: DynamicEngine, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> NetworkJacobian: def generateJacobianMatrix(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> NetworkJacobian:
""" """
Generate the Jacobian matrix for the current state. Generate the Jacobian matrix for the current state.
""" """
@typing.overload @typing.overload
def generateJacobianMatrix(self: DynamicEngine, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, activeSpecies: collections.abc.Sequence[fourdst._phys.atomic.Species]) -> NetworkJacobian: def generateJacobianMatrix(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, activeSpecies: collections.abc.Sequence[fourdst._phys.atomic.Species]) -> NetworkJacobian:
""" """
Generate the jacobian matrix only for the subset of the matrix representing the active species. Generate the jacobian matrix only for the subset of the matrix representing the active species.
""" """
@typing.overload @typing.overload
def generateJacobianMatrix(self: DynamicEngine, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, sparsityPattern: collections.abc.Sequence[tuple[typing.SupportsInt, typing.SupportsInt]]) -> NetworkJacobian: def generateJacobianMatrix(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, sparsityPattern: collections.abc.Sequence[tuple[typing.SupportsInt, typing.SupportsInt]]) -> NetworkJacobian:
""" """
Generate the jacobian matrix for the given sparsity pattern Generate the jacobian matrix for the given sparsity pattern
""" """
def generateStoichiometryMatrix(self) -> None:
...
def getBaseEngine(self) -> DynamicEngine: def getBaseEngine(self) -> DynamicEngine:
""" """
Get the base engine associated with this priming engine view. Get the base engine associated with this priming engine view.
""" """
def getDepth(self) -> gridfire._gridfire.engine.NetworkBuildDepth | int: def getNetworkReactions(self, arg0: scratchpads.StateBlob) -> gridfire._gridfire.reaction.ReactionSet:
"""
Get the current build depth of the engine.
"""
def getNetworkReactions(self) -> gridfire._gridfire.reaction.ReactionSet:
""" """
Get the set of logical reactions in the network. Get the set of logical reactions in the network.
""" """
def getNetworkSpecies(self) -> list[fourdst._phys.atomic.Species]: def getNetworkSpecies(self, arg0: scratchpads.StateBlob) -> list[fourdst._phys.atomic.Species]:
""" """
Get the list of species in the network. Get the list of species in the network.
""" """
def getScreeningModel(self) -> gridfire._gridfire.screening.ScreeningType: def getScreeningModel(self, arg0: scratchpads.StateBlob) -> gridfire._gridfire.screening.ScreeningType:
""" """
Get the current screening model of the engine. Get the current screening model of the engine.
""" """
def getSpeciesDestructionTimescales(self: DynamicEngine, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> dict[fourdst._phys.atomic.Species, float]: def getSpeciesDestructionTimescales(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> dict[fourdst._phys.atomic.Species, float]:
""" """
Get the destruction timescales for each species in the network. Get the destruction timescales for each species in the network.
""" """
def getSpeciesIndex(self, species: fourdst._phys.atomic.Species) -> int: def getSpeciesIndex(self, ctx: scratchpads.StateBlob, species: fourdst._phys.atomic.Species) -> int:
""" """
Get the index of a species in the network. Get the index of a species in the network.
""" """
def getSpeciesStatus(self, species: fourdst._phys.atomic.Species) -> SpeciesStatus: def getSpeciesStatus(self, ctx: scratchpads.StateBlob, species: fourdst._phys.atomic.Species) -> SpeciesStatus:
""" """
Get the status of a species in the network. Get the status of a species in the network.
""" """
def getSpeciesTimescales(self: DynamicEngine, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> dict[fourdst._phys.atomic.Species, float]: def getSpeciesTimescales(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> dict[fourdst._phys.atomic.Species, float]:
""" """
Get the timescales for each species in the network. Get the timescales for each species in the network.
""" """
def getStoichiometryMatrixEntry(self, species: fourdst._phys.atomic.Species, reaction: ...) -> int: def primeEngine(self, ctx: scratchpads.StateBlob, netIn: gridfire._gridfire.type.NetIn) -> PrimingReport:
"""
Get an entry from the stoichiometry matrix.
"""
def isStale(self, netIn: gridfire._gridfire.type.NetIn) -> bool:
"""
Check if the engine is stale based on the provided NetIn object.
"""
def mapNetInToMolarAbundanceVector(self, netIn: gridfire._gridfire.type.NetIn) -> list[float]:
"""
Map a NetIn object to a vector of molar abundances.
"""
def primeEngine(self, netIn: gridfire._gridfire.type.NetIn) -> PrimingReport:
""" """
Prime the engine with a NetIn object to prepare for calculations. Prime the engine with a NetIn object to prepare for calculations.
""" """
def rebuild(self, composition: ..., depth: gridfire._gridfire.engine.NetworkBuildDepth | typing.SupportsInt = ...) -> None: def project(self, ctx: scratchpads.StateBlob, netIn: gridfire._gridfire.type.NetIn) -> fourdst._phys.composition.Composition:
"""
Rebuild the engine with a new composition and build depth.
"""
def setNetworkReactions(self, reactions: gridfire._gridfire.reaction.ReactionSet) -> None:
"""
Set the network reactions to a new set of reactions.
"""
def setScreeningModel(self, screeningModel: gridfire._gridfire.screening.ScreeningType) -> None:
"""
Set the screening model for the engine.
"""
def update(self, netIn: gridfire._gridfire.type.NetIn) -> fourdst._phys.composition.Composition:
""" """
Update the engine state based on the provided NetIn object. Update the engine state based on the provided NetIn object.
""" """
@@ -1131,7 +933,7 @@ def build_nuclear_network(composition: ..., weakInterpolator: ..., maxLayers: gr
""" """
Build a nuclear network from a composition using all archived reaction data. Build a nuclear network from a composition using all archived reaction data.
""" """
def primeNetwork(netIn: gridfire._gridfire.type.NetIn, engine: ..., ignoredReactionTypes: collections.abc.Sequence[...] | None = None) -> PrimingReport: def primeNetwork(ctx: scratchpads.StateBlob, netIn: gridfire._gridfire.type.NetIn, engine: ..., ignoredReactionTypes: collections.abc.Sequence[...] | None = None) -> PrimingReport:
""" """
Prime a network with a short timescale ignition Prime a network with a short timescale ignition
""" """

View File

@@ -5,11 +5,12 @@ from __future__ import annotations
import collections.abc import collections.abc
import fourdst._phys.composition import fourdst._phys.composition
import gridfire._gridfire.engine import gridfire._gridfire.engine
import gridfire._gridfire.engine.scratchpads
import typing import typing
__all__: list[str] = ['inspect_jacobian_stiffness', 'inspect_species_balance', 'report_limiting_species'] __all__: list[str] = ['inspect_jacobian_stiffness', 'inspect_species_balance', 'report_limiting_species']
def inspect_jacobian_stiffness(engine: gridfire._gridfire.engine.DynamicEngine, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, json: bool) -> ... | None: def inspect_jacobian_stiffness(ctx: gridfire._gridfire.engine.scratchpads.StateBlob, engine: gridfire._gridfire.engine.DynamicEngine, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, json: bool) -> ... | None:
... ...
def inspect_species_balance(engine: gridfire._gridfire.engine.DynamicEngine, species_name: str, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, json: bool) -> ... | None: def inspect_species_balance(ctx: gridfire._gridfire.engine.scratchpads.StateBlob, engine: gridfire._gridfire.engine.DynamicEngine, species_name: str, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, json: bool) -> ... | None:
... ...
def report_limiting_species(engine: gridfire._gridfire.engine.DynamicEngine, Y_full: collections.abc.Sequence[typing.SupportsFloat], E_full: collections.abc.Sequence[typing.SupportsFloat], relTol: typing.SupportsFloat, absTol: typing.SupportsFloat, top_n: typing.SupportsInt, json: bool) -> ... | None: def report_limiting_species(ctx: gridfire._gridfire.engine.scratchpads.StateBlob, engine: gridfire._gridfire.engine.DynamicEngine, Y_full: collections.abc.Sequence[typing.SupportsFloat], E_full: collections.abc.Sequence[typing.SupportsFloat], relTol: typing.SupportsFloat, absTol: typing.SupportsFloat, top_n: typing.SupportsInt, json: bool) -> ... | None:
... ...

View File

@@ -0,0 +1,267 @@
"""
Engine ScratchPad bindings
"""
from __future__ import annotations
import fourdst._phys.atomic
import fourdst._phys.composition
import gridfire._gridfire.reaction
import typing
__all__: list[str] = ['ADAPTIVE_ENGINE_VIEW_SCRATCHPAD', 'ADFunRegistrationResult', 'ALREADY_REGISTERED', 'AdaptiveEngineViewScratchPad', 'DEFINED_ENGINE_VIEW_SCRATCHPAD', 'DefinedEngineViewScratchPad', 'GRAPH_ENGINE_SCRATCHPAD', 'GraphEngineScratchPad', 'MULTISCALE_PARTITIONING_ENGINE_VIEW_SCRATCHPAD', 'MultiscalePartitioningEngineViewScratchPad', 'SCRATCHPAD_BAD_CAST', 'SCRATCHPAD_NOT_FOUND', 'SCRATCHPAD_NOT_INITIALIZED', 'SCRATCHPAD_OUT_OF_BOUNDS', 'SCRATCHPAD_TYPE_COLLISION', 'SCRATCHPAD_UNKNOWN_ERROR', 'SUCCESS', 'ScratchPadType', 'StateBlob', 'StateBlobError']
class ADFunRegistrationResult:
"""
Members:
SUCCESS
ALREADY_REGISTERED
"""
ALREADY_REGISTERED: typing.ClassVar[ADFunRegistrationResult] # value = <ADFunRegistrationResult.ALREADY_REGISTERED: 1>
SUCCESS: typing.ClassVar[ADFunRegistrationResult] # value = <ADFunRegistrationResult.SUCCESS: 0>
__members__: typing.ClassVar[dict[str, ADFunRegistrationResult]] # value = {'SUCCESS': <ADFunRegistrationResult.SUCCESS: 0>, 'ALREADY_REGISTERED': <ADFunRegistrationResult.ALREADY_REGISTERED: 1>}
def __eq__(self, other: typing.Any) -> bool:
...
def __getstate__(self) -> int:
...
def __hash__(self) -> int:
...
def __index__(self) -> int:
...
def __init__(self, value: typing.SupportsInt) -> None:
...
def __int__(self) -> int:
...
def __ne__(self, other: typing.Any) -> bool:
...
def __repr__(self) -> str:
...
def __setstate__(self, state: typing.SupportsInt) -> None:
...
def __str__(self) -> str:
...
@property
def name(self) -> str:
...
@property
def value(self) -> int:
...
class AdaptiveEngineViewScratchPad:
ID: typing.ClassVar[ScratchPadType] # value = <ScratchPadType.ADAPTIVE_ENGINE_VIEW_SCRATCHPAD: 2>
def __init__(self) -> None:
...
def __repr__(self) -> str:
...
def clone(self) -> ...:
...
def initialize(self, arg0: ...) -> None:
...
def is_initialized(self) -> bool:
...
@property
def active_reactions(self) -> gridfire._gridfire.reaction.ReactionSet:
...
@property
def active_species(self) -> list[fourdst._phys.atomic.Species]:
...
@property
def has_initialized(self) -> bool:
...
class DefinedEngineViewScratchPad:
ID: typing.ClassVar[ScratchPadType] # value = <ScratchPadType.DEFINED_ENGINE_VIEW_SCRATCHPAD: 3>
def __init__(self) -> None:
...
def __repr__(self) -> str:
...
def clone(self) -> ...:
...
def is_initialized(self) -> bool:
...
@property
def active_reactions(self) -> gridfire._gridfire.reaction.ReactionSet:
...
@property
def active_species(self) -> set[fourdst._phys.atomic.Species]:
...
@property
def has_initialized(self) -> bool:
...
@property
def reaction_index_map(self) -> list[int]:
...
@property
def species_index_map(self) -> list[int]:
...
class GraphEngineScratchPad:
ID: typing.ClassVar[ScratchPadType] # value = <ScratchPadType.GRAPH_ENGINE_SCRATCHPAD: 0>
def __init__(self) -> None:
...
def __repr__(self) -> str:
...
def clone(self) -> ...:
...
def initialize(self, engine: ...) -> None:
...
def is_initialized(self) -> bool:
...
@property
def has_initialized(self) -> bool:
...
@property
def local_abundance_cache(self) -> list[float]:
...
@property
def most_recent_rhs_calculation(self) -> ... | None:
...
@property
def stepDerivativesCache(self) -> dict[int, ...]:
...
class MultiscalePartitioningEngineViewScratchPad:
ID: typing.ClassVar[ScratchPadType] # value = <ScratchPadType.MULTISCALE_PARTITIONING_ENGINE_VIEW_SCRATCHPAD: 1>
def __init__(self) -> None:
...
def __repr__(self) -> str:
...
def clone(self) -> ...:
...
def initialize(self) -> None:
...
def is_initialized(self) -> bool:
...
@property
def algebraic_species(self) -> list[fourdst._phys.atomic.Species]:
...
@property
def composition_cache(self) -> dict[int, fourdst._phys.composition.Composition]:
...
@property
def dynamic_species(self) -> list[fourdst._phys.atomic.Species]:
...
@property
def has_initialized(self) -> bool:
...
@property
def qse_groups(self) -> list[...]:
...
class ScratchPadType:
"""
Members:
GRAPH_ENGINE_SCRATCHPAD
MULTISCALE_PARTITIONING_ENGINE_VIEW_SCRATCHPAD
ADAPTIVE_ENGINE_VIEW_SCRATCHPAD
DEFINED_ENGINE_VIEW_SCRATCHPAD
"""
ADAPTIVE_ENGINE_VIEW_SCRATCHPAD: typing.ClassVar[ScratchPadType] # value = <ScratchPadType.ADAPTIVE_ENGINE_VIEW_SCRATCHPAD: 2>
DEFINED_ENGINE_VIEW_SCRATCHPAD: typing.ClassVar[ScratchPadType] # value = <ScratchPadType.DEFINED_ENGINE_VIEW_SCRATCHPAD: 3>
GRAPH_ENGINE_SCRATCHPAD: typing.ClassVar[ScratchPadType] # value = <ScratchPadType.GRAPH_ENGINE_SCRATCHPAD: 0>
MULTISCALE_PARTITIONING_ENGINE_VIEW_SCRATCHPAD: typing.ClassVar[ScratchPadType] # value = <ScratchPadType.MULTISCALE_PARTITIONING_ENGINE_VIEW_SCRATCHPAD: 1>
__members__: typing.ClassVar[dict[str, ScratchPadType]] # value = {'GRAPH_ENGINE_SCRATCHPAD': <ScratchPadType.GRAPH_ENGINE_SCRATCHPAD: 0>, 'MULTISCALE_PARTITIONING_ENGINE_VIEW_SCRATCHPAD': <ScratchPadType.MULTISCALE_PARTITIONING_ENGINE_VIEW_SCRATCHPAD: 1>, 'ADAPTIVE_ENGINE_VIEW_SCRATCHPAD': <ScratchPadType.ADAPTIVE_ENGINE_VIEW_SCRATCHPAD: 2>, 'DEFINED_ENGINE_VIEW_SCRATCHPAD': <ScratchPadType.DEFINED_ENGINE_VIEW_SCRATCHPAD: 3>}
def __eq__(self, other: typing.Any) -> bool:
...
def __getstate__(self) -> int:
...
def __hash__(self) -> int:
...
def __index__(self) -> int:
...
def __init__(self, value: typing.SupportsInt) -> None:
...
def __int__(self) -> int:
...
def __ne__(self, other: typing.Any) -> bool:
...
def __repr__(self) -> str:
...
def __setstate__(self, state: typing.SupportsInt) -> None:
...
def __str__(self) -> str:
...
@property
def name(self) -> str:
...
@property
def value(self) -> int:
...
class StateBlob:
@staticmethod
def error_to_string(arg0: StateBlobError) -> str:
...
def __init__(self) -> None:
...
def __repr__(self) -> str:
...
def clone_structure(self) -> StateBlob:
...
def enroll(self, arg0: ScratchPadType) -> None:
...
def get(self, arg0: ScratchPadType) -> ...:
...
def get_registered_scratchpads(self) -> set[ScratchPadType]:
...
def get_status(self, arg0: ScratchPadType) -> ...:
...
def get_status_map(self) -> dict[ScratchPadType, ...]:
...
class StateBlobError:
"""
Members:
SCRATCHPAD_OUT_OF_BOUNDS
SCRATCHPAD_NOT_FOUND
SCRATCHPAD_BAD_CAST
SCRATCHPAD_NOT_INITIALIZED
SCRATCHPAD_TYPE_COLLISION
SCRATCHPAD_UNKNOWN_ERROR
"""
SCRATCHPAD_BAD_CAST: typing.ClassVar[StateBlobError] # value = <StateBlobError.SCRATCHPAD_BAD_CAST: 1>
SCRATCHPAD_NOT_FOUND: typing.ClassVar[StateBlobError] # value = <StateBlobError.SCRATCHPAD_NOT_FOUND: 0>
SCRATCHPAD_NOT_INITIALIZED: typing.ClassVar[StateBlobError] # value = <StateBlobError.SCRATCHPAD_NOT_INITIALIZED: 2>
SCRATCHPAD_OUT_OF_BOUNDS: typing.ClassVar[StateBlobError] # value = <StateBlobError.SCRATCHPAD_OUT_OF_BOUNDS: 4>
SCRATCHPAD_TYPE_COLLISION: typing.ClassVar[StateBlobError] # value = <StateBlobError.SCRATCHPAD_TYPE_COLLISION: 3>
SCRATCHPAD_UNKNOWN_ERROR: typing.ClassVar[StateBlobError] # value = <StateBlobError.SCRATCHPAD_UNKNOWN_ERROR: 5>
__members__: typing.ClassVar[dict[str, StateBlobError]] # value = {'SCRATCHPAD_OUT_OF_BOUNDS': <StateBlobError.SCRATCHPAD_OUT_OF_BOUNDS: 4>, 'SCRATCHPAD_NOT_FOUND': <StateBlobError.SCRATCHPAD_NOT_FOUND: 0>, 'SCRATCHPAD_BAD_CAST': <StateBlobError.SCRATCHPAD_BAD_CAST: 1>, 'SCRATCHPAD_NOT_INITIALIZED': <StateBlobError.SCRATCHPAD_NOT_INITIALIZED: 2>, 'SCRATCHPAD_TYPE_COLLISION': <StateBlobError.SCRATCHPAD_TYPE_COLLISION: 3>, 'SCRATCHPAD_UNKNOWN_ERROR': <StateBlobError.SCRATCHPAD_UNKNOWN_ERROR: 5>}
def __eq__(self, other: typing.Any) -> bool:
...
def __getstate__(self) -> int:
...
def __hash__(self) -> int:
...
def __index__(self) -> int:
...
def __init__(self, value: typing.SupportsInt) -> None:
...
def __int__(self) -> int:
...
def __ne__(self, other: typing.Any) -> bool:
...
def __repr__(self) -> str:
...
def __setstate__(self, state: typing.SupportsInt) -> None:
...
def __str__(self) -> str:
...
@property
def name(self) -> str:
...
@property
def value(self) -> int:
...
ADAPTIVE_ENGINE_VIEW_SCRATCHPAD: ScratchPadType # value = <ScratchPadType.ADAPTIVE_ENGINE_VIEW_SCRATCHPAD: 2>
ALREADY_REGISTERED: ADFunRegistrationResult # value = <ADFunRegistrationResult.ALREADY_REGISTERED: 1>
DEFINED_ENGINE_VIEW_SCRATCHPAD: ScratchPadType # value = <ScratchPadType.DEFINED_ENGINE_VIEW_SCRATCHPAD: 3>
GRAPH_ENGINE_SCRATCHPAD: ScratchPadType # value = <ScratchPadType.GRAPH_ENGINE_SCRATCHPAD: 0>
MULTISCALE_PARTITIONING_ENGINE_VIEW_SCRATCHPAD: ScratchPadType # value = <ScratchPadType.MULTISCALE_PARTITIONING_ENGINE_VIEW_SCRATCHPAD: 1>
SCRATCHPAD_BAD_CAST: StateBlobError # value = <StateBlobError.SCRATCHPAD_BAD_CAST: 1>
SCRATCHPAD_NOT_FOUND: StateBlobError # value = <StateBlobError.SCRATCHPAD_NOT_FOUND: 0>
SCRATCHPAD_NOT_INITIALIZED: StateBlobError # value = <StateBlobError.SCRATCHPAD_NOT_INITIALIZED: 2>
SCRATCHPAD_OUT_OF_BOUNDS: StateBlobError # value = <StateBlobError.SCRATCHPAD_OUT_OF_BOUNDS: 4>
SCRATCHPAD_TYPE_COLLISION: StateBlobError # value = <StateBlobError.SCRATCHPAD_TYPE_COLLISION: 3>
SCRATCHPAD_UNKNOWN_ERROR: StateBlobError # value = <StateBlobError.SCRATCHPAD_UNKNOWN_ERROR: 5>
SUCCESS: ADFunRegistrationResult # value = <ADFunRegistrationResult.SUCCESS: 0>

View File

@@ -2,7 +2,7 @@
GridFire exceptions bindings GridFire exceptions bindings
""" """
from __future__ import annotations from __future__ import annotations
__all__: list[str] = ['BadCollectionError', 'BadRHSEngineError', 'CVODESolverFailureError', 'DebugException', 'EngineError', 'FailedToPartitionEngineError', 'GridFireError', 'HashingError', 'IllConditionedJacobianError', 'InvalidQSESolutionError', 'JacobianError', 'KINSolSolverFailureError', 'MissingBaseReactionError', 'MissingKeyReactionError', 'MissingSeedSpeciesError', 'NetworkResizedError', 'PolicyError', 'ReactionError', 'ReactionParsingError', 'SUNDIALSError', 'SingularJacobianError', 'SolverError', 'StaleJacobianError', 'UnableToSetNetworkReactionsError', 'UninitializedJacobianError', 'UnknownJacobianError', 'UtilityError'] __all__: list[str] = ['BadCollectionError', 'BadRHSEngineError', 'CVODESolverFailureError', 'DebugException', 'EngineError', 'FailedToPartitionEngineError', 'GridFireError', 'HashingError', 'IllConditionedJacobianError', 'InvalidQSESolutionError', 'JacobianError', 'KINSolSolverFailureError', 'MissingBaseReactionError', 'MissingKeyReactionError', 'MissingSeedSpeciesError', 'NetworkResizedError', 'PolicyError', 'ReactionError', 'ReactionParsingError', 'SUNDIALSError', 'ScratchPadError', 'SingularJacobianError', 'SolverError', 'StaleJacobianError', 'UnableToSetNetworkReactionsError', 'UninitializedJacobianError', 'UnknownJacobianError', 'UtilityError']
class BadCollectionError(EngineError): class BadCollectionError(EngineError):
pass pass
class BadRHSEngineError(EngineError): class BadRHSEngineError(EngineError):
@@ -43,6 +43,8 @@ class ReactionParsingError(ReactionError):
pass pass
class SUNDIALSError(SolverError): class SUNDIALSError(SolverError):
pass pass
class ScratchPadError(GridFireError):
pass
class SingularJacobianError(SolverError): class SingularJacobianError(SolverError):
pass pass
class SolverError(GridFireError): class SolverError(GridFireError):

View File

@@ -6,9 +6,11 @@ import collections.abc
import fourdst._phys.atomic import fourdst._phys.atomic
import fourdst._phys.composition import fourdst._phys.composition
import gridfire._gridfire.engine import gridfire._gridfire.engine
import gridfire._gridfire.engine.scratchpads
import gridfire._gridfire.partition
import gridfire._gridfire.reaction import gridfire._gridfire.reaction
import typing import typing
__all__: list[str] = ['CNOChainPolicy', 'CNOIChainPolicy', 'CNOIIChainPolicy', 'CNOIIIChainPolicy', 'CNOIVChainPolicy', 'HotCNOChainPolicy', 'HotCNOIChainPolicy', 'HotCNOIIChainPolicy', 'HotCNOIIIChainPolicy', 'INITIALIZED_UNVERIFIED', 'INITIALIZED_VERIFIED', 'MISSING_KEY_REACTION', 'MISSING_KEY_SPECIES', 'MainSequencePolicy', 'MainSequenceReactionChainPolicy', 'MultiReactionChainPolicy', 'NetworkPolicy', 'NetworkPolicyStatus', 'ProtonProtonChainPolicy', 'ProtonProtonIChainPolicy', 'ProtonProtonIIChainPolicy', 'ProtonProtonIIIChainPolicy', 'ReactionChainPolicy', 'TemperatureDependentChainPolicy', 'TripleAlphaChainPolicy', 'UNINITIALIZED'] __all__: list[str] = ['CNOChainPolicy', 'CNOIChainPolicy', 'CNOIIChainPolicy', 'CNOIIIChainPolicy', 'CNOIVChainPolicy', 'ConstructionResults', 'HotCNOChainPolicy', 'HotCNOIChainPolicy', 'HotCNOIIChainPolicy', 'HotCNOIIIChainPolicy', 'INITIALIZED_UNVERIFIED', 'INITIALIZED_VERIFIED', 'MISSING_KEY_REACTION', 'MISSING_KEY_SPECIES', 'MainSequencePolicy', 'MainSequenceReactionChainPolicy', 'MultiReactionChainPolicy', 'NetworkPolicy', 'NetworkPolicyStatus', 'ProtonProtonChainPolicy', 'ProtonProtonIChainPolicy', 'ProtonProtonIIChainPolicy', 'ProtonProtonIIIChainPolicy', 'ReactionChainPolicy', 'TemperatureDependentChainPolicy', 'TripleAlphaChainPolicy', 'UNINITIALIZED', 'network_policy_status_to_string']
class CNOChainPolicy(MultiReactionChainPolicy): class CNOChainPolicy(MultiReactionChainPolicy):
def __eq__(self, other: ReactionChainPolicy) -> bool: def __eq__(self, other: ReactionChainPolicy) -> bool:
""" """
@@ -224,6 +226,13 @@ class CNOIVChainPolicy(TemperatureDependentChainPolicy):
""" """
Get the name of the reaction chain policy. Get the name of the reaction chain policy.
""" """
class ConstructionResults:
@property
def engine(self) -> gridfire._gridfire.engine.DynamicEngine:
...
@property
def scratch_blob(self) -> gridfire._gridfire.engine.scratchpads.StateBlob:
...
class HotCNOChainPolicy(MultiReactionChainPolicy): class HotCNOChainPolicy(MultiReactionChainPolicy):
def __eq__(self, other: ReactionChainPolicy) -> bool: def __eq__(self, other: ReactionChainPolicy) -> bool:
""" """
@@ -407,14 +416,18 @@ class MainSequencePolicy(NetworkPolicy):
""" """
Construct MainSequencePolicy from seed species and mass fractions. Construct MainSequencePolicy from seed species and mass fractions.
""" """
def construct(self) -> gridfire._gridfire.engine.DynamicEngine: def construct(self) -> ConstructionResults:
""" """
Construct the network according to the policy. Construct the network according to the policy.
""" """
def get_engine_stack(self) -> list[gridfire._gridfire.engine.DynamicEngine]:
...
def get_engine_types_stack(self) -> list[gridfire._gridfire.engine.EngineTypes]: def get_engine_types_stack(self) -> list[gridfire._gridfire.engine.EngineTypes]:
""" """
Get the types of engines in the stack constructed by the network policy. Get the types of engines in the stack constructed by the network policy.
""" """
def get_partition_function(self) -> gridfire._gridfire.partition.PartitionFunction:
...
def get_seed_reactions(self) -> gridfire._gridfire.reaction.ReactionSet: def get_seed_reactions(self) -> gridfire._gridfire.reaction.ReactionSet:
""" """
Get the set of seed reactions required by the network policy. Get the set of seed reactions required by the network policy.
@@ -423,6 +436,8 @@ class MainSequencePolicy(NetworkPolicy):
""" """
Get the set of seed species required by the network policy. Get the set of seed species required by the network policy.
""" """
def get_stack_scratch_blob(self) -> gridfire._gridfire.engine.scratchpads.StateBlob:
...
def get_status(self) -> NetworkPolicyStatus: def get_status(self) -> NetworkPolicyStatus:
""" """
Get the current status of the network policy. Get the current status of the network policy.
@@ -743,6 +758,10 @@ class TripleAlphaChainPolicy(TemperatureDependentChainPolicy):
""" """
Get the name of the reaction chain policy. Get the name of the reaction chain policy.
""" """
def network_policy_status_to_string(status: NetworkPolicyStatus) -> str:
"""
Convert a NetworkPolicyStatus enum value to its string representation.
"""
INITIALIZED_UNVERIFIED: NetworkPolicyStatus # value = <NetworkPolicyStatus.INITIALIZED_UNVERIFIED: 1> INITIALIZED_UNVERIFIED: NetworkPolicyStatus # value = <NetworkPolicyStatus.INITIALIZED_UNVERIFIED: 1>
INITIALIZED_VERIFIED: NetworkPolicyStatus # value = <NetworkPolicyStatus.INITIALIZED_VERIFIED: 4> INITIALIZED_VERIFIED: NetworkPolicyStatus # value = <NetworkPolicyStatus.INITIALIZED_VERIFIED: 4>
MISSING_KEY_REACTION: NetworkPolicyStatus # value = <NetworkPolicyStatus.MISSING_KEY_REACTION: 2> MISSING_KEY_REACTION: NetworkPolicyStatus # value = <NetworkPolicyStatus.MISSING_KEY_REACTION: 2>

View File

@@ -5,47 +5,113 @@ from __future__ import annotations
import collections.abc import collections.abc
import fourdst._phys.atomic import fourdst._phys.atomic
import gridfire._gridfire.engine import gridfire._gridfire.engine
import gridfire._gridfire.engine.scratchpads
import gridfire._gridfire.type import gridfire._gridfire.type
import types
import typing import typing
__all__: list[str] = ['CVODESolverStrategy', 'CVODETimestepContext', 'DynamicNetworkSolverStrategy', 'SolverContextBase'] __all__: list[str] = ['GridSolver', 'GridSolverContext', 'MultiZoneDynamicNetworkSolver', 'PointSolver', 'PointSolverContext', 'PointSolverTimestepContext', 'SingleZoneDynamicNetworkSolver', 'SolverContextBase']
class CVODESolverStrategy(DynamicNetworkSolverStrategy): class GridSolver(MultiZoneDynamicNetworkSolver):
def __init__(self, engine: gridfire._gridfire.engine.DynamicEngine) -> None: def __init__(self, engine: gridfire._gridfire.engine.DynamicEngine, solver: SingleZoneDynamicNetworkSolver) -> None:
""" """
Initialize the CVODESolverStrategy object. Initialize the GridSolver object.
""" """
def evaluate(self, netIn: gridfire._gridfire.type.NetIn, display_trigger: bool = False) -> gridfire._gridfire.type.NetOut: def evaluate(self, solver_ctx: SolverContextBase, netIns: collections.abc.Sequence[gridfire._gridfire.type.NetIn]) -> list[gridfire._gridfire.type.NetOut]:
""" """
evaluate the dynamic engine using the dynamic engine class evaluate the dynamic engine using the dynamic engine class
""" """
def get_absTol(self) -> float: class GridSolverContext(SolverContextBase):
detailed_logging: bool
stdout_logging: bool
zone_completion_logging: bool
def __init__(self, ctx_template: gridfire._gridfire.engine.scratchpads.StateBlob) -> None:
...
@typing.overload
def clear_callback(self) -> None:
...
@typing.overload
def clear_callback(self, zone_idx: typing.SupportsInt) -> None:
...
def init(self) -> None:
...
def reset(self) -> None:
...
@typing.overload
def set_callback(self, callback: collections.abc.Callable[[...], None]) -> None:
...
@typing.overload
def set_callback(self, callback: collections.abc.Callable[[...], None], zone_idx: typing.SupportsInt) -> None:
...
class MultiZoneDynamicNetworkSolver:
def evaluate(self, solver_ctx: SolverContextBase, netIns: collections.abc.Sequence[gridfire._gridfire.type.NetIn]) -> list[gridfire._gridfire.type.NetOut]:
""" """
Get the absolute tolerance for the CVODE solver. evaluate the dynamic engine using the dynamic engine class for multiple zones (using openmp if available)
""" """
def get_relTol(self) -> float: class PointSolver(SingleZoneDynamicNetworkSolver):
def __init__(self, engine: gridfire._gridfire.engine.DynamicEngine) -> None:
""" """
Get the relative tolerance for the CVODE solver. Initialize the PointSolver object.
""" """
def get_stdout_logging_enabled(self) -> bool: def evaluate(self, solver_ctx: SolverContextBase, netIn: gridfire._gridfire.type.NetIn, display_trigger: bool = False, force_reinitialization: bool = False) -> gridfire._gridfire.type.NetOut:
""" """
Check if solver logging to standard output is enabled. evaluate the dynamic engine using the dynamic engine class
""" """
def set_absTol(self, absTol: typing.SupportsFloat) -> None: class PointSolverContext:
""" callback: collections.abc.Callable[[PointSolverTimestepContext], None] | None
Set the absolute tolerance for the CVODE solver. detailed_logging: bool
""" stdout_logging: bool
def set_callback(self, cb: collections.abc.Callable[[CVODETimestepContext], None]) -> None: def __init__(self, engine_ctx: gridfire._gridfire.engine.scratchpads.StateBlob) -> None:
""" ...
Set a callback function which will run at the end of every successful timestep def clear_context(self) -> None:
""" ...
def set_relTol(self, relTol: typing.SupportsFloat) -> None: def has_context(self) -> bool:
""" ...
Set the relative tolerance for the CVODE solver. def init(self) -> None:
""" ...
def set_stdout_logging_enabled(self, logging_enabled: bool) -> None: def init_context(self) -> None:
""" ...
Enable logging to standard output. def reset_all(self) -> None:
""" ...
class CVODETimestepContext(SolverContextBase): def reset_cvode(self) -> None:
...
def reset_user(self) -> None:
...
@property
def J(self) -> _generic_SUNMatrix:
...
@property
def LS(self) -> _generic_SUNLinearSolver:
...
@property
def Y(self) -> _generic_N_Vector:
...
@property
def YErr(self) -> _generic_N_Vector:
...
@property
def abs_tol(self) -> float:
...
@abs_tol.setter
def abs_tol(self, arg1: typing.SupportsFloat) -> None:
...
@property
def cvode_mem(self) -> types.CapsuleType:
...
@property
def engine_ctx(self) -> gridfire._gridfire.engine.scratchpads.StateBlob:
...
@property
def num_steps(self) -> int:
...
@property
def rel_tol(self) -> float:
...
@rel_tol.setter
def rel_tol(self, arg1: typing.SupportsFloat) -> None:
...
@property
def sun_ctx(self) -> SUNContext_:
...
class PointSolverTimestepContext:
@property @property
def T9(self) -> float: def T9(self) -> float:
... ...
@@ -77,16 +143,15 @@ class CVODETimestepContext(SolverContextBase):
def state(self) -> list[float]: def state(self) -> list[float]:
... ...
@property @property
def state_ctx(self) -> gridfire._gridfire.engine.scratchpads.StateBlob:
...
@property
def t(self) -> float: def t(self) -> float:
... ...
class DynamicNetworkSolverStrategy: class SingleZoneDynamicNetworkSolver:
def describe_callback_context(self) -> list[tuple[str, str]]: def evaluate(self, solver_ctx: SolverContextBase, netIn: gridfire._gridfire.type.NetIn) -> gridfire._gridfire.type.NetOut:
""" """
Get a structure representing what data is in the callback context in a human readable format evaluate the dynamic engine using the dynamic engine class for a single zone
"""
def evaluate(self, netIn: gridfire._gridfire.type.NetIn) -> gridfire._gridfire.type.NetOut:
"""
evaluate the dynamic engine using the dynamic engine class
""" """
class SolverContextBase: class SolverContextBase:
pass pass

View File

@@ -59,3 +59,9 @@ class NetOut:
@property @property
def num_steps(self) -> int: def num_steps(self) -> int:
... ...
@property
def specific_neutrino_energy_loss(self) -> float:
...
@property
def specific_neutrino_flux(self) -> float:
...

View File

@@ -4,10 +4,11 @@ GridFire utility method bindings
from __future__ import annotations from __future__ import annotations
import fourdst._phys.composition import fourdst._phys.composition
import gridfire._gridfire.engine import gridfire._gridfire.engine
import gridfire._gridfire.engine.scratchpads
import typing import typing
from . import hashing from . import hashing
__all__: list[str] = ['formatNuclearTimescaleLogString', 'hash_atomic', 'hash_reaction', 'hashing'] __all__: list[str] = ['formatNuclearTimescaleLogString', 'hash_atomic', 'hash_reaction', 'hashing']
def formatNuclearTimescaleLogString(engine: gridfire._gridfire.engine.DynamicEngine, Y: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> str: def formatNuclearTimescaleLogString(ctx: gridfire._gridfire.engine.scratchpads.StateBlob, engine: gridfire._gridfire.engine.DynamicEngine, Y: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> str:
""" """
Format a string for logging nuclear timescales based on temperature, density, and energy generation rate. Format a string for logging nuclear timescales based on temperature, density, and energy generation rate.
""" """

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,16 @@
"""
Python bindings for the fourdst utility modules which are a part of the 4D-STAR project.
"""
from __future__ import annotations
from . import config
from . import engine
from . import exceptions
from . import io
from . import partition
from . import policy
from . import reaction
from . import screening
from . import solver
from . import type
from . import utils
__all__: list[str] = ['config', 'engine', 'exceptions', 'io', 'partition', 'policy', 'reaction', 'screening', 'solver', 'type', 'utils']

View File

@@ -0,0 +1,47 @@
"""
GridFire configuration bindings
"""
from __future__ import annotations
import typing
__all__: list[str] = ['AdaptiveEngineViewConfig', 'CVODESolverConfig', 'EngineConfig', 'EngineViewConfig', 'GridFireConfig', 'SolverConfig']
class AdaptiveEngineViewConfig:
def __init__(self) -> None:
...
@property
def relativeCullingThreshold(self) -> float:
...
@relativeCullingThreshold.setter
def relativeCullingThreshold(self, arg0: typing.SupportsFloat) -> None:
...
class CVODESolverConfig:
def __init__(self) -> None:
...
@property
def absTol(self) -> float:
...
@absTol.setter
def absTol(self, arg0: typing.SupportsFloat) -> None:
...
@property
def relTol(self) -> float:
...
@relTol.setter
def relTol(self, arg0: typing.SupportsFloat) -> None:
...
class EngineConfig:
views: EngineViewConfig
def __init__(self) -> None:
...
class EngineViewConfig:
adaptiveEngineView: AdaptiveEngineViewConfig
def __init__(self) -> None:
...
class GridFireConfig:
engine: EngineConfig
solver: SolverConfig
def __init__(self) -> None:
...
class SolverConfig:
cvode: CVODESolverConfig
def __init__(self) -> None:
...

View File

@@ -0,0 +1,972 @@
"""
Engine and Engine View bindings
"""
from __future__ import annotations
import collections.abc
import fourdst._phys.atomic
import fourdst._phys.composition
import gridfire._gridfire.io
import gridfire._gridfire.partition
import gridfire._gridfire.reaction
import gridfire._gridfire.screening
import gridfire._gridfire.type
import numpy
import numpy.typing
import typing
from . import diagnostics
from . import scratchpads
__all__: list[str] = ['ACTIVE', 'ADAPTIVE_ENGINE_VIEW', 'AdaptiveEngineView', 'BuildDepthType', 'DEFAULT', 'DEFINED_ENGINE_VIEW', 'DefinedEngineView', 'DynamicEngine', 'EQUILIBRIUM', 'Engine', 'EngineTypes', 'FILE_DEFINED_ENGINE_VIEW', 'FULL_SUCCESS', 'FifthOrder', 'FileDefinedEngineView', 'FourthOrder', 'Full', 'GRAPH_ENGINE', 'GraphEngine', 'INACTIVE_FLOW', 'MAX_ITERATIONS_REACHED', 'MULTISCALE_PARTITIONING_ENGINE_VIEW', 'MultiscalePartitioningEngineView', 'NONE', 'NOT_PRESENT', 'NO_SPECIES_TO_PRIME', 'NetworkBuildDepth', 'NetworkConstructionFlags', 'NetworkJacobian', 'NetworkPrimingEngineView', 'PRIMING_ENGINE_VIEW', 'PrimingReport', 'PrimingReportStatus', 'REACLIB', 'REACLIB_STRONG', 'REACLIB_WEAK', 'SecondOrder', 'Shallow', 'SparsityPattern', 'SpeciesStatus', 'StepDerivatives', 'ThirdOrder', 'WRL_BETA_MINUS', 'WRL_BETA_PLUS', 'WRL_ELECTRON_CAPTURE', 'WRL_POSITRON_CAPTURE', 'WRL_WEAK', 'build_nuclear_network', 'diagnostics', 'primeNetwork', 'regularize_jacobian', 'scratchpads']
class AdaptiveEngineView(DynamicEngine):
def __init__(self, baseEngine: DynamicEngine) -> None:
"""
Construct an adaptive engine view with a base engine.
"""
def calculateEpsDerivatives(self, ctx: scratchpads.StateBlob, comp: ..., T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> ...:
"""
Calculate deps/dT and deps/drho
"""
def calculateMolarReactionFlow(self: DynamicEngine, ctx: scratchpads.StateBlob, reaction: ..., comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> float:
"""
Calculate the molar reaction flow for a given reaction.
"""
def calculateRHSAndEnergy(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> StepDerivatives:
"""
Calculate the right-hand side (dY/dt) and energy generation rate.
"""
def collectComposition(self, ctx: scratchpads.StateBlob, composition: ..., T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> fourdst._phys.composition.Composition:
"""
Recursively collect composition from current engine and any sub engines if they exist.
"""
@typing.overload
def generateJacobianMatrix(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> NetworkJacobian:
"""
Generate the Jacobian matrix for the current state.
"""
@typing.overload
def generateJacobianMatrix(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, activeSpecies: collections.abc.Sequence[fourdst._phys.atomic.Species]) -> NetworkJacobian:
"""
Generate the jacobian matrix only for the subset of the matrix representing the active species.
"""
@typing.overload
def generateJacobianMatrix(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, sparsityPattern: collections.abc.Sequence[tuple[typing.SupportsInt, typing.SupportsInt]]) -> NetworkJacobian:
"""
Generate the jacobian matrix for the given sparsity pattern
"""
def getBaseEngine(self) -> DynamicEngine:
"""
Get the base engine associated with this adaptive engine view.
"""
def getNetworkReactions(self, arg0: scratchpads.StateBlob) -> gridfire._gridfire.reaction.ReactionSet:
"""
Get the set of logical reactions in the network.
"""
def getNetworkSpecies(self, arg0: scratchpads.StateBlob) -> list[fourdst._phys.atomic.Species]:
"""
Get the list of species in the network.
"""
def getScreeningModel(self, arg0: scratchpads.StateBlob) -> gridfire._gridfire.screening.ScreeningType:
"""
Get the current screening model of the engine.
"""
def getSpeciesDestructionTimescales(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> dict[fourdst._phys.atomic.Species, float]:
"""
Get the destruction timescales for each species in the network.
"""
def getSpeciesIndex(self, ctx: scratchpads.StateBlob, species: fourdst._phys.atomic.Species) -> int:
"""
Get the index of a species in the network.
"""
def getSpeciesStatus(self, ctx: scratchpads.StateBlob, species: fourdst._phys.atomic.Species) -> SpeciesStatus:
"""
Get the status of a species in the network.
"""
def getSpeciesTimescales(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> dict[fourdst._phys.atomic.Species, float]:
"""
Get the timescales for each species in the network.
"""
def primeEngine(self, ctx: scratchpads.StateBlob, netIn: gridfire._gridfire.type.NetIn) -> PrimingReport:
"""
Prime the engine with a NetIn object to prepare for calculations.
"""
def project(self, ctx: scratchpads.StateBlob, netIn: gridfire._gridfire.type.NetIn) -> fourdst._phys.composition.Composition:
"""
Update the engine state based on the provided NetIn object.
"""
class BuildDepthType:
pass
class DefinedEngineView(DynamicEngine):
def __init__(self, peNames: collections.abc.Sequence[str], baseEngine: GraphEngine) -> None:
"""
Construct a defined engine view with a list of tracked reactions and a base engine.
"""
def calculateEpsDerivatives(self, ctx: scratchpads.StateBlob, comp: ..., T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> ...:
"""
Calculate deps/dT and deps/drho
"""
def calculateMolarReactionFlow(self: DynamicEngine, ctx: scratchpads.StateBlob, reaction: ..., comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> float:
"""
Calculate the molar reaction flow for a given reaction.
"""
def calculateRHSAndEnergy(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> StepDerivatives:
"""
Calculate the right-hand side (dY/dt) and energy generation rate.
"""
def collectComposition(self, ctx: scratchpads.StateBlob, composition: ..., T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> fourdst._phys.composition.Composition:
"""
Recursively collect composition from current engine and any sub engines if they exist.
"""
@typing.overload
def generateJacobianMatrix(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> NetworkJacobian:
"""
Generate the Jacobian matrix for the current state.
"""
@typing.overload
def generateJacobianMatrix(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, activeSpecies: collections.abc.Sequence[fourdst._phys.atomic.Species]) -> NetworkJacobian:
"""
Generate the jacobian matrix only for the subset of the matrix representing the active species.
"""
@typing.overload
def generateJacobianMatrix(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, sparsityPattern: collections.abc.Sequence[tuple[typing.SupportsInt, typing.SupportsInt]]) -> NetworkJacobian:
"""
Generate the jacobian matrix for the given sparsity pattern
"""
def getBaseEngine(self) -> DynamicEngine:
"""
Get the base engine associated with this defined engine view.
"""
def getNetworkReactions(self, arg0: scratchpads.StateBlob) -> gridfire._gridfire.reaction.ReactionSet:
"""
Get the set of logical reactions in the network.
"""
def getNetworkSpecies(self, arg0: scratchpads.StateBlob) -> list[fourdst._phys.atomic.Species]:
"""
Get the list of species in the network.
"""
def getScreeningModel(self, arg0: scratchpads.StateBlob) -> gridfire._gridfire.screening.ScreeningType:
"""
Get the current screening model of the engine.
"""
def getSpeciesDestructionTimescales(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> dict[fourdst._phys.atomic.Species, float]:
"""
Get the destruction timescales for each species in the network.
"""
def getSpeciesIndex(self, ctx: scratchpads.StateBlob, species: fourdst._phys.atomic.Species) -> int:
"""
Get the index of a species in the network.
"""
def getSpeciesStatus(self, ctx: scratchpads.StateBlob, species: fourdst._phys.atomic.Species) -> SpeciesStatus:
"""
Get the status of a species in the network.
"""
def getSpeciesTimescales(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> dict[fourdst._phys.atomic.Species, float]:
"""
Get the timescales for each species in the network.
"""
def primeEngine(self, ctx: scratchpads.StateBlob, netIn: gridfire._gridfire.type.NetIn) -> PrimingReport:
"""
Prime the engine with a NetIn object to prepare for calculations.
"""
def project(self, ctx: scratchpads.StateBlob, netIn: gridfire._gridfire.type.NetIn) -> fourdst._phys.composition.Composition:
"""
Update the engine state based on the provided NetIn object.
"""
class DynamicEngine:
pass
class Engine:
pass
class EngineTypes:
"""
Members:
GRAPH_ENGINE : The standard graph-based engine.
ADAPTIVE_ENGINE_VIEW : An engine that adapts based on certain criteria.
MULTISCALE_PARTITIONING_ENGINE_VIEW : An engine that partitions the system at multiple scales.
PRIMING_ENGINE_VIEW : An engine that uses a priming strategy for simulations.
DEFINED_ENGINE_VIEW : An engine defined by user specifications.
FILE_DEFINED_ENGINE_VIEW : An engine defined through external files.
"""
ADAPTIVE_ENGINE_VIEW: typing.ClassVar[EngineTypes] # value = <EngineTypes.ADAPTIVE_ENGINE_VIEW: 1>
DEFINED_ENGINE_VIEW: typing.ClassVar[EngineTypes] # value = <EngineTypes.DEFINED_ENGINE_VIEW: 4>
FILE_DEFINED_ENGINE_VIEW: typing.ClassVar[EngineTypes] # value = <EngineTypes.FILE_DEFINED_ENGINE_VIEW: 5>
GRAPH_ENGINE: typing.ClassVar[EngineTypes] # value = <EngineTypes.GRAPH_ENGINE: 0>
MULTISCALE_PARTITIONING_ENGINE_VIEW: typing.ClassVar[EngineTypes] # value = <EngineTypes.MULTISCALE_PARTITIONING_ENGINE_VIEW: 2>
PRIMING_ENGINE_VIEW: typing.ClassVar[EngineTypes] # value = <EngineTypes.PRIMING_ENGINE_VIEW: 3>
__members__: typing.ClassVar[dict[str, EngineTypes]] # value = {'GRAPH_ENGINE': <EngineTypes.GRAPH_ENGINE: 0>, 'ADAPTIVE_ENGINE_VIEW': <EngineTypes.ADAPTIVE_ENGINE_VIEW: 1>, 'MULTISCALE_PARTITIONING_ENGINE_VIEW': <EngineTypes.MULTISCALE_PARTITIONING_ENGINE_VIEW: 2>, 'PRIMING_ENGINE_VIEW': <EngineTypes.PRIMING_ENGINE_VIEW: 3>, 'DEFINED_ENGINE_VIEW': <EngineTypes.DEFINED_ENGINE_VIEW: 4>, 'FILE_DEFINED_ENGINE_VIEW': <EngineTypes.FILE_DEFINED_ENGINE_VIEW: 5>}
def __eq__(self, other: typing.Any) -> bool:
...
def __getstate__(self) -> int:
...
def __hash__(self) -> int:
...
def __index__(self) -> int:
...
def __init__(self, value: typing.SupportsInt) -> None:
...
def __int__(self) -> int:
...
def __ne__(self, other: typing.Any) -> bool:
...
@typing.overload
def __repr__(self) -> str:
...
@typing.overload
def __repr__(self) -> str:
"""
String representation of the EngineTypes.
"""
def __setstate__(self, state: typing.SupportsInt) -> None:
...
def __str__(self) -> str:
...
@property
def name(self) -> str:
...
@property
def value(self) -> int:
...
class FileDefinedEngineView(DefinedEngineView):
def __init__(self, baseEngine: GraphEngine, fileName: str, parser: gridfire._gridfire.io.NetworkFileParser) -> None:
"""
Construct a defined engine view from a file and a base engine.
"""
def calculateEpsDerivatives(self, ctx: scratchpads.StateBlob, comp: ..., T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> ...:
"""
Calculate deps/dT and deps/drho
"""
def calculateMolarReactionFlow(self: DynamicEngine, ctx: scratchpads.StateBlob, reaction: ..., comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> float:
"""
Calculate the molar reaction flow for a given reaction.
"""
def calculateRHSAndEnergy(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> StepDerivatives:
"""
Calculate the right-hand side (dY/dt) and energy generation rate.
"""
def collectComposition(self, ctx: scratchpads.StateBlob, composition: ..., T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> fourdst._phys.composition.Composition:
"""
Recursively collect composition from current engine and any sub engines if they exist.
"""
@typing.overload
def generateJacobianMatrix(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> NetworkJacobian:
"""
Generate the Jacobian matrix for the current state.
"""
@typing.overload
def generateJacobianMatrix(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, activeSpecies: collections.abc.Sequence[fourdst._phys.atomic.Species]) -> NetworkJacobian:
"""
Generate the jacobian matrix only for the subset of the matrix representing the active species.
"""
@typing.overload
def generateJacobianMatrix(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, sparsityPattern: collections.abc.Sequence[tuple[typing.SupportsInt, typing.SupportsInt]]) -> NetworkJacobian:
"""
Generate the jacobian matrix for the given sparsity pattern
"""
def getBaseEngine(self) -> DynamicEngine:
"""
Get the base engine associated with this file defined engine view.
"""
def getNetworkFile(self) -> str:
"""
Get the network file associated with this defined engine view.
"""
def getNetworkReactions(self, arg0: scratchpads.StateBlob) -> gridfire._gridfire.reaction.ReactionSet:
"""
Get the set of logical reactions in the network.
"""
def getNetworkSpecies(self, arg0: scratchpads.StateBlob) -> list[fourdst._phys.atomic.Species]:
"""
Get the list of species in the network.
"""
def getParser(self) -> gridfire._gridfire.io.NetworkFileParser:
"""
Get the parser used for this defined engine view.
"""
def getScreeningModel(self, arg0: scratchpads.StateBlob) -> gridfire._gridfire.screening.ScreeningType:
"""
Get the current screening model of the engine.
"""
def getSpeciesDestructionTimescales(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> dict[fourdst._phys.atomic.Species, float]:
"""
Get the destruction timescales for each species in the network.
"""
def getSpeciesIndex(self, ctx: scratchpads.StateBlob, species: fourdst._phys.atomic.Species) -> int:
"""
Get the index of a species in the network.
"""
def getSpeciesStatus(self, ctx: scratchpads.StateBlob, species: fourdst._phys.atomic.Species) -> SpeciesStatus:
"""
Get the status of a species in the network.
"""
def getSpeciesTimescales(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> dict[fourdst._phys.atomic.Species, float]:
"""
Get the timescales for each species in the network.
"""
def primeEngine(self, ctx: scratchpads.StateBlob, netIn: gridfire._gridfire.type.NetIn) -> PrimingReport:
"""
Prime the engine with a NetIn object to prepare for calculations.
"""
def project(self, ctx: scratchpads.StateBlob, netIn: gridfire._gridfire.type.NetIn) -> fourdst._phys.composition.Composition:
"""
Update the engine state based on the provided NetIn object.
"""
class GraphEngine(DynamicEngine):
@typing.overload
def __init__(self, composition: fourdst._phys.composition.Composition, depth: gridfire._gridfire.engine.NetworkBuildDepth | typing.SupportsInt = ...) -> None:
"""
Initialize GraphEngine with a composition and build depth.
"""
@typing.overload
def __init__(self, composition: fourdst._phys.composition.Composition, partitionFunction: gridfire._gridfire.partition.PartitionFunction, depth: gridfire._gridfire.engine.NetworkBuildDepth | typing.SupportsInt = ...) -> None:
"""
Initialize GraphEngine with a composition, partition function and build depth.
"""
@typing.overload
def __init__(self, reactions: gridfire._gridfire.reaction.ReactionSet) -> None:
"""
Initialize GraphEngine with a set of reactions.
"""
def calculateEpsDerivatives(self, ctx: scratchpads.StateBlob, comp: ..., T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> ...:
"""
Calculate deps/dT and deps/drho
"""
def calculateMolarReactionFlow(self: DynamicEngine, ctx: scratchpads.StateBlob, reaction: ..., comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> float:
"""
Calculate the molar reaction flow for a given reaction.
"""
def calculateRHSAndEnergy(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> StepDerivatives:
"""
Calculate the right-hand side (dY/dt) and energy generation rate.
"""
def calculateReverseRate(self, reaction: ..., T9: typing.SupportsFloat, rho: typing.SupportsFloat, composition: ...) -> float:
"""
Calculate the reverse rate for a given reaction at a specific temperature, density, and composition.
"""
def calculateReverseRateTwoBody(self, reaction: ..., T9: typing.SupportsFloat, forwardRate: typing.SupportsFloat, expFactor: typing.SupportsFloat) -> float:
"""
Calculate the reverse rate for a two-body reaction at a specific temperature.
"""
def calculateReverseRateTwoBodyDerivative(self, reaction: ..., T9: typing.SupportsFloat, rho: typing.SupportsFloat, composition: fourdst._phys.composition.Composition, reverseRate: typing.SupportsFloat) -> float:
"""
Calculate the derivative of the reverse rate for a two-body reaction at a specific temperature.
"""
def collectComposition(self, ctx: scratchpads.StateBlob, composition: ..., T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> fourdst._phys.composition.Composition:
"""
Recursively collect composition from current engine and any sub engines if they exist.
"""
def exportToCSV(self, ctx: scratchpads.StateBlob, filename: str) -> None:
"""
Export the network to a CSV file for analysis.
"""
def exportToDot(self, ctx: scratchpads.StateBlob, filename: str) -> None:
"""
Export the network to a DOT file for visualization.
"""
@typing.overload
def generateJacobianMatrix(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> NetworkJacobian:
"""
Generate the Jacobian matrix for the current state.
"""
@typing.overload
def generateJacobianMatrix(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, activeSpecies: collections.abc.Sequence[fourdst._phys.atomic.Species]) -> NetworkJacobian:
"""
Generate the jacobian matrix only for the subset of the matrix representing the active species.
"""
@typing.overload
def generateJacobianMatrix(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, sparsityPattern: collections.abc.Sequence[tuple[typing.SupportsInt, typing.SupportsInt]]) -> NetworkJacobian:
"""
Generate the jacobian matrix for the given sparsity pattern
"""
def getNetworkReactions(self, arg0: scratchpads.StateBlob) -> gridfire._gridfire.reaction.ReactionSet:
"""
Get the set of logical reactions in the network.
"""
def getNetworkSpecies(self, arg0: scratchpads.StateBlob) -> list[fourdst._phys.atomic.Species]:
"""
Get the list of species in the network.
"""
def getPartitionFunction(self, arg0: scratchpads.StateBlob) -> gridfire._gridfire.partition.PartitionFunction:
"""
Get the partition function used by the engine.
"""
def getScreeningModel(self, arg0: scratchpads.StateBlob) -> gridfire._gridfire.screening.ScreeningType:
"""
Get the current screening model of the engine.
"""
@typing.overload
def getSpeciesDestructionTimescales(self, ctx: scratchpads.StateBlob, composition: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, activeReactions: gridfire._gridfire.reaction.ReactionSet) -> ...:
...
@typing.overload
def getSpeciesDestructionTimescales(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> dict[fourdst._phys.atomic.Species, float]:
"""
Get the destruction timescales for each species in the network.
"""
def getSpeciesIndex(self, ctx: scratchpads.StateBlob, species: fourdst._phys.atomic.Species) -> int:
"""
Get the index of a species in the network.
"""
def getSpeciesStatus(self, ctx: scratchpads.StateBlob, species: fourdst._phys.atomic.Species) -> SpeciesStatus:
"""
Get the status of a species in the network.
"""
@typing.overload
def getSpeciesTimescales(self, ctx: scratchpads.StateBlob, composition: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, activeReactions: gridfire._gridfire.reaction.ReactionSet) -> ...:
...
@typing.overload
def getSpeciesTimescales(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> dict[fourdst._phys.atomic.Species, float]:
"""
Get the timescales for each species in the network.
"""
def involvesSpecies(self, ctx: scratchpads.StateBlob, species: fourdst._phys.atomic.Species) -> bool:
"""
Check if a given species is involved in the network.
"""
def isPrecomputationEnabled(self, arg0: scratchpads.StateBlob) -> bool:
"""
Check if precomputation is enabled for the engine.
"""
def isUsingReverseReactions(self, arg0: scratchpads.StateBlob) -> bool:
"""
Check if the engine is using reverse reactions.
"""
def primeEngine(self, ctx: scratchpads.StateBlob, netIn: gridfire._gridfire.type.NetIn) -> PrimingReport:
"""
Prime the engine with a NetIn object to prepare for calculations.
"""
def project(self, ctx: scratchpads.StateBlob, netIn: gridfire._gridfire.type.NetIn) -> fourdst._phys.composition.Composition:
"""
Update the engine state based on the provided NetIn object.
"""
class MultiscalePartitioningEngineView(DynamicEngine):
def __init__(self, baseEngine: GraphEngine) -> None:
"""
Construct a multiscale partitioning engine view with a base engine.
"""
def calculateEpsDerivatives(self, ctx: scratchpads.StateBlob, comp: ..., T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> ...:
"""
Calculate deps/dT and deps/drho
"""
def calculateMolarReactionFlow(self: DynamicEngine, ctx: scratchpads.StateBlob, reaction: ..., comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> float:
"""
Calculate the molar reaction flow for a given reaction.
"""
def calculateRHSAndEnergy(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> StepDerivatives:
"""
Calculate the right-hand side (dY/dt) and energy generation rate.
"""
def collectComposition(self, ctx: scratchpads.StateBlob, composition: ..., T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> fourdst._phys.composition.Composition:
"""
Recursively collect composition from current engine and any sub engines if they exist.
"""
def exportToDot(self, ctx: scratchpads.StateBlob, filename: str, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> None:
"""
Export the network to a DOT file for visualization.
"""
@typing.overload
def generateJacobianMatrix(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> NetworkJacobian:
"""
Generate the Jacobian matrix for the current state.
"""
@typing.overload
def generateJacobianMatrix(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, activeSpecies: collections.abc.Sequence[fourdst._phys.atomic.Species]) -> NetworkJacobian:
"""
Generate the jacobian matrix only for the subset of the matrix representing the active species.
"""
@typing.overload
def generateJacobianMatrix(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, sparsityPattern: collections.abc.Sequence[tuple[typing.SupportsInt, typing.SupportsInt]]) -> NetworkJacobian:
"""
Generate the jacobian matrix for the given sparsity pattern
"""
def getBaseEngine(self) -> DynamicEngine:
"""
Get the base engine associated with this multiscale partitioning engine view.
"""
def getDynamicSpecies(self: scratchpads.StateBlob) -> list[fourdst._phys.atomic.Species]:
"""
Get the list of dynamic species in the network.
"""
def getFastSpecies(self, arg0: scratchpads.StateBlob) -> list[fourdst._phys.atomic.Species]:
"""
Get the list of fast species in the network.
"""
def getNetworkReactions(self, arg0: scratchpads.StateBlob) -> gridfire._gridfire.reaction.ReactionSet:
"""
Get the set of logical reactions in the network.
"""
def getNetworkSpecies(self, arg0: scratchpads.StateBlob) -> list[fourdst._phys.atomic.Species]:
"""
Get the list of species in the network.
"""
def getNormalizedEquilibratedComposition(self, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> fourdst._phys.composition.Composition:
"""
Get the normalized equilibrated composition for the algebraic species.
"""
def getScreeningModel(self, arg0: scratchpads.StateBlob) -> gridfire._gridfire.screening.ScreeningType:
"""
Get the current screening model of the engine.
"""
def getSpeciesDestructionTimescales(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> dict[fourdst._phys.atomic.Species, float]:
"""
Get the destruction timescales for each species in the network.
"""
def getSpeciesIndex(self, ctx: scratchpads.StateBlob, species: fourdst._phys.atomic.Species) -> int:
"""
Get the index of a species in the network.
"""
def getSpeciesStatus(self, ctx: scratchpads.StateBlob, species: fourdst._phys.atomic.Species) -> SpeciesStatus:
"""
Get the status of a species in the network.
"""
def getSpeciesTimescales(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> dict[fourdst._phys.atomic.Species, float]:
"""
Get the timescales for each species in the network.
"""
def involvesSpecies(self: scratchpads.StateBlob, species: fourdst._phys.atomic.Species) -> bool:
"""
Check if a given species is involved in the network (in either the algebraic or dynamic set).
"""
def involvesSpeciesInDynamic(self: scratchpads.StateBlob, species: fourdst._phys.atomic.Species) -> bool:
"""
Check if a given species is involved in the network's dynamic set.
"""
def involvesSpeciesInQSE(self: scratchpads.StateBlob, species: fourdst._phys.atomic.Species) -> bool:
"""
Check if a given species is involved in the network's algebraic set.
"""
def partitionNetwork(self, ctx: scratchpads.StateBlob, netIn: gridfire._gridfire.type.NetIn) -> fourdst._phys.composition.Composition:
"""
Partition the network based on species timescales and connectivity.
"""
def primeEngine(self, ctx: scratchpads.StateBlob, netIn: gridfire._gridfire.type.NetIn) -> PrimingReport:
"""
Prime the engine with a NetIn object to prepare for calculations.
"""
def project(self, ctx: scratchpads.StateBlob, netIn: gridfire._gridfire.type.NetIn) -> fourdst._phys.composition.Composition:
"""
Update the engine state based on the provided NetIn object.
"""
class NetworkBuildDepth:
"""
Members:
Full : Full network build depth
Shallow : Shallow network build depth
SecondOrder : Second order network build depth
ThirdOrder : Third order network build depth
FourthOrder : Fourth order network build depth
FifthOrder : Fifth order network build depth
"""
FifthOrder: typing.ClassVar[NetworkBuildDepth] # value = <NetworkBuildDepth.FifthOrder: 5>
FourthOrder: typing.ClassVar[NetworkBuildDepth] # value = <NetworkBuildDepth.FourthOrder: 4>
Full: typing.ClassVar[NetworkBuildDepth] # value = <NetworkBuildDepth.Full: -1>
SecondOrder: typing.ClassVar[NetworkBuildDepth] # value = <NetworkBuildDepth.SecondOrder: 2>
Shallow: typing.ClassVar[NetworkBuildDepth] # value = <NetworkBuildDepth.Shallow: 1>
ThirdOrder: typing.ClassVar[NetworkBuildDepth] # value = <NetworkBuildDepth.ThirdOrder: 3>
__members__: typing.ClassVar[dict[str, NetworkBuildDepth]] # value = {'Full': <NetworkBuildDepth.Full: -1>, 'Shallow': <NetworkBuildDepth.Shallow: 1>, 'SecondOrder': <NetworkBuildDepth.SecondOrder: 2>, 'ThirdOrder': <NetworkBuildDepth.ThirdOrder: 3>, 'FourthOrder': <NetworkBuildDepth.FourthOrder: 4>, 'FifthOrder': <NetworkBuildDepth.FifthOrder: 5>}
def __eq__(self, other: typing.Any) -> bool:
...
def __getstate__(self) -> int:
...
def __hash__(self) -> int:
...
def __index__(self) -> int:
...
def __init__(self, value: typing.SupportsInt) -> None:
...
def __int__(self) -> int:
...
def __ne__(self, other: typing.Any) -> bool:
...
def __repr__(self) -> str:
...
def __setstate__(self, state: typing.SupportsInt) -> None:
...
def __str__(self) -> str:
...
@property
def name(self) -> str:
...
@property
def value(self) -> int:
...
class NetworkConstructionFlags:
"""
Members:
NONE : No special construction flags.
REACLIB_STRONG : Include strong reactions from reaclib.
WRL_BETA_MINUS : Include beta-minus decay reactions from weak rate library.
WRL_BETA_PLUS : Include beta-plus decay reactions from weak rate library.
WRL_ELECTRON_CAPTURE : Include electron capture reactions from weak rate library.
WRL_POSITRON_CAPTURE : Include positron capture reactions from weak rate library.
REACLIB_WEAK : Include weak reactions from reaclib.
WRL_WEAK : Include all weak reactions from weak rate library.
REACLIB : Include all reactions from reaclib.
DEFAULT : Default construction flags (Reaclib strong and weak).
"""
DEFAULT: typing.ClassVar[NetworkConstructionFlags] # value = <NetworkConstructionFlags.REACLIB: 33>
NONE: typing.ClassVar[NetworkConstructionFlags] # value = <NetworkConstructionFlags.NONE: 0>
REACLIB: typing.ClassVar[NetworkConstructionFlags] # value = <NetworkConstructionFlags.REACLIB: 33>
REACLIB_STRONG: typing.ClassVar[NetworkConstructionFlags] # value = <NetworkConstructionFlags.REACLIB_STRONG: 1>
REACLIB_WEAK: typing.ClassVar[NetworkConstructionFlags] # value = <NetworkConstructionFlags.REACLIB_WEAK: 32>
WRL_BETA_MINUS: typing.ClassVar[NetworkConstructionFlags] # value = <NetworkConstructionFlags.WRL_BETA_MINUS: 2>
WRL_BETA_PLUS: typing.ClassVar[NetworkConstructionFlags] # value = <NetworkConstructionFlags.WRL_BETA_PLUS: 4>
WRL_ELECTRON_CAPTURE: typing.ClassVar[NetworkConstructionFlags] # value = <NetworkConstructionFlags.WRL_ELECTRON_CAPTURE: 8>
WRL_POSITRON_CAPTURE: typing.ClassVar[NetworkConstructionFlags] # value = <NetworkConstructionFlags.WRL_POSITRON_CAPTURE: 16>
WRL_WEAK: typing.ClassVar[NetworkConstructionFlags] # value = <NetworkConstructionFlags.WRL_WEAK: 30>
__members__: typing.ClassVar[dict[str, NetworkConstructionFlags]] # value = {'NONE': <NetworkConstructionFlags.NONE: 0>, 'REACLIB_STRONG': <NetworkConstructionFlags.REACLIB_STRONG: 1>, 'WRL_BETA_MINUS': <NetworkConstructionFlags.WRL_BETA_MINUS: 2>, 'WRL_BETA_PLUS': <NetworkConstructionFlags.WRL_BETA_PLUS: 4>, 'WRL_ELECTRON_CAPTURE': <NetworkConstructionFlags.WRL_ELECTRON_CAPTURE: 8>, 'WRL_POSITRON_CAPTURE': <NetworkConstructionFlags.WRL_POSITRON_CAPTURE: 16>, 'REACLIB_WEAK': <NetworkConstructionFlags.REACLIB_WEAK: 32>, 'WRL_WEAK': <NetworkConstructionFlags.WRL_WEAK: 30>, 'REACLIB': <NetworkConstructionFlags.REACLIB: 33>, 'DEFAULT': <NetworkConstructionFlags.REACLIB: 33>}
def __eq__(self, other: typing.Any) -> bool:
...
def __getstate__(self) -> int:
...
def __hash__(self) -> int:
...
def __index__(self) -> int:
...
def __init__(self, value: typing.SupportsInt) -> None:
...
def __int__(self) -> int:
...
def __ne__(self, other: typing.Any) -> bool:
...
@typing.overload
def __repr__(self) -> str:
...
@typing.overload
def __repr__(self) -> str:
...
def __setstate__(self, state: typing.SupportsInt) -> None:
...
def __str__(self) -> str:
...
@property
def name(self) -> str:
...
@property
def value(self) -> int:
...
class NetworkJacobian:
@typing.overload
def __getitem__(self, key: tuple[fourdst._phys.atomic.Species, fourdst._phys.atomic.Species]) -> float:
"""
Get an entry from the Jacobian matrix using species identifiers.
"""
@typing.overload
def __getitem__(self, key: tuple[typing.SupportsInt, typing.SupportsInt]) -> float:
"""
Get an entry from the Jacobian matrix using indices.
"""
@typing.overload
def __setitem__(self, key: tuple[fourdst._phys.atomic.Species, fourdst._phys.atomic.Species], value: typing.SupportsFloat) -> None:
"""
Set an entry in the Jacobian matrix using species identifiers.
"""
@typing.overload
def __setitem__(self, key: tuple[typing.SupportsInt, typing.SupportsInt], value: typing.SupportsFloat) -> None:
"""
Set an entry in the Jacobian matrix using indices.
"""
def data(self) -> ...:
"""
Get the underlying sparse matrix data.
"""
def infs(self) -> list[tuple[tuple[fourdst._phys.atomic.Species, fourdst._phys.atomic.Species], float]]:
"""
Get all infinite entries in the Jacobian matrix.
"""
def mapping(self) -> dict[fourdst._phys.atomic.Species, int]:
"""
Get the species-to-index mapping.
"""
def nans(self) -> list[tuple[tuple[fourdst._phys.atomic.Species, fourdst._phys.atomic.Species], float]]:
"""
Get all NaN entries in the Jacobian matrix.
"""
def nnz(self) -> int:
"""
Get the number of non-zero entries in the Jacobian matrix.
"""
def rank(self) -> int:
"""
Get the rank of the Jacobian matrix.
"""
def shape(self) -> tuple[int, int]:
"""
Get the shape of the Jacobian matrix as (rows, columns).
"""
def singular(self) -> bool:
"""
Check if the Jacobian matrix is singular.
"""
def to_csv(self, filename: str) -> None:
"""
Export the Jacobian matrix to a CSV file.
"""
def to_numpy(self) -> numpy.typing.NDArray[numpy.float64]:
"""
Convert the Jacobian matrix to a NumPy array.
"""
class NetworkPrimingEngineView(DefinedEngineView):
@typing.overload
def __init__(self, ctx: scratchpads.StateBlob, primingSymbol: str, baseEngine: GraphEngine) -> None:
"""
Construct a priming engine view with a priming symbol and a base engine.
"""
@typing.overload
def __init__(self, ctx: scratchpads.StateBlob, primingSpecies: fourdst._phys.atomic.Species, baseEngine: GraphEngine) -> None:
"""
Construct a priming engine view with a priming species and a base engine.
"""
def calculateEpsDerivatives(self, ctx: scratchpads.StateBlob, comp: ..., T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> ...:
"""
Calculate deps/dT and deps/drho
"""
def calculateMolarReactionFlow(self: DynamicEngine, ctx: scratchpads.StateBlob, reaction: ..., comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> float:
"""
Calculate the molar reaction flow for a given reaction.
"""
def calculateRHSAndEnergy(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> StepDerivatives:
"""
Calculate the right-hand side (dY/dt) and energy generation rate.
"""
def collectComposition(self, ctx: scratchpads.StateBlob, composition: ..., T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> fourdst._phys.composition.Composition:
"""
Recursively collect composition from current engine and any sub engines if they exist.
"""
@typing.overload
def generateJacobianMatrix(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> NetworkJacobian:
"""
Generate the Jacobian matrix for the current state.
"""
@typing.overload
def generateJacobianMatrix(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, activeSpecies: collections.abc.Sequence[fourdst._phys.atomic.Species]) -> NetworkJacobian:
"""
Generate the jacobian matrix only for the subset of the matrix representing the active species.
"""
@typing.overload
def generateJacobianMatrix(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, sparsityPattern: collections.abc.Sequence[tuple[typing.SupportsInt, typing.SupportsInt]]) -> NetworkJacobian:
"""
Generate the jacobian matrix for the given sparsity pattern
"""
def getBaseEngine(self) -> DynamicEngine:
"""
Get the base engine associated with this priming engine view.
"""
def getNetworkReactions(self, arg0: scratchpads.StateBlob) -> gridfire._gridfire.reaction.ReactionSet:
"""
Get the set of logical reactions in the network.
"""
def getNetworkSpecies(self, arg0: scratchpads.StateBlob) -> list[fourdst._phys.atomic.Species]:
"""
Get the list of species in the network.
"""
def getScreeningModel(self, arg0: scratchpads.StateBlob) -> gridfire._gridfire.screening.ScreeningType:
"""
Get the current screening model of the engine.
"""
def getSpeciesDestructionTimescales(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> dict[fourdst._phys.atomic.Species, float]:
"""
Get the destruction timescales for each species in the network.
"""
def getSpeciesIndex(self, ctx: scratchpads.StateBlob, species: fourdst._phys.atomic.Species) -> int:
"""
Get the index of a species in the network.
"""
def getSpeciesStatus(self, ctx: scratchpads.StateBlob, species: fourdst._phys.atomic.Species) -> SpeciesStatus:
"""
Get the status of a species in the network.
"""
def getSpeciesTimescales(self: DynamicEngine, ctx: scratchpads.StateBlob, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> dict[fourdst._phys.atomic.Species, float]:
"""
Get the timescales for each species in the network.
"""
def primeEngine(self, ctx: scratchpads.StateBlob, netIn: gridfire._gridfire.type.NetIn) -> PrimingReport:
"""
Prime the engine with a NetIn object to prepare for calculations.
"""
def project(self, ctx: scratchpads.StateBlob, netIn: gridfire._gridfire.type.NetIn) -> fourdst._phys.composition.Composition:
"""
Update the engine state based on the provided NetIn object.
"""
class PrimingReport:
def __repr__(self) -> str:
...
@property
def primedComposition(self) -> fourdst._phys.composition.Composition:
"""
The composition after priming.
"""
@property
def status(self) -> PrimingReportStatus:
"""
Status message from the priming process.
"""
@property
def success(self) -> bool:
"""
Indicates if the priming was successful.
"""
class PrimingReportStatus:
"""
Members:
FULL_SUCCESS : Priming was full successful.
NO_SPECIES_TO_PRIME : Solver Failed to converge during priming.
MAX_ITERATIONS_REACHED : Engine has already been primed.
"""
FULL_SUCCESS: typing.ClassVar[PrimingReportStatus] # value = <PrimingReportStatus.FULL_SUCCESS: 0>
MAX_ITERATIONS_REACHED: typing.ClassVar[PrimingReportStatus] # value = <PrimingReportStatus.MAX_ITERATIONS_REACHED: 1>
NO_SPECIES_TO_PRIME: typing.ClassVar[PrimingReportStatus] # value = <PrimingReportStatus.NO_SPECIES_TO_PRIME: 2>
__members__: typing.ClassVar[dict[str, PrimingReportStatus]] # value = {'FULL_SUCCESS': <PrimingReportStatus.FULL_SUCCESS: 0>, 'NO_SPECIES_TO_PRIME': <PrimingReportStatus.NO_SPECIES_TO_PRIME: 2>, 'MAX_ITERATIONS_REACHED': <PrimingReportStatus.MAX_ITERATIONS_REACHED: 1>}
def __eq__(self, other: typing.Any) -> bool:
...
def __getstate__(self) -> int:
...
def __hash__(self) -> int:
...
def __index__(self) -> int:
...
def __init__(self, value: typing.SupportsInt) -> None:
...
def __int__(self) -> int:
...
def __ne__(self, other: typing.Any) -> bool:
...
@typing.overload
def __repr__(self) -> str:
...
@typing.overload
def __repr__(self) -> str:
"""
String representation of the PrimingReport.
"""
def __setstate__(self, state: typing.SupportsInt) -> None:
...
def __str__(self) -> str:
...
@property
def name(self) -> str:
...
@property
def value(self) -> int:
...
class SparsityPattern:
pass
class SpeciesStatus:
"""
Members:
ACTIVE : Species is active in the network.
EQUILIBRIUM : Species is in equilibrium.
INACTIVE_FLOW : Species is inactive due to flow.
NOT_PRESENT : Species is not present in the network.
"""
ACTIVE: typing.ClassVar[SpeciesStatus] # value = <SpeciesStatus.ACTIVE: 0>
EQUILIBRIUM: typing.ClassVar[SpeciesStatus] # value = <SpeciesStatus.EQUILIBRIUM: 1>
INACTIVE_FLOW: typing.ClassVar[SpeciesStatus] # value = <SpeciesStatus.INACTIVE_FLOW: 2>
NOT_PRESENT: typing.ClassVar[SpeciesStatus] # value = <SpeciesStatus.NOT_PRESENT: 3>
__members__: typing.ClassVar[dict[str, SpeciesStatus]] # value = {'ACTIVE': <SpeciesStatus.ACTIVE: 0>, 'EQUILIBRIUM': <SpeciesStatus.EQUILIBRIUM: 1>, 'INACTIVE_FLOW': <SpeciesStatus.INACTIVE_FLOW: 2>, 'NOT_PRESENT': <SpeciesStatus.NOT_PRESENT: 3>}
def __eq__(self, other: typing.Any) -> bool:
...
def __getstate__(self) -> int:
...
def __hash__(self) -> int:
...
def __index__(self) -> int:
...
def __init__(self, value: typing.SupportsInt) -> None:
...
def __int__(self) -> int:
...
def __ne__(self, other: typing.Any) -> bool:
...
@typing.overload
def __repr__(self) -> str:
...
@typing.overload
def __repr__(self) -> str:
...
def __setstate__(self, state: typing.SupportsInt) -> None:
...
def __str__(self) -> str:
...
@property
def name(self) -> str:
...
@property
def value(self) -> int:
...
class StepDerivatives:
@property
def dYdt(self) -> dict[fourdst._phys.atomic.Species, float]:
"""
The right-hand side (dY/dt) of the ODE system.
"""
@property
def energy(self) -> float:
"""
The energy generation rate.
"""
def build_nuclear_network(composition: ..., weakInterpolator: ..., maxLayers: gridfire._gridfire.engine.NetworkBuildDepth | typing.SupportsInt = ..., ReactionTypes: NetworkConstructionFlags = ...) -> gridfire._gridfire.reaction.ReactionSet:
"""
Build a nuclear network from a composition using all archived reaction data.
"""
def primeNetwork(ctx: scratchpads.StateBlob, netIn: gridfire._gridfire.type.NetIn, engine: ..., ignoredReactionTypes: collections.abc.Sequence[...] | None = None) -> PrimingReport:
"""
Prime a network with a short timescale ignition
"""
def regularize_jacobian(jacobian: NetworkJacobian, composition: fourdst._phys.composition.Composition) -> NetworkJacobian:
"""
regularize_jacobian
"""
ACTIVE: SpeciesStatus # value = <SpeciesStatus.ACTIVE: 0>
ADAPTIVE_ENGINE_VIEW: EngineTypes # value = <EngineTypes.ADAPTIVE_ENGINE_VIEW: 1>
DEFAULT: NetworkConstructionFlags # value = <NetworkConstructionFlags.REACLIB: 33>
DEFINED_ENGINE_VIEW: EngineTypes # value = <EngineTypes.DEFINED_ENGINE_VIEW: 4>
EQUILIBRIUM: SpeciesStatus # value = <SpeciesStatus.EQUILIBRIUM: 1>
FILE_DEFINED_ENGINE_VIEW: EngineTypes # value = <EngineTypes.FILE_DEFINED_ENGINE_VIEW: 5>
FULL_SUCCESS: PrimingReportStatus # value = <PrimingReportStatus.FULL_SUCCESS: 0>
FifthOrder: NetworkBuildDepth # value = <NetworkBuildDepth.FifthOrder: 5>
FourthOrder: NetworkBuildDepth # value = <NetworkBuildDepth.FourthOrder: 4>
Full: NetworkBuildDepth # value = <NetworkBuildDepth.Full: -1>
GRAPH_ENGINE: EngineTypes # value = <EngineTypes.GRAPH_ENGINE: 0>
INACTIVE_FLOW: SpeciesStatus # value = <SpeciesStatus.INACTIVE_FLOW: 2>
MAX_ITERATIONS_REACHED: PrimingReportStatus # value = <PrimingReportStatus.MAX_ITERATIONS_REACHED: 1>
MULTISCALE_PARTITIONING_ENGINE_VIEW: EngineTypes # value = <EngineTypes.MULTISCALE_PARTITIONING_ENGINE_VIEW: 2>
NONE: NetworkConstructionFlags # value = <NetworkConstructionFlags.NONE: 0>
NOT_PRESENT: SpeciesStatus # value = <SpeciesStatus.NOT_PRESENT: 3>
NO_SPECIES_TO_PRIME: PrimingReportStatus # value = <PrimingReportStatus.NO_SPECIES_TO_PRIME: 2>
PRIMING_ENGINE_VIEW: EngineTypes # value = <EngineTypes.PRIMING_ENGINE_VIEW: 3>
REACLIB: NetworkConstructionFlags # value = <NetworkConstructionFlags.REACLIB: 33>
REACLIB_STRONG: NetworkConstructionFlags # value = <NetworkConstructionFlags.REACLIB_STRONG: 1>
REACLIB_WEAK: NetworkConstructionFlags # value = <NetworkConstructionFlags.REACLIB_WEAK: 32>
SecondOrder: NetworkBuildDepth # value = <NetworkBuildDepth.SecondOrder: 2>
Shallow: NetworkBuildDepth # value = <NetworkBuildDepth.Shallow: 1>
ThirdOrder: NetworkBuildDepth # value = <NetworkBuildDepth.ThirdOrder: 3>
WRL_BETA_MINUS: NetworkConstructionFlags # value = <NetworkConstructionFlags.WRL_BETA_MINUS: 2>
WRL_BETA_PLUS: NetworkConstructionFlags # value = <NetworkConstructionFlags.WRL_BETA_PLUS: 4>
WRL_ELECTRON_CAPTURE: NetworkConstructionFlags # value = <NetworkConstructionFlags.WRL_ELECTRON_CAPTURE: 8>
WRL_POSITRON_CAPTURE: NetworkConstructionFlags # value = <NetworkConstructionFlags.WRL_POSITRON_CAPTURE: 16>
WRL_WEAK: NetworkConstructionFlags # value = <NetworkConstructionFlags.WRL_WEAK: 30>

View File

@@ -0,0 +1,16 @@
"""
A submodule for engine diagnostics
"""
from __future__ import annotations
import collections.abc
import fourdst._phys.composition
import gridfire._gridfire.engine
import gridfire._gridfire.engine.scratchpads
import typing
__all__: list[str] = ['inspect_jacobian_stiffness', 'inspect_species_balance', 'report_limiting_species']
def inspect_jacobian_stiffness(ctx: gridfire._gridfire.engine.scratchpads.StateBlob, engine: gridfire._gridfire.engine.DynamicEngine, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, json: bool) -> ... | None:
...
def inspect_species_balance(ctx: gridfire._gridfire.engine.scratchpads.StateBlob, engine: gridfire._gridfire.engine.DynamicEngine, species_name: str, comp: fourdst._phys.composition.Composition, T9: typing.SupportsFloat, rho: typing.SupportsFloat, json: bool) -> ... | None:
...
def report_limiting_species(ctx: gridfire._gridfire.engine.scratchpads.StateBlob, engine: gridfire._gridfire.engine.DynamicEngine, Y_full: collections.abc.Sequence[typing.SupportsFloat], E_full: collections.abc.Sequence[typing.SupportsFloat], relTol: typing.SupportsFloat, absTol: typing.SupportsFloat, top_n: typing.SupportsInt, json: bool) -> ... | None:
...

View File

@@ -0,0 +1,267 @@
"""
Engine ScratchPad bindings
"""
from __future__ import annotations
import fourdst._phys.atomic
import fourdst._phys.composition
import gridfire._gridfire.reaction
import typing
__all__: list[str] = ['ADAPTIVE_ENGINE_VIEW_SCRATCHPAD', 'ADFunRegistrationResult', 'ALREADY_REGISTERED', 'AdaptiveEngineViewScratchPad', 'DEFINED_ENGINE_VIEW_SCRATCHPAD', 'DefinedEngineViewScratchPad', 'GRAPH_ENGINE_SCRATCHPAD', 'GraphEngineScratchPad', 'MULTISCALE_PARTITIONING_ENGINE_VIEW_SCRATCHPAD', 'MultiscalePartitioningEngineViewScratchPad', 'SCRATCHPAD_BAD_CAST', 'SCRATCHPAD_NOT_FOUND', 'SCRATCHPAD_NOT_INITIALIZED', 'SCRATCHPAD_OUT_OF_BOUNDS', 'SCRATCHPAD_TYPE_COLLISION', 'SCRATCHPAD_UNKNOWN_ERROR', 'SUCCESS', 'ScratchPadType', 'StateBlob', 'StateBlobError']
class ADFunRegistrationResult:
"""
Members:
SUCCESS
ALREADY_REGISTERED
"""
ALREADY_REGISTERED: typing.ClassVar[ADFunRegistrationResult] # value = <ADFunRegistrationResult.ALREADY_REGISTERED: 1>
SUCCESS: typing.ClassVar[ADFunRegistrationResult] # value = <ADFunRegistrationResult.SUCCESS: 0>
__members__: typing.ClassVar[dict[str, ADFunRegistrationResult]] # value = {'SUCCESS': <ADFunRegistrationResult.SUCCESS: 0>, 'ALREADY_REGISTERED': <ADFunRegistrationResult.ALREADY_REGISTERED: 1>}
def __eq__(self, other: typing.Any) -> bool:
...
def __getstate__(self) -> int:
...
def __hash__(self) -> int:
...
def __index__(self) -> int:
...
def __init__(self, value: typing.SupportsInt) -> None:
...
def __int__(self) -> int:
...
def __ne__(self, other: typing.Any) -> bool:
...
def __repr__(self) -> str:
...
def __setstate__(self, state: typing.SupportsInt) -> None:
...
def __str__(self) -> str:
...
@property
def name(self) -> str:
...
@property
def value(self) -> int:
...
class AdaptiveEngineViewScratchPad:
ID: typing.ClassVar[ScratchPadType] # value = <ScratchPadType.ADAPTIVE_ENGINE_VIEW_SCRATCHPAD: 2>
def __init__(self) -> None:
...
def __repr__(self) -> str:
...
def clone(self) -> ...:
...
def initialize(self, arg0: ...) -> None:
...
def is_initialized(self) -> bool:
...
@property
def active_reactions(self) -> gridfire._gridfire.reaction.ReactionSet:
...
@property
def active_species(self) -> list[fourdst._phys.atomic.Species]:
...
@property
def has_initialized(self) -> bool:
...
class DefinedEngineViewScratchPad:
ID: typing.ClassVar[ScratchPadType] # value = <ScratchPadType.DEFINED_ENGINE_VIEW_SCRATCHPAD: 3>
def __init__(self) -> None:
...
def __repr__(self) -> str:
...
def clone(self) -> ...:
...
def is_initialized(self) -> bool:
...
@property
def active_reactions(self) -> gridfire._gridfire.reaction.ReactionSet:
...
@property
def active_species(self) -> set[fourdst._phys.atomic.Species]:
...
@property
def has_initialized(self) -> bool:
...
@property
def reaction_index_map(self) -> list[int]:
...
@property
def species_index_map(self) -> list[int]:
...
class GraphEngineScratchPad:
ID: typing.ClassVar[ScratchPadType] # value = <ScratchPadType.GRAPH_ENGINE_SCRATCHPAD: 0>
def __init__(self) -> None:
...
def __repr__(self) -> str:
...
def clone(self) -> ...:
...
def initialize(self, engine: ...) -> None:
...
def is_initialized(self) -> bool:
...
@property
def has_initialized(self) -> bool:
...
@property
def local_abundance_cache(self) -> list[float]:
...
@property
def most_recent_rhs_calculation(self) -> ... | None:
...
@property
def stepDerivativesCache(self) -> dict[int, ...]:
...
class MultiscalePartitioningEngineViewScratchPad:
ID: typing.ClassVar[ScratchPadType] # value = <ScratchPadType.MULTISCALE_PARTITIONING_ENGINE_VIEW_SCRATCHPAD: 1>
def __init__(self) -> None:
...
def __repr__(self) -> str:
...
def clone(self) -> ...:
...
def initialize(self) -> None:
...
def is_initialized(self) -> bool:
...
@property
def algebraic_species(self) -> list[fourdst._phys.atomic.Species]:
...
@property
def composition_cache(self) -> dict[int, fourdst._phys.composition.Composition]:
...
@property
def dynamic_species(self) -> list[fourdst._phys.atomic.Species]:
...
@property
def has_initialized(self) -> bool:
...
@property
def qse_groups(self) -> list[...]:
...
class ScratchPadType:
"""
Members:
GRAPH_ENGINE_SCRATCHPAD
MULTISCALE_PARTITIONING_ENGINE_VIEW_SCRATCHPAD
ADAPTIVE_ENGINE_VIEW_SCRATCHPAD
DEFINED_ENGINE_VIEW_SCRATCHPAD
"""
ADAPTIVE_ENGINE_VIEW_SCRATCHPAD: typing.ClassVar[ScratchPadType] # value = <ScratchPadType.ADAPTIVE_ENGINE_VIEW_SCRATCHPAD: 2>
DEFINED_ENGINE_VIEW_SCRATCHPAD: typing.ClassVar[ScratchPadType] # value = <ScratchPadType.DEFINED_ENGINE_VIEW_SCRATCHPAD: 3>
GRAPH_ENGINE_SCRATCHPAD: typing.ClassVar[ScratchPadType] # value = <ScratchPadType.GRAPH_ENGINE_SCRATCHPAD: 0>
MULTISCALE_PARTITIONING_ENGINE_VIEW_SCRATCHPAD: typing.ClassVar[ScratchPadType] # value = <ScratchPadType.MULTISCALE_PARTITIONING_ENGINE_VIEW_SCRATCHPAD: 1>
__members__: typing.ClassVar[dict[str, ScratchPadType]] # value = {'GRAPH_ENGINE_SCRATCHPAD': <ScratchPadType.GRAPH_ENGINE_SCRATCHPAD: 0>, 'MULTISCALE_PARTITIONING_ENGINE_VIEW_SCRATCHPAD': <ScratchPadType.MULTISCALE_PARTITIONING_ENGINE_VIEW_SCRATCHPAD: 1>, 'ADAPTIVE_ENGINE_VIEW_SCRATCHPAD': <ScratchPadType.ADAPTIVE_ENGINE_VIEW_SCRATCHPAD: 2>, 'DEFINED_ENGINE_VIEW_SCRATCHPAD': <ScratchPadType.DEFINED_ENGINE_VIEW_SCRATCHPAD: 3>}
def __eq__(self, other: typing.Any) -> bool:
...
def __getstate__(self) -> int:
...
def __hash__(self) -> int:
...
def __index__(self) -> int:
...
def __init__(self, value: typing.SupportsInt) -> None:
...
def __int__(self) -> int:
...
def __ne__(self, other: typing.Any) -> bool:
...
def __repr__(self) -> str:
...
def __setstate__(self, state: typing.SupportsInt) -> None:
...
def __str__(self) -> str:
...
@property
def name(self) -> str:
...
@property
def value(self) -> int:
...
class StateBlob:
@staticmethod
def error_to_string(arg0: StateBlobError) -> str:
...
def __init__(self) -> None:
...
def __repr__(self) -> str:
...
def clone_structure(self) -> StateBlob:
...
def enroll(self, arg0: ScratchPadType) -> None:
...
def get(self, arg0: ScratchPadType) -> ...:
...
def get_registered_scratchpads(self) -> set[ScratchPadType]:
...
def get_status(self, arg0: ScratchPadType) -> ...:
...
def get_status_map(self) -> dict[ScratchPadType, ...]:
...
class StateBlobError:
"""
Members:
SCRATCHPAD_OUT_OF_BOUNDS
SCRATCHPAD_NOT_FOUND
SCRATCHPAD_BAD_CAST
SCRATCHPAD_NOT_INITIALIZED
SCRATCHPAD_TYPE_COLLISION
SCRATCHPAD_UNKNOWN_ERROR
"""
SCRATCHPAD_BAD_CAST: typing.ClassVar[StateBlobError] # value = <StateBlobError.SCRATCHPAD_BAD_CAST: 1>
SCRATCHPAD_NOT_FOUND: typing.ClassVar[StateBlobError] # value = <StateBlobError.SCRATCHPAD_NOT_FOUND: 0>
SCRATCHPAD_NOT_INITIALIZED: typing.ClassVar[StateBlobError] # value = <StateBlobError.SCRATCHPAD_NOT_INITIALIZED: 2>
SCRATCHPAD_OUT_OF_BOUNDS: typing.ClassVar[StateBlobError] # value = <StateBlobError.SCRATCHPAD_OUT_OF_BOUNDS: 4>
SCRATCHPAD_TYPE_COLLISION: typing.ClassVar[StateBlobError] # value = <StateBlobError.SCRATCHPAD_TYPE_COLLISION: 3>
SCRATCHPAD_UNKNOWN_ERROR: typing.ClassVar[StateBlobError] # value = <StateBlobError.SCRATCHPAD_UNKNOWN_ERROR: 5>
__members__: typing.ClassVar[dict[str, StateBlobError]] # value = {'SCRATCHPAD_OUT_OF_BOUNDS': <StateBlobError.SCRATCHPAD_OUT_OF_BOUNDS: 4>, 'SCRATCHPAD_NOT_FOUND': <StateBlobError.SCRATCHPAD_NOT_FOUND: 0>, 'SCRATCHPAD_BAD_CAST': <StateBlobError.SCRATCHPAD_BAD_CAST: 1>, 'SCRATCHPAD_NOT_INITIALIZED': <StateBlobError.SCRATCHPAD_NOT_INITIALIZED: 2>, 'SCRATCHPAD_TYPE_COLLISION': <StateBlobError.SCRATCHPAD_TYPE_COLLISION: 3>, 'SCRATCHPAD_UNKNOWN_ERROR': <StateBlobError.SCRATCHPAD_UNKNOWN_ERROR: 5>}
def __eq__(self, other: typing.Any) -> bool:
...
def __getstate__(self) -> int:
...
def __hash__(self) -> int:
...
def __index__(self) -> int:
...
def __init__(self, value: typing.SupportsInt) -> None:
...
def __int__(self) -> int:
...
def __ne__(self, other: typing.Any) -> bool:
...
def __repr__(self) -> str:
...
def __setstate__(self, state: typing.SupportsInt) -> None:
...
def __str__(self) -> str:
...
@property
def name(self) -> str:
...
@property
def value(self) -> int:
...
ADAPTIVE_ENGINE_VIEW_SCRATCHPAD: ScratchPadType # value = <ScratchPadType.ADAPTIVE_ENGINE_VIEW_SCRATCHPAD: 2>
ALREADY_REGISTERED: ADFunRegistrationResult # value = <ADFunRegistrationResult.ALREADY_REGISTERED: 1>
DEFINED_ENGINE_VIEW_SCRATCHPAD: ScratchPadType # value = <ScratchPadType.DEFINED_ENGINE_VIEW_SCRATCHPAD: 3>
GRAPH_ENGINE_SCRATCHPAD: ScratchPadType # value = <ScratchPadType.GRAPH_ENGINE_SCRATCHPAD: 0>
MULTISCALE_PARTITIONING_ENGINE_VIEW_SCRATCHPAD: ScratchPadType # value = <ScratchPadType.MULTISCALE_PARTITIONING_ENGINE_VIEW_SCRATCHPAD: 1>
SCRATCHPAD_BAD_CAST: StateBlobError # value = <StateBlobError.SCRATCHPAD_BAD_CAST: 1>
SCRATCHPAD_NOT_FOUND: StateBlobError # value = <StateBlobError.SCRATCHPAD_NOT_FOUND: 0>
SCRATCHPAD_NOT_INITIALIZED: StateBlobError # value = <StateBlobError.SCRATCHPAD_NOT_INITIALIZED: 2>
SCRATCHPAD_OUT_OF_BOUNDS: StateBlobError # value = <StateBlobError.SCRATCHPAD_OUT_OF_BOUNDS: 4>
SCRATCHPAD_TYPE_COLLISION: StateBlobError # value = <StateBlobError.SCRATCHPAD_TYPE_COLLISION: 3>
SCRATCHPAD_UNKNOWN_ERROR: StateBlobError # value = <StateBlobError.SCRATCHPAD_UNKNOWN_ERROR: 5>
SUCCESS: ADFunRegistrationResult # value = <ADFunRegistrationResult.SUCCESS: 0>

View File

@@ -0,0 +1,61 @@
"""
GridFire exceptions bindings
"""
from __future__ import annotations
__all__: list[str] = ['BadCollectionError', 'BadRHSEngineError', 'CVODESolverFailureError', 'DebugException', 'EngineError', 'FailedToPartitionEngineError', 'GridFireError', 'HashingError', 'IllConditionedJacobianError', 'InvalidQSESolutionError', 'JacobianError', 'KINSolSolverFailureError', 'MissingBaseReactionError', 'MissingKeyReactionError', 'MissingSeedSpeciesError', 'NetworkResizedError', 'PolicyError', 'ReactionError', 'ReactionParsingError', 'SUNDIALSError', 'ScratchPadError', 'SingularJacobianError', 'SolverError', 'StaleJacobianError', 'UnableToSetNetworkReactionsError', 'UninitializedJacobianError', 'UnknownJacobianError', 'UtilityError']
class BadCollectionError(EngineError):
pass
class BadRHSEngineError(EngineError):
pass
class CVODESolverFailureError(SUNDIALSError):
pass
class DebugException(GridFireError):
pass
class EngineError(GridFireError):
pass
class FailedToPartitionEngineError(EngineError):
pass
class GridFireError(Exception):
pass
class HashingError(UtilityError):
pass
class IllConditionedJacobianError(SolverError):
pass
class InvalidQSESolutionError(EngineError):
pass
class JacobianError(EngineError):
pass
class KINSolSolverFailureError(SUNDIALSError):
pass
class MissingBaseReactionError(PolicyError):
pass
class MissingKeyReactionError(PolicyError):
pass
class MissingSeedSpeciesError(PolicyError):
pass
class NetworkResizedError(EngineError):
pass
class PolicyError(GridFireError):
pass
class ReactionError(GridFireError):
pass
class ReactionParsingError(ReactionError):
pass
class SUNDIALSError(SolverError):
pass
class ScratchPadError(GridFireError):
pass
class SingularJacobianError(SolverError):
pass
class SolverError(GridFireError):
pass
class StaleJacobianError(JacobianError):
pass
class UnableToSetNetworkReactionsError(EngineError):
pass
class UninitializedJacobianError(JacobianError):
pass
class UnknownJacobianError(JacobianError):
pass
class UtilityError(GridFireError):
pass

View File

@@ -0,0 +1,14 @@
"""
GridFire io bindings
"""
from __future__ import annotations
__all__: list[str] = ['NetworkFileParser', 'ParsedNetworkData', 'SimpleReactionListFileParser']
class NetworkFileParser:
pass
class ParsedNetworkData:
pass
class SimpleReactionListFileParser(NetworkFileParser):
def parse(self, filename: str) -> ParsedNetworkData:
"""
Parse a simple reaction list file and return a ParsedNetworkData object.
"""

View File

@@ -0,0 +1,142 @@
"""
GridFire partition function bindings
"""
from __future__ import annotations
import collections.abc
import typing
__all__: list[str] = ['BasePartitionType', 'CompositePartitionFunction', 'GroundState', 'GroundStatePartitionFunction', 'PartitionFunction', 'RauscherThielemann', 'RauscherThielemannPartitionDataRecord', 'RauscherThielemannPartitionFunction', 'basePartitionTypeToString', 'stringToBasePartitionType']
class BasePartitionType:
"""
Members:
RauscherThielemann
GroundState
"""
GroundState: typing.ClassVar[BasePartitionType] # value = <BasePartitionType.GroundState: 1>
RauscherThielemann: typing.ClassVar[BasePartitionType] # value = <BasePartitionType.RauscherThielemann: 0>
__members__: typing.ClassVar[dict[str, BasePartitionType]] # value = {'RauscherThielemann': <BasePartitionType.RauscherThielemann: 0>, 'GroundState': <BasePartitionType.GroundState: 1>}
def __eq__(self, other: typing.Any) -> bool:
...
def __getstate__(self) -> int:
...
def __hash__(self) -> int:
...
def __index__(self) -> int:
...
def __init__(self, value: typing.SupportsInt) -> None:
...
def __int__(self) -> int:
...
def __ne__(self, other: typing.Any) -> bool:
...
def __repr__(self) -> str:
...
def __setstate__(self, state: typing.SupportsInt) -> None:
...
def __str__(self) -> str:
...
@property
def name(self) -> str:
...
@property
def value(self) -> int:
...
class CompositePartitionFunction:
@typing.overload
def __init__(self, partitionFunctions: collections.abc.Sequence[BasePartitionType]) -> None:
"""
Create a composite partition function from a list of base partition types.
"""
@typing.overload
def __init__(self, arg0: CompositePartitionFunction) -> None:
"""
Copy constructor for CompositePartitionFunction.
"""
def evaluate(self, z: typing.SupportsInt, a: typing.SupportsInt, T9: typing.SupportsFloat) -> float:
"""
Evaluate the composite partition function for given Z, A, and T9.
"""
def evaluateDerivative(self, z: typing.SupportsInt, a: typing.SupportsInt, T9: typing.SupportsFloat) -> float:
"""
Evaluate the derivative of the composite partition function for given Z, A, and T9.
"""
def get_type(self) -> str:
"""
Get the type of the partition function (should return 'Composite').
"""
def supports(self, z: typing.SupportsInt, a: typing.SupportsInt) -> bool:
"""
Check if the composite partition function supports given Z and A.
"""
class GroundStatePartitionFunction(PartitionFunction):
def __init__(self) -> None:
...
def evaluate(self, z: typing.SupportsInt, a: typing.SupportsInt, T9: typing.SupportsFloat) -> float:
"""
Evaluate the ground state partition function for given Z, A, and T9.
"""
def evaluateDerivative(self, z: typing.SupportsInt, a: typing.SupportsInt, T9: typing.SupportsFloat) -> float:
"""
Evaluate the derivative of the ground state partition function for given Z, A, and T9.
"""
def get_type(self) -> str:
"""
Get the type of the partition function (should return 'GroundState').
"""
def supports(self, z: typing.SupportsInt, a: typing.SupportsInt) -> bool:
"""
Check if the ground state partition function supports given Z and A.
"""
class PartitionFunction:
pass
class RauscherThielemannPartitionDataRecord:
@property
def a(self) -> int:
"""
Mass number
"""
@property
def ground_state_spin(self) -> float:
"""
Ground state spin
"""
@property
def normalized_g_values(self) -> float:
"""
Normalized g-values for the first 24 energy levels
"""
@property
def z(self) -> int:
"""
Atomic number
"""
class RauscherThielemannPartitionFunction(PartitionFunction):
def __init__(self) -> None:
...
def evaluate(self, z: typing.SupportsInt, a: typing.SupportsInt, T9: typing.SupportsFloat) -> float:
"""
Evaluate the Rauscher-Thielemann partition function for given Z, A, and T9.
"""
def evaluateDerivative(self, z: typing.SupportsInt, a: typing.SupportsInt, T9: typing.SupportsFloat) -> float:
"""
Evaluate the derivative of the Rauscher-Thielemann partition function for given Z, A, and T9.
"""
def get_type(self) -> str:
"""
Get the type of the partition function (should return 'RauscherThielemann').
"""
def supports(self, z: typing.SupportsInt, a: typing.SupportsInt) -> bool:
"""
Check if the Rauscher-Thielemann partition function supports given Z and A.
"""
def basePartitionTypeToString(type: BasePartitionType) -> str:
"""
Convert BasePartitionType to string.
"""
def stringToBasePartitionType(typeStr: str) -> BasePartitionType:
"""
Convert string to BasePartitionType.
"""
GroundState: BasePartitionType # value = <BasePartitionType.GroundState: 1>
RauscherThielemann: BasePartitionType # value = <BasePartitionType.RauscherThielemann: 0>

View File

@@ -0,0 +1,769 @@
"""
GridFire network policy bindings
"""
from __future__ import annotations
import collections.abc
import fourdst._phys.atomic
import fourdst._phys.composition
import gridfire._gridfire.engine
import gridfire._gridfire.engine.scratchpads
import gridfire._gridfire.partition
import gridfire._gridfire.reaction
import typing
__all__: list[str] = ['CNOChainPolicy', 'CNOIChainPolicy', 'CNOIIChainPolicy', 'CNOIIIChainPolicy', 'CNOIVChainPolicy', 'ConstructionResults', 'HotCNOChainPolicy', 'HotCNOIChainPolicy', 'HotCNOIIChainPolicy', 'HotCNOIIIChainPolicy', 'INITIALIZED_UNVERIFIED', 'INITIALIZED_VERIFIED', 'MISSING_KEY_REACTION', 'MISSING_KEY_SPECIES', 'MainSequencePolicy', 'MainSequenceReactionChainPolicy', 'MultiReactionChainPolicy', 'NetworkPolicy', 'NetworkPolicyStatus', 'ProtonProtonChainPolicy', 'ProtonProtonIChainPolicy', 'ProtonProtonIIChainPolicy', 'ProtonProtonIIIChainPolicy', 'ReactionChainPolicy', 'TemperatureDependentChainPolicy', 'TripleAlphaChainPolicy', 'UNINITIALIZED', 'network_policy_status_to_string']
class CNOChainPolicy(MultiReactionChainPolicy):
def __eq__(self, other: ReactionChainPolicy) -> bool:
"""
Check equality with another ReactionChainPolicy.
"""
def __hash__(self) -> int:
...
def __init__(self) -> None:
...
def __ne__(self, other: ReactionChainPolicy) -> bool:
"""
Check inequality with another ReactionChainPolicy.
"""
def __repr__(self) -> str:
...
@typing.overload
def contains(self, id: str) -> bool:
"""
Check if the reaction chain contains a reaction with the given ID.
"""
@typing.overload
def contains(self, reaction: ...) -> bool:
"""
Check if the reaction chain contains the given reaction.
"""
def get_reactions(self) -> gridfire._gridfire.reaction.ReactionSet:
"""
Get the ReactionSet representing this reaction chain.
"""
def hash(self, seed: typing.SupportsInt) -> int:
"""
Compute a hash value for the reaction chain policy.
"""
@typing.overload
def name(self) -> str:
"""
Get the name of the reaction chain policy.
"""
@typing.overload
def name(self) -> str:
"""
Get the name of the reaction chain policy.
"""
class CNOIChainPolicy(TemperatureDependentChainPolicy):
def __eq__(self, other: ReactionChainPolicy) -> bool:
"""
Check equality with another ReactionChainPolicy.
"""
def __hash__(self) -> int:
...
def __init__(self) -> None:
...
def __ne__(self, other: ReactionChainPolicy) -> bool:
"""
Check inequality with another ReactionChainPolicy.
"""
def __repr__(self) -> str:
...
@typing.overload
def contains(self, id: str) -> bool:
"""
Check if the reaction chain contains a reaction with the given ID.
"""
@typing.overload
def contains(self, reaction: ...) -> bool:
"""
Check if the reaction chain contains the given reaction.
"""
def get_reactions(self) -> gridfire._gridfire.reaction.ReactionSet:
"""
Get the ReactionSet representing this reaction chain.
"""
def hash(self, seed: typing.SupportsInt) -> int:
"""
Compute a hash value for the reaction chain policy.
"""
@typing.overload
def name(self) -> str:
"""
Get the name of the reaction chain policy.
"""
@typing.overload
def name(self) -> str:
"""
Get the name of the reaction chain policy.
"""
class CNOIIChainPolicy(TemperatureDependentChainPolicy):
def __eq__(self, other: ReactionChainPolicy) -> bool:
"""
Check equality with another ReactionChainPolicy.
"""
def __hash__(self) -> int:
...
def __init__(self) -> None:
...
def __ne__(self, other: ReactionChainPolicy) -> bool:
"""
Check inequality with another ReactionChainPolicy.
"""
def __repr__(self) -> str:
...
@typing.overload
def contains(self, id: str) -> bool:
"""
Check if the reaction chain contains a reaction with the given ID.
"""
@typing.overload
def contains(self, reaction: ...) -> bool:
"""
Check if the reaction chain contains the given reaction.
"""
def get_reactions(self) -> gridfire._gridfire.reaction.ReactionSet:
"""
Get the ReactionSet representing this reaction chain.
"""
def hash(self, seed: typing.SupportsInt) -> int:
"""
Compute a hash value for the reaction chain policy.
"""
@typing.overload
def name(self) -> str:
"""
Get the name of the reaction chain policy.
"""
@typing.overload
def name(self) -> str:
"""
Get the name of the reaction chain policy.
"""
class CNOIIIChainPolicy(TemperatureDependentChainPolicy):
def __eq__(self, other: ReactionChainPolicy) -> bool:
"""
Check equality with another ReactionChainPolicy.
"""
def __hash__(self) -> int:
...
def __init__(self) -> None:
...
def __ne__(self, other: ReactionChainPolicy) -> bool:
"""
Check inequality with another ReactionChainPolicy.
"""
def __repr__(self) -> str:
...
@typing.overload
def contains(self, id: str) -> bool:
"""
Check if the reaction chain contains a reaction with the given ID.
"""
@typing.overload
def contains(self, reaction: ...) -> bool:
"""
Check if the reaction chain contains the given reaction.
"""
def get_reactions(self) -> gridfire._gridfire.reaction.ReactionSet:
"""
Get the ReactionSet representing this reaction chain.
"""
def hash(self, seed: typing.SupportsInt) -> int:
"""
Compute a hash value for the reaction chain policy.
"""
@typing.overload
def name(self) -> str:
"""
Get the name of the reaction chain policy.
"""
@typing.overload
def name(self) -> str:
"""
Get the name of the reaction chain policy.
"""
class CNOIVChainPolicy(TemperatureDependentChainPolicy):
def __eq__(self, other: ReactionChainPolicy) -> bool:
"""
Check equality with another ReactionChainPolicy.
"""
def __hash__(self) -> int:
...
def __init__(self) -> None:
...
def __ne__(self, other: ReactionChainPolicy) -> bool:
"""
Check inequality with another ReactionChainPolicy.
"""
def __repr__(self) -> str:
...
@typing.overload
def contains(self, id: str) -> bool:
"""
Check if the reaction chain contains a reaction with the given ID.
"""
@typing.overload
def contains(self, reaction: ...) -> bool:
"""
Check if the reaction chain contains the given reaction.
"""
def get_reactions(self) -> gridfire._gridfire.reaction.ReactionSet:
"""
Get the ReactionSet representing this reaction chain.
"""
def hash(self, seed: typing.SupportsInt) -> int:
"""
Compute a hash value for the reaction chain policy.
"""
@typing.overload
def name(self) -> str:
"""
Get the name of the reaction chain policy.
"""
@typing.overload
def name(self) -> str:
"""
Get the name of the reaction chain policy.
"""
class ConstructionResults:
@property
def engine(self) -> gridfire._gridfire.engine.DynamicEngine:
...
@property
def scratch_blob(self) -> gridfire._gridfire.engine.scratchpads.StateBlob:
...
class HotCNOChainPolicy(MultiReactionChainPolicy):
def __eq__(self, other: ReactionChainPolicy) -> bool:
"""
Check equality with another ReactionChainPolicy.
"""
def __hash__(self) -> int:
...
def __init__(self) -> None:
...
def __ne__(self, other: ReactionChainPolicy) -> bool:
"""
Check inequality with another ReactionChainPolicy.
"""
def __repr__(self) -> str:
...
@typing.overload
def contains(self, id: str) -> bool:
"""
Check if the reaction chain contains a reaction with the given ID.
"""
@typing.overload
def contains(self, reaction: ...) -> bool:
"""
Check if the reaction chain contains the given reaction.
"""
def get_reactions(self) -> gridfire._gridfire.reaction.ReactionSet:
"""
Get the ReactionSet representing this reaction chain.
"""
def hash(self, seed: typing.SupportsInt) -> int:
"""
Compute a hash value for the reaction chain policy.
"""
@typing.overload
def name(self) -> str:
"""
Get the name of the reaction chain policy.
"""
@typing.overload
def name(self) -> str:
"""
Get the name of the reaction chain policy.
"""
class HotCNOIChainPolicy(TemperatureDependentChainPolicy):
def __eq__(self, other: ReactionChainPolicy) -> bool:
"""
Check equality with another ReactionChainPolicy.
"""
def __hash__(self) -> int:
...
def __init__(self) -> None:
...
def __ne__(self, other: ReactionChainPolicy) -> bool:
"""
Check inequality with another ReactionChainPolicy.
"""
def __repr__(self) -> str:
...
@typing.overload
def contains(self, id: str) -> bool:
"""
Check if the reaction chain contains a reaction with the given ID.
"""
@typing.overload
def contains(self, reaction: ...) -> bool:
"""
Check if the reaction chain contains the given reaction.
"""
def get_reactions(self) -> gridfire._gridfire.reaction.ReactionSet:
"""
Get the ReactionSet representing this reaction chain.
"""
def hash(self, seed: typing.SupportsInt) -> int:
"""
Compute a hash value for the reaction chain policy.
"""
@typing.overload
def name(self) -> str:
"""
Get the name of the reaction chain policy.
"""
@typing.overload
def name(self) -> str:
"""
Get the name of the reaction chain policy.
"""
class HotCNOIIChainPolicy(TemperatureDependentChainPolicy):
def __eq__(self, other: ReactionChainPolicy) -> bool:
"""
Check equality with another ReactionChainPolicy.
"""
def __hash__(self) -> int:
...
def __init__(self) -> None:
...
def __ne__(self, other: ReactionChainPolicy) -> bool:
"""
Check inequality with another ReactionChainPolicy.
"""
def __repr__(self) -> str:
...
@typing.overload
def contains(self, id: str) -> bool:
"""
Check if the reaction chain contains a reaction with the given ID.
"""
@typing.overload
def contains(self, reaction: ...) -> bool:
"""
Check if the reaction chain contains the given reaction.
"""
def get_reactions(self) -> gridfire._gridfire.reaction.ReactionSet:
"""
Get the ReactionSet representing this reaction chain.
"""
def hash(self, seed: typing.SupportsInt) -> int:
"""
Compute a hash value for the reaction chain policy.
"""
@typing.overload
def name(self) -> str:
"""
Get the name of the reaction chain policy.
"""
@typing.overload
def name(self) -> str:
"""
Get the name of the reaction chain policy.
"""
class HotCNOIIIChainPolicy(TemperatureDependentChainPolicy):
def __eq__(self, other: ReactionChainPolicy) -> bool:
"""
Check equality with another ReactionChainPolicy.
"""
def __hash__(self) -> int:
...
def __init__(self) -> None:
...
def __ne__(self, other: ReactionChainPolicy) -> bool:
"""
Check inequality with another ReactionChainPolicy.
"""
def __repr__(self) -> str:
...
@typing.overload
def contains(self, id: str) -> bool:
"""
Check if the reaction chain contains a reaction with the given ID.
"""
@typing.overload
def contains(self, reaction: ...) -> bool:
"""
Check if the reaction chain contains the given reaction.
"""
def get_reactions(self) -> gridfire._gridfire.reaction.ReactionSet:
"""
Get the ReactionSet representing this reaction chain.
"""
def hash(self, seed: typing.SupportsInt) -> int:
"""
Compute a hash value for the reaction chain policy.
"""
@typing.overload
def name(self) -> str:
"""
Get the name of the reaction chain policy.
"""
@typing.overload
def name(self) -> str:
"""
Get the name of the reaction chain policy.
"""
class MainSequencePolicy(NetworkPolicy):
@typing.overload
def __init__(self, composition: fourdst._phys.composition.Composition) -> None:
"""
Construct MainSequencePolicy from an existing composition.
"""
@typing.overload
def __init__(self, seed_species: collections.abc.Sequence[fourdst._phys.atomic.Species], mass_fractions: collections.abc.Sequence[typing.SupportsFloat]) -> None:
"""
Construct MainSequencePolicy from seed species and mass fractions.
"""
def construct(self) -> ConstructionResults:
"""
Construct the network according to the policy.
"""
def get_engine_stack(self) -> list[gridfire._gridfire.engine.DynamicEngine]:
...
def get_engine_types_stack(self) -> list[gridfire._gridfire.engine.EngineTypes]:
"""
Get the types of engines in the stack constructed by the network policy.
"""
def get_partition_function(self) -> gridfire._gridfire.partition.PartitionFunction:
...
def get_seed_reactions(self) -> gridfire._gridfire.reaction.ReactionSet:
"""
Get the set of seed reactions required by the network policy.
"""
def get_seed_species(self) -> set[fourdst._phys.atomic.Species]:
"""
Get the set of seed species required by the network policy.
"""
def get_stack_scratch_blob(self) -> gridfire._gridfire.engine.scratchpads.StateBlob:
...
def get_status(self) -> NetworkPolicyStatus:
"""
Get the current status of the network policy.
"""
def name(self) -> str:
"""
Get the name of the network policy.
"""
class MainSequenceReactionChainPolicy(MultiReactionChainPolicy):
def __eq__(self, other: ReactionChainPolicy) -> bool:
"""
Check equality with another ReactionChainPolicy.
"""
def __hash__(self) -> int:
...
def __init__(self) -> None:
...
def __ne__(self, other: ReactionChainPolicy) -> bool:
"""
Check inequality with another ReactionChainPolicy.
"""
def __repr__(self) -> str:
...
@typing.overload
def contains(self, id: str) -> bool:
"""
Check if the reaction chain contains a reaction with the given ID.
"""
@typing.overload
def contains(self, reaction: ...) -> bool:
"""
Check if the reaction chain contains the given reaction.
"""
def get_reactions(self) -> gridfire._gridfire.reaction.ReactionSet:
"""
Get the ReactionSet representing this reaction chain.
"""
def hash(self, seed: typing.SupportsInt) -> int:
"""
Compute a hash value for the reaction chain policy.
"""
@typing.overload
def name(self) -> str:
"""
Get the name of the reaction chain policy.
"""
@typing.overload
def name(self) -> str:
"""
Get the name of the reaction chain policy.
"""
class MultiReactionChainPolicy(ReactionChainPolicy):
pass
class NetworkPolicy:
pass
class NetworkPolicyStatus:
"""
Members:
UNINITIALIZED
INITIALIZED_UNVERIFIED
MISSING_KEY_REACTION
MISSING_KEY_SPECIES
INITIALIZED_VERIFIED
"""
INITIALIZED_UNVERIFIED: typing.ClassVar[NetworkPolicyStatus] # value = <NetworkPolicyStatus.INITIALIZED_UNVERIFIED: 1>
INITIALIZED_VERIFIED: typing.ClassVar[NetworkPolicyStatus] # value = <NetworkPolicyStatus.INITIALIZED_VERIFIED: 4>
MISSING_KEY_REACTION: typing.ClassVar[NetworkPolicyStatus] # value = <NetworkPolicyStatus.MISSING_KEY_REACTION: 2>
MISSING_KEY_SPECIES: typing.ClassVar[NetworkPolicyStatus] # value = <NetworkPolicyStatus.MISSING_KEY_SPECIES: 3>
UNINITIALIZED: typing.ClassVar[NetworkPolicyStatus] # value = <NetworkPolicyStatus.UNINITIALIZED: 0>
__members__: typing.ClassVar[dict[str, NetworkPolicyStatus]] # value = {'UNINITIALIZED': <NetworkPolicyStatus.UNINITIALIZED: 0>, 'INITIALIZED_UNVERIFIED': <NetworkPolicyStatus.INITIALIZED_UNVERIFIED: 1>, 'MISSING_KEY_REACTION': <NetworkPolicyStatus.MISSING_KEY_REACTION: 2>, 'MISSING_KEY_SPECIES': <NetworkPolicyStatus.MISSING_KEY_SPECIES: 3>, 'INITIALIZED_VERIFIED': <NetworkPolicyStatus.INITIALIZED_VERIFIED: 4>}
def __eq__(self, other: typing.Any) -> bool:
...
def __getstate__(self) -> int:
...
def __hash__(self) -> int:
...
def __index__(self) -> int:
...
def __init__(self, value: typing.SupportsInt) -> None:
...
def __int__(self) -> int:
...
def __ne__(self, other: typing.Any) -> bool:
...
def __repr__(self) -> str:
...
def __setstate__(self, state: typing.SupportsInt) -> None:
...
def __str__(self) -> str:
...
@property
def name(self) -> str:
...
@property
def value(self) -> int:
...
class ProtonProtonChainPolicy(MultiReactionChainPolicy):
def __eq__(self, other: ReactionChainPolicy) -> bool:
"""
Check equality with another ReactionChainPolicy.
"""
def __hash__(self) -> int:
...
def __init__(self) -> None:
...
def __ne__(self, other: ReactionChainPolicy) -> bool:
"""
Check inequality with another ReactionChainPolicy.
"""
def __repr__(self) -> str:
...
@typing.overload
def contains(self, id: str) -> bool:
"""
Check if the reaction chain contains a reaction with the given ID.
"""
@typing.overload
def contains(self, reaction: ...) -> bool:
"""
Check if the reaction chain contains the given reaction.
"""
def get_reactions(self) -> gridfire._gridfire.reaction.ReactionSet:
"""
Get the ReactionSet representing this reaction chain.
"""
def hash(self, seed: typing.SupportsInt) -> int:
"""
Compute a hash value for the reaction chain policy.
"""
@typing.overload
def name(self) -> str:
"""
Get the name of the reaction chain policy.
"""
@typing.overload
def name(self) -> str:
"""
Get the name of the reaction chain policy.
"""
class ProtonProtonIChainPolicy(TemperatureDependentChainPolicy):
def __eq__(self, other: ReactionChainPolicy) -> bool:
"""
Check equality with another ReactionChainPolicy.
"""
def __hash__(self) -> int:
...
def __init__(self) -> None:
...
def __ne__(self, other: ReactionChainPolicy) -> bool:
"""
Check inequality with another ReactionChainPolicy.
"""
def __repr__(self) -> str:
...
@typing.overload
def contains(self, id: str) -> bool:
"""
Check if the reaction chain contains a reaction with the given ID.
"""
@typing.overload
def contains(self, reaction: ...) -> bool:
"""
Check if the reaction chain contains the given reaction.
"""
def get_reactions(self) -> gridfire._gridfire.reaction.ReactionSet:
"""
Get the ReactionSet representing this reaction chain.
"""
def hash(self, seed: typing.SupportsInt) -> int:
"""
Compute a hash value for the reaction chain policy.
"""
@typing.overload
def name(self) -> str:
"""
Get the name of the reaction chain policy.
"""
@typing.overload
def name(self) -> str:
"""
Get the name of the reaction chain policy.
"""
class ProtonProtonIIChainPolicy(TemperatureDependentChainPolicy):
def __eq__(self, other: ReactionChainPolicy) -> bool:
"""
Check equality with another ReactionChainPolicy.
"""
def __hash__(self) -> int:
...
def __init__(self) -> None:
...
def __ne__(self, other: ReactionChainPolicy) -> bool:
"""
Check inequality with another ReactionChainPolicy.
"""
def __repr__(self) -> str:
...
@typing.overload
def contains(self, id: str) -> bool:
"""
Check if the reaction chain contains a reaction with the given ID.
"""
@typing.overload
def contains(self, reaction: ...) -> bool:
"""
Check if the reaction chain contains the given reaction.
"""
def get_reactions(self) -> gridfire._gridfire.reaction.ReactionSet:
"""
Get the ReactionSet representing this reaction chain.
"""
def hash(self, seed: typing.SupportsInt) -> int:
"""
Compute a hash value for the reaction chain policy.
"""
@typing.overload
def name(self) -> str:
"""
Get the name of the reaction chain policy.
"""
@typing.overload
def name(self) -> str:
"""
Get the name of the reaction chain policy.
"""
class ProtonProtonIIIChainPolicy(TemperatureDependentChainPolicy):
def __eq__(self, other: ReactionChainPolicy) -> bool:
"""
Check equality with another ReactionChainPolicy.
"""
def __hash__(self) -> int:
...
def __init__(self) -> None:
...
def __ne__(self, other: ReactionChainPolicy) -> bool:
"""
Check inequality with another ReactionChainPolicy.
"""
def __repr__(self) -> str:
...
@typing.overload
def contains(self, id: str) -> bool:
"""
Check if the reaction chain contains a reaction with the given ID.
"""
@typing.overload
def contains(self, reaction: ...) -> bool:
"""
Check if the reaction chain contains the given reaction.
"""
def get_reactions(self) -> gridfire._gridfire.reaction.ReactionSet:
"""
Get the ReactionSet representing this reaction chain.
"""
def hash(self, seed: typing.SupportsInt) -> int:
"""
Compute a hash value for the reaction chain policy.
"""
@typing.overload
def name(self) -> str:
"""
Get the name of the reaction chain policy.
"""
@typing.overload
def name(self) -> str:
"""
Get the name of the reaction chain policy.
"""
class ReactionChainPolicy:
pass
class TemperatureDependentChainPolicy(ReactionChainPolicy):
pass
class TripleAlphaChainPolicy(TemperatureDependentChainPolicy):
def __eq__(self, other: ReactionChainPolicy) -> bool:
"""
Check equality with another ReactionChainPolicy.
"""
def __hash__(self) -> int:
...
def __init__(self) -> None:
...
def __ne__(self, other: ReactionChainPolicy) -> bool:
"""
Check inequality with another ReactionChainPolicy.
"""
def __repr__(self) -> str:
...
@typing.overload
def contains(self, id: str) -> bool:
"""
Check if the reaction chain contains a reaction with the given ID.
"""
@typing.overload
def contains(self, reaction: ...) -> bool:
"""
Check if the reaction chain contains the given reaction.
"""
def get_reactions(self) -> gridfire._gridfire.reaction.ReactionSet:
"""
Get the ReactionSet representing this reaction chain.
"""
def hash(self, seed: typing.SupportsInt) -> int:
"""
Compute a hash value for the reaction chain policy.
"""
@typing.overload
def name(self) -> str:
"""
Get the name of the reaction chain policy.
"""
@typing.overload
def name(self) -> str:
"""
Get the name of the reaction chain policy.
"""
def network_policy_status_to_string(status: NetworkPolicyStatus) -> str:
"""
Convert a NetworkPolicyStatus enum value to its string representation.
"""
INITIALIZED_UNVERIFIED: NetworkPolicyStatus # value = <NetworkPolicyStatus.INITIALIZED_UNVERIFIED: 1>
INITIALIZED_VERIFIED: NetworkPolicyStatus # value = <NetworkPolicyStatus.INITIALIZED_VERIFIED: 4>
MISSING_KEY_REACTION: NetworkPolicyStatus # value = <NetworkPolicyStatus.MISSING_KEY_REACTION: 2>
MISSING_KEY_SPECIES: NetworkPolicyStatus # value = <NetworkPolicyStatus.MISSING_KEY_SPECIES: 3>
UNINITIALIZED: NetworkPolicyStatus # value = <NetworkPolicyStatus.UNINITIALIZED: 0>

View File

@@ -0,0 +1,249 @@
"""
GridFire reaction bindings
"""
from __future__ import annotations
import collections.abc
import fourdst._phys.atomic
import fourdst._phys.composition
import typing
__all__: list[str] = ['LogicalReaclibReaction', 'RateCoefficientSet', 'ReaclibReaction', 'ReactionSet', 'get_all_reactions', 'packReactionSet']
class LogicalReaclibReaction(ReaclibReaction):
@typing.overload
def __init__(self, reactions: collections.abc.Sequence[ReaclibReaction]) -> None:
"""
Construct a LogicalReaclibReaction from a vector of ReaclibReaction objects.
"""
@typing.overload
def __init__(self, reactions: collections.abc.Sequence[ReaclibReaction], is_reverse: bool) -> None:
"""
Construct a LogicalReaclibReaction from a vector of ReaclibReaction objects.
"""
def __len__(self) -> int:
"""
Overload len() to return the number of source rates.
"""
def add_reaction(self, reaction: ReaclibReaction) -> None:
"""
Add another Reaction source to this logical reaction.
"""
def calculate_forward_rate_log_derivative(self, T9: typing.SupportsFloat, rho: typing.SupportsFloat, Ye: typing.SupportsFloat, mue: typing.SupportsFloat, Composition: fourdst._phys.composition.Composition) -> float:
"""
Calculate the forward rate log derivative at a given temperature T9 (in units of 10^9 K).
"""
def calculate_rate(self, T9: typing.SupportsFloat, rho: typing.SupportsFloat, Ye: typing.SupportsFloat, mue: typing.SupportsFloat, Y: collections.abc.Sequence[typing.SupportsFloat], index_to_species_map: collections.abc.Mapping[typing.SupportsInt, fourdst._phys.atomic.Species]) -> float:
"""
Calculate the reaction rate at a given temperature T9 (in units of 10^9 K). Note that for a reaclib reaction only T9 is actually used, all other parameters are there for interface compatibility.
"""
def size(self) -> int:
"""
Get the number of source rates contributing to this logical reaction.
"""
def sources(self) -> list[str]:
"""
Get the list of source labels for the aggregated rates.
"""
class RateCoefficientSet:
def __init__(self, a0: typing.SupportsFloat, a1: typing.SupportsFloat, a2: typing.SupportsFloat, a3: typing.SupportsFloat, a4: typing.SupportsFloat, a5: typing.SupportsFloat, a6: typing.SupportsFloat) -> None:
"""
Construct a RateCoefficientSet with the given parameters.
"""
class ReaclibReaction:
__hash__: typing.ClassVar[None] = None
def __eq__(self, arg0: ReaclibReaction) -> bool:
"""
Equality operator for reactions based on their IDs.
"""
def __init__(self, id: str, peName: str, chapter: typing.SupportsInt, reactants: collections.abc.Sequence[fourdst._phys.atomic.Species], products: collections.abc.Sequence[fourdst._phys.atomic.Species], qValue: typing.SupportsFloat, label: str, sets: RateCoefficientSet, reverse: bool = False) -> None:
"""
Construct a Reaction with the given parameters.
"""
def __neq__(self, arg0: ReaclibReaction) -> bool:
"""
Inequality operator for reactions based on their IDs.
"""
def __repr__(self) -> str:
...
def all_species(self) -> set[fourdst._phys.atomic.Species]:
"""
Get all species involved in the reaction (both reactants and products) as a set.
"""
def calculate_rate(self, T9: typing.SupportsFloat, rho: typing.SupportsFloat, Y: collections.abc.Sequence[typing.SupportsFloat]) -> float:
"""
Calculate the reaction rate at a given temperature T9 (in units of 10^9 K).
"""
def chapter(self) -> int:
"""
Get the REACLIB chapter number defining the reaction structure.
"""
def contains(self, species: fourdst._phys.atomic.Species) -> bool:
"""
Check if the reaction contains a specific species.
"""
def contains_product(self, arg0: fourdst._phys.atomic.Species) -> bool:
"""
Check if the reaction contains a specific product species.
"""
def contains_reactant(self, arg0: fourdst._phys.atomic.Species) -> bool:
"""
Check if the reaction contains a specific reactant species.
"""
def excess_energy(self) -> float:
"""
Calculate the excess energy from the mass difference of reactants and products.
"""
def hash(self, seed: typing.SupportsInt = 0) -> int:
"""
Compute a hash for the reaction based on its ID.
"""
def id(self) -> str:
"""
Get the unique identifier of the reaction.
"""
def is_reverse(self) -> bool:
"""
Check if this is a reverse reaction rate.
"""
def num_species(self) -> int:
"""
Count the number of species in the reaction.
"""
def peName(self) -> str:
"""
Get the reaction name in (projectile, ejectile) notation (e.g., 'p(p,g)d').
"""
def product_species(self) -> set[fourdst._phys.atomic.Species]:
"""
Get the product species of the reaction as a set.
"""
def products(self) -> list[fourdst._phys.atomic.Species]:
"""
Get a list of product species in the reaction.
"""
def qValue(self) -> float:
"""
Get the Q-value of the reaction in MeV.
"""
def rateCoefficients(self) -> RateCoefficientSet:
"""
get the set of rate coefficients.
"""
def reactant_species(self) -> set[fourdst._phys.atomic.Species]:
"""
Get the reactant species of the reaction as a set.
"""
def reactants(self) -> list[fourdst._phys.atomic.Species]:
"""
Get a list of reactant species in the reaction.
"""
def sourceLabel(self) -> str:
"""
Get the source label for the rate data (e.g., 'wc12w', 'st08').
"""
@typing.overload
def stoichiometry(self, species: fourdst._phys.atomic.Species) -> int:
"""
Get the stoichiometry of the reaction as a map from species to their coefficients.
"""
@typing.overload
def stoichiometry(self) -> dict[fourdst._phys.atomic.Species, int]:
"""
Get the stoichiometry of the reaction as a map from species to their coefficients.
"""
class ReactionSet:
__hash__: typing.ClassVar[None] = None
@staticmethod
def from_clones(reactions: collections.abc.Sequence[...]) -> ReactionSet:
"""
Create a ReactionSet that takes ownership of the reactions by cloning the input reactions.
"""
def __eq__(self, LogicalReactionSet: ReactionSet) -> bool:
"""
Equality operator for LogicalReactionSets based on their contents.
"""
def __getitem__(self, index: typing.SupportsInt) -> ...:
"""
Get a LogicalReaclibReaction by index.
"""
def __getitem___(self, id: str) -> ...:
"""
Get a LogicalReaclibReaction by its ID.
"""
@typing.overload
def __init__(self, reactions: collections.abc.Sequence[...]) -> None:
"""
Construct a LogicalReactionSet from a vector of LogicalReaclibReaction objects.
"""
@typing.overload
def __init__(self) -> None:
"""
Default constructor for an empty LogicalReactionSet.
"""
@typing.overload
def __init__(self, other: ReactionSet) -> None:
"""
Copy constructor for LogicalReactionSet.
"""
def __len__(self) -> int:
"""
Overload len() to return the number of LogicalReactions.
"""
def __ne__(self, LogicalReactionSet: ReactionSet) -> bool:
"""
Inequality operator for LogicalReactionSets based on their contents.
"""
def __repr__(self) -> str:
...
def add_reaction(self, reaction: ...) -> None:
"""
Add a LogicalReaclibReaction to the set.
"""
def clear(self) -> None:
"""
Remove all LogicalReactions from the set.
"""
@typing.overload
def contains(self, id: str) -> bool:
"""
Check if the set contains a specific LogicalReaclibReaction.
"""
@typing.overload
def contains(self, reaction: ...) -> bool:
"""
Check if the set contains a specific Reaction.
"""
def contains_product(self, species: fourdst._phys.atomic.Species) -> bool:
"""
Check if any reaction in the set has the species as a product.
"""
def contains_reactant(self, species: fourdst._phys.atomic.Species) -> bool:
"""
Check if any reaction in the set has the species as a reactant.
"""
def contains_species(self, species: fourdst._phys.atomic.Species) -> bool:
"""
Check if any reaction in the set involves the given species.
"""
def getReactionSetSpecies(self) -> set[fourdst._phys.atomic.Species]:
"""
Get all species involved in the reactions of the set as a set of Species objects.
"""
def hash(self, seed: typing.SupportsInt = 0) -> int:
"""
Compute a hash for the LogicalReactionSet based on its contents.
"""
def remove_reaction(self, reaction: ...) -> None:
"""
Remove a LogicalReaclibReaction from the set.
"""
def size(self) -> int:
"""
Get the number of LogicalReactions in the set.
"""
def get_all_reactions() -> ReactionSet:
"""
Get all reactions from the REACLIB database.
"""
def packReactionSet(reactionSet: ReactionSet) -> ReactionSet:
"""
Convert a ReactionSet to a LogicalReactionSet by aggregating reactions with the same peName.
"""

View File

@@ -0,0 +1,68 @@
"""
GridFire plasma screening bindings
"""
from __future__ import annotations
import collections.abc
import fourdst._phys.atomic
import gridfire._gridfire.reaction
import typing
__all__: list[str] = ['BARE', 'BareScreeningModel', 'ScreeningModel', 'ScreeningType', 'WEAK', 'WeakScreeningModel', 'selectScreeningModel']
class BareScreeningModel:
def __init__(self) -> None:
...
def calculateScreeningFactors(self, reactions: gridfire._gridfire.reaction.ReactionSet, species: collections.abc.Sequence[fourdst._phys.atomic.Species], Y: collections.abc.Sequence[typing.SupportsFloat], T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> list[float]:
"""
Calculate the bare plasma screening factors. This always returns 1.0 (bare)
"""
class ScreeningModel:
pass
class ScreeningType:
"""
Members:
BARE
WEAK
"""
BARE: typing.ClassVar[ScreeningType] # value = <ScreeningType.BARE: 0>
WEAK: typing.ClassVar[ScreeningType] # value = <ScreeningType.WEAK: 1>
__members__: typing.ClassVar[dict[str, ScreeningType]] # value = {'BARE': <ScreeningType.BARE: 0>, 'WEAK': <ScreeningType.WEAK: 1>}
def __eq__(self, other: typing.Any) -> bool:
...
def __getstate__(self) -> int:
...
def __hash__(self) -> int:
...
def __index__(self) -> int:
...
def __init__(self, value: typing.SupportsInt) -> None:
...
def __int__(self) -> int:
...
def __ne__(self, other: typing.Any) -> bool:
...
def __repr__(self) -> str:
...
def __setstate__(self, state: typing.SupportsInt) -> None:
...
def __str__(self) -> str:
...
@property
def name(self) -> str:
...
@property
def value(self) -> int:
...
class WeakScreeningModel:
def __init__(self) -> None:
...
def calculateScreeningFactors(self, reactions: gridfire._gridfire.reaction.ReactionSet, species: collections.abc.Sequence[fourdst._phys.atomic.Species], Y: collections.abc.Sequence[typing.SupportsFloat], T9: typing.SupportsFloat, rho: typing.SupportsFloat) -> list[float]:
"""
Calculate the weak plasma screening factors using the Salpeter (1954) model.
"""
def selectScreeningModel(type: ScreeningType) -> ScreeningModel:
"""
Select a screening model based on the specified type. Returns a pointer to the selected model.
"""
BARE: ScreeningType # value = <ScreeningType.BARE: 0>
WEAK: ScreeningType # value = <ScreeningType.WEAK: 1>

Some files were not shown because too many files have changed in this diff Show More