Merge pull request #67 from tboudreaux/feature/pythonInterface/poly

MFEM and polytrope python bindings
This commit is contained in:
2025-06-16 12:23:53 -04:00
committed by GitHub
28 changed files with 2331 additions and 60 deletions

View File

@@ -9,6 +9,11 @@ py_mod = py_installation.extension_module(
meson.project_source_root() + '/src/python/const/bindings.cpp',
meson.project_source_root() + '/src/python/config/bindings.cpp',
meson.project_source_root() + '/src/python/eos/bindings.cpp',
meson.project_source_root() + '/src/python/mfem/bindings.cpp',
meson.project_source_root() + '/src/python/mfem/Trampoline/PyMFEMTrampolines/Operator/PyOperator.cpp',
meson.project_source_root() + '/src/python/mfem/Trampoline/PyMFEMTrampolines/Operator/Matrix/PyMatrix.cpp',
meson.project_source_root() + '/src/python/mfem/Trampoline/PyMFEMTrampolines/Coefficient/PyCoefficient.cpp',
meson.project_source_root() + '/src/python/polytrope/bindings.cpp',
],
dependencies : [
pybind11_dep,
@@ -16,7 +21,10 @@ py_mod = py_installation.extension_module(
config_dep,
composition_dep,
eos_dep,
species_weight_dep
species_weight_dep,
mfem_dep,
polysolver_dep,
trampoline_dep
],
cpp_args : ['-UNDEBUG'], # Example: Ensure assertions are enabled if needed
install : true,

View File

@@ -39,8 +39,8 @@
#include "utilities.h"
#include "quill/LogMacros.h"
namespace serif {
namespace polytrope {
namespace serif::polytrope {
namespace laneEmden {
@@ -382,7 +382,6 @@ void PolySolver::setInitialGuess() const {
serif::probe::glVisView(*m_theta, m_mesh, "θ init");
serif::probe::glVisView(*m_phi, m_mesh, "φ init");
}
std::cout << "HERE" << std::endl;
}
@@ -497,5 +496,5 @@ solverBundle PolySolver::setupNewtonSolver() const {
return solver;
}
} // namespace polytrope
} // namespace serif
} // namespace serif::polytrope

View File

@@ -271,7 +271,14 @@ public: // Public methods
* @return A reference to the `mfem::GridFunction` storing the \f$\theta\f$ solution.
* @note The solution is populated after `solve()` has been successfully called.
*/
mfem::GridFunction& getSolution() const { return *m_theta; }
mfem::GridFunction& getTheta() const { return *m_theta; }
/**
* @brief Gets a reference to the solution grid function for \f$\phi\f$.
* @return A reference to the `mfem::GridFunction` storing the \f$\phi\f$ solution.
* @note The solution is populated after `solve()` has been successfully called.
*/
mfem::GridFunction& getPhi() const { return *m_phi; }
private: // Private Attributes
// --- Configuration and Logging ---

View File

@@ -119,7 +119,6 @@ namespace serif::utilities {
mfem::IsoparametricTransformation T;
mesh->GetElementTransformation(ne, &T);
phi_gf.GetCurl(T, curl_mag_vec);
std::cout << "HERE" << std::endl;
}
mfem::L2_FECollection fac(order, dim);
mfem::FiniteElementSpace fs(mesh, &fac);

View File

@@ -31,27 +31,26 @@
* @brief A collection of utilities for working with MFEM and solving the lane-emden equation.
*/
namespace serif {
namespace polytrope {
/**
namespace serif::polytrope {
/**
* @namespace polyMFEMUtils
* @brief A namespace for utilities for working with MFEM and solving the lane-emden equation.
*/
namespace polyMFEMUtils {
/**
namespace polyMFEMUtils {
/**
* @brief A class for nonlinear power integrator.
*/
class NonlinearPowerIntegrator: public mfem::NonlinearFormIntegrator {
public:
/**
class NonlinearPowerIntegrator: public mfem::NonlinearFormIntegrator {
public:
/**
* @brief Constructor for NonlinearPowerIntegrator.
*
* @param coeff The function coefficient.
* @param n The polytropic index.
*/
NonlinearPowerIntegrator(double n);
NonlinearPowerIntegrator(double n);
/**
/**
* @brief Assembles the element vector.
*
* @param el The finite element.
@@ -59,8 +58,8 @@ namespace polyMFEMUtils {
* @param elfun The element function.
* @param elvect The element vector to be assembled.
*/
virtual void AssembleElementVector(const mfem::FiniteElement &el, mfem::ElementTransformation &Trans, const mfem::Vector &elfun, mfem::Vector &elvect) override;
/**
virtual void AssembleElementVector(const mfem::FiniteElement &el, mfem::ElementTransformation &Trans, const mfem::Vector &elfun, mfem::Vector &elvect) override;
/**
* @brief Assembles the element gradient.
*
* @param el The finite element.
@@ -68,40 +67,39 @@ namespace polyMFEMUtils {
* @param elfun The element function.
* @param elmat The element matrix to be assembled.
*/
virtual void AssembleElementGrad (const mfem::FiniteElement &el, mfem::ElementTransformation &Trans, const mfem::Vector &elfun, mfem::DenseMatrix &elmat) override;
private:
serif::config::Config& m_config = serif::config::Config::getInstance();
serif::probe::LogManager& m_logManager = serif::probe::LogManager::getInstance();
quill::Logger* m_logger = m_logManager.getLogger("log");
double m_polytropicIndex;
double m_epsilon;
static constexpr double m_regularizationRadius = 0.15; ///< Regularization radius for the epsilon function, used to avoid singularities in the power law.
static constexpr double m_regularizationCoeff = 1.0/6.0; ///< Coefficient for the regularization term, used to ensure smoothness in the power law.
};
virtual void AssembleElementGrad (const mfem::FiniteElement &el, mfem::ElementTransformation &Trans, const mfem::Vector &elfun, mfem::DenseMatrix &elmat) override;
private:
serif::config::Config& m_config = serif::config::Config::getInstance();
serif::probe::LogManager& m_logManager = serif::probe::LogManager::getInstance();
quill::Logger* m_logger = m_logManager.getLogger("log");
double m_polytropicIndex;
double m_epsilon;
static constexpr double m_regularizationRadius = 0.15; ///< Regularization radius for the epsilon function, used to avoid singularities in the power law.
static constexpr double m_regularizationCoeff = 1.0/6.0; ///< Coefficient for the regularization term, used to ensure smoothness in the power law.
};
inline double dfmod(const double epsilon, const double n) {
if (n == 0.0) {
return 0.0;
inline double dfmod(const double epsilon, const double n) {
if (n == 0.0) {
return 0.0;
}
if (n == 1.0) {
return 1.0;
}
return n * std::pow(epsilon, n - 1.0);
}
if (n == 1.0) {
return 1.0;
inline double fmod(const double theta, const double n, const double epsilon) {
if (n == 0.0) {
return 1.0;
}
// For n != 0
const double y_prime_at_epsilon = dfmod(epsilon, n); // Uses the robust dfmod
const double y_at_epsilon = std::pow(epsilon, n); // epsilon^n
// f_mod(theta) = y_at_epsilon + y_prime_at_epsilon * (theta - epsilon)
return y_at_epsilon + y_prime_at_epsilon * (theta - epsilon);
}
return n * std::pow(epsilon, n - 1.0);
}
inline double fmod(const double theta, const double n, const double epsilon) {
if (n == 0.0) {
return 1.0;
}
// For n != 0
const double y_prime_at_epsilon = dfmod(epsilon, n); // Uses the robust dfmod
const double y_at_epsilon = std::pow(epsilon, n); // epsilon^n
// f_mod(theta) = y_at_epsilon + y_prime_at_epsilon * (theta - epsilon)
return y_at_epsilon + y_prime_at_epsilon * (theta - epsilon);
}
} // namespace polyMFEMUtils
} // namespace polytrope
} // namespace serif
} // namespace polyMFEMUtils
}

View File

@@ -26,8 +26,7 @@
#include "probe.h"
namespace serif {
namespace polytrope {
namespace serif::polytrope {
/**
* @brief Represents the Schur complement operator used in the solution process.
@@ -299,7 +298,7 @@ private:
std::unique_ptr<mfem::MixedBilinearForm> m_M; ///< Bilinear form M, coupling θ and φ.
std::unique_ptr<mfem::MixedBilinearForm> m_Q; ///< Bilinear form Q, coupling φ and θ.
std::unique_ptr<mfem::BilinearForm> m_D; ///< Bilinear form D, acting on φ.
std::unique_ptr<mfem::BilinearForm> m_S;
std::unique_ptr<mfem::BilinearForm> m_S; ///< Bilinear form S, used for least squares stabilization.
std::unique_ptr<mfem::NonlinearForm> m_f; ///< Nonlinear form f, acting on θ.
// --- Full Matrix Representations (owned, derived from forms) ---
@@ -395,5 +394,4 @@ private:
void update_preconditioner(const mfem::Operator &grad) const;
};
} // namespace polytrope
} // namespace serif
} // namespace serif::polytrope

View File

@@ -6,6 +6,8 @@
#include "composition/bindings.h"
#include "config/bindings.h"
#include "eos/bindings.h"
#include "mfem/bindings.h"
#include "polytrope/bindings.h"
PYBIND11_MODULE(serif, m) {
m.doc() = "Python bindings for the SERiF project";
@@ -21,4 +23,10 @@ PYBIND11_MODULE(serif, m) {
auto eosMod = m.def_submodule("eos", "EOS-module bindings");
register_eos_bindings(eosMod);
auto mfemMod = m.def_submodule("mfem", "MFEM bindings");
register_mfem_bindings(mfemMod);
auto polytropeMod = m.def_submodule("polytrope", "Polytrope-module bindings");
register_polytrope_bindings(polytropeMod);
}

View File

@@ -5,12 +5,9 @@
#include <string>
#include "helm.h"
// #include "resourceManager.h"
#include "bindings.h"
#include "EOSio.h"
#include "helm.h"
#include "../../eos/public/EOSio.h"
#include "../../eos/public/helm.h"
namespace serif::eos {
class EOSio;

View File

@@ -1,4 +1,6 @@
subdir('composition')
subdir('const')
subdir('config')
subdir('mfem')
subdir('eos')

View File

@@ -0,0 +1,51 @@
#include "PyCoefficient.h"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/functional.h> // Needed for std::function
#include <memory>
#include "mfem.hpp"
namespace py = pybind11;
using namespace mfem;
namespace serif::pybind {
real_t PyCoefficient::Eval(ElementTransformation &T, const IntegrationPoint &ip) {
PYBIND11_OVERRIDE_PURE(
real_t, /* Return type */
Coefficient, /* Base class */
Eval, /* Method name */
T, ip /* Arguments */
);
}
// Override virtual SetTime method
void PyCoefficient::SetTime(real_t t) {
PYBIND11_OVERRIDE(
void,
Coefficient,
SetTime,
t
);
}
void PyVectorCoefficient::Eval(Vector &V, ElementTransformation &T, const IntegrationPoint &ip) {
PYBIND11_OVERRIDE_PURE(
void, /* Return type */
VectorCoefficient, /* Base class */
Eval, /* Method name */
V, T, ip /* Arguments */
);
}
// Override the virtual SetTime method
void PyVectorCoefficient::SetTime(real_t t) {
PYBIND11_OVERRIDE(
void,
VectorCoefficient,
SetTime,
t
);
}
}

View File

