feat(debugUtils): added more sparse matrix debug utilities

This commit is contained in:
2025-04-14 07:58:37 -04:00
parent 41460acacf
commit c680433740
5 changed files with 233 additions and 108 deletions

View File

@@ -8,6 +8,11 @@
#include "mfem.hpp"
#include <iostream>
#include <fstream>
#include <vector>
#include <array>
#include <iomanip>
#include <tuple>
#include <ranges>
/**
* @brief Saves an mfem::SparseMatrix to a custom compact binary file (.csrbin).
@@ -29,6 +34,58 @@
* - J array (int64_t * NNZ): CSR Column Indices
* - Data array (double * NNZ): CSR Non-zero values
*/
void write_sparse_matrix(const mfem::SparseMatrix &mat, std::ostream &outfile) {
// --- Get Data Pointers and Dimensions from MFEM Matrix ---
const int* mfem_I = mat.GetI();
const int* mfem_J = mat.GetJ();
const double* mfem_data = mat.GetData();
uint64_t height = static_cast<uint64_t>(mat.Height());
uint64_t width = static_cast<uint64_t>(mat.Width());
uint64_t nnz = static_cast<uint64_t>(mat.NumNonZeroElems());
uint64_t i_count = height + 1;
uint64_t j_count = nnz;
uint64_t data_count = nnz;
// --- Write Header ---
const char magic[4] = {'C', 'S', 'R', 'B'};
const uint8_t version = 1;
const uint8_t int_size = 8;
const uint8_t flt_size = 8;
const uint8_t reserved = 0;
outfile.write(magic, 4);
outfile.write(reinterpret_cast<const char*>(&version), 1);
outfile.write(reinterpret_cast<const char*>(&int_size), 1);
outfile.write(reinterpret_cast<const char*>(&flt_size), 1);
outfile.write(reinterpret_cast<const char*>(&reserved), 1);
outfile.write(reinterpret_cast<const char*>(&height), sizeof(height));
outfile.write(reinterpret_cast<const char*>(&width), sizeof(width));
outfile.write(reinterpret_cast<const char*>(&nnz), sizeof(nnz));
if (!outfile) throw std::runtime_error("Error writing header.");
// --- Write Arrays (Converting int to int64_t for I and J) ---
std::vector<int64_t> i_buffer(i_count);
for (uint64_t idx = 0; idx < i_count; ++idx) {
i_buffer[idx] = static_cast<int64_t>(mfem_I[idx]);
}
outfile.write(reinterpret_cast<const char*>(i_buffer.data()), i_count * sizeof(int64_t));
if (!outfile) throw std::runtime_error("Error writing I array.");
std::vector<int64_t> j_buffer(j_count);
for (uint64_t idx = 0; idx < j_count; ++idx) {
j_buffer[idx] = static_cast<int64_t>(mfem_J[idx]);
}
outfile.write(reinterpret_cast<const char*>(j_buffer.data()), j_count * sizeof(int64_t));
if (!outfile) throw std::runtime_error("Error writing J array.");
outfile.write(reinterpret_cast<const char*>(mfem_data), data_count * sizeof(double));
if (!outfile) throw std::runtime_error("Error writing Data array.");
}
bool saveSparseMatrixBinary(const mfem::SparseMatrix& mat, const std::string& filename) {
std::ofstream outfile(filename, std::ios::binary | std::ios::trunc);
if (!outfile) {
@@ -37,55 +94,7 @@ bool saveSparseMatrixBinary(const mfem::SparseMatrix& mat, const std::string& fi
}
try {
// --- Get Data Pointers and Dimensions from MFEM Matrix ---
const int* mfem_I = mat.GetI();
const int* mfem_J = mat.GetJ();
const double* mfem_data = mat.GetData();
uint64_t height = static_cast<uint64_t>(mat.Height());
uint64_t width = static_cast<uint64_t>(mat.Width());
uint64_t nnz = static_cast<uint64_t>(mat.NumNonZeroElems());
uint64_t i_count = height + 1;
uint64_t j_count = nnz;
uint64_t data_count = nnz;
// --- Write Header ---
const char magic[4] = {'C', 'S', 'R', 'B'};
const uint8_t version = 1;
const uint8_t int_size = 8;
const uint8_t flt_size = 8;
const uint8_t reserved = 0;
outfile.write(magic, 4);
outfile.write(reinterpret_cast<const char*>(&version), 1);
outfile.write(reinterpret_cast<const char*>(&int_size), 1);
outfile.write(reinterpret_cast<const char*>(&flt_size), 1);
outfile.write(reinterpret_cast<const char*>(&reserved), 1);
outfile.write(reinterpret_cast<const char*>(&height), sizeof(height));
outfile.write(reinterpret_cast<const char*>(&width), sizeof(width));
outfile.write(reinterpret_cast<const char*>(&nnz), sizeof(nnz));
if (!outfile) throw std::runtime_error("Error writing header.");
// --- Write Arrays (Converting int to int64_t for I and J) ---
std::vector<int64_t> i_buffer(i_count);
for (uint64_t idx = 0; idx < i_count; ++idx) {
i_buffer[idx] = static_cast<int64_t>(mfem_I[idx]);
}
outfile.write(reinterpret_cast<const char*>(i_buffer.data()), i_count * sizeof(int64_t));
if (!outfile) throw std::runtime_error("Error writing I array.");
std::vector<int64_t> j_buffer(j_count);
for (uint64_t idx = 0; idx < j_count; ++idx) {
j_buffer[idx] = static_cast<int64_t>(mfem_J[idx]);
}
outfile.write(reinterpret_cast<const char*>(j_buffer.data()), j_count * sizeof(int64_t));
if (!outfile) throw std::runtime_error("Error writing J array.");
outfile.write(reinterpret_cast<const char*>(mfem_data), data_count * sizeof(double));
if (!outfile) throw std::runtime_error("Error writing Data array.");
write_sparse_matrix(mat, outfile);
} catch (const std::exception& e) {
@@ -163,4 +172,33 @@ void writeDenseMatrixToCSV(const std::string &filename, int precision, const mfe
writeDenseMatrixToCSV(filename, precision, mat);
}
void saveBlockFormToBinary(std::vector<mfem::SparseMatrix *> &block_diags, std::vector<std::array<int, 2>> block, std::string filename) {
// First write a magic number and version
// --- Open the file ---
std::ofstream outfile(filename, std::ios::binary | std::ios::trunc);
if (!outfile) {
std::cerr << "Error: Cannot open file for writing: " << filename << std::endl;
return;
}
// --- Write Header ---
const char magic[4] = {'B', 'L', 'C', 'K'};
const char datastart[9] = {'D', 'A', 'T', 'A', 'S', 'T', 'A', 'R', 'T'};
const char dataend[7] = {'D', 'A', 'T', 'A', 'E', 'N', 'D'};
const uint8_t size = block_diags.size();
outfile.write(reinterpret_cast<const char*>(&magic), 4);
outfile.write(reinterpret_cast<const char*>(&size), sizeof(size));
for (const auto&& [block_diag, blockIDs] : std::views::zip(block_diags, block)) {
// Write the sparse matrix data
outfile.write(reinterpret_cast<const char*>(&datastart), 9);
outfile.write(reinterpret_cast<const char*>(&blockIDs), sizeof(blockIDs));
write_sparse_matrix(*block_diag, outfile);
outfile.write(reinterpret_cast<const char*>(&dataend), 7);
}
}
#endif //MFEM_SMOUT_H