feat(python): Repaired python bindings
Python bindings have now been brought back up to feature pairity with C++. Further, stubs have been added for all python features so that code completion will work
This commit is contained in:
233
src/python/policy/bindings.cpp
Normal file
233
src/python/policy/bindings.cpp
Normal file
@@ -0,0 +1,233 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
|
||||
#include <memory>
|
||||
#include "bindings.h"
|
||||
#include "trampoline/py_policy.h"
|
||||
|
||||
#include "gridfire/policy/policy.h"
|
||||
|
||||
PYBIND11_DECLARE_HOLDER_TYPE(T, std::unique_ptr<T>, true) // Declare unique_ptr as a holder type for pybind11
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
concept IsReactionChainPolicy = std::is_base_of_v<gridfire::policy::ReactionChainPolicy, T>;
|
||||
|
||||
template <typename T>
|
||||
concept IsNetworkPolicy = std::is_base_of_v<gridfire::policy::NetworkPolicy, T>;
|
||||
|
||||
template<IsReactionChainPolicy T, IsReactionChainPolicy BaseT>
|
||||
void registerReactionChainPolicyDefs(py::class_<T, BaseT>& pyClass) {
|
||||
pyClass.def(
|
||||
"get_reactions",
|
||||
&T::get_reactions,
|
||||
"Get the ReactionSet representing this reaction chain."
|
||||
)
|
||||
.def(
|
||||
"contains",
|
||||
py::overload_cast<const std::string&>(&T::contains, py::const_),
|
||||
py::arg("id"),
|
||||
"Check if the reaction chain contains a reaction with the given ID."
|
||||
)
|
||||
.def(
|
||||
"contains",
|
||||
py::overload_cast<const gridfire::reaction::Reaction&>(&T::contains, py::const_),
|
||||
py::arg("reaction"),
|
||||
"Check if the reaction chain contains the given reaction."
|
||||
)
|
||||
.def(
|
||||
"name",
|
||||
&T::name,
|
||||
"Get the name of the reaction chain policy."
|
||||
)
|
||||
.def(
|
||||
"hash",
|
||||
&T::hash,
|
||||
py::arg("seed"),
|
||||
"Compute a hash value for the reaction chain policy."
|
||||
)
|
||||
.def(
|
||||
"__eq__",
|
||||
&T::operator==,
|
||||
py::arg("other"),
|
||||
"Check equality with another ReactionChainPolicy."
|
||||
)
|
||||
.def(
|
||||
"__ne__",
|
||||
&T::operator!=,
|
||||
py::arg("other"),
|
||||
"Check inequality with another ReactionChainPolicy."
|
||||
)
|
||||
.def("__hash__", [](const T &self) {
|
||||
return self.hash(0);
|
||||
}
|
||||
)
|
||||
.def("__repr__", [](const T &self) {
|
||||
std::stringstream ss;
|
||||
ss << self;
|
||||
return ss.str();
|
||||
});
|
||||
}
|
||||
|
||||
template<IsNetworkPolicy T, IsNetworkPolicy BaseT>
|
||||
void registerNetworkPolicyDefs(py::class_<T, BaseT> pyClass) {
|
||||
pyClass.def(
|
||||
"name",
|
||||
&T::name,
|
||||
"Get the name of the network policy."
|
||||
)
|
||||
.def(
|
||||
"get_seed_species",
|
||||
&T::get_seed_species,
|
||||
"Get the set of seed species required by the network policy."
|
||||
)
|
||||
.def(
|
||||
"get_seed_reactions",
|
||||
&T::get_seed_reactions,
|
||||
"Get the set of seed reactions required by the network policy."
|
||||
)
|
||||
.def(
|
||||
"get_status",
|
||||
&T::get_status,
|
||||
"Get the current status of the network policy."
|
||||
)
|
||||
.def(
|
||||
"get_engine_types_stack",
|
||||
&T::get_engine_types_stack,
|
||||
"Get the types of engines in the stack constructed by the network policy."
|
||||
)
|
||||
.def(
|
||||
"construct",
|
||||
&T::construct,
|
||||
py::return_value_policy::reference,
|
||||
"Construct the network according to the policy."
|
||||
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void register_policy_bindings(pybind11::module &m) {
|
||||
register_reaction_chain_policy_bindings(m);
|
||||
register_network_policy_bindings(m);
|
||||
}
|
||||
|
||||
void register_reaction_chain_policy_bindings(pybind11::module &m) {
|
||||
using namespace gridfire::policy;
|
||||
|
||||
py::class_<ReactionChainPolicy, PyReactionChainPolicy> py_reactionChainPolicy(m, "ReactionChainPolicy");
|
||||
py::class_<MultiReactionChainPolicy, ReactionChainPolicy> py_multiChainPolicy(m, "MultiReactionChainPolicy");
|
||||
py::class_<TemperatureDependentChainPolicy, ReactionChainPolicy> py_tempDepChainPolicy(m, "TemperatureDependentChainPolicy");
|
||||
|
||||
|
||||
py::class_<ProtonProtonIChainPolicy, TemperatureDependentChainPolicy> py_ppI(m, "ProtonProtonIChainPolicy");
|
||||
py_ppI.def(py::init<>());
|
||||
py_ppI.def("name", &ProtonProtonIChainPolicy::name, "Get the name of the reaction chain policy.");
|
||||
|
||||
py::class_<ProtonProtonIIChainPolicy, TemperatureDependentChainPolicy> py_ppII(m, "ProtonProtonIIChainPolicy");
|
||||
py_ppII.def(py::init<>());
|
||||
py_ppII.def("name", &ProtonProtonIIChainPolicy::name, "Get the name of the reaction chain policy.");
|
||||
|
||||
py::class_<ProtonProtonIIIChainPolicy, TemperatureDependentChainPolicy> py_ppIII(m, "ProtonProtonIIIChainPolicy");
|
||||
py_ppIII.def(py::init<>());
|
||||
py_ppIII.def("name", &ProtonProtonIIIChainPolicy::name, "Get the name of the reaction chain policy.");
|
||||
|
||||
py::class_<ProtonProtonChainPolicy, MultiReactionChainPolicy> py_ppChain(m, "ProtonProtonChainPolicy");
|
||||
py_ppChain.def(py::init<>());
|
||||
py_ppChain.def("name", &ProtonProtonChainPolicy::name, "Get the name of the reaction chain policy.");
|
||||
|
||||
registerReactionChainPolicyDefs(py_ppI);
|
||||
registerReactionChainPolicyDefs(py_ppII);
|
||||
registerReactionChainPolicyDefs(py_ppIII);
|
||||
registerReactionChainPolicyDefs(py_ppChain);
|
||||
|
||||
py::class_<CNOIChainPolicy, TemperatureDependentChainPolicy> py_cnoI(m, "CNOIChainPolicy");
|
||||
py_cnoI.def(py::init<>());
|
||||
py_cnoI.def("name", &CNOIChainPolicy::name, "Get the name of the reaction chain policy.");
|
||||
|
||||
py::class_<CNOIIChainPolicy, TemperatureDependentChainPolicy> py_cnoII(m, "CNOIIChainPolicy");
|
||||
py_cnoII.def(py::init<>());
|
||||
py_cnoII.def("name", &CNOIIChainPolicy::name, "Get the name of the reaction chain policy.");
|
||||
|
||||
py::class_<CNOIIIChainPolicy, TemperatureDependentChainPolicy> py_cnoIII(m, "CNOIIIChainPolicy");
|
||||
py_cnoIII.def(py::init<>());
|
||||
py_cnoIII.def("name", &CNOIIIChainPolicy::name, "Get the name of the reaction chain policy.");
|
||||
|
||||
py::class_<CNOIVChainPolicy, TemperatureDependentChainPolicy> py_cnoIV(m, "CNOIVChainPolicy");
|
||||
py_cnoIV.def(py::init<>());
|
||||
py_cnoIV.def("name", &CNOIVChainPolicy::name, "Get the name of the reaction chain policy.");
|
||||
|
||||
py::class_<CNOChainPolicy, MultiReactionChainPolicy> py_cnoChain(m, "CNOChainPolicy");
|
||||
py_cnoChain.def(py::init<>());
|
||||
py_cnoChain.def("name", &CNOChainPolicy::name, "Get the name of the reaction chain policy.");
|
||||
|
||||
registerReactionChainPolicyDefs(py_cnoI);
|
||||
registerReactionChainPolicyDefs(py_cnoII);
|
||||
registerReactionChainPolicyDefs(py_cnoIII);
|
||||
registerReactionChainPolicyDefs(py_cnoIV);
|
||||
registerReactionChainPolicyDefs(py_cnoChain);
|
||||
|
||||
py::class_<HotCNOIChainPolicy, TemperatureDependentChainPolicy> py_hotCNOI(m, "HotCNOIChainPolicy");
|
||||
py_hotCNOI.def(py::init<>());
|
||||
py_hotCNOI.def("name", &HotCNOIChainPolicy::name, "Get the name of the reaction chain policy.");
|
||||
|
||||
py::class_<HotCNOIIChainPolicy, TemperatureDependentChainPolicy> py_hotCNOII(m, "HotCNOIIChainPolicy");
|
||||
py_hotCNOII.def(py::init<>());
|
||||
py_hotCNOII.def("name", &HotCNOIIChainPolicy::name, "Get the name of the reaction chain policy.");
|
||||
|
||||
py::class_<HotCNOIIIChainPolicy, TemperatureDependentChainPolicy> py_hotCNOIII(m, "HotCNOIIIChainPolicy");
|
||||
py_hotCNOIII.def(py::init<>());
|
||||
py_hotCNOIII.def("name", &HotCNOIIIChainPolicy::name, "Get the name of the reaction chain policy.");
|
||||
|
||||
py::class_<HotCNOChainPolicy, MultiReactionChainPolicy> py_hotCNOChain(m, "HotCNOChainPolicy");
|
||||
py_hotCNOChain.def(py::init<>());
|
||||
py_hotCNOChain.def("name", &HotCNOChainPolicy::name, "Get the name of the reaction chain policy.");
|
||||
|
||||
registerReactionChainPolicyDefs(py_hotCNOI);
|
||||
registerReactionChainPolicyDefs(py_hotCNOII);
|
||||
registerReactionChainPolicyDefs(py_hotCNOIII);
|
||||
registerReactionChainPolicyDefs(py_hotCNOChain);
|
||||
|
||||
py::class_<TripleAlphaChainPolicy, TemperatureDependentChainPolicy> py_tripleAlpha(m, "TripleAlphaChainPolicy");
|
||||
py_tripleAlpha.def(py::init<>());
|
||||
py_tripleAlpha.def("name", &TripleAlphaChainPolicy::name, "Get the name of the reaction chain policy.");
|
||||
|
||||
registerReactionChainPolicyDefs(py_tripleAlpha);
|
||||
|
||||
py::class_<MainSequenceReactionChainPolicy, MultiReactionChainPolicy> py_mainSeq(m, "MainSequenceReactionChainPolicy");
|
||||
py_mainSeq.def(py::init<>());
|
||||
py_mainSeq.def("name", &MainSequenceReactionChainPolicy::name, "Get the name of the reaction chain policy.");
|
||||
|
||||
registerReactionChainPolicyDefs(py_mainSeq);
|
||||
|
||||
}
|
||||
|
||||
void register_network_policy_bindings(pybind11::module &m) {
|
||||
py::enum_<gridfire::policy::NetworkPolicyStatus>(m, "NetworkPolicyStatus")
|
||||
.value("UNINITIALIZED", gridfire::policy::NetworkPolicyStatus::UNINITIALIZED)
|
||||
.value("INITIALIZED_UNVERIFIED", gridfire::policy::NetworkPolicyStatus::INITIALIZED_UNVERIFIED)
|
||||
.value("MISSING_KEY_REACTION", gridfire::policy::NetworkPolicyStatus::MISSING_KEY_REACTION)
|
||||
.value("MISSING_KEY_SPECIES", gridfire::policy::NetworkPolicyStatus::MISSING_KEY_SPECIES)
|
||||
.value("INITIALIZED_VERIFIED", gridfire::policy::NetworkPolicyStatus::INITIALIZED_VERIFIED)
|
||||
.export_values();
|
||||
|
||||
py::class_<gridfire::policy::NetworkPolicy, PyNetworkPolicy> py_networkPolicy(m, "NetworkPolicy");
|
||||
py::class_<gridfire::policy::MainSequencePolicy, gridfire::policy::NetworkPolicy> py_mainSeqPolicy(m, "MainSequencePolicy");
|
||||
py_mainSeqPolicy.def(
|
||||
py::init<const fourdst::composition::Composition&>(),
|
||||
py::arg("composition"),
|
||||
"Construct MainSequencePolicy from an existing composition."
|
||||
);
|
||||
py_mainSeqPolicy.def(
|
||||
py::init<std::vector<fourdst::atomic::Species>, const std::vector<double>&>(),
|
||||
py::arg("seed_species"),
|
||||
py::arg("mass_fractions"),
|
||||
"Construct MainSequencePolicy from seed species and mass fractions."
|
||||
);
|
||||
|
||||
registerNetworkPolicyDefs(py_mainSeqPolicy);
|
||||
}
|
||||
7
src/python/policy/bindings.h
Normal file
7
src/python/policy/bindings.h
Normal file
@@ -0,0 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
void register_policy_bindings(pybind11::module& m);
|
||||
void register_reaction_chain_policy_bindings(pybind11::module& m);
|
||||
void register_network_policy_bindings(pybind11::module& m);
|
||||
19
src/python/policy/meson.build
Normal file
19
src/python/policy/meson.build
Normal file
@@ -0,0 +1,19 @@
|
||||
# Define the library
|
||||
subdir('trampoline')
|
||||
|
||||
bindings_sources = files('bindings.cpp')
|
||||
bindings_headers = files('bindings.h')
|
||||
|
||||
dependencies = [
|
||||
gridfire_dep,
|
||||
python3_dep,
|
||||
pybind11_dep,
|
||||
]
|
||||
|
||||
shared_module('py_gf_policy',
|
||||
bindings_sources,
|
||||
cpp_args: ['-fvisibility=default'],
|
||||
install : true,
|
||||
dependencies: dependencies,
|
||||
include_directories: include_directories('.')
|
||||
)
|
||||
21
src/python/policy/trampoline/meson.build
Normal file
21
src/python/policy/trampoline/meson.build
Normal file
@@ -0,0 +1,21 @@
|
||||
gf_policy_trampoline_sources = files('py_policy.cpp')
|
||||
|
||||
gf_policy_trapoline_dependencies = [
|
||||
gridfire_dep,
|
||||
pybind11_dep,
|
||||
python3_dep,
|
||||
]
|
||||
|
||||
gf_policy_trampoline_lib = static_library(
|
||||
'policy_trampolines',
|
||||
gf_policy_trampoline_sources,
|
||||
include_directories: include_directories('.'),
|
||||
dependencies: gf_policy_trapoline_dependencies,
|
||||
install: false,
|
||||
)
|
||||
|
||||
gr_policy_trampoline_dep = declare_dependency(
|
||||
link_with: gf_policy_trampoline_lib,
|
||||
include_directories: ('.'),
|
||||
dependencies: gf_policy_trapoline_dependencies,
|
||||
)
|
||||
149
src/python/policy/trampoline/py_policy.cpp
Normal file
149
src/python/policy/trampoline/py_policy.cpp
Normal file
@@ -0,0 +1,149 @@
|
||||
#include "py_policy.h"
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
|
||||
#include "fourdst/atomic/atomicSpecies.h"
|
||||
|
||||
#include "gridfire/reaction/reaction.h"
|
||||
#include "gridfire/engine/engine.h"
|
||||
|
||||
#include "gridfire/policy/policy.h"
|
||||
|
||||
#include <string>
|
||||
#include <set>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
std::string PyNetworkPolicy::name() const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
std::string,
|
||||
gridfire::policy::NetworkPolicy,
|
||||
name
|
||||
);
|
||||
}
|
||||
|
||||
const std::set<fourdst::atomic::Species>& PyNetworkPolicy::get_seed_species() const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
const std::set<fourdst::atomic::Species>&,
|
||||
gridfire::policy::NetworkPolicy,
|
||||
get_seed_species
|
||||
);
|
||||
}
|
||||
|
||||
const gridfire::reaction::ReactionSet& PyNetworkPolicy::get_seed_reactions() const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
const gridfire::reaction::ReactionSet&,
|
||||
gridfire::policy::NetworkPolicy,
|
||||
get_seed_reactions
|
||||
);
|
||||
}
|
||||
|
||||
gridfire::engine::DynamicEngine& PyNetworkPolicy::construct() {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
gridfire::engine::DynamicEngine&,
|
||||
gridfire::policy::NetworkPolicy,
|
||||
construct
|
||||
);
|
||||
}
|
||||
|
||||
gridfire::policy::NetworkPolicyStatus PyNetworkPolicy::get_status() const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
gridfire::policy::NetworkPolicyStatus,
|
||||
gridfire::policy::NetworkPolicy,
|
||||
getStatus
|
||||
);
|
||||
}
|
||||
|
||||
const std::vector<std::unique_ptr<gridfire::engine::DynamicEngine>> &PyNetworkPolicy::get_engine_stack() const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
const std::vector<std::unique_ptr<gridfire::engine::DynamicEngine>> &,
|
||||
gridfire::policy::NetworkPolicy,
|
||||
get_engine_stack
|
||||
);
|
||||
}
|
||||
|
||||
std::vector<gridfire::engine::EngineTypes> PyNetworkPolicy::get_engine_types_stack() const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
std::vector<gridfire::engine::EngineTypes>,
|
||||
gridfire::policy::NetworkPolicy,
|
||||
get_engine_types_stack
|
||||
);
|
||||
}
|
||||
|
||||
const std::unique_ptr<gridfire::partition::PartitionFunction>& PyNetworkPolicy::get_partition_function() const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
const std::unique_ptr<gridfire::partition::PartitionFunction>&,
|
||||
gridfire::policy::NetworkPolicy,
|
||||
get_partition_function
|
||||
);
|
||||
}
|
||||
|
||||
const gridfire::reaction::ReactionSet &PyReactionChainPolicy::get_reactions() const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
const gridfire::reaction::ReactionSet &,
|
||||
gridfire::policy::ReactionChainPolicy,
|
||||
get_reactions
|
||||
);
|
||||
}
|
||||
|
||||
bool PyReactionChainPolicy::contains(const std::string &id) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
bool,
|
||||
gridfire::policy::ReactionChainPolicy,
|
||||
contains,
|
||||
id
|
||||
);
|
||||
}
|
||||
|
||||
bool PyReactionChainPolicy::contains(const gridfire::reaction::Reaction &reaction) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
bool,
|
||||
gridfire::policy::ReactionChainPolicy,
|
||||
contains,
|
||||
reaction
|
||||
);
|
||||
}
|
||||
|
||||
std::unique_ptr<gridfire::policy::ReactionChainPolicy> PyReactionChainPolicy::clone() const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
std::unique_ptr<gridfire::policy::ReactionChainPolicy>,
|
||||
gridfire::policy::ReactionChainPolicy,
|
||||
clone
|
||||
);
|
||||
}
|
||||
|
||||
std::string PyReactionChainPolicy::name() const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
std::string,
|
||||
gridfire::policy::ReactionChainPolicy,
|
||||
name
|
||||
);
|
||||
}
|
||||
|
||||
uint64_t PyReactionChainPolicy::hash(uint64_t seed) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
uint64_t,
|
||||
gridfire::policy::ReactionChainPolicy,
|
||||
hash,
|
||||
seed
|
||||
);
|
||||
}
|
||||
|
||||
bool PyReactionChainPolicy::operator==(const ReactionChainPolicy &other) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
bool,
|
||||
gridfire::policy::ReactionChainPolicy,
|
||||
operator==,
|
||||
other
|
||||
);
|
||||
}
|
||||
|
||||
bool PyReactionChainPolicy::operator!=(const ReactionChainPolicy &other) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
bool,
|
||||
gridfire::policy::ReactionChainPolicy,
|
||||
operator!=,
|
||||
other
|
||||
);
|
||||
}
|
||||
44
src/python/policy/trampoline/py_policy.h
Normal file
44
src/python/policy/trampoline/py_policy.h
Normal file
@@ -0,0 +1,44 @@
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include "gridfire/policy/policy.h"
|
||||
|
||||
class PyNetworkPolicy final : public gridfire::policy::NetworkPolicy {
|
||||
public:
|
||||
[[nodiscard]] std::string name() const override;
|
||||
|
||||
[[nodiscard]] const std::set<fourdst::atomic::Species>& get_seed_species() const override;
|
||||
|
||||
[[nodiscard]] const gridfire::reaction::ReactionSet& get_seed_reactions() const override;
|
||||
|
||||
[[nodiscard]] gridfire::engine::DynamicEngine& construct() override;
|
||||
|
||||
[[nodiscard]] gridfire::policy::NetworkPolicyStatus get_status() const override;
|
||||
|
||||
[[nodiscard]] const std::vector<std::unique_ptr<gridfire::engine::DynamicEngine>> &get_engine_stack() const override;
|
||||
|
||||
[[nodiscard]] std::vector<gridfire::engine::EngineTypes> get_engine_types_stack() const override;
|
||||
|
||||
[[nodiscard]] const std::unique_ptr<gridfire::partition::PartitionFunction>& get_partition_function() const override;
|
||||
};
|
||||
|
||||
class PyReactionChainPolicy final : public gridfire::policy::ReactionChainPolicy {
|
||||
public:
|
||||
[[nodiscard]] const gridfire::reaction::ReactionSet & get_reactions() const override;
|
||||
|
||||
[[nodiscard]] bool contains(const std::string &id) const override;
|
||||
|
||||
[[nodiscard]] bool contains(const gridfire::reaction::Reaction &reaction) const override;
|
||||
|
||||
[[nodiscard]] std::unique_ptr<ReactionChainPolicy> clone() const override;
|
||||
|
||||
[[nodiscard]] std::string name() const override;
|
||||
|
||||
[[nodiscard]] uint64_t hash(uint64_t seed) const override;
|
||||
|
||||
[[nodiscard]] bool operator==(const ReactionChainPolicy &other) const override;
|
||||
|
||||
[[nodiscard]] bool operator!=(const ReactionChainPolicy &other) const override;
|
||||
};
|
||||
Reference in New Issue
Block a user