@@ -0,0 +1,162 @@
/**
* @file PyCoefficient.h
* @brief Defines pybind11 trampoline classes for mfem::Coefficient and mfem::VectorCoefficient.
*
* These trampoline classes allow Python classes to inherit from mfem::Coefficient
* and mfem::VectorCoefficient, enabling Python-defined coefficients to be used
* within the C++ MFEM library.
*/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/functional.h> // Needed for std::function
#include <memory>
#include "mfem.hpp"
namespace py = pybind11;
using namespace mfem;
/**
* @namespace serif::pybind
* @brief Contains pybind11 helper classes and trampoline classes for interfacing C++ with Python.
*/
namespace serif::pybind {
/**
* @brief Trampoline class for mfem::Coefficient.
*
* This class allows Python classes to inherit from mfem::Coefficient and override
* its virtual methods. This is essential for creating custom coefficients in Python
* that can be used by MFEM's C++ backend.
*
* @see mfem::Coefficient
*
* @par Python Usage Example:
* @code{.py}
* import mfem.ser_ext as mfem
*
* class MyPythonCoefficient(mfem.Coefficient):
* def __init__(self):
* super().__init__() # Call the base C++ constructor
*
* def Eval(self, T, ip):
* # T is an mfem.ElementTransformation
* # ip is an mfem.IntegrationPoint
* # Example: return a constant value
* return 1.0
*
* def SetTime(self, t):
* # Optionally handle time-dependent coefficients
* super().SetTime(t) # Call base class method
* print(f"Time set to: {t}")
*
* # Using the Python coefficient
* py_coeff = MyPythonCoefficient()
* # py_coeff can now be passed to MFEM functions expecting an mfem::Coefficient
* @endcode
*/
class PyCoefficient : public Coefficient {
public:
using Coefficient::Coefficient; /**< Inherit constructors from mfem::Coefficient. */
/**
* @brief Evaluate the coefficient at a given IntegrationPoint in an ElementTransformation.
*
* This method is called by MFEM when the value of the coefficient is needed.
* If a Python class inherits from PyCoefficient, it *must* override this method.
*
* @param T The element transformation.
* @param ip The integration point.
* @return The value of the coefficient at the given point.
*
* @note This method forwards the call to the Python override.
* PYBIND11_OVERRIDE_PURE is used in the .cpp file to handle this.
*/
real_t Eval(ElementTransformation &T, const IntegrationPoint &ip) override;
/**
* @brief Set the current time for time-dependent coefficients.
*
* This method is called by MFEM to update the time for time-dependent coefficients.
* Python classes inheriting from PyCoefficient can override this method to implement
* time-dependent behavior.
*
* @param t The current time.
*
* @note This method forwards the call to the Python override if one exists.
* PYBIND11_OVERRIDE is used in the .cpp file to handle this.
*/
void SetTime(real_t t) override;
};
/**
* @brief Trampoline class for mfem::VectorCoefficient.
*
* This class allows Python classes to inherit from mfem::VectorCoefficient and override
* its virtual methods. This is essential for creating custom vector-valued coefficients
* in Python that can be used by MFEM's C++ backend.
*
* @see mfem::VectorCoefficient
*
* @par Python Usage Example:
* @code{.py}
* import mfem.ser_ext as mfem
* import numpy as np
*
* class MyPythonVectorCoefficient(mfem.VectorCoefficient):
* def __init__(self, dim):
* super().__init__(dim) # Call the base C++ constructor, pass vector dimension
* self.dim = dim
*
* def Eval(self, V, T, ip):
* # V is an mfem.Vector (output parameter, must be filled)
* # T is an mfem.ElementTransformation
* # ip is an mfem.IntegrationPoint
* # Example: return a constant vector [1.0, 2.0, ...]
* for i in range(self.dim):
* V[i] = float(i + 1)
*
* def SetTime(self, t):
* super().SetTime(t)
* print(f"VectorCoefficient time set to: {t}")
*
* # Using the Python vector coefficient
* vec_dim = 2
* py_vec_coeff = MyPythonVectorCoefficient(vec_dim)
* # py_vec_coeff can now be passed to MFEM functions expecting an mfem::VectorCoefficient
* @endcode
*/
class PyVectorCoefficient : public VectorCoefficient {
public:
using VectorCoefficient::VectorCoefficient; /**< Inherit constructors from mfem::VectorCoefficient. */
/**
* @brief Evaluate the vector coefficient at a given IntegrationPoint in an ElementTransformation.
*
* This method is called by MFEM when the value of the vector coefficient is needed.
* If a Python class inherits from PyVectorCoefficient, it *must* override this method.
* The result should be stored in the output Vector @p V.
*
* @param V Output vector to store the result. Its size should match the coefficient's dimension.
* @param T The element transformation.
* @param ip The integration point.
*
* @note This method forwards the call to the Python override.
* PYBIND11_OVERRIDE_PURE is used in the .cpp file to handle this.
*/
void Eval(Vector &V, ElementTransformation &T, const IntegrationPoint &ip) override;
/**
* @brief Set the current time for time-dependent vector coefficients.
*
* This method is called by MFEM to update the time for time-dependent vector coefficients.
* Python classes inheriting from PyVectorCoefficient can override this method to implement
* time-dependent behavior.
*
* @param t The current time.
*
* @note This method forwards the call to the Python override if one exists.
* PYBIND11_OVERRIDE is used in the .cpp file to handle this.
*/
void SetTime(real_t t) override;
};
}

View File

@@ -0,0 +1,5 @@
PyCoefficient_sources = files(
'PyCoefficient.cpp'
)
trampoline_sources += PyCoefficient_sources

View File

@@ -0,0 +1,137 @@
#include "PyMFEMTrampolines/Operator/Matrix/PyMatrix.h"
#include "mfem.hpp"
namespace serif::pybind {
// --- Trampolines for new mfem::Matrix pure virtual methods ---
mfem::real_t& PyMatrix::Elem(int i, int j) {
PYBIND11_OVERRIDE_PURE(
mfem::real_t&, /* Return type */
mfem::Matrix, /* C++ base class */
Elem, /* C++ function name */
i, /* Argument 1 */
j /* Argument 2 */
);
}
const mfem::real_t& PyMatrix::Elem(int i, int j) const {
PYBIND11_OVERRIDE_PURE(
const mfem::real_t&,
mfem::Matrix,
Elem,
i,
j
);
}
mfem::MatrixInverse* PyMatrix::Inverse() const {
PYBIND11_OVERRIDE_PURE(
mfem::MatrixInverse*,
mfem::Matrix,
Inverse
);
}
// --- Trampoline for new mfem::Matrix regular virtual methods ---
void PyMatrix::Finalize(int skip_zeros) {
PYBIND11_OVERRIDE(
void,
mfem::Matrix,
Finalize,
skip_zeros
);
}
// --- Trampolines for inherited mfem::Operator virtual methods ---
void PyMatrix::Mult(const mfem::Vector &x, mfem::Vector &y) const {
// This remains PURE as mfem::Matrix does not implement it.
PYBIND11_OVERRIDE_PURE(
void,
mfem::Matrix,
Mult,
x,
y
);
}
void PyMatrix::MultTranspose(const mfem::Vector &x, mfem::Vector &y) const {
PYBIND11_OVERRIDE(
void,
mfem::Matrix,
MultTranspose,
x,
y
);
}
void PyMatrix::AddMult(const mfem::Vector &x, mfem::Vector &y, const mfem::real_t a) const {
PYBIND11_OVERRIDE(
void,
mfem::Matrix,
AddMult,
x,
y,
a
);
}
void PyMatrix::AddMultTranspose(const mfem::Vector &x, mfem::Vector &y, const mfem::real_t a) const {
PYBIND11_OVERRIDE(
void,
mfem::Matrix,
AddMultTranspose,
x,
y,
a
);
}
mfem::Operator& PyMatrix::GetGradient(const mfem::Vector &x) const {
PYBIND11_OVERRIDE(
mfem::Operator&,
mfem::Matrix,
GetGradient,
x
);
}
void PyMatrix::AssembleDiagonal(mfem::Vector &diag) const {
PYBIND11_OVERRIDE(
void,
mfem::Matrix,
AssembleDiagonal,
diag
);
}
const mfem::Operator* PyMatrix::GetProlongation() const {
PYBIND11_OVERRIDE(
const mfem::Operator*,
mfem::Matrix,
GetProlongation
);
}
const mfem::Operator* PyMatrix::GetRestriction() const {
PYBIND11_OVERRIDE(
const mfem::Operator*,
mfem::Matrix,
GetRestriction
);
}
void PyMatrix::RecoverFEMSolution(const mfem::Vector &X, const mfem::Vector &b, mfem::Vector &x) {
PYBIND11_OVERRIDE(
void,
mfem::Matrix,
RecoverFEMSolution,
X,
b,
x
);
}
} // namespace serif::pybind

View File

