Files
SERiF/src/python/eos/bindings.cpp
Emily Boudreaux cf25c54cda feat(eos): EOS now uses composition module
EOS code updated to make use of composition module. This is part of the larger change to change all composition handling to use the composition module. Note that the current implimentation is a bit hacky since it simply copies back and forth to the alredy used HELMEOSInput and HELMEOSOutput structs. In furture this can be more tightly connected to avoid extra copies.
2025-06-16 15:00:33 -04:00

160 lines
8.4 KiB
C++

#include <pybind11/pybind11.h>
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc if needed directly
#include <pybind11/numpy.h>
#include <string>
#include "helm.h"
#include "bindings.h"
#include "EOSio.h"
#include "helm.h"
namespace serif::eos {
class EOSio;
}
namespace py = pybind11;
void register_eos_bindings(pybind11::module &eos_submodule) {
py::class_<serif::eos::EOSio>(eos_submodule, "EOSio")
.def(py::init<std::string>(), py::arg("filename"))
// .def("load", &EOSio::load)
.def("getFormat", &serif::eos::EOSio::getFormat, "Get the format of the EOS table.")
.def("getTable", [](serif::eos::EOSio &self) -> serif::eos::helmholtz::HELMTable* {
auto& table_variant = self.getTable();
// Use std::get_if to safely access the contents of the variant.
// This returns a pointer to the value if the variant holds that type, otherwise nullptr.
if (auto* ptr_to_unique_ptr = std::get_if<std::unique_ptr<serif::eos::helmholtz::HELMTable>>(&table_variant)) {
return (*ptr_to_unique_ptr).get();
}
return nullptr;
}, py::return_value_policy::reference_internal, // IMPORTANT: Keep this policy!
"Get the EOS table data.")
.def("__repr__", [](const serif::eos::EOSio &eos) {
return "<EOSio(filename='" + eos.getFilename() + "', format='" + eos.getFormatName() + "')>";
});
py::class_<serif::eos::EOSTable>(eos_submodule, "EOSTable");
py::class_<serif::eos::helmholtz::HELMTable>(eos_submodule, "HELMTable")
.def_readonly("loaded", &serif::eos::helmholtz::HELMTable::loaded)
.def_readonly("imax", &serif::eos::helmholtz::HELMTable::imax)
.def_readonly("jmax", &serif::eos::helmholtz::HELMTable::jmax)
.def_readonly("t", &serif::eos::helmholtz::HELMTable::t)
.def_readonly("d", &serif::eos::helmholtz::HELMTable::d)
.def("__repr__", [](const serif::eos::helmholtz::HELMTable &table) {
return "<HELMTable(loaded=" + std::to_string(table.loaded) + ", imax=" + std::to_string(table.imax) +
", jmax=" + std::to_string(table.jmax) + ")>";
})
.def_property_readonly("f", [](serif::eos::helmholtz::HELMTable &table) -> py::array_t<double> {
// --- Check Preconditions ---
// 1. Check if dimensions are valid
if (table.imax <= 0 || table.jmax <= 0) {
// Return empty array or throw error for invalid dimensions
throw std::runtime_error("HELMTable dimensions (imax, jmax) are non-positive.");
// Alternatively: return py::array_t<double>();
}
// 2. Check if pointer 'f' and the data block 'f[0]' are non-null
// (Essential check assuming f could be null if not loaded/initialized)
if (!table.f || !table.f[0]) {
throw std::runtime_error("HELMTable data buffer 'f' is null or not initialized.");
// Alternatively: return py::array_t<double>();
}
// --- Get necessary info ---
py::ssize_t rows = static_cast<py::ssize_t>(table.imax);
py::ssize_t cols = static_cast<py::ssize_t>(table.jmax);
double* data_ptr = table.f[0]; // Pointer to the start of contiguous data block
// --- Define NumPy array shape and strides ---
std::vector<py::ssize_t> shape = {rows, cols};
std::vector<py::ssize_t> strides = {
static_cast<py::ssize_t>(cols * sizeof(double)), // Stride to next row
static_cast<py::ssize_t>( sizeof(double)) // Stride to next element in row
};
// --- Create and return the py::array_t ---
// py::cast(table) creates a py::object that acts as the 'base'.
// This tells NumPy not to manage the memory of 'data_ptr' and
// ensures the 'table' object stays alive as long as the NumPy array view exists.
return py::array_t<double>(
shape, // The dimensions of the array
strides, // How many bytes to step in each dimension
data_ptr, // Pointer to the actual data
py::cast(table) // Owner object (keeps C++ object alive)
);
}, py::return_value_policy::reference_internal); // Keep parent 'table' alive
py::class_<serif::eos::helmholtz::HELMEOSOutput>(eos_submodule, "EOS")
.def(py::init<>())
.def_readonly("ye", &serif::eos::helmholtz::HELMEOSOutput::ye)
.def_readonly("etaele", &serif::eos::helmholtz::HELMEOSOutput::etaele)
.def_readonly("xnefer", &serif::eos::helmholtz::HELMEOSOutput::xnefer)
.def_readonly("ptot", &serif::eos::helmholtz::HELMEOSOutput::ptot)
.def_readonly("pgas", &serif::eos::helmholtz::HELMEOSOutput::pgas)
.def_readonly("prad", &serif::eos::helmholtz::HELMEOSOutput::prad)
.def_readonly("etot", &serif::eos::helmholtz::HELMEOSOutput::etot)
.def_readonly("egas", &serif::eos::helmholtz::HELMEOSOutput::egas)
.def_readonly("erad", &serif::eos::helmholtz::HELMEOSOutput::erad)
.def_readonly("stot", &serif::eos::helmholtz::HELMEOSOutput::stot)
.def_readonly("sgas", &serif::eos::helmholtz::HELMEOSOutput::sgas)
.def_readonly("srad", &serif::eos::helmholtz::HELMEOSOutput::srad)
.def_readonly("dpresdd", &serif::eos::helmholtz::HELMEOSOutput::dpresdd)
.def_readonly("dpresdt", &serif::eos::helmholtz::HELMEOSOutput::dpresdt)
.def_readonly("dpresda", &serif::eos::helmholtz::HELMEOSOutput::dpresda)
.def_readonly("dpresdz", &serif::eos::helmholtz::HELMEOSOutput::dpresdz)
.def_readonly("dentrdd", &serif::eos::helmholtz::HELMEOSOutput::dentrdd)
.def_readonly("dentrdt", &serif::eos::helmholtz::HELMEOSOutput::dentrdt)
.def_readonly("dentrda", &serif::eos::helmholtz::HELMEOSOutput::dentrda)
.def_readonly("dentrdz", &serif::eos::helmholtz::HELMEOSOutput::dentrdz)
.def_readonly("denerdd", &serif::eos::helmholtz::HELMEOSOutput::denerdd)
.def_readonly("denerdt", &serif::eos::helmholtz::HELMEOSOutput::denerdt)
.def_readonly("denerda", &serif::eos::helmholtz::HELMEOSOutput::denerda)
.def_readonly("denerdz", &serif::eos::helmholtz::HELMEOSOutput::denerdz)
.def_readonly("chiT", &serif::eos::helmholtz::HELMEOSOutput::chiT)
.def_readonly("chiRho", &serif::eos::helmholtz::HELMEOSOutput::chiRho)
.def_readonly("csound", &serif::eos::helmholtz::HELMEOSOutput::csound)
.def_readonly("grad_ad", &serif::eos::helmholtz::HELMEOSOutput::grad_ad)
.def_readonly("gamma1", &serif::eos::helmholtz::HELMEOSOutput::gamma1)
.def_readonly("gamma2", &serif::eos::helmholtz::HELMEOSOutput::gamma2)
.def_readonly("gamma3", &serif::eos::helmholtz::HELMEOSOutput::gamma3)
.def_readonly("cV", &serif::eos::helmholtz::HELMEOSOutput::cV)
.def_readonly("cP", &serif::eos::helmholtz::HELMEOSOutput::cP)
.def_readonly("dse", &serif::eos::helmholtz::HELMEOSOutput::dse)
.def_readonly("dpe", &serif::eos::helmholtz::HELMEOSOutput::dpe)
.def_readonly("dsp", &serif::eos::helmholtz::HELMEOSOutput::dsp)
.def("__repr__", [](const serif::eos::helmholtz::HELMEOSOutput &eos) {
return "<EOS (output from helmholtz eos)>";
});
py::class_<serif::eos::helmholtz::HELMEOSInput>(eos_submodule, "HELMEOSInput")
.def(py::init<>())
.def_readwrite("T", &serif::eos::helmholtz::HELMEOSInput::T)
.def_readwrite("rho", &serif::eos::helmholtz::HELMEOSInput::rho)
.def_readwrite("abar", &serif::eos::helmholtz::HELMEOSInput::abar)
.def_readwrite("zbar", &serif::eos::helmholtz::HELMEOSInput::zbar)
.def("__repr__", [](const serif::eos::helmholtz::HELMEOSInput &input) {
return "<HELMEOSInput(T=" + std::to_string(input.T) +
", rho=" + std::to_string(input.rho) +
", abar=" + std::to_string(input.abar) +
", zbar=" + std::to_string(input.zbar) + ")>";
});
eos_submodule.def("get_helm_eos",
&serif::eos::helmholtz::get_helm_EOS,
py::arg("q"), py::arg("table"),
"Calculate the Helmholtz EOS components based on input parameters and table data.");
}