feat(python): added robust python bindings covering the entire codebase

This commit is contained in:
2025-07-23 16:26:30 -04:00
parent 6a22cb65b8
commit f20bffc411
134 changed files with 2202 additions and 170 deletions

View File

@@ -0,0 +1,113 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc if needed directly
#include <iostream>
#include <memory>
#include "bindings.h"
#include "gridfire/partition/partition.h"
PYBIND11_DECLARE_HOLDER_TYPE(T, std::unique_ptr<T>, true) // Declare unique_ptr as a holder type for pybind11
#include "trampoline/py_partition.h"
namespace py = pybind11;
void register_partition_bindings(pybind11::module &m) {
using PF = gridfire::partition::PartitionFunction;
py::class_<PF, PyPartitionFunction>(m, "PartitionFunction");
register_partition_types_bindings(m);
register_ground_state_partition_bindings(m);
register_rauscher_thielemann_partition_data_record_bindings(m);
register_rauscher_thielemann_partition_bindings(m);
register_composite_partition_bindings(m);
}
void register_partition_types_bindings(pybind11::module &m) {
py::enum_<gridfire::partition::BasePartitionType>(m, "BasePartitionType")
.value("RauscherThielemann", gridfire::partition::BasePartitionType::RauscherThielemann)
.value("GroundState", gridfire::partition::BasePartitionType::GroundState)
.export_values();
m.def("basePartitionTypeToString", [](gridfire::partition::BasePartitionType type) {
return gridfire::partition::basePartitionTypeToString[type];
}, py::arg("type"), "Convert BasePartitionType to string.");
m.def("stringToBasePartitionType", [](const std::string &typeStr) {
return gridfire::partition::stringToBasePartitionType[typeStr];
}, py::arg("typeStr"), "Convert string to BasePartitionType.");
}
void register_ground_state_partition_bindings(pybind11::module &m) {
using GSPF = gridfire::partition::GroundStatePartitionFunction;
using PF = gridfire::partition::PartitionFunction;
py::class_<GSPF, PF>(m, "GroundStatePartitionFunction")
.def(py::init<>())
.def("evaluate", &gridfire::partition::GroundStatePartitionFunction::evaluate,
py::arg("z"), py::arg("a"), py::arg("T9"),
"Evaluate the ground state partition function for given Z, A, and T9.")
.def("evaluateDerivative", &gridfire::partition::GroundStatePartitionFunction::evaluateDerivative,
py::arg("z"), py::arg("a"), py::arg("T9"),
"Evaluate the derivative of the ground state partition function for given Z, A, and T9.")
.def("supports", &gridfire::partition::GroundStatePartitionFunction::supports,
py::arg("z"), py::arg("a"),
"Check if the ground state partition function supports given Z and A.")
.def("get_type", &gridfire::partition::GroundStatePartitionFunction::type,
"Get the type of the partition function (should return 'GroundState').");
}
void register_rauscher_thielemann_partition_data_record_bindings(pybind11::module &m) {
py::class_<gridfire::partition::record::RauscherThielemannPartitionDataRecord>(m, "RauscherThielemannPartitionDataRecord")
.def_readonly("z", &gridfire::partition::record::RauscherThielemannPartitionDataRecord::z, "Atomic number")
.def_readonly("a", &gridfire::partition::record::RauscherThielemannPartitionDataRecord::a, "Mass number")
.def_readonly("ground_state_spin", &gridfire::partition::record::RauscherThielemannPartitionDataRecord::ground_state_spin, "Ground state spin")
.def_readonly("normalized_g_values", &gridfire::partition::record::RauscherThielemannPartitionDataRecord::normalized_g_values, "Normalized g-values for the first 24 energy levels");
}
void register_rauscher_thielemann_partition_bindings(pybind11::module &m) {
using RTPF = gridfire::partition::RauscherThielemannPartitionFunction;
using PF = gridfire::partition::PartitionFunction;
py::class_<RTPF, PF>(m, "RauscherThielemannPartitionFunction")
.def(py::init<>())
.def("evaluate", &gridfire::partition::RauscherThielemannPartitionFunction::evaluate,
py::arg("z"), py::arg("a"), py::arg("T9"),
"Evaluate the Rauscher-Thielemann partition function for given Z, A, and T9.")
.def("evaluateDerivative", &gridfire::partition::RauscherThielemannPartitionFunction::evaluateDerivative,
py::arg("z"), py::arg("a"), py::arg("T9"),
"Evaluate the derivative of the Rauscher-Thielemann partition function for given Z, A, and T9.")
.def("supports", &gridfire::partition::RauscherThielemannPartitionFunction::supports,
py::arg("z"), py::arg("a"),
"Check if the Rauscher-Thielemann partition function supports given Z and A.")
.def("get_type", &gridfire::partition::RauscherThielemannPartitionFunction::type,
"Get the type of the partition function (should return 'RauscherThielemann').");
}
void register_composite_partition_bindings(pybind11::module &m) {
py::class_<gridfire::partition::CompositePartitionFunction>(m, "CompositePartitionFunction")
.def(py::init<const std::vector<gridfire::partition::BasePartitionType>&>(),
py::arg("partitionFunctions"),
"Create a composite partition function from a list of base partition types.")
.def(py::init<const gridfire::partition::CompositePartitionFunction&>(),
"Copy constructor for CompositePartitionFunction.")
.def("evaluate", &gridfire::partition::CompositePartitionFunction::evaluate,
py::arg("z"), py::arg("a"), py::arg("T9"),
"Evaluate the composite partition function for given Z, A, and T9.")
.def("evaluateDerivative", &gridfire::partition::CompositePartitionFunction::evaluateDerivative,
py::arg("z"), py::arg("a"), py::arg("T9"),
"Evaluate the derivative of the composite partition function for given Z, A, and T9.")
.def("supports", &gridfire::partition::CompositePartitionFunction::supports,
py::arg("z"), py::arg("a"),
"Check if the composite partition function supports given Z and A.")
.def("get_type", &gridfire::partition::CompositePartitionFunction::type,
"Get the type of the partition function (should return 'Composite').");
}

