feat(python): added robust python bindings covering the entire codebase
This commit is contained in:
113
src/python/partition/bindings.cpp
Normal file
113
src/python/partition/bindings.cpp
Normal file
@@ -0,0 +1,113 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
|
||||
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc if needed directly
|
||||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
|
||||
#include "bindings.h"
|
||||
|
||||
#include "gridfire/partition/partition.h"
|
||||
|
||||
PYBIND11_DECLARE_HOLDER_TYPE(T, std::unique_ptr<T>, true) // Declare unique_ptr as a holder type for pybind11
|
||||
|
||||
#include "trampoline/py_partition.h"
|
||||
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
|
||||
void register_partition_bindings(pybind11::module &m) {
|
||||
using PF = gridfire::partition::PartitionFunction;
|
||||
py::class_<PF, PyPartitionFunction>(m, "PartitionFunction");
|
||||
|
||||
register_partition_types_bindings(m);
|
||||
register_ground_state_partition_bindings(m);
|
||||
register_rauscher_thielemann_partition_data_record_bindings(m);
|
||||
register_rauscher_thielemann_partition_bindings(m);
|
||||
|
||||
register_composite_partition_bindings(m);
|
||||
}
|
||||
|
||||
void register_partition_types_bindings(pybind11::module &m) {
|
||||
py::enum_<gridfire::partition::BasePartitionType>(m, "BasePartitionType")
|
||||
.value("RauscherThielemann", gridfire::partition::BasePartitionType::RauscherThielemann)
|
||||
.value("GroundState", gridfire::partition::BasePartitionType::GroundState)
|
||||
.export_values();
|
||||
|
||||
m.def("basePartitionTypeToString", [](gridfire::partition::BasePartitionType type) {
|
||||
return gridfire::partition::basePartitionTypeToString[type];
|
||||
}, py::arg("type"), "Convert BasePartitionType to string.");
|
||||
|
||||
m.def("stringToBasePartitionType", [](const std::string &typeStr) {
|
||||
return gridfire::partition::stringToBasePartitionType[typeStr];
|
||||
}, py::arg("typeStr"), "Convert string to BasePartitionType.");
|
||||
}
|
||||
|
||||
void register_ground_state_partition_bindings(pybind11::module &m) {
|
||||
using GSPF = gridfire::partition::GroundStatePartitionFunction;
|
||||
using PF = gridfire::partition::PartitionFunction;
|
||||
py::class_<GSPF, PF>(m, "GroundStatePartitionFunction")
|
||||
.def(py::init<>())
|
||||
.def("evaluate", &gridfire::partition::GroundStatePartitionFunction::evaluate,
|
||||
py::arg("z"), py::arg("a"), py::arg("T9"),
|
||||
"Evaluate the ground state partition function for given Z, A, and T9.")
|
||||
.def("evaluateDerivative", &gridfire::partition::GroundStatePartitionFunction::evaluateDerivative,
|
||||
py::arg("z"), py::arg("a"), py::arg("T9"),
|
||||
"Evaluate the derivative of the ground state partition function for given Z, A, and T9.")
|
||||
.def("supports", &gridfire::partition::GroundStatePartitionFunction::supports,
|
||||
py::arg("z"), py::arg("a"),
|
||||
"Check if the ground state partition function supports given Z and A.")
|
||||
.def("get_type", &gridfire::partition::GroundStatePartitionFunction::type,
|
||||
"Get the type of the partition function (should return 'GroundState').");
|
||||
}
|
||||
|
||||
void register_rauscher_thielemann_partition_data_record_bindings(pybind11::module &m) {
|
||||
py::class_<gridfire::partition::record::RauscherThielemannPartitionDataRecord>(m, "RauscherThielemannPartitionDataRecord")
|
||||
.def_readonly("z", &gridfire::partition::record::RauscherThielemannPartitionDataRecord::z, "Atomic number")
|
||||
.def_readonly("a", &gridfire::partition::record::RauscherThielemannPartitionDataRecord::a, "Mass number")
|
||||
.def_readonly("ground_state_spin", &gridfire::partition::record::RauscherThielemannPartitionDataRecord::ground_state_spin, "Ground state spin")
|
||||
.def_readonly("normalized_g_values", &gridfire::partition::record::RauscherThielemannPartitionDataRecord::normalized_g_values, "Normalized g-values for the first 24 energy levels");
|
||||
}
|
||||
|
||||
|
||||
void register_rauscher_thielemann_partition_bindings(pybind11::module &m) {
|
||||
using RTPF = gridfire::partition::RauscherThielemannPartitionFunction;
|
||||
using PF = gridfire::partition::PartitionFunction;
|
||||
py::class_<RTPF, PF>(m, "RauscherThielemannPartitionFunction")
|
||||
.def(py::init<>())
|
||||
.def("evaluate", &gridfire::partition::RauscherThielemannPartitionFunction::evaluate,
|
||||
py::arg("z"), py::arg("a"), py::arg("T9"),
|
||||
"Evaluate the Rauscher-Thielemann partition function for given Z, A, and T9.")
|
||||
.def("evaluateDerivative", &gridfire::partition::RauscherThielemannPartitionFunction::evaluateDerivative,
|
||||
py::arg("z"), py::arg("a"), py::arg("T9"),
|
||||
"Evaluate the derivative of the Rauscher-Thielemann partition function for given Z, A, and T9.")
|
||||
.def("supports", &gridfire::partition::RauscherThielemannPartitionFunction::supports,
|
||||
py::arg("z"), py::arg("a"),
|
||||
"Check if the Rauscher-Thielemann partition function supports given Z and A.")
|
||||
.def("get_type", &gridfire::partition::RauscherThielemannPartitionFunction::type,
|
||||
"Get the type of the partition function (should return 'RauscherThielemann').");
|
||||
}
|
||||
|
||||
void register_composite_partition_bindings(pybind11::module &m) {
|
||||
py::class_<gridfire::partition::CompositePartitionFunction>(m, "CompositePartitionFunction")
|
||||
.def(py::init<const std::vector<gridfire::partition::BasePartitionType>&>(),
|
||||
py::arg("partitionFunctions"),
|
||||
"Create a composite partition function from a list of base partition types.")
|
||||
.def(py::init<const gridfire::partition::CompositePartitionFunction&>(),
|
||||
"Copy constructor for CompositePartitionFunction.")
|
||||
.def("evaluate", &gridfire::partition::CompositePartitionFunction::evaluate,
|
||||
py::arg("z"), py::arg("a"), py::arg("T9"),
|
||||
"Evaluate the composite partition function for given Z, A, and T9.")
|
||||
.def("evaluateDerivative", &gridfire::partition::CompositePartitionFunction::evaluateDerivative,
|
||||
py::arg("z"), py::arg("a"), py::arg("T9"),
|
||||
"Evaluate the derivative of the composite partition function for given Z, A, and T9.")
|
||||
.def("supports", &gridfire::partition::CompositePartitionFunction::supports,
|
||||
py::arg("z"), py::arg("a"),
|
||||
"Check if the composite partition function supports given Z and A.")
|
||||
.def("get_type", &gridfire::partition::CompositePartitionFunction::type,
|
||||
"Get the type of the partition function (should return 'Composite').");
|
||||
}
|
||||
|
||||
|
||||
|
||||
16
src/python/partition/bindings.h
Normal file
16
src/python/partition/bindings.h
Normal file
@@ -0,0 +1,16 @@
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
void register_partition_bindings(pybind11::module &m);
|
||||
|
||||
void register_partition_types_bindings(pybind11::module &m);
|
||||
|
||||
void register_ground_state_partition_bindings(pybind11::module &m);
|
||||
|
||||
void register_rauscher_thielemann_partition_data_record_bindings(pybind11::module &m);
|
||||
|
||||
void register_rauscher_thielemann_partition_bindings(pybind11::module &m);
|
||||
|
||||
void register_composite_partition_bindings(pybind11::module &m);
|
||||
|
||||
19
src/python/partition/meson.build
Normal file
19
src/python/partition/meson.build
Normal file
@@ -0,0 +1,19 @@
|
||||
subdir('trampoline')
|
||||
|
||||
# Define the library
|
||||
bindings_sources = files('bindings.cpp')
|
||||
bindings_headers = files('bindings.h')
|
||||
|
||||
dependencies = [
|
||||
gridfire_dep,
|
||||
python3_dep,
|
||||
pybind11_dep,
|
||||
]
|
||||
|
||||
shared_module('py_gf_partition',
|
||||
bindings_sources,
|
||||
cpp_args: ['-fvisibility=default'],
|
||||
install : true,
|
||||
dependencies: dependencies,
|
||||
include_directories: include_directories('.')
|
||||
)
|
||||
21
src/python/partition/trampoline/meson.build
Normal file
21
src/python/partition/trampoline/meson.build
Normal file
@@ -0,0 +1,21 @@
|
||||
gf_partition_trampoline_sources = files('py_partition.cpp')
|
||||
|
||||
gf_partition_trapoline_dependencies = [
|
||||
gridfire_dep,
|
||||
pybind11_dep,
|
||||
python3_dep,
|
||||
]
|
||||
|
||||
gf_partition_trampoline_lib = static_library(
|
||||
'partition_trampolines',
|
||||
gf_partition_trampoline_sources,
|
||||
include_directories: include_directories('.'),
|
||||
dependencies: gf_partition_trapoline_dependencies,
|
||||
install: false,
|
||||
)
|
||||
|
||||
gr_partition_trampoline_dep = declare_dependency(
|
||||
link_with: gf_partition_trampoline_lib,
|
||||
include_directories: ('.'),
|
||||
dependencies: gf_partition_trapoline_dependencies,
|
||||
)
|
||||
57
src/python/partition/trampoline/py_partition.cpp
Normal file
57
src/python/partition/trampoline/py_partition.cpp
Normal file
@@ -0,0 +1,57 @@
|
||||
#include "py_partition.h"
|
||||
|
||||
#include "gridfire/partition/partition.h"
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
|
||||
double PyPartitionFunction::evaluate(int z, int a, double T9) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
double,
|
||||
gridfire::partition::PartitionFunction,
|
||||
evaluate,
|
||||
z, a, T9
|
||||
);
|
||||
}
|
||||
|
||||
double PyPartitionFunction::evaluateDerivative(int z, int a, double T9) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
double,
|
||||
gridfire::partition::PartitionFunction,
|
||||
evaluateDerivative,
|
||||
z, a, T9
|
||||
);
|
||||
}
|
||||
|
||||
bool PyPartitionFunction::supports(int z, int a) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
bool,
|
||||
gridfire::partition::PartitionFunction,
|
||||
supports,
|
||||
z, a
|
||||
);
|
||||
}
|
||||
|
||||
std::string PyPartitionFunction::type() const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
std::string,
|
||||
gridfire::partition::PartitionFunction,
|
||||
type
|
||||
);
|
||||
}
|
||||
|
||||
std::unique_ptr<gridfire::partition::PartitionFunction> PyPartitionFunction::clone() const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
std::unique_ptr<gridfire::partition::PartitionFunction>,
|
||||
gridfire::partition::PartitionFunction,
|
||||
clone
|
||||
);
|
||||
}
|
||||
|
||||
15
src/python/partition/trampoline/py_partition.h
Normal file
15
src/python/partition/trampoline/py_partition.h
Normal file
@@ -0,0 +1,15 @@
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include "gridfire/partition/partition.h"
|
||||
|
||||
|
||||
class PyPartitionFunction final : public gridfire::partition::PartitionFunction {
|
||||
double evaluate(int z, int a, double T9) const override;
|
||||
double evaluateDerivative(int z, int a, double T9) const override;
|
||||
bool supports(int z, int a) const override;
|
||||
std::string type() const override;
|
||||
std::unique_ptr<gridfire::partition::PartitionFunction> clone() const override;
|
||||
};
|
||||
Reference in New Issue
Block a user