refactor(mfem_smout): broke logic saving sparse matrix into two methods
This commit is contained in:
@@ -34,15 +34,39 @@
|
||||
* - J array (int64_t * NNZ): CSR Column Indices
|
||||
* - Data array (double * NNZ): CSR Non-zero values
|
||||
*/
|
||||
void write_sparse_matrix(const mfem::SparseMatrix &mat, std::ostream &outfile) {
|
||||
inline void write_sparse_matrix(mfem::Operator &op, std::ostream &outfile, bool neg) {
|
||||
mfem::SparseMatrix* mat;
|
||||
try {
|
||||
mat = dynamic_cast<mfem::SparseMatrix*>(&op);
|
||||
if (!mat) {
|
||||
throw std::runtime_error("The operator is not a SparseMatrix.");
|
||||
}
|
||||
} catch (const std::runtime_error&) {
|
||||
try {
|
||||
const auto& blf = dynamic_cast<mfem::BilinearForm*>(&op);
|
||||
if (!blf) {
|
||||
throw std::runtime_error("The operator is not a SparseMatrix or BilinearForm.");
|
||||
}
|
||||
mat = &blf->SpMat();
|
||||
} catch (const std::runtime_error&) {
|
||||
auto mblf = dynamic_cast<mfem::MixedBilinearForm*>(&op);
|
||||
if (!mblf) {
|
||||
throw std::runtime_error("The operator is not a SparseMatrix or BilinearForm or MixedBilinear Form.");
|
||||
}
|
||||
mat = &mblf->SpMat();
|
||||
}
|
||||
}
|
||||
if (neg) {
|
||||
*mat *= -1.0;
|
||||
}
|
||||
// --- Get Data Pointers and Dimensions from MFEM Matrix ---
|
||||
const int* mfem_I = mat.GetI();
|
||||
const int* mfem_J = mat.GetJ();
|
||||
const double* mfem_data = mat.GetData();
|
||||
const int* mfem_I = mat->GetI();
|
||||
const int* mfem_J = mat->GetJ();
|
||||
const double* mfem_data = mat->GetData();
|
||||
|
||||
uint64_t height = static_cast<uint64_t>(mat.Height());
|
||||
uint64_t width = static_cast<uint64_t>(mat.Width());
|
||||
uint64_t nnz = static_cast<uint64_t>(mat.NumNonZeroElems());
|
||||
auto height = static_cast<uint64_t>(mat->Height());
|
||||
auto width = static_cast<uint64_t>(mat->Width());
|
||||
auto nnz = static_cast<uint64_t>(mat->NumNonZeroElems());
|
||||
uint64_t i_count = height + 1;
|
||||
uint64_t j_count = nnz;
|
||||
uint64_t data_count = nnz;
|
||||
@@ -86,7 +110,7 @@ void write_sparse_matrix(const mfem::SparseMatrix &mat, std::ostream &outfile) {
|
||||
if (!outfile) throw std::runtime_error("Error writing Data array.");
|
||||
}
|
||||
|
||||
bool saveSparseMatrixBinary(const mfem::SparseMatrix& mat, const std::string& filename) {
|
||||
inline bool saveSparseMatrixBinary(mfem::SparseMatrix& mat, const std::string& filename) {
|
||||
std::ofstream outfile(filename, std::ios::binary | std::ios::trunc);
|
||||
if (!outfile) {
|
||||
std::cerr << "Error: Cannot open file for writing: " << filename << std::endl;
|
||||
@@ -94,7 +118,7 @@ bool saveSparseMatrixBinary(const mfem::SparseMatrix& mat, const std::string& fi
|
||||
}
|
||||
|
||||
try {
|
||||
write_sparse_matrix(mat, outfile);
|
||||
write_sparse_matrix(mat, outfile, false);
|
||||
|
||||
|
||||
} catch (const std::exception& e) {
|
||||
@@ -111,7 +135,7 @@ bool saveSparseMatrixBinary(const mfem::SparseMatrix& mat, const std::string& fi
|
||||
return true;
|
||||
}
|
||||
|
||||
void writeDenseMatrixToCSV(const std::string &filename, int precision, const mfem::DenseMatrix *mat) {
|
||||
inline void writeDenseMatrixToCSV(const std::string &filename, int precision, const mfem::DenseMatrix *mat) {
|
||||
if (!mat) {
|
||||
throw std::runtime_error("The operator is not a SparseMatrix.");
|
||||
}
|
||||
@@ -159,9 +183,9 @@ void writeDenseMatrixToCSV(const std::string &filename, int precision, const mfe
|
||||
* @param filename The name of the output CSV file.
|
||||
* @param precision Number of decimal places for floating-point values.
|
||||
*/
|
||||
void writeOperatorToCSV(const mfem::Operator &op,
|
||||
const std::string &filename,
|
||||
int precision = 6) // Add precision argument
|
||||
inline void writeOperatorToCSV(const mfem::Operator &op,
|
||||
const std::string &filename,
|
||||
int precision = 6) // Add precision argument
|
||||
{
|
||||
// Attempt to cast the Operator to a SparseMatrix
|
||||
const auto *sparse_mat = dynamic_cast<const mfem::SparseMatrix*>(&op);
|
||||
@@ -172,7 +196,7 @@ void writeDenseMatrixToCSV(const std::string &filename, int precision, const mfe
|
||||
writeDenseMatrixToCSV(filename, precision, mat);
|
||||
}
|
||||
|
||||
void saveBlockFormToBinary(std::vector<mfem::SparseMatrix *> &block_diags, std::vector<std::array<int, 2>> block, std::string filename) {
|
||||
inline void saveBlockFormToBinary(const std::vector<mfem::Operator*> &block_diags, const std::vector<std::array<int, 2>>& block, const std::vector<bool> neg, const std::string &filename) {
|
||||
// First write a magic number and version
|
||||
|
||||
// --- Open the file ---
|
||||
@@ -191,11 +215,11 @@ void saveBlockFormToBinary(std::vector<mfem::SparseMatrix *> &block_diags, std::
|
||||
outfile.write(reinterpret_cast<const char*>(&magic), 4);
|
||||
outfile.write(reinterpret_cast<const char*>(&size), sizeof(size));
|
||||
|
||||
for (const auto&& [block_diag, blockIDs] : std::views::zip(block_diags, block)) {
|
||||
for (const auto&& [block_diag, blockIDs, isNeg] : std::views::zip(block_diags, block, neg)) {
|
||||
// Write the sparse matrix data
|
||||
outfile.write(reinterpret_cast<const char*>(&datastart), 9);
|
||||
outfile.write(reinterpret_cast<const char*>(&blockIDs), sizeof(blockIDs));
|
||||
write_sparse_matrix(*block_diag, outfile);
|
||||
write_sparse_matrix(*block_diag, outfile, isNeg);
|
||||
outfile.write(reinterpret_cast<const char*>(&dataend), 7);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user