@@ -0,0 +1,253 @@
/**
* @file PyMatrix.h
* @brief Defines a pybind11 trampoline class for mfem::Matrix.
*
* This trampoline class allows Python classes to inherit from mfem::Matrix,
* enabling Python-defined matrices to be used within the C++ MFEM library,
* including overriding methods from its base class mfem::Operator.
*/
#pragma once
#include "mfem.hpp"
#include "pybind11/pybind11.h"
namespace serif::pybind {
/**
* @brief A trampoline class for mfem::Matrix.
*
* This class allows Python classes to inherit from mfem::Matrix and correctly
* override its virtual functions, including those inherited from its base,
* mfem::Operator. This is useful for creating custom matrix types in Python
* that can interact seamlessly with MFEM's C++ components.
*
* @see mfem::Matrix
* @see mfem::Operator
* @see PyOperator
*
* @par Python Usage Example:
* @code{.py}
* import mfem.ser_ext as mfem
* import numpy as np
*
* class MyPythonMatrix(mfem.Matrix):
* def __init__(self, size=0):
* super().__init__(size) # Call base C++ constructor
* if size > 0:
* # Example: Initialize with some data if size is provided
* # Note: Direct data manipulation might be complex due to
* # C++ ownership. This example is conceptual.
* # For a real-world scenario, you might manage data in Python
* # and use Eval or Mult to reflect its state.
* self._py_data = np.zeros((size, size))
* else:
* self._py_data = None
*
* # --- mfem.Matrix pure virtual overrides ---
* def Elem(self, i, j):
* # This is problematic for direct override for write access
* # due to returning a reference.
* # Typically, you'd manage data and use it in Mult/GetRow etc.
* # For read access (const version), it's more straightforward.
* if self._py_data is not None:
* return self._py_data[i, j]
* raise IndexError("Matrix data not initialized or index out of bounds")
*
* # const Elem version for read access
* # def GetElem(self, i, j): # A helper in Python if Elem is tricky
* # return self._py_data[i,j]
*
* def Inverse(self):
* # Return an mfem.MatrixInverse object
* # This is a simplified example; a real inverse is complex.
* print("MyPythonMatrix.Inverse() called, returning dummy inverse.")
* # For a real implementation, you'd compute and return an actual
* # mfem.MatrixInverse or a Python-derived version of it.
* # This might involve creating a new PyMatrix representing the inverse.
* identity_inv = mfem.DenseMatrix(self.Width())
* for i in range(self.Width()):
* identity_inv[i,i] = 1.0
* return mfem.MatrixInverse(identity_inv) # Example
*
* # --- mfem.Matrix regular virtual overrides ---
* def Finalize(self, skip_zeros=1):
* super().Finalize(skip_zeros) # Call base
* print(f"MyPythonMatrix.Finalize({skip_zeros}) called.")
*
* # --- mfem.Operator virtual overrides ---
* def Mult(self, x, y):
* # x is mfem.Vector (input), y is mfem.Vector (output)
* if self._py_data is not None:
* x_np = x.GetDataArray()
* y_np = np.dot(self._py_data, x_np)
* y.SetSize(self._py_data.shape[0])
* y.Assign(y_np)
* else:
* # Fallback to base class if no Python data
* # super().Mult(x,y) # This might not work as expected if base is abstract
* # or if the C++ mfem.Matrix itself doesn't have data.
* # For a sparse matrix, this would call its sparse Mult.
* # For a dense matrix, it would use its data.
* # If this PyMatrix is purely Python defined, it must implement Mult.
* raise NotImplementedError("Mult not implemented or data not set")
* print("MyPythonMatrix.Mult called.")
*
* # Using the Python Matrix
* mat_size = 3
* py_mat = MyPythonMatrix(mat_size)
* py_mat._py_data = np.array([[1,2,3],[4,5,6],[7,8,9]], dtype=float)
*
* x_vec = mfem.Vector(mat_size)
* x_vec.Assign(np.array([1,1,1], dtype=float))
* y_vec = mfem.Vector(mat_size)
*
* py_mat.Mult(x_vec, y_vec)
* print("Result y:", y_vec.GetDataArray())
*
* # inv_op = py_mat.Inverse() # Be cautious with dummy implementations
* # print("Inverse type:", type(inv_op))
*
* # Note: Overriding Elem(i,j) to return a C++ reference from Python
* # for write access (mfem::real_t&) is complex and often not directly feasible
* # in a straightforward way with pybind11 for typical Python data structures.
* # Python users would typically interact via methods like Set, Add, or by
* # providing data that the C++ side can access (e.g., via GetData).
* # The const version of Elem (read-only) is more manageable.
* @endcode
*/
class PyMatrix : public mfem::Matrix {
public:
// Inherit constructors from the base mfem::Matrix class.
using mfem::Matrix::Matrix;
// --- Trampolines for new mfem::Matrix pure virtual methods ---
/**
* @brief Access element (i,j) for read/write.
* Pure virtual in mfem::Matrix. Must be overridden in Python.
* @param i Row index.
* @param j Column index.
* @return Reference to the matrix element.
* @note Returning a C++ reference from Python for write access can be complex.
* Consider alternative data management strategies in Python.
* PYBIND11_OVERRIDE_PURE is used in the .cpp file.
*/
mfem::real_t& Elem(int i, int j) override;
/**
* @brief Access element (i,j) for read-only.
* Pure virtual in mfem::Matrix. Must be overridden in Python.
* @param i Row index.
* @param j Column index.
* @return Const reference to the matrix element.
* @note PYBIND11_OVERRIDE_PURE is used in the .cpp file.
*/
const mfem::real_t& Elem(int i, int j) const override;
/**
* @brief Get the inverse of the matrix.
* Pure virtual in mfem::Matrix. Must be overridden in Python.
* The caller is responsible for deleting the returned MatrixInverse object.
* @return Pointer to an mfem::MatrixInverse object.
* @note PYBIND11_OVERRIDE_PURE is used in the .cpp file.
*/
mfem::MatrixInverse* Inverse() const override;
// --- Trampoline for new mfem::Matrix regular virtual methods ---
/**
* @brief Finalize matrix assembly.
* For sparse matrices, this typically involves finalizing the sparse structure.
* Can be overridden in Python.
* @param skip_zeros See mfem::SparseMatrix::Finalize documentation.
* @note PYBIND11_OVERRIDE is used in the .cpp file.
*/
void Finalize(int skip_zeros) override;
// --- Trampolines for inherited mfem::Operator virtual methods ---
// These must be repeated here to allow Python classes inheriting from
// Matrix to override methods originally from the Operator base class.
/**
* @brief Perform the operator action: y = A*x.
* Inherited from mfem::Operator. Can be overridden in Python.
* If not overridden, mfem::Matrix's default implementation is used.
* @param x The input vector.
* @param y The output vector (result of A*x).
* @note PYBIND11_OVERRIDE is used in the .cpp file.
*/
void Mult(const mfem::Vector &x, mfem::Vector &y) const override;
/**
* @brief Perform the transpose operator action: y = A^T*x.
* Inherited from mfem::Operator. Can be overridden in Python.
* @param x The input vector.
* @param y The output vector (result of A^T*x).
* @note PYBIND11_OVERRIDE is used in the .cpp file.
*/
void MultTranspose(const mfem::Vector &x, mfem::Vector &y) const override;
/**
* @brief Perform the action y += a*(A*x).
* Inherited from mfem::Operator. Can be overridden in Python.
* @param x The input vector.
* @param y The vector to which a*(A*x) is added.
* @param a Scalar multiplier (defaults to 1.0).
* @note PYBIND11_OVERRIDE is used in the .cpp file.
*/
void AddMult(const mfem::Vector &x, mfem::Vector &y, const mfem::real_t a = 1.0) const override;
/**
* @brief Perform the action y += a*(A^T*x).
* Inherited from mfem::Operator. Can be overridden in Python.
* @param x The input vector.
* @param y The vector to which a*(A^T*x) is added.
* @param a Scalar multiplier (defaults to 1.0).
* @note PYBIND11_OVERRIDE is used in the .cpp file.
*/
void AddMultTranspose(const mfem::Vector &x, mfem::Vector &y, const mfem::real_t a = 1.0) const override;
/**
* @brief Get the gradient operator (Jacobian) at a given point x.
* Inherited from mfem::Operator. Can be overridden in Python.
* For a linear matrix operator, the gradient is typically the matrix itself.
* @param x The point at which to evaluate the gradient (often unused for linear operators).
* @return A reference to the gradient operator.
* @note PYBIND11_OVERRIDE is used in the .cpp file.
*/
mfem::Operator& GetGradient(const mfem::Vector &x) const override;
/**
* @brief Assemble the diagonal of the operator.
* Inherited from mfem::Operator. Can be overridden in Python.
* @param diag Output vector to store the diagonal entries.
* @note PYBIND11_OVERRIDE is used in the .cpp file.
*/
void AssembleDiagonal(mfem::Vector &diag) const override;
/**
* @brief Get the prolongation operator.
* Inherited from mfem::Operator. Can be overridden in Python.
* @return A const pointer to the prolongation operator, or nullptr if not applicable.
* @note PYBIND11_OVERRIDE is used in the .cpp file.
*/
const mfem::Operator* GetProlongation() const override;
/**
* @brief Get the restriction operator.
* Inherited from mfem::Operator. Can be overridden in Python.
* @return A const pointer to the restriction operator, or nullptr if not applicable.
* @note PYBIND11_OVERRIDE is used in the .cpp file.
*/
const mfem::Operator* GetRestriction() const override;
/**
* @brief Recover the FEM solution.
* Inherited from mfem::Operator. Can be overridden in Python.
* @param X The reduced solution vector.
* @param b The right-hand side vector.
* @param x Output vector for the full FEM solution.
* @note PYBIND11_OVERRIDE is used in the .cpp file.
*/
void RecoverFEMSolution(const mfem::Vector &X, const mfem::Vector &b, mfem::Vector &x) override;
};
} // namespace serif::pybind

View File

@@ -0,0 +1,5 @@
PyMatrix_sources = files(
'PyMatrix.cpp'
)
trampoline_sources += PyMatrix_sources

View File

@@ -0,0 +1,100 @@
#include "PyMFEMTrampolines/Operator/PyOperator.h"
#include "mfem.hpp"
namespace serif::pybind {
// --- Pure virtual function implementation ---
// This override is mandatory for the trampoline class to be concrete.
void PyOperator::Mult(const mfem::Vector &x, mfem::Vector &y) const {
PYBIND11_OVERRIDE_PURE(
void, /* Return type */
mfem::Operator, /* C++ base class */
Mult, /* C++ function name */
x, /* Argument 1 */
y /* Argument 2 */
);
}
// --- Regular virtual function implementations ---
// These overrides allow Python subclasses to optionally provide their own
// implementation for these methods. If they don't, the default C++
// implementation from mfem::Operator will be called.
void PyOperator::MultTranspose(const mfem::Vector &x, mfem::Vector &y) const {
PYBIND11_OVERRIDE(
void,
mfem::Operator,
MultTranspose,
x,
y
);
}
void PyOperator::AddMult(const mfem::Vector &x, mfem::Vector &y, const mfem::real_t a) const {
PYBIND11_OVERRIDE(
void,
mfem::Operator,
AddMult,
x,
y,
a
);
}
void PyOperator::AddMultTranspose(const mfem::Vector &x, mfem::Vector &y, const mfem::real_t a) const {
PYBIND11_OVERRIDE(
void,
mfem::Operator,
AddMultTranspose,
x,
y,
a
);
}
mfem::Operator& PyOperator::GetGradient(const mfem::Vector &x) const {
PYBIND11_OVERRIDE(
mfem::Operator&,
mfem::Operator,
GetGradient,
x
);
}
void PyOperator::AssembleDiagonal(mfem::Vector &diag) const {
PYBIND11_OVERRIDE(
void,
mfem::Operator,
AssembleDiagonal,
diag
);
}
const mfem::Operator* PyOperator::GetProlongation() const {
PYBIND11_OVERRIDE(
const mfem::Operator*,
mfem::Operator,
GetProlongation
);
}
const mfem::Operator* PyOperator::GetRestriction() const {
PYBIND11_OVERRIDE(
const mfem::Operator*,
mfem::Operator,
GetRestriction
);
}
void PyOperator::RecoverFEMSolution(const mfem::Vector &X, const mfem::Vector &b, mfem::Vector &x) {
PYBIND11_OVERRIDE(
void,
mfem::Operator,
RecoverFEMSolution,
X,
b,
x
);
}
} // namespace serif::pybind::mfem

View File

@@ -0,0 +1,210 @@
/**
* @file PyOperator.h
* @brief Defines a pybind11 trampoline class for mfem::Operator.
*
* This trampoline class allows Python classes to inherit from mfem::Operator,
* enabling Python-defined operators to be used within the C++ MFEM library.
*/
#pragma once
#include "mfem.hpp"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h" // Needed for vectors, maps, sets, strings
namespace py = pybind11;
/**
* @namespace serif::pybind
* @brief Contains pybind11 helper classes and trampoline classes for interfacing C++ with Python.
*/
namespace serif::pybind {
/**
* @brief Trampoline class for mfem::Operator.
*
* This class allows Python classes to inherit from mfem::Operator and correctly
* override its virtual functions. When a virtual function is called from C++,
* the trampoline ensures the call is forwarded to the Python implementation if one exists.
* This is crucial for integrating Python-defined linear or non-linear operators
* into MFEM's C++-based solvers and algorithms.
*
* @see mfem::Operator
*
* @par Python Usage Example:
* @code{.py}
* import mfem.ser_ext as mfem
* import numpy as np
*
* class MyPythonOperator(mfem.Operator):
* def __init__(self, size):
* super().__init__(size) # Call the base C++ constructor
* # Or super().__init__() if default constructing and setting height/width later
* self.matrix = np.random.rand(size, size) # Example: store a dense matrix
*
* # Must override Mult
* def Mult(self, x, y):
* # x is an mfem.Vector (input)
* # y is an mfem.Vector (output, y = A*x)
* # Ensure y is correctly sized if not already
* if y.Size() != self.Height():
* y.SetSize(self.Height())
*
* # Example: y = self.matrix * x
* # This is a conceptual illustration. For actual matrix-vector products
* # with numpy, you'd convert mfem.Vector to numpy array or iterate.
* x_np = x.GetDataArray() # Get a numpy view if configured
* y_np = np.dot(self.matrix, x_np)
* y.Assign(y_np) # Assign result back to mfem.Vector
* print("MyPythonOperator.Mult called")
*
* # Optionally override other methods like MultTranspose, GetGradient, etc.
* def MultTranspose(self, x, y):
* if y.Size() != self.Width():
* y.SetSize(self.Width())
* # Example: y = self.matrix.T * x
* x_np = x.GetDataArray()
* y_np = np.dot(self.matrix.T, x_np)
* y.Assign(y_np)
* print("MyPythonOperator.MultTranspose called")
*
* # Using the Python operator
* op_size = 5
* py_op = MyPythonOperator(op_size)
*
* x_vec = mfem.Vector(op_size)
* x_vec.Assign(np.arange(op_size, dtype=float))
* y_vec = mfem.Vector(op_size)
*
* py_op.Mult(x_vec, y_vec)
* print("Result y:", y_vec.GetDataArray())
* # py_op can now be passed to MFEM functions expecting an mfem::Operator
* @endcode
*/
class PyOperator : public mfem::Operator {
public:
// Inherit the constructors from the base mfem::Operator class.
// This allows Python classes to call e.g., super().__init__(size).
using mfem::Operator::Operator;
// --- Trampoline declarations for all overridable virtual functions ---
/**
* @brief Perform the operator action: y = A*x.
*
* This is a pure virtual function in mfem::Operator and **must** be overridden
* by any Python class inheriting from PyOperator.
*
* @param x The input vector.
* @param y The output vector (result of A*x).
* @note This method forwards the call to the Python override.
* PYBIND11_OVERRIDE_PURE is used in the .cpp file to handle this.
*/
void Mult(const mfem::Vector &x, mfem::Vector &y) const override;
/**
* @brief Perform the transpose operator action: y = A^T*x.
*
* Optional override. If not overridden, MFEM's base implementation
* (which may raise an error or be a no-op) will be used.
*
* @param x The input vector.
* @param y The output vector (result of A^T*x).
* @note This method forwards the call to the Python override if one exists.
* PYBIND11_OVERRIDE is used in the .cpp file to handle this.
*/
void MultTranspose(const mfem::Vector &x, mfem::Vector &y) const override;
/**
* @brief Perform the action y += a*(A*x).
*
* Optional override.
*
* @param x The input vector.
* @param y The vector to which a*(A*x) is added.
* @param a Scalar multiplier (defaults to 1.0).
* @note This method forwards the call to the Python override if one exists.
*/
void AddMult(const mfem::Vector &x, mfem::Vector &y, const mfem::real_t a = 1.0) const override;
/**
* @brief Perform the action y += a*(A^T*x).
*
* Optional override.
*
* @param x The input vector.
* @param y The vector to which a*(A^T*x) is added.
* @param a Scalar multiplier (defaults to 1.0).
* @note This method forwards the call to the Python override if one exists.
*/
void AddMultTranspose(const mfem::Vector &x, mfem::Vector &y, const mfem::real_t a = 1.0) const override;
/**
* @brief Get the gradient operator (Jacobian) at a given point x.
*
* For non-linear operators, this method should return the linearization (Jacobian)
* of the operator at the point `x`. The returned Operator is owned by this
* Operator and should not be deleted by the caller.
* Optional override.
*
* @param x The point at which to evaluate the gradient.
* @return A reference to the gradient operator.
* @note This method forwards the call to the Python override if one exists.
*/
Operator& GetGradient(const mfem::Vector &x) const override;
/**
* @brief Assemble the diagonal of the operator.
*
* For discrete operators (e.g., matrices), this method should compute and store
* the diagonal entries of the operator in the vector `diag`.
* Optional override.
*
* @param diag Output vector to store the diagonal entries.
* @note This method forwards the call to the Python override if one exists.
*/
void AssembleDiagonal(mfem::Vector &diag) const override;
/**
* @brief Get the prolongation operator.
*
* Used in multilevel methods (e.g., AMG). Returns a pointer to the prolongation
* operator (interpolation from a coarser level to this level).
* The returned Operator is typically owned by this Operator.
* Optional override.
*
* @return A const pointer to the prolongation operator, or nullptr if not applicable.
* @note This method forwards the call to the Python override if one exists.
*/
const mfem::Operator* GetProlongation() const override;
/**
* @brief Get the restriction operator.
*
* Used in multilevel methods (e.g., AMG). Returns a pointer to the restriction
* operator (projection from this level to a coarser level).
* Typically, this is the transpose of the prolongation operator.
* The returned Operator is typically owned by this Operator.
* Optional override.
*
* @return A const pointer to the restriction operator, or nullptr if not applicable.
* @note This method forwards the call to the Python override if one exists.
*/
const mfem::Operator* GetRestriction() const override;
/**
* @brief Recover the FEM solution.
*
* For operators that are part of a system solve (e.g., static condensation),
* this method can be used to reconstruct the full finite element solution `x`
* from a reduced solution `X` and the right-hand side `b`.
* Optional override.
*
* @param X The reduced solution vector.
* @param b The right-hand side vector.
* @param x Output vector for the full FEM solution.
* @note This method forwards the call to the Python override if one exists.
*/
void RecoverFEMSolution(const mfem::Vector &X, const mfem::Vector &b, mfem::Vector &x) override;
};
} // namespace serif::pybind

