feat(pythonInterface/network): added network interface from python module
This commit is contained in:
@@ -14,6 +14,7 @@ py_mod = py_installation.extension_module(
|
||||
meson.project_source_root() + '/src/python/mfem/Trampoline/PyMFEMTrampolines/Operator/Matrix/PyMatrix.cpp',
|
||||
meson.project_source_root() + '/src/python/mfem/Trampoline/PyMFEMTrampolines/Coefficient/PyCoefficient.cpp',
|
||||
meson.project_source_root() + '/src/python/polytrope/bindings.cpp',
|
||||
meson.project_source_root() + '/src/python/network/bindings.cpp',
|
||||
],
|
||||
dependencies : [
|
||||
pybind11_dep,
|
||||
@@ -24,7 +25,8 @@ py_mod = py_installation.extension_module(
|
||||
species_weight_dep,
|
||||
mfem_dep,
|
||||
polysolver_dep,
|
||||
trampoline_dep
|
||||
trampoline_dep,
|
||||
network_dep,
|
||||
],
|
||||
cpp_args : ['-UNDEBUG'], # Example: Ensure assertions are enabled if needed
|
||||
install : true,
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "eos/bindings.h"
|
||||
#include "mfem/bindings.h"
|
||||
#include "polytrope/bindings.h"
|
||||
#include "network/bindings.h"
|
||||
|
||||
PYBIND11_MODULE(serif, m) {
|
||||
m.doc() = "Python bindings for the SERiF project";
|
||||
@@ -29,4 +30,7 @@ PYBIND11_MODULE(serif, m) {
|
||||
|
||||
auto polytropeMod = m.def_submodule("polytrope", "Polytrope-module bindings");
|
||||
register_polytrope_bindings(polytropeMod);
|
||||
|
||||
auto networkMod = m.def_submodule("network", "Network-module bindings");
|
||||
register_network_bindings(networkMod);
|
||||
}
|
||||
@@ -3,4 +3,6 @@ subdir('const')
|
||||
subdir('config')
|
||||
|
||||
subdir('mfem')
|
||||
subdir('eos')
|
||||
subdir('eos')
|
||||
subdir('polytrope')
|
||||
subdir('network')
|
||||
125
src/python/network/bindings.cpp
Normal file
125
src/python/network/bindings.cpp
Normal file
@@ -0,0 +1,125 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
|
||||
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc if needed directly
|
||||
#include <pybind11/numpy.h>
|
||||
|
||||
#include <string>
|
||||
#include "bindings.h"
|
||||
|
||||
#include "approx8.h"
|
||||
#include "network.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
|
||||
void register_network_bindings(pybind11::module &network_submodule) {
|
||||
py::enum_<serif::network::NetworkFormat>(network_submodule, "NetworkFormat")
|
||||
.value("APPROX8", serif::network::NetworkFormat::APPROX8)
|
||||
.value("UNKNOWN", serif::network::NetworkFormat::UNKNOWN)
|
||||
.def("__str__", [](const serif::network::NetworkFormat format) {
|
||||
return serif::network::FormatStringLookup[format];
|
||||
});
|
||||
|
||||
py::class_<serif::network::NetIn>(network_submodule, "NetIn")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("composition", &serif::network::NetIn::composition)
|
||||
.def_readwrite("tMax", &serif::network::NetIn::tMax)
|
||||
.def_readwrite("dt0", &serif::network::NetIn::dt0)
|
||||
.def_readwrite("temperature", &serif::network::NetIn::temperature)
|
||||
.def_readwrite("density", &serif::network::NetIn::density)
|
||||
.def_readwrite("energy", &serif::network::NetIn::energy)
|
||||
.def("__repr__", [](const serif::network::NetIn &netIn) {
|
||||
std::stringstream ss;
|
||||
ss << "NetIn(composition=" << netIn.composition
|
||||
<< ", tMax=" << netIn.tMax
|
||||
<< ", dt0=" << netIn.dt0
|
||||
<< ", temperature=" << netIn.temperature
|
||||
<< ", density=" << netIn.density
|
||||
<< ", energy=" << netIn.energy << ")";
|
||||
return ss.str();
|
||||
});
|
||||
|
||||
py::class_<serif::network::NetOut>(network_submodule, "NetOut")
|
||||
.def_readonly("composition", &serif::network::NetOut::composition)
|
||||
.def_readonly("num_steps", &serif::network::NetOut::num_steps)
|
||||
.def_readonly("energy", &serif::network::NetOut::energy)
|
||||
.def("__repr__", [](const serif::network::NetOut &netOut) {
|
||||
std::stringstream ss;
|
||||
ss << "NetOut(composition=" << netOut.composition
|
||||
<< ", num_steps=" << netOut.num_steps
|
||||
<< ", energy=" << netOut.energy << ")";
|
||||
return ss.str();
|
||||
});
|
||||
|
||||
py::class_<serif::network::Network>(network_submodule, "Network")
|
||||
.def(py::init<serif::network::NetworkFormat>(), py::arg("format"))
|
||||
.def("evaluate", &serif::network::Network::evaluate, py::arg("netIn"))
|
||||
.def("getFormat", &serif::network::Network::getFormat)
|
||||
.def("setFormat", &serif::network::Network::setFormat, py::arg("format"))
|
||||
.def("__repr__", [](const serif::network::Network &network) {
|
||||
std::stringstream ss;
|
||||
ss << "Network(format=" << serif::network::FormatStringLookup[network.getFormat()] << ")";
|
||||
return ss.str();
|
||||
});
|
||||
|
||||
auto approx8Module = network_submodule.def_submodule("approx8", "Approx8 nuclear reaction network module");
|
||||
register_approx8_bindings(approx8Module);
|
||||
|
||||
}
|
||||
|
||||
void register_approx8_bindings(pybind11::module &network_submodule) {
|
||||
using namespace serif::network::approx8;
|
||||
|
||||
py::class_<vec7>(network_submodule, "vec7")
|
||||
.def("__getitem__", [](const vec7 &v, const size_t i) {
|
||||
if (i >= v.size()) throw py::index_error();
|
||||
return v[i];
|
||||
})
|
||||
.def("__setitem__", [](vec7 &v, const size_t i, const double value) {
|
||||
if (i >= v.size()) throw py::index_error();
|
||||
v[i] = value;
|
||||
})
|
||||
.def("__len__", [](const vec7 &v) { return v.size(); })
|
||||
.def("__repr__", [](const vec7 &v) {
|
||||
std::stringstream ss;
|
||||
ss << "vec7(";
|
||||
for (size_t i = 0; i < v.size(); ++i) {
|
||||
ss << v[i];
|
||||
if (i < v.size() - 1) ss << ", ";
|
||||
}
|
||||
ss << ")";
|
||||
return ss.str();
|
||||
});
|
||||
|
||||
network_submodule.def("sum_product", &sum_product, "Find the sum of the element wise product of two vec7 arrays", py::arg("a"), py::arg("b"));
|
||||
network_submodule.def("get_T9_array", &get_T9_array, "Return an array of T9 terms for the nuclear reaction rate fit", py::arg("T"));
|
||||
network_submodule.def("rate_fit", &rate_fit, "Evaluate the nuclear reaction rate given the T9 array and coefficients", py::arg("T9"), py::arg("coef"));
|
||||
network_submodule.def("pp_rate", &pp_rate, "Calculate the rate for the reaction p + p -> d", py::arg("T9"));
|
||||
network_submodule.def("dp_rate", &dp_rate, "Calculate the rate for the reaction p + d -> he3", py::arg("T9"));
|
||||
network_submodule.def("he3he3_rate", &he3he3_rate, "Calculate the rate for the reaction he3 + he3 -> he4 + 2p", py::arg("T9"));
|
||||
network_submodule.def("he3he4_rate", &he3he4_rate, "Calculate the rate for the reaction he3(he3,2p)he4", py::arg("T9"));
|
||||
network_submodule.def("triple_alpha_rate", &triple_alpha_rate, "Calculate the rate for the reaction he4 + he4 + he4 -> c12", py::arg("T9"));
|
||||
network_submodule.def("c12p_rate", &c12p_rate, "Calculate the rate for the reaction c12 + p -> n13", py::arg("T9"));
|
||||
network_submodule.def("c12a_rate", &c12a_rate, "Calculate the rate for the reaction c12 + he4 -> o16", py::arg("T9"));
|
||||
network_submodule.def("n14p_rate", &n14p_rate, "Calculate the rate for the reaction n14(p,g)o15 - o15 + p -> c12", py::arg("T9"));
|
||||
network_submodule.def("n14a_rate", &n14a_rate, "Calculate the rate for the reaction n14(a,g)f18 assumed to go on to ne20", py::arg("T9"));
|
||||
network_submodule.def("n15pa_rate", &n15pa_rate, "Calculate the rate for the reaction n15(p,a)c12 (CNO I)", py::arg("T9"));
|
||||
network_submodule.def("n15pg_rate", &n15pg_rate, "Calculate the rate for the reaction n15(p,g)o16 (CNO II)", py::arg("T9"));
|
||||
network_submodule.def("n15pg_frac", &n15pg_frac, "Calculate the fraction for the reaction n15(p,g)o16", py::arg("T9"));
|
||||
network_submodule.def("o16p_rate", &o16p_rate, "Calculate the rate for the reaction o16(p, g)f17 then f17 -> o17(p,a)n14", py::arg("T9"));
|
||||
network_submodule.def("o16a_rate", &o16a_rate, "Calculate the rate for the reaction o16(a,g)ne20", py::arg("T9"));
|
||||
network_submodule.def("ne20a_rate", &ne20a_rate, "Calculate the rate for the reaction ne20(a,g)mg24", py::arg("T9"));
|
||||
network_submodule.def("c12c12_rate", &c12c12_rate, "Calculate the rate for the reaction c12(c12,a)ne20", py::arg("T9"));
|
||||
network_submodule.def("c12o16_rate", &c12o16_rate, "Calculate the rate for the reaction c12(o16,a)mg24", py::arg("T9"));
|
||||
|
||||
py::class_<Approx8Network>(network_submodule, "Approx8Network")
|
||||
.def(py::init<>())
|
||||
.def("evaluate", &Approx8Network::evaluate, py::arg("netIn"), "Evaluate the Approx8 nuclear reaction network with the given input")
|
||||
.def("setStiff", &Approx8Network::setStiff, py::arg("stiff"), "Set whether to use a stiff solver or not")
|
||||
.def("isStiff", &Approx8Network::isStiff, "Get whether the network is set to use a stiff solver or not")
|
||||
.def("__repr__", [](const Approx8Network &network) {
|
||||
std::stringstream ss;
|
||||
ss << "Approx8Network(stiff=" << (network.isStiff() ? "True" : "False") << ")";
|
||||
return ss.str();
|
||||
});
|
||||
}
|
||||
7
src/python/network/bindings.h
Normal file
7
src/python/network/bindings.h
Normal file
@@ -0,0 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
void register_network_bindings(pybind11::module &network_submodule);
|
||||
|
||||
void register_approx8_bindings(pybind11::module &network_submodule);
|
||||
17
src/python/network/meson.build
Normal file
17
src/python/network/meson.build
Normal file
@@ -0,0 +1,17 @@
|
||||
# Define the library
|
||||
bindings_sources = files('bindings.cpp')
|
||||
bindings_headers = files('bindings.h')
|
||||
|
||||
dependencies = [
|
||||
python3_dep,
|
||||
pybind11_dep,
|
||||
network_dep,
|
||||
]
|
||||
|
||||
shared_module('py_network',
|
||||
bindings_sources,
|
||||
include_directories: include_directories('.'),
|
||||
cpp_args: ['-fvisibility=default'],
|
||||
install : true,
|
||||
dependencies: dependencies,
|
||||
)
|
||||
27
tests/python/network/evaluateNetwork.py
Normal file
27
tests/python/network/evaluateNetwork.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from serif.network.approx8 import Approx8Network
|
||||
from serif.network import NetIn
|
||||
from serif.composition import Composition
|
||||
|
||||
comp = Composition(
|
||||
["H-1", "He-3", "He-4", "C-12", "N-14", "O-16", "Ne-20", "Mg-24"]
|
||||
)
|
||||
comp.setMassFraction(
|
||||
["H-1", "He-3", "He-4", "C-12", "N-14", "O-16", "Ne-20", "Mg-24"],
|
||||
[0.708, 2.94e-5, 0.276, 0.003, 0.0011, 9.62e-3, 1.62e-3, 5.16e-4]
|
||||
)
|
||||
comp.finalize(True)
|
||||
|
||||
netIn = NetIn()
|
||||
netIn.composition = comp
|
||||
netIn.temperature = 1e7
|
||||
netIn.density = 1e2
|
||||
netIn.energy = 0.0
|
||||
netIn.tMax = 3.15e17
|
||||
netIn.dt0 = 1e12
|
||||
|
||||
net = Approx8Network()
|
||||
netOut = net.evaluate(netIn)
|
||||
|
||||
print(netIn)
|
||||
print(netOut)
|
||||
print(net)
|
||||
Reference in New Issue
Block a user