47 lines
1.7 KiB
C++
47 lines
1.7 KiB
C++
#pragma once
|
|
|
|
#include "mfem.hpp"
|
|
#include "pybind11/pybind11.h"
|
|
#include "pybind11/stl.h" // Needed for vectors, maps, sets, strings
|
|
|
|
namespace py = pybind11;
|
|
|
|
namespace serif::pybind {
|
|
|
|
/**
|
|
* @brief A trampoline class for mfem::Operator.
|
|
* * This class allows Python classes to inherit from mfem::Operator and correctly
|
|
* override its virtual functions. When a virtual function is called from C++,
|
|
* the trampoline ensures the call is forwarded to the Python implementation if one exists.
|
|
*/
|
|
class PyOperator : public mfem::Operator {
|
|
public:
|
|
// Inherit the constructors from the base mfem::Operator class.
|
|
// This allows Python classes to call e.g., super().__init__(size).
|
|
using mfem::Operator::Operator;
|
|
|
|
// --- Trampoline declarations for all overridable virtual functions ---
|
|
|
|
// Pure virtual function (MANDATORY override)
|
|
void Mult(const mfem::Vector &x, mfem::Vector &y) const override;
|
|
|
|
// Regular virtual functions (RECOMMENDED overrides)
|
|
void MultTranspose(const mfem::Vector &x, mfem::Vector &y) const override;
|
|
|
|
void AddMult(const mfem::Vector &x, mfem::Vector &y, const mfem::real_t a = 1.0) const override;
|
|
|
|
void AddMultTranspose(const mfem::Vector &x, mfem::Vector &y, const mfem::real_t a = 1.0) const override;
|
|
|
|
Operator& GetGradient(const mfem::Vector &x) const override;
|
|
|
|
void AssembleDiagonal(mfem::Vector &diag) const override;
|
|
|
|
const mfem::Operator* GetProlongation() const override;
|
|
|
|
const mfem::Operator* GetRestriction() const override;
|
|
|
|
void RecoverFEMSolution(const mfem::Vector &X, const mfem::Vector &b, mfem::Vector &x) override;
|
|
};
|
|
|
|
} // namespace serif::pybind::mfem
|