feat(python): Python Bindings

Python Bindings are working again
This commit is contained in:
2025-12-20 16:02:52 -05:00
parent d65c237b26
commit 11a596b75b
78 changed files with 4411 additions and 1110 deletions

View File

@@ -8,7 +8,6 @@
#include "gridfire/policy/policy.h"
PYBIND11_DECLARE_HOLDER_TYPE(T, std::unique_ptr<T>, true) // Declare unique_ptr as a holder type for pybind11
namespace py = pybind11;
@@ -103,9 +102,30 @@ namespace {
.def(
"construct",
&T::construct,
py::return_value_policy::reference,
"Construct the network according to the policy."
)
.def(
"get_engine_stack",
[](const T &self) {
const auto& stack = self.get_engine_stack();
std::vector<gridfire::engine::DynamicEngine*> engine_ptrs;
engine_ptrs.reserve(stack.size());
for (const auto& engine_uptr : stack) {
engine_ptrs.push_back(engine_uptr.get());
}
return engine_ptrs;
},
py::return_value_policy::reference_internal
)
.def(
"get_stack_scratch_blob",
&T::get_stack_scratch_blob
)
.def(
"get_partition_function",
&T::get_partition_function
);
}
}
@@ -215,6 +235,26 @@ void register_network_policy_bindings(pybind11::module &m) {
.value("INITIALIZED_VERIFIED", gridfire::policy::NetworkPolicyStatus::INITIALIZED_VERIFIED)
.export_values();
m.def("network_policy_status_to_string",
&gridfire::policy::NetworkPolicyStatusToString,
py::arg("status"),
"Convert a NetworkPolicyStatus enum value to its string representation."
);
py::class_<gridfire::policy::ConstructionResults>(m, "ConstructionResults")
.def_property_readonly("engine",
[](const gridfire::policy::ConstructionResults &self) -> const gridfire::engine::DynamicEngine& {
return self.engine;
},
py::return_value_policy::reference
)
.def_property_readonly("scratch_blob",
[](const gridfire::policy::ConstructionResults &self) {
return self.scratch_blob.get();
},
py::return_value_policy::reference_internal
);
py::class_<gridfire::policy::NetworkPolicy, PyNetworkPolicy> py_networkPolicy(m, "NetworkPolicy");
py::class_<gridfire::policy::MainSequencePolicy, gridfire::policy::NetworkPolicy> py_mainSeqPolicy(m, "MainSequencePolicy");
py_mainSeqPolicy.def(

View File

@@ -1,19 +0,0 @@
# Define the library
subdir('trampoline')
bindings_sources = files('bindings.cpp')
bindings_headers = files('bindings.h')
dependencies = [
gridfire_dep,
python3_dep,
pybind11_dep,
]
shared_module('py_gf_policy',
bindings_sources,
cpp_args: ['-fvisibility=default'],
install : true,
dependencies: dependencies,
include_directories: include_directories('.')
)

View File

@@ -1,21 +0,0 @@
gf_policy_trampoline_sources = files('py_policy.cpp')
gf_policy_trapoline_dependencies = [
gridfire_dep,
pybind11_dep,
python3_dep,
]
gf_policy_trampoline_lib = static_library(
'policy_trampolines',
gf_policy_trampoline_sources,
include_directories: include_directories('.'),
dependencies: gf_policy_trapoline_dependencies,
install: false,
)
gr_policy_trampoline_dep = declare_dependency(
link_with: gf_policy_trampoline_lib,
include_directories: ('.'),
dependencies: gf_policy_trapoline_dependencies,
)

View File

@@ -39,9 +39,9 @@ const gridfire::reaction::ReactionSet& PyNetworkPolicy::get_seed_reactions() con
);
}
gridfire::engine::DynamicEngine& PyNetworkPolicy::construct() {
gridfire::policy::ConstructionResults PyNetworkPolicy::construct() {
PYBIND11_OVERRIDE_PURE(
gridfire::engine::DynamicEngine&,
gridfire::policy::ConstructionResults,
gridfire::policy::NetworkPolicy,
construct
);
@@ -79,6 +79,14 @@ const std::unique_ptr<gridfire::partition::PartitionFunction>& PyNetworkPolicy::
);
}
std::unique_ptr<gridfire::engine::scratch::StateBlob> PyNetworkPolicy::get_stack_scratch_blob() const {
PYBIND11_OVERRIDE_PURE(
std::unique_ptr<gridfire::engine::scratch::StateBlob>,
gridfire::policy::NetworkPolicy,
get_stack_scratch_blob
);
}
const gridfire::reaction::ReactionSet &PyReactionChainPolicy::get_reactions() const {
PYBIND11_OVERRIDE_PURE(
const gridfire::reaction::ReactionSet &,

View File

@@ -13,7 +13,7 @@ public:
[[nodiscard]] const gridfire::reaction::ReactionSet& get_seed_reactions() const override;
[[nodiscard]] gridfire::engine::DynamicEngine& construct() override;
[[nodiscard]] gridfire::policy::ConstructionResults construct() override;
[[nodiscard]] gridfire::policy::NetworkPolicyStatus get_status() const override;
@@ -22,6 +22,8 @@ public:
[[nodiscard]] std::vector<gridfire::engine::EngineTypes> get_engine_types_stack() const override;
[[nodiscard]] const std::unique_ptr<gridfire::partition::PartitionFunction>& get_partition_function() const override;
[[nodiscard]] std::unique_ptr<gridfire::engine::scratch::StateBlob> get_stack_scratch_blob() const override;
};
class PyReactionChainPolicy final : public gridfire::policy::ReactionChainPolicy {