View File

@@ -0,0 +1,7 @@
PyOperator_sources = files(
'PyOperator.cpp'
)
trampoline_sources += PyOperator_sources
subdir('Matrix')

View File

@@ -0,0 +1,2 @@
subdir('Operator')
subdir('Coefficient')

View File

@@ -0,0 +1,23 @@
trampoline_sources = []
subdir('PyMFEMTrampolines')
dependencies = [
mfem_dep,
pybind11_dep,
python3_dep,
]
trampoline_lib = static_library(
'mfem_trampolines',
trampoline_sources,
include_directories: include_directories('.'),
dependencies: dependencies,
install: false,
)
trampoline_dep = declare_dependency(
link_with: trampoline_lib,
include_directories: ('.'),
dependencies: dependencies,
)

View File

@@ -0,0 +1,862 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>
#include <pybind11/operators.h> // For operator overloads
#include <memory>
#include <sstream>
// Include your trampoline class header. The implementation will be in a separate .cpp file.
#include "bindings.h"
#include "Trampoline/PyMFEMTrampolines/Operator/PyOperator.h"
#include "Trampoline/PyMFEMTrampolines/Operator/Matrix/PyMatrix.h"
#include "Trampoline/PyMFEMTrampolines/Coefficient/PyCoefficient.h"
#include "mfem.hpp"
namespace py = pybind11;
using namespace mfem;
// This function registers all the mfem-related classes to the python module
void register_mfem_bindings(py::module &mfem_submodule) {
register_operator_bindings(mfem_submodule);
register_matrix_bindings(mfem_submodule);
register_vector_bindings(mfem_submodule);
register_array_bindings(mfem_submodule);
register_table_bindings(mfem_submodule);
register_mesh_bindings(mfem_submodule);
auto formsModule = mfem_submodule.def_submodule("forms", "MFEM forms module");
register_bilinear_form_bindings(formsModule);
register_mixed_bilinear_form_bindings(formsModule);
auto fecModule = mfem_submodule.def_submodule("fec", "MFEM finite element collection module");
register_basis_type_bindings(fecModule);
register_finite_element_collection_bindings(fecModule);
register_H1_FECollection_bindings(fecModule);
register_RT_FECollection_bindings(fecModule);
register_ND_FECollection_bindings(fecModule);
auto fesModule = mfem_submodule.def_submodule("fes", "MFEM finite element space module");
register_finite_element_space_bindings(fesModule);
register_coefficient_bindings(mfem_submodule);
register_intrule_bindings(mfem_submodule);
register_eltrans_bindings(mfem_submodule);
register_grid_function_bindings(mfem_submodule);
}
void register_operator_bindings(py::module &mfem_submodule) {
// Use the PyOperator trampoline when binding mfem::Operator
// This allows Python classes to inherit from mfem::Operator
py::class_<Operator, serif::pybind::PyOperator /* Trampoline */>(mfem_submodule, "Operator")
// NOTE: We DO NOT define an __init__ method because Operator is abstract.
// Python users will instantiate concrete derived classes instead.
// --- Bind Properties ---
.def_property_readonly("height", &Operator::Height, "Get the height (number of rows) of the Operator.")
.def_property_readonly("width", &Operator::Width, "Get the width (number of columns) of the Operator.")
// --- Bind Core Virtual Methods ---
// We bind the methods of the C++ base class so they can be called from Python.
// The trampoline handles redirecting to a Python override if one exists.
// Mult: y = A(x)
.def("Mult", py::overload_cast<const Vector&, Vector&>(&Operator::Mult, py::const_),
py::arg("x"), py::arg("y"), "Calculates y = A(x). y must be pre-allocated.")
// Pythonic overload for Mult that returns a new vector
.def("Mult", [](const Operator &op, const Vector &x) {
Vector y(op.Height());
op.Mult(x, y);
return y;
}, py::arg("x"), "Calculates and returns a new vector y = A(x).")
// MultTranspose: y = A^T(x)
.def("MultTranspose", py::overload_cast<const Vector&, Vector&>(&Operator::MultTranspose, py::const_),
py::arg("x"), py::arg("y"), "Calculates y = A^T(x). y must be pre-allocated.")
.def("MultTranspose", [](const Operator &op, const Vector &x) {
Vector y(op.Width());
op.MultTranspose(x, y);
return y;
}, py::arg("x"), "Calculates and returns a new vector y = A^T(x).")
// Additive versions
.def("AddMult", &Operator::AddMult, py::arg("x"), py::arg("y"), py::arg("a") = 1.0, "Performs y += a * A(x).")
.def("AddMultTranspose", &Operator::AddMultTranspose, py::arg("x"), py::arg("y"), py::arg("a") = 1.0, "Performs y += a * A^T(x).")
// Other core virtual methods
.def("AssembleDiagonal", &Operator::AssembleDiagonal, py::arg("diag"), "Assembles the operator diagonal into the given Vector.")
.def("RecoverFEMSolution", &Operator::RecoverFEMSolution, py::arg("X"), py::arg("b"), py::arg("x"), "Recovers the full FE solution.")
.def("GetGradient", &Operator::GetGradient, py::return_value_policy::reference, "Returns the Gradient of a non-linear operator.")
// Methods returning other operators (e.g., for parallel/BCs)
.def("GetProlongation", &Operator::GetProlongation, py::return_value_policy::reference, "Returns the prolongation operator.")
.def("GetRestriction", &Operator::GetRestriction, py::return_value_policy::reference, "Returns the restriction operator.")
// --- Pythonic Operator Overloading ---
.def("__matmul__", [](const Operator &op, const Vector &x) {
Vector y(op.Height());
op.Mult(x, y);
return y;
}, py::is_operator());
}
void register_matrix_bindings(py::module &mfem_submodule) {
py::class_<Matrix, Operator, serif::pybind::PyMatrix>(mfem_submodule, "Matrix")
// No constructor since it's an abstract base class.
.def_property_readonly("is_square", &Matrix::IsSquare,
"Returns true if the matrix is square.")
.def("finalize", &Matrix::Finalize, py::arg("skip_zeros") = 1,
"Finalizes the matrix initialization.")
.def("inverse", &Matrix::Inverse,
"Returns a pointer to (an approximation) of the matrix inverse.",
py::return_value_policy::take_ownership) // The caller owns the returned pointer
// Pythonic element access: mat[i, j]
.def("__getitem__", [](const Matrix &m, py::tuple t) {
if (t.size() != 2) {
throw py::index_error("Matrix index must be a 2-tuple (i, j)");
}
return m.Elem(t[0].cast<int>(), t[1].cast<int>());
})
// Pythonic element assignment: mat[i, j] = value
.def("__setitem__", [](Matrix &m, py::tuple t, real_t value) {
if (t.size() != 2) {
throw py::index_error("Matrix index must be a 2-tuple (i, j)");
}
m.Elem(t[0].cast<int>(), t[1].cast<int>()) = value;
})
.def("__repr__", [](const Matrix &m) {
return "<mfem.Matrix (Abstract) " +
std::to_string(m.Height()) + "x" +
std::to_string(m.Width()) + ">";
});
}
void register_vector_bindings(py::module &mfem_submodule) {
// Register the mfem::Vector class
py::class_<mfem::Vector>(mfem_submodule, "Vector")
.def(py::init<int>(), py::arg("size"))
.def(py::init<const mfem::Vector &>(), py::arg("other"))
.def(py::init([](py::array_t<double> arr) {
py::buffer_info info = arr.request();
if (info.ndim != 1) {
throw std::runtime_error("Vector(): expected a 1-D numpy array");
}
mfem::Vector v(info.size);
std::memcpy(v.GetData(), info.ptr, info.size * sizeof(double));
return v;
}), py::arg("array"))
.def("GetData", &mfem::Vector::GetData, py::return_value_policy::reference_internal)
.def("Size", &mfem::Vector::Size)
.def("__getitem__", [](const mfem::Vector &v, int i) { return v[i]; })
.def("__len__", &mfem::Vector::Size)
.def("__setitem__", [](mfem::Vector &v, int i, double value) { v[i] = value; })
.def("__repr__", [](const mfem::Vector &v) {
return "<mfem.Vector(size=" + std::to_string(v.Size()) + ")>";
})
.def("as_numpy", [](mfem::Vector &self) {
return py::array_t<double>(self.Size(), self.GetData());
});
}
void register_array_bindings(py::module &mfem_submodule) {
py::class_<Array<int>>(mfem_submodule, "IntArray")
// --- Constructors ---
.def(py::init<>(), "Default constructor.")
.def(py::init<int>(), py::arg("size"), "Constructor with size.")
.def(py::init([](const std::vector<int> &v) {
auto *arr = new Array<int>(v.size());
for (size_t i = 0; i < v.size(); ++i) {
(*arr)[i] = v[i];
}
return arr;
}), py::arg("list"), "Constructor from a Python list.")
// --- Pythonic Features ---
.def("__len__", &Array<int>::Size)
.def("__getitem__", [](const Array<int> &self, int i) {
if (i < 0) i += self.Size(); // Handle negative indices
if (i < 0 || i >= self.Size()) throw py::index_error();
return self[i];
})
.def("__setitem__", [](Array<int> &self, int i, int value) {
if (i < 0) i += self.Size(); // Handle negative indices
if (i < 0 || i >= self.Size()) throw py::index_error();
self[i] = value;
})
.def("__iter__", [](Array<int> &self) {
return py::make_iterator(self.begin(), self.end());
}, py::keep_alive<0, 1>()) // Keep array alive while iterator is used
.def("__repr__", [](const Array<int> &self) {
std::stringstream ss;
ss << "[";
for (int i = 0; i < self.Size(); ++i) {
ss << self[i] << (i == self.Size() - 1 ? "" : ", ");
}
ss << "]";
return ss.str();
})
// --- Core Methods ---
.def("Size", &Array<int>::Size)
.def("SetSize", py::overload_cast<int>(&Array<int>::SetSize), py::arg("size"))
.def("Append", py::overload_cast<const int &>(&Array<int>::Append), py::arg("el"))
.def("Last", py::overload_cast<>(&Array<int>::Last, py::const_))
.def("DeleteLast", &Array<int>::DeleteLast)
.def("DeleteAll", &Array<int>::DeleteAll)
.def("Sort", [](Array<int> &self) { self.Sort(); })
.def("Unique", &Array<int>::Unique)
.def("Assign", [](Array<int> &self, const Array<int> &other) {
self.Assign(other.GetData());
})
.def("as_numpy", [](Array<int> &self) {
// Create a Python-owned copy to avoid memory issues
return py::array_t<int>(self.Size(), self.GetData());
});
}
// Main function to register BilinearForm
void register_bilinear_form_bindings(py::module &mfem_submodule) {
// It's good practice to bind enums used by the class
bind_assembly_level_enum(mfem_submodule);
// Bind the mfem::BilinearForm class, inheriting from mfem::Matrix
// No trampoline is needed because this is a concrete class.
py::class_<BilinearForm, Matrix>(mfem_submodule, "BilinearForm")
// --- Constructor ---
.def(py::init<FiniteElementSpace *>(), py::arg("fespace"),
// The keep_alive policy ensures the Python object for the
// FiniteElementSpace is not garbage-collected while the
// BilinearForm is still using it.
py::keep_alive<1, 2>(),
"Constructs a bilinear form on the given FiniteElementSpace.")
// --- Setup Methods ---
.def("SetAssemblyLevel", &BilinearForm::SetAssemblyLevel, py::arg("assembly_level"),
"Set the assembly level (e.g., LEGACY, FULL, PARTIAL, NONE).")
.def("EnableStaticCondensation", &BilinearForm::EnableStaticCondensation,
"Enable static condensation to reduce system size.")
.def("EnableHybridization", &BilinearForm::EnableHybridization,
"Enable hybridization.")
.def("UsePrecomputedSparsity", &BilinearForm::UsePrecomputedSparsity, py::arg("ps") = 1,
"Enable use of precomputed sparsity pattern.")
.def("SetDiagonalPolicy", &BilinearForm::SetDiagonalPolicy, py::arg("policy"),
"Set the policy for handling diagonal entries of essential DOFs.")
// --- Integrator Methods ---
.def("AddDomainIntegrator", py::overload_cast<BilinearFormIntegrator *>(&BilinearForm::AddDomainIntegrator),
py::arg("bfi"), py::keep_alive<1, 2>(),
"Adds a domain integrator to the form.")
.def("AddBoundaryIntegrator", py::overload_cast<BilinearFormIntegrator *>(&BilinearForm::AddBoundaryIntegrator),
py::arg("bfi"), py::keep_alive<1, 2>(),
"Adds a boundary integrator to the form.")
.def("AddInteriorFaceIntegrator", &BilinearForm::AddInteriorFaceIntegrator,
py::arg("bfi"), py::keep_alive<1, 2>(),
"Adds an interior face integrator (e.g., for DG methods).")
.def("AddBdrFaceIntegrator", py::overload_cast<BilinearFormIntegrator *>(&BilinearForm::AddBdrFaceIntegrator),
py::arg("bfi"), py::keep_alive<1, 2>(),
"Adds a boundary face integrator.")
// --- Assembly and System Formulation ---
.def("Assemble", &BilinearForm::Assemble, py::arg("skip_zeros") = 1,
"Assembles the bilinear form into a sparse matrix.")
.def("AssembleDiagonal", &BilinearForm::AssembleDiagonal, py::arg("diag"),
"Assembles the diagonal of the operator into a Vector.")
.def("FormLinearSystem",
[](BilinearForm &self, const Array<int> &ess_tdof_list, Vector &x,
Vector &b, OperatorHandle &A, Vector &X, Vector &B, int copy_interior) {
self.FormLinearSystem(ess_tdof_list, x, b, A, X, B, copy_interior);
},
py::arg("ess_tdof_list"), py::arg("x"), py::arg("b"), py::arg("A"),
py::arg("X"), py::arg("B"), py::arg("copy_interior") = 0,
"Forms the linear system AX=B, applying boundary conditions and other transformations.")
.def("FormSystemMatrix",
[](BilinearForm &self, const Array<int> &ess_tdof_list, OperatorHandle &A) {
self.FormSystemMatrix(ess_tdof_list, A);
},
py::arg("ess_tdof_list"), py::arg("A"),
"Forms the system matrix A, applying necessary transformations.")
.def("RecoverFEMSolution", py::overload_cast<const Vector&, const Vector&, Vector&>(&BilinearForm::RecoverFEMSolution),
py::arg("X"), py::arg("b"), py::arg("x"),
"Recovers the full FE solution vector after solving a linear system.")
// --- Accessor Methods ---
.def("FESpace", py::overload_cast<>(&BilinearForm::FESpace, py::const_), py::return_value_policy::reference_internal,
"Returns a pointer to the associated FiniteElementSpace.")
.def("SpMat", py::overload_cast<>(&BilinearForm::SpMat), py::return_value_policy::reference_internal,
"Returns a reference to the internal sparse matrix.")
.def("Update", &BilinearForm::Update, py::arg("nfes") = nullptr,
"Update the BilinearForm after the FE space has changed.");
// You will also need to bind the OperatorHandle and DiagonalPolicy enum
// if you haven't already. Example for DiagonalPolicy:
py::enum_<mfem::Matrix::DiagonalPolicy>(mfem_submodule, "DiagonalPolicy")
.value("DIAG_ZERO", mfem::Matrix::DiagonalPolicy::DIAG_ZERO)
.value("DIAG_ONE", mfem::Matrix::DiagonalPolicy::DIAG_ONE)
.value("DIAG_KEEP", mfem::Matrix::DiagonalPolicy::DIAG_KEEP)
.export_values();
}
// Helper function to bind the AssemblyLevel enum
void bind_assembly_level_enum(py::module &m) {
py::enum_<AssemblyLevel>(m, "AssemblyLevel")
.value("LEGACY", AssemblyLevel::LEGACY)
.value("FULL", AssemblyLevel::FULL)
.value("ELEMENT", AssemblyLevel::ELEMENT)
.value("PARTIAL", AssemblyLevel::PARTIAL)
.value("NONE", AssemblyLevel::NONE)
.export_values();
}
// Main function to register MixedBilinearForm
void register_mixed_bilinear_form_bindings(py::module &mfem_submodule) {
// Bind the mfem::MixedBilinearForm class, inheriting from mfem::Matrix
// No trampoline is needed because this is a concrete class.
py::class_<MixedBilinearForm, Matrix>(mfem_submodule, "MixedBilinearForm")
// --- Constructor ---
.def(py::init<FiniteElementSpace *, FiniteElementSpace *>(),
py::arg("trial_fespace"), py::arg("test_fespace"),
// Keep alive policies ensure the FE space objects are not garbage
// collected while the MixedBilinearForm is still using them.
py::keep_alive<1, 2>(), py::keep_alive<1, 3>(),
"Constructs a mixed bilinear form on the given trial and test FE spaces.")
// --- Setup Methods ---
.def("SetAssemblyLevel", &MixedBilinearForm::SetAssemblyLevel, py::arg("assembly_level"),
"Set the assembly level (e.g., LEGACY, FULL, PARTIAL, NONE).")
// --- Integrator Methods ---
.def("AddDomainIntegrator", py::overload_cast<BilinearFormIntegrator *>(&MixedBilinearForm::AddDomainIntegrator),
py::arg("bfi"), py::keep_alive<1, 2>(),
"Adds a domain integrator to the form.")
.def("AddBoundaryIntegrator", py::overload_cast<BilinearFormIntegrator *>(&MixedBilinearForm::AddBoundaryIntegrator),
py::arg("bfi"), py::keep_alive<1, 2>(),
"Adds a boundary integrator to the form.")
.def("AddInteriorFaceIntegrator", &MixedBilinearForm::AddInteriorFaceIntegrator,
py::arg("bfi"), py::keep_alive<1, 2>(),
"Adds an interior face integrator.")
.def("AddBdrFaceIntegrator", py::overload_cast<BilinearFormIntegrator *>(&MixedBilinearForm::AddBdrFaceIntegrator),
py::arg("bfi"), py::keep_alive<1, 2>(),
"Adds a boundary face integrator.")
.def("AddTraceFaceIntegrator", &MixedBilinearForm::AddTraceFaceIntegrator,
py::arg("bfi"), py::keep_alive<1, 2>(),
"Adds a trace face integrator.")
// --- Assembly and System Formulation ---
.def("Assemble", &MixedBilinearForm::Assemble, py::arg("skip_zeros") = 1,
"Assembles the mixed bilinear form into a sparse matrix.")
.def("FormRectangularSystemMatrix",
[](MixedBilinearForm &self, const Array<int> &trial_tdof_list, const Array<int> &test_tdof_list, OperatorHandle &A) {
self.FormRectangularSystemMatrix(trial_tdof_list, test_tdof_list, A);
},
py::arg("trial_tdof_list"), py::arg("test_tdof_list"), py::arg("A"),
"Forms the rectangular system matrix A, applying necessary transformations.")
.def("FormRectangularLinearSystem",
[](MixedBilinearForm &self, const Array<int> &trial_tdof_list, const Array<int> &test_tdof_list,
Vector &x, Vector &b, OperatorHandle &A, Vector &X, Vector &B) {
self.FormRectangularLinearSystem(trial_tdof_list, test_tdof_list, x, b, A, X, B);
},
py::arg("trial_tdof_list"), py::arg("test_tdof_list"), py::arg("x"), py::arg("b"),
py::arg("A"), py::arg("X"), py::arg("B"),
"Forms the rectangular linear system AX=B.")
// --- Accessor Methods ---
.def("TrialFESpace", py::overload_cast<>(&MixedBilinearForm::TrialFESpace, py::const_),
py::return_value_policy::reference_internal,
"Returns a pointer to the associated trial FiniteElementSpace.")
.def("TestFESpace", py::overload_cast<>(&MixedBilinearForm::TestFESpace, py::const_),
py::return_value_policy::reference_internal,
"Returns a pointer to the associated test FiniteElementSpace.")
.def("SpMat", py::overload_cast<>(&MixedBilinearForm::SpMat),
py::return_value_policy::reference_internal,
"Returns a reference to the internal sparse matrix.")
.def("Update", &MixedBilinearForm::Update,
"Update the MixedBilinearForm after the FE spaces have changed.");
}
// This function can be called from your main registration function.
void register_mesh_bindings(py::module &mfem_submodule) {
// Bind the mfem::Mesh class. No trampoline needed.
py::class_<Mesh>(mfem_submodule, "Mesh")
// --- Constructors & Loading ---
// Default constructor for creating an empty mesh object
.def(py::init<>())
// Constructor to load from a file path
.def(py::init<const std::string &, int, int, bool>(),
py::arg("filename"), py::arg("generate_edges") = 0,
py::arg("refine") = 1, py::arg("fix_orientation") = true)
// Static factory method for a more Pythonic loading interface
.def_static("LoadFromFile", &Mesh::LoadFromFile,
py::arg("filename"), py::arg("generate_edges") = 0,
py::arg("refine") = 1, py::arg("fix_orientation") = true,
"Creates a mesh by reading a file.")
// --- Basic Properties & Stats ---
.def_property_readonly("dim", &Mesh::Dimension)
.def_property_readonly("space_dim", &Mesh::SpaceDimension)
.def_property_readonly("nv", &Mesh::GetNV, "Number of Vertices")
.def_property_readonly("ne", &Mesh::GetNE, "Number of Elements")
.def_property_readonly("nbe", &Mesh::GetNBE, "Number of Boundary Elements")
.def_property_readonly("n_edges", &Mesh::GetNEdges)
.def_property_readonly("n_faces", &Mesh::GetNFaces)
.def_readonly("attributes", &Mesh::attributes)
.def_readonly("bdr_attributes", &Mesh::bdr_attributes)
.def("GetBoundingBox",
[](Mesh &self, int ref) {
Vector min, max;
self.GetBoundingBox(min, max, ref);
// Here you might want to return a tuple of numpy arrays instead
return py::make_tuple(min, max);
}, py::arg("ref") = 2,
"Returns the min and max corners of the mesh bounding box.")
// --- Connectivity Data ---
.def("GetElementVertices",
[](const Mesh &self, int i) {
Array<int> v;
self.GetElementVertices(i, v);
return v;
}, py::arg("i"), "Returns a list of vertex indices for element i.")
.def("GetBdrElementVertices",
[](const Mesh &self, int i) {
Array<int> v;
self.GetBdrElementVertices(i, v);
return v;
}, py::arg("i"), "Returns a list of vertex indices for boundary element i.")
.def("GetFaceElements",
[](const Mesh &self, int i) {
int e1, e2;
self.GetFaceElements(i, &e1, &e2);
return py::make_tuple(e1, e2);
}, py::arg("face_idx"), "Returns a tuple of the two elements sharing a face.")
.def("GetBdrElementFace",
[](const Mesh &self, int i) {
int f, o;
self.GetBdrElementFace(i, &f, &o);
return py::make_tuple(f, o);
}, py::arg("bdr_elem_idx"), "Returns the face index and orientation for a boundary element.")
.def("ElementToEdgeTable", &Mesh::ElementToEdgeTable, py::return_value_policy::reference_internal)
.def("ElementToFaceTable", &Mesh::ElementToFaceTable, py::return_value_policy::reference_internal)
// --- Coordinate Transformations & Curvature ---
.def("GetNodes", py::overload_cast<>(&Mesh::GetNodes, py::const_), py::return_value_policy::reference_internal,
"Returns the GridFunction for the mesh nodes (if any).")
.def("SetCurvature", &Mesh::SetCurvature, py::arg("order"), py::arg("discontinuous") = false,
py::arg("space_dim") = -1, py::arg("ordering") = 1,
"Set the curvature of the mesh, creating a high-order nodal GridFunction.")
// Use py::overload_cast to resolve ambiguity for overloaded methods
.def("GetElementTransformation", py::overload_cast<int>(&Mesh::GetElementTransformation), py::arg("i"),
py::return_value_policy::reference_internal)
.def("GetFaceElementTransformations", py::overload_cast<int, int>(&Mesh::GetFaceElementTransformations), py::arg("i"),
py::arg("mask")=31, py::return_value_policy::reference_internal)
// --- I/O & Visualization ---
.def("Save", &Mesh::Save, py::arg("filename"), py::arg("precision") = 16);
// Bind the VTKFormat enum used in PrintVTU
py::enum_<VTKFormat>(mfem_submodule, "VTKFormat")
.value("ASCII", VTKFormat::ASCII)
.value("BINARY", VTKFormat::BINARY)
.value("BINARY32", VTKFormat::BINARY32)
.export_values();
}
void register_table_bindings(pybind11::module &mfem_submodule) {
// Bind mfem::Table first, as it's a return type for some Mesh methods.
py::class_<Table>(mfem_submodule, "Table")
.def("GetRow", [](const Table &self, int row) {
Array<int> row_data;
self.GetRow(row, row_data);
// pybind11 will automatically convert Array<int> to a Python list
return row_data;
}, py::arg("row"), "Get a row of the table as a list.")
.def("RowSize", &Table::RowSize, py::arg("row"), "Get the number of entries in a specific row.")
.def_property_readonly("height", &Table::Size, "Number of rows in the table.")
.def_property_readonly("width", &Table::Width, "Number of columns in the table.")
.def("__len__", &Table::Size)
.def("__getitem__", [](const Table &self, int row) {
if (row < 0 || row >= self.Size()) {
throw py::index_error("Row index out of bounds");
}
Array<int> row_data;
self.GetRow(row, row_data);
return row_data;
})
.def("__repr__", [](const Table &self) {
return "<mfem.Table (" + std::to_string(self.Size()) + "x" +
std::to_string(self.Width()) + ")>";
});
}
void register_finite_element_collection_bindings(pybind11::module &mfem_submodule) {
py::class_<FiniteElementCollection>(mfem_submodule, "FiniteElementCollection")
// No constructor for abstract base classes
.def("Name", &FiniteElementCollection::Name)
.def("GetOrder", &FiniteElementCollection::GetOrder);
}
// Binds the mfem::BasisType class and its internal anonymous enum values
// as static properties of a Python class. This should be called before
// binding any class that uses BasisType values in its constructor.
void register_basis_type_bindings(py::module &m) {
py::class_<BasisType> basis_type(m, "BasisType", "Possible basis types.");
basis_type.def_property_readonly_static("Invalid", [](py::object) { return BasisType::Invalid; });
basis_type.def_property_readonly_static("GaussLegendre", [](py::object) { return BasisType::GaussLegendre; });
basis_type.def_property_readonly_static("GaussLobatto", [](py::object) { return BasisType::GaussLobatto; });
basis_type.def_property_readonly_static("Positive", [](py::object) { return BasisType::Positive; });
basis_type.def_property_readonly_static("OpenUniform", [](py::object) { return BasisType::OpenUniform; });
basis_type.def_property_readonly_static("ClosedUniform", [](py::object) { return BasisType::ClosedUniform; });
basis_type.def_property_readonly_static("OpenHalfUniform", [](py::object) { return BasisType::OpenHalfUniform; });
basis_type.def_property_readonly_static("Serendipity", [](py::object) { return BasisType::Serendipity; });
basis_type.def_property_readonly_static("ClosedGL", [](py::object) { return BasisType::ClosedGL; });
basis_type.def_property_readonly_static("IntegratedGLL", [](py::object) { return BasisType::IntegratedGLL; });
}
void register_H1_FECollection_bindings(py::module &m) {
py::class_<H1_FECollection, FiniteElementCollection>(m, "H1_FECollection")
.def(py::init([](int p, int dim) {
// The lambda explicitly calls the constructor, avoiding the
// default argument parsing issue.
return new H1_FECollection(p, dim);
}),
py::arg("p"),
py::arg("dim") = 3,
"Constructs an H1 finite element collection.")
.def("GetBasisType", &H1_FECollection::GetBasisType);
}
void register_RT_FECollection_bindings(py::module &m) {
py::class_<RT_FECollection, FiniteElementCollection>(m, "RT_FECollection")
.def(py::init<const int, const int>(),
py::arg("p"), py::arg("dim"),
"Constructs a Raviart-Thomas H(div)-conforming FE collection.")
.def("GetClosedBasisType", &RT_FECollection::GetClosedBasisType)
.def("GetOpenBasisType", &RT_FECollection::GetOpenBasisType);
}
void register_ND_FECollection_bindings(py::module &m) {
py::class_<ND_FECollection, FiniteElementCollection>(m, "ND_FECollection")
.def(py::init<const int, const int>(),
py::arg("p"), py::arg("dim"),
"Constructs a Nedelec H(curl)-conforming FE collection.")
.def("GetClosedBasisType", &ND_FECollection::GetClosedBasisType)
.def("GetOpenBasisType", &ND_FECollection::GetOpenBasisType);
}
void register_finite_element_space_bindings(py::module &mfem_submodule) {
// Bind dependent enums first
bind_ordering_enum(mfem_submodule);
// Bind the mfem::FiniteElementSpace class
py::class_<FiniteElementSpace>(mfem_submodule, "FiniteElementSpace")
// --- Constructors ---
.def(py::init<Mesh *, const FiniteElementCollection *, int, int>(),
py::arg("mesh"), py::arg("fec"), py::arg("vdim") = 1, py::arg("ordering") = Ordering::byNODES,
// Keep alive policies prevent the mesh and fec from being
// garbage collected by Python while the C++ FESpace object exists.
py::keep_alive<1, 2>(), py::keep_alive<1, 3>())
// --- Core Properties & Stats ---
.def_property_readonly("ndofs", &FiniteElementSpace::GetNDofs, "Number of local scalar degrees of freedom.")
.def_property_readonly("vdim", &FiniteElementSpace::GetVDim, "Vector dimension of the space.")
.def_property_readonly("vsize", &FiniteElementSpace::GetVSize, "Total number of local vector degrees of freedom.")
.def_property_readonly("true_vsize", &FiniteElementSpace::GetTrueVSize, "Number of true (conforming) vector degrees of freedom.")
.def("GetMesh", &FiniteElementSpace::GetMesh, py::return_value_policy::reference_internal, "Get the associated Mesh.")
.def("FEColl", &FiniteElementSpace::FEColl, py::return_value_policy::reference_internal, "Get the associated FiniteElementCollection.")
// --- DOF Management ---
.def("GetElementDofs",
[](const FiniteElementSpace &self, int i) {
Array<int> dofs;
self.GetElementDofs(i, dofs);
return dofs;
}, py::arg("elem_idx"), "Get the local scalar DOFs for a given element.")
.def("GetElementVDofs",
[](const FiniteElementSpace &self, int i) {
Array<int> vdofs;
self.GetElementVDofs(i, vdofs);
return vdofs;
}, py::arg("elem_idx"), "Get the local vector DOFs for a given element.")
.def("GetBdrElementDofs",
[](const FiniteElementSpace &self, int i) {
Array<int> dofs;
self.GetBdrElementDofs(i, dofs);
return dofs;
}, py::arg("bdr_elem_idx"), "Get the local scalar DOFs for a given boundary element.")
.def("GetEssentialVDofs",
[](const FiniteElementSpace &self, const Array<int> &bdr_attr_is_ess, int component) {
Array<int> ess_vdofs;
self.GetEssentialVDofs(bdr_attr_is_ess, ess_vdofs, component);
return ess_vdofs;
}, py::arg("bdr_attr_is_ess"), py::arg("component") = -1,
"Get a list of essential (Dirichlet) vector DOFs based on boundary attributes.")
.def("GetEssentialTrueDofs",
[](const FiniteElementSpace &self, const Array<int> &bdr_attr_is_ess, int component) {
Array<int> ess_tdof_list;
self.GetEssentialTrueDofs(bdr_attr_is_ess, ess_tdof_list, component);
return ess_tdof_list;
}, py::arg("bdr_attr_is_ess"), py::arg("component") = -1,
"Get a list of essential true (conforming) DOFs for use in linear systems.")
// --- Updating & Operators ---
.def("Update", &FiniteElementSpace::Update, py::arg("want_transform") = true,
"Update the FE space after the mesh has been modified (e.g., refined).")
.def("GetUpdateOperator", py::overload_cast<>(&FiniteElementSpace::GetUpdateOperator),
py::return_value_policy::reference_internal,
"Get the operator that maps GridFunctions from the old space to the new space after an Update.")
.def("GetProlongationMatrix", &FiniteElementSpace::GetProlongationMatrix,
py::return_value_policy::reference_internal, "Get the P operator (true DOFs to local DOFs).")
.def("GetRestrictionOperator", &FiniteElementSpace::GetRestrictionOperator,
py::return_value_policy::reference_internal, "Get the R operator (local DOFs to true DOFs).")
.def("__repr__", [](const FiniteElementSpace &fes) {
std::stringstream ss;
ss << "<mfem.FiniteElementSpace with " << fes.GetNDofs() << " DOFs, vdim=" << fes.GetVDim() << ">";
return ss.str();
});
}
void bind_ordering_enum(py::module &mfem_submodule) {
py::enum_<Ordering::Type>(mfem_submodule, "Ordering")
.value("byNODES", Ordering::byNODES)
.value("byVDIM", Ordering::byVDIM)
.export_values();
}
/// Binds mfem::GridFunction
void register_grid_function_bindings(py::module &m) {
// Assumes that Vector, FiniteElementSpace, Coefficient, VectorCoefficient,
// IntegrationRule, and ElementTransformation are already bound.
py::class_<GridFunction, Vector>(m, "GridFunction")
.def(py::init<FiniteElementSpace *>(), py::arg("fespace"),
py::keep_alive<1, 2>(), // Keep FE space alive
"Construct a GridFunction on a given FiniteElementSpace.")
.def(py::init<FiniteElementSpace *, real_t *>(),
py::arg("fespace"), py::arg("data"),
py::keep_alive<1, 2>(),
"Construct a GridFunction using previously allocated data.")
.def("FESpace", py::overload_cast<>(&GridFunction::FESpace, py::const_),
py::return_value_policy::reference_internal,
"Returns the associated FiniteElementSpace.")
.def("Update", &GridFunction::Update,
"Update the GridFunction after its FE space has been modified.")
.def("SetSpace", &GridFunction::SetSpace, py::arg("fespace"),
py::keep_alive<1, 2>(), "Associate a new FE space with the GridFunction.")
// --- Projection Methods ---
.def("ProjectCoefficient",
py::overload_cast<Coefficient &>(&GridFunction::ProjectCoefficient),
py::arg("coeff"),
"Project a scalar Coefficient onto the GridFunction.")
.def("ProjectCoefficient",
py::overload_cast<VectorCoefficient &>(&GridFunction::ProjectCoefficient),
py::arg("vcoeff"),
"Project a vector Coefficient onto the GridFunction.")
.def("ProjectBdrCoefficient",
py::overload_cast<Coefficient &, const Array<int> &>(&GridFunction::ProjectBdrCoefficient),
py::arg("coeff"), py::arg("attr"),
"Project a scalar Coefficient onto the boundary degrees of freedom.")
.def("ProjectBdrCoefficient",
py::overload_cast<VectorCoefficient &, const Array<int> &>(&GridFunction::ProjectBdrCoefficient),
py::arg("vcoeff"), py::arg("attr"),
"Project a vector Coefficient onto the boundary degrees of freedom.")
// --- Evaluation Methods ---
.def("GetValue", py::overload_cast<ElementTransformation &, const IntegrationPoint &, int, Vector *>(&GridFunction::GetValue, py::const_),
py::arg("T"), py::arg("ip"), py::arg("comp") = 0, py::arg("tr") = nullptr,
"Get the scalar value at a point described by an ElementTransformation.")
.def("GetVectorValue",
[](const GridFunction &self, ElementTransformation &T, const IntegrationPoint &ip) {
Vector val;
self.GetVectorValue(T, ip, val);
return val;
},
py::arg("T"), py::arg("ip"),
"Get the vector value at a point described by an ElementTransformation.")
.def("GetValues",
[](const GridFunction &self, ElementTransformation &T, const IntegrationRule &ir, int comp) {
Vector vals;
self.GetValues(T, ir, vals, comp);
return vals;
},
py::arg("T"), py::arg("ir"), py::arg("comp") = 0,
"Get scalar values at all points of an IntegrationRule.")
.def("GetVectorValues",
[](const GridFunction &self, ElementTransformation &T, const IntegrationRule &ir) {
DenseMatrix vals;
self.GetVectorValues(T, ir, vals);
return vals; // pybind11 handles DenseMatrix
},
py::arg("T"), py::arg("ir"),
"Get vector values at all points of an IntegrationRule.")
// --- True DOF (constrained system) Methods ---
.def("GetTrueDofs", &GridFunction::GetTrueDofs, py::arg("tv"),
"Extract the true-dofs from the GridFunction.")
.def("SetFromTrueDofs", &GridFunction::SetFromTrueDofs, py::arg("tv"),
"Set the GridFunction from a true-dof vector.")
// --- I/O ---
.def("Save", py::overload_cast<const char *, int>(&GridFunction::Save, py::const_),
py::arg("fname"), py::arg("precision") = 16)
.def("__repr__", [](const GridFunction &gf) {
std::stringstream ss;
ss << "<mfem.GridFunction with " << gf.Size() << " DOFs>";
return ss.str();
})
.def("as_numpy", [](const GridFunction &self) {
// Convert the GridFunction to a numpy array
return py::array_t<real_t>(self.Size(), self.GetData());
}, "Convert the GridFunction to a numpy array.");
}
// This function registers the coefficient classes to the python module
void register_coefficient_bindings(py::module &m) {
// --- Bind abstract base classes using trampolines ---
py::class_<Coefficient, serif::pybind::PyCoefficient> coefficient(m, "Coefficient");
coefficient
.def(py::init<>())
.def("SetTime", &Coefficient::SetTime, py::arg("t"))
.def("GetTime", &Coefficient::GetTime)
.def("Eval", py::overload_cast<ElementTransformation &, const IntegrationPoint&>(&Coefficient::Eval),
"Evaluate the coefficient at a point in an element.",
py::arg("T"), py::arg("ip"));
py::class_<VectorCoefficient, serif::pybind::PyVectorCoefficient> vector_coefficient(m, "VectorCoefficient");
vector_coefficient
.def(py::init<int>(), py::arg("vdim"))
.def("SetTime", &VectorCoefficient::SetTime, py::arg("t"))
.def("GetTime", &VectorCoefficient::GetTime)
.def("GetVDim", &VectorCoefficient::GetVDim)
.def("Eval", py::overload_cast<Vector &, ElementTransformation &, const IntegrationPoint &>(&VectorCoefficient::Eval),
"Evaluate the vector coefficient at a point in an element.",
py::arg("V"), py::arg("T"), py::arg("ip"));
// --- Bind useful concrete classes ---
// ConstantCoefficient
py::class_<ConstantCoefficient, Coefficient>(m, "ConstantCoefficient")
.def(py::init<real_t>(), py::arg("c") = 1.0)
.def_readwrite("constant", &ConstantCoefficient::constant);
// FunctionCoefficient (allows using Python functions as coefficients)
py::class_<FunctionCoefficient, Coefficient>(m, "FunctionCoefficient")
.def(py::init<std::function<real_t(const Vector &)>>(),
py::arg("F"), "Create a coefficient from a Python function of space.")
.def(py::init<std::function<real_t(const Vector &, real_t)>>(),
py::arg("TDF"), "Create a coefficient from a Python function of space and time.");
// VectorConstantCoefficient
py::class_<VectorConstantCoefficient, VectorCoefficient>(m, "VectorConstantCoefficient")
.def(py::init<const Vector &>(), py::arg("v"));
// VectorFunctionCoefficient
py::class_<VectorFunctionCoefficient, VectorCoefficient>(m, "VectorFunctionCoefficient")
.def(py::init<int, std::function<void(const Vector &, Vector &)>>(),
py::arg("dim"), py::arg("F"), "Create a vector coefficient from a Python function of space.")
.def(py::init<int, std::function<void(const Vector &, real_t, Vector &)>>(),
py::arg("dim"), py::arg("TDF"), "Create a vector coefficient from a Python function of space and time.");
// GridFunctionCoefficient (useful for source terms that depend on the solution)
py::class_<GridFunctionCoefficient, Coefficient>(m, "GridFunctionCoefficient")
.def(py::init<>())
.def(py::init<const GridFunction *>(), py::arg("gf"))
.def("SetGridFunction", &GridFunctionCoefficient::SetGridFunction, py::arg("gf"))
.def("GetGridFunction", &GridFunctionCoefficient::GetGridFunction, py::return_value_policy::reference);
}
/// Binds mfem::IntegrationPoint and mfem::IntegrationRule
void register_intrule_bindings(py::module &m) {
py::class_<IntegrationPoint>(m, "IntegrationPoint")
.def(py::init<>())
.def_readwrite("x", &IntegrationPoint::x)
.def_readwrite("y", &IntegrationPoint::y)
.def_readwrite("z", &IntegrationPoint::z)
.def_readwrite("weight", &IntegrationPoint::weight)
.def("__repr__", [](const IntegrationPoint &ip) {
std::stringstream ss;
ss << "IP(x=" << ip.x << ", y=" << ip.y << ", z=" << ip.z << ", w=" << ip.weight << ")";
return ss.str();
});
py::class_<IntegrationRule>(m, "IntegrationRule")
.def(py::init<>())
.def(py::init<int>(), py::arg("NumPoints"))
.def("GetNPoints", &IntegrationRule::GetNPoints)
.def("GetOrder", &IntegrationRule::GetOrder)
.def("__len__", &IntegrationRule::GetNPoints)
.def("__getitem__", [](const IntegrationRule &self, int i) {
if (i < 0 || i >= self.GetNPoints()) throw py::index_error();
return self.IntPoint(i);
})
.def("__iter__", [](const IntegrationRule &self) {
return py::make_iterator(self.begin(), self.end());
}, py::keep_alive<0, 1>());
}
/// Binds mfem::ElementTransformation and related classes
void register_eltrans_bindings(py::module &m) {
// Bind the base class. Users get pointers to this, but never create it.
// It is not abstract in the C++ sense, but we treat it as such for bindings.
py::class_<ElementTransformation> eltrans(m, "ElementTransformation");
eltrans
.def_readonly("Attribute", &ElementTransformation::Attribute)
.def_readonly("ElementNo", &ElementTransformation::ElementNo)
.def("SetIntPoint", &ElementTransformation::SetIntPoint, py::arg("ip"))
.def("Transform", py::overload_cast<const IntegrationPoint &, Vector &>(&ElementTransformation::Transform),
py::arg("ip"), py::arg("transip"))
.def("Weight", &ElementTransformation::Weight)
// Properties that return references to internal data should use reference_internal policy
.def_property_readonly("Jacobian", &ElementTransformation::Jacobian, py::return_value_policy::reference_internal)
.def_property_readonly("InverseJacobian", &ElementTransformation::InverseJacobian, py::return_value_policy::reference_internal)
.def_property_readonly("AdjugateJacobian", &ElementTransformation::AdjugateJacobian, py::return_value_policy::reference_internal);
// Bind IsoparametricTransformation, which is a concrete type of ElementTransformation
py::class_<IsoparametricTransformation, ElementTransformation>(m, "IsoparametricTransformation")
.def(py::init<>());
// Bind FaceElementTransformations, crucial for DG methods
py::class_<FaceElementTransformations, ElementTransformation>(m, "FaceElementTransformations")
.def(py::init<>())
.def_readonly("Elem1No", &FaceElementTransformations::Elem1No)
.def_readonly("Elem2No", &FaceElementTransformations::Elem2No)
.def_property_readonly("Elem1", [](FaceElementTransformations &self) { return self.Elem1; }, py::return_value_policy::reference)
.def_property_readonly("Elem2", [](FaceElementTransformations &self) { return self.Elem2; }, py::return_value_policy::reference)
.def("GetElement1IntPoint", &FaceElementTransformations::GetElement1IntPoint)
.def("GetElement2IntPoint", &FaceElementTransformations::GetElement2IntPoint);
// Bind the enum used for selecting transformations
py::enum_<FaceElementTransformations::ConfigMasks>(eltrans, "ConfigMasks")
.value("HAVE_ELEM1", FaceElementTransformations::HAVE_ELEM1)
.value("HAVE_ELEM2", FaceElementTransformations::HAVE_ELEM2)
.value("HAVE_LOC1", FaceElementTransformations::HAVE_LOC1)
.value("HAVE_LOC2", FaceElementTransformations::HAVE_LOC2)
.value("HAVE_FACE", FaceElementTransformations::HAVE_FACE)
.export_values();
}

