Merge pull request #67 from tboudreaux/feature/pythonInterface/poly
MFEM and polytrope python bindings
This commit is contained in:
@@ -9,6 +9,11 @@ py_mod = py_installation.extension_module(
|
||||
meson.project_source_root() + '/src/python/const/bindings.cpp',
|
||||
meson.project_source_root() + '/src/python/config/bindings.cpp',
|
||||
meson.project_source_root() + '/src/python/eos/bindings.cpp',
|
||||
meson.project_source_root() + '/src/python/mfem/bindings.cpp',
|
||||
meson.project_source_root() + '/src/python/mfem/Trampoline/PyMFEMTrampolines/Operator/PyOperator.cpp',
|
||||
meson.project_source_root() + '/src/python/mfem/Trampoline/PyMFEMTrampolines/Operator/Matrix/PyMatrix.cpp',
|
||||
meson.project_source_root() + '/src/python/mfem/Trampoline/PyMFEMTrampolines/Coefficient/PyCoefficient.cpp',
|
||||
meson.project_source_root() + '/src/python/polytrope/bindings.cpp',
|
||||
],
|
||||
dependencies : [
|
||||
pybind11_dep,
|
||||
@@ -16,7 +21,10 @@ py_mod = py_installation.extension_module(
|
||||
config_dep,
|
||||
composition_dep,
|
||||
eos_dep,
|
||||
species_weight_dep
|
||||
species_weight_dep,
|
||||
mfem_dep,
|
||||
polysolver_dep,
|
||||
trampoline_dep
|
||||
],
|
||||
cpp_args : ['-UNDEBUG'], # Example: Ensure assertions are enabled if needed
|
||||
install : true,
|
||||
|
||||
@@ -39,8 +39,8 @@
|
||||
#include "utilities.h"
|
||||
#include "quill/LogMacros.h"
|
||||
|
||||
namespace serif {
|
||||
namespace polytrope {
|
||||
|
||||
namespace serif::polytrope {
|
||||
|
||||
namespace laneEmden {
|
||||
|
||||
@@ -382,7 +382,6 @@ void PolySolver::setInitialGuess() const {
|
||||
serif::probe::glVisView(*m_theta, m_mesh, "θ init");
|
||||
serif::probe::glVisView(*m_phi, m_mesh, "φ init");
|
||||
}
|
||||
std::cout << "HERE" << std::endl;
|
||||
|
||||
}
|
||||
|
||||
@@ -497,5 +496,5 @@ solverBundle PolySolver::setupNewtonSolver() const {
|
||||
return solver;
|
||||
}
|
||||
|
||||
} // namespace polytrope
|
||||
} // namespace serif
|
||||
} // namespace serif::polytrope
|
||||
|
||||
|
||||
@@ -271,7 +271,14 @@ public: // Public methods
|
||||
* @return A reference to the `mfem::GridFunction` storing the \f$\theta\f$ solution.
|
||||
* @note The solution is populated after `solve()` has been successfully called.
|
||||
*/
|
||||
mfem::GridFunction& getSolution() const { return *m_theta; }
|
||||
mfem::GridFunction& getTheta() const { return *m_theta; }
|
||||
|
||||
/**
|
||||
* @brief Gets a reference to the solution grid function for \f$\phi\f$.
|
||||
* @return A reference to the `mfem::GridFunction` storing the \f$\phi\f$ solution.
|
||||
* @note The solution is populated after `solve()` has been successfully called.
|
||||
*/
|
||||
mfem::GridFunction& getPhi() const { return *m_phi; }
|
||||
|
||||
private: // Private Attributes
|
||||
// --- Configuration and Logging ---
|
||||
|
||||
@@ -119,7 +119,6 @@ namespace serif::utilities {
|
||||
mfem::IsoparametricTransformation T;
|
||||
mesh->GetElementTransformation(ne, &T);
|
||||
phi_gf.GetCurl(T, curl_mag_vec);
|
||||
std::cout << "HERE" << std::endl;
|
||||
}
|
||||
mfem::L2_FECollection fac(order, dim);
|
||||
mfem::FiniteElementSpace fs(mesh, &fac);
|
||||
|
||||
@@ -31,27 +31,26 @@
|
||||
* @brief A collection of utilities for working with MFEM and solving the lane-emden equation.
|
||||
*/
|
||||
|
||||
namespace serif {
|
||||
namespace polytrope {
|
||||
/**
|
||||
namespace serif::polytrope {
|
||||
/**
|
||||
* @namespace polyMFEMUtils
|
||||
* @brief A namespace for utilities for working with MFEM and solving the lane-emden equation.
|
||||
*/
|
||||
namespace polyMFEMUtils {
|
||||
/**
|
||||
namespace polyMFEMUtils {
|
||||
/**
|
||||
* @brief A class for nonlinear power integrator.
|
||||
*/
|
||||
class NonlinearPowerIntegrator: public mfem::NonlinearFormIntegrator {
|
||||
public:
|
||||
/**
|
||||
class NonlinearPowerIntegrator: public mfem::NonlinearFormIntegrator {
|
||||
public:
|
||||
/**
|
||||
* @brief Constructor for NonlinearPowerIntegrator.
|
||||
*
|
||||
* @param coeff The function coefficient.
|
||||
* @param n The polytropic index.
|
||||
*/
|
||||
NonlinearPowerIntegrator(double n);
|
||||
NonlinearPowerIntegrator(double n);
|
||||
|
||||
/**
|
||||
/**
|
||||
* @brief Assembles the element vector.
|
||||
*
|
||||
* @param el The finite element.
|
||||
@@ -59,8 +58,8 @@ namespace polyMFEMUtils {
|
||||
* @param elfun The element function.
|
||||
* @param elvect The element vector to be assembled.
|
||||
*/
|
||||
virtual void AssembleElementVector(const mfem::FiniteElement &el, mfem::ElementTransformation &Trans, const mfem::Vector &elfun, mfem::Vector &elvect) override;
|
||||
/**
|
||||
virtual void AssembleElementVector(const mfem::FiniteElement &el, mfem::ElementTransformation &Trans, const mfem::Vector &elfun, mfem::Vector &elvect) override;
|
||||
/**
|
||||
* @brief Assembles the element gradient.
|
||||
*
|
||||
* @param el The finite element.
|
||||
@@ -68,40 +67,39 @@ namespace polyMFEMUtils {
|
||||
* @param elfun The element function.
|
||||
* @param elmat The element matrix to be assembled.
|
||||
*/
|
||||
virtual void AssembleElementGrad (const mfem::FiniteElement &el, mfem::ElementTransformation &Trans, const mfem::Vector &elfun, mfem::DenseMatrix &elmat) override;
|
||||
private:
|
||||
serif::config::Config& m_config = serif::config::Config::getInstance();
|
||||
serif::probe::LogManager& m_logManager = serif::probe::LogManager::getInstance();
|
||||
quill::Logger* m_logger = m_logManager.getLogger("log");
|
||||
double m_polytropicIndex;
|
||||
double m_epsilon;
|
||||
static constexpr double m_regularizationRadius = 0.15; ///< Regularization radius for the epsilon function, used to avoid singularities in the power law.
|
||||
static constexpr double m_regularizationCoeff = 1.0/6.0; ///< Coefficient for the regularization term, used to ensure smoothness in the power law.
|
||||
};
|
||||
virtual void AssembleElementGrad (const mfem::FiniteElement &el, mfem::ElementTransformation &Trans, const mfem::Vector &elfun, mfem::DenseMatrix &elmat) override;
|
||||
private:
|
||||
serif::config::Config& m_config = serif::config::Config::getInstance();
|
||||
serif::probe::LogManager& m_logManager = serif::probe::LogManager::getInstance();
|
||||
quill::Logger* m_logger = m_logManager.getLogger("log");
|
||||
double m_polytropicIndex;
|
||||
double m_epsilon;
|
||||
static constexpr double m_regularizationRadius = 0.15; ///< Regularization radius for the epsilon function, used to avoid singularities in the power law.
|
||||
static constexpr double m_regularizationCoeff = 1.0/6.0; ///< Coefficient for the regularization term, used to ensure smoothness in the power law.
|
||||
};
|
||||
|
||||
inline double dfmod(const double epsilon, const double n) {
|
||||
if (n == 0.0) {
|
||||
return 0.0;
|
||||
inline double dfmod(const double epsilon, const double n) {
|
||||
if (n == 0.0) {
|
||||
return 0.0;
|
||||
}
|
||||
if (n == 1.0) {
|
||||
return 1.0;
|
||||
}
|
||||
return n * std::pow(epsilon, n - 1.0);
|
||||
}
|
||||
if (n == 1.0) {
|
||||
return 1.0;
|
||||
|
||||
inline double fmod(const double theta, const double n, const double epsilon) {
|
||||
if (n == 0.0) {
|
||||
return 1.0;
|
||||
}
|
||||
// For n != 0
|
||||
const double y_prime_at_epsilon = dfmod(epsilon, n); // Uses the robust dfmod
|
||||
const double y_at_epsilon = std::pow(epsilon, n); // epsilon^n
|
||||
|
||||
// f_mod(theta) = y_at_epsilon + y_prime_at_epsilon * (theta - epsilon)
|
||||
return y_at_epsilon + y_prime_at_epsilon * (theta - epsilon);
|
||||
}
|
||||
return n * std::pow(epsilon, n - 1.0);
|
||||
}
|
||||
|
||||
inline double fmod(const double theta, const double n, const double epsilon) {
|
||||
if (n == 0.0) {
|
||||
return 1.0;
|
||||
}
|
||||
// For n != 0
|
||||
const double y_prime_at_epsilon = dfmod(epsilon, n); // Uses the robust dfmod
|
||||
const double y_at_epsilon = std::pow(epsilon, n); // epsilon^n
|
||||
|
||||
// f_mod(theta) = y_at_epsilon + y_prime_at_epsilon * (theta - epsilon)
|
||||
return y_at_epsilon + y_prime_at_epsilon * (theta - epsilon);
|
||||
}
|
||||
|
||||
|
||||
} // namespace polyMFEMUtils
|
||||
} // namespace polytrope
|
||||
} // namespace serif
|
||||
} // namespace polyMFEMUtils
|
||||
}
|
||||
|
||||
@@ -26,8 +26,7 @@
|
||||
|
||||
#include "probe.h"
|
||||
|
||||
namespace serif {
|
||||
namespace polytrope {
|
||||
namespace serif::polytrope {
|
||||
|
||||
/**
|
||||
* @brief Represents the Schur complement operator used in the solution process.
|
||||
@@ -299,7 +298,7 @@ private:
|
||||
std::unique_ptr<mfem::MixedBilinearForm> m_M; ///< Bilinear form M, coupling θ and φ.
|
||||
std::unique_ptr<mfem::MixedBilinearForm> m_Q; ///< Bilinear form Q, coupling φ and θ.
|
||||
std::unique_ptr<mfem::BilinearForm> m_D; ///< Bilinear form D, acting on φ.
|
||||
std::unique_ptr<mfem::BilinearForm> m_S;
|
||||
std::unique_ptr<mfem::BilinearForm> m_S; ///< Bilinear form S, used for least squares stabilization.
|
||||
std::unique_ptr<mfem::NonlinearForm> m_f; ///< Nonlinear form f, acting on θ.
|
||||
|
||||
// --- Full Matrix Representations (owned, derived from forms) ---
|
||||
@@ -395,5 +394,4 @@ private:
|
||||
void update_preconditioner(const mfem::Operator &grad) const;
|
||||
};
|
||||
|
||||
} // namespace polytrope
|
||||
} // namespace serif
|
||||
} // namespace serif::polytrope
|
||||
@@ -6,6 +6,8 @@
|
||||
#include "composition/bindings.h"
|
||||
#include "config/bindings.h"
|
||||
#include "eos/bindings.h"
|
||||
#include "mfem/bindings.h"
|
||||
#include "polytrope/bindings.h"
|
||||
|
||||
PYBIND11_MODULE(serif, m) {
|
||||
m.doc() = "Python bindings for the SERiF project";
|
||||
@@ -21,4 +23,10 @@ PYBIND11_MODULE(serif, m) {
|
||||
|
||||
auto eosMod = m.def_submodule("eos", "EOS-module bindings");
|
||||
register_eos_bindings(eosMod);
|
||||
|
||||
auto mfemMod = m.def_submodule("mfem", "MFEM bindings");
|
||||
register_mfem_bindings(mfemMod);
|
||||
|
||||
auto polytropeMod = m.def_submodule("polytrope", "Polytrope-module bindings");
|
||||
register_polytrope_bindings(polytropeMod);
|
||||
}
|
||||
@@ -5,12 +5,9 @@
|
||||
|
||||
#include <string>
|
||||
#include "helm.h"
|
||||
// #include "resourceManager.h"
|
||||
#include "bindings.h"
|
||||
#include "EOSio.h"
|
||||
#include "helm.h"
|
||||
#include "../../eos/public/EOSio.h"
|
||||
#include "../../eos/public/helm.h"
|
||||
|
||||
namespace serif::eos {
|
||||
class EOSio;
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
subdir('composition')
|
||||
subdir('const')
|
||||
subdir('config')
|
||||
|
||||
subdir('mfem')
|
||||
subdir('eos')
|
||||
@@ -0,0 +1,51 @@
|
||||
#include "PyCoefficient.h"
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <pybind11/functional.h> // Needed for std::function
|
||||
#include <memory>
|
||||
|
||||
#include "mfem.hpp"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace mfem;
|
||||
|
||||
namespace serif::pybind {
|
||||
real_t PyCoefficient::Eval(ElementTransformation &T, const IntegrationPoint &ip) {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
real_t, /* Return type */
|
||||
Coefficient, /* Base class */
|
||||
Eval, /* Method name */
|
||||
T, ip /* Arguments */
|
||||
);
|
||||
}
|
||||
|
||||
// Override virtual SetTime method
|
||||
void PyCoefficient::SetTime(real_t t) {
|
||||
PYBIND11_OVERRIDE(
|
||||
void,
|
||||
Coefficient,
|
||||
SetTime,
|
||||
t
|
||||
);
|
||||
}
|
||||
|
||||
void PyVectorCoefficient::Eval(Vector &V, ElementTransformation &T, const IntegrationPoint &ip) {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
void, /* Return type */
|
||||
VectorCoefficient, /* Base class */
|
||||
Eval, /* Method name */
|
||||
V, T, ip /* Arguments */
|
||||
);
|
||||
}
|
||||
|
||||
// Override the virtual SetTime method
|
||||
void PyVectorCoefficient::SetTime(real_t t) {
|
||||
PYBIND11_OVERRIDE(
|
||||
void,
|
||||
VectorCoefficient,
|
||||
SetTime,
|
||||
t
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,162 @@
|
||||
/**
|
||||
* @file PyCoefficient.h
|
||||
* @brief Defines pybind11 trampoline classes for mfem::Coefficient and mfem::VectorCoefficient.
|
||||
*
|
||||
* These trampoline classes allow Python classes to inherit from mfem::Coefficient
|
||||
* and mfem::VectorCoefficient, enabling Python-defined coefficients to be used
|
||||
* within the C++ MFEM library.
|
||||
*/
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <pybind11/functional.h> // Needed for std::function
|
||||
#include <memory>
|
||||
|
||||
#include "mfem.hpp"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace mfem;
|
||||
|
||||
/**
|
||||
* @namespace serif::pybind
|
||||
* @brief Contains pybind11 helper classes and trampoline classes for interfacing C++ with Python.
|
||||
*/
|
||||
namespace serif::pybind {
|
||||
/**
|
||||
* @brief Trampoline class for mfem::Coefficient.
|
||||
*
|
||||
* This class allows Python classes to inherit from mfem::Coefficient and override
|
||||
* its virtual methods. This is essential for creating custom coefficients in Python
|
||||
* that can be used by MFEM's C++ backend.
|
||||
*
|
||||
* @see mfem::Coefficient
|
||||
*
|
||||
* @par Python Usage Example:
|
||||
* @code{.py}
|
||||
* import mfem.ser_ext as mfem
|
||||
*
|
||||
* class MyPythonCoefficient(mfem.Coefficient):
|
||||
* def __init__(self):
|
||||
* super().__init__() # Call the base C++ constructor
|
||||
*
|
||||
* def Eval(self, T, ip):
|
||||
* # T is an mfem.ElementTransformation
|
||||
* # ip is an mfem.IntegrationPoint
|
||||
* # Example: return a constant value
|
||||
* return 1.0
|
||||
*
|
||||
* def SetTime(self, t):
|
||||
* # Optionally handle time-dependent coefficients
|
||||
* super().SetTime(t) # Call base class method
|
||||
* print(f"Time set to: {t}")
|
||||
*
|
||||
* # Using the Python coefficient
|
||||
* py_coeff = MyPythonCoefficient()
|
||||
* # py_coeff can now be passed to MFEM functions expecting an mfem::Coefficient
|
||||
* @endcode
|
||||
*/
|
||||
class PyCoefficient : public Coefficient {
|
||||
public:
|
||||
using Coefficient::Coefficient; /**< Inherit constructors from mfem::Coefficient. */
|
||||
|
||||
/**
|
||||
* @brief Evaluate the coefficient at a given IntegrationPoint in an ElementTransformation.
|
||||
*
|
||||
* This method is called by MFEM when the value of the coefficient is needed.
|
||||
* If a Python class inherits from PyCoefficient, it *must* override this method.
|
||||
*
|
||||
* @param T The element transformation.
|
||||
* @param ip The integration point.
|
||||
* @return The value of the coefficient at the given point.
|
||||
*
|
||||
* @note This method forwards the call to the Python override.
|
||||
* PYBIND11_OVERRIDE_PURE is used in the .cpp file to handle this.
|
||||
*/
|
||||
real_t Eval(ElementTransformation &T, const IntegrationPoint &ip) override;
|
||||
|
||||
/**
|
||||
* @brief Set the current time for time-dependent coefficients.
|
||||
*
|
||||
* This method is called by MFEM to update the time for time-dependent coefficients.
|
||||
* Python classes inheriting from PyCoefficient can override this method to implement
|
||||
* time-dependent behavior.
|
||||
*
|
||||
* @param t The current time.
|
||||
*
|
||||
* @note This method forwards the call to the Python override if one exists.
|
||||
* PYBIND11_OVERRIDE is used in the .cpp file to handle this.
|
||||
*/
|
||||
void SetTime(real_t t) override;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Trampoline class for mfem::VectorCoefficient.
|
||||
*
|
||||
* This class allows Python classes to inherit from mfem::VectorCoefficient and override
|
||||
* its virtual methods. This is essential for creating custom vector-valued coefficients
|
||||
* in Python that can be used by MFEM's C++ backend.
|
||||
*
|
||||
* @see mfem::VectorCoefficient
|
||||
*
|
||||
* @par Python Usage Example:
|
||||
* @code{.py}
|
||||
* import mfem.ser_ext as mfem
|
||||
* import numpy as np
|
||||
*
|
||||
* class MyPythonVectorCoefficient(mfem.VectorCoefficient):
|
||||
* def __init__(self, dim):
|
||||
* super().__init__(dim) # Call the base C++ constructor, pass vector dimension
|
||||
* self.dim = dim
|
||||
*
|
||||
* def Eval(self, V, T, ip):
|
||||
* # V is an mfem.Vector (output parameter, must be filled)
|
||||
* # T is an mfem.ElementTransformation
|
||||
* # ip is an mfem.IntegrationPoint
|
||||
* # Example: return a constant vector [1.0, 2.0, ...]
|
||||
* for i in range(self.dim):
|
||||
* V[i] = float(i + 1)
|
||||
*
|
||||
* def SetTime(self, t):
|
||||
* super().SetTime(t)
|
||||
* print(f"VectorCoefficient time set to: {t}")
|
||||
*
|
||||
* # Using the Python vector coefficient
|
||||
* vec_dim = 2
|
||||
* py_vec_coeff = MyPythonVectorCoefficient(vec_dim)
|
||||
* # py_vec_coeff can now be passed to MFEM functions expecting an mfem::VectorCoefficient
|
||||
* @endcode
|
||||
*/
|
||||
class PyVectorCoefficient : public VectorCoefficient {
|
||||
public:
|
||||
using VectorCoefficient::VectorCoefficient; /**< Inherit constructors from mfem::VectorCoefficient. */
|
||||
|
||||
/**
|
||||
* @brief Evaluate the vector coefficient at a given IntegrationPoint in an ElementTransformation.
|
||||
*
|
||||
* This method is called by MFEM when the value of the vector coefficient is needed.
|
||||
* If a Python class inherits from PyVectorCoefficient, it *must* override this method.
|
||||
* The result should be stored in the output Vector @p V.
|
||||
*
|
||||
* @param V Output vector to store the result. Its size should match the coefficient's dimension.
|
||||
* @param T The element transformation.
|
||||
* @param ip The integration point.
|
||||
*
|
||||
* @note This method forwards the call to the Python override.
|
||||
* PYBIND11_OVERRIDE_PURE is used in the .cpp file to handle this.
|
||||
*/
|
||||
void Eval(Vector &V, ElementTransformation &T, const IntegrationPoint &ip) override;
|
||||
|
||||
/**
|
||||
* @brief Set the current time for time-dependent vector coefficients.
|
||||
*
|
||||
* This method is called by MFEM to update the time for time-dependent vector coefficients.
|
||||
* Python classes inheriting from PyVectorCoefficient can override this method to implement
|
||||
* time-dependent behavior.
|
||||
*
|
||||
* @param t The current time.
|
||||
*
|
||||
* @note This method forwards the call to the Python override if one exists.
|
||||
* PYBIND11_OVERRIDE is used in the .cpp file to handle this.
|
||||
*/
|
||||
void SetTime(real_t t) override;
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
PyCoefficient_sources = files(
|
||||
'PyCoefficient.cpp'
|
||||
)
|
||||
|
||||
trampoline_sources += PyCoefficient_sources
|
||||
@@ -0,0 +1,137 @@
|
||||
#include "PyMFEMTrampolines/Operator/Matrix/PyMatrix.h"
|
||||
#include "mfem.hpp"
|
||||
|
||||
namespace serif::pybind {
|
||||
|
||||
// --- Trampolines for new mfem::Matrix pure virtual methods ---
|
||||
|
||||
mfem::real_t& PyMatrix::Elem(int i, int j) {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
mfem::real_t&, /* Return type */
|
||||
mfem::Matrix, /* C++ base class */
|
||||
Elem, /* C++ function name */
|
||||
i, /* Argument 1 */
|
||||
j /* Argument 2 */
|
||||
);
|
||||
}
|
||||
|
||||
const mfem::real_t& PyMatrix::Elem(int i, int j) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
const mfem::real_t&,
|
||||
mfem::Matrix,
|
||||
Elem,
|
||||
i,
|
||||
j
|
||||
);
|
||||
}
|
||||
|
||||
mfem::MatrixInverse* PyMatrix::Inverse() const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
mfem::MatrixInverse*,
|
||||
mfem::Matrix,
|
||||
Inverse
|
||||
);
|
||||
}
|
||||
|
||||
// --- Trampoline for new mfem::Matrix regular virtual methods ---
|
||||
|
||||
void PyMatrix::Finalize(int skip_zeros) {
|
||||
PYBIND11_OVERRIDE(
|
||||
void,
|
||||
mfem::Matrix,
|
||||
Finalize,
|
||||
skip_zeros
|
||||
);
|
||||
}
|
||||
|
||||
// --- Trampolines for inherited mfem::Operator virtual methods ---
|
||||
|
||||
void PyMatrix::Mult(const mfem::Vector &x, mfem::Vector &y) const {
|
||||
// This remains PURE as mfem::Matrix does not implement it.
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
void,
|
||||
mfem::Matrix,
|
||||
Mult,
|
||||
x,
|
||||
y
|
||||
);
|
||||
}
|
||||
|
||||
void PyMatrix::MultTranspose(const mfem::Vector &x, mfem::Vector &y) const {
|
||||
PYBIND11_OVERRIDE(
|
||||
void,
|
||||
mfem::Matrix,
|
||||
MultTranspose,
|
||||
x,
|
||||
y
|
||||
);
|
||||
}
|
||||
|
||||
void PyMatrix::AddMult(const mfem::Vector &x, mfem::Vector &y, const mfem::real_t a) const {
|
||||
PYBIND11_OVERRIDE(
|
||||
void,
|
||||
mfem::Matrix,
|
||||
AddMult,
|
||||
x,
|
||||
y,
|
||||
a
|
||||
);
|
||||
}
|
||||
|
||||
void PyMatrix::AddMultTranspose(const mfem::Vector &x, mfem::Vector &y, const mfem::real_t a) const {
|
||||
PYBIND11_OVERRIDE(
|
||||
void,
|
||||
mfem::Matrix,
|
||||
AddMultTranspose,
|
||||
x,
|
||||
y,
|
||||
a
|
||||
);
|
||||
}
|
||||
|
||||
mfem::Operator& PyMatrix::GetGradient(const mfem::Vector &x) const {
|
||||
PYBIND11_OVERRIDE(
|
||||
mfem::Operator&,
|
||||
mfem::Matrix,
|
||||
GetGradient,
|
||||
x
|
||||
);
|
||||
}
|
||||
|
||||
void PyMatrix::AssembleDiagonal(mfem::Vector &diag) const {
|
||||
PYBIND11_OVERRIDE(
|
||||
void,
|
||||
mfem::Matrix,
|
||||
AssembleDiagonal,
|
||||
diag
|
||||
);
|
||||
}
|
||||
|
||||
const mfem::Operator* PyMatrix::GetProlongation() const {
|
||||
PYBIND11_OVERRIDE(
|
||||
const mfem::Operator*,
|
||||
mfem::Matrix,
|
||||
GetProlongation
|
||||
);
|
||||
}
|
||||
|
||||
const mfem::Operator* PyMatrix::GetRestriction() const {
|
||||
PYBIND11_OVERRIDE(
|
||||
const mfem::Operator*,
|
||||
mfem::Matrix,
|
||||
GetRestriction
|
||||
);
|
||||
}
|
||||
|
||||
void PyMatrix::RecoverFEMSolution(const mfem::Vector &X, const mfem::Vector &b, mfem::Vector &x) {
|
||||
PYBIND11_OVERRIDE(
|
||||
void,
|
||||
mfem::Matrix,
|
||||
RecoverFEMSolution,
|
||||
X,
|
||||
b,
|
||||
x
|
||||
);
|
||||
}
|
||||
|
||||
} // namespace serif::pybind
|
||||
@@ -0,0 +1,253 @@
|
||||
/**
|
||||
* @file PyMatrix.h
|
||||
* @brief Defines a pybind11 trampoline class for mfem::Matrix.
|
||||
*
|
||||
* This trampoline class allows Python classes to inherit from mfem::Matrix,
|
||||
* enabling Python-defined matrices to be used within the C++ MFEM library,
|
||||
* including overriding methods from its base class mfem::Operator.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "mfem.hpp"
|
||||
#include "pybind11/pybind11.h"
|
||||
|
||||
namespace serif::pybind {
|
||||
|
||||
/**
|
||||
* @brief A trampoline class for mfem::Matrix.
|
||||
*
|
||||
* This class allows Python classes to inherit from mfem::Matrix and correctly
|
||||
* override its virtual functions, including those inherited from its base,
|
||||
* mfem::Operator. This is useful for creating custom matrix types in Python
|
||||
* that can interact seamlessly with MFEM's C++ components.
|
||||
*
|
||||
* @see mfem::Matrix
|
||||
* @see mfem::Operator
|
||||
* @see PyOperator
|
||||
*
|
||||
* @par Python Usage Example:
|
||||
* @code{.py}
|
||||
* import mfem.ser_ext as mfem
|
||||
* import numpy as np
|
||||
*
|
||||
* class MyPythonMatrix(mfem.Matrix):
|
||||
* def __init__(self, size=0):
|
||||
* super().__init__(size) # Call base C++ constructor
|
||||
* if size > 0:
|
||||
* # Example: Initialize with some data if size is provided
|
||||
* # Note: Direct data manipulation might be complex due to
|
||||
* # C++ ownership. This example is conceptual.
|
||||
* # For a real-world scenario, you might manage data in Python
|
||||
* # and use Eval or Mult to reflect its state.
|
||||
* self._py_data = np.zeros((size, size))
|
||||
* else:
|
||||
* self._py_data = None
|
||||
*
|
||||
* # --- mfem.Matrix pure virtual overrides ---
|
||||
* def Elem(self, i, j):
|
||||
* # This is problematic for direct override for write access
|
||||
* # due to returning a reference.
|
||||
* # Typically, you'd manage data and use it in Mult/GetRow etc.
|
||||
* # For read access (const version), it's more straightforward.
|
||||
* if self._py_data is not None:
|
||||
* return self._py_data[i, j]
|
||||
* raise IndexError("Matrix data not initialized or index out of bounds")
|
||||
*
|
||||
* # const Elem version for read access
|
||||
* # def GetElem(self, i, j): # A helper in Python if Elem is tricky
|
||||
* # return self._py_data[i,j]
|
||||
*
|
||||
* def Inverse(self):
|
||||
* # Return an mfem.MatrixInverse object
|
||||
* # This is a simplified example; a real inverse is complex.
|
||||
* print("MyPythonMatrix.Inverse() called, returning dummy inverse.")
|
||||
* # For a real implementation, you'd compute and return an actual
|
||||
* # mfem.MatrixInverse or a Python-derived version of it.
|
||||
* # This might involve creating a new PyMatrix representing the inverse.
|
||||
* identity_inv = mfem.DenseMatrix(self.Width())
|
||||
* for i in range(self.Width()):
|
||||
* identity_inv[i,i] = 1.0
|
||||
* return mfem.MatrixInverse(identity_inv) # Example
|
||||
*
|
||||
* # --- mfem.Matrix regular virtual overrides ---
|
||||
* def Finalize(self, skip_zeros=1):
|
||||
* super().Finalize(skip_zeros) # Call base
|
||||
* print(f"MyPythonMatrix.Finalize({skip_zeros}) called.")
|
||||
*
|
||||
* # --- mfem.Operator virtual overrides ---
|
||||
* def Mult(self, x, y):
|
||||
* # x is mfem.Vector (input), y is mfem.Vector (output)
|
||||
* if self._py_data is not None:
|
||||
* x_np = x.GetDataArray()
|
||||
* y_np = np.dot(self._py_data, x_np)
|
||||
* y.SetSize(self._py_data.shape[0])
|
||||
* y.Assign(y_np)
|
||||
* else:
|
||||
* # Fallback to base class if no Python data
|
||||
* # super().Mult(x,y) # This might not work as expected if base is abstract
|
||||
* # or if the C++ mfem.Matrix itself doesn't have data.
|
||||
* # For a sparse matrix, this would call its sparse Mult.
|
||||
* # For a dense matrix, it would use its data.
|
||||
* # If this PyMatrix is purely Python defined, it must implement Mult.
|
||||
* raise NotImplementedError("Mult not implemented or data not set")
|
||||
* print("MyPythonMatrix.Mult called.")
|
||||
*
|
||||
* # Using the Python Matrix
|
||||
* mat_size = 3
|
||||
* py_mat = MyPythonMatrix(mat_size)
|
||||
* py_mat._py_data = np.array([[1,2,3],[4,5,6],[7,8,9]], dtype=float)
|
||||
*
|
||||
* x_vec = mfem.Vector(mat_size)
|
||||
* x_vec.Assign(np.array([1,1,1], dtype=float))
|
||||
* y_vec = mfem.Vector(mat_size)
|
||||
*
|
||||
* py_mat.Mult(x_vec, y_vec)
|
||||
* print("Result y:", y_vec.GetDataArray())
|
||||
*
|
||||
* # inv_op = py_mat.Inverse() # Be cautious with dummy implementations
|
||||
* # print("Inverse type:", type(inv_op))
|
||||
*
|
||||
* # Note: Overriding Elem(i,j) to return a C++ reference from Python
|
||||
* # for write access (mfem::real_t&) is complex and often not directly feasible
|
||||
* # in a straightforward way with pybind11 for typical Python data structures.
|
||||
* # Python users would typically interact via methods like Set, Add, or by
|
||||
* # providing data that the C++ side can access (e.g., via GetData).
|
||||
* # The const version of Elem (read-only) is more manageable.
|
||||
* @endcode
|
||||
*/
|
||||
class PyMatrix : public mfem::Matrix {
|
||||
public:
|
||||
// Inherit constructors from the base mfem::Matrix class.
|
||||
using mfem::Matrix::Matrix;
|
||||
|
||||
// --- Trampolines for new mfem::Matrix pure virtual methods ---
|
||||
/**
|
||||
* @brief Access element (i,j) for read/write.
|
||||
* Pure virtual in mfem::Matrix. Must be overridden in Python.
|
||||
* @param i Row index.
|
||||
* @param j Column index.
|
||||
* @return Reference to the matrix element.
|
||||
* @note Returning a C++ reference from Python for write access can be complex.
|
||||
* Consider alternative data management strategies in Python.
|
||||
* PYBIND11_OVERRIDE_PURE is used in the .cpp file.
|
||||
*/
|
||||
mfem::real_t& Elem(int i, int j) override;
|
||||
|
||||
/**
|
||||
* @brief Access element (i,j) for read-only.
|
||||
* Pure virtual in mfem::Matrix. Must be overridden in Python.
|
||||
* @param i Row index.
|
||||
* @param j Column index.
|
||||
* @return Const reference to the matrix element.
|
||||
* @note PYBIND11_OVERRIDE_PURE is used in the .cpp file.
|
||||
*/
|
||||
const mfem::real_t& Elem(int i, int j) const override;
|
||||
|
||||
/**
|
||||
* @brief Get the inverse of the matrix.
|
||||
* Pure virtual in mfem::Matrix. Must be overridden in Python.
|
||||
* The caller is responsible for deleting the returned MatrixInverse object.
|
||||
* @return Pointer to an mfem::MatrixInverse object.
|
||||
* @note PYBIND11_OVERRIDE_PURE is used in the .cpp file.
|
||||
*/
|
||||
mfem::MatrixInverse* Inverse() const override;
|
||||
|
||||
// --- Trampoline for new mfem::Matrix regular virtual methods ---
|
||||
/**
|
||||
* @brief Finalize matrix assembly.
|
||||
* For sparse matrices, this typically involves finalizing the sparse structure.
|
||||
* Can be overridden in Python.
|
||||
* @param skip_zeros See mfem::SparseMatrix::Finalize documentation.
|
||||
* @note PYBIND11_OVERRIDE is used in the .cpp file.
|
||||
*/
|
||||
void Finalize(int skip_zeros) override;
|
||||
|
||||
// --- Trampolines for inherited mfem::Operator virtual methods ---
|
||||
// These must be repeated here to allow Python classes inheriting from
|
||||
// Matrix to override methods originally from the Operator base class.
|
||||
|
||||
/**
|
||||
* @brief Perform the operator action: y = A*x.
|
||||
* Inherited from mfem::Operator. Can be overridden in Python.
|
||||
* If not overridden, mfem::Matrix's default implementation is used.
|
||||
* @param x The input vector.
|
||||
* @param y The output vector (result of A*x).
|
||||
* @note PYBIND11_OVERRIDE is used in the .cpp file.
|
||||
*/
|
||||
void Mult(const mfem::Vector &x, mfem::Vector &y) const override;
|
||||
|
||||
/**
|
||||
* @brief Perform the transpose operator action: y = A^T*x.
|
||||
* Inherited from mfem::Operator. Can be overridden in Python.
|
||||
* @param x The input vector.
|
||||
* @param y The output vector (result of A^T*x).
|
||||
* @note PYBIND11_OVERRIDE is used in the .cpp file.
|
||||
*/
|
||||
void MultTranspose(const mfem::Vector &x, mfem::Vector &y) const override;
|
||||
|
||||
/**
|
||||
* @brief Perform the action y += a*(A*x).
|
||||
* Inherited from mfem::Operator. Can be overridden in Python.
|
||||
* @param x The input vector.
|
||||
* @param y The vector to which a*(A*x) is added.
|
||||
* @param a Scalar multiplier (defaults to 1.0).
|
||||
* @note PYBIND11_OVERRIDE is used in the .cpp file.
|
||||
*/
|
||||
void AddMult(const mfem::Vector &x, mfem::Vector &y, const mfem::real_t a = 1.0) const override;
|
||||
|
||||
/**
|
||||
* @brief Perform the action y += a*(A^T*x).
|
||||
* Inherited from mfem::Operator. Can be overridden in Python.
|
||||
* @param x The input vector.
|
||||
* @param y The vector to which a*(A^T*x) is added.
|
||||
* @param a Scalar multiplier (defaults to 1.0).
|
||||
* @note PYBIND11_OVERRIDE is used in the .cpp file.
|
||||
*/
|
||||
void AddMultTranspose(const mfem::Vector &x, mfem::Vector &y, const mfem::real_t a = 1.0) const override;
|
||||
|
||||
/**
|
||||
* @brief Get the gradient operator (Jacobian) at a given point x.
|
||||
* Inherited from mfem::Operator. Can be overridden in Python.
|
||||
* For a linear matrix operator, the gradient is typically the matrix itself.
|
||||
* @param x The point at which to evaluate the gradient (often unused for linear operators).
|
||||
* @return A reference to the gradient operator.
|
||||
* @note PYBIND11_OVERRIDE is used in the .cpp file.
|
||||
*/
|
||||
mfem::Operator& GetGradient(const mfem::Vector &x) const override;
|
||||
|
||||
/**
|
||||
* @brief Assemble the diagonal of the operator.
|
||||
* Inherited from mfem::Operator. Can be overridden in Python.
|
||||
* @param diag Output vector to store the diagonal entries.
|
||||
* @note PYBIND11_OVERRIDE is used in the .cpp file.
|
||||
*/
|
||||
void AssembleDiagonal(mfem::Vector &diag) const override;
|
||||
|
||||
/**
|
||||
* @brief Get the prolongation operator.
|
||||
* Inherited from mfem::Operator. Can be overridden in Python.
|
||||
* @return A const pointer to the prolongation operator, or nullptr if not applicable.
|
||||
* @note PYBIND11_OVERRIDE is used in the .cpp file.
|
||||
*/
|
||||
const mfem::Operator* GetProlongation() const override;
|
||||
|
||||
/**
|
||||
* @brief Get the restriction operator.
|
||||
* Inherited from mfem::Operator. Can be overridden in Python.
|
||||
* @return A const pointer to the restriction operator, or nullptr if not applicable.
|
||||
* @note PYBIND11_OVERRIDE is used in the .cpp file.
|
||||
*/
|
||||
const mfem::Operator* GetRestriction() const override;
|
||||
|
||||
/**
|
||||
* @brief Recover the FEM solution.
|
||||
* Inherited from mfem::Operator. Can be overridden in Python.
|
||||
* @param X The reduced solution vector.
|
||||
* @param b The right-hand side vector.
|
||||
* @param x Output vector for the full FEM solution.
|
||||
* @note PYBIND11_OVERRIDE is used in the .cpp file.
|
||||
*/
|
||||
void RecoverFEMSolution(const mfem::Vector &X, const mfem::Vector &b, mfem::Vector &x) override;
|
||||
};
|
||||
|
||||
} // namespace serif::pybind
|
||||
@@ -0,0 +1,5 @@
|
||||
PyMatrix_sources = files(
|
||||
'PyMatrix.cpp'
|
||||
)
|
||||
|
||||
trampoline_sources += PyMatrix_sources
|
||||
@@ -0,0 +1,100 @@
|
||||
#include "PyMFEMTrampolines/Operator/PyOperator.h"
|
||||
#include "mfem.hpp"
|
||||
|
||||
namespace serif::pybind {
|
||||
|
||||
// --- Pure virtual function implementation ---
|
||||
// This override is mandatory for the trampoline class to be concrete.
|
||||
void PyOperator::Mult(const mfem::Vector &x, mfem::Vector &y) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
void, /* Return type */
|
||||
mfem::Operator, /* C++ base class */
|
||||
Mult, /* C++ function name */
|
||||
x, /* Argument 1 */
|
||||
y /* Argument 2 */
|
||||
);
|
||||
}
|
||||
|
||||
// --- Regular virtual function implementations ---
|
||||
// These overrides allow Python subclasses to optionally provide their own
|
||||
// implementation for these methods. If they don't, the default C++
|
||||
// implementation from mfem::Operator will be called.
|
||||
|
||||
void PyOperator::MultTranspose(const mfem::Vector &x, mfem::Vector &y) const {
|
||||
PYBIND11_OVERRIDE(
|
||||
void,
|
||||
mfem::Operator,
|
||||
MultTranspose,
|
||||
x,
|
||||
y
|
||||
);
|
||||
}
|
||||
|
||||
void PyOperator::AddMult(const mfem::Vector &x, mfem::Vector &y, const mfem::real_t a) const {
|
||||
PYBIND11_OVERRIDE(
|
||||
void,
|
||||
mfem::Operator,
|
||||
AddMult,
|
||||
x,
|
||||
y,
|
||||
a
|
||||
);
|
||||
}
|
||||
|
||||
void PyOperator::AddMultTranspose(const mfem::Vector &x, mfem::Vector &y, const mfem::real_t a) const {
|
||||
PYBIND11_OVERRIDE(
|
||||
void,
|
||||
mfem::Operator,
|
||||
AddMultTranspose,
|
||||
x,
|
||||
y,
|
||||
a
|
||||
);
|
||||
}
|
||||
|
||||
mfem::Operator& PyOperator::GetGradient(const mfem::Vector &x) const {
|
||||
PYBIND11_OVERRIDE(
|
||||
mfem::Operator&,
|
||||
mfem::Operator,
|
||||
GetGradient,
|
||||
x
|
||||
);
|
||||
}
|
||||
|
||||
void PyOperator::AssembleDiagonal(mfem::Vector &diag) const {
|
||||
PYBIND11_OVERRIDE(
|
||||
void,
|
||||
mfem::Operator,
|
||||
AssembleDiagonal,
|
||||
diag
|
||||
);
|
||||
}
|
||||
|
||||
const mfem::Operator* PyOperator::GetProlongation() const {
|
||||
PYBIND11_OVERRIDE(
|
||||
const mfem::Operator*,
|
||||
mfem::Operator,
|
||||
GetProlongation
|
||||
);
|
||||
}
|
||||
|
||||
const mfem::Operator* PyOperator::GetRestriction() const {
|
||||
PYBIND11_OVERRIDE(
|
||||
const mfem::Operator*,
|
||||
mfem::Operator,
|
||||
GetRestriction
|
||||
);
|
||||
}
|
||||
|
||||
void PyOperator::RecoverFEMSolution(const mfem::Vector &X, const mfem::Vector &b, mfem::Vector &x) {
|
||||
PYBIND11_OVERRIDE(
|
||||
void,
|
||||
mfem::Operator,
|
||||
RecoverFEMSolution,
|
||||
X,
|
||||
b,
|
||||
x
|
||||
);
|
||||
}
|
||||
|
||||
} // namespace serif::pybind::mfem
|
||||
@@ -0,0 +1,210 @@
|
||||
/**
|
||||
* @file PyOperator.h
|
||||
* @brief Defines a pybind11 trampoline class for mfem::Operator.
|
||||
*
|
||||
* This trampoline class allows Python classes to inherit from mfem::Operator,
|
||||
* enabling Python-defined operators to be used within the C++ MFEM library.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "mfem.hpp"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h" // Needed for vectors, maps, sets, strings
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
/**
|
||||
* @namespace serif::pybind
|
||||
* @brief Contains pybind11 helper classes and trampoline classes for interfacing C++ with Python.
|
||||
*/
|
||||
namespace serif::pybind {
|
||||
|
||||
/**
|
||||
* @brief Trampoline class for mfem::Operator.
|
||||
*
|
||||
* This class allows Python classes to inherit from mfem::Operator and correctly
|
||||
* override its virtual functions. When a virtual function is called from C++,
|
||||
* the trampoline ensures the call is forwarded to the Python implementation if one exists.
|
||||
* This is crucial for integrating Python-defined linear or non-linear operators
|
||||
* into MFEM's C++-based solvers and algorithms.
|
||||
*
|
||||
* @see mfem::Operator
|
||||
*
|
||||
* @par Python Usage Example:
|
||||
* @code{.py}
|
||||
* import mfem.ser_ext as mfem
|
||||
* import numpy as np
|
||||
*
|
||||
* class MyPythonOperator(mfem.Operator):
|
||||
* def __init__(self, size):
|
||||
* super().__init__(size) # Call the base C++ constructor
|
||||
* # Or super().__init__() if default constructing and setting height/width later
|
||||
* self.matrix = np.random.rand(size, size) # Example: store a dense matrix
|
||||
*
|
||||
* # Must override Mult
|
||||
* def Mult(self, x, y):
|
||||
* # x is an mfem.Vector (input)
|
||||
* # y is an mfem.Vector (output, y = A*x)
|
||||
* # Ensure y is correctly sized if not already
|
||||
* if y.Size() != self.Height():
|
||||
* y.SetSize(self.Height())
|
||||
*
|
||||
* # Example: y = self.matrix * x
|
||||
* # This is a conceptual illustration. For actual matrix-vector products
|
||||
* # with numpy, you'd convert mfem.Vector to numpy array or iterate.
|
||||
* x_np = x.GetDataArray() # Get a numpy view if configured
|
||||
* y_np = np.dot(self.matrix, x_np)
|
||||
* y.Assign(y_np) # Assign result back to mfem.Vector
|
||||
* print("MyPythonOperator.Mult called")
|
||||
*
|
||||
* # Optionally override other methods like MultTranspose, GetGradient, etc.
|
||||
* def MultTranspose(self, x, y):
|
||||
* if y.Size() != self.Width():
|
||||
* y.SetSize(self.Width())
|
||||
* # Example: y = self.matrix.T * x
|
||||
* x_np = x.GetDataArray()
|
||||
* y_np = np.dot(self.matrix.T, x_np)
|
||||
* y.Assign(y_np)
|
||||
* print("MyPythonOperator.MultTranspose called")
|
||||
*
|
||||
* # Using the Python operator
|
||||
* op_size = 5
|
||||
* py_op = MyPythonOperator(op_size)
|
||||
*
|
||||
* x_vec = mfem.Vector(op_size)
|
||||
* x_vec.Assign(np.arange(op_size, dtype=float))
|
||||
* y_vec = mfem.Vector(op_size)
|
||||
*
|
||||
* py_op.Mult(x_vec, y_vec)
|
||||
* print("Result y:", y_vec.GetDataArray())
|
||||
* # py_op can now be passed to MFEM functions expecting an mfem::Operator
|
||||
* @endcode
|
||||
*/
|
||||
class PyOperator : public mfem::Operator {
|
||||
public:
|
||||
// Inherit the constructors from the base mfem::Operator class.
|
||||
// This allows Python classes to call e.g., super().__init__(size).
|
||||
using mfem::Operator::Operator;
|
||||
|
||||
// --- Trampoline declarations for all overridable virtual functions ---
|
||||
|
||||
/**
|
||||
* @brief Perform the operator action: y = A*x.
|
||||
*
|
||||
* This is a pure virtual function in mfem::Operator and **must** be overridden
|
||||
* by any Python class inheriting from PyOperator.
|
||||
*
|
||||
* @param x The input vector.
|
||||
* @param y The output vector (result of A*x).
|
||||
* @note This method forwards the call to the Python override.
|
||||
* PYBIND11_OVERRIDE_PURE is used in the .cpp file to handle this.
|
||||
*/
|
||||
void Mult(const mfem::Vector &x, mfem::Vector &y) const override;
|
||||
|
||||
/**
|
||||
* @brief Perform the transpose operator action: y = A^T*x.
|
||||
*
|
||||
* Optional override. If not overridden, MFEM's base implementation
|
||||
* (which may raise an error or be a no-op) will be used.
|
||||
*
|
||||
* @param x The input vector.
|
||||
* @param y The output vector (result of A^T*x).
|
||||
* @note This method forwards the call to the Python override if one exists.
|
||||
* PYBIND11_OVERRIDE is used in the .cpp file to handle this.
|
||||
*/
|
||||
void MultTranspose(const mfem::Vector &x, mfem::Vector &y) const override;
|
||||
|
||||
/**
|
||||
* @brief Perform the action y += a*(A*x).
|
||||
*
|
||||
* Optional override.
|
||||
*
|
||||
* @param x The input vector.
|
||||
* @param y The vector to which a*(A*x) is added.
|
||||
* @param a Scalar multiplier (defaults to 1.0).
|
||||
* @note This method forwards the call to the Python override if one exists.
|
||||
*/
|
||||
void AddMult(const mfem::Vector &x, mfem::Vector &y, const mfem::real_t a = 1.0) const override;
|
||||
|
||||
/**
|
||||
* @brief Perform the action y += a*(A^T*x).
|
||||
*
|
||||
* Optional override.
|
||||
*
|
||||
* @param x The input vector.
|
||||
* @param y The vector to which a*(A^T*x) is added.
|
||||
* @param a Scalar multiplier (defaults to 1.0).
|
||||
* @note This method forwards the call to the Python override if one exists.
|
||||
*/
|
||||
void AddMultTranspose(const mfem::Vector &x, mfem::Vector &y, const mfem::real_t a = 1.0) const override;
|
||||
|
||||
/**
|
||||
* @brief Get the gradient operator (Jacobian) at a given point x.
|
||||
*
|
||||
* For non-linear operators, this method should return the linearization (Jacobian)
|
||||
* of the operator at the point `x`. The returned Operator is owned by this
|
||||
* Operator and should not be deleted by the caller.
|
||||
* Optional override.
|
||||
*
|
||||
* @param x The point at which to evaluate the gradient.
|
||||
* @return A reference to the gradient operator.
|
||||
* @note This method forwards the call to the Python override if one exists.
|
||||
*/
|
||||
Operator& GetGradient(const mfem::Vector &x) const override;
|
||||
|
||||
/**
|
||||
* @brief Assemble the diagonal of the operator.
|
||||
*
|
||||
* For discrete operators (e.g., matrices), this method should compute and store
|
||||
* the diagonal entries of the operator in the vector `diag`.
|
||||
* Optional override.
|
||||
*
|
||||
* @param diag Output vector to store the diagonal entries.
|
||||
* @note This method forwards the call to the Python override if one exists.
|
||||
*/
|
||||
void AssembleDiagonal(mfem::Vector &diag) const override;
|
||||
|
||||
/**
|
||||
* @brief Get the prolongation operator.
|
||||
*
|
||||
* Used in multilevel methods (e.g., AMG). Returns a pointer to the prolongation
|
||||
* operator (interpolation from a coarser level to this level).
|
||||
* The returned Operator is typically owned by this Operator.
|
||||
* Optional override.
|
||||
*
|
||||
* @return A const pointer to the prolongation operator, or nullptr if not applicable.
|
||||
* @note This method forwards the call to the Python override if one exists.
|
||||
*/
|
||||
const mfem::Operator* GetProlongation() const override;
|
||||
|
||||
/**
|
||||
* @brief Get the restriction operator.
|
||||
*
|
||||
* Used in multilevel methods (e.g., AMG). Returns a pointer to the restriction
|
||||
* operator (projection from this level to a coarser level).
|
||||
* Typically, this is the transpose of the prolongation operator.
|
||||
* The returned Operator is typically owned by this Operator.
|
||||
* Optional override.
|
||||
*
|
||||
* @return A const pointer to the restriction operator, or nullptr if not applicable.
|
||||
* @note This method forwards the call to the Python override if one exists.
|
||||
*/
|
||||
const mfem::Operator* GetRestriction() const override;
|
||||
|
||||
/**
|
||||
* @brief Recover the FEM solution.
|
||||
*
|
||||
* For operators that are part of a system solve (e.g., static condensation),
|
||||
* this method can be used to reconstruct the full finite element solution `x`
|
||||
* from a reduced solution `X` and the right-hand side `b`.
|
||||
* Optional override.
|
||||
*
|
||||
* @param X The reduced solution vector.
|
||||
* @param b The right-hand side vector.
|
||||
* @param x Output vector for the full FEM solution.
|
||||
* @note This method forwards the call to the Python override if one exists.
|
||||
*/
|
||||
void RecoverFEMSolution(const mfem::Vector &X, const mfem::Vector &b, mfem::Vector &x) override;
|
||||
};
|
||||
|
||||
} // namespace serif::pybind
|
||||
@@ -0,0 +1,7 @@
|
||||
PyOperator_sources = files(
|
||||
'PyOperator.cpp'
|
||||
)
|
||||
|
||||
trampoline_sources += PyOperator_sources
|
||||
|
||||
subdir('Matrix')
|
||||
2
src/python/mfem/Trampoline/PyMFEMTrampolines/meson.build
Normal file
2
src/python/mfem/Trampoline/PyMFEMTrampolines/meson.build
Normal file
@@ -0,0 +1,2 @@
|
||||
subdir('Operator')
|
||||
subdir('Coefficient')
|
||||
23
src/python/mfem/Trampoline/meson.build
Normal file
23
src/python/mfem/Trampoline/meson.build
Normal file
@@ -0,0 +1,23 @@
|
||||
trampoline_sources = []
|
||||
|
||||
subdir('PyMFEMTrampolines')
|
||||
|
||||
dependencies = [
|
||||
mfem_dep,
|
||||
pybind11_dep,
|
||||
python3_dep,
|
||||
]
|
||||
|
||||
trampoline_lib = static_library(
|
||||
'mfem_trampolines',
|
||||
trampoline_sources,
|
||||
include_directories: include_directories('.'),
|
||||
dependencies: dependencies,
|
||||
install: false,
|
||||
)
|
||||
|
||||
trampoline_dep = declare_dependency(
|
||||
link_with: trampoline_lib,
|
||||
include_directories: ('.'),
|
||||
dependencies: dependencies,
|
||||
)
|
||||
862
src/python/mfem/bindings.cpp
Normal file
862
src/python/mfem/bindings.cpp
Normal file
@@ -0,0 +1,862 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <pybind11/numpy.h>
|
||||
#include <pybind11/operators.h> // For operator overloads
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
|
||||
// Include your trampoline class header. The implementation will be in a separate .cpp file.
|
||||
#include "bindings.h"
|
||||
|
||||
#include "Trampoline/PyMFEMTrampolines/Operator/PyOperator.h"
|
||||
#include "Trampoline/PyMFEMTrampolines/Operator/Matrix/PyMatrix.h"
|
||||
#include "Trampoline/PyMFEMTrampolines/Coefficient/PyCoefficient.h"
|
||||
|
||||
#include "mfem.hpp"
|
||||
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace mfem;
|
||||
|
||||
// This function registers all the mfem-related classes to the python module
|
||||
void register_mfem_bindings(py::module &mfem_submodule) {
|
||||
register_operator_bindings(mfem_submodule);
|
||||
register_matrix_bindings(mfem_submodule);
|
||||
|
||||
register_vector_bindings(mfem_submodule);
|
||||
register_array_bindings(mfem_submodule);
|
||||
register_table_bindings(mfem_submodule);
|
||||
|
||||
register_mesh_bindings(mfem_submodule);
|
||||
|
||||
auto formsModule = mfem_submodule.def_submodule("forms", "MFEM forms module");
|
||||
register_bilinear_form_bindings(formsModule);
|
||||
register_mixed_bilinear_form_bindings(formsModule);
|
||||
|
||||
auto fecModule = mfem_submodule.def_submodule("fec", "MFEM finite element collection module");
|
||||
register_basis_type_bindings(fecModule);
|
||||
register_finite_element_collection_bindings(fecModule);
|
||||
register_H1_FECollection_bindings(fecModule);
|
||||
register_RT_FECollection_bindings(fecModule);
|
||||
register_ND_FECollection_bindings(fecModule);
|
||||
|
||||
auto fesModule = mfem_submodule.def_submodule("fes", "MFEM finite element space module");
|
||||
register_finite_element_space_bindings(fesModule);
|
||||
|
||||
register_coefficient_bindings(mfem_submodule);
|
||||
register_intrule_bindings(mfem_submodule);
|
||||
register_eltrans_bindings(mfem_submodule);
|
||||
|
||||
register_grid_function_bindings(mfem_submodule);
|
||||
}
|
||||
|
||||
void register_operator_bindings(py::module &mfem_submodule) {
|
||||
// Use the PyOperator trampoline when binding mfem::Operator
|
||||
// This allows Python classes to inherit from mfem::Operator
|
||||
py::class_<Operator, serif::pybind::PyOperator /* Trampoline */>(mfem_submodule, "Operator")
|
||||
// NOTE: We DO NOT define an __init__ method because Operator is abstract.
|
||||
// Python users will instantiate concrete derived classes instead.
|
||||
|
||||
// --- Bind Properties ---
|
||||
.def_property_readonly("height", &Operator::Height, "Get the height (number of rows) of the Operator.")
|
||||
.def_property_readonly("width", &Operator::Width, "Get the width (number of columns) of the Operator.")
|
||||
|
||||
// --- Bind Core Virtual Methods ---
|
||||
// We bind the methods of the C++ base class so they can be called from Python.
|
||||
// The trampoline handles redirecting to a Python override if one exists.
|
||||
|
||||
// Mult: y = A(x)
|
||||
.def("Mult", py::overload_cast<const Vector&, Vector&>(&Operator::Mult, py::const_),
|
||||
py::arg("x"), py::arg("y"), "Calculates y = A(x). y must be pre-allocated.")
|
||||
|
||||
// Pythonic overload for Mult that returns a new vector
|
||||
.def("Mult", [](const Operator &op, const Vector &x) {
|
||||
Vector y(op.Height());
|
||||
op.Mult(x, y);
|
||||
return y;
|
||||
}, py::arg("x"), "Calculates and returns a new vector y = A(x).")
|
||||
|
||||
// MultTranspose: y = A^T(x)
|
||||
.def("MultTranspose", py::overload_cast<const Vector&, Vector&>(&Operator::MultTranspose, py::const_),
|
||||
py::arg("x"), py::arg("y"), "Calculates y = A^T(x). y must be pre-allocated.")
|
||||
|
||||
.def("MultTranspose", [](const Operator &op, const Vector &x) {
|
||||
Vector y(op.Width());
|
||||
op.MultTranspose(x, y);
|
||||
return y;
|
||||
}, py::arg("x"), "Calculates and returns a new vector y = A^T(x).")
|
||||
|
||||
// Additive versions
|
||||
.def("AddMult", &Operator::AddMult, py::arg("x"), py::arg("y"), py::arg("a") = 1.0, "Performs y += a * A(x).")
|
||||
.def("AddMultTranspose", &Operator::AddMultTranspose, py::arg("x"), py::arg("y"), py::arg("a") = 1.0, "Performs y += a * A^T(x).")
|
||||
|
||||
// Other core virtual methods
|
||||
.def("AssembleDiagonal", &Operator::AssembleDiagonal, py::arg("diag"), "Assembles the operator diagonal into the given Vector.")
|
||||
.def("RecoverFEMSolution", &Operator::RecoverFEMSolution, py::arg("X"), py::arg("b"), py::arg("x"), "Recovers the full FE solution.")
|
||||
.def("GetGradient", &Operator::GetGradient, py::return_value_policy::reference, "Returns the Gradient of a non-linear operator.")
|
||||
|
||||
// Methods returning other operators (e.g., for parallel/BCs)
|
||||
.def("GetProlongation", &Operator::GetProlongation, py::return_value_policy::reference, "Returns the prolongation operator.")
|
||||
.def("GetRestriction", &Operator::GetRestriction, py::return_value_policy::reference, "Returns the restriction operator.")
|
||||
|
||||
// --- Pythonic Operator Overloading ---
|
||||
.def("__matmul__", [](const Operator &op, const Vector &x) {
|
||||
Vector y(op.Height());
|
||||
op.Mult(x, y);
|
||||
return y;
|
||||
}, py::is_operator());
|
||||
}
|
||||
|
||||
void register_matrix_bindings(py::module &mfem_submodule) {
|
||||
py::class_<Matrix, Operator, serif::pybind::PyMatrix>(mfem_submodule, "Matrix")
|
||||
// No constructor since it's an abstract base class.
|
||||
.def_property_readonly("is_square", &Matrix::IsSquare,
|
||||
"Returns true if the matrix is square.")
|
||||
|
||||
.def("finalize", &Matrix::Finalize, py::arg("skip_zeros") = 1,
|
||||
"Finalizes the matrix initialization.")
|
||||
|
||||
.def("inverse", &Matrix::Inverse,
|
||||
"Returns a pointer to (an approximation) of the matrix inverse.",
|
||||
py::return_value_policy::take_ownership) // The caller owns the returned pointer
|
||||
|
||||
// Pythonic element access: mat[i, j]
|
||||
.def("__getitem__", [](const Matrix &m, py::tuple t) {
|
||||
if (t.size() != 2) {
|
||||
throw py::index_error("Matrix index must be a 2-tuple (i, j)");
|
||||
}
|
||||
return m.Elem(t[0].cast<int>(), t[1].cast<int>());
|
||||
})
|
||||
|
||||
// Pythonic element assignment: mat[i, j] = value
|
||||
.def("__setitem__", [](Matrix &m, py::tuple t, real_t value) {
|
||||
if (t.size() != 2) {
|
||||
throw py::index_error("Matrix index must be a 2-tuple (i, j)");
|
||||
}
|
||||
m.Elem(t[0].cast<int>(), t[1].cast<int>()) = value;
|
||||
})
|
||||
|
||||
.def("__repr__", [](const Matrix &m) {
|
||||
return "<mfem.Matrix (Abstract) " +
|
||||
std::to_string(m.Height()) + "x" +
|
||||
std::to_string(m.Width()) + ">";
|
||||
});
|
||||
}
|
||||
|
||||
void register_vector_bindings(py::module &mfem_submodule) {
|
||||
// Register the mfem::Vector class
|
||||
py::class_<mfem::Vector>(mfem_submodule, "Vector")
|
||||
.def(py::init<int>(), py::arg("size"))
|
||||
.def(py::init<const mfem::Vector &>(), py::arg("other"))
|
||||
.def(py::init([](py::array_t<double> arr) {
|
||||
py::buffer_info info = arr.request();
|
||||
if (info.ndim != 1) {
|
||||
throw std::runtime_error("Vector(): expected a 1-D numpy array");
|
||||
}
|
||||
mfem::Vector v(info.size);
|
||||
std::memcpy(v.GetData(), info.ptr, info.size * sizeof(double));
|
||||
return v;
|
||||
}), py::arg("array"))
|
||||
.def("GetData", &mfem::Vector::GetData, py::return_value_policy::reference_internal)
|
||||
.def("Size", &mfem::Vector::Size)
|
||||
.def("__getitem__", [](const mfem::Vector &v, int i) { return v[i]; })
|
||||
.def("__len__", &mfem::Vector::Size)
|
||||
.def("__setitem__", [](mfem::Vector &v, int i, double value) { v[i] = value; })
|
||||
.def("__repr__", [](const mfem::Vector &v) {
|
||||
return "<mfem.Vector(size=" + std::to_string(v.Size()) + ")>";
|
||||
})
|
||||
.def("as_numpy", [](mfem::Vector &self) {
|
||||
return py::array_t<double>(self.Size(), self.GetData());
|
||||
});
|
||||
}
|
||||
|
||||
void register_array_bindings(py::module &mfem_submodule) {
|
||||
py::class_<Array<int>>(mfem_submodule, "IntArray")
|
||||
// --- Constructors ---
|
||||
.def(py::init<>(), "Default constructor.")
|
||||
.def(py::init<int>(), py::arg("size"), "Constructor with size.")
|
||||
.def(py::init([](const std::vector<int> &v) {
|
||||
auto *arr = new Array<int>(v.size());
|
||||
for (size_t i = 0; i < v.size(); ++i) {
|
||||
(*arr)[i] = v[i];
|
||||
}
|
||||
return arr;
|
||||
}), py::arg("list"), "Constructor from a Python list.")
|
||||
|
||||
// --- Pythonic Features ---
|
||||
.def("__len__", &Array<int>::Size)
|
||||
.def("__getitem__", [](const Array<int> &self, int i) {
|
||||
if (i < 0) i += self.Size(); // Handle negative indices
|
||||
if (i < 0 || i >= self.Size()) throw py::index_error();
|
||||
return self[i];
|
||||
})
|
||||
.def("__setitem__", [](Array<int> &self, int i, int value) {
|
||||
if (i < 0) i += self.Size(); // Handle negative indices
|
||||
if (i < 0 || i >= self.Size()) throw py::index_error();
|
||||
self[i] = value;
|
||||
})
|
||||
.def("__iter__", [](Array<int> &self) {
|
||||
return py::make_iterator(self.begin(), self.end());
|
||||
}, py::keep_alive<0, 1>()) // Keep array alive while iterator is used
|
||||
.def("__repr__", [](const Array<int> &self) {
|
||||
std::stringstream ss;
|
||||
ss << "[";
|
||||
for (int i = 0; i < self.Size(); ++i) {
|
||||
ss << self[i] << (i == self.Size() - 1 ? "" : ", ");
|
||||
}
|
||||
ss << "]";
|
||||
return ss.str();
|
||||
})
|
||||
|
||||
// --- Core Methods ---
|
||||
.def("Size", &Array<int>::Size)
|
||||
.def("SetSize", py::overload_cast<int>(&Array<int>::SetSize), py::arg("size"))
|
||||
.def("Append", py::overload_cast<const int &>(&Array<int>::Append), py::arg("el"))
|
||||
.def("Last", py::overload_cast<>(&Array<int>::Last, py::const_))
|
||||
.def("DeleteLast", &Array<int>::DeleteLast)
|
||||
.def("DeleteAll", &Array<int>::DeleteAll)
|
||||
.def("Sort", [](Array<int> &self) { self.Sort(); })
|
||||
.def("Unique", &Array<int>::Unique)
|
||||
.def("Assign", [](Array<int> &self, const Array<int> &other) {
|
||||
self.Assign(other.GetData());
|
||||
})
|
||||
.def("as_numpy", [](Array<int> &self) {
|
||||
// Create a Python-owned copy to avoid memory issues
|
||||
return py::array_t<int>(self.Size(), self.GetData());
|
||||
});
|
||||
}
|
||||
|
||||
// Main function to register BilinearForm
|
||||
void register_bilinear_form_bindings(py::module &mfem_submodule) {
|
||||
|
||||
// It's good practice to bind enums used by the class
|
||||
bind_assembly_level_enum(mfem_submodule);
|
||||
|
||||
// Bind the mfem::BilinearForm class, inheriting from mfem::Matrix
|
||||
// No trampoline is needed because this is a concrete class.
|
||||
py::class_<BilinearForm, Matrix>(mfem_submodule, "BilinearForm")
|
||||
|
||||
// --- Constructor ---
|
||||
.def(py::init<FiniteElementSpace *>(), py::arg("fespace"),
|
||||
// The keep_alive policy ensures the Python object for the
|
||||
// FiniteElementSpace is not garbage-collected while the
|
||||
// BilinearForm is still using it.
|
||||
py::keep_alive<1, 2>(),
|
||||
"Constructs a bilinear form on the given FiniteElementSpace.")
|
||||
|
||||
// --- Setup Methods ---
|
||||
.def("SetAssemblyLevel", &BilinearForm::SetAssemblyLevel, py::arg("assembly_level"),
|
||||
"Set the assembly level (e.g., LEGACY, FULL, PARTIAL, NONE).")
|
||||
.def("EnableStaticCondensation", &BilinearForm::EnableStaticCondensation,
|
||||
"Enable static condensation to reduce system size.")
|
||||
.def("EnableHybridization", &BilinearForm::EnableHybridization,
|
||||
"Enable hybridization.")
|
||||
.def("UsePrecomputedSparsity", &BilinearForm::UsePrecomputedSparsity, py::arg("ps") = 1,
|
||||
"Enable use of precomputed sparsity pattern.")
|
||||
.def("SetDiagonalPolicy", &BilinearForm::SetDiagonalPolicy, py::arg("policy"),
|
||||
"Set the policy for handling diagonal entries of essential DOFs.")
|
||||
|
||||
// --- Integrator Methods ---
|
||||
.def("AddDomainIntegrator", py::overload_cast<BilinearFormIntegrator *>(&BilinearForm::AddDomainIntegrator),
|
||||
py::arg("bfi"), py::keep_alive<1, 2>(),
|
||||
"Adds a domain integrator to the form.")
|
||||
.def("AddBoundaryIntegrator", py::overload_cast<BilinearFormIntegrator *>(&BilinearForm::AddBoundaryIntegrator),
|
||||
py::arg("bfi"), py::keep_alive<1, 2>(),
|
||||
"Adds a boundary integrator to the form.")
|
||||
.def("AddInteriorFaceIntegrator", &BilinearForm::AddInteriorFaceIntegrator,
|
||||
py::arg("bfi"), py::keep_alive<1, 2>(),
|
||||
"Adds an interior face integrator (e.g., for DG methods).")
|
||||
.def("AddBdrFaceIntegrator", py::overload_cast<BilinearFormIntegrator *>(&BilinearForm::AddBdrFaceIntegrator),
|
||||
py::arg("bfi"), py::keep_alive<1, 2>(),
|
||||
"Adds a boundary face integrator.")
|
||||
|
||||
// --- Assembly and System Formulation ---
|
||||
.def("Assemble", &BilinearForm::Assemble, py::arg("skip_zeros") = 1,
|
||||
"Assembles the bilinear form into a sparse matrix.")
|
||||
.def("AssembleDiagonal", &BilinearForm::AssembleDiagonal, py::arg("diag"),
|
||||
"Assembles the diagonal of the operator into a Vector.")
|
||||
.def("FormLinearSystem",
|
||||
[](BilinearForm &self, const Array<int> &ess_tdof_list, Vector &x,
|
||||
Vector &b, OperatorHandle &A, Vector &X, Vector &B, int copy_interior) {
|
||||
self.FormLinearSystem(ess_tdof_list, x, b, A, X, B, copy_interior);
|
||||
},
|
||||
py::arg("ess_tdof_list"), py::arg("x"), py::arg("b"), py::arg("A"),
|
||||
py::arg("X"), py::arg("B"), py::arg("copy_interior") = 0,
|
||||
"Forms the linear system AX=B, applying boundary conditions and other transformations.")
|
||||
.def("FormSystemMatrix",
|
||||
[](BilinearForm &self, const Array<int> &ess_tdof_list, OperatorHandle &A) {
|
||||
self.FormSystemMatrix(ess_tdof_list, A);
|
||||
},
|
||||
py::arg("ess_tdof_list"), py::arg("A"),
|
||||
"Forms the system matrix A, applying necessary transformations.")
|
||||
.def("RecoverFEMSolution", py::overload_cast<const Vector&, const Vector&, Vector&>(&BilinearForm::RecoverFEMSolution),
|
||||
py::arg("X"), py::arg("b"), py::arg("x"),
|
||||
"Recovers the full FE solution vector after solving a linear system.")
|
||||
|
||||
// --- Accessor Methods ---
|
||||
.def("FESpace", py::overload_cast<>(&BilinearForm::FESpace, py::const_), py::return_value_policy::reference_internal,
|
||||
"Returns a pointer to the associated FiniteElementSpace.")
|
||||
.def("SpMat", py::overload_cast<>(&BilinearForm::SpMat), py::return_value_policy::reference_internal,
|
||||
"Returns a reference to the internal sparse matrix.")
|
||||
.def("Update", &BilinearForm::Update, py::arg("nfes") = nullptr,
|
||||
"Update the BilinearForm after the FE space has changed.");
|
||||
|
||||
// You will also need to bind the OperatorHandle and DiagonalPolicy enum
|
||||
// if you haven't already. Example for DiagonalPolicy:
|
||||
py::enum_<mfem::Matrix::DiagonalPolicy>(mfem_submodule, "DiagonalPolicy")
|
||||
.value("DIAG_ZERO", mfem::Matrix::DiagonalPolicy::DIAG_ZERO)
|
||||
.value("DIAG_ONE", mfem::Matrix::DiagonalPolicy::DIAG_ONE)
|
||||
.value("DIAG_KEEP", mfem::Matrix::DiagonalPolicy::DIAG_KEEP)
|
||||
.export_values();
|
||||
}
|
||||
|
||||
// Helper function to bind the AssemblyLevel enum
|
||||
void bind_assembly_level_enum(py::module &m) {
|
||||
py::enum_<AssemblyLevel>(m, "AssemblyLevel")
|
||||
.value("LEGACY", AssemblyLevel::LEGACY)
|
||||
.value("FULL", AssemblyLevel::FULL)
|
||||
.value("ELEMENT", AssemblyLevel::ELEMENT)
|
||||
.value("PARTIAL", AssemblyLevel::PARTIAL)
|
||||
.value("NONE", AssemblyLevel::NONE)
|
||||
.export_values();
|
||||
}
|
||||
|
||||
// Main function to register MixedBilinearForm
|
||||
void register_mixed_bilinear_form_bindings(py::module &mfem_submodule) {
|
||||
|
||||
// Bind the mfem::MixedBilinearForm class, inheriting from mfem::Matrix
|
||||
// No trampoline is needed because this is a concrete class.
|
||||
py::class_<MixedBilinearForm, Matrix>(mfem_submodule, "MixedBilinearForm")
|
||||
|
||||
// --- Constructor ---
|
||||
.def(py::init<FiniteElementSpace *, FiniteElementSpace *>(),
|
||||
py::arg("trial_fespace"), py::arg("test_fespace"),
|
||||
// Keep alive policies ensure the FE space objects are not garbage
|
||||
// collected while the MixedBilinearForm is still using them.
|
||||
py::keep_alive<1, 2>(), py::keep_alive<1, 3>(),
|
||||
"Constructs a mixed bilinear form on the given trial and test FE spaces.")
|
||||
|
||||
// --- Setup Methods ---
|
||||
.def("SetAssemblyLevel", &MixedBilinearForm::SetAssemblyLevel, py::arg("assembly_level"),
|
||||
"Set the assembly level (e.g., LEGACY, FULL, PARTIAL, NONE).")
|
||||
|
||||
// --- Integrator Methods ---
|
||||
.def("AddDomainIntegrator", py::overload_cast<BilinearFormIntegrator *>(&MixedBilinearForm::AddDomainIntegrator),
|
||||
py::arg("bfi"), py::keep_alive<1, 2>(),
|
||||
"Adds a domain integrator to the form.")
|
||||
.def("AddBoundaryIntegrator", py::overload_cast<BilinearFormIntegrator *>(&MixedBilinearForm::AddBoundaryIntegrator),
|
||||
py::arg("bfi"), py::keep_alive<1, 2>(),
|
||||
"Adds a boundary integrator to the form.")
|
||||
.def("AddInteriorFaceIntegrator", &MixedBilinearForm::AddInteriorFaceIntegrator,
|
||||
py::arg("bfi"), py::keep_alive<1, 2>(),
|
||||
"Adds an interior face integrator.")
|
||||
.def("AddBdrFaceIntegrator", py::overload_cast<BilinearFormIntegrator *>(&MixedBilinearForm::AddBdrFaceIntegrator),
|
||||
py::arg("bfi"), py::keep_alive<1, 2>(),
|
||||
"Adds a boundary face integrator.")
|
||||
.def("AddTraceFaceIntegrator", &MixedBilinearForm::AddTraceFaceIntegrator,
|
||||
py::arg("bfi"), py::keep_alive<1, 2>(),
|
||||
"Adds a trace face integrator.")
|
||||
|
||||
|
||||
// --- Assembly and System Formulation ---
|
||||
.def("Assemble", &MixedBilinearForm::Assemble, py::arg("skip_zeros") = 1,
|
||||
"Assembles the mixed bilinear form into a sparse matrix.")
|
||||
.def("FormRectangularSystemMatrix",
|
||||
[](MixedBilinearForm &self, const Array<int> &trial_tdof_list, const Array<int> &test_tdof_list, OperatorHandle &A) {
|
||||
self.FormRectangularSystemMatrix(trial_tdof_list, test_tdof_list, A);
|
||||
},
|
||||
py::arg("trial_tdof_list"), py::arg("test_tdof_list"), py::arg("A"),
|
||||
"Forms the rectangular system matrix A, applying necessary transformations.")
|
||||
.def("FormRectangularLinearSystem",
|
||||
[](MixedBilinearForm &self, const Array<int> &trial_tdof_list, const Array<int> &test_tdof_list,
|
||||
Vector &x, Vector &b, OperatorHandle &A, Vector &X, Vector &B) {
|
||||
self.FormRectangularLinearSystem(trial_tdof_list, test_tdof_list, x, b, A, X, B);
|
||||
},
|
||||
py::arg("trial_tdof_list"), py::arg("test_tdof_list"), py::arg("x"), py::arg("b"),
|
||||
py::arg("A"), py::arg("X"), py::arg("B"),
|
||||
"Forms the rectangular linear system AX=B.")
|
||||
|
||||
// --- Accessor Methods ---
|
||||
.def("TrialFESpace", py::overload_cast<>(&MixedBilinearForm::TrialFESpace, py::const_),
|
||||
py::return_value_policy::reference_internal,
|
||||
"Returns a pointer to the associated trial FiniteElementSpace.")
|
||||
.def("TestFESpace", py::overload_cast<>(&MixedBilinearForm::TestFESpace, py::const_),
|
||||
py::return_value_policy::reference_internal,
|
||||
"Returns a pointer to the associated test FiniteElementSpace.")
|
||||
.def("SpMat", py::overload_cast<>(&MixedBilinearForm::SpMat),
|
||||
py::return_value_policy::reference_internal,
|
||||
"Returns a reference to the internal sparse matrix.")
|
||||
.def("Update", &MixedBilinearForm::Update,
|
||||
"Update the MixedBilinearForm after the FE spaces have changed.");
|
||||
}
|
||||
|
||||
// This function can be called from your main registration function.
|
||||
void register_mesh_bindings(py::module &mfem_submodule) {
|
||||
// Bind the mfem::Mesh class. No trampoline needed.
|
||||
py::class_<Mesh>(mfem_submodule, "Mesh")
|
||||
|
||||
// --- Constructors & Loading ---
|
||||
// Default constructor for creating an empty mesh object
|
||||
.def(py::init<>())
|
||||
// Constructor to load from a file path
|
||||
.def(py::init<const std::string &, int, int, bool>(),
|
||||
py::arg("filename"), py::arg("generate_edges") = 0,
|
||||
py::arg("refine") = 1, py::arg("fix_orientation") = true)
|
||||
// Static factory method for a more Pythonic loading interface
|
||||
.def_static("LoadFromFile", &Mesh::LoadFromFile,
|
||||
py::arg("filename"), py::arg("generate_edges") = 0,
|
||||
py::arg("refine") = 1, py::arg("fix_orientation") = true,
|
||||
"Creates a mesh by reading a file.")
|
||||
|
||||
// --- Basic Properties & Stats ---
|
||||
.def_property_readonly("dim", &Mesh::Dimension)
|
||||
.def_property_readonly("space_dim", &Mesh::SpaceDimension)
|
||||
.def_property_readonly("nv", &Mesh::GetNV, "Number of Vertices")
|
||||
.def_property_readonly("ne", &Mesh::GetNE, "Number of Elements")
|
||||
.def_property_readonly("nbe", &Mesh::GetNBE, "Number of Boundary Elements")
|
||||
.def_property_readonly("n_edges", &Mesh::GetNEdges)
|
||||
.def_property_readonly("n_faces", &Mesh::GetNFaces)
|
||||
.def_readonly("attributes", &Mesh::attributes)
|
||||
.def_readonly("bdr_attributes", &Mesh::bdr_attributes)
|
||||
.def("GetBoundingBox",
|
||||
[](Mesh &self, int ref) {
|
||||
Vector min, max;
|
||||
self.GetBoundingBox(min, max, ref);
|
||||
// Here you might want to return a tuple of numpy arrays instead
|
||||
return py::make_tuple(min, max);
|
||||
}, py::arg("ref") = 2,
|
||||
"Returns the min and max corners of the mesh bounding box.")
|
||||
|
||||
// --- Connectivity Data ---
|
||||
.def("GetElementVertices",
|
||||
[](const Mesh &self, int i) {
|
||||
Array<int> v;
|
||||
self.GetElementVertices(i, v);
|
||||
return v;
|
||||
}, py::arg("i"), "Returns a list of vertex indices for element i.")
|
||||
.def("GetBdrElementVertices",
|
||||
[](const Mesh &self, int i) {
|
||||
Array<int> v;
|
||||
self.GetBdrElementVertices(i, v);
|
||||
return v;
|
||||
}, py::arg("i"), "Returns a list of vertex indices for boundary element i.")
|
||||
.def("GetFaceElements",
|
||||
[](const Mesh &self, int i) {
|
||||
int e1, e2;
|
||||
self.GetFaceElements(i, &e1, &e2);
|
||||
return py::make_tuple(e1, e2);
|
||||
}, py::arg("face_idx"), "Returns a tuple of the two elements sharing a face.")
|
||||
.def("GetBdrElementFace",
|
||||
[](const Mesh &self, int i) {
|
||||
int f, o;
|
||||
self.GetBdrElementFace(i, &f, &o);
|
||||
return py::make_tuple(f, o);
|
||||
}, py::arg("bdr_elem_idx"), "Returns the face index and orientation for a boundary element.")
|
||||
.def("ElementToEdgeTable", &Mesh::ElementToEdgeTable, py::return_value_policy::reference_internal)
|
||||
.def("ElementToFaceTable", &Mesh::ElementToFaceTable, py::return_value_policy::reference_internal)
|
||||
|
||||
// --- Coordinate Transformations & Curvature ---
|
||||
.def("GetNodes", py::overload_cast<>(&Mesh::GetNodes, py::const_), py::return_value_policy::reference_internal,
|
||||
"Returns the GridFunction for the mesh nodes (if any).")
|
||||
.def("SetCurvature", &Mesh::SetCurvature, py::arg("order"), py::arg("discontinuous") = false,
|
||||
py::arg("space_dim") = -1, py::arg("ordering") = 1,
|
||||
"Set the curvature of the mesh, creating a high-order nodal GridFunction.")
|
||||
// Use py::overload_cast to resolve ambiguity for overloaded methods
|
||||
.def("GetElementTransformation", py::overload_cast<int>(&Mesh::GetElementTransformation), py::arg("i"),
|
||||
py::return_value_policy::reference_internal)
|
||||
.def("GetFaceElementTransformations", py::overload_cast<int, int>(&Mesh::GetFaceElementTransformations), py::arg("i"),
|
||||
py::arg("mask")=31, py::return_value_policy::reference_internal)
|
||||
|
||||
// --- I/O & Visualization ---
|
||||
.def("Save", &Mesh::Save, py::arg("filename"), py::arg("precision") = 16);
|
||||
|
||||
// Bind the VTKFormat enum used in PrintVTU
|
||||
py::enum_<VTKFormat>(mfem_submodule, "VTKFormat")
|
||||
.value("ASCII", VTKFormat::ASCII)
|
||||
.value("BINARY", VTKFormat::BINARY)
|
||||
.value("BINARY32", VTKFormat::BINARY32)
|
||||
.export_values();
|
||||
}
|
||||
|
||||
void register_table_bindings(pybind11::module &mfem_submodule) {
|
||||
// Bind mfem::Table first, as it's a return type for some Mesh methods.
|
||||
py::class_<Table>(mfem_submodule, "Table")
|
||||
.def("GetRow", [](const Table &self, int row) {
|
||||
Array<int> row_data;
|
||||
self.GetRow(row, row_data);
|
||||
// pybind11 will automatically convert Array<int> to a Python list
|
||||
return row_data;
|
||||
}, py::arg("row"), "Get a row of the table as a list.")
|
||||
.def("RowSize", &Table::RowSize, py::arg("row"), "Get the number of entries in a specific row.")
|
||||
.def_property_readonly("height", &Table::Size, "Number of rows in the table.")
|
||||
.def_property_readonly("width", &Table::Width, "Number of columns in the table.")
|
||||
.def("__len__", &Table::Size)
|
||||
.def("__getitem__", [](const Table &self, int row) {
|
||||
if (row < 0 || row >= self.Size()) {
|
||||
throw py::index_error("Row index out of bounds");
|
||||
}
|
||||
Array<int> row_data;
|
||||
self.GetRow(row, row_data);
|
||||
return row_data;
|
||||
})
|
||||
.def("__repr__", [](const Table &self) {
|
||||
return "<mfem.Table (" + std::to_string(self.Size()) + "x" +
|
||||
std::to_string(self.Width()) + ")>";
|
||||
});
|
||||
}
|
||||
|
||||
void register_finite_element_collection_bindings(pybind11::module &mfem_submodule) {
|
||||
py::class_<FiniteElementCollection>(mfem_submodule, "FiniteElementCollection")
|
||||
// No constructor for abstract base classes
|
||||
.def("Name", &FiniteElementCollection::Name)
|
||||
.def("GetOrder", &FiniteElementCollection::GetOrder);
|
||||
}
|
||||
|
||||
// Binds the mfem::BasisType class and its internal anonymous enum values
|
||||
// as static properties of a Python class. This should be called before
|
||||
// binding any class that uses BasisType values in its constructor.
|
||||
void register_basis_type_bindings(py::module &m) {
|
||||
py::class_<BasisType> basis_type(m, "BasisType", "Possible basis types.");
|
||||
|
||||
basis_type.def_property_readonly_static("Invalid", [](py::object) { return BasisType::Invalid; });
|
||||
basis_type.def_property_readonly_static("GaussLegendre", [](py::object) { return BasisType::GaussLegendre; });
|
||||
basis_type.def_property_readonly_static("GaussLobatto", [](py::object) { return BasisType::GaussLobatto; });
|
||||
basis_type.def_property_readonly_static("Positive", [](py::object) { return BasisType::Positive; });
|
||||
basis_type.def_property_readonly_static("OpenUniform", [](py::object) { return BasisType::OpenUniform; });
|
||||
basis_type.def_property_readonly_static("ClosedUniform", [](py::object) { return BasisType::ClosedUniform; });
|
||||
basis_type.def_property_readonly_static("OpenHalfUniform", [](py::object) { return BasisType::OpenHalfUniform; });
|
||||
basis_type.def_property_readonly_static("Serendipity", [](py::object) { return BasisType::Serendipity; });
|
||||
basis_type.def_property_readonly_static("ClosedGL", [](py::object) { return BasisType::ClosedGL; });
|
||||
basis_type.def_property_readonly_static("IntegratedGLL", [](py::object) { return BasisType::IntegratedGLL; });
|
||||
}
|
||||
|
||||
|
||||
|
||||
void register_H1_FECollection_bindings(py::module &m) {
|
||||
py::class_<H1_FECollection, FiniteElementCollection>(m, "H1_FECollection")
|
||||
.def(py::init([](int p, int dim) {
|
||||
// The lambda explicitly calls the constructor, avoiding the
|
||||
// default argument parsing issue.
|
||||
return new H1_FECollection(p, dim);
|
||||
}),
|
||||
py::arg("p"),
|
||||
py::arg("dim") = 3,
|
||||
"Constructs an H1 finite element collection.")
|
||||
.def("GetBasisType", &H1_FECollection::GetBasisType);
|
||||
}
|
||||
|
||||
|
||||
void register_RT_FECollection_bindings(py::module &m) {
|
||||
py::class_<RT_FECollection, FiniteElementCollection>(m, "RT_FECollection")
|
||||
.def(py::init<const int, const int>(),
|
||||
py::arg("p"), py::arg("dim"),
|
||||
"Constructs a Raviart-Thomas H(div)-conforming FE collection.")
|
||||
.def("GetClosedBasisType", &RT_FECollection::GetClosedBasisType)
|
||||
.def("GetOpenBasisType", &RT_FECollection::GetOpenBasisType);
|
||||
}
|
||||
|
||||
void register_ND_FECollection_bindings(py::module &m) {
|
||||
py::class_<ND_FECollection, FiniteElementCollection>(m, "ND_FECollection")
|
||||
.def(py::init<const int, const int>(),
|
||||
py::arg("p"), py::arg("dim"),
|
||||
"Constructs a Nedelec H(curl)-conforming FE collection.")
|
||||
.def("GetClosedBasisType", &ND_FECollection::GetClosedBasisType)
|
||||
.def("GetOpenBasisType", &ND_FECollection::GetOpenBasisType);
|
||||
}
|
||||
|
||||
|
||||
void register_finite_element_space_bindings(py::module &mfem_submodule) {
|
||||
// Bind dependent enums first
|
||||
bind_ordering_enum(mfem_submodule);
|
||||
|
||||
// Bind the mfem::FiniteElementSpace class
|
||||
py::class_<FiniteElementSpace>(mfem_submodule, "FiniteElementSpace")
|
||||
// --- Constructors ---
|
||||
.def(py::init<Mesh *, const FiniteElementCollection *, int, int>(),
|
||||
py::arg("mesh"), py::arg("fec"), py::arg("vdim") = 1, py::arg("ordering") = Ordering::byNODES,
|
||||
// Keep alive policies prevent the mesh and fec from being
|
||||
// garbage collected by Python while the C++ FESpace object exists.
|
||||
py::keep_alive<1, 2>(), py::keep_alive<1, 3>())
|
||||
|
||||
// --- Core Properties & Stats ---
|
||||
.def_property_readonly("ndofs", &FiniteElementSpace::GetNDofs, "Number of local scalar degrees of freedom.")
|
||||
.def_property_readonly("vdim", &FiniteElementSpace::GetVDim, "Vector dimension of the space.")
|
||||
.def_property_readonly("vsize", &FiniteElementSpace::GetVSize, "Total number of local vector degrees of freedom.")
|
||||
.def_property_readonly("true_vsize", &FiniteElementSpace::GetTrueVSize, "Number of true (conforming) vector degrees of freedom.")
|
||||
.def("GetMesh", &FiniteElementSpace::GetMesh, py::return_value_policy::reference_internal, "Get the associated Mesh.")
|
||||
.def("FEColl", &FiniteElementSpace::FEColl, py::return_value_policy::reference_internal, "Get the associated FiniteElementCollection.")
|
||||
|
||||
// --- DOF Management ---
|
||||
.def("GetElementDofs",
|
||||
[](const FiniteElementSpace &self, int i) {
|
||||
Array<int> dofs;
|
||||
self.GetElementDofs(i, dofs);
|
||||
return dofs;
|
||||
}, py::arg("elem_idx"), "Get the local scalar DOFs for a given element.")
|
||||
.def("GetElementVDofs",
|
||||
[](const FiniteElementSpace &self, int i) {
|
||||
Array<int> vdofs;
|
||||
self.GetElementVDofs(i, vdofs);
|
||||
return vdofs;
|
||||
}, py::arg("elem_idx"), "Get the local vector DOFs for a given element.")
|
||||
.def("GetBdrElementDofs",
|
||||
[](const FiniteElementSpace &self, int i) {
|
||||
Array<int> dofs;
|
||||
self.GetBdrElementDofs(i, dofs);
|
||||
return dofs;
|
||||
}, py::arg("bdr_elem_idx"), "Get the local scalar DOFs for a given boundary element.")
|
||||
.def("GetEssentialVDofs",
|
||||
[](const FiniteElementSpace &self, const Array<int> &bdr_attr_is_ess, int component) {
|
||||
Array<int> ess_vdofs;
|
||||
self.GetEssentialVDofs(bdr_attr_is_ess, ess_vdofs, component);
|
||||
return ess_vdofs;
|
||||
}, py::arg("bdr_attr_is_ess"), py::arg("component") = -1,
|
||||
"Get a list of essential (Dirichlet) vector DOFs based on boundary attributes.")
|
||||
.def("GetEssentialTrueDofs",
|
||||
[](const FiniteElementSpace &self, const Array<int> &bdr_attr_is_ess, int component) {
|
||||
Array<int> ess_tdof_list;
|
||||
self.GetEssentialTrueDofs(bdr_attr_is_ess, ess_tdof_list, component);
|
||||
return ess_tdof_list;
|
||||
}, py::arg("bdr_attr_is_ess"), py::arg("component") = -1,
|
||||
"Get a list of essential true (conforming) DOFs for use in linear systems.")
|
||||
|
||||
// --- Updating & Operators ---
|
||||
.def("Update", &FiniteElementSpace::Update, py::arg("want_transform") = true,
|
||||
"Update the FE space after the mesh has been modified (e.g., refined).")
|
||||
.def("GetUpdateOperator", py::overload_cast<>(&FiniteElementSpace::GetUpdateOperator),
|
||||
py::return_value_policy::reference_internal,
|
||||
"Get the operator that maps GridFunctions from the old space to the new space after an Update.")
|
||||
.def("GetProlongationMatrix", &FiniteElementSpace::GetProlongationMatrix,
|
||||
py::return_value_policy::reference_internal, "Get the P operator (true DOFs to local DOFs).")
|
||||
.def("GetRestrictionOperator", &FiniteElementSpace::GetRestrictionOperator,
|
||||
py::return_value_policy::reference_internal, "Get the R operator (local DOFs to true DOFs).")
|
||||
|
||||
.def("__repr__", [](const FiniteElementSpace &fes) {
|
||||
std::stringstream ss;
|
||||
ss << "<mfem.FiniteElementSpace with " << fes.GetNDofs() << " DOFs, vdim=" << fes.GetVDim() << ">";
|
||||
return ss.str();
|
||||
});
|
||||
}
|
||||
|
||||
void bind_ordering_enum(py::module &mfem_submodule) {
|
||||
py::enum_<Ordering::Type>(mfem_submodule, "Ordering")
|
||||
.value("byNODES", Ordering::byNODES)
|
||||
.value("byVDIM", Ordering::byVDIM)
|
||||
.export_values();
|
||||
}
|
||||
|
||||
/// Binds mfem::GridFunction
|
||||
void register_grid_function_bindings(py::module &m) {
|
||||
// Assumes that Vector, FiniteElementSpace, Coefficient, VectorCoefficient,
|
||||
// IntegrationRule, and ElementTransformation are already bound.
|
||||
|
||||
py::class_<GridFunction, Vector>(m, "GridFunction")
|
||||
.def(py::init<FiniteElementSpace *>(), py::arg("fespace"),
|
||||
py::keep_alive<1, 2>(), // Keep FE space alive
|
||||
"Construct a GridFunction on a given FiniteElementSpace.")
|
||||
|
||||
.def(py::init<FiniteElementSpace *, real_t *>(),
|
||||
py::arg("fespace"), py::arg("data"),
|
||||
py::keep_alive<1, 2>(),
|
||||
"Construct a GridFunction using previously allocated data.")
|
||||
|
||||
.def("FESpace", py::overload_cast<>(&GridFunction::FESpace, py::const_),
|
||||
py::return_value_policy::reference_internal,
|
||||
"Returns the associated FiniteElementSpace.")
|
||||
.def("Update", &GridFunction::Update,
|
||||
"Update the GridFunction after its FE space has been modified.")
|
||||
.def("SetSpace", &GridFunction::SetSpace, py::arg("fespace"),
|
||||
py::keep_alive<1, 2>(), "Associate a new FE space with the GridFunction.")
|
||||
|
||||
// --- Projection Methods ---
|
||||
.def("ProjectCoefficient",
|
||||
py::overload_cast<Coefficient &>(&GridFunction::ProjectCoefficient),
|
||||
py::arg("coeff"),
|
||||
"Project a scalar Coefficient onto the GridFunction.")
|
||||
.def("ProjectCoefficient",
|
||||
py::overload_cast<VectorCoefficient &>(&GridFunction::ProjectCoefficient),
|
||||
py::arg("vcoeff"),
|
||||
"Project a vector Coefficient onto the GridFunction.")
|
||||
.def("ProjectBdrCoefficient",
|
||||
py::overload_cast<Coefficient &, const Array<int> &>(&GridFunction::ProjectBdrCoefficient),
|
||||
py::arg("coeff"), py::arg("attr"),
|
||||
"Project a scalar Coefficient onto the boundary degrees of freedom.")
|
||||
.def("ProjectBdrCoefficient",
|
||||
py::overload_cast<VectorCoefficient &, const Array<int> &>(&GridFunction::ProjectBdrCoefficient),
|
||||
py::arg("vcoeff"), py::arg("attr"),
|
||||
"Project a vector Coefficient onto the boundary degrees of freedom.")
|
||||
|
||||
// --- Evaluation Methods ---
|
||||
.def("GetValue", py::overload_cast<ElementTransformation &, const IntegrationPoint &, int, Vector *>(&GridFunction::GetValue, py::const_),
|
||||
py::arg("T"), py::arg("ip"), py::arg("comp") = 0, py::arg("tr") = nullptr,
|
||||
"Get the scalar value at a point described by an ElementTransformation.")
|
||||
.def("GetVectorValue",
|
||||
[](const GridFunction &self, ElementTransformation &T, const IntegrationPoint &ip) {
|
||||
Vector val;
|
||||
self.GetVectorValue(T, ip, val);
|
||||
return val;
|
||||
},
|
||||
py::arg("T"), py::arg("ip"),
|
||||
"Get the vector value at a point described by an ElementTransformation.")
|
||||
.def("GetValues",
|
||||
[](const GridFunction &self, ElementTransformation &T, const IntegrationRule &ir, int comp) {
|
||||
Vector vals;
|
||||
self.GetValues(T, ir, vals, comp);
|
||||
return vals;
|
||||
},
|
||||
py::arg("T"), py::arg("ir"), py::arg("comp") = 0,
|
||||
"Get scalar values at all points of an IntegrationRule.")
|
||||
.def("GetVectorValues",
|
||||
[](const GridFunction &self, ElementTransformation &T, const IntegrationRule &ir) {
|
||||
DenseMatrix vals;
|
||||
self.GetVectorValues(T, ir, vals);
|
||||
return vals; // pybind11 handles DenseMatrix
|
||||
},
|
||||
py::arg("T"), py::arg("ir"),
|
||||
"Get vector values at all points of an IntegrationRule.")
|
||||
|
||||
// --- True DOF (constrained system) Methods ---
|
||||
.def("GetTrueDofs", &GridFunction::GetTrueDofs, py::arg("tv"),
|
||||
"Extract the true-dofs from the GridFunction.")
|
||||
.def("SetFromTrueDofs", &GridFunction::SetFromTrueDofs, py::arg("tv"),
|
||||
"Set the GridFunction from a true-dof vector.")
|
||||
|
||||
// --- I/O ---
|
||||
.def("Save", py::overload_cast<const char *, int>(&GridFunction::Save, py::const_),
|
||||
py::arg("fname"), py::arg("precision") = 16)
|
||||
.def("__repr__", [](const GridFunction &gf) {
|
||||
std::stringstream ss;
|
||||
ss << "<mfem.GridFunction with " << gf.Size() << " DOFs>";
|
||||
return ss.str();
|
||||
})
|
||||
.def("as_numpy", [](const GridFunction &self) {
|
||||
// Convert the GridFunction to a numpy array
|
||||
return py::array_t<real_t>(self.Size(), self.GetData());
|
||||
}, "Convert the GridFunction to a numpy array.");
|
||||
}
|
||||
|
||||
// This function registers the coefficient classes to the python module
|
||||
void register_coefficient_bindings(py::module &m) {
|
||||
|
||||
// --- Bind abstract base classes using trampolines ---
|
||||
py::class_<Coefficient, serif::pybind::PyCoefficient> coefficient(m, "Coefficient");
|
||||
coefficient
|
||||
.def(py::init<>())
|
||||
.def("SetTime", &Coefficient::SetTime, py::arg("t"))
|
||||
.def("GetTime", &Coefficient::GetTime)
|
||||
.def("Eval", py::overload_cast<ElementTransformation &, const IntegrationPoint&>(&Coefficient::Eval),
|
||||
"Evaluate the coefficient at a point in an element.",
|
||||
py::arg("T"), py::arg("ip"));
|
||||
|
||||
py::class_<VectorCoefficient, serif::pybind::PyVectorCoefficient> vector_coefficient(m, "VectorCoefficient");
|
||||
vector_coefficient
|
||||
.def(py::init<int>(), py::arg("vdim"))
|
||||
.def("SetTime", &VectorCoefficient::SetTime, py::arg("t"))
|
||||
.def("GetTime", &VectorCoefficient::GetTime)
|
||||
.def("GetVDim", &VectorCoefficient::GetVDim)
|
||||
.def("Eval", py::overload_cast<Vector &, ElementTransformation &, const IntegrationPoint &>(&VectorCoefficient::Eval),
|
||||
"Evaluate the vector coefficient at a point in an element.",
|
||||
py::arg("V"), py::arg("T"), py::arg("ip"));
|
||||
|
||||
// --- Bind useful concrete classes ---
|
||||
|
||||
// ConstantCoefficient
|
||||
py::class_<ConstantCoefficient, Coefficient>(m, "ConstantCoefficient")
|
||||
.def(py::init<real_t>(), py::arg("c") = 1.0)
|
||||
.def_readwrite("constant", &ConstantCoefficient::constant);
|
||||
|
||||
// FunctionCoefficient (allows using Python functions as coefficients)
|
||||
py::class_<FunctionCoefficient, Coefficient>(m, "FunctionCoefficient")
|
||||
.def(py::init<std::function<real_t(const Vector &)>>(),
|
||||
py::arg("F"), "Create a coefficient from a Python function of space.")
|
||||
.def(py::init<std::function<real_t(const Vector &, real_t)>>(),
|
||||
py::arg("TDF"), "Create a coefficient from a Python function of space and time.");
|
||||
|
||||
// VectorConstantCoefficient
|
||||
py::class_<VectorConstantCoefficient, VectorCoefficient>(m, "VectorConstantCoefficient")
|
||||
.def(py::init<const Vector &>(), py::arg("v"));
|
||||
|
||||
// VectorFunctionCoefficient
|
||||
py::class_<VectorFunctionCoefficient, VectorCoefficient>(m, "VectorFunctionCoefficient")
|
||||
.def(py::init<int, std::function<void(const Vector &, Vector &)>>(),
|
||||
py::arg("dim"), py::arg("F"), "Create a vector coefficient from a Python function of space.")
|
||||
.def(py::init<int, std::function<void(const Vector &, real_t, Vector &)>>(),
|
||||
py::arg("dim"), py::arg("TDF"), "Create a vector coefficient from a Python function of space and time.");
|
||||
|
||||
// GridFunctionCoefficient (useful for source terms that depend on the solution)
|
||||
py::class_<GridFunctionCoefficient, Coefficient>(m, "GridFunctionCoefficient")
|
||||
.def(py::init<>())
|
||||
.def(py::init<const GridFunction *>(), py::arg("gf"))
|
||||
.def("SetGridFunction", &GridFunctionCoefficient::SetGridFunction, py::arg("gf"))
|
||||
.def("GetGridFunction", &GridFunctionCoefficient::GetGridFunction, py::return_value_policy::reference);
|
||||
}
|
||||
|
||||
/// Binds mfem::IntegrationPoint and mfem::IntegrationRule
|
||||
void register_intrule_bindings(py::module &m) {
|
||||
py::class_<IntegrationPoint>(m, "IntegrationPoint")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("x", &IntegrationPoint::x)
|
||||
.def_readwrite("y", &IntegrationPoint::y)
|
||||
.def_readwrite("z", &IntegrationPoint::z)
|
||||
.def_readwrite("weight", &IntegrationPoint::weight)
|
||||
.def("__repr__", [](const IntegrationPoint &ip) {
|
||||
std::stringstream ss;
|
||||
ss << "IP(x=" << ip.x << ", y=" << ip.y << ", z=" << ip.z << ", w=" << ip.weight << ")";
|
||||
return ss.str();
|
||||
});
|
||||
|
||||
py::class_<IntegrationRule>(m, "IntegrationRule")
|
||||
.def(py::init<>())
|
||||
.def(py::init<int>(), py::arg("NumPoints"))
|
||||
.def("GetNPoints", &IntegrationRule::GetNPoints)
|
||||
.def("GetOrder", &IntegrationRule::GetOrder)
|
||||
.def("__len__", &IntegrationRule::GetNPoints)
|
||||
.def("__getitem__", [](const IntegrationRule &self, int i) {
|
||||
if (i < 0 || i >= self.GetNPoints()) throw py::index_error();
|
||||
return self.IntPoint(i);
|
||||
})
|
||||
.def("__iter__", [](const IntegrationRule &self) {
|
||||
return py::make_iterator(self.begin(), self.end());
|
||||
}, py::keep_alive<0, 1>());
|
||||
}
|
||||
|
||||
/// Binds mfem::ElementTransformation and related classes
|
||||
void register_eltrans_bindings(py::module &m) {
|
||||
// Bind the base class. Users get pointers to this, but never create it.
|
||||
// It is not abstract in the C++ sense, but we treat it as such for bindings.
|
||||
py::class_<ElementTransformation> eltrans(m, "ElementTransformation");
|
||||
eltrans
|
||||
.def_readonly("Attribute", &ElementTransformation::Attribute)
|
||||
.def_readonly("ElementNo", &ElementTransformation::ElementNo)
|
||||
.def("SetIntPoint", &ElementTransformation::SetIntPoint, py::arg("ip"))
|
||||
.def("Transform", py::overload_cast<const IntegrationPoint &, Vector &>(&ElementTransformation::Transform),
|
||||
py::arg("ip"), py::arg("transip"))
|
||||
.def("Weight", &ElementTransformation::Weight)
|
||||
// Properties that return references to internal data should use reference_internal policy
|
||||
.def_property_readonly("Jacobian", &ElementTransformation::Jacobian, py::return_value_policy::reference_internal)
|
||||
.def_property_readonly("InverseJacobian", &ElementTransformation::InverseJacobian, py::return_value_policy::reference_internal)
|
||||
.def_property_readonly("AdjugateJacobian", &ElementTransformation::AdjugateJacobian, py::return_value_policy::reference_internal);
|
||||
|
||||
// Bind IsoparametricTransformation, which is a concrete type of ElementTransformation
|
||||
py::class_<IsoparametricTransformation, ElementTransformation>(m, "IsoparametricTransformation")
|
||||
.def(py::init<>());
|
||||
|
||||
// Bind FaceElementTransformations, crucial for DG methods
|
||||
py::class_<FaceElementTransformations, ElementTransformation>(m, "FaceElementTransformations")
|
||||
.def(py::init<>())
|
||||
.def_readonly("Elem1No", &FaceElementTransformations::Elem1No)
|
||||
.def_readonly("Elem2No", &FaceElementTransformations::Elem2No)
|
||||
.def_property_readonly("Elem1", [](FaceElementTransformations &self) { return self.Elem1; }, py::return_value_policy::reference)
|
||||
.def_property_readonly("Elem2", [](FaceElementTransformations &self) { return self.Elem2; }, py::return_value_policy::reference)
|
||||
.def("GetElement1IntPoint", &FaceElementTransformations::GetElement1IntPoint)
|
||||
.def("GetElement2IntPoint", &FaceElementTransformations::GetElement2IntPoint);
|
||||
|
||||
// Bind the enum used for selecting transformations
|
||||
py::enum_<FaceElementTransformations::ConfigMasks>(eltrans, "ConfigMasks")
|
||||
.value("HAVE_ELEM1", FaceElementTransformations::HAVE_ELEM1)
|
||||
.value("HAVE_ELEM2", FaceElementTransformations::HAVE_ELEM2)
|
||||
.value("HAVE_LOC1", FaceElementTransformations::HAVE_LOC1)
|
||||
.value("HAVE_LOC2", FaceElementTransformations::HAVE_LOC2)
|
||||
.value("HAVE_FACE", FaceElementTransformations::HAVE_FACE)
|
||||
.export_values();
|
||||
}
|
||||
307
src/python/mfem/bindings.h
Normal file
307
src/python/mfem/bindings.h
Normal file
@@ -0,0 +1,307 @@
|
||||
/**
|
||||
* @file bindings.h
|
||||
* @brief Declares functions to register MFEM core library components with pybind11.
|
||||
*
|
||||
* This header file lists the functions responsible for creating Python bindings
|
||||
* for various parts of the MFEM library. Each function typically registers
|
||||
* a set of related classes, enums, or functionalities to a pybind11::module,
|
||||
* which is expected to be a submodule named `mfem` within the main `serif` Python module.
|
||||
*
|
||||
* @see /Users/tboudreaux/Programming/SERiF/src/python/bindings.cpp for how these are used.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
/**
|
||||
* @brief Registers all core MFEM bindings to the given Python submodule.
|
||||
*
|
||||
* This function serves as the main entry point for exposing MFEM functionalities
|
||||
* to Python. It calls various other `register_*_bindings` and `bind_*_enum`
|
||||
* functions to populate the `mfem_submodule`.
|
||||
*
|
||||
* @param mfem_submodule The pybind11 module (typically `serif.mfem`) to which
|
||||
* MFEM bindings will be added.
|
||||
* @par Python Usage Example:
|
||||
* @code{.py}
|
||||
* import serif.mfem as mfem
|
||||
* # Now mfem.Operator, mfem.Vector, mfem.Mesh, etc., are accessible.
|
||||
* vec = mfem.Vector(10)
|
||||
* print(vec.Size())
|
||||
* @endcode
|
||||
*/
|
||||
void register_mfem_bindings(pybind11::module &mfem_submodule);
|
||||
|
||||
/**
|
||||
* @brief Registers mfem::Operator and related classes.
|
||||
* @param mfem_submodule The `serif.mfem` Python submodule.
|
||||
* @par Python Usage Example:
|
||||
* @code{.py}
|
||||
* import serif.mfem as mfem
|
||||
* # Assuming PyOperator trampoline is bound
|
||||
* # op = mfem.Operator() # Or a derived class like mfem.DenseMatrix
|
||||
* @endcode
|
||||
*/
|
||||
void register_operator_bindings(pybind11::module &mfem_submodule);
|
||||
|
||||
/**
|
||||
* @brief Registers mfem::Matrix and its derived classes (e.g., mfem::DenseMatrix, mfem::SparseMatrix).
|
||||
* @param mfem_submodule The `serif.mfem` Python submodule.
|
||||
* @par Python Usage Example:
|
||||
* @code{.py}
|
||||
* import serif.mfem as mfem
|
||||
* mat = mfem.DenseMatrix(2, 2)
|
||||
* mat[0,0] = 1.0
|
||||
* mat[0,1] = 2.0
|
||||
* mat[1,0] = 3.0
|
||||
* mat[1,1] = 4.0
|
||||
* mat.Print()
|
||||
* @endcode
|
||||
*/
|
||||
void register_matrix_bindings(pybind11::module &mfem_submodule);
|
||||
|
||||
/**
|
||||
* @brief Registers mfem::Vector.
|
||||
* @param mfem_submodule The `serif.mfem` Python submodule.
|
||||
* @par Python Usage Example:
|
||||
* @code{.py}
|
||||
* import serif.mfem as mfem
|
||||
* vec = mfem.Vector(5)
|
||||
* vec[0] = 1.5
|
||||
* print(vec.Size(), vec[0])
|
||||
* @endcode
|
||||
*/
|
||||
void register_vector_bindings(pybind11::module &mfem_submodule);
|
||||
|
||||
/**
|
||||
* @brief Registers mfem::Array.
|
||||
* @param mfem_submodule The `serif.mfem` Python submodule.
|
||||
* @par Python Usage Example:
|
||||
* @code{.py}
|
||||
* import serif.mfem as mfem
|
||||
* arr_int = mfem.intArray(5) # Assuming intArray is a typedef or specific binding
|
||||
* arr_int[0] = 10
|
||||
* print(arr_int.Size(), arr_int[0])
|
||||
* @endcode
|
||||
*/
|
||||
void register_array_bindings(pybind11::module &mfem_submodule);
|
||||
|
||||
/**
|
||||
* @brief Binds the mfem::AssemblyLevel enum.
|
||||
* @param mfem_submodule The `serif.mfem` Python submodule.
|
||||
* @par Python Usage Example:
|
||||
* @code{.py}
|
||||
* import serif.mfem as mfem
|
||||
* # level = mfem.AssemblyLevel.LEGACY # Or other enum values
|
||||
* # print(level)
|
||||
* @endcode
|
||||
*/
|
||||
void bind_assembly_level_enum(pybind11::module &mfem_submodule);
|
||||
|
||||
/**
|
||||
* @brief Registers mfem::BilinearForm and related functionalities.
|
||||
* @param mfem_submodule The `serif.mfem` Python submodule.
|
||||
* @par Python Usage Example:
|
||||
* @code{.py}
|
||||
* import serif.mfem as mfem
|
||||
* # Assuming FiniteElementSpace (fes) is created
|
||||
* # bform = mfem.BilinearForm(fes)
|
||||
* # bform.AddDomainIntegrator(mfem.MassIntegrator()) # Assuming MassIntegrator is bound
|
||||
* # bform.Assemble()
|
||||
* # A = bform.SpMat()
|
||||
* @endcode
|
||||
*/
|
||||
void register_bilinear_form_bindings(pybind11::module &mfem_submodule);
|
||||
|
||||
/**
|
||||
* @brief Registers mfem::MixedBilinearForm and related functionalities.
|
||||
* @param mfem_submodule The `serif.mfem` Python submodule.
|
||||
* @par Python Usage Example:
|
||||
* @code{.py}
|
||||
* import serif.mfem as mfem
|
||||
* # Assuming trial_fes and test_fes are FiniteElementSpaces
|
||||
* # mbform = mfem.MixedBilinearForm(trial_fes, test_fes)
|
||||
* # mbform.AddDomainIntegrator(mfem.VectorFEMassIntegrator()) # Example
|
||||
* # mbform.Assemble()
|
||||
* @endcode
|
||||
*/
|
||||
void register_mixed_bilinear_form_bindings(pybind11::module &mfem_submodule);
|
||||
|
||||
/**
|
||||
* @brief Registers mfem::Table.
|
||||
* @param mfem_submodule The `serif.mfem` Python submodule.
|
||||
* @par Python Usage Example:
|
||||
* @code{.py}
|
||||
* import serif.mfem as mfem
|
||||
* table = mfem.Table()
|
||||
* # ... use table methods ...
|
||||
* @endcode
|
||||
*/
|
||||
void register_table_bindings(pybind11::module &mfem_submodule);
|
||||
|
||||
/**
|
||||
* @brief Registers mfem::Mesh.
|
||||
* @param mfem_submodule The `serif.mfem` Python submodule.
|
||||
* @par Python Usage Example:
|
||||
* @code{.py}
|
||||
* import serif.mfem as mfem
|
||||
* mesh = mfem.Mesh.MakeCartesian1D(10) # Example constructor
|
||||
* print(mesh.Dimension(), mesh.GetNE())
|
||||
* @endcode
|
||||
*/
|
||||
void register_mesh_bindings(pybind11::module &mfem_submodule);
|
||||
|
||||
/**
|
||||
* @brief Registers mfem::BasisType enum and related constants.
|
||||
* @param mfem_submodule The `serif.mfem` Python submodule.
|
||||
* @par Python Usage Example:
|
||||
* @code{.py}
|
||||
* import serif.mfem as mfem
|
||||
* basis_type = mfem.BasisType.GaussLobatto
|
||||
* print(basis_type)
|
||||
* @endcode
|
||||
*/
|
||||
void register_basis_type_bindings(pybind11::module &mfem_submodule);
|
||||
|
||||
/**
|
||||
* @brief Registers mfem::FiniteElementCollection base class.
|
||||
* @param mfem_submodule The `serif.mfem` Python submodule.
|
||||
* @par Python Usage Example:
|
||||
* @code{.py}
|
||||
* import serif.mfem as mfem
|
||||
* # fec = mfem.FiniteElementCollection() # Typically use derived classes
|
||||
* @endcode
|
||||
*/
|
||||
void register_finite_element_collection_bindings(pybind11::module &mfem_submodule);
|
||||
|
||||
/**
|
||||
* @brief Registers mfem::H1_FECollection.
|
||||
* @param mfem_submodule The `serif.mfem` Python submodule.
|
||||
* @par Python Usage Example:
|
||||
* @code{.py}
|
||||
* import serif.mfem as mfem
|
||||
* order = 1
|
||||
* dim = 2
|
||||
* fec = mfem.H1_FECollection(order, dim)
|
||||
* print(fec.GetName())
|
||||
* @endcode
|
||||
*/
|
||||
void register_H1_FECollection_bindings(pybind11::module &mfem_submodule);
|
||||
|
||||
/**
|
||||
* @brief Registers mfem::RT_FECollection.
|
||||
* @param mfem_submodule The `serif.mfem` Python submodule.
|
||||
* @par Python Usage Example:
|
||||
* @code{.py}
|
||||
* import serif.mfem as mfem
|
||||
* order = 1
|
||||
* dim = 2
|
||||
* fec = mfem.RT_FECollection(order-1, dim) # RT_FECollection uses p-1 for order p
|
||||
* print(fec.GetName())
|
||||
* @endcode
|
||||
*/
|
||||
void register_RT_FECollection_bindings(pybind11::module &mfem_submodule);
|
||||
|
||||
/**
|
||||
* @brief Registers mfem::ND_FECollection (Nedelec finite elements).
|
||||
* @param mfem_submodule The `serif.mfem` Python submodule.
|
||||
* @par Python Usage Example:
|
||||
* @code{.py}
|
||||
* import serif.mfem as mfem
|
||||
* order = 1
|
||||
* dim = 3
|
||||
* fec = mfem.ND_FECollection(order, dim)
|
||||
* print(fec.GetName())
|
||||
* @endcode
|
||||
*/
|
||||
void register_ND_FECollection_bindings(pybind11::module &mfem_submodule);
|
||||
|
||||
/**
|
||||
* @brief Binds the mfem::Ordering::Type enum.
|
||||
* @param mfem_submodule The `serif.mfem` Python submodule.
|
||||
* @par Python Usage Example:
|
||||
* @code{.py}
|
||||
* import serif.mfem as mfem
|
||||
* ordering = mfem.Ordering.byNODES
|
||||
* print(ordering)
|
||||
* @endcode
|
||||
*/
|
||||
void bind_ordering_enum(pybind11::module &mfem_submodule);
|
||||
|
||||
/**
|
||||
* @brief Registers mfem::FiniteElementSpace.
|
||||
* @param mfem_submodule The `serif.mfem` Python submodule.
|
||||
* @par Python Usage Example:
|
||||
* @code{.py}
|
||||
* import serif.mfem as mfem
|
||||
* mesh = mfem.Mesh.MakeCartesian1D(5)
|
||||
* fec = mfem.H1_FECollection(1, mesh.Dimension())
|
||||
* fes = mfem.FiniteElementSpace(mesh, fec)
|
||||
* print(fes.GetNDofs())
|
||||
* @endcode
|
||||
*/
|
||||
void register_finite_element_space_bindings(pybind11::module &mfem_submodule);
|
||||
|
||||
/**
|
||||
* @brief Registers mfem::Coefficient, mfem::VectorCoefficient and related classes/trampolines.
|
||||
* @param m The `serif.mfem` Python submodule.
|
||||
* @par Python Usage Example:
|
||||
* @code{.py}
|
||||
* import serif.mfem as mfem
|
||||
* const_coeff = mfem.ConstantCoefficient(2.5)
|
||||
* # vec_coeff = mfem.VectorConstantCoefficient(mfem.Vector([1.0, 2.0]))
|
||||
*
|
||||
* # Using a Python-derived coefficient (if PyCoefficient trampoline is bound)
|
||||
* class MyCoeff(mfem.Coefficient):
|
||||
* def Eval(self, T, ip):
|
||||
* return 1.0
|
||||
* my_c = MyCoeff()
|
||||
* @endcode
|
||||
*/
|
||||
void register_coefficient_bindings(pybind11::module &m);
|
||||
|
||||
/**
|
||||
* @brief Registers mfem::ElementTransformation.
|
||||
* @param m The `serif.mfem` Python submodule.
|
||||
* @par Python Usage Example:
|
||||
* @code{.py}
|
||||
* import serif.mfem as mfem
|
||||
* # ElementTransformation objects are usually obtained from Mesh or FiniteElementSpace
|
||||
* # mesh = mfem.Mesh.MakeCartesian1D(1)
|
||||
* # el_trans = mesh.GetElementTransformation(0)
|
||||
* # print(el_trans.ElementNo)
|
||||
* @endcode
|
||||
*/
|
||||
void register_eltrans_bindings(pybind11::module &m);
|
||||
|
||||
/**
|
||||
* @brief Registers mfem::IntegrationRule and mfem::IntegrationPoint.
|
||||
* @param m The `serif.mfem` Python submodule.
|
||||
* @par Python Usage Example:
|
||||
* @code{.py}
|
||||
* import serif.mfem as mfem
|
||||
* # Get a standard integration rule
|
||||
* ir = mfem.IntRules.Get(mfem.Geometry.SEGMENT, 3) # Order 3 for a segment
|
||||
* for i in range(ir.GetNPoints()):
|
||||
* ip = ir.IntPoint(i)
|
||||
* # print(f"Point {i}: coords {ip.x}, weight {ip.weight}")
|
||||
* @endcode
|
||||
*/
|
||||
void register_intrule_bindings(pybind11::module &m);
|
||||
|
||||
/**
|
||||
* @brief Registers mfem::GridFunction.
|
||||
* @param mfem_submodule The `serif.mfem` Python submodule.
|
||||
* @par Python Usage Example:
|
||||
* @code{.py}
|
||||
* import serif.mfem as mfem
|
||||
* mesh = mfem.Mesh.MakeCartesian1D(5)
|
||||
* fec = mfem.H1_FECollection(1, mesh.Dimension())
|
||||
* fes = mfem.FiniteElementSpace(mesh, fec)
|
||||
* gf = mfem.GridFunction(fes)
|
||||
* gf.Assign(0.0) # Set all values to 0
|
||||
* print(gf.Size())
|
||||
* @endcode
|
||||
*/
|
||||
void register_grid_function_bindings(pybind11::module &mfem_submodule);
|
||||
|
||||
26
src/python/mfem/meson.build
Normal file
26
src/python/mfem/meson.build
Normal file
@@ -0,0 +1,26 @@
|
||||
subdir('Trampoline')
|
||||
|
||||
# Define the library
|
||||
bindings_sources = files(
|
||||
'bindings.cpp',
|
||||
)
|
||||
bindings_headers = files(
|
||||
'bindings.h',
|
||||
)
|
||||
|
||||
dependencies = [
|
||||
config_dep,
|
||||
resourceManager_dep,
|
||||
python3_dep,
|
||||
pybind11_dep,
|
||||
mpi_dep,
|
||||
trampoline_dep,
|
||||
]
|
||||
|
||||
shared_module('py_mfem',
|
||||
bindings_sources,
|
||||
include_directories: include_directories('.'),
|
||||
cpp_args: ['-fvisibility=default'],
|
||||
install : true,
|
||||
dependencies: dependencies,
|
||||
)
|
||||
26
src/python/polytrope/bindings.cpp
Normal file
26
src/python/polytrope/bindings.cpp
Normal file
@@ -0,0 +1,26 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
|
||||
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc if needed directly
|
||||
#include <pybind11/numpy.h>
|
||||
|
||||
#include "bindings.h"
|
||||
#include "EOSio.h"
|
||||
#include "helm.h"
|
||||
#include "polySolver.h"
|
||||
#include "mfem.hpp"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void register_polytrope_bindings(pybind11::module &polytrope_submodule) {
|
||||
py::class_<serif::polytrope::PolySolver>(polytrope_submodule, "PolySolver")
|
||||
.def(py::init<double, int>(), py::arg("polytropic_index"), py::arg("FEM_order"))
|
||||
.def("solve", &serif::polytrope::PolySolver::solve, "Solve the polytrope equation.")
|
||||
.def("get_theta", &serif::polytrope::PolySolver::getTheta, py::return_value_policy::reference_internal)
|
||||
.def("get_phi", &serif::polytrope::PolySolver::getPhi, py::return_value_policy::reference_internal)
|
||||
.def("get_order", &serif::polytrope::PolySolver::getOrder)
|
||||
.def("get_n", &serif::polytrope::PolySolver::getN);
|
||||
|
||||
py::class_<serif::polytrope::PolytropeOperator, mfem::Operator>(polytrope_submodule, "PolytropeOperator")
|
||||
.def("Mult", &serif::polytrope::PolytropeOperator::Mult);
|
||||
|
||||
}
|
||||
31
src/python/polytrope/bindings.h
Normal file
31
src/python/polytrope/bindings.h
Normal file
@@ -0,0 +1,31 @@
|
||||
/**
|
||||
* @file bindings.h
|
||||
* @brief Declares the function to register polytrope module C++ components with pybind11.
|
||||
*
|
||||
* This file contains the declaration for `register_polytrope_bindings`, which is responsible
|
||||
* for creating Python bindings for classes and functions within the `serif::polytrope` C++
|
||||
* namespace. These bindings will be accessible in Python under the `serif.polytrope` submodule.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
/**
|
||||
* @brief Registers C++ classes and functions from the `serif::polytrope` namespace to Python.
|
||||
*
|
||||
* This function takes a pybind11::module object, representing the `serif.polytrope` Python submodule,
|
||||
* and adds bindings for various components like `PolytropeOperator`, `PolySolver`, etc.
|
||||
* This allows these C++ components to be instantiated and used directly from Python.
|
||||
*
|
||||
* @param polytrope_submodule The pybind11 module (typically `serif.polytrope`) to which
|
||||
* the polytrope C++ bindings will be added.
|
||||
*
|
||||
* @par Python Usage Example:
|
||||
* After these bindings are registered and the Python module is imported:
|
||||
* @code{.py}
|
||||
* from serif.polytrope import PolySolver
|
||||
* polytrope = PolySolver(1.5, 1)
|
||||
* polytrope.solve()
|
||||
* @endcode
|
||||
*/
|
||||
void register_polytrope_bindings(pybind11::module &polytrope_submodule);
|
||||
19
src/python/polytrope/meson.build
Normal file
19
src/python/polytrope/meson.build
Normal file
@@ -0,0 +1,19 @@
|
||||
# Define the library
|
||||
bindings_sources = files('bindings.cpp')
|
||||
bindings_headers = files('bindings.h')
|
||||
|
||||
dependencies = [
|
||||
polysolver_dep,
|
||||
config_dep,
|
||||
resourceManager_dep,
|
||||
python3_dep,
|
||||
pybind11_dep,
|
||||
]
|
||||
|
||||
shared_module('py_polytrope',
|
||||
bindings_sources,
|
||||
include_directories: include_directories('.'),
|
||||
cpp_args: ['-fvisibility=default'],
|
||||
install : true,
|
||||
dependencies: dependencies,
|
||||
)
|
||||
11
tests/python/polytrope/loadMesh.py
Normal file
11
tests/python/polytrope/loadMesh.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from serif.mfem import Mesh
|
||||
import argparse
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Test loading a mesh using MFEM's mesh loading called from Python")
|
||||
parser.add_argument("path", type=str, help="path to mesh")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
mesh = Mesh(args.path)
|
||||
print(mesh.nv)
|
||||
18
tests/python/polytrope/runPolytrope.py
Normal file
18
tests/python/polytrope/runPolytrope.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from serif import config
|
||||
from serif.polytrope import PolySolver
|
||||
from serif.mfem import Matrix
|
||||
from serif.mfem import Operator
|
||||
from serif.mfem.forms import BilinearForm
|
||||
|
||||
config.loadConfig('../../testsConfig.yaml')
|
||||
n = config.get("Tests:Poly:Index", 0.0)
|
||||
|
||||
polytrope = PolySolver(n, 1)
|
||||
polytrope.solve()
|
||||
theta = polytrope.get_theta()
|
||||
print(theta)
|
||||
|
||||
FESpace = theta.FESpace()
|
||||
print(FESpace)
|
||||
|
||||
print(theta.as_numpy())
|
||||
Reference in New Issue
Block a user