feat(pythonInterface/mfem): added loads of mfem bindings to make interacting through python easy
This commit is contained in:
@@ -9,6 +9,10 @@ py_mod = py_installation.extension_module(
|
||||
meson.project_source_root() + '/src/python/const/bindings.cpp',
|
||||
meson.project_source_root() + '/src/python/config/bindings.cpp',
|
||||
meson.project_source_root() + '/src/python/eos/bindings.cpp',
|
||||
meson.project_source_root() + '/src/python/mfem/bindings.cpp',
|
||||
meson.project_source_root() + '/src/python/mfem/Trampoline/PyMFEMTrampolines/Operator/PyOperator.cpp',
|
||||
meson.project_source_root() + '/src/python/mfem/Trampoline/PyMFEMTrampolines/Operator/Matrix/PyMatrix.cpp',
|
||||
meson.project_source_root() + '/src/python/mfem/Trampoline/PyMFEMTrampolines/Coefficient/PyCoefficient.cpp',
|
||||
meson.project_source_root() + '/src/python/polytrope/bindings.cpp',
|
||||
],
|
||||
dependencies : [
|
||||
@@ -18,7 +22,9 @@ py_mod = py_installation.extension_module(
|
||||
composition_dep,
|
||||
eos_dep,
|
||||
species_weight_dep,
|
||||
mfem_dep,
|
||||
polysolver_dep,
|
||||
trampoline_dep
|
||||
],
|
||||
cpp_args : ['-UNDEBUG'], # Example: Ensure assertions are enabled if needed
|
||||
install : true,
|
||||
|
||||
@@ -39,8 +39,8 @@
|
||||
#include "utilities.h"
|
||||
#include "quill/LogMacros.h"
|
||||
|
||||
namespace serif {
|
||||
namespace polytrope {
|
||||
|
||||
namespace serif::polytrope {
|
||||
|
||||
namespace laneEmden {
|
||||
|
||||
@@ -382,7 +382,6 @@ void PolySolver::setInitialGuess() const {
|
||||
serif::probe::glVisView(*m_theta, m_mesh, "θ init");
|
||||
serif::probe::glVisView(*m_phi, m_mesh, "φ init");
|
||||
}
|
||||
std::cout << "HERE" << std::endl;
|
||||
|
||||
}
|
||||
|
||||
@@ -497,5 +496,5 @@ solverBundle PolySolver::setupNewtonSolver() const {
|
||||
return solver;
|
||||
}
|
||||
|
||||
} // namespace polytrope
|
||||
} // namespace serif
|
||||
} // namespace serif::polytrope
|
||||
|
||||
|
||||
@@ -271,7 +271,14 @@ public: // Public methods
|
||||
* @return A reference to the `mfem::GridFunction` storing the \f$\theta\f$ solution.
|
||||
* @note The solution is populated after `solve()` has been successfully called.
|
||||
*/
|
||||
mfem::GridFunction& getSolution() const { return *m_theta; }
|
||||
mfem::GridFunction& getTheta() const { return *m_theta; }
|
||||
|
||||
/**
|
||||
* @brief Gets a reference to the solution grid function for \f$\phi\f$.
|
||||
* @return A reference to the `mfem::GridFunction` storing the \f$\phi\f$ solution.
|
||||
* @note The solution is populated after `solve()` has been successfully called.
|
||||
*/
|
||||
mfem::GridFunction& getPhi() const { return *m_phi; }
|
||||
|
||||
private: // Private Attributes
|
||||
// --- Configuration and Logging ---
|
||||
|
||||
@@ -119,7 +119,6 @@ namespace serif::utilities {
|
||||
mfem::IsoparametricTransformation T;
|
||||
mesh->GetElementTransformation(ne, &T);
|
||||
phi_gf.GetCurl(T, curl_mag_vec);
|
||||
std::cout << "HERE" << std::endl;
|
||||
}
|
||||
mfem::L2_FECollection fac(order, dim);
|
||||
mfem::FiniteElementSpace fs(mesh, &fac);
|
||||
|
||||
@@ -31,27 +31,26 @@
|
||||
* @brief A collection of utilities for working with MFEM and solving the lane-emden equation.
|
||||
*/
|
||||
|
||||
namespace serif {
|
||||
namespace polytrope {
|
||||
/**
|
||||
namespace serif::polytrope {
|
||||
/**
|
||||
* @namespace polyMFEMUtils
|
||||
* @brief A namespace for utilities for working with MFEM and solving the lane-emden equation.
|
||||
*/
|
||||
namespace polyMFEMUtils {
|
||||
/**
|
||||
namespace polyMFEMUtils {
|
||||
/**
|
||||
* @brief A class for nonlinear power integrator.
|
||||
*/
|
||||
class NonlinearPowerIntegrator: public mfem::NonlinearFormIntegrator {
|
||||
public:
|
||||
/**
|
||||
class NonlinearPowerIntegrator: public mfem::NonlinearFormIntegrator {
|
||||
public:
|
||||
/**
|
||||
* @brief Constructor for NonlinearPowerIntegrator.
|
||||
*
|
||||
* @param coeff The function coefficient.
|
||||
* @param n The polytropic index.
|
||||
*/
|
||||
NonlinearPowerIntegrator(double n);
|
||||
NonlinearPowerIntegrator(double n);
|
||||
|
||||
/**
|
||||
/**
|
||||
* @brief Assembles the element vector.
|
||||
*
|
||||
* @param el The finite element.
|
||||
@@ -59,8 +58,8 @@ namespace polyMFEMUtils {
|
||||
* @param elfun The element function.
|
||||
* @param elvect The element vector to be assembled.
|
||||
*/
|
||||
virtual void AssembleElementVector(const mfem::FiniteElement &el, mfem::ElementTransformation &Trans, const mfem::Vector &elfun, mfem::Vector &elvect) override;
|
||||
/**
|
||||
virtual void AssembleElementVector(const mfem::FiniteElement &el, mfem::ElementTransformation &Trans, const mfem::Vector &elfun, mfem::Vector &elvect) override;
|
||||
/**
|
||||
* @brief Assembles the element gradient.
|
||||
*
|
||||
* @param el The finite element.
|
||||
@@ -68,40 +67,39 @@ namespace polyMFEMUtils {
|
||||
* @param elfun The element function.
|
||||
* @param elmat The element matrix to be assembled.
|
||||
*/
|
||||
virtual void AssembleElementGrad (const mfem::FiniteElement &el, mfem::ElementTransformation &Trans, const mfem::Vector &elfun, mfem::DenseMatrix &elmat) override;
|
||||
private:
|
||||
serif::config::Config& m_config = serif::config::Config::getInstance();
|
||||
serif::probe::LogManager& m_logManager = serif::probe::LogManager::getInstance();
|
||||
quill::Logger* m_logger = m_logManager.getLogger("log");
|
||||
double m_polytropicIndex;
|
||||
double m_epsilon;
|
||||
static constexpr double m_regularizationRadius = 0.15; ///< Regularization radius for the epsilon function, used to avoid singularities in the power law.
|
||||
static constexpr double m_regularizationCoeff = 1.0/6.0; ///< Coefficient for the regularization term, used to ensure smoothness in the power law.
|
||||
};
|
||||
virtual void AssembleElementGrad (const mfem::FiniteElement &el, mfem::ElementTransformation &Trans, const mfem::Vector &elfun, mfem::DenseMatrix &elmat) override;
|
||||
private:
|
||||
serif::config::Config& m_config = serif::config::Config::getInstance();
|
||||
serif::probe::LogManager& m_logManager = serif::probe::LogManager::getInstance();
|
||||
quill::Logger* m_logger = m_logManager.getLogger("log");
|
||||
double m_polytropicIndex;
|
||||
double m_epsilon;
|
||||
static constexpr double m_regularizationRadius = 0.15; ///< Regularization radius for the epsilon function, used to avoid singularities in the power law.
|
||||
static constexpr double m_regularizationCoeff = 1.0/6.0; ///< Coefficient for the regularization term, used to ensure smoothness in the power law.
|
||||
};
|
||||
|
||||
inline double dfmod(const double epsilon, const double n) {
|
||||
if (n == 0.0) {
|
||||
return 0.0;
|
||||
inline double dfmod(const double epsilon, const double n) {
|
||||
if (n == 0.0) {
|
||||
return 0.0;
|
||||
}
|
||||
if (n == 1.0) {
|
||||
return 1.0;
|
||||
}
|
||||
return n * std::pow(epsilon, n - 1.0);
|
||||
}
|
||||
if (n == 1.0) {
|
||||
return 1.0;
|
||||
|
||||
inline double fmod(const double theta, const double n, const double epsilon) {
|
||||
if (n == 0.0) {
|
||||
return 1.0;
|
||||
}
|
||||
// For n != 0
|
||||
const double y_prime_at_epsilon = dfmod(epsilon, n); // Uses the robust dfmod
|
||||
const double y_at_epsilon = std::pow(epsilon, n); // epsilon^n
|
||||
|
||||
// f_mod(theta) = y_at_epsilon + y_prime_at_epsilon * (theta - epsilon)
|
||||
return y_at_epsilon + y_prime_at_epsilon * (theta - epsilon);
|
||||
}
|
||||
return n * std::pow(epsilon, n - 1.0);
|
||||
}
|
||||
|
||||
inline double fmod(const double theta, const double n, const double epsilon) {
|
||||
if (n == 0.0) {
|
||||
return 1.0;
|
||||
}
|
||||
// For n != 0
|
||||
const double y_prime_at_epsilon = dfmod(epsilon, n); // Uses the robust dfmod
|
||||
const double y_at_epsilon = std::pow(epsilon, n); // epsilon^n
|
||||
|
||||
// f_mod(theta) = y_at_epsilon + y_prime_at_epsilon * (theta - epsilon)
|
||||
return y_at_epsilon + y_prime_at_epsilon * (theta - epsilon);
|
||||
}
|
||||
|
||||
|
||||
} // namespace polyMFEMUtils
|
||||
} // namespace polytrope
|
||||
} // namespace serif
|
||||
} // namespace polyMFEMUtils
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "composition/bindings.h"
|
||||
#include "config/bindings.h"
|
||||
#include "eos/bindings.h"
|
||||
#include "mfem/bindings.h"
|
||||
#include "polytrope/bindings.h"
|
||||
|
||||
PYBIND11_MODULE(serif, m) {
|
||||
@@ -23,6 +24,9 @@ PYBIND11_MODULE(serif, m) {
|
||||
auto eosMod = m.def_submodule("eos", "EOS-module bindings");
|
||||
register_eos_bindings(eosMod);
|
||||
|
||||
auto mfemMod = m.def_submodule("mfem", "MFEM bindings");
|
||||
register_mfem_bindings(mfemMod);
|
||||
|
||||
auto polytropeMod = m.def_submodule("polytrope", "Polytrope-module bindings");
|
||||
register_polytrope_bindings(polytropeMod);
|
||||
}
|
||||
@@ -1,4 +1,6 @@
|
||||
subdir('composition')
|
||||
subdir('const')
|
||||
subdir('config')
|
||||
|
||||
subdir('mfem')
|
||||
subdir('eos')
|
||||
@@ -0,0 +1,51 @@
|
||||
#include "PyCoefficient.h"
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <pybind11/functional.h> // Needed for std::function
|
||||
#include <memory>
|
||||
|
||||
#include "mfem.hpp"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace mfem;
|
||||
|
||||
namespace serif::pybind {
|
||||
real_t PyCoefficient::Eval(ElementTransformation &T, const IntegrationPoint &ip) {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
real_t, /* Return type */
|
||||
Coefficient, /* Base class */
|
||||
Eval, /* Method name */
|
||||
T, ip /* Arguments */
|
||||
);
|
||||
}
|
||||
|
||||
// Override virtual SetTime method
|
||||
void PyCoefficient::SetTime(real_t t) {
|
||||
PYBIND11_OVERRIDE(
|
||||
void,
|
||||
Coefficient,
|
||||
SetTime,
|
||||
t
|
||||
);
|
||||
}
|
||||
|
||||
void PyVectorCoefficient::Eval(Vector &V, ElementTransformation &T, const IntegrationPoint &ip) {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
void, /* Return type */
|
||||
VectorCoefficient, /* Base class */
|
||||
Eval, /* Method name */
|
||||
V, T, ip /* Arguments */
|
||||
);
|
||||
}
|
||||
|
||||
// Override the virtual SetTime method
|
||||
void PyVectorCoefficient::SetTime(real_t t) {
|
||||
PYBIND11_OVERRIDE(
|
||||
void,
|
||||
VectorCoefficient,
|
||||
SetTime,
|
||||
t
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <pybind11/functional.h> // Needed for std::function
|
||||
#include <memory>
|
||||
|
||||
#include "mfem.hpp"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace mfem;
|
||||
|
||||
namespace serif::pybind {
|
||||
// --- Trampoline for the abstract Coefficient base class ---
|
||||
class PyCoefficient : public Coefficient {
|
||||
public:
|
||||
using Coefficient::Coefficient;
|
||||
real_t Eval(ElementTransformation &T, const IntegrationPoint &ip) override;
|
||||
void SetTime(real_t t) override;
|
||||
};
|
||||
|
||||
// --- Trampoline for the abstract VectorCoefficient base class ---
|
||||
class PyVectorCoefficient : public VectorCoefficient {
|
||||
public:
|
||||
using VectorCoefficient::VectorCoefficient;
|
||||
void Eval(Vector &V, ElementTransformation &T, const IntegrationPoint &ip) override;
|
||||
void SetTime(real_t t) override;
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
PyCoefficient_sources = files(
|
||||
'PyCoefficient.cpp'
|
||||
)
|
||||
|
||||
trampoline_sources += PyCoefficient_sources
|
||||
@@ -0,0 +1,137 @@
|
||||
#include "PyMFEMTrampolines/Operator/Matrix/PyMatrix.h"
|
||||
#include "mfem.hpp"
|
||||
|
||||
namespace serif::pybind {
|
||||
|
||||
// --- Trampolines for new mfem::Matrix pure virtual methods ---
|
||||
|
||||
mfem::real_t& PyMatrix::Elem(int i, int j) {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
mfem::real_t&, /* Return type */
|
||||
mfem::Matrix, /* C++ base class */
|
||||
Elem, /* C++ function name */
|
||||
i, /* Argument 1 */
|
||||
j /* Argument 2 */
|
||||
);
|
||||
}
|
||||
|
||||
const mfem::real_t& PyMatrix::Elem(int i, int j) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
const mfem::real_t&,
|
||||
mfem::Matrix,
|
||||
Elem,
|
||||
i,
|
||||
j
|
||||
);
|
||||
}
|
||||
|
||||
mfem::MatrixInverse* PyMatrix::Inverse() const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
mfem::MatrixInverse*,
|
||||
mfem::Matrix,
|
||||
Inverse
|
||||
);
|
||||
}
|
||||
|
||||
// --- Trampoline for new mfem::Matrix regular virtual methods ---
|
||||
|
||||
void PyMatrix::Finalize(int skip_zeros) {
|
||||
PYBIND11_OVERRIDE(
|
||||
void,
|
||||
mfem::Matrix,
|
||||
Finalize,
|
||||
skip_zeros
|
||||
);
|
||||
}
|
||||
|
||||
// --- Trampolines for inherited mfem::Operator virtual methods ---
|
||||
|
||||
void PyMatrix::Mult(const mfem::Vector &x, mfem::Vector &y) const {
|
||||
// This remains PURE as mfem::Matrix does not implement it.
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
void,
|
||||
mfem::Matrix,
|
||||
Mult,
|
||||
x,
|
||||
y
|
||||
);
|
||||
}
|
||||
|
||||
void PyMatrix::MultTranspose(const mfem::Vector &x, mfem::Vector &y) const {
|
||||
PYBIND11_OVERRIDE(
|
||||
void,
|
||||
mfem::Matrix,
|
||||
MultTranspose,
|
||||
x,
|
||||
y
|
||||
);
|
||||
}
|
||||
|
||||
void PyMatrix::AddMult(const mfem::Vector &x, mfem::Vector &y, const mfem::real_t a) const {
|
||||
PYBIND11_OVERRIDE(
|
||||
void,
|
||||
mfem::Matrix,
|
||||
AddMult,
|
||||
x,
|
||||
y,
|
||||
a
|
||||
);
|
||||
}
|
||||
|
||||
void PyMatrix::AddMultTranspose(const mfem::Vector &x, mfem::Vector &y, const mfem::real_t a) const {
|
||||
PYBIND11_OVERRIDE(
|
||||
void,
|
||||
mfem::Matrix,
|
||||
AddMultTranspose,
|
||||
x,
|
||||
y,
|
||||
a
|
||||
);
|
||||
}
|
||||
|
||||
mfem::Operator& PyMatrix::GetGradient(const mfem::Vector &x) const {
|
||||
PYBIND11_OVERRIDE(
|
||||
mfem::Operator&,
|
||||
mfem::Matrix,
|
||||
GetGradient,
|
||||
x
|
||||
);
|
||||
}
|
||||
|
||||
void PyMatrix::AssembleDiagonal(mfem::Vector &diag) const {
|
||||
PYBIND11_OVERRIDE(
|
||||
void,
|
||||
mfem::Matrix,
|
||||
AssembleDiagonal,
|
||||
diag
|
||||
);
|
||||
}
|
||||
|
||||
const mfem::Operator* PyMatrix::GetProlongation() const {
|
||||
PYBIND11_OVERRIDE(
|
||||
const mfem::Operator*,
|
||||
mfem::Matrix,
|
||||
GetProlongation
|
||||
);
|
||||
}
|
||||
|
||||
const mfem::Operator* PyMatrix::GetRestriction() const {
|
||||
PYBIND11_OVERRIDE(
|
||||
const mfem::Operator*,
|
||||
mfem::Matrix,
|
||||
GetRestriction
|
||||
);
|
||||
}
|
||||
|
||||
void PyMatrix::RecoverFEMSolution(const mfem::Vector &X, const mfem::Vector &b, mfem::Vector &x) {
|
||||
PYBIND11_OVERRIDE(
|
||||
void,
|
||||
mfem::Matrix,
|
||||
RecoverFEMSolution,
|
||||
X,
|
||||
b,
|
||||
x
|
||||
);
|
||||
}
|
||||
|
||||
} // namespace serif::pybind
|
||||
@@ -0,0 +1,41 @@
|
||||
#pragma once
|
||||
|
||||
#include "mfem.hpp"
|
||||
#include "pybind11/pybind11.h"
|
||||
|
||||
namespace serif::pybind {
|
||||
|
||||
/**
|
||||
* @brief A trampoline class for mfem::Matrix.
|
||||
*
|
||||
* This class allows Python classes to inherit from mfem::Matrix and correctly
|
||||
* override its virtual functions, including those inherited from mfem::Operator.
|
||||
*/
|
||||
class PyMatrix : public mfem::Matrix {
|
||||
public:
|
||||
// Inherit constructors from the base mfem::Matrix class.
|
||||
using mfem::Matrix::Matrix;
|
||||
|
||||
// --- Trampolines for new mfem::Matrix pure virtual methods ---
|
||||
mfem::real_t& Elem(int i, int j) override;
|
||||
const mfem::real_t& Elem(int i, int j) const override;
|
||||
mfem::MatrixInverse* Inverse() const override;
|
||||
|
||||
// --- Trampoline for new mfem::Matrix regular virtual methods ---
|
||||
void Finalize(int) override;
|
||||
|
||||
// --- Trampolines for inherited mfem::Operator virtual methods ---
|
||||
// These must be repeated here to allow Python classes inheriting from
|
||||
// Matrix to override methods originally from the Operator base class.
|
||||
void Mult(const mfem::Vector &x, mfem::Vector &y) const override;
|
||||
void MultTranspose(const mfem::Vector &x, mfem::Vector &y) const override;
|
||||
void AddMult(const mfem::Vector &x, mfem::Vector &y, const mfem::real_t a = 1.0) const override;
|
||||
void AddMultTranspose(const mfem::Vector &x, mfem::Vector &y, const mfem::real_t a = 1.0) const override;
|
||||
mfem::Operator& GetGradient(const mfem::Vector &x) const override;
|
||||
void AssembleDiagonal(mfem::Vector &diag) const override;
|
||||
const mfem::Operator* GetProlongation() const override;
|
||||
const mfem::Operator* GetRestriction() const override;
|
||||
void RecoverFEMSolution(const mfem::Vector &X, const mfem::Vector &b, mfem::Vector &x) override;
|
||||
};
|
||||
|
||||
} // namespace serif::pybind
|
||||
@@ -0,0 +1,5 @@
|
||||
PyMatrix_sources = files(
|
||||
'PyMatrix.cpp'
|
||||
)
|
||||
|
||||
trampoline_sources += PyMatrix_sources
|
||||
@@ -0,0 +1,100 @@
|
||||
#include "PyMFEMTrampolines/Operator/PyOperator.h"
|
||||
#include "mfem.hpp"
|
||||
|
||||
namespace serif::pybind {
|
||||
|
||||
// --- Pure virtual function implementation ---
|
||||
// This override is mandatory for the trampoline class to be concrete.
|
||||
void PyOperator::Mult(const mfem::Vector &x, mfem::Vector &y) const {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
void, /* Return type */
|
||||
mfem::Operator, /* C++ base class */
|
||||
Mult, /* C++ function name */
|
||||
x, /* Argument 1 */
|
||||
y /* Argument 2 */
|
||||
);
|
||||
}
|
||||
|
||||
// --- Regular virtual function implementations ---
|
||||
// These overrides allow Python subclasses to optionally provide their own
|
||||
// implementation for these methods. If they don't, the default C++
|
||||
// implementation from mfem::Operator will be called.
|
||||
|
||||
void PyOperator::MultTranspose(const mfem::Vector &x, mfem::Vector &y) const {
|
||||
PYBIND11_OVERRIDE(
|
||||
void,
|
||||
mfem::Operator,
|
||||
MultTranspose,
|
||||
x,
|
||||
y
|
||||
);
|
||||
}
|
||||
|
||||
void PyOperator::AddMult(const mfem::Vector &x, mfem::Vector &y, const mfem::real_t a) const {
|
||||
PYBIND11_OVERRIDE(
|
||||
void,
|
||||
mfem::Operator,
|
||||
AddMult,
|
||||
x,
|
||||
y,
|
||||
a
|
||||
);
|
||||
}
|
||||
|
||||
void PyOperator::AddMultTranspose(const mfem::Vector &x, mfem::Vector &y, const mfem::real_t a) const {
|
||||
PYBIND11_OVERRIDE(
|
||||
void,
|
||||
mfem::Operator,
|
||||
AddMultTranspose,
|
||||
x,
|
||||
y,
|
||||
a
|
||||
);
|
||||
}
|
||||
|
||||
mfem::Operator& PyOperator::GetGradient(const mfem::Vector &x) const {
|
||||
PYBIND11_OVERRIDE(
|
||||
mfem::Operator&,
|
||||
mfem::Operator,
|
||||
GetGradient,
|
||||
x
|
||||
);
|
||||
}
|
||||
|
||||
void PyOperator::AssembleDiagonal(mfem::Vector &diag) const {
|
||||
PYBIND11_OVERRIDE(
|
||||
void,
|
||||
mfem::Operator,
|
||||
AssembleDiagonal,
|
||||
diag
|
||||
);
|
||||
}
|
||||
|
||||
const mfem::Operator* PyOperator::GetProlongation() const {
|
||||
PYBIND11_OVERRIDE(
|
||||
const mfem::Operator*,
|
||||
mfem::Operator,
|
||||
GetProlongation
|
||||
);
|
||||
}
|
||||
|
||||
const mfem::Operator* PyOperator::GetRestriction() const {
|
||||
PYBIND11_OVERRIDE(
|
||||
const mfem::Operator*,
|
||||
mfem::Operator,
|
||||
GetRestriction
|
||||
);
|
||||
}
|
||||
|
||||
void PyOperator::RecoverFEMSolution(const mfem::Vector &X, const mfem::Vector &b, mfem::Vector &x) {
|
||||
PYBIND11_OVERRIDE(
|
||||
void,
|
||||
mfem::Operator,
|
||||
RecoverFEMSolution,
|
||||
X,
|
||||
b,
|
||||
x
|
||||
);
|
||||
}
|
||||
|
||||
} // namespace serif::pybind::mfem
|
||||
@@ -0,0 +1,46 @@
|
||||
#pragma once
|
||||
|
||||
#include "mfem.hpp"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h" // Needed for vectors, maps, sets, strings
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace serif::pybind {
|
||||
|
||||
/**
|
||||
* @brief A trampoline class for mfem::Operator.
|
||||
* * This class allows Python classes to inherit from mfem::Operator and correctly
|
||||
* override its virtual functions. When a virtual function is called from C++,
|
||||
* the trampoline ensures the call is forwarded to the Python implementation if one exists.
|
||||
*/
|
||||
class PyOperator : public mfem::Operator {
|
||||
public:
|
||||
// Inherit the constructors from the base mfem::Operator class.
|
||||
// This allows Python classes to call e.g., super().__init__(size).
|
||||
using mfem::Operator::Operator;
|
||||
|
||||
// --- Trampoline declarations for all overridable virtual functions ---
|
||||
|
||||
// Pure virtual function (MANDATORY override)
|
||||
void Mult(const mfem::Vector &x, mfem::Vector &y) const override;
|
||||
|
||||
// Regular virtual functions (RECOMMENDED overrides)
|
||||
void MultTranspose(const mfem::Vector &x, mfem::Vector &y) const override;
|
||||
|
||||
void AddMult(const mfem::Vector &x, mfem::Vector &y, const mfem::real_t a = 1.0) const override;
|
||||
|
||||
void AddMultTranspose(const mfem::Vector &x, mfem::Vector &y, const mfem::real_t a = 1.0) const override;
|
||||
|
||||
Operator& GetGradient(const mfem::Vector &x) const override;
|
||||
|
||||
void AssembleDiagonal(mfem::Vector &diag) const override;
|
||||
|
||||
const mfem::Operator* GetProlongation() const override;
|
||||
|
||||
const mfem::Operator* GetRestriction() const override;
|
||||
|
||||
void RecoverFEMSolution(const mfem::Vector &X, const mfem::Vector &b, mfem::Vector &x) override;
|
||||
};
|
||||
|
||||
} // namespace serif::pybind::mfem
|
||||
@@ -0,0 +1,7 @@
|
||||
PyOperator_sources = files(
|
||||
'PyOperator.cpp'
|
||||
)
|
||||
|
||||
trampoline_sources += PyOperator_sources
|
||||
|
||||
subdir('Matrix')
|
||||
2
src/python/mfem/Trampoline/PyMFEMTrampolines/meson.build
Normal file
2
src/python/mfem/Trampoline/PyMFEMTrampolines/meson.build
Normal file
@@ -0,0 +1,2 @@
|
||||
subdir('Operator')
|
||||
subdir('Coefficient')
|
||||
23
src/python/mfem/Trampoline/meson.build
Normal file
23
src/python/mfem/Trampoline/meson.build
Normal file
@@ -0,0 +1,23 @@
|
||||
trampoline_sources = []
|
||||
|
||||
subdir('PyMFEMTrampolines')
|
||||
|
||||
dependencies = [
|
||||
mfem_dep,
|
||||
pybind11_dep,
|
||||
python3_dep,
|
||||
]
|
||||
|
||||
trampoline_lib = static_library(
|
||||
'mfem_trampolines',
|
||||
trampoline_sources,
|
||||
include_directories: include_directories('.'),
|
||||
dependencies: dependencies,
|
||||
install: false,
|
||||
)
|
||||
|
||||
trampoline_dep = declare_dependency(
|
||||
link_with: trampoline_lib,
|
||||
include_directories: ('.'),
|
||||
dependencies: dependencies,
|
||||
)
|
||||
862
src/python/mfem/bindings.cpp
Normal file
862
src/python/mfem/bindings.cpp
Normal file
@@ -0,0 +1,862 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <pybind11/numpy.h>
|
||||
#include <pybind11/operators.h> // For operator overloads
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
|
||||
// Include your trampoline class header. The implementation will be in a separate .cpp file.
|
||||
#include "bindings.h"
|
||||
|
||||
#include "Trampoline/PyMFEMTrampolines/Operator/PyOperator.h"
|
||||
#include "Trampoline/PyMFEMTrampolines/Operator/Matrix/PyMatrix.h"
|
||||
#include "Trampoline/PyMFEMTrampolines/Coefficient/PyCoefficient.h"
|
||||
|
||||
#include "mfem.hpp"
|
||||
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace mfem;
|
||||
|
||||
// This function registers all the mfem-related classes to the python module
|
||||
void register_mfem_bindings(py::module &mfem_submodule) {
|
||||
register_operator_bindings(mfem_submodule);
|
||||
register_matrix_bindings(mfem_submodule);
|
||||
|
||||
register_vector_bindings(mfem_submodule);
|
||||
register_array_bindings(mfem_submodule);
|
||||
register_table_bindings(mfem_submodule);
|
||||
|
||||
register_mesh_bindings(mfem_submodule);
|
||||
|
||||
auto formsModule = mfem_submodule.def_submodule("forms", "MFEM forms module");
|
||||
register_bilinear_form_bindings(formsModule);
|
||||
register_mixed_bilinear_form_bindings(formsModule);
|
||||
|
||||
auto fecModule = mfem_submodule.def_submodule("fec", "MFEM finite element collection module");
|
||||
register_basis_type_bindings(fecModule);
|
||||
register_finite_element_collection_bindings(fecModule);
|
||||
register_H1_FECollection_bindings(fecModule);
|
||||
register_RT_FECollection_bindings(fecModule);
|
||||
register_ND_FECollection_bindings(fecModule);
|
||||
|
||||
auto fesModule = mfem_submodule.def_submodule("fes", "MFEM finite element space module");
|
||||
register_finite_element_space_bindings(fesModule);
|
||||
|
||||
register_coefficient_bindings(mfem_submodule);
|
||||
register_intrule_bindings(mfem_submodule);
|
||||
register_eltrans_bindings(mfem_submodule);
|
||||
|
||||
register_grid_function_bindings(mfem_submodule);
|
||||
}
|
||||
|
||||
void register_operator_bindings(py::module &mfem_submodule) {
|
||||
// Use the PyOperator trampoline when binding mfem::Operator
|
||||
// This allows Python classes to inherit from mfem::Operator
|
||||
py::class_<Operator, serif::pybind::PyOperator /* Trampoline */>(mfem_submodule, "Operator")
|
||||
// NOTE: We DO NOT define an __init__ method because Operator is abstract.
|
||||
// Python users will instantiate concrete derived classes instead.
|
||||
|
||||
// --- Bind Properties ---
|
||||
.def_property_readonly("height", &Operator::Height, "Get the height (number of rows) of the Operator.")
|
||||
.def_property_readonly("width", &Operator::Width, "Get the width (number of columns) of the Operator.")
|
||||
|
||||
// --- Bind Core Virtual Methods ---
|
||||
// We bind the methods of the C++ base class so they can be called from Python.
|
||||
// The trampoline handles redirecting to a Python override if one exists.
|
||||
|
||||
// Mult: y = A(x)
|
||||
.def("Mult", py::overload_cast<const Vector&, Vector&>(&Operator::Mult, py::const_),
|
||||
py::arg("x"), py::arg("y"), "Calculates y = A(x). y must be pre-allocated.")
|
||||
|
||||
// Pythonic overload for Mult that returns a new vector
|
||||
.def("Mult", [](const Operator &op, const Vector &x) {
|
||||
Vector y(op.Height());
|
||||
op.Mult(x, y);
|
||||
return y;
|
||||
}, py::arg("x"), "Calculates and returns a new vector y = A(x).")
|
||||
|
||||
// MultTranspose: y = A^T(x)
|
||||
.def("MultTranspose", py::overload_cast<const Vector&, Vector&>(&Operator::MultTranspose, py::const_),
|
||||
py::arg("x"), py::arg("y"), "Calculates y = A^T(x). y must be pre-allocated.")
|
||||
|
||||
.def("MultTranspose", [](const Operator &op, const Vector &x) {
|
||||
Vector y(op.Width());
|
||||
op.MultTranspose(x, y);
|
||||
return y;
|
||||
}, py::arg("x"), "Calculates and returns a new vector y = A^T(x).")
|
||||
|
||||
// Additive versions
|
||||
.def("AddMult", &Operator::AddMult, py::arg("x"), py::arg("y"), py::arg("a") = 1.0, "Performs y += a * A(x).")
|
||||
.def("AddMultTranspose", &Operator::AddMultTranspose, py::arg("x"), py::arg("y"), py::arg("a") = 1.0, "Performs y += a * A^T(x).")
|
||||
|
||||
// Other core virtual methods
|
||||
.def("AssembleDiagonal", &Operator::AssembleDiagonal, py::arg("diag"), "Assembles the operator diagonal into the given Vector.")
|
||||
.def("RecoverFEMSolution", &Operator::RecoverFEMSolution, py::arg("X"), py::arg("b"), py::arg("x"), "Recovers the full FE solution.")
|
||||
.def("GetGradient", &Operator::GetGradient, py::return_value_policy::reference, "Returns the Gradient of a non-linear operator.")
|
||||
|
||||
// Methods returning other operators (e.g., for parallel/BCs)
|
||||
.def("GetProlongation", &Operator::GetProlongation, py::return_value_policy::reference, "Returns the prolongation operator.")
|
||||
.def("GetRestriction", &Operator::GetRestriction, py::return_value_policy::reference, "Returns the restriction operator.")
|
||||
|
||||
// --- Pythonic Operator Overloading ---
|
||||
.def("__matmul__", [](const Operator &op, const Vector &x) {
|
||||
Vector y(op.Height());
|
||||
op.Mult(x, y);
|
||||
return y;
|
||||
}, py::is_operator());
|
||||
}
|
||||
|
||||
void register_matrix_bindings(py::module &mfem_submodule) {
|
||||
py::class_<Matrix, Operator, serif::pybind::PyMatrix>(mfem_submodule, "Matrix")
|
||||
// No constructor since it's an abstract base class.
|
||||
.def_property_readonly("is_square", &Matrix::IsSquare,
|
||||
"Returns true if the matrix is square.")
|
||||
|
||||
.def("finalize", &Matrix::Finalize, py::arg("skip_zeros") = 1,
|
||||
"Finalizes the matrix initialization.")
|
||||
|
||||
.def("inverse", &Matrix::Inverse,
|
||||
"Returns a pointer to (an approximation) of the matrix inverse.",
|
||||
py::return_value_policy::take_ownership) // The caller owns the returned pointer
|
||||
|
||||
// Pythonic element access: mat[i, j]
|
||||
.def("__getitem__", [](const Matrix &m, py::tuple t) {
|
||||
if (t.size() != 2) {
|
||||
throw py::index_error("Matrix index must be a 2-tuple (i, j)");
|
||||
}
|
||||
return m.Elem(t[0].cast<int>(), t[1].cast<int>());
|
||||
})
|
||||
|
||||
// Pythonic element assignment: mat[i, j] = value
|
||||
.def("__setitem__", [](Matrix &m, py::tuple t, real_t value) {
|
||||
if (t.size() != 2) {
|
||||
throw py::index_error("Matrix index must be a 2-tuple (i, j)");
|
||||
}
|
||||
m.Elem(t[0].cast<int>(), t[1].cast<int>()) = value;
|
||||
})
|
||||
|
||||
.def("__repr__", [](const Matrix &m) {
|
||||
return "<mfem.Matrix (Abstract) " +
|
||||
std::to_string(m.Height()) + "x" +
|
||||
std::to_string(m.Width()) + ">";
|
||||
});
|
||||
}
|
||||
|
||||
void register_vector_bindings(py::module &mfem_submodule) {
|
||||
// Register the mfem::Vector class
|
||||
py::class_<mfem::Vector>(mfem_submodule, "Vector")
|
||||
.def(py::init<int>(), py::arg("size"))
|
||||
.def(py::init<const mfem::Vector &>(), py::arg("other"))
|
||||
.def(py::init([](py::array_t<double> arr) {
|
||||
py::buffer_info info = arr.request();
|
||||
if (info.ndim != 1) {
|
||||
throw std::runtime_error("Vector(): expected a 1-D numpy array");
|
||||
}
|
||||
mfem::Vector v(info.size);
|
||||
std::memcpy(v.GetData(), info.ptr, info.size * sizeof(double));
|
||||
return v;
|
||||
}), py::arg("array"))
|
||||
.def("GetData", &mfem::Vector::GetData, py::return_value_policy::reference_internal)
|
||||
.def("Size", &mfem::Vector::Size)
|
||||
.def("__getitem__", [](const mfem::Vector &v, int i) { return v[i]; })
|
||||
.def("__len__", &mfem::Vector::Size)
|
||||
.def("__setitem__", [](mfem::Vector &v, int i, double value) { v[i] = value; })
|
||||
.def("__repr__", [](const mfem::Vector &v) {
|
||||
return "<mfem.Vector(size=" + std::to_string(v.Size()) + ")>";
|
||||
})
|
||||
.def("as_numpy", [](mfem::Vector &self) {
|
||||
return py::array_t<double>(self.Size(), self.GetData());
|
||||
});
|
||||
}
|
||||
|
||||
void register_array_bindings(py::module &mfem_submodule) {
|
||||
py::class_<Array<int>>(mfem_submodule, "IntArray")
|
||||
// --- Constructors ---
|
||||
.def(py::init<>(), "Default constructor.")
|
||||
.def(py::init<int>(), py::arg("size"), "Constructor with size.")
|
||||
.def(py::init([](const std::vector<int> &v) {
|
||||
auto *arr = new Array<int>(v.size());
|
||||
for (size_t i = 0; i < v.size(); ++i) {
|
||||
(*arr)[i] = v[i];
|
||||
}
|
||||
return arr;
|
||||
}), py::arg("list"), "Constructor from a Python list.")
|
||||
|
||||
// --- Pythonic Features ---
|
||||
.def("__len__", &Array<int>::Size)
|
||||
.def("__getitem__", [](const Array<int> &self, int i) {
|
||||
if (i < 0) i += self.Size(); // Handle negative indices
|
||||
if (i < 0 || i >= self.Size()) throw py::index_error();
|
||||
return self[i];
|
||||
})
|
||||
.def("__setitem__", [](Array<int> &self, int i, int value) {
|
||||
if (i < 0) i += self.Size(); // Handle negative indices
|
||||
if (i < 0 || i >= self.Size()) throw py::index_error();
|
||||
self[i] = value;
|
||||
})
|
||||
.def("__iter__", [](Array<int> &self) {
|
||||
return py::make_iterator(self.begin(), self.end());
|
||||
}, py::keep_alive<0, 1>()) // Keep array alive while iterator is used
|
||||
.def("__repr__", [](const Array<int> &self) {
|
||||
std::stringstream ss;
|
||||
ss << "[";
|
||||
for (int i = 0; i < self.Size(); ++i) {
|
||||
ss << self[i] << (i == self.Size() - 1 ? "" : ", ");
|
||||
}
|
||||
ss << "]";
|
||||
return ss.str();
|
||||
})
|
||||
|
||||
// --- Core Methods ---
|
||||
.def("Size", &Array<int>::Size)
|
||||
.def("SetSize", py::overload_cast<int>(&Array<int>::SetSize), py::arg("size"))
|
||||
.def("Append", py::overload_cast<const int &>(&Array<int>::Append), py::arg("el"))
|
||||
.def("Last", py::overload_cast<>(&Array<int>::Last, py::const_))
|
||||
.def("DeleteLast", &Array<int>::DeleteLast)
|
||||
.def("DeleteAll", &Array<int>::DeleteAll)
|
||||
.def("Sort", [](Array<int> &self) { self.Sort(); })
|
||||
.def("Unique", &Array<int>::Unique)
|
||||
.def("Assign", [](Array<int> &self, const Array<int> &other) {
|
||||
self.Assign(other.GetData());
|
||||
})
|
||||
.def("as_numpy", [](Array<int> &self) {
|
||||
// Create a Python-owned copy to avoid memory issues
|
||||
return py::array_t<int>(self.Size(), self.GetData());
|
||||
});
|
||||
}
|
||||
|
||||
// Main function to register BilinearForm
|
||||
void register_bilinear_form_bindings(py::module &mfem_submodule) {
|
||||
|
||||
// It's good practice to bind enums used by the class
|
||||
bind_assembly_level_enum(mfem_submodule);
|
||||
|
||||
// Bind the mfem::BilinearForm class, inheriting from mfem::Matrix
|
||||
// No trampoline is needed because this is a concrete class.
|
||||
py::class_<BilinearForm, Matrix>(mfem_submodule, "BilinearForm")
|
||||
|
||||
// --- Constructor ---
|
||||
.def(py::init<FiniteElementSpace *>(), py::arg("fespace"),
|
||||
// The keep_alive policy ensures the Python object for the
|
||||
// FiniteElementSpace is not garbage-collected while the
|
||||
// BilinearForm is still using it.
|
||||
py::keep_alive<1, 2>(),
|
||||
"Constructs a bilinear form on the given FiniteElementSpace.")
|
||||
|
||||
// --- Setup Methods ---
|
||||
.def("SetAssemblyLevel", &BilinearForm::SetAssemblyLevel, py::arg("assembly_level"),
|
||||
"Set the assembly level (e.g., LEGACY, FULL, PARTIAL, NONE).")
|
||||
.def("EnableStaticCondensation", &BilinearForm::EnableStaticCondensation,
|
||||
"Enable static condensation to reduce system size.")
|
||||
.def("EnableHybridization", &BilinearForm::EnableHybridization,
|
||||
"Enable hybridization.")
|
||||
.def("UsePrecomputedSparsity", &BilinearForm::UsePrecomputedSparsity, py::arg("ps") = 1,
|
||||
"Enable use of precomputed sparsity pattern.")
|
||||
.def("SetDiagonalPolicy", &BilinearForm::SetDiagonalPolicy, py::arg("policy"),
|
||||
"Set the policy for handling diagonal entries of essential DOFs.")
|
||||
|
||||
// --- Integrator Methods ---
|
||||
.def("AddDomainIntegrator", py::overload_cast<BilinearFormIntegrator *>(&BilinearForm::AddDomainIntegrator),
|
||||
py::arg("bfi"), py::keep_alive<1, 2>(),
|
||||
"Adds a domain integrator to the form.")
|
||||
.def("AddBoundaryIntegrator", py::overload_cast<BilinearFormIntegrator *>(&BilinearForm::AddBoundaryIntegrator),
|
||||
py::arg("bfi"), py::keep_alive<1, 2>(),
|
||||
"Adds a boundary integrator to the form.")
|
||||
.def("AddInteriorFaceIntegrator", &BilinearForm::AddInteriorFaceIntegrator,
|
||||
py::arg("bfi"), py::keep_alive<1, 2>(),
|
||||
"Adds an interior face integrator (e.g., for DG methods).")
|
||||
.def("AddBdrFaceIntegrator", py::overload_cast<BilinearFormIntegrator *>(&BilinearForm::AddBdrFaceIntegrator),
|
||||
py::arg("bfi"), py::keep_alive<1, 2>(),
|
||||
"Adds a boundary face integrator.")
|
||||
|
||||
// --- Assembly and System Formulation ---
|
||||
.def("Assemble", &BilinearForm::Assemble, py::arg("skip_zeros") = 1,
|
||||
"Assembles the bilinear form into a sparse matrix.")
|
||||
.def("AssembleDiagonal", &BilinearForm::AssembleDiagonal, py::arg("diag"),
|
||||
"Assembles the diagonal of the operator into a Vector.")
|
||||
.def("FormLinearSystem",
|
||||
[](BilinearForm &self, const Array<int> &ess_tdof_list, Vector &x,
|
||||
Vector &b, OperatorHandle &A, Vector &X, Vector &B, int copy_interior) {
|
||||
self.FormLinearSystem(ess_tdof_list, x, b, A, X, B, copy_interior);
|
||||
},
|
||||
py::arg("ess_tdof_list"), py::arg("x"), py::arg("b"), py::arg("A"),
|
||||
py::arg("X"), py::arg("B"), py::arg("copy_interior") = 0,
|
||||
"Forms the linear system AX=B, applying boundary conditions and other transformations.")
|
||||
.def("FormSystemMatrix",
|
||||
[](BilinearForm &self, const Array<int> &ess_tdof_list, OperatorHandle &A) {
|
||||
self.FormSystemMatrix(ess_tdof_list, A);
|
||||
},
|
||||
py::arg("ess_tdof_list"), py::arg("A"),
|
||||
"Forms the system matrix A, applying necessary transformations.")
|
||||
.def("RecoverFEMSolution", py::overload_cast<const Vector&, const Vector&, Vector&>(&BilinearForm::RecoverFEMSolution),
|
||||
py::arg("X"), py::arg("b"), py::arg("x"),
|
||||
"Recovers the full FE solution vector after solving a linear system.")
|
||||
|
||||
// --- Accessor Methods ---
|
||||
.def("FESpace", py::overload_cast<>(&BilinearForm::FESpace, py::const_), py::return_value_policy::reference_internal,
|
||||
"Returns a pointer to the associated FiniteElementSpace.")
|
||||
.def("SpMat", py::overload_cast<>(&BilinearForm::SpMat), py::return_value_policy::reference_internal,
|
||||
"Returns a reference to the internal sparse matrix.")
|
||||
.def("Update", &BilinearForm::Update, py::arg("nfes") = nullptr,
|
||||
"Update the BilinearForm after the FE space has changed.");
|
||||
|
||||
// You will also need to bind the OperatorHandle and DiagonalPolicy enum
|
||||
// if you haven't already. Example for DiagonalPolicy:
|
||||
py::enum_<mfem::Matrix::DiagonalPolicy>(mfem_submodule, "DiagonalPolicy")
|
||||
.value("DIAG_ZERO", mfem::Matrix::DiagonalPolicy::DIAG_ZERO)
|
||||
.value("DIAG_ONE", mfem::Matrix::DiagonalPolicy::DIAG_ONE)
|
||||
.value("DIAG_KEEP", mfem::Matrix::DiagonalPolicy::DIAG_KEEP)
|
||||
.export_values();
|
||||
}
|
||||
|
||||
// Helper function to bind the AssemblyLevel enum
|
||||
void bind_assembly_level_enum(py::module &m) {
|
||||
py::enum_<AssemblyLevel>(m, "AssemblyLevel")
|
||||
.value("LEGACY", AssemblyLevel::LEGACY)
|
||||
.value("FULL", AssemblyLevel::FULL)
|
||||
.value("ELEMENT", AssemblyLevel::ELEMENT)
|
||||
.value("PARTIAL", AssemblyLevel::PARTIAL)
|
||||
.value("NONE", AssemblyLevel::NONE)
|
||||
.export_values();
|
||||
}
|
||||
|
||||
// Main function to register MixedBilinearForm
|
||||
void register_mixed_bilinear_form_bindings(py::module &mfem_submodule) {
|
||||
|
||||
// Bind the mfem::MixedBilinearForm class, inheriting from mfem::Matrix
|
||||
// No trampoline is needed because this is a concrete class.
|
||||
py::class_<MixedBilinearForm, Matrix>(mfem_submodule, "MixedBilinearForm")
|
||||
|
||||
// --- Constructor ---
|
||||
.def(py::init<FiniteElementSpace *, FiniteElementSpace *>(),
|
||||
py::arg("trial_fespace"), py::arg("test_fespace"),
|
||||
// Keep alive policies ensure the FE space objects are not garbage
|
||||
// collected while the MixedBilinearForm is still using them.
|
||||
py::keep_alive<1, 2>(), py::keep_alive<1, 3>(),
|
||||
"Constructs a mixed bilinear form on the given trial and test FE spaces.")
|
||||
|
||||
// --- Setup Methods ---
|
||||
.def("SetAssemblyLevel", &MixedBilinearForm::SetAssemblyLevel, py::arg("assembly_level"),
|
||||
"Set the assembly level (e.g., LEGACY, FULL, PARTIAL, NONE).")
|
||||
|
||||
// --- Integrator Methods ---
|
||||
.def("AddDomainIntegrator", py::overload_cast<BilinearFormIntegrator *>(&MixedBilinearForm::AddDomainIntegrator),
|
||||
py::arg("bfi"), py::keep_alive<1, 2>(),
|
||||
"Adds a domain integrator to the form.")
|
||||
.def("AddBoundaryIntegrator", py::overload_cast<BilinearFormIntegrator *>(&MixedBilinearForm::AddBoundaryIntegrator),
|
||||
py::arg("bfi"), py::keep_alive<1, 2>(),
|
||||
"Adds a boundary integrator to the form.")
|
||||
.def("AddInteriorFaceIntegrator", &MixedBilinearForm::AddInteriorFaceIntegrator,
|
||||
py::arg("bfi"), py::keep_alive<1, 2>(),
|
||||
"Adds an interior face integrator.")
|
||||
.def("AddBdrFaceIntegrator", py::overload_cast<BilinearFormIntegrator *>(&MixedBilinearForm::AddBdrFaceIntegrator),
|
||||
py::arg("bfi"), py::keep_alive<1, 2>(),
|
||||
"Adds a boundary face integrator.")
|
||||
.def("AddTraceFaceIntegrator", &MixedBilinearForm::AddTraceFaceIntegrator,
|
||||
py::arg("bfi"), py::keep_alive<1, 2>(),
|
||||
"Adds a trace face integrator.")
|
||||
|
||||
|
||||
// --- Assembly and System Formulation ---
|
||||
.def("Assemble", &MixedBilinearForm::Assemble, py::arg("skip_zeros") = 1,
|
||||
"Assembles the mixed bilinear form into a sparse matrix.")
|
||||
.def("FormRectangularSystemMatrix",
|
||||
[](MixedBilinearForm &self, const Array<int> &trial_tdof_list, const Array<int> &test_tdof_list, OperatorHandle &A) {
|
||||
self.FormRectangularSystemMatrix(trial_tdof_list, test_tdof_list, A);
|
||||
},
|
||||
py::arg("trial_tdof_list"), py::arg("test_tdof_list"), py::arg("A"),
|
||||
"Forms the rectangular system matrix A, applying necessary transformations.")
|
||||
.def("FormRectangularLinearSystem",
|
||||
[](MixedBilinearForm &self, const Array<int> &trial_tdof_list, const Array<int> &test_tdof_list,
|
||||
Vector &x, Vector &b, OperatorHandle &A, Vector &X, Vector &B) {
|
||||
self.FormRectangularLinearSystem(trial_tdof_list, test_tdof_list, x, b, A, X, B);
|
||||
},
|
||||
py::arg("trial_tdof_list"), py::arg("test_tdof_list"), py::arg("x"), py::arg("b"),
|
||||
py::arg("A"), py::arg("X"), py::arg("B"),
|
||||
"Forms the rectangular linear system AX=B.")
|
||||
|
||||
// --- Accessor Methods ---
|
||||
.def("TrialFESpace", py::overload_cast<>(&MixedBilinearForm::TrialFESpace, py::const_),
|
||||
py::return_value_policy::reference_internal,
|
||||
"Returns a pointer to the associated trial FiniteElementSpace.")
|
||||
.def("TestFESpace", py::overload_cast<>(&MixedBilinearForm::TestFESpace, py::const_),
|
||||
py::return_value_policy::reference_internal,
|
||||
"Returns a pointer to the associated test FiniteElementSpace.")
|
||||
.def("SpMat", py::overload_cast<>(&MixedBilinearForm::SpMat),
|
||||
py::return_value_policy::reference_internal,
|
||||
"Returns a reference to the internal sparse matrix.")
|
||||
.def("Update", &MixedBilinearForm::Update,
|
||||
"Update the MixedBilinearForm after the FE spaces have changed.");
|
||||
}
|
||||
|
||||
// This function can be called from your main registration function.
|
||||
void register_mesh_bindings(py::module &mfem_submodule) {
|
||||
// Bind the mfem::Mesh class. No trampoline needed.
|
||||
py::class_<Mesh>(mfem_submodule, "Mesh")
|
||||
|
||||
// --- Constructors & Loading ---
|
||||
// Default constructor for creating an empty mesh object
|
||||
.def(py::init<>())
|
||||
// Constructor to load from a file path
|
||||
.def(py::init<const std::string &, int, int, bool>(),
|
||||
py::arg("filename"), py::arg("generate_edges") = 0,
|
||||
py::arg("refine") = 1, py::arg("fix_orientation") = true)
|
||||
// Static factory method for a more Pythonic loading interface
|
||||
.def_static("LoadFromFile", &Mesh::LoadFromFile,
|
||||
py::arg("filename"), py::arg("generate_edges") = 0,
|
||||
py::arg("refine") = 1, py::arg("fix_orientation") = true,
|
||||
"Creates a mesh by reading a file.")
|
||||
|
||||
// --- Basic Properties & Stats ---
|
||||
.def_property_readonly("dim", &Mesh::Dimension)
|
||||
.def_property_readonly("space_dim", &Mesh::SpaceDimension)
|
||||
.def_property_readonly("nv", &Mesh::GetNV, "Number of Vertices")
|
||||
.def_property_readonly("ne", &Mesh::GetNE, "Number of Elements")
|
||||
.def_property_readonly("nbe", &Mesh::GetNBE, "Number of Boundary Elements")
|
||||
.def_property_readonly("n_edges", &Mesh::GetNEdges)
|
||||
.def_property_readonly("n_faces", &Mesh::GetNFaces)
|
||||
.def_readonly("attributes", &Mesh::attributes)
|
||||
.def_readonly("bdr_attributes", &Mesh::bdr_attributes)
|
||||
.def("GetBoundingBox",
|
||||
[](Mesh &self, int ref) {
|
||||
Vector min, max;
|
||||
self.GetBoundingBox(min, max, ref);
|
||||
// Here you might want to return a tuple of numpy arrays instead
|
||||
return py::make_tuple(min, max);
|
||||
}, py::arg("ref") = 2,
|
||||
"Returns the min and max corners of the mesh bounding box.")
|
||||
|
||||
// --- Connectivity Data ---
|
||||
.def("GetElementVertices",
|
||||
[](const Mesh &self, int i) {
|
||||
Array<int> v;
|
||||
self.GetElementVertices(i, v);
|
||||
return v;
|
||||
}, py::arg("i"), "Returns a list of vertex indices for element i.")
|
||||
.def("GetBdrElementVertices",
|
||||
[](const Mesh &self, int i) {
|
||||
Array<int> v;
|
||||
self.GetBdrElementVertices(i, v);
|
||||
return v;
|
||||
}, py::arg("i"), "Returns a list of vertex indices for boundary element i.")
|
||||
.def("GetFaceElements",
|
||||
[](const Mesh &self, int i) {
|
||||
int e1, e2;
|
||||
self.GetFaceElements(i, &e1, &e2);
|
||||
return py::make_tuple(e1, e2);
|
||||
}, py::arg("face_idx"), "Returns a tuple of the two elements sharing a face.")
|
||||
.def("GetBdrElementFace",
|
||||
[](const Mesh &self, int i) {
|
||||
int f, o;
|
||||
self.GetBdrElementFace(i, &f, &o);
|
||||
return py::make_tuple(f, o);
|
||||
}, py::arg("bdr_elem_idx"), "Returns the face index and orientation for a boundary element.")
|
||||
.def("ElementToEdgeTable", &Mesh::ElementToEdgeTable, py::return_value_policy::reference_internal)
|
||||
.def("ElementToFaceTable", &Mesh::ElementToFaceTable, py::return_value_policy::reference_internal)
|
||||
|
||||
// --- Coordinate Transformations & Curvature ---
|
||||
.def("GetNodes", py::overload_cast<>(&Mesh::GetNodes, py::const_), py::return_value_policy::reference_internal,
|
||||
"Returns the GridFunction for the mesh nodes (if any).")
|
||||
.def("SetCurvature", &Mesh::SetCurvature, py::arg("order"), py::arg("discontinuous") = false,
|
||||
py::arg("space_dim") = -1, py::arg("ordering") = 1,
|
||||
"Set the curvature of the mesh, creating a high-order nodal GridFunction.")
|
||||
// Use py::overload_cast to resolve ambiguity for overloaded methods
|
||||
.def("GetElementTransformation", py::overload_cast<int>(&Mesh::GetElementTransformation), py::arg("i"),
|
||||
py::return_value_policy::reference_internal)
|
||||
.def("GetFaceElementTransformations", py::overload_cast<int, int>(&Mesh::GetFaceElementTransformations), py::arg("i"),
|
||||
py::arg("mask")=31, py::return_value_policy::reference_internal)
|
||||
|
||||
// --- I/O & Visualization ---
|
||||
.def("Save", &Mesh::Save, py::arg("filename"), py::arg("precision") = 16);
|
||||
|
||||
// Bind the VTKFormat enum used in PrintVTU
|
||||
py::enum_<VTKFormat>(mfem_submodule, "VTKFormat")
|
||||
.value("ASCII", VTKFormat::ASCII)
|
||||
.value("BINARY", VTKFormat::BINARY)
|
||||
.value("BINARY32", VTKFormat::BINARY32)
|
||||
.export_values();
|
||||
}
|
||||
|
||||
void register_table_bindings(pybind11::module &mfem_submodule) {
|
||||
// Bind mfem::Table first, as it's a return type for some Mesh methods.
|
||||
py::class_<Table>(mfem_submodule, "Table")
|
||||
.def("GetRow", [](const Table &self, int row) {
|
||||
Array<int> row_data;
|
||||
self.GetRow(row, row_data);
|
||||
// pybind11 will automatically convert Array<int> to a Python list
|
||||
return row_data;
|
||||
}, py::arg("row"), "Get a row of the table as a list.")
|
||||
.def("RowSize", &Table::RowSize, py::arg("row"), "Get the number of entries in a specific row.")
|
||||
.def_property_readonly("height", &Table::Size, "Number of rows in the table.")
|
||||
.def_property_readonly("width", &Table::Width, "Number of columns in the table.")
|
||||
.def("__len__", &Table::Size)
|
||||
.def("__getitem__", [](const Table &self, int row) {
|
||||
if (row < 0 || row >= self.Size()) {
|
||||
throw py::index_error("Row index out of bounds");
|
||||
}
|
||||
Array<int> row_data;
|
||||
self.GetRow(row, row_data);
|
||||
return row_data;
|
||||
})
|
||||
.def("__repr__", [](const Table &self) {
|
||||
return "<mfem.Table (" + std::to_string(self.Size()) + "x" +
|
||||
std::to_string(self.Width()) + ")>";
|
||||
});
|
||||
}
|
||||
|
||||
void register_finite_element_collection_bindings(pybind11::module &mfem_submodule) {
|
||||
py::class_<FiniteElementCollection>(mfem_submodule, "FiniteElementCollection")
|
||||
// No constructor for abstract base classes
|
||||
.def("Name", &FiniteElementCollection::Name)
|
||||
.def("GetOrder", &FiniteElementCollection::GetOrder);
|
||||
}
|
||||
|
||||
// Binds the mfem::BasisType class and its internal anonymous enum values
|
||||
// as static properties of a Python class. This should be called before
|
||||
// binding any class that uses BasisType values in its constructor.
|
||||
void register_basis_type_bindings(py::module &m) {
|
||||
py::class_<BasisType> basis_type(m, "BasisType", "Possible basis types.");
|
||||
|
||||
basis_type.def_property_readonly_static("Invalid", [](py::object) { return BasisType::Invalid; });
|
||||
basis_type.def_property_readonly_static("GaussLegendre", [](py::object) { return BasisType::GaussLegendre; });
|
||||
basis_type.def_property_readonly_static("GaussLobatto", [](py::object) { return BasisType::GaussLobatto; });
|
||||
basis_type.def_property_readonly_static("Positive", [](py::object) { return BasisType::Positive; });
|
||||
basis_type.def_property_readonly_static("OpenUniform", [](py::object) { return BasisType::OpenUniform; });
|
||||
basis_type.def_property_readonly_static("ClosedUniform", [](py::object) { return BasisType::ClosedUniform; });
|
||||
basis_type.def_property_readonly_static("OpenHalfUniform", [](py::object) { return BasisType::OpenHalfUniform; });
|
||||
basis_type.def_property_readonly_static("Serendipity", [](py::object) { return BasisType::Serendipity; });
|
||||
basis_type.def_property_readonly_static("ClosedGL", [](py::object) { return BasisType::ClosedGL; });
|
||||
basis_type.def_property_readonly_static("IntegratedGLL", [](py::object) { return BasisType::IntegratedGLL; });
|
||||
}
|
||||
|
||||
|
||||
|
||||
void register_H1_FECollection_bindings(py::module &m) {
|
||||
py::class_<H1_FECollection, FiniteElementCollection>(m, "H1_FECollection")
|
||||
.def(py::init([](int p, int dim) {
|
||||
// The lambda explicitly calls the constructor, avoiding the
|
||||
// default argument parsing issue.
|
||||
return new H1_FECollection(p, dim);
|
||||
}),
|
||||
py::arg("p"),
|
||||
py::arg("dim") = 3,
|
||||
"Constructs an H1 finite element collection.")
|
||||
.def("GetBasisType", &H1_FECollection::GetBasisType);
|
||||
}
|
||||
|
||||
|
||||
void register_RT_FECollection_bindings(py::module &m) {
|
||||
py::class_<RT_FECollection, FiniteElementCollection>(m, "RT_FECollection")
|
||||
.def(py::init<const int, const int>(),
|
||||
py::arg("p"), py::arg("dim"),
|
||||
"Constructs a Raviart-Thomas H(div)-conforming FE collection.")
|
||||
.def("GetClosedBasisType", &RT_FECollection::GetClosedBasisType)
|
||||
.def("GetOpenBasisType", &RT_FECollection::GetOpenBasisType);
|
||||
}
|
||||
|
||||
void register_ND_FECollection_bindings(py::module &m) {
|
||||
py::class_<ND_FECollection, FiniteElementCollection>(m, "ND_FECollection")
|
||||
.def(py::init<const int, const int>(),
|
||||
py::arg("p"), py::arg("dim"),
|
||||
"Constructs a Nedelec H(curl)-conforming FE collection.")
|
||||
.def("GetClosedBasisType", &ND_FECollection::GetClosedBasisType)
|
||||
.def("GetOpenBasisType", &ND_FECollection::GetOpenBasisType);
|
||||
}
|
||||
|
||||
|
||||
void register_finite_element_space_bindings(py::module &mfem_submodule) {
|
||||
// Bind dependent enums first
|
||||
bind_ordering_enum(mfem_submodule);
|
||||
|
||||
// Bind the mfem::FiniteElementSpace class
|
||||
py::class_<FiniteElementSpace>(mfem_submodule, "FiniteElementSpace")
|
||||
// --- Constructors ---
|
||||
.def(py::init<Mesh *, const FiniteElementCollection *, int, int>(),
|
||||
py::arg("mesh"), py::arg("fec"), py::arg("vdim") = 1, py::arg("ordering") = Ordering::byNODES,
|
||||
// Keep alive policies prevent the mesh and fec from being
|
||||
// garbage collected by Python while the C++ FESpace object exists.
|
||||
py::keep_alive<1, 2>(), py::keep_alive<1, 3>())
|
||||
|
||||
// --- Core Properties & Stats ---
|
||||
.def_property_readonly("ndofs", &FiniteElementSpace::GetNDofs, "Number of local scalar degrees of freedom.")
|
||||
.def_property_readonly("vdim", &FiniteElementSpace::GetVDim, "Vector dimension of the space.")
|
||||
.def_property_readonly("vsize", &FiniteElementSpace::GetVSize, "Total number of local vector degrees of freedom.")
|
||||
.def_property_readonly("true_vsize", &FiniteElementSpace::GetTrueVSize, "Number of true (conforming) vector degrees of freedom.")
|
||||
.def("GetMesh", &FiniteElementSpace::GetMesh, py::return_value_policy::reference_internal, "Get the associated Mesh.")
|
||||
.def("FEColl", &FiniteElementSpace::FEColl, py::return_value_policy::reference_internal, "Get the associated FiniteElementCollection.")
|
||||
|
||||
// --- DOF Management ---
|
||||
.def("GetElementDofs",
|
||||
[](const FiniteElementSpace &self, int i) {
|
||||
Array<int> dofs;
|
||||
self.GetElementDofs(i, dofs);
|
||||
return dofs;
|
||||
}, py::arg("elem_idx"), "Get the local scalar DOFs for a given element.")
|
||||
.def("GetElementVDofs",
|
||||
[](const FiniteElementSpace &self, int i) {
|
||||
Array<int> vdofs;
|
||||
self.GetElementVDofs(i, vdofs);
|
||||
return vdofs;
|
||||
}, py::arg("elem_idx"), "Get the local vector DOFs for a given element.")
|
||||
.def("GetBdrElementDofs",
|
||||
[](const FiniteElementSpace &self, int i) {
|
||||
Array<int> dofs;
|
||||
self.GetBdrElementDofs(i, dofs);
|
||||
return dofs;
|
||||
}, py::arg("bdr_elem_idx"), "Get the local scalar DOFs for a given boundary element.")
|
||||
.def("GetEssentialVDofs",
|
||||
[](const FiniteElementSpace &self, const Array<int> &bdr_attr_is_ess, int component) {
|
||||
Array<int> ess_vdofs;
|
||||
self.GetEssentialVDofs(bdr_attr_is_ess, ess_vdofs, component);
|
||||
return ess_vdofs;
|
||||
}, py::arg("bdr_attr_is_ess"), py::arg("component") = -1,
|
||||
"Get a list of essential (Dirichlet) vector DOFs based on boundary attributes.")
|
||||
.def("GetEssentialTrueDofs",
|
||||
[](const FiniteElementSpace &self, const Array<int> &bdr_attr_is_ess, int component) {
|
||||
Array<int> ess_tdof_list;
|
||||
self.GetEssentialTrueDofs(bdr_attr_is_ess, ess_tdof_list, component);
|
||||
return ess_tdof_list;
|
||||
}, py::arg("bdr_attr_is_ess"), py::arg("component") = -1,
|
||||
"Get a list of essential true (conforming) DOFs for use in linear systems.")
|
||||
|
||||
// --- Updating & Operators ---
|
||||
.def("Update", &FiniteElementSpace::Update, py::arg("want_transform") = true,
|
||||
"Update the FE space after the mesh has been modified (e.g., refined).")
|
||||
.def("GetUpdateOperator", py::overload_cast<>(&FiniteElementSpace::GetUpdateOperator),
|
||||
py::return_value_policy::reference_internal,
|
||||
"Get the operator that maps GridFunctions from the old space to the new space after an Update.")
|
||||
.def("GetProlongationMatrix", &FiniteElementSpace::GetProlongationMatrix,
|
||||
py::return_value_policy::reference_internal, "Get the P operator (true DOFs to local DOFs).")
|
||||
.def("GetRestrictionOperator", &FiniteElementSpace::GetRestrictionOperator,
|
||||
py::return_value_policy::reference_internal, "Get the R operator (local DOFs to true DOFs).")
|
||||
|
||||
.def("__repr__", [](const FiniteElementSpace &fes) {
|
||||
std::stringstream ss;
|
||||
ss << "<mfem.FiniteElementSpace with " << fes.GetNDofs() << " DOFs, vdim=" << fes.GetVDim() << ">";
|
||||
return ss.str();
|
||||
});
|
||||
}
|
||||
|
||||
void bind_ordering_enum(py::module &mfem_submodule) {
|
||||
py::enum_<Ordering::Type>(mfem_submodule, "Ordering")
|
||||
.value("byNODES", Ordering::byNODES)
|
||||
.value("byVDIM", Ordering::byVDIM)
|
||||
.export_values();
|
||||
}
|
||||
|
||||
/// Binds mfem::GridFunction
|
||||
void register_grid_function_bindings(py::module &m) {
|
||||
// Assumes that Vector, FiniteElementSpace, Coefficient, VectorCoefficient,
|
||||
// IntegrationRule, and ElementTransformation are already bound.
|
||||
|
||||
py::class_<GridFunction, Vector>(m, "GridFunction")
|
||||
.def(py::init<FiniteElementSpace *>(), py::arg("fespace"),
|
||||
py::keep_alive<1, 2>(), // Keep FE space alive
|
||||
"Construct a GridFunction on a given FiniteElementSpace.")
|
||||
|
||||
.def(py::init<FiniteElementSpace *, real_t *>(),
|
||||
py::arg("fespace"), py::arg("data"),
|
||||
py::keep_alive<1, 2>(),
|
||||
"Construct a GridFunction using previously allocated data.")
|
||||
|
||||
.def("FESpace", py::overload_cast<>(&GridFunction::FESpace, py::const_),
|
||||
py::return_value_policy::reference_internal,
|
||||
"Returns the associated FiniteElementSpace.")
|
||||
.def("Update", &GridFunction::Update,
|
||||
"Update the GridFunction after its FE space has been modified.")
|
||||
.def("SetSpace", &GridFunction::SetSpace, py::arg("fespace"),
|
||||
py::keep_alive<1, 2>(), "Associate a new FE space with the GridFunction.")
|
||||
|
||||
// --- Projection Methods ---
|
||||
.def("ProjectCoefficient",
|
||||
py::overload_cast<Coefficient &>(&GridFunction::ProjectCoefficient),
|
||||
py::arg("coeff"),
|
||||
"Project a scalar Coefficient onto the GridFunction.")
|
||||
.def("ProjectCoefficient",
|
||||
py::overload_cast<VectorCoefficient &>(&GridFunction::ProjectCoefficient),
|
||||
py::arg("vcoeff"),
|
||||
"Project a vector Coefficient onto the GridFunction.")
|
||||
.def("ProjectBdrCoefficient",
|
||||
py::overload_cast<Coefficient &, const Array<int> &>(&GridFunction::ProjectBdrCoefficient),
|
||||
py::arg("coeff"), py::arg("attr"),
|
||||
"Project a scalar Coefficient onto the boundary degrees of freedom.")
|
||||
.def("ProjectBdrCoefficient",
|
||||
py::overload_cast<VectorCoefficient &, const Array<int> &>(&GridFunction::ProjectBdrCoefficient),
|
||||
py::arg("vcoeff"), py::arg("attr"),
|
||||
"Project a vector Coefficient onto the boundary degrees of freedom.")
|
||||
|
||||
// --- Evaluation Methods ---
|
||||
.def("GetValue", py::overload_cast<ElementTransformation &, const IntegrationPoint &, int, Vector *>(&GridFunction::GetValue, py::const_),
|
||||
py::arg("T"), py::arg("ip"), py::arg("comp") = 0, py::arg("tr") = nullptr,
|
||||
"Get the scalar value at a point described by an ElementTransformation.")
|
||||
.def("GetVectorValue",
|
||||
[](const GridFunction &self, ElementTransformation &T, const IntegrationPoint &ip) {
|
||||
Vector val;
|
||||
self.GetVectorValue(T, ip, val);
|
||||
return val;
|
||||
},
|
||||
py::arg("T"), py::arg("ip"),
|
||||
"Get the vector value at a point described by an ElementTransformation.")
|
||||
.def("GetValues",
|
||||
[](const GridFunction &self, ElementTransformation &T, const IntegrationRule &ir, int comp) {
|
||||
Vector vals;
|
||||
self.GetValues(T, ir, vals, comp);
|
||||
return vals;
|
||||
},
|
||||
py::arg("T"), py::arg("ir"), py::arg("comp") = 0,
|
||||
"Get scalar values at all points of an IntegrationRule.")
|
||||
.def("GetVectorValues",
|
||||
[](const GridFunction &self, ElementTransformation &T, const IntegrationRule &ir) {
|
||||
DenseMatrix vals;
|
||||
self.GetVectorValues(T, ir, vals);
|
||||
return vals; // pybind11 handles DenseMatrix
|
||||
},
|
||||
py::arg("T"), py::arg("ir"),
|
||||
"Get vector values at all points of an IntegrationRule.")
|
||||
|
||||
// --- True DOF (constrained system) Methods ---
|
||||
.def("GetTrueDofs", &GridFunction::GetTrueDofs, py::arg("tv"),
|
||||
"Extract the true-dofs from the GridFunction.")
|
||||
.def("SetFromTrueDofs", &GridFunction::SetFromTrueDofs, py::arg("tv"),
|
||||
"Set the GridFunction from a true-dof vector.")
|
||||
|
||||
// --- I/O ---
|
||||
.def("Save", py::overload_cast<const char *, int>(&GridFunction::Save, py::const_),
|
||||
py::arg("fname"), py::arg("precision") = 16)
|
||||
.def("__repr__", [](const GridFunction &gf) {
|
||||
std::stringstream ss;
|
||||
ss << "<mfem.GridFunction with " << gf.Size() << " DOFs>";
|
||||
return ss.str();
|
||||
})
|
||||
.def("as_numpy", [](const GridFunction &self) {
|
||||
// Convert the GridFunction to a numpy array
|
||||
return py::array_t<real_t>(self.Size(), self.GetData());
|
||||
}, "Convert the GridFunction to a numpy array.");
|
||||
}
|
||||
|
||||
// This function registers the coefficient classes to the python module
|
||||
void register_coefficient_bindings(py::module &m) {
|
||||
|
||||
// --- Bind abstract base classes using trampolines ---
|
||||
py::class_<Coefficient, serif::pybind::PyCoefficient> coefficient(m, "Coefficient");
|
||||
coefficient
|
||||
.def(py::init<>())
|
||||
.def("SetTime", &Coefficient::SetTime, py::arg("t"))
|
||||
.def("GetTime", &Coefficient::GetTime)
|
||||
.def("Eval", py::overload_cast<ElementTransformation &, const IntegrationPoint&>(&Coefficient::Eval),
|
||||
"Evaluate the coefficient at a point in an element.",
|
||||
py::arg("T"), py::arg("ip"));
|
||||
|
||||
py::class_<VectorCoefficient, serif::pybind::PyVectorCoefficient> vector_coefficient(m, "VectorCoefficient");
|
||||
vector_coefficient
|
||||
.def(py::init<int>(), py::arg("vdim"))
|
||||
.def("SetTime", &VectorCoefficient::SetTime, py::arg("t"))
|
||||
.def("GetTime", &VectorCoefficient::GetTime)
|
||||
.def("GetVDim", &VectorCoefficient::GetVDim)
|
||||
.def("Eval", py::overload_cast<Vector &, ElementTransformation &, const IntegrationPoint &>(&VectorCoefficient::Eval),
|
||||
"Evaluate the vector coefficient at a point in an element.",
|
||||
py::arg("V"), py::arg("T"), py::arg("ip"));
|
||||
|
||||
// --- Bind useful concrete classes ---
|
||||
|
||||
// ConstantCoefficient
|
||||
py::class_<ConstantCoefficient, Coefficient>(m, "ConstantCoefficient")
|
||||
.def(py::init<real_t>(), py::arg("c") = 1.0)
|
||||
.def_readwrite("constant", &ConstantCoefficient::constant);
|
||||
|
||||
// FunctionCoefficient (allows using Python functions as coefficients)
|
||||
py::class_<FunctionCoefficient, Coefficient>(m, "FunctionCoefficient")
|
||||
.def(py::init<std::function<real_t(const Vector &)>>(),
|
||||
py::arg("F"), "Create a coefficient from a Python function of space.")
|
||||
.def(py::init<std::function<real_t(const Vector &, real_t)>>(),
|
||||
py::arg("TDF"), "Create a coefficient from a Python function of space and time.");
|
||||
|
||||
// VectorConstantCoefficient
|
||||
py::class_<VectorConstantCoefficient, VectorCoefficient>(m, "VectorConstantCoefficient")
|
||||
.def(py::init<const Vector &>(), py::arg("v"));
|
||||
|
||||
// VectorFunctionCoefficient
|
||||
py::class_<VectorFunctionCoefficient, VectorCoefficient>(m, "VectorFunctionCoefficient")
|
||||
.def(py::init<int, std::function<void(const Vector &, Vector &)>>(),
|
||||
py::arg("dim"), py::arg("F"), "Create a vector coefficient from a Python function of space.")
|
||||
.def(py::init<int, std::function<void(const Vector &, real_t, Vector &)>>(),
|
||||
py::arg("dim"), py::arg("TDF"), "Create a vector coefficient from a Python function of space and time.");
|
||||
|
||||
// GridFunctionCoefficient (useful for source terms that depend on the solution)
|
||||
py::class_<GridFunctionCoefficient, Coefficient>(m, "GridFunctionCoefficient")
|
||||
.def(py::init<>())
|
||||
.def(py::init<const GridFunction *>(), py::arg("gf"))
|
||||
.def("SetGridFunction", &GridFunctionCoefficient::SetGridFunction, py::arg("gf"))
|
||||
.def("GetGridFunction", &GridFunctionCoefficient::GetGridFunction, py::return_value_policy::reference);
|
||||
}
|
||||
|
||||
/// Binds mfem::IntegrationPoint and mfem::IntegrationRule
|
||||
void register_intrule_bindings(py::module &m) {
|
||||
py::class_<IntegrationPoint>(m, "IntegrationPoint")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("x", &IntegrationPoint::x)
|
||||
.def_readwrite("y", &IntegrationPoint::y)
|
||||
.def_readwrite("z", &IntegrationPoint::z)
|
||||
.def_readwrite("weight", &IntegrationPoint::weight)
|
||||
.def("__repr__", [](const IntegrationPoint &ip) {
|
||||
std::stringstream ss;
|
||||
ss << "IP(x=" << ip.x << ", y=" << ip.y << ", z=" << ip.z << ", w=" << ip.weight << ")";
|
||||
return ss.str();
|
||||
});
|
||||
|
||||
py::class_<IntegrationRule>(m, "IntegrationRule")
|
||||
.def(py::init<>())
|
||||
.def(py::init<int>(), py::arg("NumPoints"))
|
||||
.def("GetNPoints", &IntegrationRule::GetNPoints)
|
||||
.def("GetOrder", &IntegrationRule::GetOrder)
|
||||
.def("__len__", &IntegrationRule::GetNPoints)
|
||||
.def("__getitem__", [](const IntegrationRule &self, int i) {
|
||||
if (i < 0 || i >= self.GetNPoints()) throw py::index_error();
|
||||
return self.IntPoint(i);
|
||||
})
|
||||
.def("__iter__", [](const IntegrationRule &self) {
|
||||
return py::make_iterator(self.begin(), self.end());
|
||||
}, py::keep_alive<0, 1>());
|
||||
}
|
||||
|
||||
/// Binds mfem::ElementTransformation and related classes
|
||||
void register_eltrans_bindings(py::module &m) {
|
||||
// Bind the base class. Users get pointers to this, but never create it.
|
||||
// It is not abstract in the C++ sense, but we treat it as such for bindings.
|
||||
py::class_<ElementTransformation> eltrans(m, "ElementTransformation");
|
||||
eltrans
|
||||
.def_readonly("Attribute", &ElementTransformation::Attribute)
|
||||
.def_readonly("ElementNo", &ElementTransformation::ElementNo)
|
||||
.def("SetIntPoint", &ElementTransformation::SetIntPoint, py::arg("ip"))
|
||||
.def("Transform", py::overload_cast<const IntegrationPoint &, Vector &>(&ElementTransformation::Transform),
|
||||
py::arg("ip"), py::arg("transip"))
|
||||
.def("Weight", &ElementTransformation::Weight)
|
||||
// Properties that return references to internal data should use reference_internal policy
|
||||
.def_property_readonly("Jacobian", &ElementTransformation::Jacobian, py::return_value_policy::reference_internal)
|
||||
.def_property_readonly("InverseJacobian", &ElementTransformation::InverseJacobian, py::return_value_policy::reference_internal)
|
||||
.def_property_readonly("AdjugateJacobian", &ElementTransformation::AdjugateJacobian, py::return_value_policy::reference_internal);
|
||||
|
||||
// Bind IsoparametricTransformation, which is a concrete type of ElementTransformation
|
||||
py::class_<IsoparametricTransformation, ElementTransformation>(m, "IsoparametricTransformation")
|
||||
.def(py::init<>());
|
||||
|
||||
// Bind FaceElementTransformations, crucial for DG methods
|
||||
py::class_<FaceElementTransformations, ElementTransformation>(m, "FaceElementTransformations")
|
||||
.def(py::init<>())
|
||||
.def_readonly("Elem1No", &FaceElementTransformations::Elem1No)
|
||||
.def_readonly("Elem2No", &FaceElementTransformations::Elem2No)
|
||||
.def_property_readonly("Elem1", [](FaceElementTransformations &self) { return self.Elem1; }, py::return_value_policy::reference)
|
||||
.def_property_readonly("Elem2", [](FaceElementTransformations &self) { return self.Elem2; }, py::return_value_policy::reference)
|
||||
.def("GetElement1IntPoint", &FaceElementTransformations::GetElement1IntPoint)
|
||||
.def("GetElement2IntPoint", &FaceElementTransformations::GetElement2IntPoint);
|
||||
|
||||
// Bind the enum used for selecting transformations
|
||||
py::enum_<FaceElementTransformations::ConfigMasks>(eltrans, "ConfigMasks")
|
||||
.value("HAVE_ELEM1", FaceElementTransformations::HAVE_ELEM1)
|
||||
.value("HAVE_ELEM2", FaceElementTransformations::HAVE_ELEM2)
|
||||
.value("HAVE_LOC1", FaceElementTransformations::HAVE_LOC1)
|
||||
.value("HAVE_LOC2", FaceElementTransformations::HAVE_LOC2)
|
||||
.value("HAVE_FACE", FaceElementTransformations::HAVE_FACE)
|
||||
.export_values();
|
||||
}
|
||||
33
src/python/mfem/bindings.h
Normal file
33
src/python/mfem/bindings.h
Normal file
@@ -0,0 +1,33 @@
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
void register_mfem_bindings(pybind11::module &mfem_submodule);
|
||||
|
||||
void register_operator_bindings(pybind11::module &mfem_submodule);
|
||||
void register_matrix_bindings(pybind11::module &mfem_submodule);
|
||||
void register_vector_bindings(pybind11::module &mfem_submodule);
|
||||
void register_array_bindings(pybind11::module &mfem_submodule);
|
||||
|
||||
void bind_assembly_level_enum(pybind11::module &mfem_submodule);
|
||||
void register_bilinear_form_bindings(pybind11::module &mfem_submodule);
|
||||
void register_mixed_bilinear_form_bindings(pybind11::module &mfem_submodule);
|
||||
|
||||
void register_table_bindings(pybind11::module &mfem_submodule);
|
||||
void register_mesh_bindings(pybind11::module &mfem_submodule);
|
||||
|
||||
void register_basis_type_bindings(pybind11::module &mfem_submodule);
|
||||
void register_finite_element_collection_bindings(pybind11::module &mfem_submodule);
|
||||
void register_H1_FECollection_bindings(pybind11::module &mfem_submodule);
|
||||
void register_RT_FECollection_bindings(pybind11::module &mfem_submodule);
|
||||
void register_ND_FECollection_bindings(pybind11::module &mfem_submodule);
|
||||
|
||||
void bind_ordering_enum(pybind11::module &mfem_submodule);
|
||||
void register_finite_element_space_bindings(pybind11::module &mfem_submodule);
|
||||
|
||||
|
||||
void register_coefficient_bindings(pybind11::module &m);
|
||||
void register_eltrans_bindings(pybind11::module &m);
|
||||
void register_intrule_bindings(pybind11::module &m);
|
||||
|
||||
void register_grid_function_bindings(pybind11::module &mfem_submodule);
|
||||
26
src/python/mfem/meson.build
Normal file
26
src/python/mfem/meson.build
Normal file
@@ -0,0 +1,26 @@
|
||||
subdir('Trampoline')
|
||||
|
||||
# Define the library
|
||||
bindings_sources = files(
|
||||
'bindings.cpp',
|
||||
)
|
||||
bindings_headers = files(
|
||||
'bindings.h',
|
||||
)
|
||||
|
||||
dependencies = [
|
||||
config_dep,
|
||||
resourceManager_dep,
|
||||
python3_dep,
|
||||
pybind11_dep,
|
||||
mpi_dep,
|
||||
trampoline_dep,
|
||||
]
|
||||
|
||||
shared_module('py_mfem',
|
||||
bindings_sources,
|
||||
include_directories: include_directories('.'),
|
||||
cpp_args: ['-fvisibility=default'],
|
||||
install : true,
|
||||
dependencies: dependencies,
|
||||
)
|
||||
@@ -1,15 +1,26 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h> // Needed for vectors, maps, sets, strings
|
||||
#include <pybind11/stl_bind.h> // Needed for binding std::vector, std::map etc if needed directly
|
||||
#include <pybind11/numpy.h>
|
||||
|
||||
#include "bindings.h"
|
||||
#include "EOSio.h"
|
||||
#include "helm.h"
|
||||
#include "polySolver.h"
|
||||
#include "../../polytrope/solver/public/polySolver.h"
|
||||
#include "../../polytrope/utils/public/polytropeOperator.h"
|
||||
#include "mfem.hpp"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void register_polytrope_bindings(pybind11::module &polytrope_submodule) {
|
||||
py::class_<serif::polytrope::PolySolver>(polytrope_submodule, "PolySolver")
|
||||
.def(py::init<double, int>(), py::arg("polytropic_index"), py::arg("FEM_order"))
|
||||
.def("solve", &serif::polytrope::PolySolver::solve, "Solve the polytrope equation.");
|
||||
.def("solve", &serif::polytrope::PolySolver::solve, "Solve the polytrope equation.")
|
||||
.def("get_theta", &serif::polytrope::PolySolver::getTheta, py::return_value_policy::reference_internal)
|
||||
.def("get_phi", &serif::polytrope::PolySolver::getPhi, py::return_value_policy::reference_internal)
|
||||
.def("get_order", &serif::polytrope::PolySolver::getOrder)
|
||||
.def("get_n", &serif::polytrope::PolySolver::getN);
|
||||
|
||||
py::class_<serif::polytrope::PolytropeOperator, mfem::Operator>(polytrope_submodule, "PolytropeOperator")
|
||||
.def("Mult", &serif::polytrope::PolytropeOperator::Mult);
|
||||
|
||||
}
|
||||
|
||||
11
tests/python/polytrope/loadMesh.py
Normal file
11
tests/python/polytrope/loadMesh.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from serif.mfem import Mesh
|
||||
import argparse
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Test loading a mesh using MFEM's mesh loading called from Python")
|
||||
parser.add_argument("path", type=str, help="path to mesh")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
mesh = Mesh(args.path)
|
||||
print(mesh.nv)
|
||||
@@ -1,8 +1,18 @@
|
||||
from serif import config
|
||||
from serif.polytrope import PolySolver
|
||||
from serif.mfem import Matrix
|
||||
from serif.mfem import Operator
|
||||
from serif.mfem.forms import BilinearForm
|
||||
|
||||
config.loadConfig('../../testsConfig.yaml')
|
||||
n = config.get("Tests:Poly:Index", 0.0)
|
||||
|
||||
polytrope = PolySolver(n, 1)
|
||||
polytrope.solve()
|
||||
polytrope.solve()
|
||||
theta = polytrope.get_theta()
|
||||
print(theta)
|
||||
|
||||
FESpace = theta.FESpace()
|
||||
print(FESpace)
|
||||
|
||||
print(theta.as_numpy())
|
||||
|
||||
Reference in New Issue
Block a user