View File

@@ -0,0 +1,16 @@
#pragma once
#include <pybind11/pybind11.h>
void register_partition_bindings(pybind11::module &m);
void register_partition_types_bindings(pybind11::module &m);
void register_ground_state_partition_bindings(pybind11::module &m);
void register_rauscher_thielemann_partition_data_record_bindings(pybind11::module &m);
void register_rauscher_thielemann_partition_bindings(pybind11::module &m);
void register_composite_partition_bindings(pybind11::module &m);

View File

@@ -0,0 +1,19 @@
subdir('trampoline')
# Define the library
bindings_sources = files('bindings.cpp')
bindings_headers = files('bindings.h')
dependencies = [
gridfire_dep,
python3_dep,
pybind11_dep,
]
shared_module('py_gf_partition',
bindings_sources,
cpp_args: ['-fvisibility=default'],
install : true,
dependencies: dependencies,
include_directories: include_directories('.')
)

View File

@@ -0,0 +1,21 @@
gf_partition_trampoline_sources = files('py_partition.cpp')
gf_partition_trapoline_dependencies = [
gridfire_dep,
pybind11_dep,
python3_dep,
]
gf_partition_trampoline_lib = static_library(
'partition_trampolines',
gf_partition_trampoline_sources,
include_directories: include_directories('.'),
dependencies: gf_partition_trapoline_dependencies,
install: false,
)
gr_partition_trampoline_dep = declare_dependency(
link_with: gf_partition_trampoline_lib,
include_directories: ('.'),
dependencies: gf_partition_trapoline_dependencies,
)

View File

@@ -0,0 +1,57 @@
#include "py_partition.h"
#include "gridfire/partition/partition.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include <string>
#include <memory>
namespace py = pybind11;
double PyPartitionFunction::evaluate(int z, int a, double T9) const {
PYBIND11_OVERRIDE_PURE(
double,
gridfire::partition::PartitionFunction,
evaluate,
z, a, T9
);
}
double PyPartitionFunction::evaluateDerivative(int z, int a, double T9) const {
PYBIND11_OVERRIDE_PURE(
double,
gridfire::partition::PartitionFunction,
evaluateDerivative,
z, a, T9
);
}
bool PyPartitionFunction::supports(int z, int a) const {
PYBIND11_OVERRIDE_PURE(
bool,
gridfire::partition::PartitionFunction,
supports,
z, a
);
}
std::string PyPartitionFunction::type() const {
PYBIND11_OVERRIDE_PURE(
std::string,
gridfire::partition::PartitionFunction,
type
);
}
std::unique_ptr<gridfire::partition::PartitionFunction> PyPartitionFunction::clone() const {
PYBIND11_OVERRIDE_PURE(
std::unique_ptr<gridfire::partition::PartitionFunction>,
gridfire::partition::PartitionFunction,
clone
);
}

View File

@@ -0,0 +1,15 @@
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "gridfire/partition/partition.h"
class PyPartitionFunction final : public gridfire::partition::PartitionFunction {
double evaluate(int z, int a, double T9) const override;
double evaluateDerivative(int z, int a, double T9) const override;
bool supports(int z, int a) const override;
std::string type() const override;
std::unique_ptr<gridfire::partition::PartitionFunction> clone() const override;
};