diff --git a/build-python/meson.build b/build-python/meson.build index 48236f5..a09fd03 100644 --- a/build-python/meson.build +++ b/build-python/meson.build @@ -9,6 +9,11 @@ py_mod = py_installation.extension_module( meson.project_source_root() + '/src/python/const/bindings.cpp', meson.project_source_root() + '/src/python/config/bindings.cpp', meson.project_source_root() + '/src/python/eos/bindings.cpp', + meson.project_source_root() + '/src/python/mfem/bindings.cpp', + meson.project_source_root() + '/src/python/mfem/Trampoline/PyMFEMTrampolines/Operator/PyOperator.cpp', + meson.project_source_root() + '/src/python/mfem/Trampoline/PyMFEMTrampolines/Operator/Matrix/PyMatrix.cpp', + meson.project_source_root() + '/src/python/mfem/Trampoline/PyMFEMTrampolines/Coefficient/PyCoefficient.cpp', + meson.project_source_root() + '/src/python/polytrope/bindings.cpp', ], dependencies : [ pybind11_dep, @@ -16,7 +21,10 @@ py_mod = py_installation.extension_module( config_dep, composition_dep, eos_dep, - species_weight_dep + species_weight_dep, + mfem_dep, + polysolver_dep, + trampoline_dep ], cpp_args : ['-UNDEBUG'], # Example: Ensure assertions are enabled if needed install : true, diff --git a/src/polytrope/solver/private/polySolver.cpp b/src/polytrope/solver/private/polySolver.cpp index 3b44d55..372830e 100644 --- a/src/polytrope/solver/private/polySolver.cpp +++ b/src/polytrope/solver/private/polySolver.cpp @@ -39,8 +39,8 @@ #include "utilities.h" #include "quill/LogMacros.h" -namespace serif { -namespace polytrope { + +namespace serif::polytrope { namespace laneEmden { @@ -382,7 +382,6 @@ void PolySolver::setInitialGuess() const { serif::probe::glVisView(*m_theta, m_mesh, "θ init"); serif::probe::glVisView(*m_phi, m_mesh, "φ init"); } - std::cout << "HERE" << std::endl; } @@ -497,5 +496,5 @@ solverBundle PolySolver::setupNewtonSolver() const { return solver; } -} // namespace polytrope -} // namespace serif +} // namespace serif::polytrope + diff --git a/src/polytrope/solver/public/polySolver.h b/src/polytrope/solver/public/polySolver.h index 8dcadd3..3e56cb2 100644 --- a/src/polytrope/solver/public/polySolver.h +++ b/src/polytrope/solver/public/polySolver.h @@ -271,7 +271,14 @@ public: // Public methods * @return A reference to the `mfem::GridFunction` storing the \f$\theta\f$ solution. * @note The solution is populated after `solve()` has been successfully called. */ - mfem::GridFunction& getSolution() const { return *m_theta; } + mfem::GridFunction& getTheta() const { return *m_theta; } + + /** + * @brief Gets a reference to the solution grid function for \f$\phi\f$. + * @return A reference to the `mfem::GridFunction` storing the \f$\phi\f$ solution. + * @note The solution is populated after `solve()` has been successfully called. + */ + mfem::GridFunction& getPhi() const { return *m_phi; } private: // Private Attributes // --- Configuration and Logging --- diff --git a/src/polytrope/utils/private/utilities.cpp b/src/polytrope/utils/private/utilities.cpp index 74ae125..d4a121f 100644 --- a/src/polytrope/utils/private/utilities.cpp +++ b/src/polytrope/utils/private/utilities.cpp @@ -119,7 +119,6 @@ namespace serif::utilities { mfem::IsoparametricTransformation T; mesh->GetElementTransformation(ne, &T); phi_gf.GetCurl(T, curl_mag_vec); - std::cout << "HERE" << std::endl; } mfem::L2_FECollection fac(order, dim); mfem::FiniteElementSpace fs(mesh, &fac); diff --git a/src/polytrope/utils/public/integrators.h b/src/polytrope/utils/public/integrators.h index 4d787e1..a821edd 100644 --- a/src/polytrope/utils/public/integrators.h +++ b/src/polytrope/utils/public/integrators.h @@ -31,27 +31,26 @@ * @brief A collection of utilities for working with MFEM and solving the lane-emden equation. */ -namespace serif { -namespace polytrope { -/** +namespace serif::polytrope { + /** * @namespace polyMFEMUtils * @brief A namespace for utilities for working with MFEM and solving the lane-emden equation. */ -namespace polyMFEMUtils { - /** + namespace polyMFEMUtils { + /** * @brief A class for nonlinear power integrator. */ - class NonlinearPowerIntegrator: public mfem::NonlinearFormIntegrator { - public: - /** + class NonlinearPowerIntegrator: public mfem::NonlinearFormIntegrator { + public: + /** * @brief Constructor for NonlinearPowerIntegrator. * * @param coeff The function coefficient. * @param n The polytropic index. */ - NonlinearPowerIntegrator(double n); + NonlinearPowerIntegrator(double n); - /** + /** * @brief Assembles the element vector. * * @param el The finite element. @@ -59,8 +58,8 @@ namespace polyMFEMUtils { * @param elfun The element function. * @param elvect The element vector to be assembled. */ - virtual void AssembleElementVector(const mfem::FiniteElement &el, mfem::ElementTransformation &Trans, const mfem::Vector &elfun, mfem::Vector &elvect) override; - /** + virtual void AssembleElementVector(const mfem::FiniteElement &el, mfem::ElementTransformation &Trans, const mfem::Vector &elfun, mfem::Vector &elvect) override; + /** * @brief Assembles the element gradient. * * @param el The finite element. @@ -68,40 +67,39 @@ namespace polyMFEMUtils { * @param elfun The element function. * @param elmat The element matrix to be assembled. */ - virtual void AssembleElementGrad (const mfem::FiniteElement &el, mfem::ElementTransformation &Trans, const mfem::Vector &elfun, mfem::DenseMatrix &elmat) override; - private: - serif::config::Config& m_config = serif::config::Config::getInstance(); - serif::probe::LogManager& m_logManager = serif::probe::LogManager::getInstance(); - quill::Logger* m_logger = m_logManager.getLogger("log"); - double m_polytropicIndex; - double m_epsilon; - static constexpr double m_regularizationRadius = 0.15; ///< Regularization radius for the epsilon function, used to avoid singularities in the power law. - static constexpr double m_regularizationCoeff = 1.0/6.0; ///< Coefficient for the regularization term, used to ensure smoothness in the power law. - }; + virtual void AssembleElementGrad (const mfem::FiniteElement &el, mfem::ElementTransformation &Trans, const mfem::Vector &elfun, mfem::DenseMatrix &elmat) override; + private: + serif::config::Config& m_config = serif::config::Config::getInstance(); + serif::probe::LogManager& m_logManager = serif::probe::LogManager::getInstance(); + quill::Logger* m_logger = m_logManager.getLogger("log"); + double m_polytropicIndex; + double m_epsilon; + static constexpr double m_regularizationRadius = 0.15; ///< Regularization radius for the epsilon function, used to avoid singularities in the power law. + static constexpr double m_regularizationCoeff = 1.0/6.0; ///< Coefficient for the regularization term, used to ensure smoothness in the power law. + }; - inline double dfmod(const double epsilon, const double n) { - if (n == 0.0) { - return 0.0; + inline double dfmod(const double epsilon, const double n) { + if (n == 0.0) { + return 0.0; + } + if (n == 1.0) { + return 1.0; + } + return n * std::pow(epsilon, n - 1.0); } - if (n == 1.0) { - return 1.0; + + inline double fmod(const double theta, const double n, const double epsilon) { + if (n == 0.0) { + return 1.0; + } + // For n != 0 + const double y_prime_at_epsilon = dfmod(epsilon, n); // Uses the robust dfmod + const double y_at_epsilon = std::pow(epsilon, n); // epsilon^n + + // f_mod(theta) = y_at_epsilon + y_prime_at_epsilon * (theta - epsilon) + return y_at_epsilon + y_prime_at_epsilon * (theta - epsilon); } - return n * std::pow(epsilon, n - 1.0); - } - - inline double fmod(const double theta, const double n, const double epsilon) { - if (n == 0.0) { - return 1.0; - } - // For n != 0 - const double y_prime_at_epsilon = dfmod(epsilon, n); // Uses the robust dfmod - const double y_at_epsilon = std::pow(epsilon, n); // epsilon^n - - // f_mod(theta) = y_at_epsilon + y_prime_at_epsilon * (theta - epsilon) - return y_at_epsilon + y_prime_at_epsilon * (theta - epsilon); - } -} // namespace polyMFEMUtils -} // namespace polytrope -} // namespace serif \ No newline at end of file + } // namespace polyMFEMUtils +} diff --git a/src/polytrope/utils/public/polytropeOperator.h b/src/polytrope/utils/public/polytropeOperator.h index eefa505..71209ad 100644 --- a/src/polytrope/utils/public/polytropeOperator.h +++ b/src/polytrope/utils/public/polytropeOperator.h @@ -26,8 +26,7 @@ #include "probe.h" -namespace serif { -namespace polytrope { +namespace serif::polytrope { /** * @brief Represents the Schur complement operator used in the solution process. @@ -299,7 +298,7 @@ private: std::unique_ptr m_M; ///< Bilinear form M, coupling θ and φ. std::unique_ptr m_Q; ///< Bilinear form Q, coupling φ and θ. std::unique_ptr m_D; ///< Bilinear form D, acting on φ. - std::unique_ptr m_S; + std::unique_ptr m_S; ///< Bilinear form S, used for least squares stabilization. std::unique_ptr m_f; ///< Nonlinear form f, acting on θ. // --- Full Matrix Representations (owned, derived from forms) --- @@ -395,5 +394,4 @@ private: void update_preconditioner(const mfem::Operator &grad) const; }; -} // namespace polytrope -} // namespace serif \ No newline at end of file +} // namespace serif::polytrope \ No newline at end of file diff --git a/src/python/bindings.cpp b/src/python/bindings.cpp index 411c3f6..610f55a 100644 --- a/src/python/bindings.cpp +++ b/src/python/bindings.cpp @@ -6,6 +6,8 @@ #include "composition/bindings.h" #include "config/bindings.h" #include "eos/bindings.h" +#include "mfem/bindings.h" +#include "polytrope/bindings.h" PYBIND11_MODULE(serif, m) { m.doc() = "Python bindings for the SERiF project"; @@ -21,4 +23,10 @@ PYBIND11_MODULE(serif, m) { auto eosMod = m.def_submodule("eos", "EOS-module bindings"); register_eos_bindings(eosMod); + + auto mfemMod = m.def_submodule("mfem", "MFEM bindings"); + register_mfem_bindings(mfemMod); + + auto polytropeMod = m.def_submodule("polytrope", "Polytrope-module bindings"); + register_polytrope_bindings(polytropeMod); } \ No newline at end of file diff --git a/src/python/eos/bindings.cpp b/src/python/eos/bindings.cpp index 40e31d9..e8c6314 100644 --- a/src/python/eos/bindings.cpp +++ b/src/python/eos/bindings.cpp @@ -5,12 +5,9 @@ #include #include "helm.h" -// #include "resourceManager.h" #include "bindings.h" #include "EOSio.h" #include "helm.h" -#include "../../eos/public/EOSio.h" -#include "../../eos/public/helm.h" namespace serif::eos { class EOSio; diff --git a/src/python/meson.build b/src/python/meson.build index c865fc2..6c8df90 100644 --- a/src/python/meson.build +++ b/src/python/meson.build @@ -1,4 +1,6 @@ subdir('composition') subdir('const') subdir('config') + +subdir('mfem') subdir('eos') \ No newline at end of file diff --git a/src/python/mfem/Trampoline/PyMFEMTrampolines/Coefficient/PyCoefficient.cpp b/src/python/mfem/Trampoline/PyMFEMTrampolines/Coefficient/PyCoefficient.cpp new file mode 100644 index 0000000..818d2a4 --- /dev/null +++ b/src/python/mfem/Trampoline/PyMFEMTrampolines/Coefficient/PyCoefficient.cpp @@ -0,0 +1,51 @@ +#include "PyCoefficient.h" + +#include +#include +#include // Needed for std::function +#include + +#include "mfem.hpp" + +namespace py = pybind11; +using namespace mfem; + +namespace serif::pybind { + real_t PyCoefficient::Eval(ElementTransformation &T, const IntegrationPoint &ip) { + PYBIND11_OVERRIDE_PURE( + real_t, /* Return type */ + Coefficient, /* Base class */ + Eval, /* Method name */ + T, ip /* Arguments */ + ); + } + + // Override virtual SetTime method + void PyCoefficient::SetTime(real_t t) { + PYBIND11_OVERRIDE( + void, + Coefficient, + SetTime, + t + ); + } + + void PyVectorCoefficient::Eval(Vector &V, ElementTransformation &T, const IntegrationPoint &ip) { + PYBIND11_OVERRIDE_PURE( + void, /* Return type */ + VectorCoefficient, /* Base class */ + Eval, /* Method name */ + V, T, ip /* Arguments */ + ); + } + + // Override the virtual SetTime method + void PyVectorCoefficient::SetTime(real_t t) { + PYBIND11_OVERRIDE( + void, + VectorCoefficient, + SetTime, + t + ); + } +} \ No newline at end of file diff --git a/src/python/mfem/Trampoline/PyMFEMTrampolines/Coefficient/PyCoefficient.h b/src/python/mfem/Trampoline/PyMFEMTrampolines/Coefficient/PyCoefficient.h new file mode 100644 index 0000000..ee83a91 --- /dev/null +++ b/src/python/mfem/Trampoline/PyMFEMTrampolines/Coefficient/PyCoefficient.h @@ -0,0 +1,162 @@ +/** + * @file PyCoefficient.h + * @brief Defines pybind11 trampoline classes for mfem::Coefficient and mfem::VectorCoefficient. + * + * These trampoline classes allow Python classes to inherit from mfem::Coefficient + * and mfem::VectorCoefficient, enabling Python-defined coefficients to be used + * within the C++ MFEM library. + */ +#include +#include +#include // Needed for std::function +#include + +#include "mfem.hpp" + +namespace py = pybind11; +using namespace mfem; + +/** + * @namespace serif::pybind + * @brief Contains pybind11 helper classes and trampoline classes for interfacing C++ with Python. + */ +namespace serif::pybind { + /** + * @brief Trampoline class for mfem::Coefficient. + * + * This class allows Python classes to inherit from mfem::Coefficient and override + * its virtual methods. This is essential for creating custom coefficients in Python + * that can be used by MFEM's C++ backend. + * + * @see mfem::Coefficient + * + * @par Python Usage Example: + * @code{.py} + * import mfem.ser_ext as mfem + * + * class MyPythonCoefficient(mfem.Coefficient): + * def __init__(self): + * super().__init__() # Call the base C++ constructor + * + * def Eval(self, T, ip): + * # T is an mfem.ElementTransformation + * # ip is an mfem.IntegrationPoint + * # Example: return a constant value + * return 1.0 + * + * def SetTime(self, t): + * # Optionally handle time-dependent coefficients + * super().SetTime(t) # Call base class method + * print(f"Time set to: {t}") + * + * # Using the Python coefficient + * py_coeff = MyPythonCoefficient() + * # py_coeff can now be passed to MFEM functions expecting an mfem::Coefficient + * @endcode + */ + class PyCoefficient : public Coefficient { + public: + using Coefficient::Coefficient; /**< Inherit constructors from mfem::Coefficient. */ + + /** + * @brief Evaluate the coefficient at a given IntegrationPoint in an ElementTransformation. + * + * This method is called by MFEM when the value of the coefficient is needed. + * If a Python class inherits from PyCoefficient, it *must* override this method. + * + * @param T The element transformation. + * @param ip The integration point. + * @return The value of the coefficient at the given point. + * + * @note This method forwards the call to the Python override. + * PYBIND11_OVERRIDE_PURE is used in the .cpp file to handle this. + */ + real_t Eval(ElementTransformation &T, const IntegrationPoint &ip) override; + + /** + * @brief Set the current time for time-dependent coefficients. + * + * This method is called by MFEM to update the time for time-dependent coefficients. + * Python classes inheriting from PyCoefficient can override this method to implement + * time-dependent behavior. + * + * @param t The current time. + * + * @note This method forwards the call to the Python override if one exists. + * PYBIND11_OVERRIDE is used in the .cpp file to handle this. + */ + void SetTime(real_t t) override; + }; + + /** + * @brief Trampoline class for mfem::VectorCoefficient. + * + * This class allows Python classes to inherit from mfem::VectorCoefficient and override + * its virtual methods. This is essential for creating custom vector-valued coefficients + * in Python that can be used by MFEM's C++ backend. + * + * @see mfem::VectorCoefficient + * + * @par Python Usage Example: + * @code{.py} + * import mfem.ser_ext as mfem + * import numpy as np + * + * class MyPythonVectorCoefficient(mfem.VectorCoefficient): + * def __init__(self, dim): + * super().__init__(dim) # Call the base C++ constructor, pass vector dimension + * self.dim = dim + * + * def Eval(self, V, T, ip): + * # V is an mfem.Vector (output parameter, must be filled) + * # T is an mfem.ElementTransformation + * # ip is an mfem.IntegrationPoint + * # Example: return a constant vector [1.0, 2.0, ...] + * for i in range(self.dim): + * V[i] = float(i + 1) + * + * def SetTime(self, t): + * super().SetTime(t) + * print(f"VectorCoefficient time set to: {t}") + * + * # Using the Python vector coefficient + * vec_dim = 2 + * py_vec_coeff = MyPythonVectorCoefficient(vec_dim) + * # py_vec_coeff can now be passed to MFEM functions expecting an mfem::VectorCoefficient + * @endcode + */ + class PyVectorCoefficient : public VectorCoefficient { + public: + using VectorCoefficient::VectorCoefficient; /**< Inherit constructors from mfem::VectorCoefficient. */ + + /** + * @brief Evaluate the vector coefficient at a given IntegrationPoint in an ElementTransformation. + * + * This method is called by MFEM when the value of the vector coefficient is needed. + * If a Python class inherits from PyVectorCoefficient, it *must* override this method. + * The result should be stored in the output Vector @p V. + * + * @param V Output vector to store the result. Its size should match the coefficient's dimension. + * @param T The element transformation. + * @param ip The integration point. + * + * @note This method forwards the call to the Python override. + * PYBIND11_OVERRIDE_PURE is used in the .cpp file to handle this. + */ + void Eval(Vector &V, ElementTransformation &T, const IntegrationPoint &ip) override; + + /** + * @brief Set the current time for time-dependent vector coefficients. + * + * This method is called by MFEM to update the time for time-dependent vector coefficients. + * Python classes inheriting from PyVectorCoefficient can override this method to implement + * time-dependent behavior. + * + * @param t The current time. + * + * @note This method forwards the call to the Python override if one exists. + * PYBIND11_OVERRIDE is used in the .cpp file to handle this. + */ + void SetTime(real_t t) override; + }; +} \ No newline at end of file diff --git a/src/python/mfem/Trampoline/PyMFEMTrampolines/Coefficient/meson.build b/src/python/mfem/Trampoline/PyMFEMTrampolines/Coefficient/meson.build new file mode 100644 index 0000000..f51c37c --- /dev/null +++ b/src/python/mfem/Trampoline/PyMFEMTrampolines/Coefficient/meson.build @@ -0,0 +1,5 @@ +PyCoefficient_sources = files( + 'PyCoefficient.cpp' +) + +trampoline_sources += PyCoefficient_sources diff --git a/src/python/mfem/Trampoline/PyMFEMTrampolines/Operator/Matrix/PyMatrix.cpp b/src/python/mfem/Trampoline/PyMFEMTrampolines/Operator/Matrix/PyMatrix.cpp new file mode 100644 index 0000000..418d448 --- /dev/null +++ b/src/python/mfem/Trampoline/PyMFEMTrampolines/Operator/Matrix/PyMatrix.cpp @@ -0,0 +1,137 @@ +#include "PyMFEMTrampolines/Operator/Matrix/PyMatrix.h" +#include "mfem.hpp" + +namespace serif::pybind { + + // --- Trampolines for new mfem::Matrix pure virtual methods --- + + mfem::real_t& PyMatrix::Elem(int i, int j) { + PYBIND11_OVERRIDE_PURE( + mfem::real_t&, /* Return type */ + mfem::Matrix, /* C++ base class */ + Elem, /* C++ function name */ + i, /* Argument 1 */ + j /* Argument 2 */ + ); + } + + const mfem::real_t& PyMatrix::Elem(int i, int j) const { + PYBIND11_OVERRIDE_PURE( + const mfem::real_t&, + mfem::Matrix, + Elem, + i, + j + ); + } + + mfem::MatrixInverse* PyMatrix::Inverse() const { + PYBIND11_OVERRIDE_PURE( + mfem::MatrixInverse*, + mfem::Matrix, + Inverse + ); + } + + // --- Trampoline for new mfem::Matrix regular virtual methods --- + + void PyMatrix::Finalize(int skip_zeros) { + PYBIND11_OVERRIDE( + void, + mfem::Matrix, + Finalize, + skip_zeros + ); + } + + // --- Trampolines for inherited mfem::Operator virtual methods --- + + void PyMatrix::Mult(const mfem::Vector &x, mfem::Vector &y) const { + // This remains PURE as mfem::Matrix does not implement it. + PYBIND11_OVERRIDE_PURE( + void, + mfem::Matrix, + Mult, + x, + y + ); + } + + void PyMatrix::MultTranspose(const mfem::Vector &x, mfem::Vector &y) const { + PYBIND11_OVERRIDE( + void, + mfem::Matrix, + MultTranspose, + x, + y + ); + } + + void PyMatrix::AddMult(const mfem::Vector &x, mfem::Vector &y, const mfem::real_t a) const { + PYBIND11_OVERRIDE( + void, + mfem::Matrix, + AddMult, + x, + y, + a + ); + } + + void PyMatrix::AddMultTranspose(const mfem::Vector &x, mfem::Vector &y, const mfem::real_t a) const { + PYBIND11_OVERRIDE( + void, + mfem::Matrix, + AddMultTranspose, + x, + y, + a + ); + } + + mfem::Operator& PyMatrix::GetGradient(const mfem::Vector &x) const { + PYBIND11_OVERRIDE( + mfem::Operator&, + mfem::Matrix, + GetGradient, + x + ); + } + + void PyMatrix::AssembleDiagonal(mfem::Vector &diag) const { + PYBIND11_OVERRIDE( + void, + mfem::Matrix, + AssembleDiagonal, + diag + ); + } + + const mfem::Operator* PyMatrix::GetProlongation() const { + PYBIND11_OVERRIDE( + const mfem::Operator*, + mfem::Matrix, + GetProlongation + ); + } + + const mfem::Operator* PyMatrix::GetRestriction() const { + PYBIND11_OVERRIDE( + const mfem::Operator*, + mfem::Matrix, + GetRestriction + ); + } + + void PyMatrix::RecoverFEMSolution(const mfem::Vector &X, const mfem::Vector &b, mfem::Vector &x) { + PYBIND11_OVERRIDE( + void, + mfem::Matrix, + RecoverFEMSolution, + X, + b, + x + ); + } + +} // namespace serif::pybind diff --git a/src/python/mfem/Trampoline/PyMFEMTrampolines/Operator/Matrix/PyMatrix.h b/src/python/mfem/Trampoline/PyMFEMTrampolines/Operator/Matrix/PyMatrix.h new file mode 100644 index 0000000..ad11f6e --- /dev/null +++ b/src/python/mfem/Trampoline/PyMFEMTrampolines/Operator/Matrix/PyMatrix.h @@ -0,0 +1,253 @@ +/** + * @file PyMatrix.h + * @brief Defines a pybind11 trampoline class for mfem::Matrix. + * + * This trampoline class allows Python classes to inherit from mfem::Matrix, + * enabling Python-defined matrices to be used within the C++ MFEM library, + * including overriding methods from its base class mfem::Operator. + */ +#pragma once + +#include "mfem.hpp" +#include "pybind11/pybind11.h" + +namespace serif::pybind { + + /** + * @brief A trampoline class for mfem::Matrix. + * + * This class allows Python classes to inherit from mfem::Matrix and correctly + * override its virtual functions, including those inherited from its base, + * mfem::Operator. This is useful for creating custom matrix types in Python + * that can interact seamlessly with MFEM's C++ components. + * + * @see mfem::Matrix + * @see mfem::Operator + * @see PyOperator + * + * @par Python Usage Example: + * @code{.py} + * import mfem.ser_ext as mfem + * import numpy as np + * + * class MyPythonMatrix(mfem.Matrix): + * def __init__(self, size=0): + * super().__init__(size) # Call base C++ constructor + * if size > 0: + * # Example: Initialize with some data if size is provided + * # Note: Direct data manipulation might be complex due to + * # C++ ownership. This example is conceptual. + * # For a real-world scenario, you might manage data in Python + * # and use Eval or Mult to reflect its state. + * self._py_data = np.zeros((size, size)) + * else: + * self._py_data = None + * + * # --- mfem.Matrix pure virtual overrides --- + * def Elem(self, i, j): + * # This is problematic for direct override for write access + * # due to returning a reference. + * # Typically, you'd manage data and use it in Mult/GetRow etc. + * # For read access (const version), it's more straightforward. + * if self._py_data is not None: + * return self._py_data[i, j] + * raise IndexError("Matrix data not initialized or index out of bounds") + * + * # const Elem version for read access + * # def GetElem(self, i, j): # A helper in Python if Elem is tricky + * # return self._py_data[i,j] + * + * def Inverse(self): + * # Return an mfem.MatrixInverse object + * # This is a simplified example; a real inverse is complex. + * print("MyPythonMatrix.Inverse() called, returning dummy inverse.") + * # For a real implementation, you'd compute and return an actual + * # mfem.MatrixInverse or a Python-derived version of it. + * # This might involve creating a new PyMatrix representing the inverse. + * identity_inv = mfem.DenseMatrix(self.Width()) + * for i in range(self.Width()): + * identity_inv[i,i] = 1.0 + * return mfem.MatrixInverse(identity_inv) # Example + * + * # --- mfem.Matrix regular virtual overrides --- + * def Finalize(self, skip_zeros=1): + * super().Finalize(skip_zeros) # Call base + * print(f"MyPythonMatrix.Finalize({skip_zeros}) called.") + * + * # --- mfem.Operator virtual overrides --- + * def Mult(self, x, y): + * # x is mfem.Vector (input), y is mfem.Vector (output) + * if self._py_data is not None: + * x_np = x.GetDataArray() + * y_np = np.dot(self._py_data, x_np) + * y.SetSize(self._py_data.shape[0]) + * y.Assign(y_np) + * else: + * # Fallback to base class if no Python data + * # super().Mult(x,y) # This might not work as expected if base is abstract + * # or if the C++ mfem.Matrix itself doesn't have data. + * # For a sparse matrix, this would call its sparse Mult. + * # For a dense matrix, it would use its data. + * # If this PyMatrix is purely Python defined, it must implement Mult. + * raise NotImplementedError("Mult not implemented or data not set") + * print("MyPythonMatrix.Mult called.") + * + * # Using the Python Matrix + * mat_size = 3 + * py_mat = MyPythonMatrix(mat_size) + * py_mat._py_data = np.array([[1,2,3],[4,5,6],[7,8,9]], dtype=float) + * + * x_vec = mfem.Vector(mat_size) + * x_vec.Assign(np.array([1,1,1], dtype=float)) + * y_vec = mfem.Vector(mat_size) + * + * py_mat.Mult(x_vec, y_vec) + * print("Result y:", y_vec.GetDataArray()) + * + * # inv_op = py_mat.Inverse() # Be cautious with dummy implementations + * # print("Inverse type:", type(inv_op)) + * + * # Note: Overriding Elem(i,j) to return a C++ reference from Python + * # for write access (mfem::real_t&) is complex and often not directly feasible + * # in a straightforward way with pybind11 for typical Python data structures. + * # Python users would typically interact via methods like Set, Add, or by + * # providing data that the C++ side can access (e.g., via GetData). + * # The const version of Elem (read-only) is more manageable. + * @endcode + */ + class PyMatrix : public mfem::Matrix { + public: + // Inherit constructors from the base mfem::Matrix class. + using mfem::Matrix::Matrix; + + // --- Trampolines for new mfem::Matrix pure virtual methods --- + /** + * @brief Access element (i,j) for read/write. + * Pure virtual in mfem::Matrix. Must be overridden in Python. + * @param i Row index. + * @param j Column index. + * @return Reference to the matrix element. + * @note Returning a C++ reference from Python for write access can be complex. + * Consider alternative data management strategies in Python. + * PYBIND11_OVERRIDE_PURE is used in the .cpp file. + */ + mfem::real_t& Elem(int i, int j) override; + + /** + * @brief Access element (i,j) for read-only. + * Pure virtual in mfem::Matrix. Must be overridden in Python. + * @param i Row index. + * @param j Column index. + * @return Const reference to the matrix element. + * @note PYBIND11_OVERRIDE_PURE is used in the .cpp file. + */ + const mfem::real_t& Elem(int i, int j) const override; + + /** + * @brief Get the inverse of the matrix. + * Pure virtual in mfem::Matrix. Must be overridden in Python. + * The caller is responsible for deleting the returned MatrixInverse object. + * @return Pointer to an mfem::MatrixInverse object. + * @note PYBIND11_OVERRIDE_PURE is used in the .cpp file. + */ + mfem::MatrixInverse* Inverse() const override; + + // --- Trampoline for new mfem::Matrix regular virtual methods --- + /** + * @brief Finalize matrix assembly. + * For sparse matrices, this typically involves finalizing the sparse structure. + * Can be overridden in Python. + * @param skip_zeros See mfem::SparseMatrix::Finalize documentation. + * @note PYBIND11_OVERRIDE is used in the .cpp file. + */ + void Finalize(int skip_zeros) override; + + // --- Trampolines for inherited mfem::Operator virtual methods --- + // These must be repeated here to allow Python classes inheriting from + // Matrix to override methods originally from the Operator base class. + + /** + * @brief Perform the operator action: y = A*x. + * Inherited from mfem::Operator. Can be overridden in Python. + * If not overridden, mfem::Matrix's default implementation is used. + * @param x The input vector. + * @param y The output vector (result of A*x). + * @note PYBIND11_OVERRIDE is used in the .cpp file. + */ + void Mult(const mfem::Vector &x, mfem::Vector &y) const override; + + /** + * @brief Perform the transpose operator action: y = A^T*x. + * Inherited from mfem::Operator. Can be overridden in Python. + * @param x The input vector. + * @param y The output vector (result of A^T*x). + * @note PYBIND11_OVERRIDE is used in the .cpp file. + */ + void MultTranspose(const mfem::Vector &x, mfem::Vector &y) const override; + + /** + * @brief Perform the action y += a*(A*x). + * Inherited from mfem::Operator. Can be overridden in Python. + * @param x The input vector. + * @param y The vector to which a*(A*x) is added. + * @param a Scalar multiplier (defaults to 1.0). + * @note PYBIND11_OVERRIDE is used in the .cpp file. + */ + void AddMult(const mfem::Vector &x, mfem::Vector &y, const mfem::real_t a = 1.0) const override; + + /** + * @brief Perform the action y += a*(A^T*x). + * Inherited from mfem::Operator. Can be overridden in Python. + * @param x The input vector. + * @param y The vector to which a*(A^T*x) is added. + * @param a Scalar multiplier (defaults to 1.0). + * @note PYBIND11_OVERRIDE is used in the .cpp file. + */ + void AddMultTranspose(const mfem::Vector &x, mfem::Vector &y, const mfem::real_t a = 1.0) const override; + + /** + * @brief Get the gradient operator (Jacobian) at a given point x. + * Inherited from mfem::Operator. Can be overridden in Python. + * For a linear matrix operator, the gradient is typically the matrix itself. + * @param x The point at which to evaluate the gradient (often unused for linear operators). + * @return A reference to the gradient operator. + * @note PYBIND11_OVERRIDE is used in the .cpp file. + */ + mfem::Operator& GetGradient(const mfem::Vector &x) const override; + + /** + * @brief Assemble the diagonal of the operator. + * Inherited from mfem::Operator. Can be overridden in Python. + * @param diag Output vector to store the diagonal entries. + * @note PYBIND11_OVERRIDE is used in the .cpp file. + */ + void AssembleDiagonal(mfem::Vector &diag) const override; + + /** + * @brief Get the prolongation operator. + * Inherited from mfem::Operator. Can be overridden in Python. + * @return A const pointer to the prolongation operator, or nullptr if not applicable. + * @note PYBIND11_OVERRIDE is used in the .cpp file. + */ + const mfem::Operator* GetProlongation() const override; + + /** + * @brief Get the restriction operator. + * Inherited from mfem::Operator. Can be overridden in Python. + * @return A const pointer to the restriction operator, or nullptr if not applicable. + * @note PYBIND11_OVERRIDE is used in the .cpp file. + */ + const mfem::Operator* GetRestriction() const override; + + /** + * @brief Recover the FEM solution. + * Inherited from mfem::Operator. Can be overridden in Python. + * @param X The reduced solution vector. + * @param b The right-hand side vector. + * @param x Output vector for the full FEM solution. + * @note PYBIND11_OVERRIDE is used in the .cpp file. + */ + void RecoverFEMSolution(const mfem::Vector &X, const mfem::Vector &b, mfem::Vector &x) override; + }; + +} // namespace serif::pybind \ No newline at end of file diff --git a/src/python/mfem/Trampoline/PyMFEMTrampolines/Operator/Matrix/meson.build b/src/python/mfem/Trampoline/PyMFEMTrampolines/Operator/Matrix/meson.build new file mode 100644 index 0000000..878a773 --- /dev/null +++ b/src/python/mfem/Trampoline/PyMFEMTrampolines/Operator/Matrix/meson.build @@ -0,0 +1,5 @@ +PyMatrix_sources = files( + 'PyMatrix.cpp' +) + +trampoline_sources += PyMatrix_sources diff --git a/src/python/mfem/Trampoline/PyMFEMTrampolines/Operator/PyOperator.cpp b/src/python/mfem/Trampoline/PyMFEMTrampolines/Operator/PyOperator.cpp new file mode 100644 index 0000000..8b4804b --- /dev/null +++ b/src/python/mfem/Trampoline/PyMFEMTrampolines/Operator/PyOperator.cpp @@ -0,0 +1,100 @@ +#include "PyMFEMTrampolines/Operator/PyOperator.h" +#include "mfem.hpp" + +namespace serif::pybind { + + // --- Pure virtual function implementation --- + // This override is mandatory for the trampoline class to be concrete. + void PyOperator::Mult(const mfem::Vector &x, mfem::Vector &y) const { + PYBIND11_OVERRIDE_PURE( + void, /* Return type */ + mfem::Operator, /* C++ base class */ + Mult, /* C++ function name */ + x, /* Argument 1 */ + y /* Argument 2 */ + ); + } + + // --- Regular virtual function implementations --- + // These overrides allow Python subclasses to optionally provide their own + // implementation for these methods. If they don't, the default C++ + // implementation from mfem::Operator will be called. + + void PyOperator::MultTranspose(const mfem::Vector &x, mfem::Vector &y) const { + PYBIND11_OVERRIDE( + void, + mfem::Operator, + MultTranspose, + x, + y + ); + } + + void PyOperator::AddMult(const mfem::Vector &x, mfem::Vector &y, const mfem::real_t a) const { + PYBIND11_OVERRIDE( + void, + mfem::Operator, + AddMult, + x, + y, + a + ); + } + + void PyOperator::AddMultTranspose(const mfem::Vector &x, mfem::Vector &y, const mfem::real_t a) const { + PYBIND11_OVERRIDE( + void, + mfem::Operator, + AddMultTranspose, + x, + y, + a + ); + } + + mfem::Operator& PyOperator::GetGradient(const mfem::Vector &x) const { + PYBIND11_OVERRIDE( + mfem::Operator&, + mfem::Operator, + GetGradient, + x + ); + } + + void PyOperator::AssembleDiagonal(mfem::Vector &diag) const { + PYBIND11_OVERRIDE( + void, + mfem::Operator, + AssembleDiagonal, + diag + ); + } + + const mfem::Operator* PyOperator::GetProlongation() const { + PYBIND11_OVERRIDE( + const mfem::Operator*, + mfem::Operator, + GetProlongation + ); + } + + const mfem::Operator* PyOperator::GetRestriction() const { + PYBIND11_OVERRIDE( + const mfem::Operator*, + mfem::Operator, + GetRestriction + ); + } + + void PyOperator::RecoverFEMSolution(const mfem::Vector &X, const mfem::Vector &b, mfem::Vector &x) { + PYBIND11_OVERRIDE( + void, + mfem::Operator, + RecoverFEMSolution, + X, + b, + x + ); + } + +} // namespace serif::pybind::mfem diff --git a/src/python/mfem/Trampoline/PyMFEMTrampolines/Operator/PyOperator.h b/src/python/mfem/Trampoline/PyMFEMTrampolines/Operator/PyOperator.h new file mode 100644 index 0000000..ba8ca52 --- /dev/null +++ b/src/python/mfem/Trampoline/PyMFEMTrampolines/Operator/PyOperator.h @@ -0,0 +1,210 @@ +/** + * @file PyOperator.h + * @brief Defines a pybind11 trampoline class for mfem::Operator. + * + * This trampoline class allows Python classes to inherit from mfem::Operator, + * enabling Python-defined operators to be used within the C++ MFEM library. + */ +#pragma once + +#include "mfem.hpp" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" // Needed for vectors, maps, sets, strings + +namespace py = pybind11; + +/** + * @namespace serif::pybind + * @brief Contains pybind11 helper classes and trampoline classes for interfacing C++ with Python. + */ +namespace serif::pybind { + + /** + * @brief Trampoline class for mfem::Operator. + * + * This class allows Python classes to inherit from mfem::Operator and correctly + * override its virtual functions. When a virtual function is called from C++, + * the trampoline ensures the call is forwarded to the Python implementation if one exists. + * This is crucial for integrating Python-defined linear or non-linear operators + * into MFEM's C++-based solvers and algorithms. + * + * @see mfem::Operator + * + * @par Python Usage Example: + * @code{.py} + * import mfem.ser_ext as mfem + * import numpy as np + * + * class MyPythonOperator(mfem.Operator): + * def __init__(self, size): + * super().__init__(size) # Call the base C++ constructor + * # Or super().__init__() if default constructing and setting height/width later + * self.matrix = np.random.rand(size, size) # Example: store a dense matrix + * + * # Must override Mult + * def Mult(self, x, y): + * # x is an mfem.Vector (input) + * # y is an mfem.Vector (output, y = A*x) + * # Ensure y is correctly sized if not already + * if y.Size() != self.Height(): + * y.SetSize(self.Height()) + * + * # Example: y = self.matrix * x + * # This is a conceptual illustration. For actual matrix-vector products + * # with numpy, you'd convert mfem.Vector to numpy array or iterate. + * x_np = x.GetDataArray() # Get a numpy view if configured + * y_np = np.dot(self.matrix, x_np) + * y.Assign(y_np) # Assign result back to mfem.Vector + * print("MyPythonOperator.Mult called") + * + * # Optionally override other methods like MultTranspose, GetGradient, etc. + * def MultTranspose(self, x, y): + * if y.Size() != self.Width(): + * y.SetSize(self.Width()) + * # Example: y = self.matrix.T * x + * x_np = x.GetDataArray() + * y_np = np.dot(self.matrix.T, x_np) + * y.Assign(y_np) + * print("MyPythonOperator.MultTranspose called") + * + * # Using the Python operator + * op_size = 5 + * py_op = MyPythonOperator(op_size) + * + * x_vec = mfem.Vector(op_size) + * x_vec.Assign(np.arange(op_size, dtype=float)) + * y_vec = mfem.Vector(op_size) + * + * py_op.Mult(x_vec, y_vec) + * print("Result y:", y_vec.GetDataArray()) + * # py_op can now be passed to MFEM functions expecting an mfem::Operator + * @endcode + */ + class PyOperator : public mfem::Operator { + public: + // Inherit the constructors from the base mfem::Operator class. + // This allows Python classes to call e.g., super().__init__(size). + using mfem::Operator::Operator; + + // --- Trampoline declarations for all overridable virtual functions --- + + /** + * @brief Perform the operator action: y = A*x. + * + * This is a pure virtual function in mfem::Operator and **must** be overridden + * by any Python class inheriting from PyOperator. + * + * @param x The input vector. + * @param y The output vector (result of A*x). + * @note This method forwards the call to the Python override. + * PYBIND11_OVERRIDE_PURE is used in the .cpp file to handle this. + */ + void Mult(const mfem::Vector &x, mfem::Vector &y) const override; + + /** + * @brief Perform the transpose operator action: y = A^T*x. + * + * Optional override. If not overridden, MFEM's base implementation + * (which may raise an error or be a no-op) will be used. + * + * @param x The input vector. + * @param y The output vector (result of A^T*x). + * @note This method forwards the call to the Python override if one exists. + * PYBIND11_OVERRIDE is used in the .cpp file to handle this. + */ + void MultTranspose(const mfem::Vector &x, mfem::Vector &y) const override; + + /** + * @brief Perform the action y += a*(A*x). + * + * Optional override. + * + * @param x The input vector. + * @param y The vector to which a*(A*x) is added. + * @param a Scalar multiplier (defaults to 1.0). + * @note This method forwards the call to the Python override if one exists. + */ + void AddMult(const mfem::Vector &x, mfem::Vector &y, const mfem::real_t a = 1.0) const override; + + /** + * @brief Perform the action y += a*(A^T*x). + * + * Optional override. + * + * @param x The input vector. + * @param y The vector to which a*(A^T*x) is added. + * @param a Scalar multiplier (defaults to 1.0). + * @note This method forwards the call to the Python override if one exists. + */ + void AddMultTranspose(const mfem::Vector &x, mfem::Vector &y, const mfem::real_t a = 1.0) const override; + + /** + * @brief Get the gradient operator (Jacobian) at a given point x. + * + * For non-linear operators, this method should return the linearization (Jacobian) + * of the operator at the point `x`. The returned Operator is owned by this + * Operator and should not be deleted by the caller. + * Optional override. + * + * @param x The point at which to evaluate the gradient. + * @return A reference to the gradient operator. + * @note This method forwards the call to the Python override if one exists. + */ + Operator& GetGradient(const mfem::Vector &x) const override; + + /** + * @brief Assemble the diagonal of the operator. + * + * For discrete operators (e.g., matrices), this method should compute and store + * the diagonal entries of the operator in the vector `diag`. + * Optional override. + * + * @param diag Output vector to store the diagonal entries. + * @note This method forwards the call to the Python override if one exists. + */ + void AssembleDiagonal(mfem::Vector &diag) const override; + + /** + * @brief Get the prolongation operator. + * + * Used in multilevel methods (e.g., AMG). Returns a pointer to the prolongation + * operator (interpolation from a coarser level to this level). + * The returned Operator is typically owned by this Operator. + * Optional override. + * + * @return A const pointer to the prolongation operator, or nullptr if not applicable. + * @note This method forwards the call to the Python override if one exists. + */ + const mfem::Operator* GetProlongation() const override; + + /** + * @brief Get the restriction operator. + * + * Used in multilevel methods (e.g., AMG). Returns a pointer to the restriction + * operator (projection from this level to a coarser level). + * Typically, this is the transpose of the prolongation operator. + * The returned Operator is typically owned by this Operator. + * Optional override. + * + * @return A const pointer to the restriction operator, or nullptr if not applicable. + * @note This method forwards the call to the Python override if one exists. + */ + const mfem::Operator* GetRestriction() const override; + + /** + * @brief Recover the FEM solution. + * + * For operators that are part of a system solve (e.g., static condensation), + * this method can be used to reconstruct the full finite element solution `x` + * from a reduced solution `X` and the right-hand side `b`. + * Optional override. + * + * @param X The reduced solution vector. + * @param b The right-hand side vector. + * @param x Output vector for the full FEM solution. + * @note This method forwards the call to the Python override if one exists. + */ + void RecoverFEMSolution(const mfem::Vector &X, const mfem::Vector &b, mfem::Vector &x) override; + }; + +} // namespace serif::pybind \ No newline at end of file diff --git a/src/python/mfem/Trampoline/PyMFEMTrampolines/Operator/meson.build b/src/python/mfem/Trampoline/PyMFEMTrampolines/Operator/meson.build new file mode 100644 index 0000000..b56454d --- /dev/null +++ b/src/python/mfem/Trampoline/PyMFEMTrampolines/Operator/meson.build @@ -0,0 +1,7 @@ +PyOperator_sources = files( + 'PyOperator.cpp' +) + +trampoline_sources += PyOperator_sources + +subdir('Matrix') \ No newline at end of file diff --git a/src/python/mfem/Trampoline/PyMFEMTrampolines/meson.build b/src/python/mfem/Trampoline/PyMFEMTrampolines/meson.build new file mode 100644 index 0000000..28c1f56 --- /dev/null +++ b/src/python/mfem/Trampoline/PyMFEMTrampolines/meson.build @@ -0,0 +1,2 @@ +subdir('Operator') +subdir('Coefficient') \ No newline at end of file diff --git a/src/python/mfem/Trampoline/meson.build b/src/python/mfem/Trampoline/meson.build new file mode 100644 index 0000000..277e203 --- /dev/null +++ b/src/python/mfem/Trampoline/meson.build @@ -0,0 +1,23 @@ +trampoline_sources = [] + +subdir('PyMFEMTrampolines') + +dependencies = [ + mfem_dep, + pybind11_dep, + python3_dep, +] + +trampoline_lib = static_library( + 'mfem_trampolines', + trampoline_sources, + include_directories: include_directories('.'), + dependencies: dependencies, + install: false, +) + +trampoline_dep = declare_dependency( + link_with: trampoline_lib, + include_directories: ('.'), + dependencies: dependencies, +) \ No newline at end of file diff --git a/src/python/mfem/bindings.cpp b/src/python/mfem/bindings.cpp new file mode 100644 index 0000000..0d1163a --- /dev/null +++ b/src/python/mfem/bindings.cpp @@ -0,0 +1,862 @@ +#include +#include +#include +#include // For operator overloads +#include +#include + +// Include your trampoline class header. The implementation will be in a separate .cpp file. +#include "bindings.h" + +#include "Trampoline/PyMFEMTrampolines/Operator/PyOperator.h" +#include "Trampoline/PyMFEMTrampolines/Operator/Matrix/PyMatrix.h" +#include "Trampoline/PyMFEMTrampolines/Coefficient/PyCoefficient.h" + +#include "mfem.hpp" + + +namespace py = pybind11; +using namespace mfem; + +// This function registers all the mfem-related classes to the python module +void register_mfem_bindings(py::module &mfem_submodule) { + register_operator_bindings(mfem_submodule); + register_matrix_bindings(mfem_submodule); + + register_vector_bindings(mfem_submodule); + register_array_bindings(mfem_submodule); + register_table_bindings(mfem_submodule); + + register_mesh_bindings(mfem_submodule); + + auto formsModule = mfem_submodule.def_submodule("forms", "MFEM forms module"); + register_bilinear_form_bindings(formsModule); + register_mixed_bilinear_form_bindings(formsModule); + + auto fecModule = mfem_submodule.def_submodule("fec", "MFEM finite element collection module"); + register_basis_type_bindings(fecModule); + register_finite_element_collection_bindings(fecModule); + register_H1_FECollection_bindings(fecModule); + register_RT_FECollection_bindings(fecModule); + register_ND_FECollection_bindings(fecModule); + + auto fesModule = mfem_submodule.def_submodule("fes", "MFEM finite element space module"); + register_finite_element_space_bindings(fesModule); + + register_coefficient_bindings(mfem_submodule); + register_intrule_bindings(mfem_submodule); + register_eltrans_bindings(mfem_submodule); + + register_grid_function_bindings(mfem_submodule); +} + +void register_operator_bindings(py::module &mfem_submodule) { + // Use the PyOperator trampoline when binding mfem::Operator + // This allows Python classes to inherit from mfem::Operator + py::class_(mfem_submodule, "Operator") + // NOTE: We DO NOT define an __init__ method because Operator is abstract. + // Python users will instantiate concrete derived classes instead. + + // --- Bind Properties --- + .def_property_readonly("height", &Operator::Height, "Get the height (number of rows) of the Operator.") + .def_property_readonly("width", &Operator::Width, "Get the width (number of columns) of the Operator.") + + // --- Bind Core Virtual Methods --- + // We bind the methods of the C++ base class so they can be called from Python. + // The trampoline handles redirecting to a Python override if one exists. + + // Mult: y = A(x) + .def("Mult", py::overload_cast(&Operator::Mult, py::const_), + py::arg("x"), py::arg("y"), "Calculates y = A(x). y must be pre-allocated.") + + // Pythonic overload for Mult that returns a new vector + .def("Mult", [](const Operator &op, const Vector &x) { + Vector y(op.Height()); + op.Mult(x, y); + return y; + }, py::arg("x"), "Calculates and returns a new vector y = A(x).") + + // MultTranspose: y = A^T(x) + .def("MultTranspose", py::overload_cast(&Operator::MultTranspose, py::const_), + py::arg("x"), py::arg("y"), "Calculates y = A^T(x). y must be pre-allocated.") + + .def("MultTranspose", [](const Operator &op, const Vector &x) { + Vector y(op.Width()); + op.MultTranspose(x, y); + return y; + }, py::arg("x"), "Calculates and returns a new vector y = A^T(x).") + + // Additive versions + .def("AddMult", &Operator::AddMult, py::arg("x"), py::arg("y"), py::arg("a") = 1.0, "Performs y += a * A(x).") + .def("AddMultTranspose", &Operator::AddMultTranspose, py::arg("x"), py::arg("y"), py::arg("a") = 1.0, "Performs y += a * A^T(x).") + + // Other core virtual methods + .def("AssembleDiagonal", &Operator::AssembleDiagonal, py::arg("diag"), "Assembles the operator diagonal into the given Vector.") + .def("RecoverFEMSolution", &Operator::RecoverFEMSolution, py::arg("X"), py::arg("b"), py::arg("x"), "Recovers the full FE solution.") + .def("GetGradient", &Operator::GetGradient, py::return_value_policy::reference, "Returns the Gradient of a non-linear operator.") + + // Methods returning other operators (e.g., for parallel/BCs) + .def("GetProlongation", &Operator::GetProlongation, py::return_value_policy::reference, "Returns the prolongation operator.") + .def("GetRestriction", &Operator::GetRestriction, py::return_value_policy::reference, "Returns the restriction operator.") + + // --- Pythonic Operator Overloading --- + .def("__matmul__", [](const Operator &op, const Vector &x) { + Vector y(op.Height()); + op.Mult(x, y); + return y; + }, py::is_operator()); +} + +void register_matrix_bindings(py::module &mfem_submodule) { + py::class_(mfem_submodule, "Matrix") + // No constructor since it's an abstract base class. + .def_property_readonly("is_square", &Matrix::IsSquare, + "Returns true if the matrix is square.") + + .def("finalize", &Matrix::Finalize, py::arg("skip_zeros") = 1, + "Finalizes the matrix initialization.") + + .def("inverse", &Matrix::Inverse, + "Returns a pointer to (an approximation) of the matrix inverse.", + py::return_value_policy::take_ownership) // The caller owns the returned pointer + + // Pythonic element access: mat[i, j] + .def("__getitem__", [](const Matrix &m, py::tuple t) { + if (t.size() != 2) { + throw py::index_error("Matrix index must be a 2-tuple (i, j)"); + } + return m.Elem(t[0].cast(), t[1].cast()); + }) + + // Pythonic element assignment: mat[i, j] = value + .def("__setitem__", [](Matrix &m, py::tuple t, real_t value) { + if (t.size() != 2) { + throw py::index_error("Matrix index must be a 2-tuple (i, j)"); + } + m.Elem(t[0].cast(), t[1].cast()) = value; + }) + + .def("__repr__", [](const Matrix &m) { + return ""; + }); +} + +void register_vector_bindings(py::module &mfem_submodule) { + // Register the mfem::Vector class + py::class_(mfem_submodule, "Vector") + .def(py::init(), py::arg("size")) + .def(py::init(), py::arg("other")) + .def(py::init([](py::array_t arr) { + py::buffer_info info = arr.request(); + if (info.ndim != 1) { + throw std::runtime_error("Vector(): expected a 1-D numpy array"); + } + mfem::Vector v(info.size); + std::memcpy(v.GetData(), info.ptr, info.size * sizeof(double)); + return v; + }), py::arg("array")) + .def("GetData", &mfem::Vector::GetData, py::return_value_policy::reference_internal) + .def("Size", &mfem::Vector::Size) + .def("__getitem__", [](const mfem::Vector &v, int i) { return v[i]; }) + .def("__len__", &mfem::Vector::Size) + .def("__setitem__", [](mfem::Vector &v, int i, double value) { v[i] = value; }) + .def("__repr__", [](const mfem::Vector &v) { + return ""; + }) + .def("as_numpy", [](mfem::Vector &self) { + return py::array_t(self.Size(), self.GetData()); + }); +} + +void register_array_bindings(py::module &mfem_submodule) { + py::class_>(mfem_submodule, "IntArray") + // --- Constructors --- + .def(py::init<>(), "Default constructor.") + .def(py::init(), py::arg("size"), "Constructor with size.") + .def(py::init([](const std::vector &v) { + auto *arr = new Array(v.size()); + for (size_t i = 0; i < v.size(); ++i) { + (*arr)[i] = v[i]; + } + return arr; + }), py::arg("list"), "Constructor from a Python list.") + + // --- Pythonic Features --- + .def("__len__", &Array::Size) + .def("__getitem__", [](const Array &self, int i) { + if (i < 0) i += self.Size(); // Handle negative indices + if (i < 0 || i >= self.Size()) throw py::index_error(); + return self[i]; + }) + .def("__setitem__", [](Array &self, int i, int value) { + if (i < 0) i += self.Size(); // Handle negative indices + if (i < 0 || i >= self.Size()) throw py::index_error(); + self[i] = value; + }) + .def("__iter__", [](Array &self) { + return py::make_iterator(self.begin(), self.end()); + }, py::keep_alive<0, 1>()) // Keep array alive while iterator is used + .def("__repr__", [](const Array &self) { + std::stringstream ss; + ss << "["; + for (int i = 0; i < self.Size(); ++i) { + ss << self[i] << (i == self.Size() - 1 ? "" : ", "); + } + ss << "]"; + return ss.str(); + }) + + // --- Core Methods --- + .def("Size", &Array::Size) + .def("SetSize", py::overload_cast(&Array::SetSize), py::arg("size")) + .def("Append", py::overload_cast(&Array::Append), py::arg("el")) + .def("Last", py::overload_cast<>(&Array::Last, py::const_)) + .def("DeleteLast", &Array::DeleteLast) + .def("DeleteAll", &Array::DeleteAll) + .def("Sort", [](Array &self) { self.Sort(); }) + .def("Unique", &Array::Unique) + .def("Assign", [](Array &self, const Array &other) { + self.Assign(other.GetData()); + }) + .def("as_numpy", [](Array &self) { + // Create a Python-owned copy to avoid memory issues + return py::array_t(self.Size(), self.GetData()); + }); +} + +// Main function to register BilinearForm +void register_bilinear_form_bindings(py::module &mfem_submodule) { + + // It's good practice to bind enums used by the class + bind_assembly_level_enum(mfem_submodule); + + // Bind the mfem::BilinearForm class, inheriting from mfem::Matrix + // No trampoline is needed because this is a concrete class. + py::class_(mfem_submodule, "BilinearForm") + + // --- Constructor --- + .def(py::init(), py::arg("fespace"), + // The keep_alive policy ensures the Python object for the + // FiniteElementSpace is not garbage-collected while the + // BilinearForm is still using it. + py::keep_alive<1, 2>(), + "Constructs a bilinear form on the given FiniteElementSpace.") + + // --- Setup Methods --- + .def("SetAssemblyLevel", &BilinearForm::SetAssemblyLevel, py::arg("assembly_level"), + "Set the assembly level (e.g., LEGACY, FULL, PARTIAL, NONE).") + .def("EnableStaticCondensation", &BilinearForm::EnableStaticCondensation, + "Enable static condensation to reduce system size.") + .def("EnableHybridization", &BilinearForm::EnableHybridization, + "Enable hybridization.") + .def("UsePrecomputedSparsity", &BilinearForm::UsePrecomputedSparsity, py::arg("ps") = 1, + "Enable use of precomputed sparsity pattern.") + .def("SetDiagonalPolicy", &BilinearForm::SetDiagonalPolicy, py::arg("policy"), + "Set the policy for handling diagonal entries of essential DOFs.") + + // --- Integrator Methods --- + .def("AddDomainIntegrator", py::overload_cast(&BilinearForm::AddDomainIntegrator), + py::arg("bfi"), py::keep_alive<1, 2>(), + "Adds a domain integrator to the form.") + .def("AddBoundaryIntegrator", py::overload_cast(&BilinearForm::AddBoundaryIntegrator), + py::arg("bfi"), py::keep_alive<1, 2>(), + "Adds a boundary integrator to the form.") + .def("AddInteriorFaceIntegrator", &BilinearForm::AddInteriorFaceIntegrator, + py::arg("bfi"), py::keep_alive<1, 2>(), + "Adds an interior face integrator (e.g., for DG methods).") + .def("AddBdrFaceIntegrator", py::overload_cast(&BilinearForm::AddBdrFaceIntegrator), + py::arg("bfi"), py::keep_alive<1, 2>(), + "Adds a boundary face integrator.") + + // --- Assembly and System Formulation --- + .def("Assemble", &BilinearForm::Assemble, py::arg("skip_zeros") = 1, + "Assembles the bilinear form into a sparse matrix.") + .def("AssembleDiagonal", &BilinearForm::AssembleDiagonal, py::arg("diag"), + "Assembles the diagonal of the operator into a Vector.") + .def("FormLinearSystem", + [](BilinearForm &self, const Array &ess_tdof_list, Vector &x, + Vector &b, OperatorHandle &A, Vector &X, Vector &B, int copy_interior) { + self.FormLinearSystem(ess_tdof_list, x, b, A, X, B, copy_interior); + }, + py::arg("ess_tdof_list"), py::arg("x"), py::arg("b"), py::arg("A"), + py::arg("X"), py::arg("B"), py::arg("copy_interior") = 0, + "Forms the linear system AX=B, applying boundary conditions and other transformations.") + .def("FormSystemMatrix", + [](BilinearForm &self, const Array &ess_tdof_list, OperatorHandle &A) { + self.FormSystemMatrix(ess_tdof_list, A); + }, + py::arg("ess_tdof_list"), py::arg("A"), + "Forms the system matrix A, applying necessary transformations.") + .def("RecoverFEMSolution", py::overload_cast(&BilinearForm::RecoverFEMSolution), + py::arg("X"), py::arg("b"), py::arg("x"), + "Recovers the full FE solution vector after solving a linear system.") + + // --- Accessor Methods --- + .def("FESpace", py::overload_cast<>(&BilinearForm::FESpace, py::const_), py::return_value_policy::reference_internal, + "Returns a pointer to the associated FiniteElementSpace.") + .def("SpMat", py::overload_cast<>(&BilinearForm::SpMat), py::return_value_policy::reference_internal, + "Returns a reference to the internal sparse matrix.") + .def("Update", &BilinearForm::Update, py::arg("nfes") = nullptr, + "Update the BilinearForm after the FE space has changed."); + + // You will also need to bind the OperatorHandle and DiagonalPolicy enum + // if you haven't already. Example for DiagonalPolicy: + py::enum_(mfem_submodule, "DiagonalPolicy") + .value("DIAG_ZERO", mfem::Matrix::DiagonalPolicy::DIAG_ZERO) + .value("DIAG_ONE", mfem::Matrix::DiagonalPolicy::DIAG_ONE) + .value("DIAG_KEEP", mfem::Matrix::DiagonalPolicy::DIAG_KEEP) + .export_values(); +} + +// Helper function to bind the AssemblyLevel enum +void bind_assembly_level_enum(py::module &m) { + py::enum_(m, "AssemblyLevel") + .value("LEGACY", AssemblyLevel::LEGACY) + .value("FULL", AssemblyLevel::FULL) + .value("ELEMENT", AssemblyLevel::ELEMENT) + .value("PARTIAL", AssemblyLevel::PARTIAL) + .value("NONE", AssemblyLevel::NONE) + .export_values(); +} + +// Main function to register MixedBilinearForm +void register_mixed_bilinear_form_bindings(py::module &mfem_submodule) { + + // Bind the mfem::MixedBilinearForm class, inheriting from mfem::Matrix + // No trampoline is needed because this is a concrete class. + py::class_(mfem_submodule, "MixedBilinearForm") + + // --- Constructor --- + .def(py::init(), + py::arg("trial_fespace"), py::arg("test_fespace"), + // Keep alive policies ensure the FE space objects are not garbage + // collected while the MixedBilinearForm is still using them. + py::keep_alive<1, 2>(), py::keep_alive<1, 3>(), + "Constructs a mixed bilinear form on the given trial and test FE spaces.") + + // --- Setup Methods --- + .def("SetAssemblyLevel", &MixedBilinearForm::SetAssemblyLevel, py::arg("assembly_level"), + "Set the assembly level (e.g., LEGACY, FULL, PARTIAL, NONE).") + + // --- Integrator Methods --- + .def("AddDomainIntegrator", py::overload_cast(&MixedBilinearForm::AddDomainIntegrator), + py::arg("bfi"), py::keep_alive<1, 2>(), + "Adds a domain integrator to the form.") + .def("AddBoundaryIntegrator", py::overload_cast(&MixedBilinearForm::AddBoundaryIntegrator), + py::arg("bfi"), py::keep_alive<1, 2>(), + "Adds a boundary integrator to the form.") + .def("AddInteriorFaceIntegrator", &MixedBilinearForm::AddInteriorFaceIntegrator, + py::arg("bfi"), py::keep_alive<1, 2>(), + "Adds an interior face integrator.") + .def("AddBdrFaceIntegrator", py::overload_cast(&MixedBilinearForm::AddBdrFaceIntegrator), + py::arg("bfi"), py::keep_alive<1, 2>(), + "Adds a boundary face integrator.") + .def("AddTraceFaceIntegrator", &MixedBilinearForm::AddTraceFaceIntegrator, + py::arg("bfi"), py::keep_alive<1, 2>(), + "Adds a trace face integrator.") + + + // --- Assembly and System Formulation --- + .def("Assemble", &MixedBilinearForm::Assemble, py::arg("skip_zeros") = 1, + "Assembles the mixed bilinear form into a sparse matrix.") + .def("FormRectangularSystemMatrix", + [](MixedBilinearForm &self, const Array &trial_tdof_list, const Array &test_tdof_list, OperatorHandle &A) { + self.FormRectangularSystemMatrix(trial_tdof_list, test_tdof_list, A); + }, + py::arg("trial_tdof_list"), py::arg("test_tdof_list"), py::arg("A"), + "Forms the rectangular system matrix A, applying necessary transformations.") + .def("FormRectangularLinearSystem", + [](MixedBilinearForm &self, const Array &trial_tdof_list, const Array &test_tdof_list, + Vector &x, Vector &b, OperatorHandle &A, Vector &X, Vector &B) { + self.FormRectangularLinearSystem(trial_tdof_list, test_tdof_list, x, b, A, X, B); + }, + py::arg("trial_tdof_list"), py::arg("test_tdof_list"), py::arg("x"), py::arg("b"), + py::arg("A"), py::arg("X"), py::arg("B"), + "Forms the rectangular linear system AX=B.") + + // --- Accessor Methods --- + .def("TrialFESpace", py::overload_cast<>(&MixedBilinearForm::TrialFESpace, py::const_), + py::return_value_policy::reference_internal, + "Returns a pointer to the associated trial FiniteElementSpace.") + .def("TestFESpace", py::overload_cast<>(&MixedBilinearForm::TestFESpace, py::const_), + py::return_value_policy::reference_internal, + "Returns a pointer to the associated test FiniteElementSpace.") + .def("SpMat", py::overload_cast<>(&MixedBilinearForm::SpMat), + py::return_value_policy::reference_internal, + "Returns a reference to the internal sparse matrix.") + .def("Update", &MixedBilinearForm::Update, + "Update the MixedBilinearForm after the FE spaces have changed."); +} + +// This function can be called from your main registration function. +void register_mesh_bindings(py::module &mfem_submodule) { + // Bind the mfem::Mesh class. No trampoline needed. + py::class_(mfem_submodule, "Mesh") + + // --- Constructors & Loading --- + // Default constructor for creating an empty mesh object + .def(py::init<>()) + // Constructor to load from a file path + .def(py::init(), + py::arg("filename"), py::arg("generate_edges") = 0, + py::arg("refine") = 1, py::arg("fix_orientation") = true) + // Static factory method for a more Pythonic loading interface + .def_static("LoadFromFile", &Mesh::LoadFromFile, + py::arg("filename"), py::arg("generate_edges") = 0, + py::arg("refine") = 1, py::arg("fix_orientation") = true, + "Creates a mesh by reading a file.") + + // --- Basic Properties & Stats --- + .def_property_readonly("dim", &Mesh::Dimension) + .def_property_readonly("space_dim", &Mesh::SpaceDimension) + .def_property_readonly("nv", &Mesh::GetNV, "Number of Vertices") + .def_property_readonly("ne", &Mesh::GetNE, "Number of Elements") + .def_property_readonly("nbe", &Mesh::GetNBE, "Number of Boundary Elements") + .def_property_readonly("n_edges", &Mesh::GetNEdges) + .def_property_readonly("n_faces", &Mesh::GetNFaces) + .def_readonly("attributes", &Mesh::attributes) + .def_readonly("bdr_attributes", &Mesh::bdr_attributes) + .def("GetBoundingBox", + [](Mesh &self, int ref) { + Vector min, max; + self.GetBoundingBox(min, max, ref); + // Here you might want to return a tuple of numpy arrays instead + return py::make_tuple(min, max); + }, py::arg("ref") = 2, + "Returns the min and max corners of the mesh bounding box.") + + // --- Connectivity Data --- + .def("GetElementVertices", + [](const Mesh &self, int i) { + Array v; + self.GetElementVertices(i, v); + return v; + }, py::arg("i"), "Returns a list of vertex indices for element i.") + .def("GetBdrElementVertices", + [](const Mesh &self, int i) { + Array v; + self.GetBdrElementVertices(i, v); + return v; + }, py::arg("i"), "Returns a list of vertex indices for boundary element i.") + .def("GetFaceElements", + [](const Mesh &self, int i) { + int e1, e2; + self.GetFaceElements(i, &e1, &e2); + return py::make_tuple(e1, e2); + }, py::arg("face_idx"), "Returns a tuple of the two elements sharing a face.") + .def("GetBdrElementFace", + [](const Mesh &self, int i) { + int f, o; + self.GetBdrElementFace(i, &f, &o); + return py::make_tuple(f, o); + }, py::arg("bdr_elem_idx"), "Returns the face index and orientation for a boundary element.") + .def("ElementToEdgeTable", &Mesh::ElementToEdgeTable, py::return_value_policy::reference_internal) + .def("ElementToFaceTable", &Mesh::ElementToFaceTable, py::return_value_policy::reference_internal) + + // --- Coordinate Transformations & Curvature --- + .def("GetNodes", py::overload_cast<>(&Mesh::GetNodes, py::const_), py::return_value_policy::reference_internal, + "Returns the GridFunction for the mesh nodes (if any).") + .def("SetCurvature", &Mesh::SetCurvature, py::arg("order"), py::arg("discontinuous") = false, + py::arg("space_dim") = -1, py::arg("ordering") = 1, + "Set the curvature of the mesh, creating a high-order nodal GridFunction.") + // Use py::overload_cast to resolve ambiguity for overloaded methods + .def("GetElementTransformation", py::overload_cast(&Mesh::GetElementTransformation), py::arg("i"), + py::return_value_policy::reference_internal) + .def("GetFaceElementTransformations", py::overload_cast(&Mesh::GetFaceElementTransformations), py::arg("i"), + py::arg("mask")=31, py::return_value_policy::reference_internal) + + // --- I/O & Visualization --- + .def("Save", &Mesh::Save, py::arg("filename"), py::arg("precision") = 16); + + // Bind the VTKFormat enum used in PrintVTU + py::enum_(mfem_submodule, "VTKFormat") + .value("ASCII", VTKFormat::ASCII) + .value("BINARY", VTKFormat::BINARY) + .value("BINARY32", VTKFormat::BINARY32) + .export_values(); +} + +void register_table_bindings(pybind11::module &mfem_submodule) { + // Bind mfem::Table first, as it's a return type for some Mesh methods. + py::class_(mfem_submodule, "Table") + .def("GetRow", [](const Table &self, int row) { + Array row_data; + self.GetRow(row, row_data); + // pybind11 will automatically convert Array to a Python list + return row_data; + }, py::arg("row"), "Get a row of the table as a list.") + .def("RowSize", &Table::RowSize, py::arg("row"), "Get the number of entries in a specific row.") + .def_property_readonly("height", &Table::Size, "Number of rows in the table.") + .def_property_readonly("width", &Table::Width, "Number of columns in the table.") + .def("__len__", &Table::Size) + .def("__getitem__", [](const Table &self, int row) { + if (row < 0 || row >= self.Size()) { + throw py::index_error("Row index out of bounds"); + } + Array row_data; + self.GetRow(row, row_data); + return row_data; + }) + .def("__repr__", [](const Table &self) { + return ""; + }); +} + +void register_finite_element_collection_bindings(pybind11::module &mfem_submodule) { + py::class_(mfem_submodule, "FiniteElementCollection") + // No constructor for abstract base classes + .def("Name", &FiniteElementCollection::Name) + .def("GetOrder", &FiniteElementCollection::GetOrder); +} + +// Binds the mfem::BasisType class and its internal anonymous enum values +// as static properties of a Python class. This should be called before +// binding any class that uses BasisType values in its constructor. +void register_basis_type_bindings(py::module &m) { + py::class_ basis_type(m, "BasisType", "Possible basis types."); + + basis_type.def_property_readonly_static("Invalid", [](py::object) { return BasisType::Invalid; }); + basis_type.def_property_readonly_static("GaussLegendre", [](py::object) { return BasisType::GaussLegendre; }); + basis_type.def_property_readonly_static("GaussLobatto", [](py::object) { return BasisType::GaussLobatto; }); + basis_type.def_property_readonly_static("Positive", [](py::object) { return BasisType::Positive; }); + basis_type.def_property_readonly_static("OpenUniform", [](py::object) { return BasisType::OpenUniform; }); + basis_type.def_property_readonly_static("ClosedUniform", [](py::object) { return BasisType::ClosedUniform; }); + basis_type.def_property_readonly_static("OpenHalfUniform", [](py::object) { return BasisType::OpenHalfUniform; }); + basis_type.def_property_readonly_static("Serendipity", [](py::object) { return BasisType::Serendipity; }); + basis_type.def_property_readonly_static("ClosedGL", [](py::object) { return BasisType::ClosedGL; }); + basis_type.def_property_readonly_static("IntegratedGLL", [](py::object) { return BasisType::IntegratedGLL; }); +} + + + +void register_H1_FECollection_bindings(py::module &m) { + py::class_(m, "H1_FECollection") + .def(py::init([](int p, int dim) { + // The lambda explicitly calls the constructor, avoiding the + // default argument parsing issue. + return new H1_FECollection(p, dim); + }), + py::arg("p"), + py::arg("dim") = 3, + "Constructs an H1 finite element collection.") + .def("GetBasisType", &H1_FECollection::GetBasisType); +} + + +void register_RT_FECollection_bindings(py::module &m) { + py::class_(m, "RT_FECollection") + .def(py::init(), + py::arg("p"), py::arg("dim"), + "Constructs a Raviart-Thomas H(div)-conforming FE collection.") + .def("GetClosedBasisType", &RT_FECollection::GetClosedBasisType) + .def("GetOpenBasisType", &RT_FECollection::GetOpenBasisType); +} + +void register_ND_FECollection_bindings(py::module &m) { + py::class_(m, "ND_FECollection") + .def(py::init(), + py::arg("p"), py::arg("dim"), + "Constructs a Nedelec H(curl)-conforming FE collection.") + .def("GetClosedBasisType", &ND_FECollection::GetClosedBasisType) + .def("GetOpenBasisType", &ND_FECollection::GetOpenBasisType); +} + + +void register_finite_element_space_bindings(py::module &mfem_submodule) { + // Bind dependent enums first + bind_ordering_enum(mfem_submodule); + + // Bind the mfem::FiniteElementSpace class + py::class_(mfem_submodule, "FiniteElementSpace") + // --- Constructors --- + .def(py::init(), + py::arg("mesh"), py::arg("fec"), py::arg("vdim") = 1, py::arg("ordering") = Ordering::byNODES, + // Keep alive policies prevent the mesh and fec from being + // garbage collected by Python while the C++ FESpace object exists. + py::keep_alive<1, 2>(), py::keep_alive<1, 3>()) + + // --- Core Properties & Stats --- + .def_property_readonly("ndofs", &FiniteElementSpace::GetNDofs, "Number of local scalar degrees of freedom.") + .def_property_readonly("vdim", &FiniteElementSpace::GetVDim, "Vector dimension of the space.") + .def_property_readonly("vsize", &FiniteElementSpace::GetVSize, "Total number of local vector degrees of freedom.") + .def_property_readonly("true_vsize", &FiniteElementSpace::GetTrueVSize, "Number of true (conforming) vector degrees of freedom.") + .def("GetMesh", &FiniteElementSpace::GetMesh, py::return_value_policy::reference_internal, "Get the associated Mesh.") + .def("FEColl", &FiniteElementSpace::FEColl, py::return_value_policy::reference_internal, "Get the associated FiniteElementCollection.") + + // --- DOF Management --- + .def("GetElementDofs", + [](const FiniteElementSpace &self, int i) { + Array dofs; + self.GetElementDofs(i, dofs); + return dofs; + }, py::arg("elem_idx"), "Get the local scalar DOFs for a given element.") + .def("GetElementVDofs", + [](const FiniteElementSpace &self, int i) { + Array vdofs; + self.GetElementVDofs(i, vdofs); + return vdofs; + }, py::arg("elem_idx"), "Get the local vector DOFs for a given element.") + .def("GetBdrElementDofs", + [](const FiniteElementSpace &self, int i) { + Array dofs; + self.GetBdrElementDofs(i, dofs); + return dofs; + }, py::arg("bdr_elem_idx"), "Get the local scalar DOFs for a given boundary element.") + .def("GetEssentialVDofs", + [](const FiniteElementSpace &self, const Array &bdr_attr_is_ess, int component) { + Array ess_vdofs; + self.GetEssentialVDofs(bdr_attr_is_ess, ess_vdofs, component); + return ess_vdofs; + }, py::arg("bdr_attr_is_ess"), py::arg("component") = -1, + "Get a list of essential (Dirichlet) vector DOFs based on boundary attributes.") + .def("GetEssentialTrueDofs", + [](const FiniteElementSpace &self, const Array &bdr_attr_is_ess, int component) { + Array ess_tdof_list; + self.GetEssentialTrueDofs(bdr_attr_is_ess, ess_tdof_list, component); + return ess_tdof_list; + }, py::arg("bdr_attr_is_ess"), py::arg("component") = -1, + "Get a list of essential true (conforming) DOFs for use in linear systems.") + + // --- Updating & Operators --- + .def("Update", &FiniteElementSpace::Update, py::arg("want_transform") = true, + "Update the FE space after the mesh has been modified (e.g., refined).") + .def("GetUpdateOperator", py::overload_cast<>(&FiniteElementSpace::GetUpdateOperator), + py::return_value_policy::reference_internal, + "Get the operator that maps GridFunctions from the old space to the new space after an Update.") + .def("GetProlongationMatrix", &FiniteElementSpace::GetProlongationMatrix, + py::return_value_policy::reference_internal, "Get the P operator (true DOFs to local DOFs).") + .def("GetRestrictionOperator", &FiniteElementSpace::GetRestrictionOperator, + py::return_value_policy::reference_internal, "Get the R operator (local DOFs to true DOFs).") + + .def("__repr__", [](const FiniteElementSpace &fes) { + std::stringstream ss; + ss << ""; + return ss.str(); + }); +} + +void bind_ordering_enum(py::module &mfem_submodule) { + py::enum_(mfem_submodule, "Ordering") + .value("byNODES", Ordering::byNODES) + .value("byVDIM", Ordering::byVDIM) + .export_values(); +} + +/// Binds mfem::GridFunction +void register_grid_function_bindings(py::module &m) { + // Assumes that Vector, FiniteElementSpace, Coefficient, VectorCoefficient, + // IntegrationRule, and ElementTransformation are already bound. + + py::class_(m, "GridFunction") + .def(py::init(), py::arg("fespace"), + py::keep_alive<1, 2>(), // Keep FE space alive + "Construct a GridFunction on a given FiniteElementSpace.") + + .def(py::init(), + py::arg("fespace"), py::arg("data"), + py::keep_alive<1, 2>(), + "Construct a GridFunction using previously allocated data.") + + .def("FESpace", py::overload_cast<>(&GridFunction::FESpace, py::const_), + py::return_value_policy::reference_internal, + "Returns the associated FiniteElementSpace.") + .def("Update", &GridFunction::Update, + "Update the GridFunction after its FE space has been modified.") + .def("SetSpace", &GridFunction::SetSpace, py::arg("fespace"), + py::keep_alive<1, 2>(), "Associate a new FE space with the GridFunction.") + + // --- Projection Methods --- + .def("ProjectCoefficient", + py::overload_cast(&GridFunction::ProjectCoefficient), + py::arg("coeff"), + "Project a scalar Coefficient onto the GridFunction.") + .def("ProjectCoefficient", + py::overload_cast(&GridFunction::ProjectCoefficient), + py::arg("vcoeff"), + "Project a vector Coefficient onto the GridFunction.") + .def("ProjectBdrCoefficient", + py::overload_cast &>(&GridFunction::ProjectBdrCoefficient), + py::arg("coeff"), py::arg("attr"), + "Project a scalar Coefficient onto the boundary degrees of freedom.") + .def("ProjectBdrCoefficient", + py::overload_cast &>(&GridFunction::ProjectBdrCoefficient), + py::arg("vcoeff"), py::arg("attr"), + "Project a vector Coefficient onto the boundary degrees of freedom.") + + // --- Evaluation Methods --- + .def("GetValue", py::overload_cast(&GridFunction::GetValue, py::const_), + py::arg("T"), py::arg("ip"), py::arg("comp") = 0, py::arg("tr") = nullptr, + "Get the scalar value at a point described by an ElementTransformation.") + .def("GetVectorValue", + [](const GridFunction &self, ElementTransformation &T, const IntegrationPoint &ip) { + Vector val; + self.GetVectorValue(T, ip, val); + return val; + }, + py::arg("T"), py::arg("ip"), + "Get the vector value at a point described by an ElementTransformation.") + .def("GetValues", + [](const GridFunction &self, ElementTransformation &T, const IntegrationRule &ir, int comp) { + Vector vals; + self.GetValues(T, ir, vals, comp); + return vals; + }, + py::arg("T"), py::arg("ir"), py::arg("comp") = 0, + "Get scalar values at all points of an IntegrationRule.") + .def("GetVectorValues", + [](const GridFunction &self, ElementTransformation &T, const IntegrationRule &ir) { + DenseMatrix vals; + self.GetVectorValues(T, ir, vals); + return vals; // pybind11 handles DenseMatrix + }, + py::arg("T"), py::arg("ir"), + "Get vector values at all points of an IntegrationRule.") + + // --- True DOF (constrained system) Methods --- + .def("GetTrueDofs", &GridFunction::GetTrueDofs, py::arg("tv"), + "Extract the true-dofs from the GridFunction.") + .def("SetFromTrueDofs", &GridFunction::SetFromTrueDofs, py::arg("tv"), + "Set the GridFunction from a true-dof vector.") + + // --- I/O --- + .def("Save", py::overload_cast(&GridFunction::Save, py::const_), + py::arg("fname"), py::arg("precision") = 16) + .def("__repr__", [](const GridFunction &gf) { + std::stringstream ss; + ss << ""; + return ss.str(); + }) + .def("as_numpy", [](const GridFunction &self) { + // Convert the GridFunction to a numpy array + return py::array_t(self.Size(), self.GetData()); + }, "Convert the GridFunction to a numpy array."); +} + +// This function registers the coefficient classes to the python module +void register_coefficient_bindings(py::module &m) { + + // --- Bind abstract base classes using trampolines --- + py::class_ coefficient(m, "Coefficient"); + coefficient + .def(py::init<>()) + .def("SetTime", &Coefficient::SetTime, py::arg("t")) + .def("GetTime", &Coefficient::GetTime) + .def("Eval", py::overload_cast(&Coefficient::Eval), + "Evaluate the coefficient at a point in an element.", + py::arg("T"), py::arg("ip")); + + py::class_ vector_coefficient(m, "VectorCoefficient"); + vector_coefficient + .def(py::init(), py::arg("vdim")) + .def("SetTime", &VectorCoefficient::SetTime, py::arg("t")) + .def("GetTime", &VectorCoefficient::GetTime) + .def("GetVDim", &VectorCoefficient::GetVDim) + .def("Eval", py::overload_cast(&VectorCoefficient::Eval), + "Evaluate the vector coefficient at a point in an element.", + py::arg("V"), py::arg("T"), py::arg("ip")); + + // --- Bind useful concrete classes --- + + // ConstantCoefficient + py::class_(m, "ConstantCoefficient") + .def(py::init(), py::arg("c") = 1.0) + .def_readwrite("constant", &ConstantCoefficient::constant); + + // FunctionCoefficient (allows using Python functions as coefficients) + py::class_(m, "FunctionCoefficient") + .def(py::init>(), + py::arg("F"), "Create a coefficient from a Python function of space.") + .def(py::init>(), + py::arg("TDF"), "Create a coefficient from a Python function of space and time."); + + // VectorConstantCoefficient + py::class_(m, "VectorConstantCoefficient") + .def(py::init(), py::arg("v")); + + // VectorFunctionCoefficient + py::class_(m, "VectorFunctionCoefficient") + .def(py::init>(), + py::arg("dim"), py::arg("F"), "Create a vector coefficient from a Python function of space.") + .def(py::init>(), + py::arg("dim"), py::arg("TDF"), "Create a vector coefficient from a Python function of space and time."); + + // GridFunctionCoefficient (useful for source terms that depend on the solution) + py::class_(m, "GridFunctionCoefficient") + .def(py::init<>()) + .def(py::init(), py::arg("gf")) + .def("SetGridFunction", &GridFunctionCoefficient::SetGridFunction, py::arg("gf")) + .def("GetGridFunction", &GridFunctionCoefficient::GetGridFunction, py::return_value_policy::reference); +} + +/// Binds mfem::IntegrationPoint and mfem::IntegrationRule +void register_intrule_bindings(py::module &m) { + py::class_(m, "IntegrationPoint") + .def(py::init<>()) + .def_readwrite("x", &IntegrationPoint::x) + .def_readwrite("y", &IntegrationPoint::y) + .def_readwrite("z", &IntegrationPoint::z) + .def_readwrite("weight", &IntegrationPoint::weight) + .def("__repr__", [](const IntegrationPoint &ip) { + std::stringstream ss; + ss << "IP(x=" << ip.x << ", y=" << ip.y << ", z=" << ip.z << ", w=" << ip.weight << ")"; + return ss.str(); + }); + + py::class_(m, "IntegrationRule") + .def(py::init<>()) + .def(py::init(), py::arg("NumPoints")) + .def("GetNPoints", &IntegrationRule::GetNPoints) + .def("GetOrder", &IntegrationRule::GetOrder) + .def("__len__", &IntegrationRule::GetNPoints) + .def("__getitem__", [](const IntegrationRule &self, int i) { + if (i < 0 || i >= self.GetNPoints()) throw py::index_error(); + return self.IntPoint(i); + }) + .def("__iter__", [](const IntegrationRule &self) { + return py::make_iterator(self.begin(), self.end()); + }, py::keep_alive<0, 1>()); +} + +/// Binds mfem::ElementTransformation and related classes +void register_eltrans_bindings(py::module &m) { + // Bind the base class. Users get pointers to this, but never create it. + // It is not abstract in the C++ sense, but we treat it as such for bindings. + py::class_ eltrans(m, "ElementTransformation"); + eltrans + .def_readonly("Attribute", &ElementTransformation::Attribute) + .def_readonly("ElementNo", &ElementTransformation::ElementNo) + .def("SetIntPoint", &ElementTransformation::SetIntPoint, py::arg("ip")) + .def("Transform", py::overload_cast(&ElementTransformation::Transform), + py::arg("ip"), py::arg("transip")) + .def("Weight", &ElementTransformation::Weight) + // Properties that return references to internal data should use reference_internal policy + .def_property_readonly("Jacobian", &ElementTransformation::Jacobian, py::return_value_policy::reference_internal) + .def_property_readonly("InverseJacobian", &ElementTransformation::InverseJacobian, py::return_value_policy::reference_internal) + .def_property_readonly("AdjugateJacobian", &ElementTransformation::AdjugateJacobian, py::return_value_policy::reference_internal); + + // Bind IsoparametricTransformation, which is a concrete type of ElementTransformation + py::class_(m, "IsoparametricTransformation") + .def(py::init<>()); + + // Bind FaceElementTransformations, crucial for DG methods + py::class_(m, "FaceElementTransformations") + .def(py::init<>()) + .def_readonly("Elem1No", &FaceElementTransformations::Elem1No) + .def_readonly("Elem2No", &FaceElementTransformations::Elem2No) + .def_property_readonly("Elem1", [](FaceElementTransformations &self) { return self.Elem1; }, py::return_value_policy::reference) + .def_property_readonly("Elem2", [](FaceElementTransformations &self) { return self.Elem2; }, py::return_value_policy::reference) + .def("GetElement1IntPoint", &FaceElementTransformations::GetElement1IntPoint) + .def("GetElement2IntPoint", &FaceElementTransformations::GetElement2IntPoint); + + // Bind the enum used for selecting transformations + py::enum_(eltrans, "ConfigMasks") + .value("HAVE_ELEM1", FaceElementTransformations::HAVE_ELEM1) + .value("HAVE_ELEM2", FaceElementTransformations::HAVE_ELEM2) + .value("HAVE_LOC1", FaceElementTransformations::HAVE_LOC1) + .value("HAVE_LOC2", FaceElementTransformations::HAVE_LOC2) + .value("HAVE_FACE", FaceElementTransformations::HAVE_FACE) + .export_values(); +} diff --git a/src/python/mfem/bindings.h b/src/python/mfem/bindings.h new file mode 100644 index 0000000..83e2e98 --- /dev/null +++ b/src/python/mfem/bindings.h @@ -0,0 +1,307 @@ +/** + * @file bindings.h + * @brief Declares functions to register MFEM core library components with pybind11. + * + * This header file lists the functions responsible for creating Python bindings + * for various parts of the MFEM library. Each function typically registers + * a set of related classes, enums, or functionalities to a pybind11::module, + * which is expected to be a submodule named `mfem` within the main `serif` Python module. + * + * @see /Users/tboudreaux/Programming/SERiF/src/python/bindings.cpp for how these are used. + */ +#pragma once + +#include + +/** + * @brief Registers all core MFEM bindings to the given Python submodule. + * + * This function serves as the main entry point for exposing MFEM functionalities + * to Python. It calls various other `register_*_bindings` and `bind_*_enum` + * functions to populate the `mfem_submodule`. + * + * @param mfem_submodule The pybind11 module (typically `serif.mfem`) to which + * MFEM bindings will be added. + * @par Python Usage Example: + * @code{.py} + * import serif.mfem as mfem + * # Now mfem.Operator, mfem.Vector, mfem.Mesh, etc., are accessible. + * vec = mfem.Vector(10) + * print(vec.Size()) + * @endcode + */ +void register_mfem_bindings(pybind11::module &mfem_submodule); + +/** + * @brief Registers mfem::Operator and related classes. + * @param mfem_submodule The `serif.mfem` Python submodule. + * @par Python Usage Example: + * @code{.py} + * import serif.mfem as mfem + * # Assuming PyOperator trampoline is bound + * # op = mfem.Operator() # Or a derived class like mfem.DenseMatrix + * @endcode + */ +void register_operator_bindings(pybind11::module &mfem_submodule); + +/** + * @brief Registers mfem::Matrix and its derived classes (e.g., mfem::DenseMatrix, mfem::SparseMatrix). + * @param mfem_submodule The `serif.mfem` Python submodule. + * @par Python Usage Example: + * @code{.py} + * import serif.mfem as mfem + * mat = mfem.DenseMatrix(2, 2) + * mat[0,0] = 1.0 + * mat[0,1] = 2.0 + * mat[1,0] = 3.0 + * mat[1,1] = 4.0 + * mat.Print() + * @endcode + */ +void register_matrix_bindings(pybind11::module &mfem_submodule); + +/** + * @brief Registers mfem::Vector. + * @param mfem_submodule The `serif.mfem` Python submodule. + * @par Python Usage Example: + * @code{.py} + * import serif.mfem as mfem + * vec = mfem.Vector(5) + * vec[0] = 1.5 + * print(vec.Size(), vec[0]) + * @endcode + */ +void register_vector_bindings(pybind11::module &mfem_submodule); + +/** + * @brief Registers mfem::Array. + * @param mfem_submodule The `serif.mfem` Python submodule. + * @par Python Usage Example: + * @code{.py} + * import serif.mfem as mfem + * arr_int = mfem.intArray(5) # Assuming intArray is a typedef or specific binding + * arr_int[0] = 10 + * print(arr_int.Size(), arr_int[0]) + * @endcode + */ +void register_array_bindings(pybind11::module &mfem_submodule); + +/** + * @brief Binds the mfem::AssemblyLevel enum. + * @param mfem_submodule The `serif.mfem` Python submodule. + * @par Python Usage Example: + * @code{.py} + * import serif.mfem as mfem + * # level = mfem.AssemblyLevel.LEGACY # Or other enum values + * # print(level) + * @endcode + */ +void bind_assembly_level_enum(pybind11::module &mfem_submodule); + +/** + * @brief Registers mfem::BilinearForm and related functionalities. + * @param mfem_submodule The `serif.mfem` Python submodule. + * @par Python Usage Example: + * @code{.py} + * import serif.mfem as mfem + * # Assuming FiniteElementSpace (fes) is created + * # bform = mfem.BilinearForm(fes) + * # bform.AddDomainIntegrator(mfem.MassIntegrator()) # Assuming MassIntegrator is bound + * # bform.Assemble() + * # A = bform.SpMat() + * @endcode + */ +void register_bilinear_form_bindings(pybind11::module &mfem_submodule); + +/** + * @brief Registers mfem::MixedBilinearForm and related functionalities. + * @param mfem_submodule The `serif.mfem` Python submodule. + * @par Python Usage Example: + * @code{.py} + * import serif.mfem as mfem + * # Assuming trial_fes and test_fes are FiniteElementSpaces + * # mbform = mfem.MixedBilinearForm(trial_fes, test_fes) + * # mbform.AddDomainIntegrator(mfem.VectorFEMassIntegrator()) # Example + * # mbform.Assemble() + * @endcode + */ +void register_mixed_bilinear_form_bindings(pybind11::module &mfem_submodule); + +/** + * @brief Registers mfem::Table. + * @param mfem_submodule The `serif.mfem` Python submodule. + * @par Python Usage Example: + * @code{.py} + * import serif.mfem as mfem + * table = mfem.Table() + * # ... use table methods ... + * @endcode + */ +void register_table_bindings(pybind11::module &mfem_submodule); + +/** + * @brief Registers mfem::Mesh. + * @param mfem_submodule The `serif.mfem` Python submodule. + * @par Python Usage Example: + * @code{.py} + * import serif.mfem as mfem + * mesh = mfem.Mesh.MakeCartesian1D(10) # Example constructor + * print(mesh.Dimension(), mesh.GetNE()) + * @endcode + */ +void register_mesh_bindings(pybind11::module &mfem_submodule); + +/** + * @brief Registers mfem::BasisType enum and related constants. + * @param mfem_submodule The `serif.mfem` Python submodule. + * @par Python Usage Example: + * @code{.py} + * import serif.mfem as mfem + * basis_type = mfem.BasisType.GaussLobatto + * print(basis_type) + * @endcode + */ +void register_basis_type_bindings(pybind11::module &mfem_submodule); + +/** + * @brief Registers mfem::FiniteElementCollection base class. + * @param mfem_submodule The `serif.mfem` Python submodule. + * @par Python Usage Example: + * @code{.py} + * import serif.mfem as mfem + * # fec = mfem.FiniteElementCollection() # Typically use derived classes + * @endcode + */ +void register_finite_element_collection_bindings(pybind11::module &mfem_submodule); + +/** + * @brief Registers mfem::H1_FECollection. + * @param mfem_submodule The `serif.mfem` Python submodule. + * @par Python Usage Example: + * @code{.py} + * import serif.mfem as mfem + * order = 1 + * dim = 2 + * fec = mfem.H1_FECollection(order, dim) + * print(fec.GetName()) + * @endcode + */ +void register_H1_FECollection_bindings(pybind11::module &mfem_submodule); + +/** + * @brief Registers mfem::RT_FECollection. + * @param mfem_submodule The `serif.mfem` Python submodule. + * @par Python Usage Example: + * @code{.py} + * import serif.mfem as mfem + * order = 1 + * dim = 2 + * fec = mfem.RT_FECollection(order-1, dim) # RT_FECollection uses p-1 for order p + * print(fec.GetName()) + * @endcode + */ +void register_RT_FECollection_bindings(pybind11::module &mfem_submodule); + +/** + * @brief Registers mfem::ND_FECollection (Nedelec finite elements). + * @param mfem_submodule The `serif.mfem` Python submodule. + * @par Python Usage Example: + * @code{.py} + * import serif.mfem as mfem + * order = 1 + * dim = 3 + * fec = mfem.ND_FECollection(order, dim) + * print(fec.GetName()) + * @endcode + */ +void register_ND_FECollection_bindings(pybind11::module &mfem_submodule); + +/** + * @brief Binds the mfem::Ordering::Type enum. + * @param mfem_submodule The `serif.mfem` Python submodule. + * @par Python Usage Example: + * @code{.py} + * import serif.mfem as mfem + * ordering = mfem.Ordering.byNODES + * print(ordering) + * @endcode + */ +void bind_ordering_enum(pybind11::module &mfem_submodule); + +/** + * @brief Registers mfem::FiniteElementSpace. + * @param mfem_submodule The `serif.mfem` Python submodule. + * @par Python Usage Example: + * @code{.py} + * import serif.mfem as mfem + * mesh = mfem.Mesh.MakeCartesian1D(5) + * fec = mfem.H1_FECollection(1, mesh.Dimension()) + * fes = mfem.FiniteElementSpace(mesh, fec) + * print(fes.GetNDofs()) + * @endcode + */ +void register_finite_element_space_bindings(pybind11::module &mfem_submodule); + +/** + * @brief Registers mfem::Coefficient, mfem::VectorCoefficient and related classes/trampolines. + * @param m The `serif.mfem` Python submodule. + * @par Python Usage Example: + * @code{.py} + * import serif.mfem as mfem + * const_coeff = mfem.ConstantCoefficient(2.5) + * # vec_coeff = mfem.VectorConstantCoefficient(mfem.Vector([1.0, 2.0])) + * + * # Using a Python-derived coefficient (if PyCoefficient trampoline is bound) + * class MyCoeff(mfem.Coefficient): + * def Eval(self, T, ip): + * return 1.0 + * my_c = MyCoeff() + * @endcode + */ +void register_coefficient_bindings(pybind11::module &m); + +/** + * @brief Registers mfem::ElementTransformation. + * @param m The `serif.mfem` Python submodule. + * @par Python Usage Example: + * @code{.py} + * import serif.mfem as mfem + * # ElementTransformation objects are usually obtained from Mesh or FiniteElementSpace + * # mesh = mfem.Mesh.MakeCartesian1D(1) + * # el_trans = mesh.GetElementTransformation(0) + * # print(el_trans.ElementNo) + * @endcode + */ +void register_eltrans_bindings(pybind11::module &m); + +/** + * @brief Registers mfem::IntegrationRule and mfem::IntegrationPoint. + * @param m The `serif.mfem` Python submodule. + * @par Python Usage Example: + * @code{.py} + * import serif.mfem as mfem + * # Get a standard integration rule + * ir = mfem.IntRules.Get(mfem.Geometry.SEGMENT, 3) # Order 3 for a segment + * for i in range(ir.GetNPoints()): + * ip = ir.IntPoint(i) + * # print(f"Point {i}: coords {ip.x}, weight {ip.weight}") + * @endcode + */ +void register_intrule_bindings(pybind11::module &m); + +/** + * @brief Registers mfem::GridFunction. + * @param mfem_submodule The `serif.mfem` Python submodule. + * @par Python Usage Example: + * @code{.py} + * import serif.mfem as mfem + * mesh = mfem.Mesh.MakeCartesian1D(5) + * fec = mfem.H1_FECollection(1, mesh.Dimension()) + * fes = mfem.FiniteElementSpace(mesh, fec) + * gf = mfem.GridFunction(fes) + * gf.Assign(0.0) # Set all values to 0 + * print(gf.Size()) + * @endcode + */ +void register_grid_function_bindings(pybind11::module &mfem_submodule); + diff --git a/src/python/mfem/meson.build b/src/python/mfem/meson.build new file mode 100644 index 0000000..6fae5c1 --- /dev/null +++ b/src/python/mfem/meson.build @@ -0,0 +1,26 @@ +subdir('Trampoline') + +# Define the library +bindings_sources = files( + 'bindings.cpp', +) +bindings_headers = files( + 'bindings.h', +) + +dependencies = [ + config_dep, + resourceManager_dep, + python3_dep, + pybind11_dep, + mpi_dep, + trampoline_dep, +] + +shared_module('py_mfem', + bindings_sources, + include_directories: include_directories('.'), + cpp_args: ['-fvisibility=default'], + install : true, + dependencies: dependencies, +) diff --git a/src/python/polytrope/bindings.cpp b/src/python/polytrope/bindings.cpp new file mode 100644 index 0000000..478fef0 --- /dev/null +++ b/src/python/polytrope/bindings.cpp @@ -0,0 +1,26 @@ +#include +#include // Needed for vectors, maps, sets, strings +#include // Needed for binding std::vector, std::map etc if needed directly +#include + +#include "bindings.h" +#include "EOSio.h" +#include "helm.h" +#include "polySolver.h" +#include "mfem.hpp" + +namespace py = pybind11; + +void register_polytrope_bindings(pybind11::module &polytrope_submodule) { + py::class_(polytrope_submodule, "PolySolver") + .def(py::init(), py::arg("polytropic_index"), py::arg("FEM_order")) + .def("solve", &serif::polytrope::PolySolver::solve, "Solve the polytrope equation.") + .def("get_theta", &serif::polytrope::PolySolver::getTheta, py::return_value_policy::reference_internal) + .def("get_phi", &serif::polytrope::PolySolver::getPhi, py::return_value_policy::reference_internal) + .def("get_order", &serif::polytrope::PolySolver::getOrder) + .def("get_n", &serif::polytrope::PolySolver::getN); + + py::class_(polytrope_submodule, "PolytropeOperator") + .def("Mult", &serif::polytrope::PolytropeOperator::Mult); + +} diff --git a/src/python/polytrope/bindings.h b/src/python/polytrope/bindings.h new file mode 100644 index 0000000..ef17094 --- /dev/null +++ b/src/python/polytrope/bindings.h @@ -0,0 +1,31 @@ +/** + * @file bindings.h + * @brief Declares the function to register polytrope module C++ components with pybind11. + * + * This file contains the declaration for `register_polytrope_bindings`, which is responsible + * for creating Python bindings for classes and functions within the `serif::polytrope` C++ + * namespace. These bindings will be accessible in Python under the `serif.polytrope` submodule. + */ +#pragma once + +#include + +/** + * @brief Registers C++ classes and functions from the `serif::polytrope` namespace to Python. + * + * This function takes a pybind11::module object, representing the `serif.polytrope` Python submodule, + * and adds bindings for various components like `PolytropeOperator`, `PolySolver`, etc. + * This allows these C++ components to be instantiated and used directly from Python. + * + * @param polytrope_submodule The pybind11 module (typically `serif.polytrope`) to which + * the polytrope C++ bindings will be added. + * + * @par Python Usage Example: + * After these bindings are registered and the Python module is imported: + * @code{.py} + * from serif.polytrope import PolySolver + * polytrope = PolySolver(1.5, 1) + * polytrope.solve() + * @endcode + */ +void register_polytrope_bindings(pybind11::module &polytrope_submodule); diff --git a/src/python/polytrope/meson.build b/src/python/polytrope/meson.build new file mode 100644 index 0000000..45ce25a --- /dev/null +++ b/src/python/polytrope/meson.build @@ -0,0 +1,19 @@ +# Define the library +bindings_sources = files('bindings.cpp') +bindings_headers = files('bindings.h') + +dependencies = [ + polysolver_dep, + config_dep, + resourceManager_dep, + python3_dep, + pybind11_dep, +] + +shared_module('py_polytrope', + bindings_sources, + include_directories: include_directories('.'), + cpp_args: ['-fvisibility=default'], + install : true, + dependencies: dependencies, +) diff --git a/tests/python/polytrope/loadMesh.py b/tests/python/polytrope/loadMesh.py new file mode 100644 index 0000000..a0f316d --- /dev/null +++ b/tests/python/polytrope/loadMesh.py @@ -0,0 +1,11 @@ +from serif.mfem import Mesh +import argparse + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Test loading a mesh using MFEM's mesh loading called from Python") + parser.add_argument("path", type=str, help="path to mesh") + + args = parser.parse_args() + + mesh = Mesh(args.path) + print(mesh.nv) diff --git a/tests/python/polytrope/runPolytrope.py b/tests/python/polytrope/runPolytrope.py new file mode 100644 index 0000000..8d9d2ee --- /dev/null +++ b/tests/python/polytrope/runPolytrope.py @@ -0,0 +1,18 @@ +from serif import config +from serif.polytrope import PolySolver +from serif.mfem import Matrix +from serif.mfem import Operator +from serif.mfem.forms import BilinearForm + +config.loadConfig('../../testsConfig.yaml') +n = config.get("Tests:Poly:Index", 0.0) + +polytrope = PolySolver(n, 1) +polytrope.solve() +theta = polytrope.get_theta() +print(theta) + +FESpace = theta.FESpace() +print(FESpace) + +print(theta.as_numpy())