307
src/python/mfem/bindings.h Normal file
View File

@@ -0,0 +1,307 @@
/**
* @file bindings.h
* @brief Declares functions to register MFEM core library components with pybind11.
*
* This header file lists the functions responsible for creating Python bindings
* for various parts of the MFEM library. Each function typically registers
* a set of related classes, enums, or functionalities to a pybind11::module,
* which is expected to be a submodule named `mfem` within the main `serif` Python module.
*
* @see /Users/tboudreaux/Programming/SERiF/src/python/bindings.cpp for how these are used.
*/
#pragma once
#include <pybind11/pybind11.h>
/**
* @brief Registers all core MFEM bindings to the given Python submodule.
*
* This function serves as the main entry point for exposing MFEM functionalities
* to Python. It calls various other `register_*_bindings` and `bind_*_enum`
* functions to populate the `mfem_submodule`.
*
* @param mfem_submodule The pybind11 module (typically `serif.mfem`) to which
* MFEM bindings will be added.
* @par Python Usage Example:
* @code{.py}
* import serif.mfem as mfem
* # Now mfem.Operator, mfem.Vector, mfem.Mesh, etc., are accessible.
* vec = mfem.Vector(10)
* print(vec.Size())
* @endcode
*/
void register_mfem_bindings(pybind11::module &mfem_submodule);
/**
* @brief Registers mfem::Operator and related classes.
* @param mfem_submodule The `serif.mfem` Python submodule.
* @par Python Usage Example:
* @code{.py}
* import serif.mfem as mfem
* # Assuming PyOperator trampoline is bound
* # op = mfem.Operator() # Or a derived class like mfem.DenseMatrix
* @endcode
*/
void register_operator_bindings(pybind11::module &mfem_submodule);
/**
* @brief Registers mfem::Matrix and its derived classes (e.g., mfem::DenseMatrix, mfem::SparseMatrix).
* @param mfem_submodule The `serif.mfem` Python submodule.
* @par Python Usage Example:
* @code{.py}
* import serif.mfem as mfem
* mat = mfem.DenseMatrix(2, 2)
* mat[0,0] = 1.0
* mat[0,1] = 2.0
* mat[1,0] = 3.0
* mat[1,1] = 4.0
* mat.Print()
* @endcode
*/
void register_matrix_bindings(pybind11::module &mfem_submodule);
/**
* @brief Registers mfem::Vector.
* @param mfem_submodule The `serif.mfem` Python submodule.
* @par Python Usage Example:
* @code{.py}
* import serif.mfem as mfem
* vec = mfem.Vector(5)
* vec[0] = 1.5
* print(vec.Size(), vec[0])
* @endcode
*/
void register_vector_bindings(pybind11::module &mfem_submodule);
/**
* @brief Registers mfem::Array.
* @param mfem_submodule The `serif.mfem` Python submodule.
* @par Python Usage Example:
* @code{.py}
* import serif.mfem as mfem
* arr_int = mfem.intArray(5) # Assuming intArray is a typedef or specific binding
* arr_int[0] = 10
* print(arr_int.Size(), arr_int[0])
* @endcode
*/
void register_array_bindings(pybind11::module &mfem_submodule);
/**
* @brief Binds the mfem::AssemblyLevel enum.
* @param mfem_submodule The `serif.mfem` Python submodule.
* @par Python Usage Example:
* @code{.py}
* import serif.mfem as mfem
* # level = mfem.AssemblyLevel.LEGACY # Or other enum values
* # print(level)
* @endcode
*/
void bind_assembly_level_enum(pybind11::module &mfem_submodule);
/**
* @brief Registers mfem::BilinearForm and related functionalities.
* @param mfem_submodule The `serif.mfem` Python submodule.
* @par Python Usage Example:
* @code{.py}
* import serif.mfem as mfem
* # Assuming FiniteElementSpace (fes) is created
* # bform = mfem.BilinearForm(fes)
* # bform.AddDomainIntegrator(mfem.MassIntegrator()) # Assuming MassIntegrator is bound
* # bform.Assemble()
* # A = bform.SpMat()
* @endcode
*/
void register_bilinear_form_bindings(pybind11::module &mfem_submodule);
/**
* @brief Registers mfem::MixedBilinearForm and related functionalities.
* @param mfem_submodule The `serif.mfem` Python submodule.
* @par Python Usage Example:
* @code{.py}
* import serif.mfem as mfem
* # Assuming trial_fes and test_fes are FiniteElementSpaces
* # mbform = mfem.MixedBilinearForm(trial_fes, test_fes)
* # mbform.AddDomainIntegrator(mfem.VectorFEMassIntegrator()) # Example
* # mbform.Assemble()
* @endcode
*/
void register_mixed_bilinear_form_bindings(pybind11::module &mfem_submodule);
/**
* @brief Registers mfem::Table.
* @param mfem_submodule The `serif.mfem` Python submodule.
* @par Python Usage Example:
* @code{.py}
* import serif.mfem as mfem
* table = mfem.Table()
* # ... use table methods ...
* @endcode
*/
void register_table_bindings(pybind11::module &mfem_submodule);
/**
* @brief Registers mfem::Mesh.
* @param mfem_submodule The `serif.mfem` Python submodule.
* @par Python Usage Example:
* @code{.py}
* import serif.mfem as mfem
* mesh = mfem.Mesh.MakeCartesian1D(10) # Example constructor
* print(mesh.Dimension(), mesh.GetNE())
* @endcode
*/
void register_mesh_bindings(pybind11::module &mfem_submodule);
/**
* @brief Registers mfem::BasisType enum and related constants.
* @param mfem_submodule The `serif.mfem` Python submodule.
* @par Python Usage Example:
* @code{.py}
* import serif.mfem as mfem
* basis_type = mfem.BasisType.GaussLobatto
* print(basis_type)
* @endcode
*/
void register_basis_type_bindings(pybind11::module &mfem_submodule);
/**
* @brief Registers mfem::FiniteElementCollection base class.
* @param mfem_submodule The `serif.mfem` Python submodule.
* @par Python Usage Example:
* @code{.py}
* import serif.mfem as mfem
* # fec = mfem.FiniteElementCollection() # Typically use derived classes
* @endcode
*/
void register_finite_element_collection_bindings(pybind11::module &mfem_submodule);
/**
* @brief Registers mfem::H1_FECollection.
* @param mfem_submodule The `serif.mfem` Python submodule.
* @par Python Usage Example:
* @code{.py}
* import serif.mfem as mfem
* order = 1
* dim = 2
* fec = mfem.H1_FECollection(order, dim)
* print(fec.GetName())
* @endcode
*/
void register_H1_FECollection_bindings(pybind11::module &mfem_submodule);
/**
* @brief Registers mfem::RT_FECollection.
* @param mfem_submodule The `serif.mfem` Python submodule.
* @par Python Usage Example:
* @code{.py}
* import serif.mfem as mfem
* order = 1
* dim = 2
* fec = mfem.RT_FECollection(order-1, dim) # RT_FECollection uses p-1 for order p
* print(fec.GetName())
* @endcode
*/
void register_RT_FECollection_bindings(pybind11::module &mfem_submodule);
/**
* @brief Registers mfem::ND_FECollection (Nedelec finite elements).
* @param mfem_submodule The `serif.mfem` Python submodule.
* @par Python Usage Example:
* @code{.py}
* import serif.mfem as mfem
* order = 1
* dim = 3
* fec = mfem.ND_FECollection(order, dim)
* print(fec.GetName())
* @endcode
*/
void register_ND_FECollection_bindings(pybind11::module &mfem_submodule);
/**
* @brief Binds the mfem::Ordering::Type enum.
* @param mfem_submodule The `serif.mfem` Python submodule.
* @par Python Usage Example:
* @code{.py}
* import serif.mfem as mfem
* ordering = mfem.Ordering.byNODES
* print(ordering)
* @endcode
*/
void bind_ordering_enum(pybind11::module &mfem_submodule);
/**
* @brief Registers mfem::FiniteElementSpace.
* @param mfem_submodule The `serif.mfem` Python submodule.
* @par Python Usage Example:
* @code{.py}
* import serif.mfem as mfem
* mesh = mfem.Mesh.MakeCartesian1D(5)
* fec = mfem.H1_FECollection(1, mesh.Dimension())
* fes = mfem.FiniteElementSpace(mesh, fec)
* print(fes.GetNDofs())
* @endcode
*/
void register_finite_element_space_bindings(pybind11::module &mfem_submodule);
/**
* @brief Registers mfem::Coefficient, mfem::VectorCoefficient and related classes/trampolines.
* @param m The `serif.mfem` Python submodule.
* @par Python Usage Example:
* @code{.py}
* import serif.mfem as mfem
* const_coeff = mfem.ConstantCoefficient(2.5)
* # vec_coeff = mfem.VectorConstantCoefficient(mfem.Vector([1.0, 2.0]))
*
* # Using a Python-derived coefficient (if PyCoefficient trampoline is bound)
* class MyCoeff(mfem.Coefficient):
* def Eval(self, T, ip):
* return 1.0
* my_c = MyCoeff()
* @endcode
*/
void register_coefficient_bindings(pybind11::module &m);
/**
* @brief Registers mfem::ElementTransformation.
* @param m The `serif.mfem` Python submodule.
* @par Python Usage Example:
* @code{.py}
* import serif.mfem as mfem
* # ElementTransformation objects are usually obtained from Mesh or FiniteElementSpace
* # mesh = mfem.Mesh.MakeCartesian1D(1)
* # el_trans = mesh.GetElementTransformation(0)
* # print(el_trans.ElementNo)
* @endcode
*/
void register_eltrans_bindings(pybind11::module &m);
/**
* @brief Registers mfem::IntegrationRule and mfem::IntegrationPoint.
* @param m The `serif.mfem` Python submodule.
* @par Python Usage Example:
* @code{.py}
* import serif.mfem as mfem
* # Get a standard integration rule
* ir = mfem.IntRules.Get(mfem.Geometry.SEGMENT, 3) # Order 3 for a segment
* for i in range(ir.GetNPoints()):
* ip = ir.IntPoint(i)
* # print(f"Point {i}: coords {ip.x}, weight {ip.weight}")
* @endcode
*/
void register_intrule_bindings(pybind11::module &m);
/**
* @brief Registers mfem::GridFunction.
* @param mfem_submodule The `serif.mfem` Python submodule.
* @par Python Usage Example:
* @code{.py}
* import serif.mfem as mfem
* mesh = mfem.Mesh.MakeCartesian1D(5)
* fec = mfem.H1_FECollection(1, mesh.Dimension())
* fes = mfem.FiniteElementSpace(mesh, fec)
* gf = mfem.GridFunction(fes)
* gf.Assign(0.0) # Set all values to 0
* print(gf.Size())
* @endcode
*/
void register_grid_function_bindings(pybind11::module &mfem_submodule);

