refactor(reaction): refactored to an abstract reaction class in prep for weak reactions
This commit is contained in:
@@ -1,8 +1,6 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "types/bindings.h"
|
||||
#include "partition/bindings.h"
|
||||
#include "expectations/bindings.h"
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
|
||||
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc if needed directly
|
||||
|
||||
#include <iostream>
|
||||
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc. if needed directly
|
||||
|
||||
#include "bindings.h"
|
||||
|
||||
@@ -152,7 +150,7 @@ void register_engine_bindings(py::module &m) {
|
||||
);
|
||||
}
|
||||
|
||||
void register_base_engine_bindings(pybind11::module &m) {
|
||||
void register_base_engine_bindings(const pybind11::module &m) {
|
||||
|
||||
py::class_<gridfire::StepDerivatives<double>>(m, "StepDerivatives")
|
||||
.def_readonly("dYdt", &gridfire::StepDerivatives<double>::dydt, "The right-hand side (dY/dt) of the ODE system.")
|
||||
@@ -165,15 +163,15 @@ void register_base_engine_bindings(pybind11::module &m) {
|
||||
con_stype_register_graph_engine_bindings(m);
|
||||
}
|
||||
|
||||
void abs_stype_register_engine_bindings(pybind11::module &m) {
|
||||
void abs_stype_register_engine_bindings(const pybind11::module &m) {
|
||||
py::class_<gridfire::Engine, PyEngine>(m, "Engine");
|
||||
}
|
||||
|
||||
void abs_stype_register_dynamic_engine_bindings(pybind11::module &m) {
|
||||
void abs_stype_register_dynamic_engine_bindings(const pybind11::module &m) {
|
||||
const auto a = py::class_<gridfire::DynamicEngine, PyDynamicEngine>(m, "DynamicEngine");
|
||||
}
|
||||
|
||||
void con_stype_register_graph_engine_bindings(pybind11::module &m) {
|
||||
void con_stype_register_graph_engine_bindings(const pybind11::module &m) {
|
||||
py::enum_<gridfire::NetworkBuildDepth>(m, "NetworkBuildDepth")
|
||||
.value("Full", gridfire::NetworkBuildDepth::Full, "Full network build depth")
|
||||
.value("Shallow", gridfire::NetworkBuildDepth::Shallow, "Shallow network build depth")
|
||||
@@ -199,7 +197,7 @@ void con_stype_register_graph_engine_bindings(pybind11::module &m) {
|
||||
py::arg("depth") = gridfire::NetworkBuildDepth::Full,
|
||||
"Initialize GraphEngine with a composition, partition function and build depth."
|
||||
);
|
||||
py_dynamic_engine_bindings.def(py::init<const gridfire::reaction::LogicalReactionSet &>(),
|
||||
py_dynamic_engine_bindings.def(py::init<const gridfire::reaction::ReactionSet &>(),
|
||||
py::arg("reactions"),
|
||||
"Initialize GraphEngine with a set of reactions."
|
||||
);
|
||||
@@ -267,7 +265,7 @@ void con_stype_register_graph_engine_bindings(pybind11::module &m) {
|
||||
registerDynamicEngineDefs<gridfire::GraphEngine, gridfire::DynamicEngine>(py_dynamic_engine_bindings);
|
||||
}
|
||||
|
||||
void register_engine_view_bindings(pybind11::module &m) {
|
||||
void register_engine_view_bindings(const pybind11::module &m) {
|
||||
auto py_defined_engine_view_bindings = py::class_<gridfire::DefinedEngineView, gridfire::DynamicEngine>(m, "DefinedEngineView");
|
||||
|
||||
py_defined_engine_view_bindings.def(py::init<std::vector<std::string>, gridfire::DynamicEngine&>(),
|
||||
|
||||
@@ -4,13 +4,13 @@
|
||||
|
||||
void register_engine_bindings(pybind11::module &m);
|
||||
|
||||
void register_base_engine_bindings(pybind11::module &m);
|
||||
void register_base_engine_bindings(const pybind11::module &m);
|
||||
|
||||
void register_engine_view_bindings(pybind11::module &m);
|
||||
void register_engine_view_bindings(const pybind11::module &m);
|
||||
|
||||
void abs_stype_register_engine_bindings(pybind11::module &m);
|
||||
void abs_stype_register_dynamic_engine_bindings(pybind11::module &m);
|
||||
void abs_stype_register_engine_bindings(const pybind11::module &m);
|
||||
void abs_stype_register_dynamic_engine_bindings(const pybind11::module &m);
|
||||
|
||||
void con_stype_register_graph_engine_bindings(pybind11::module &m);
|
||||
void con_stype_register_graph_engine_bindings(const pybind11::module &m);
|
||||
|
||||
|
||||
|
||||
@@ -23,10 +23,9 @@ const std::vector<fourdst::atomic::Species>& PyEngine::getNetworkSpecies() const
|
||||
/*
|
||||
* get_override() looks for a Python method that overrides this C++ one.
|
||||
*/
|
||||
py::function override = py::get_override(this, "getNetworkSpecies");
|
||||
|
||||
if (override) {
|
||||
py::object result = override();
|
||||
if (const py::function override = py::get_override(this, "getNetworkSpecies")) {
|
||||
const py::object result = override();
|
||||
m_species_cache = result.cast<std::vector<fourdst::atomic::Species>>();
|
||||
return m_species_cache;
|
||||
}
|
||||
@@ -57,10 +56,9 @@ const std::vector<fourdst::atomic::Species>& PyDynamicEngine::getNetworkSpecies(
|
||||
/*
|
||||
* get_override() looks for a Python method that overrides this C++ one.
|
||||
*/
|
||||
py::function override = py::get_override(this, "getNetworkSpecies");
|
||||
|
||||
if (override) {
|
||||
py::object result = override();
|
||||
if (const py::function override = py::get_override(this, "getNetworkSpecies")) {
|
||||
const py::object result = override();
|
||||
m_species_cache = result.cast<std::vector<fourdst::atomic::Species>>();
|
||||
return m_species_cache;
|
||||
}
|
||||
@@ -129,15 +127,15 @@ double PyDynamicEngine::calculateMolarReactionFlow(const gridfire::reaction::Rea
|
||||
);
|
||||
}
|
||||
|
||||
const gridfire::reaction::LogicalReactionSet& PyDynamicEngine::getNetworkReactions() const {
|
||||
const gridfire::reaction::ReactionSet& PyDynamicEngine::getNetworkReactions() const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
const gridfire::reaction::LogicalReactionSet&,
|
||||
const gridfire::reaction::ReactionSet&,
|
||||
gridfire::DynamicEngine,
|
||||
getNetworkReactions
|
||||
);
|
||||
}
|
||||
|
||||
void PyDynamicEngine::setNetworkReactions(const gridfire::reaction::LogicalReactionSet& reactions) {
|
||||
void PyDynamicEngine::setNetworkReactions(const gridfire::reaction::ReactionSet& reactions) {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
void,
|
||||
gridfire::DynamicEngine,
|
||||
@@ -199,7 +197,7 @@ gridfire::screening::ScreeningType PyDynamicEngine::getScreeningModel() const {
|
||||
);
|
||||
}
|
||||
|
||||
int PyDynamicEngine::getSpeciesIndex(const fourdst::atomic::Species &species) const {
|
||||
size_t PyDynamicEngine::getSpeciesIndex(const fourdst::atomic::Species &species) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
int,
|
||||
gridfire::DynamicEngine,
|
||||
|
||||
@@ -27,15 +27,16 @@ public:
|
||||
void generateStoichiometryMatrix() override;
|
||||
int getStoichiometryMatrixEntry(int speciesIndex, int reactionIndex) const override;
|
||||
double calculateMolarReactionFlow(const gridfire::reaction::Reaction &reaction, const std::vector<double> &Y, double T9, double rho) const override;
|
||||
const gridfire::reaction::LogicalReactionSet& getNetworkReactions() const override;
|
||||
void setNetworkReactions(const gridfire::reaction::LogicalReactionSet& reactions) override;
|
||||
const gridfire::reaction::ReactionSet& getNetworkReactions() const override;
|
||||
void setNetworkReactions(const gridfire::reaction::ReactionSet& reactions) override;
|
||||
std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::expectations::StaleEngineError> getSpeciesTimescales(const std::vector<double> &Y, double T9, double rho) const override;
|
||||
std::expected<std::unordered_map<fourdst::atomic::Species, double>, gridfire::expectations::StaleEngineError> getSpeciesDestructionTimescales(const std::vector<double> &Y, double T9, double rho) const override;
|
||||
fourdst::composition::Composition update(const gridfire::NetIn &netIn) override;
|
||||
bool isStale(const gridfire::NetIn &netIn) override;
|
||||
void setScreeningModel(gridfire::screening::ScreeningType model) override;
|
||||
gridfire::screening::ScreeningType getScreeningModel() const override;
|
||||
int getSpeciesIndex(const fourdst::atomic::Species &species) const override;
|
||||
|
||||
size_t getSpeciesIndex(const fourdst::atomic::Species &species) const override;
|
||||
std::vector<double> mapNetInToMolarAbundanceVector(const gridfire::NetIn &netIn) const override;
|
||||
gridfire::PrimingReport primeEngine(const gridfire::NetIn &netIn) override;
|
||||
gridfire::BuildDepthType getDepth() const override {
|
||||
|
||||
@@ -1,8 +1,4 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
|
||||
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc if needed directly
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "bindings.h"
|
||||
|
||||
@@ -10,7 +6,7 @@ namespace py = pybind11;
|
||||
|
||||
#include "gridfire/exceptions/exceptions.h"
|
||||
|
||||
void register_exception_bindings(py::module &m) {
|
||||
void register_exception_bindings(const py::module &m) {
|
||||
py::register_exception<gridfire::exceptions::EngineError>(m, "GridFireEngineError");
|
||||
|
||||
// TODO: Make it so that we can grab the stale state in python
|
||||
|
||||
@@ -2,4 +2,4 @@
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
void register_exception_bindings(pybind11::module &m);
|
||||
void register_exception_bindings(const pybind11::module &m);
|
||||
|
||||
@@ -8,7 +8,7 @@ namespace py = pybind11;
|
||||
|
||||
#include "gridfire/expectations/expectations.h"
|
||||
|
||||
void register_expectation_bindings(py::module &m) {
|
||||
void register_expectation_bindings(const py::module &m) {
|
||||
py::enum_<gridfire::expectations::EngineErrorTypes>(m, "EngineErrorTypes")
|
||||
.value("FAILURE", gridfire::expectations::EngineErrorTypes::FAILURE)
|
||||
.value("INDEX", gridfire::expectations::EngineErrorTypes::INDEX)
|
||||
|
||||
@@ -2,4 +2,4 @@
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
void register_expectation_bindings(pybind11::module &m);
|
||||
void register_expectation_bindings(const pybind11::module &m);
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
|
||||
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc if needed directly
|
||||
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc. if needed directly
|
||||
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
|
||||
#include "bindings.h"
|
||||
@@ -12,12 +10,12 @@
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void register_io_bindings(py::module &m) {
|
||||
py::class_<gridfire::io::ParsedNetworkData>(m, "ParsedNetworkData");
|
||||
auto register_io_bindings(const py::module &m) -> void {
|
||||
auto ParsedNetworkData = py::class_<gridfire::io::ParsedNetworkData>(m, "ParsedNetworkData");
|
||||
|
||||
py::class_<gridfire::io::NetworkFileParser, PyNetworkFileParser>(m, "NetworkFileParser");
|
||||
auto NetworkFileParser = py::class_<gridfire::io::NetworkFileParser, PyNetworkFileParser>(m, "NetworkFileParser");
|
||||
|
||||
py::class_<gridfire::io::SimpleReactionListFileParser, gridfire::io::NetworkFileParser>(m, "SimpleReactionListFileParser")
|
||||
auto SimpleReactionListFileParser = py::class_<gridfire::io::SimpleReactionListFileParser, gridfire::io::NetworkFileParser>(m, "SimpleReactionListFileParser")
|
||||
.def("parse", &gridfire::io::SimpleReactionListFileParser::parse,
|
||||
py::arg("filename"),
|
||||
"Parse a simple reaction list file and return a ParsedNetworkData object.");
|
||||
|
||||
@@ -2,4 +2,4 @@
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
void register_io_bindings(pybind11::module &m);
|
||||
void register_io_bindings(const pybind11::module &m);
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
|
||||
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc if needed directly
|
||||
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc. if needed directly
|
||||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
@@ -19,7 +19,7 @@ namespace py = pybind11;
|
||||
|
||||
void register_partition_bindings(pybind11::module &m) {
|
||||
using PF = gridfire::partition::PartitionFunction;
|
||||
py::class_<PF, PyPartitionFunction>(m, "PartitionFunction");
|
||||
auto TrampPartitionFunction = py::class_<PF, PyPartitionFunction>(m, "PartitionFunction");
|
||||
|
||||
register_partition_types_bindings(m);
|
||||
register_ground_state_partition_bindings(m);
|
||||
@@ -44,7 +44,7 @@ void register_partition_types_bindings(pybind11::module &m) {
|
||||
}, py::arg("typeStr"), "Convert string to BasePartitionType.");
|
||||
}
|
||||
|
||||
void register_ground_state_partition_bindings(pybind11::module &m) {
|
||||
void register_ground_state_partition_bindings(const pybind11::module &m) {
|
||||
using GSPF = gridfire::partition::GroundStatePartitionFunction;
|
||||
using PF = gridfire::partition::PartitionFunction;
|
||||
py::class_<GSPF, PF>(m, "GroundStatePartitionFunction")
|
||||
@@ -62,7 +62,7 @@ void register_ground_state_partition_bindings(pybind11::module &m) {
|
||||
"Get the type of the partition function (should return 'GroundState').");
|
||||
}
|
||||
|
||||
void register_rauscher_thielemann_partition_data_record_bindings(pybind11::module &m) {
|
||||
void register_rauscher_thielemann_partition_data_record_bindings(const pybind11::module &m) {
|
||||
py::class_<gridfire::partition::record::RauscherThielemannPartitionDataRecord>(m, "RauscherThielemannPartitionDataRecord")
|
||||
.def_readonly("z", &gridfire::partition::record::RauscherThielemannPartitionDataRecord::z, "Atomic number")
|
||||
.def_readonly("a", &gridfire::partition::record::RauscherThielemannPartitionDataRecord::a, "Mass number")
|
||||
@@ -71,7 +71,7 @@ void register_rauscher_thielemann_partition_data_record_bindings(pybind11::modul
|
||||
}
|
||||
|
||||
|
||||
void register_rauscher_thielemann_partition_bindings(pybind11::module &m) {
|
||||
void register_rauscher_thielemann_partition_bindings(const pybind11::module &m) {
|
||||
using RTPF = gridfire::partition::RauscherThielemannPartitionFunction;
|
||||
using PF = gridfire::partition::PartitionFunction;
|
||||
py::class_<RTPF, PF>(m, "RauscherThielemannPartitionFunction")
|
||||
@@ -89,7 +89,7 @@ void register_rauscher_thielemann_partition_bindings(pybind11::module &m) {
|
||||
"Get the type of the partition function (should return 'RauscherThielemann').");
|
||||
}
|
||||
|
||||
void register_composite_partition_bindings(pybind11::module &m) {
|
||||
void register_composite_partition_bindings(const pybind11::module &m) {
|
||||
py::class_<gridfire::partition::CompositePartitionFunction>(m, "CompositePartitionFunction")
|
||||
.def(py::init<const std::vector<gridfire::partition::BasePartitionType>&>(),
|
||||
py::arg("partitionFunctions"),
|
||||
|
||||
@@ -6,11 +6,11 @@ void register_partition_bindings(pybind11::module &m);
|
||||
|
||||
void register_partition_types_bindings(pybind11::module &m);
|
||||
|
||||
void register_ground_state_partition_bindings(pybind11::module &m);
|
||||
void register_ground_state_partition_bindings(const pybind11::module &m);
|
||||
|
||||
void register_rauscher_thielemann_partition_data_record_bindings(pybind11::module &m);
|
||||
void register_rauscher_thielemann_partition_data_record_bindings(const pybind11::module &m);
|
||||
|
||||
void register_rauscher_thielemann_partition_bindings(pybind11::module &m);
|
||||
void register_rauscher_thielemann_partition_bindings(const pybind11::module &m);
|
||||
|
||||
void register_composite_partition_bindings(pybind11::module &m);
|
||||
void register_composite_partition_bindings(const pybind11::module &m);
|
||||
|
||||
|
||||
@@ -7,9 +7,9 @@
|
||||
|
||||
|
||||
class PyPartitionFunction final : public gridfire::partition::PartitionFunction {
|
||||
double evaluate(int z, int a, double T9) const override;
|
||||
double evaluateDerivative(int z, int a, double T9) const override;
|
||||
bool supports(int z, int a) const override;
|
||||
std::string type() const override;
|
||||
std::unique_ptr<gridfire::partition::PartitionFunction> clone() const override;
|
||||
[[nodiscard]] double evaluate(int z, int a, double T9) const override;
|
||||
[[nodiscard]] double evaluateDerivative(int z, int a, double T9) const override;
|
||||
[[nodiscard]] bool supports(int z, int a) const override;
|
||||
[[nodiscard]] std::string type() const override;
|
||||
[[nodiscard]] std::unique_ptr<PartitionFunction> clone() const override;
|
||||
};
|
||||
@@ -1,6 +1,6 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
|
||||
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc if needed directly
|
||||
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc. if needed directly
|
||||
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
@@ -22,150 +22,332 @@ void register_reaction_bindings(py::module &m) {
|
||||
);
|
||||
|
||||
using fourdst::atomic::Species;
|
||||
py::class_<gridfire::reaction::Reaction>(m, "Reaction")
|
||||
.def(py::init<const std::string_view, const std::string_view, int, const std::vector<Species>&, const std::vector<Species>&, double, std::string_view, gridfire::reaction::RateCoefficientSet, bool>(),
|
||||
py::arg("id"), py::arg("peName"), py::arg("chapter"),
|
||||
py::arg("reactants"), py::arg("products"), py::arg("qValue"),
|
||||
py::arg("label"), py::arg("sets"), py::arg("reverse") = false,
|
||||
"Construct a Reaction with the given parameters.")
|
||||
.def("calculate_rate", static_cast<double (gridfire::reaction::Reaction::*)(double) const>(&gridfire::reaction::Reaction::calculate_rate),
|
||||
py::arg("T9"), "Calculate the reaction rate at a given temperature T9 (in units of 10^9 K).")
|
||||
.def("peName", &gridfire::reaction::Reaction::peName,
|
||||
"Get the reaction name in (projectile, ejectile) notation (e.g., 'p(p,g)d').")
|
||||
.def("chapter", &gridfire::reaction::Reaction::chapter,
|
||||
"Get the REACLIB chapter number defining the reaction structure.")
|
||||
.def("sourceLabel", &gridfire::reaction::Reaction::sourceLabel,
|
||||
"Get the source label for the rate data (e.g., 'wc12w', 'st08').")
|
||||
.def("rateCoefficients", &gridfire::reaction::Reaction::rateCoefficients,
|
||||
"get the set of rate coefficients.")
|
||||
.def("contains", &gridfire::reaction::Reaction::contains,
|
||||
py::arg("species"), "Check if the reaction contains a specific species.")
|
||||
.def("contains_reactant", &gridfire::reaction::Reaction::contains_reactant,
|
||||
"Check if the reaction contains a specific reactant species.")
|
||||
.def("contains_product", &gridfire::reaction::Reaction::contains_product,
|
||||
"Check if the reaction contains a specific product species.")
|
||||
.def("all_species", &gridfire::reaction::Reaction::all_species,
|
||||
"Get all species involved in the reaction (both reactants and products) as a set.")
|
||||
.def("reactant_species", &gridfire::reaction::Reaction::reactant_species,
|
||||
"Get the reactant species of the reaction as a set.")
|
||||
.def("product_species", &gridfire::reaction::Reaction::product_species,
|
||||
"Get the product species of the reaction as a set.")
|
||||
.def("num_species", &gridfire::reaction::Reaction::num_species,
|
||||
"Count the number of species in the reaction.")
|
||||
.def("stoichiometry", static_cast<int (gridfire::reaction::Reaction::*)(const Species&) const>(&gridfire::reaction::Reaction::stoichiometry),
|
||||
py::arg("species"),
|
||||
"Get the stoichiometry of the reaction as a map from species to their coefficients.")
|
||||
.def("stoichiometry", static_cast<std::unordered_map<Species, int> (gridfire::reaction::Reaction::*)() const>(&gridfire::reaction::Reaction::stoichiometry),
|
||||
"Get the stoichiometry of the reaction as a map from species to their coefficients.")
|
||||
.def("id", &gridfire::reaction::Reaction::id,
|
||||
"Get the unique identifier of the reaction.")
|
||||
.def("qValue", &gridfire::reaction::Reaction::qValue,
|
||||
"Get the Q-value of the reaction in MeV.")
|
||||
.def("reactants", &gridfire::reaction::Reaction::reactants,
|
||||
"Get a list of reactant species in the reaction.")
|
||||
.def("products", &gridfire::reaction::Reaction::products,
|
||||
"Get a list of product species in the reaction.")
|
||||
.def("is_reverse", &gridfire::reaction::Reaction::is_reverse,
|
||||
"Check if this is a reverse reaction rate.")
|
||||
.def("excess_energy", &gridfire::reaction::Reaction::excess_energy,
|
||||
"Calculate the excess energy from the mass difference of reactants and products.")
|
||||
.def("__eq__", &gridfire::reaction::Reaction::operator==,
|
||||
"Equality operator for reactions based on their IDs.")
|
||||
.def("__neq__", &gridfire::reaction::Reaction::operator!=,
|
||||
"Inequality operator for reactions based on their IDs.")
|
||||
.def("hash", &gridfire::reaction::Reaction::hash,
|
||||
py::arg("seed") = 0,
|
||||
"Compute a hash for the reaction based on its ID.")
|
||||
.def("__repr__", [](const gridfire::reaction::Reaction& self) {
|
||||
std::stringstream ss;
|
||||
ss << self; // Use the existing operator<< for Reaction
|
||||
return ss.str();
|
||||
});
|
||||
|
||||
py::class_<gridfire::reaction::LogicalReaction, gridfire::reaction::Reaction>(m, "LogicalReaction")
|
||||
.def(py::init<const std::vector<gridfire::reaction::Reaction>>(),
|
||||
py::arg("reactions"),
|
||||
"Construct a LogicalReaction from a vector of Reaction objects.")
|
||||
.def("add_reaction", &gridfire::reaction::LogicalReaction::add_reaction,
|
||||
py::arg("reaction"),
|
||||
"Add another Reaction source to this logical reaction.")
|
||||
.def("size", &gridfire::reaction::LogicalReaction::size,
|
||||
"Get the number of source rates contributing to this logical reaction.")
|
||||
.def("__len__", &gridfire::reaction::LogicalReaction::size,
|
||||
"Overload len() to return the number of source rates.")
|
||||
.def("sources", &gridfire::reaction::LogicalReaction::sources,
|
||||
"Get the list of source labels for the aggregated rates.")
|
||||
.def("calculate_rate", static_cast<double (gridfire::reaction::LogicalReaction::*)(double) const>(&gridfire::reaction::LogicalReaction::calculate_rate),
|
||||
py::arg("T9"), "Calculate the reaction rate at a given temperature T9 (in units of 10^9 K).")
|
||||
.def("calculate_forward_rate_log_derivative", &gridfire::reaction::LogicalReaction::calculate_forward_rate_log_derivative,
|
||||
py::arg("T9"), "Calculate the forward rate log derivative at a given temperature T9 (in units of 10^9 K).");
|
||||
|
||||
py::class_<gridfire::reaction::LogicalReactionSet>(m, "LogicalReactionSet")
|
||||
.def(py::init<const std::vector<gridfire::reaction::LogicalReaction>>(),
|
||||
py::arg("reactions"),
|
||||
"Construct a LogicalReactionSet from a vector of LogicalReaction objects.")
|
||||
.def(py::init<>(),
|
||||
"Default constructor for an empty LogicalReactionSet.")
|
||||
.def(py::init<const gridfire::reaction::LogicalReactionSet&>(),
|
||||
py::arg("other"),
|
||||
"Copy constructor for LogicalReactionSet.")
|
||||
.def("add_reaction", &gridfire::reaction::LogicalReactionSet::add_reaction,
|
||||
py::arg("reaction"),
|
||||
"Add a LogicalReaction to the set.")
|
||||
.def("remove_reaction", &gridfire::reaction::LogicalReactionSet::remove_reaction,
|
||||
py::arg("reaction"),
|
||||
"Remove a LogicalReaction from the set.")
|
||||
.def("contains", py::overload_cast<const std::string_view&>(&gridfire::reaction::LogicalReactionSet::contains, py::const_),
|
||||
py::class_<gridfire::reaction::ReaclibReaction>(m, "ReaclibReaction")
|
||||
.def(
|
||||
py::init<
|
||||
const std::string_view,
|
||||
const std::string_view,
|
||||
int,
|
||||
const std::vector<Species>&,
|
||||
const std::vector<Species>&,
|
||||
double, std::string_view,
|
||||
gridfire::reaction::RateCoefficientSet,
|
||||
bool
|
||||
>(),
|
||||
py::arg("id"),
|
||||
"Check if the set contains a specific LogicalReaction.")
|
||||
.def("contains", py::overload_cast<const gridfire::reaction::Reaction&>(&gridfire::reaction::LogicalReactionSet::contains, py::const_),
|
||||
py::arg("reaction"),
|
||||
"Check if the set contains a specific Reaction.")
|
||||
.def("size", &gridfire::reaction::LogicalReactionSet::size,
|
||||
"Get the number of LogicalReactions in the set.")
|
||||
.def("__len__", &gridfire::reaction::LogicalReactionSet::size,
|
||||
"Overload len() to return the number of LogicalReactions.")
|
||||
.def("clear", &gridfire::reaction::LogicalReactionSet::clear,
|
||||
"Remove all LogicalReactions from the set.")
|
||||
.def("containes_species", &gridfire::reaction::LogicalReactionSet::contains_species,
|
||||
py::arg("species"),
|
||||
"Check if any reaction in the set involves the given species.")
|
||||
.def("contains_reactant", &gridfire::reaction::LogicalReactionSet::contains_reactant,
|
||||
py::arg("species"),
|
||||
"Check if any reaction in the set has the species as a reactant.")
|
||||
.def("contains_product", &gridfire::reaction::LogicalReactionSet::contains_product,
|
||||
py::arg("species"),
|
||||
"Check if any reaction in the set has the species as a product.")
|
||||
.def("__getitem__", py::overload_cast<size_t>(&gridfire::reaction::LogicalReactionSet::operator[], py::const_),
|
||||
py::arg("index"),
|
||||
"Get a LogicalReaction by index.")
|
||||
.def("__getitem___", py::overload_cast<const std::string_view&>(&gridfire::reaction::LogicalReactionSet::operator[], py::const_),
|
||||
py::arg("id"),
|
||||
"Get a LogicalReaction by its ID.")
|
||||
.def("__eq__", &gridfire::reaction::LogicalReactionSet::operator==,
|
||||
py::arg("LogicalReactionSet"),
|
||||
"Equality operator for LogicalReactionSets based on their contents.")
|
||||
.def("__ne__", &gridfire::reaction::LogicalReactionSet::operator!=,
|
||||
py::arg("LogicalReactionSet"),
|
||||
"Inequality operator for LogicalReactionSets based on their contents.")
|
||||
.def("hash", &gridfire::reaction::LogicalReactionSet::hash,
|
||||
py::arg("seed") = 0,
|
||||
"Compute a hash for the LogicalReactionSet based on its contents."
|
||||
py::arg("peName"),
|
||||
py::arg("chapter"),
|
||||
py::arg("reactants"),
|
||||
py::arg("products"),
|
||||
py::arg("qValue"),
|
||||
py::arg("label"),
|
||||
py::arg("sets"),
|
||||
py::arg("reverse") = false,
|
||||
"Construct a Reaction with the given parameters."
|
||||
)
|
||||
.def("__repr__", [](const gridfire::reaction::LogicalReactionSet& self) {
|
||||
std::stringstream ss;
|
||||
ss << self;
|
||||
return ss.str();
|
||||
})
|
||||
.def("getReactionSetSpecies", &gridfire::reaction::LogicalReactionSet::getReactionSetSpecies,
|
||||
"Get all species involved in the reactions of the set as a set of Species objects.");
|
||||
.def(
|
||||
"calculate_rate",
|
||||
[](const gridfire::reaction::ReaclibReaction& self, const double T9, const double rho, const std::vector<double>& Y) -> double {
|
||||
return self.calculate_rate(T9, rho, Y);
|
||||
},
|
||||
py::arg("T9"),
|
||||
py::arg("rho"),
|
||||
py::arg("Y"),
|
||||
"Calculate the reaction rate at a given temperature T9 (in units of 10^9 K)."
|
||||
)
|
||||
.def(
|
||||
"peName",
|
||||
&gridfire::reaction::ReaclibReaction::peName,
|
||||
"Get the reaction name in (projectile, ejectile) notation (e.g., 'p(p,g)d')."
|
||||
)
|
||||
.def(
|
||||
"chapter",
|
||||
&gridfire::reaction::ReaclibReaction::chapter,
|
||||
"Get the REACLIB chapter number defining the reaction structure."
|
||||
)
|
||||
.def(
|
||||
"sourceLabel",
|
||||
&gridfire::reaction::ReaclibReaction::sourceLabel,
|
||||
"Get the source label for the rate data (e.g., 'wc12w', 'st08')."
|
||||
)
|
||||
.def(
|
||||
"rateCoefficients",
|
||||
&gridfire::reaction::ReaclibReaction::rateCoefficients,
|
||||
"get the set of rate coefficients."
|
||||
)
|
||||
.def(
|
||||
"contains",
|
||||
&gridfire::reaction::ReaclibReaction::contains,
|
||||
py::arg("species"),
|
||||
"Check if the reaction contains a specific species."
|
||||
)
|
||||
.def(
|
||||
"contains_reactant",
|
||||
&gridfire::reaction::ReaclibReaction::contains_reactant,
|
||||
"Check if the reaction contains a specific reactant species."
|
||||
)
|
||||
.def(
|
||||
"contains_product",
|
||||
&gridfire::reaction::ReaclibReaction::contains_product,
|
||||
"Check if the reaction contains a specific product species."
|
||||
)
|
||||
.def(
|
||||
"all_species",
|
||||
&gridfire::reaction::ReaclibReaction::all_species,
|
||||
"Get all species involved in the reaction (both reactants and products) as a set."
|
||||
)
|
||||
.def(
|
||||
"reactant_species",
|
||||
&gridfire::reaction::ReaclibReaction::reactant_species,
|
||||
"Get the reactant species of the reaction as a set."
|
||||
)
|
||||
.def(
|
||||
"product_species",
|
||||
&gridfire::reaction::ReaclibReaction::product_species,
|
||||
"Get the product species of the reaction as a set."
|
||||
)
|
||||
.def(
|
||||
"num_species",
|
||||
&gridfire::reaction::ReaclibReaction::num_species,
|
||||
"Count the number of species in the reaction."
|
||||
)
|
||||
.def(
|
||||
"stoichiometry",
|
||||
[](const gridfire::reaction::ReaclibReaction& self, const Species& species) -> int {
|
||||
return self.stoichiometry(species);
|
||||
},
|
||||
py::arg("species"),
|
||||
"Get the stoichiometry of the reaction as a map from species to their coefficients."
|
||||
)
|
||||
.def(
|
||||
"stoichiometry",
|
||||
[](const gridfire::reaction::ReaclibReaction& self) -> std::unordered_map<Species, int> {
|
||||
return self.stoichiometry();
|
||||
},
|
||||
"Get the stoichiometry of the reaction as a map from species to their coefficients."
|
||||
)
|
||||
.def(
|
||||
"id",
|
||||
&gridfire::reaction::ReaclibReaction::id,
|
||||
"Get the unique identifier of the reaction."
|
||||
)
|
||||
.def(
|
||||
"qValue",
|
||||
&gridfire::reaction::ReaclibReaction::qValue,
|
||||
"Get the Q-value of the reaction in MeV."
|
||||
)
|
||||
.def(
|
||||
"reactants",
|
||||
&gridfire::reaction::ReaclibReaction::reactants,
|
||||
"Get a list of reactant species in the reaction."
|
||||
)
|
||||
.def(
|
||||
"products",
|
||||
&gridfire::reaction::ReaclibReaction::products,
|
||||
"Get a list of product species in the reaction."
|
||||
)
|
||||
.def(
|
||||
"is_reverse",
|
||||
&gridfire::reaction::ReaclibReaction::is_reverse,
|
||||
"Check if this is a reverse reaction rate."
|
||||
)
|
||||
.def(
|
||||
"excess_energy",
|
||||
&gridfire::reaction::ReaclibReaction::excess_energy,
|
||||
"Calculate the excess energy from the mass difference of reactants and products."
|
||||
)
|
||||
.def(
|
||||
"__eq__",
|
||||
&gridfire::reaction::ReaclibReaction::operator==,
|
||||
"Equality operator for reactions based on their IDs."
|
||||
)
|
||||
.def(
|
||||
"__neq__",
|
||||
&gridfire::reaction::ReaclibReaction::operator!=,
|
||||
"Inequality operator for reactions based on their IDs."
|
||||
)
|
||||
.def(
|
||||
"hash",
|
||||
&gridfire::reaction::ReaclibReaction::hash,
|
||||
py::arg("seed") = 0,
|
||||
"Compute a hash for the reaction based on its ID."
|
||||
)
|
||||
.def(
|
||||
"__repr__",
|
||||
[](const gridfire::reaction::ReaclibReaction& self) {
|
||||
std::stringstream ss;
|
||||
ss << self; // Use the existing operator<< for Reaction
|
||||
return ss.str();
|
||||
}
|
||||
);
|
||||
|
||||
m.def("packReactionSetToLogicalReactionSet",
|
||||
&gridfire::reaction::packReactionSetToLogicalReactionSet,
|
||||
py::class_<gridfire::reaction::LogicalReaclibReaction, gridfire::reaction::ReaclibReaction>(m, "LogicalReaclibReaction")
|
||||
.def(
|
||||
py::init<const std::vector<gridfire::reaction::Reaction>>(),
|
||||
py::arg("reactions"),
|
||||
"Construct a LogicalReaclibReaction from a vector of Reaction objects."
|
||||
)
|
||||
.def(
|
||||
"add_reaction",
|
||||
&gridfire::reaction::LogicalReaclibReaction::add_reaction,
|
||||
py::arg("reaction"),
|
||||
"Add another Reaction source to this logical reaction."
|
||||
)
|
||||
.def(
|
||||
"size",
|
||||
&gridfire::reaction::LogicalReaclibReaction::size,
|
||||
"Get the number of source rates contributing to this logical reaction."
|
||||
)
|
||||
.def(
|
||||
"__len__",
|
||||
&gridfire::reaction::LogicalReaclibReaction::size,
|
||||
"Overload len() to return the number of source rates."
|
||||
)
|
||||
.def(
|
||||
"sources",
|
||||
&gridfire::reaction::LogicalReaclibReaction::sources,
|
||||
"Get the list of source labels for the aggregated rates."
|
||||
)
|
||||
.def(
|
||||
"calculate_rate",
|
||||
[](const gridfire::reaction::LogicalReaclibReaction& self, const double T9, const double rho, const std::vector<double>& Y) -> double {
|
||||
return self.calculate_rate(T9, rho, Y);
|
||||
},
|
||||
py::arg("T9"),
|
||||
"Calculate the reaction rate at a given temperature T9 (in units of 10^9 K)."
|
||||
)
|
||||
.def(
|
||||
"calculate_forward_rate_log_derivative",
|
||||
&gridfire::reaction::LogicalReaclibReaction::calculate_forward_rate_log_derivative,
|
||||
py::arg("T9"),
|
||||
"Calculate the forward rate log derivative at a given temperature T9 (in units of 10^9 K)."
|
||||
);
|
||||
|
||||
py::class_<gridfire::reaction::ReactionSet>(m, "ReactionSet")
|
||||
// TODO: Fix the constructor to accept a vector of unique ptrs to Reaclib Reactions
|
||||
.def(
|
||||
py::init<const std::vector<gridfire::reaction::Reaction>>(),
|
||||
py::arg("reactions"),
|
||||
"Construct a LogicalReactionSet from a vector of LogicalReaclibReaction objects."
|
||||
)
|
||||
.def(
|
||||
py::init<>(),
|
||||
"Default constructor for an empty LogicalReactionSet."
|
||||
)
|
||||
.def(
|
||||
py::init<const gridfire::reaction::ReactionSet&>(),
|
||||
py::arg("other"),
|
||||
"Copy constructor for LogicalReactionSet."
|
||||
)
|
||||
.def(
|
||||
"add_reaction",
|
||||
py::overload_cast<const gridfire::reaction::Reaction&>(&gridfire::reaction::ReactionSet::add_reaction),
|
||||
py::arg("reaction"),
|
||||
"Add a LogicalReaclibReaction to the set."
|
||||
)
|
||||
.def(
|
||||
"remove_reaction",
|
||||
&gridfire::reaction::ReactionSet::remove_reaction,
|
||||
py::arg("reaction"),
|
||||
"Remove a LogicalReaclibReaction from the set."
|
||||
)
|
||||
.def(
|
||||
"contains",
|
||||
py::overload_cast<const std::string_view&>(&gridfire::reaction::ReactionSet::contains, py::const_),
|
||||
py::arg("id"),
|
||||
"Check if the set contains a specific LogicalReaclibReaction."
|
||||
)
|
||||
.def(
|
||||
"contains",
|
||||
py::overload_cast<const gridfire::reaction::Reaction&>(&gridfire::reaction::ReactionSet::contains, py::const_),
|
||||
py::arg("reaction"),
|
||||
"Check if the set contains a specific Reaction."
|
||||
)
|
||||
.def(
|
||||
"size",
|
||||
&gridfire::reaction::ReactionSet::size,
|
||||
"Get the number of LogicalReactions in the set."
|
||||
)
|
||||
.def(
|
||||
"__len__", &gridfire::reaction::ReactionSet::size,
|
||||
"Overload len() to return the number of LogicalReactions."
|
||||
)
|
||||
.def(
|
||||
"clear",
|
||||
&gridfire::reaction::ReactionSet::clear,
|
||||
"Remove all LogicalReactions from the set."
|
||||
)
|
||||
.def("contains_species",
|
||||
&gridfire::reaction::ReactionSet::contains_species,
|
||||
py::arg("species"),
|
||||
"Check if any reaction in the set involves the given species."
|
||||
)
|
||||
.def(
|
||||
"contains_reactant",
|
||||
&gridfire::reaction::ReactionSet::contains_reactant,
|
||||
py::arg("species"),
|
||||
"Check if any reaction in the set has the species as a reactant."
|
||||
)
|
||||
.def(
|
||||
"contains_product",
|
||||
&gridfire::reaction::ReactionSet::contains_product,
|
||||
py::arg("species"),
|
||||
"Check if any reaction in the set has the species as a product."
|
||||
)
|
||||
.def(
|
||||
"__getitem__",
|
||||
py::overload_cast<size_t>(&gridfire::reaction::ReactionSet::operator[], py::const_),
|
||||
py::arg("index"),
|
||||
"Get a LogicalReaclibReaction by index."
|
||||
)
|
||||
.def(
|
||||
"__getitem___",
|
||||
py::overload_cast<const std::string_view&>(&gridfire::reaction::ReactionSet::operator[], py::const_),
|
||||
py::arg("id"),
|
||||
"Get a LogicalReaclibReaction by its ID."
|
||||
)
|
||||
.def(
|
||||
"__eq__",
|
||||
&gridfire::reaction::ReactionSet::operator==,
|
||||
py::arg("LogicalReactionSet"),
|
||||
"Equality operator for LogicalReactionSets based on their contents."
|
||||
)
|
||||
.def(
|
||||
"__ne__",
|
||||
&gridfire::reaction::ReactionSet::operator!=,
|
||||
py::arg("LogicalReactionSet"),
|
||||
"Inequality operator for LogicalReactionSets based on their contents."
|
||||
)
|
||||
.def(
|
||||
"hash",
|
||||
&gridfire::reaction::ReactionSet::hash,
|
||||
py::arg("seed") = 0,
|
||||
"Compute a hash for the LogicalReactionSet based on its contents."
|
||||
)
|
||||
.def(
|
||||
"__repr__",
|
||||
[](const gridfire::reaction::ReactionSet& self) {
|
||||
std::stringstream ss;
|
||||
ss << self;
|
||||
return ss.str();
|
||||
}
|
||||
)
|
||||
.def(
|
||||
"getReactionSetSpecies",
|
||||
&gridfire::reaction::ReactionSet::getReactionSetSpecies,
|
||||
"Get all species involved in the reactions of the set as a set of Species objects."
|
||||
);
|
||||
|
||||
m.def(
|
||||
"packReactionSet",
|
||||
&gridfire::reaction::packReactionSet,
|
||||
py::arg("reactionSet"),
|
||||
"Convert a ReactionSet to a LogicalReactionSet by aggregating reactions with the same peName."
|
||||
);
|
||||
|
||||
m.def("get_all_reactions", &gridfire::reaclib::get_all_reactions,
|
||||
"Get all reactions from the REACLIB database.");
|
||||
m.def(
|
||||
"get_all_reactions",
|
||||
&gridfire::reaclib::get_all_reaclib_reactions,
|
||||
"Get all reactions from the REACLIB database."
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
|
||||
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc if needed directly
|
||||
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc. if needed directly
|
||||
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
|
||||
#include "bindings.h"
|
||||
@@ -13,30 +12,53 @@
|
||||
namespace py = pybind11;
|
||||
|
||||
void register_screening_bindings(py::module &m) {
|
||||
py::class_<gridfire::screening::ScreeningModel, PyScreening>(m, "ScreeningModel");
|
||||
auto screening_model = py::class_<gridfire::screening::ScreeningModel, PyScreening>(m, "ScreeningModel");
|
||||
|
||||
py::enum_<gridfire::screening::ScreeningType>(m, "ScreeningType")
|
||||
.value("BARE", gridfire::screening::ScreeningType::BARE)
|
||||
.value("WEAK", gridfire::screening::ScreeningType::WEAK)
|
||||
.export_values();
|
||||
|
||||
m.def("selectScreeningModel", &gridfire::screening::selectScreeningModel,
|
||||
py::arg("type"),
|
||||
"Select a screening model based on the specified type. Returns a pointer to the selected model.");
|
||||
m.def(
|
||||
"selectScreeningModel",
|
||||
&gridfire::screening::selectScreeningModel,
|
||||
py::arg("type"),
|
||||
"Select a screening model based on the specified type. Returns a pointer to the selected model."
|
||||
);
|
||||
|
||||
py::class_<gridfire::screening::BareScreeningModel>(m, "BareScreeningModel")
|
||||
.def(py::init<>())
|
||||
.def("calculateScreeningFactors",
|
||||
py::overload_cast<const gridfire::reaction::LogicalReactionSet&, const std::vector<fourdst::atomic::Species>&, const std::vector<double>&, double, double>(&gridfire::screening::BareScreeningModel::calculateScreeningFactors, py::const_),
|
||||
py::arg("reactions"), py::arg("species"), py::arg("Y"), py::arg("T9"), py::arg("rho"),
|
||||
"Calculate the bare plasma screening factors. This always returns 1.0 (bare)"
|
||||
);
|
||||
py::overload_cast<
|
||||
const gridfire::reaction::ReactionSet&,
|
||||
const std::vector<fourdst::atomic::Species>&,
|
||||
const std::vector<double>&,
|
||||
double,
|
||||
double
|
||||
>(&gridfire::screening::BareScreeningModel::calculateScreeningFactors, py::const_),
|
||||
py::arg("reactions"),
|
||||
py::arg("species"),
|
||||
py::arg("Y"),
|
||||
py::arg("T9"),
|
||||
py::arg("rho"),
|
||||
"Calculate the bare plasma screening factors. This always returns 1.0 (bare)"
|
||||
);
|
||||
|
||||
py::class_<gridfire::screening::WeakScreeningModel>(m, "WeakScreeningModel")
|
||||
.def(py::init<>())
|
||||
.def("calculateScreeningFactors",
|
||||
py::overload_cast<const gridfire::reaction::LogicalReactionSet&, const std::vector<fourdst::atomic::Species>&, const std::vector<double>&, double, double>(&gridfire::screening::WeakScreeningModel::calculateScreeningFactors, py::const_),
|
||||
py::arg("reactions"), py::arg("species"), py::arg("Y"), py::arg("T9"), py::arg("rho"),
|
||||
"Calculate the weak plasma screening factors using the Salpeter (1954) model."
|
||||
);
|
||||
py::overload_cast<
|
||||
const gridfire::reaction::ReactionSet&,
|
||||
const std::vector<fourdst::atomic::Species>&,
|
||||
const std::vector<double>&,
|
||||
double,
|
||||
double
|
||||
>(&gridfire::screening::WeakScreeningModel::calculateScreeningFactors, py::const_),
|
||||
py::arg("reactions"),
|
||||
py::arg("species"),
|
||||
py::arg("Y"),
|
||||
py::arg("T9"),
|
||||
py::arg("rho"),
|
||||
"Calculate the weak plasma screening factors using the Salpeter (1954) model."
|
||||
);
|
||||
}
|
||||
@@ -13,7 +13,13 @@
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
std::vector<double> PyScreening::calculateScreeningFactors(const gridfire::reaction::LogicalReactionSet &reactions, const std::vector<fourdst::atomic::Species> &species, const std::vector<double> &Y, const double T9, const double rho) const {
|
||||
std::vector<double> PyScreening::calculateScreeningFactors(
|
||||
const gridfire::reaction::ReactionSet &reactions,
|
||||
const std::vector<fourdst::atomic::Species> &species,
|
||||
const std::vector<double> &Y,
|
||||
const double T9,
|
||||
const double rho
|
||||
) const {
|
||||
PYBIND11_OVERLOAD_PURE(
|
||||
std::vector<double>, // Return type
|
||||
gridfire::screening::ScreeningModel,
|
||||
@@ -22,7 +28,13 @@ std::vector<double> PyScreening::calculateScreeningFactors(const gridfire::react
|
||||
}
|
||||
|
||||
using ADDouble = gridfire::screening::ScreeningModel::ADDouble;
|
||||
std::vector<ADDouble> PyScreening::calculateScreeningFactors(const gridfire::reaction::LogicalReactionSet &reactions, const std::vector<fourdst::atomic::Species> &species, const std::vector<ADDouble> &Y, const ADDouble T9, const ADDouble rho) const {
|
||||
std::vector<ADDouble> PyScreening::calculateScreeningFactors(
|
||||
const gridfire::reaction::ReactionSet &reactions,
|
||||
const std::vector<fourdst::atomic::Species> &species,
|
||||
const std::vector<ADDouble> &Y,
|
||||
const ADDouble T9,
|
||||
const ADDouble rho
|
||||
) const {
|
||||
PYBIND11_OVERLOAD_PURE(
|
||||
std::vector<ADDouble>, // Return type
|
||||
gridfire::screening::ScreeningModel,
|
||||
|
||||
@@ -2,15 +2,24 @@
|
||||
|
||||
#include "gridfire/screening/screening.h"
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <pybind11/functional.h> // Needed for std::function
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "cppad/cppad.hpp"
|
||||
|
||||
class PyScreening final : public gridfire::screening::ScreeningModel {
|
||||
std::vector<double> calculateScreeningFactors(const gridfire::reaction::LogicalReactionSet &reactions, const std::vector<fourdst::atomic::Species> &species, const std::vector<double> &Y, const double T9, const double rho) const override;
|
||||
std::vector<ADDouble> calculateScreeningFactors(const gridfire::reaction::LogicalReactionSet &reactions, const std::vector<fourdst::atomic::Species> &species, const std::vector<ADDouble> &Y, const ADDouble T9, const ADDouble rho) const override;
|
||||
[[nodiscard]] std::vector<double> calculateScreeningFactors(
|
||||
const gridfire::reaction::ReactionSet &reactions,
|
||||
const std::vector<fourdst::atomic::Species> &species,
|
||||
const std::vector<double> &Y,
|
||||
double T9,
|
||||
double rho
|
||||
) const override;
|
||||
|
||||
[[nodiscard]] std::vector<ADDouble> calculateScreeningFactors(
|
||||
const gridfire::reaction::ReactionSet &reactions,
|
||||
const std::vector<fourdst::atomic::Species> &species,
|
||||
const std::vector<ADDouble> &Y,
|
||||
ADDouble T9,
|
||||
ADDouble rho
|
||||
) const override;
|
||||
};
|
||||
@@ -2,7 +2,6 @@
|
||||
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
|
||||
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc. if needed directly
|
||||
#include <pybind11/numpy.h>
|
||||
#include <pybind11/functional.h> // Needed for std::function
|
||||
|
||||
#include <boost/numeric/ublas/vector.hpp>
|
||||
|
||||
@@ -30,7 +29,7 @@ void register_solver_bindings(const py::module &m) {
|
||||
);
|
||||
|
||||
py_direct_network_solver.def("set_callback",
|
||||
[](gridfire::solver::DirectNetworkSolver &self, gridfire::solver::DirectNetworkSolver::TimestepCallback cb) {
|
||||
[](gridfire::solver::DirectNetworkSolver &self, const gridfire::solver::DirectNetworkSolver::TimestepCallback& cb) {
|
||||
self.set_callback(cb);
|
||||
},
|
||||
py::arg("callback"),
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <pybind11/functional.h> // Needed for std::function
|
||||
|
||||
#include <vector>
|
||||
#include <tuple>
|
||||
|
||||
@@ -11,5 +11,5 @@ class PyDynamicNetworkSolverStrategy final : public gridfire::solver::DynamicNet
|
||||
explicit PyDynamicNetworkSolverStrategy(gridfire::DynamicEngine &engine) : gridfire::solver::DynamicNetworkSolverStrategy(engine) {}
|
||||
gridfire::NetOut evaluate(const gridfire::NetIn &netIn) override;
|
||||
void set_callback(const std::any &callback) override;
|
||||
std::vector<std::tuple<std::string, std::string>> describe_callback_context() const override;
|
||||
[[nodiscard]] std::vector<std::tuple<std::string, std::string>> describe_callback_context() const override;
|
||||
};
|
||||
@@ -8,7 +8,7 @@ namespace py = pybind11;
|
||||
|
||||
#include "gridfire/network.h"
|
||||
|
||||
void register_type_bindings(pybind11::module &m) {
|
||||
void register_type_bindings(const pybind11::module &m) {
|
||||
py::class_<gridfire::NetIn>(m, "NetIn")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("composition", &gridfire::NetIn::composition)
|
||||
|
||||
@@ -2,4 +2,4 @@
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
void register_type_bindings(pybind11::module &m);
|
||||
void register_type_bindings(const pybind11::module &m);
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
|
||||
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc. if needed directly
|
||||
|
||||
#include "bindings.h"
|
||||
|
||||
Reference in New Issue
Block a user