Merge pull request #73 from tboudreaux/feature/pythonInterface/network

Python bindings to network interface added
This commit is contained in:
2025-06-17 11:15:38 -04:00
committed by GitHub
7 changed files with 186 additions and 2 deletions

View File

@@ -14,6 +14,7 @@ py_mod = py_installation.extension_module(
meson.project_source_root() + '/src/python/mfem/Trampoline/PyMFEMTrampolines/Operator/Matrix/PyMatrix.cpp',
meson.project_source_root() + '/src/python/mfem/Trampoline/PyMFEMTrampolines/Coefficient/PyCoefficient.cpp',
meson.project_source_root() + '/src/python/polytrope/bindings.cpp',
meson.project_source_root() + '/src/python/network/bindings.cpp',
],
dependencies : [
pybind11_dep,
@@ -24,7 +25,8 @@ py_mod = py_installation.extension_module(
species_weight_dep,
mfem_dep,
polysolver_dep,
trampoline_dep
trampoline_dep,
network_dep,
],
cpp_args : ['-UNDEBUG'], # Example: Ensure assertions are enabled if needed
install : true,

View File

@@ -8,6 +8,7 @@
#include "eos/bindings.h"
#include "mfem/bindings.h"
#include "polytrope/bindings.h"
#include "network/bindings.h"
PYBIND11_MODULE(serif, m) {
m.doc() = "Python bindings for the SERiF project";
@@ -29,4 +30,7 @@ PYBIND11_MODULE(serif, m) {
auto polytropeMod = m.def_submodule("polytrope", "Polytrope-module bindings");
register_polytrope_bindings(polytropeMod);
auto networkMod = m.def_submodule("network", "Network-module bindings");
register_network_bindings(networkMod);
}

View File

@@ -3,4 +3,6 @@ subdir('const')
subdir('config')
subdir('mfem')
subdir('eos')
subdir('eos')
subdir('polytrope')
subdir('network')

View File

@@ -0,0 +1,125 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc if needed directly
#include <pybind11/numpy.h>
#include <string>
#include "bindings.h"
#include "approx8.h"
#include "network.h"
namespace py = pybind11;
void register_network_bindings(pybind11::module &network_submodule) {
py::enum_<serif::network::NetworkFormat>(network_submodule, "NetworkFormat")
.value("APPROX8", serif::network::NetworkFormat::APPROX8)
.value("UNKNOWN", serif::network::NetworkFormat::UNKNOWN)
.def("__str__", [](const serif::network::NetworkFormat format) {
return serif::network::FormatStringLookup[format];
});
py::class_<serif::network::NetIn>(network_submodule, "NetIn")
.def(py::init<>())
.def_readwrite("composition", &serif::network::NetIn::composition)
.def_readwrite("tMax", &serif::network::NetIn::tMax)
.def_readwrite("dt0", &serif::network::NetIn::dt0)
.def_readwrite("temperature", &serif::network::NetIn::temperature)
.def_readwrite("density", &serif::network::NetIn::density)
.def_readwrite("energy", &serif::network::NetIn::energy)
.def("__repr__", [](const serif::network::NetIn &netIn) {
std::stringstream ss;
ss << "NetIn(composition=" << netIn.composition
<< ", tMax=" << netIn.tMax
<< ", dt0=" << netIn.dt0
<< ", temperature=" << netIn.temperature
<< ", density=" << netIn.density
<< ", energy=" << netIn.energy << ")";
return ss.str();
});
py::class_<serif::network::NetOut>(network_submodule, "NetOut")
.def_readonly("composition", &serif::network::NetOut::composition)
.def_readonly("num_steps", &serif::network::NetOut::num_steps)
.def_readonly("energy", &serif::network::NetOut::energy)
.def("__repr__", [](const serif::network::NetOut &netOut) {
std::stringstream ss;
ss << "NetOut(composition=" << netOut.composition
<< ", num_steps=" << netOut.num_steps
<< ", energy=" << netOut.energy << ")";
return ss.str();
});
py::class_<serif::network::Network>(network_submodule, "Network")
.def(py::init<serif::network::NetworkFormat>(), py::arg("format"))
.def("evaluate", &serif::network::Network::evaluate, py::arg("netIn"))
.def("getFormat", &serif::network::Network::getFormat)
.def("setFormat", &serif::network::Network::setFormat, py::arg("format"))
.def("__repr__", [](const serif::network::Network &network) {
std::stringstream ss;
ss << "Network(format=" << serif::network::FormatStringLookup[network.getFormat()] << ")";
return ss.str();
});
auto approx8Module = network_submodule.def_submodule("approx8", "Approx8 nuclear reaction network module");
register_approx8_bindings(approx8Module);
}
void register_approx8_bindings(pybind11::module &network_submodule) {
using namespace serif::network::approx8;
py::class_<vec7>(network_submodule, "vec7")
.def("__getitem__", [](const vec7 &v, const size_t i) {
if (i >= v.size()) throw py::index_error();
return v[i];
})
.def("__setitem__", [](vec7 &v, const size_t i, const double value) {
if (i >= v.size()) throw py::index_error();
v[i] = value;
})
.def("__len__", [](const vec7 &v) { return v.size(); })
.def("__repr__", [](const vec7 &v) {
std::stringstream ss;
ss << "vec7(";
for (size_t i = 0; i < v.size(); ++i) {
ss << v[i];
if (i < v.size() - 1) ss << ", ";
}
ss << ")";
return ss.str();
});
network_submodule.def("sum_product", &sum_product, "Find the sum of the element wise product of two vec7 arrays", py::arg("a"), py::arg("b"));
network_submodule.def("get_T9_array", &get_T9_array, "Return an array of T9 terms for the nuclear reaction rate fit", py::arg("T"));
network_submodule.def("rate_fit", &rate_fit, "Evaluate the nuclear reaction rate given the T9 array and coefficients", py::arg("T9"), py::arg("coef"));
network_submodule.def("pp_rate", &pp_rate, "Calculate the rate for the reaction p + p -> d", py::arg("T9"));
network_submodule.def("dp_rate", &dp_rate, "Calculate the rate for the reaction p + d -> he3", py::arg("T9"));
network_submodule.def("he3he3_rate", &he3he3_rate, "Calculate the rate for the reaction he3 + he3 -> he4 + 2p", py::arg("T9"));
network_submodule.def("he3he4_rate", &he3he4_rate, "Calculate the rate for the reaction he3(he3,2p)he4", py::arg("T9"));
network_submodule.def("triple_alpha_rate", &triple_alpha_rate, "Calculate the rate for the reaction he4 + he4 + he4 -> c12", py::arg("T9"));
network_submodule.def("c12p_rate", &c12p_rate, "Calculate the rate for the reaction c12 + p -> n13", py::arg("T9"));
network_submodule.def("c12a_rate", &c12a_rate, "Calculate the rate for the reaction c12 + he4 -> o16", py::arg("T9"));
network_submodule.def("n14p_rate", &n14p_rate, "Calculate the rate for the reaction n14(p,g)o15 - o15 + p -> c12", py::arg("T9"));
network_submodule.def("n14a_rate", &n14a_rate, "Calculate the rate for the reaction n14(a,g)f18 assumed to go on to ne20", py::arg("T9"));
network_submodule.def("n15pa_rate", &n15pa_rate, "Calculate the rate for the reaction n15(p,a)c12 (CNO I)", py::arg("T9"));
network_submodule.def("n15pg_rate", &n15pg_rate, "Calculate the rate for the reaction n15(p,g)o16 (CNO II)", py::arg("T9"));
network_submodule.def("n15pg_frac", &n15pg_frac, "Calculate the fraction for the reaction n15(p,g)o16", py::arg("T9"));
network_submodule.def("o16p_rate", &o16p_rate, "Calculate the rate for the reaction o16(p, g)f17 then f17 -> o17(p,a)n14", py::arg("T9"));
network_submodule.def("o16a_rate", &o16a_rate, "Calculate the rate for the reaction o16(a,g)ne20", py::arg("T9"));
network_submodule.def("ne20a_rate", &ne20a_rate, "Calculate the rate for the reaction ne20(a,g)mg24", py::arg("T9"));
network_submodule.def("c12c12_rate", &c12c12_rate, "Calculate the rate for the reaction c12(c12,a)ne20", py::arg("T9"));
network_submodule.def("c12o16_rate", &c12o16_rate, "Calculate the rate for the reaction c12(o16,a)mg24", py::arg("T9"));
py::class_<Approx8Network>(network_submodule, "Approx8Network")
.def(py::init<>())
.def("evaluate", &Approx8Network::evaluate, py::arg("netIn"), "Evaluate the Approx8 nuclear reaction network with the given input")
.def("setStiff", &Approx8Network::setStiff, py::arg("stiff"), "Set whether to use a stiff solver or not")
.def("isStiff", &Approx8Network::isStiff, "Get whether the network is set to use a stiff solver or not")
.def("__repr__", [](const Approx8Network &network) {
std::stringstream ss;
ss << "Approx8Network(stiff=" << (network.isStiff() ? "True" : "False") << ")";
return ss.str();
});
}

View File

@@ -0,0 +1,7 @@
#pragma once
#include <pybind11/pybind11.h>
void register_network_bindings(pybind11::module &network_submodule);
void register_approx8_bindings(pybind11::module &network_submodule);

View File

@@ -0,0 +1,17 @@
# Define the library
bindings_sources = files('bindings.cpp')
bindings_headers = files('bindings.h')
dependencies = [
python3_dep,
pybind11_dep,
network_dep,
]
shared_module('py_network',
bindings_sources,
include_directories: include_directories('.'),
cpp_args: ['-fvisibility=default'],
install : true,
dependencies: dependencies,
)

View File

@@ -0,0 +1,27 @@
from serif.network.approx8 import Approx8Network
from serif.network import NetIn
from serif.composition import Composition
comp = Composition(
["H-1", "He-3", "He-4", "C-12", "N-14", "O-16", "Ne-20", "Mg-24"]
)
comp.setMassFraction(
["H-1", "He-3", "He-4", "C-12", "N-14", "O-16", "Ne-20", "Mg-24"],
[0.708, 2.94e-5, 0.276, 0.003, 0.0011, 9.62e-3, 1.62e-3, 5.16e-4]
)
comp.finalize(True)
netIn = NetIn()
netIn.composition = comp
netIn.temperature = 1e7
netIn.density = 1e2
netIn.energy = 0.0
netIn.tMax = 3.15e17
netIn.dt0 = 1e12
net = Approx8Network()
netOut = net.evaluate(netIn)
print(netIn)
print(netOut)
print(net)