#pragma once #include "mfem.hpp" #include "pybind11/pybind11.h" #include "pybind11/stl.h" // Needed for vectors, maps, sets, strings namespace py = pybind11; namespace serif::pybind { /** * @brief A trampoline class for mfem::Operator. * * This class allows Python classes to inherit from mfem::Operator and correctly * override its virtual functions. When a virtual function is called from C++, * the trampoline ensures the call is forwarded to the Python implementation if one exists. */ class PyOperator : public mfem::Operator { public: // Inherit the constructors from the base mfem::Operator class. // This allows Python classes to call e.g., super().__init__(size). using mfem::Operator::Operator; // --- Trampoline declarations for all overridable virtual functions --- // Pure virtual function (MANDATORY override) void Mult(const mfem::Vector &x, mfem::Vector &y) const override; // Regular virtual functions (RECOMMENDED overrides) void MultTranspose(const mfem::Vector &x, mfem::Vector &y) const override; void AddMult(const mfem::Vector &x, mfem::Vector &y, const mfem::real_t a = 1.0) const override; void AddMultTranspose(const mfem::Vector &x, mfem::Vector &y, const mfem::real_t a = 1.0) const override; Operator& GetGradient(const mfem::Vector &x) const override; void AssembleDiagonal(mfem::Vector &diag) const override; const mfem::Operator* GetProlongation() const override; const mfem::Operator* GetRestriction() const override; void RecoverFEMSolution(const mfem::Vector &X, const mfem::Vector &b, mfem::Vector &x) override; }; } // namespace serif::pybind::mfem