diff --git a/build-python/meson.build b/build-python/meson.build index a09fd03..2af839d 100644 --- a/build-python/meson.build +++ b/build-python/meson.build @@ -14,6 +14,7 @@ py_mod = py_installation.extension_module( meson.project_source_root() + '/src/python/mfem/Trampoline/PyMFEMTrampolines/Operator/Matrix/PyMatrix.cpp', meson.project_source_root() + '/src/python/mfem/Trampoline/PyMFEMTrampolines/Coefficient/PyCoefficient.cpp', meson.project_source_root() + '/src/python/polytrope/bindings.cpp', + meson.project_source_root() + '/src/python/network/bindings.cpp', ], dependencies : [ pybind11_dep, @@ -24,7 +25,8 @@ py_mod = py_installation.extension_module( species_weight_dep, mfem_dep, polysolver_dep, - trampoline_dep + trampoline_dep, + network_dep, ], cpp_args : ['-UNDEBUG'], # Example: Ensure assertions are enabled if needed install : true, diff --git a/src/python/bindings.cpp b/src/python/bindings.cpp index 610f55a..f1d9e04 100644 --- a/src/python/bindings.cpp +++ b/src/python/bindings.cpp @@ -8,6 +8,7 @@ #include "eos/bindings.h" #include "mfem/bindings.h" #include "polytrope/bindings.h" +#include "network/bindings.h" PYBIND11_MODULE(serif, m) { m.doc() = "Python bindings for the SERiF project"; @@ -29,4 +30,7 @@ PYBIND11_MODULE(serif, m) { auto polytropeMod = m.def_submodule("polytrope", "Polytrope-module bindings"); register_polytrope_bindings(polytropeMod); + + auto networkMod = m.def_submodule("network", "Network-module bindings"); + register_network_bindings(networkMod); } \ No newline at end of file diff --git a/src/python/meson.build b/src/python/meson.build index 6c8df90..84a9847 100644 --- a/src/python/meson.build +++ b/src/python/meson.build @@ -3,4 +3,6 @@ subdir('const') subdir('config') subdir('mfem') -subdir('eos') \ No newline at end of file +subdir('eos') +subdir('polytrope') +subdir('network') \ No newline at end of file diff --git a/src/python/network/bindings.cpp b/src/python/network/bindings.cpp new file mode 100644 index 0000000..8d1c34d --- /dev/null +++ b/src/python/network/bindings.cpp @@ -0,0 +1,125 @@ +#include +#include // Needed for vectors, maps, sets, strings +#include // Needed for binding std::vector, std::map etc if needed directly +#include + +#include +#include "bindings.h" + +#include "approx8.h" +#include "network.h" + +namespace py = pybind11; + + +void register_network_bindings(pybind11::module &network_submodule) { + py::enum_(network_submodule, "NetworkFormat") + .value("APPROX8", serif::network::NetworkFormat::APPROX8) + .value("UNKNOWN", serif::network::NetworkFormat::UNKNOWN) + .def("__str__", [](const serif::network::NetworkFormat format) { + return serif::network::FormatStringLookup[format]; + }); + + py::class_(network_submodule, "NetIn") + .def(py::init<>()) + .def_readwrite("composition", &serif::network::NetIn::composition) + .def_readwrite("tMax", &serif::network::NetIn::tMax) + .def_readwrite("dt0", &serif::network::NetIn::dt0) + .def_readwrite("temperature", &serif::network::NetIn::temperature) + .def_readwrite("density", &serif::network::NetIn::density) + .def_readwrite("energy", &serif::network::NetIn::energy) + .def("__repr__", [](const serif::network::NetIn &netIn) { + std::stringstream ss; + ss << "NetIn(composition=" << netIn.composition + << ", tMax=" << netIn.tMax + << ", dt0=" << netIn.dt0 + << ", temperature=" << netIn.temperature + << ", density=" << netIn.density + << ", energy=" << netIn.energy << ")"; + return ss.str(); + }); + + py::class_(network_submodule, "NetOut") + .def_readonly("composition", &serif::network::NetOut::composition) + .def_readonly("num_steps", &serif::network::NetOut::num_steps) + .def_readonly("energy", &serif::network::NetOut::energy) + .def("__repr__", [](const serif::network::NetOut &netOut) { + std::stringstream ss; + ss << "NetOut(composition=" << netOut.composition + << ", num_steps=" << netOut.num_steps + << ", energy=" << netOut.energy << ")"; + return ss.str(); + }); + + py::class_(network_submodule, "Network") + .def(py::init(), py::arg("format")) + .def("evaluate", &serif::network::Network::evaluate, py::arg("netIn")) + .def("getFormat", &serif::network::Network::getFormat) + .def("setFormat", &serif::network::Network::setFormat, py::arg("format")) + .def("__repr__", [](const serif::network::Network &network) { + std::stringstream ss; + ss << "Network(format=" << serif::network::FormatStringLookup[network.getFormat()] << ")"; + return ss.str(); + }); + + auto approx8Module = network_submodule.def_submodule("approx8", "Approx8 nuclear reaction network module"); + register_approx8_bindings(approx8Module); + +} + +void register_approx8_bindings(pybind11::module &network_submodule) { + using namespace serif::network::approx8; + + py::class_(network_submodule, "vec7") + .def("__getitem__", [](const vec7 &v, const size_t i) { + if (i >= v.size()) throw py::index_error(); + return v[i]; + }) + .def("__setitem__", [](vec7 &v, const size_t i, const double value) { + if (i >= v.size()) throw py::index_error(); + v[i] = value; + }) + .def("__len__", [](const vec7 &v) { return v.size(); }) + .def("__repr__", [](const vec7 &v) { + std::stringstream ss; + ss << "vec7("; + for (size_t i = 0; i < v.size(); ++i) { + ss << v[i]; + if (i < v.size() - 1) ss << ", "; + } + ss << ")"; + return ss.str(); + }); + + network_submodule.def("sum_product", &sum_product, "Find the sum of the element wise product of two vec7 arrays", py::arg("a"), py::arg("b")); + network_submodule.def("get_T9_array", &get_T9_array, "Return an array of T9 terms for the nuclear reaction rate fit", py::arg("T")); + network_submodule.def("rate_fit", &rate_fit, "Evaluate the nuclear reaction rate given the T9 array and coefficients", py::arg("T9"), py::arg("coef")); + network_submodule.def("pp_rate", &pp_rate, "Calculate the rate for the reaction p + p -> d", py::arg("T9")); + network_submodule.def("dp_rate", &dp_rate, "Calculate the rate for the reaction p + d -> he3", py::arg("T9")); + network_submodule.def("he3he3_rate", &he3he3_rate, "Calculate the rate for the reaction he3 + he3 -> he4 + 2p", py::arg("T9")); + network_submodule.def("he3he4_rate", &he3he4_rate, "Calculate the rate for the reaction he3(he3,2p)he4", py::arg("T9")); + network_submodule.def("triple_alpha_rate", &triple_alpha_rate, "Calculate the rate for the reaction he4 + he4 + he4 -> c12", py::arg("T9")); + network_submodule.def("c12p_rate", &c12p_rate, "Calculate the rate for the reaction c12 + p -> n13", py::arg("T9")); + network_submodule.def("c12a_rate", &c12a_rate, "Calculate the rate for the reaction c12 + he4 -> o16", py::arg("T9")); + network_submodule.def("n14p_rate", &n14p_rate, "Calculate the rate for the reaction n14(p,g)o15 - o15 + p -> c12", py::arg("T9")); + network_submodule.def("n14a_rate", &n14a_rate, "Calculate the rate for the reaction n14(a,g)f18 assumed to go on to ne20", py::arg("T9")); + network_submodule.def("n15pa_rate", &n15pa_rate, "Calculate the rate for the reaction n15(p,a)c12 (CNO I)", py::arg("T9")); + network_submodule.def("n15pg_rate", &n15pg_rate, "Calculate the rate for the reaction n15(p,g)o16 (CNO II)", py::arg("T9")); + network_submodule.def("n15pg_frac", &n15pg_frac, "Calculate the fraction for the reaction n15(p,g)o16", py::arg("T9")); + network_submodule.def("o16p_rate", &o16p_rate, "Calculate the rate for the reaction o16(p, g)f17 then f17 -> o17(p,a)n14", py::arg("T9")); + network_submodule.def("o16a_rate", &o16a_rate, "Calculate the rate for the reaction o16(a,g)ne20", py::arg("T9")); + network_submodule.def("ne20a_rate", &ne20a_rate, "Calculate the rate for the reaction ne20(a,g)mg24", py::arg("T9")); + network_submodule.def("c12c12_rate", &c12c12_rate, "Calculate the rate for the reaction c12(c12,a)ne20", py::arg("T9")); + network_submodule.def("c12o16_rate", &c12o16_rate, "Calculate the rate for the reaction c12(o16,a)mg24", py::arg("T9")); + + py::class_(network_submodule, "Approx8Network") + .def(py::init<>()) + .def("evaluate", &Approx8Network::evaluate, py::arg("netIn"), "Evaluate the Approx8 nuclear reaction network with the given input") + .def("setStiff", &Approx8Network::setStiff, py::arg("stiff"), "Set whether to use a stiff solver or not") + .def("isStiff", &Approx8Network::isStiff, "Get whether the network is set to use a stiff solver or not") + .def("__repr__", [](const Approx8Network &network) { + std::stringstream ss; + ss << "Approx8Network(stiff=" << (network.isStiff() ? "True" : "False") << ")"; + return ss.str(); + }); +} \ No newline at end of file diff --git a/src/python/network/bindings.h b/src/python/network/bindings.h new file mode 100644 index 0000000..6dcc7f1 --- /dev/null +++ b/src/python/network/bindings.h @@ -0,0 +1,7 @@ +#pragma once + +#include + +void register_network_bindings(pybind11::module &network_submodule); + +void register_approx8_bindings(pybind11::module &network_submodule); \ No newline at end of file diff --git a/src/python/network/meson.build b/src/python/network/meson.build new file mode 100644 index 0000000..47ea8a1 --- /dev/null +++ b/src/python/network/meson.build @@ -0,0 +1,17 @@ +# Define the library +bindings_sources = files('bindings.cpp') +bindings_headers = files('bindings.h') + +dependencies = [ + python3_dep, + pybind11_dep, + network_dep, +] + +shared_module('py_network', + bindings_sources, + include_directories: include_directories('.'), + cpp_args: ['-fvisibility=default'], + install : true, + dependencies: dependencies, +) diff --git a/tests/python/network/evaluateNetwork.py b/tests/python/network/evaluateNetwork.py new file mode 100644 index 0000000..cad0c29 --- /dev/null +++ b/tests/python/network/evaluateNetwork.py @@ -0,0 +1,27 @@ +from serif.network.approx8 import Approx8Network +from serif.network import NetIn +from serif.composition import Composition + +comp = Composition( + ["H-1", "He-3", "He-4", "C-12", "N-14", "O-16", "Ne-20", "Mg-24"] +) +comp.setMassFraction( + ["H-1", "He-3", "He-4", "C-12", "N-14", "O-16", "Ne-20", "Mg-24"], + [0.708, 2.94e-5, 0.276, 0.003, 0.0011, 9.62e-3, 1.62e-3, 5.16e-4] +) +comp.finalize(True) + +netIn = NetIn() +netIn.composition = comp +netIn.temperature = 1e7 +netIn.density = 1e2 +netIn.energy = 0.0 +netIn.tMax = 3.15e17 +netIn.dt0 = 1e12 + +net = Approx8Network() +netOut = net.evaluate(netIn) + +print(netIn) +print(netOut) +print(net)