View File

@@ -0,0 +1,26 @@
subdir('Trampoline')
# Define the library
bindings_sources = files(
'bindings.cpp',
)
bindings_headers = files(
'bindings.h',
)
dependencies = [
config_dep,
resourceManager_dep,
python3_dep,
pybind11_dep,
mpi_dep,
trampoline_dep,
]
shared_module('py_mfem',
bindings_sources,
include_directories: include_directories('.'),
cpp_args: ['-fvisibility=default'],
install : true,
dependencies: dependencies,
)

View File

@@ -0,0 +1,26 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc if needed directly
#include <pybind11/numpy.h>
#include "bindings.h"
#include "EOSio.h"
#include "helm.h"
#include "polySolver.h"
#include "mfem.hpp"
namespace py = pybind11;
void register_polytrope_bindings(pybind11::module &polytrope_submodule) {
py::class_<serif::polytrope::PolySolver>(polytrope_submodule, "PolySolver")
.def(py::init<double, int>(), py::arg("polytropic_index"), py::arg("FEM_order"))
.def("solve", &serif::polytrope::PolySolver::solve, "Solve the polytrope equation.")
.def("get_theta", &serif::polytrope::PolySolver::getTheta, py::return_value_policy::reference_internal)
.def("get_phi", &serif::polytrope::PolySolver::getPhi, py::return_value_policy::reference_internal)
.def("get_order", &serif::polytrope::PolySolver::getOrder)
.def("get_n", &serif::polytrope::PolySolver::getN);
py::class_<serif::polytrope::PolytropeOperator, mfem::Operator>(polytrope_submodule, "PolytropeOperator")
.def("Mult", &serif::polytrope::PolytropeOperator::Mult);
}

