feat(python): Python Bindings
Python Bindings are working again
This commit is contained in:
@@ -8,7 +8,6 @@
|
||||
|
||||
#include "gridfire/policy/policy.h"
|
||||
|
||||
PYBIND11_DECLARE_HOLDER_TYPE(T, std::unique_ptr<T>, true) // Declare unique_ptr as a holder type for pybind11
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
@@ -103,9 +102,30 @@ namespace {
|
||||
.def(
|
||||
"construct",
|
||||
&T::construct,
|
||||
py::return_value_policy::reference,
|
||||
"Construct the network according to the policy."
|
||||
|
||||
)
|
||||
.def(
|
||||
"get_engine_stack",
|
||||
[](const T &self) {
|
||||
const auto& stack = self.get_engine_stack();
|
||||
std::vector<gridfire::engine::DynamicEngine*> engine_ptrs;
|
||||
engine_ptrs.reserve(stack.size());
|
||||
for (const auto& engine_uptr : stack) {
|
||||
engine_ptrs.push_back(engine_uptr.get());
|
||||
}
|
||||
|
||||
return engine_ptrs;
|
||||
},
|
||||
py::return_value_policy::reference_internal
|
||||
)
|
||||
.def(
|
||||
"get_stack_scratch_blob",
|
||||
&T::get_stack_scratch_blob
|
||||
)
|
||||
.def(
|
||||
"get_partition_function",
|
||||
&T::get_partition_function
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -215,6 +235,26 @@ void register_network_policy_bindings(pybind11::module &m) {
|
||||
.value("INITIALIZED_VERIFIED", gridfire::policy::NetworkPolicyStatus::INITIALIZED_VERIFIED)
|
||||
.export_values();
|
||||
|
||||
m.def("network_policy_status_to_string",
|
||||
&gridfire::policy::NetworkPolicyStatusToString,
|
||||
py::arg("status"),
|
||||
"Convert a NetworkPolicyStatus enum value to its string representation."
|
||||
);
|
||||
|
||||
py::class_<gridfire::policy::ConstructionResults>(m, "ConstructionResults")
|
||||
.def_property_readonly("engine",
|
||||
[](const gridfire::policy::ConstructionResults &self) -> const gridfire::engine::DynamicEngine& {
|
||||
return self.engine;
|
||||
},
|
||||
py::return_value_policy::reference
|
||||
)
|
||||
.def_property_readonly("scratch_blob",
|
||||
[](const gridfire::policy::ConstructionResults &self) {
|
||||
return self.scratch_blob.get();
|
||||
},
|
||||
py::return_value_policy::reference_internal
|
||||
);
|
||||
|
||||
py::class_<gridfire::policy::NetworkPolicy, PyNetworkPolicy> py_networkPolicy(m, "NetworkPolicy");
|
||||
py::class_<gridfire::policy::MainSequencePolicy, gridfire::policy::NetworkPolicy> py_mainSeqPolicy(m, "MainSequencePolicy");
|
||||
py_mainSeqPolicy.def(
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
# Define the library
|
||||
subdir('trampoline')
|
||||
|
||||
bindings_sources = files('bindings.cpp')
|
||||
bindings_headers = files('bindings.h')
|
||||
|
||||
dependencies = [
|
||||
gridfire_dep,
|
||||
python3_dep,
|
||||
pybind11_dep,
|
||||
]
|
||||
|
||||
shared_module('py_gf_policy',
|
||||
bindings_sources,
|
||||
cpp_args: ['-fvisibility=default'],
|
||||
install : true,
|
||||
dependencies: dependencies,
|
||||
include_directories: include_directories('.')
|
||||
)
|
||||
@@ -1,21 +0,0 @@
|
||||
gf_policy_trampoline_sources = files('py_policy.cpp')
|
||||
|
||||
gf_policy_trapoline_dependencies = [
|
||||
gridfire_dep,
|
||||
pybind11_dep,
|
||||
python3_dep,
|
||||
]
|
||||
|
||||
gf_policy_trampoline_lib = static_library(
|
||||
'policy_trampolines',
|
||||
gf_policy_trampoline_sources,
|
||||
include_directories: include_directories('.'),
|
||||
dependencies: gf_policy_trapoline_dependencies,
|
||||
install: false,
|
||||
)
|
||||
|
||||
gr_policy_trampoline_dep = declare_dependency(
|
||||
link_with: gf_policy_trampoline_lib,
|
||||
include_directories: ('.'),
|
||||
dependencies: gf_policy_trapoline_dependencies,
|
||||
)
|
||||
@@ -39,9 +39,9 @@ const gridfire::reaction::ReactionSet& PyNetworkPolicy::get_seed_reactions() con
|
||||
);
|
||||
}
|
||||
|
||||
gridfire::engine::DynamicEngine& PyNetworkPolicy::construct() {
|
||||
gridfire::policy::ConstructionResults PyNetworkPolicy::construct() {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
gridfire::engine::DynamicEngine&,
|
||||
gridfire::policy::ConstructionResults,
|
||||
gridfire::policy::NetworkPolicy,
|
||||
construct
|
||||
);
|
||||
@@ -79,6 +79,14 @@ const std::unique_ptr<gridfire::partition::PartitionFunction>& PyNetworkPolicy::
|
||||
);
|
||||
}
|
||||
|
||||
std::unique_ptr<gridfire::engine::scratch::StateBlob> PyNetworkPolicy::get_stack_scratch_blob() const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
std::unique_ptr<gridfire::engine::scratch::StateBlob>,
|
||||
gridfire::policy::NetworkPolicy,
|
||||
get_stack_scratch_blob
|
||||
);
|
||||
}
|
||||
|
||||
const gridfire::reaction::ReactionSet &PyReactionChainPolicy::get_reactions() const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
const gridfire::reaction::ReactionSet &,
|
||||
|
||||
@@ -13,7 +13,7 @@ public:
|
||||
|
||||
[[nodiscard]] const gridfire::reaction::ReactionSet& get_seed_reactions() const override;
|
||||
|
||||
[[nodiscard]] gridfire::engine::DynamicEngine& construct() override;
|
||||
[[nodiscard]] gridfire::policy::ConstructionResults construct() override;
|
||||
|
||||
[[nodiscard]] gridfire::policy::NetworkPolicyStatus get_status() const override;
|
||||
|
||||
@@ -22,6 +22,8 @@ public:
|
||||
[[nodiscard]] std::vector<gridfire::engine::EngineTypes> get_engine_types_stack() const override;
|
||||
|
||||
[[nodiscard]] const std::unique_ptr<gridfire::partition::PartitionFunction>& get_partition_function() const override;
|
||||
|
||||
[[nodiscard]] std::unique_ptr<gridfire::engine::scratch::StateBlob> get_stack_scratch_blob() const override;
|
||||
};
|
||||
|
||||
class PyReactionChainPolicy final : public gridfire::policy::ReactionChainPolicy {
|
||||
|
||||
Reference in New Issue
Block a user