Files
SERiF/src/python/mfem/Trampoline/PyMFEMTrampolines/Operator/PyOperator.h

47 lines
1.7 KiB
C++

#pragma once
#include "mfem.hpp"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h" // Needed for vectors, maps, sets, strings
namespace py = pybind11;
namespace serif::pybind {
/**
* @brief A trampoline class for mfem::Operator.
* * This class allows Python classes to inherit from mfem::Operator and correctly
* override its virtual functions. When a virtual function is called from C++,
* the trampoline ensures the call is forwarded to the Python implementation if one exists.
*/
class PyOperator : public mfem::Operator {
public:
// Inherit the constructors from the base mfem::Operator class.
// This allows Python classes to call e.g., super().__init__(size).
using mfem::Operator::Operator;
// --- Trampoline declarations for all overridable virtual functions ---
// Pure virtual function (MANDATORY override)
void Mult(const mfem::Vector &x, mfem::Vector &y) const override;
// Regular virtual functions (RECOMMENDED overrides)
void MultTranspose(const mfem::Vector &x, mfem::Vector &y) const override;
void AddMult(const mfem::Vector &x, mfem::Vector &y, const mfem::real_t a = 1.0) const override;
void AddMultTranspose(const mfem::Vector &x, mfem::Vector &y, const mfem::real_t a = 1.0) const override;
Operator& GetGradient(const mfem::Vector &x) const override;
void AssembleDiagonal(mfem::Vector &diag) const override;
const mfem::Operator* GetProlongation() const override;
const mfem::Operator* GetRestriction() const override;
void RecoverFEMSolution(const mfem::Vector &X, const mfem::Vector &b, mfem::Vector &x) override;
};
} // namespace serif::pybind::mfem