View File

@@ -0,0 +1,31 @@
/**
* @file bindings.h
* @brief Declares the function to register polytrope module C++ components with pybind11.
*
* This file contains the declaration for `register_polytrope_bindings`, which is responsible
* for creating Python bindings for classes and functions within the `serif::polytrope` C++
* namespace. These bindings will be accessible in Python under the `serif.polytrope` submodule.
*/
#pragma once
#include <pybind11/pybind11.h>
/**
* @brief Registers C++ classes and functions from the `serif::polytrope` namespace to Python.
*
* This function takes a pybind11::module object, representing the `serif.polytrope` Python submodule,
* and adds bindings for various components like `PolytropeOperator`, `PolySolver`, etc.
* This allows these C++ components to be instantiated and used directly from Python.
*
* @param polytrope_submodule The pybind11 module (typically `serif.polytrope`) to which
* the polytrope C++ bindings will be added.
*
* @par Python Usage Example:
* After these bindings are registered and the Python module is imported:
* @code{.py}
* from serif.polytrope import PolySolver
* polytrope = PolySolver(1.5, 1)
* polytrope.solve()
* @endcode
*/
void register_polytrope_bindings(pybind11::module &polytrope_submodule);

View File

@@ -0,0 +1,19 @@
# Define the library
bindings_sources = files('bindings.cpp')
bindings_headers = files('bindings.h')
dependencies = [
polysolver_dep,
config_dep,
resourceManager_dep,
python3_dep,
pybind11_dep,
]
shared_module('py_polytrope',
bindings_sources,
include_directories: include_directories('.'),
cpp_args: ['-fvisibility=default'],
install : true,
dependencies: dependencies,
)

View File

@@ -0,0 +1,11 @@
from serif.mfem import Mesh
import argparse
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Test loading a mesh using MFEM's mesh loading called from Python")
parser.add_argument("path", type=str, help="path to mesh")
args = parser.parse_args()
mesh = Mesh(args.path)
print(mesh.nv)

View File

@@ -0,0 +1,18 @@
from serif import config
from serif.polytrope import PolySolver
from serif.mfem import Matrix
from serif.mfem import Operator
from serif.mfem.forms import BilinearForm
config.loadConfig('../../testsConfig.yaml')
n = config.get("Tests:Poly:Index", 0.0)
polytrope = PolySolver(n, 1)
polytrope.solve()
theta = polytrope.get_theta()
print(theta)
FESpace = theta.FESpace()
print(FESpace)
print(theta.as_numpy())