feat(debug-utils): added framework for shared debug util tools

This commit is contained in:
2025-04-10 09:05:30 -04:00
parent 08b68c22de
commit 41460acacf
21 changed files with 465 additions and 1799 deletions

View File

@@ -43,11 +43,18 @@ subdir('build-config')
subdir('assets/static')
if get_option('build_debug_utils')
subdir('utils/debugUtils')
endif
# Build the main project
subdir('src')
if get_option('build_tests')
subdir('tests')
endif
# Build the utilities
subdir('utils')
if get_option('build_post_run_utils')
subdir('utils')
endif

View File

@@ -1,2 +1,4 @@
option('build_tests', type: 'boolean', value: true, description: 'Build tests')
option('user_mode', type: 'boolean', value: false, description: 'Enable user mode (set mode = 0)')
option('build_post_run_utils', type: 'boolean', value: true, description: 'Build Helper Utilities')
option('build_debug_utils', type: 'boolean', value: true, description: 'Build Debug Utilities')

View File

@@ -30,6 +30,7 @@ dependencies = [
quill_dep,
config_dep,
types_dep,
mfemanalysis_dep,
]
libpolyutils = static_library('polyutils',

View File

@@ -0,0 +1 @@
mfemanalysis_dep = declare_dependency(include_directories: 'src/include')

View File

@@ -0,0 +1,166 @@
//
// Created by Emily Boudreaux on 4/10/25.
//
#ifndef MFEM_SMOUT_H
#define MFEM_SMOUT_H
#include "mfem.hpp"
#include <iostream>
#include <fstream>
/**
* @brief Saves an mfem::SparseMatrix to a custom compact binary file (.csrbin).
*
* @param mat The mfem::SparseMatrix to save (assumed to be in CSR format).
* @param filename The path to the output file.
* @return true if saving was successful, false otherwise.
*
* File Format (.csrbin):
* - Magic (4 bytes): 'C','S','R','B'
* - Version (1 byte): 1
* - IntSize (1 byte): 8 (using int64_t for indices/dims)
* - FltSize (1 byte): 8 (using double for data)
* - Reserved (1 byte): 0
* - Height (uint64_t): Number of rows
* - Width (uint64_t): Number of columns
* - NNZ (uint64_t): Number of non-zeros
* - I array (int64_t * (Height + 1)): CSR Row Pointers
* - J array (int64_t * NNZ): CSR Column Indices
* - Data array (double * NNZ): CSR Non-zero values
*/
bool saveSparseMatrixBinary(const mfem::SparseMatrix& mat, const std::string& filename) {
std::ofstream outfile(filename, std::ios::binary | std::ios::trunc);
if (!outfile) {
std::cerr << "Error: Cannot open file for writing: " << filename << std::endl;
return false;
}
try {
// --- Get Data Pointers and Dimensions from MFEM Matrix ---
const int* mfem_I = mat.GetI();
const int* mfem_J = mat.GetJ();
const double* mfem_data = mat.GetData();
uint64_t height = static_cast<uint64_t>(mat.Height());
uint64_t width = static_cast<uint64_t>(mat.Width());
uint64_t nnz = static_cast<uint64_t>(mat.NumNonZeroElems());
uint64_t i_count = height + 1;
uint64_t j_count = nnz;
uint64_t data_count = nnz;
// --- Write Header ---
const char magic[4] = {'C', 'S', 'R', 'B'};
const uint8_t version = 1;
const uint8_t int_size = 8;
const uint8_t flt_size = 8;
const uint8_t reserved = 0;
outfile.write(magic, 4);
outfile.write(reinterpret_cast<const char*>(&version), 1);
outfile.write(reinterpret_cast<const char*>(&int_size), 1);
outfile.write(reinterpret_cast<const char*>(&flt_size), 1);
outfile.write(reinterpret_cast<const char*>(&reserved), 1);
outfile.write(reinterpret_cast<const char*>(&height), sizeof(height));
outfile.write(reinterpret_cast<const char*>(&width), sizeof(width));
outfile.write(reinterpret_cast<const char*>(&nnz), sizeof(nnz));
if (!outfile) throw std::runtime_error("Error writing header.");
// --- Write Arrays (Converting int to int64_t for I and J) ---
std::vector<int64_t> i_buffer(i_count);
for (uint64_t idx = 0; idx < i_count; ++idx) {
i_buffer[idx] = static_cast<int64_t>(mfem_I[idx]);
}
outfile.write(reinterpret_cast<const char*>(i_buffer.data()), i_count * sizeof(int64_t));
if (!outfile) throw std::runtime_error("Error writing I array.");
std::vector<int64_t> j_buffer(j_count);
for (uint64_t idx = 0; idx < j_count; ++idx) {
j_buffer[idx] = static_cast<int64_t>(mfem_J[idx]);
}
outfile.write(reinterpret_cast<const char*>(j_buffer.data()), j_count * sizeof(int64_t));
if (!outfile) throw std::runtime_error("Error writing J array.");
outfile.write(reinterpret_cast<const char*>(mfem_data), data_count * sizeof(double));
if (!outfile) throw std::runtime_error("Error writing Data array.");
} catch (const std::exception& e) {
std::cerr << "Error during binary matrix save: " << e.what() << std::endl;
outfile.close();
return false;
}
outfile.close();
if (!outfile) {
std::cerr << "Error closing file after writing: " << filename << std::endl;
return false;
}
return true;
}
void writeDenseMatrixToCSV(const std::string &filename, int precision, const mfem::DenseMatrix *mat) {
if (!mat) {
throw std::runtime_error("The operator is not a SparseMatrix.");
}
std::ofstream outfile(filename);
if (!outfile.is_open()) {
throw std::runtime_error("Failed to open file: " + filename);
}
int height = mat->Height();
int width = mat->Width();
// Set precision for floating-point output
outfile << std::fixed << std::setprecision(precision);
for (int i = 0; i < width; i++) {
outfile << i;
if (i < width - 1) {
outfile << ",";
}
else {
outfile << "\n";
}
}
// Iterate through rows
for (int i = 0; i < height; ++i) {
for (int j = 0; j < width; ++j) {
outfile << mat->Elem(i, j);
if (j < width - 1) {
outfile << ",";
}
}
outfile << std::endl;
}
outfile.close();
}
/**
* @brief Writes the dense representation of an MFEM Operator (if it's a SparseMatrix) to a CSV file.
*
* @param op The MFEM Operator to write.
* @param filename The name of the output CSV file.
* @param precision Number of decimal places for floating-point values.
*/
void writeOperatorToCSV(const mfem::Operator &op,
const std::string &filename,
int precision = 6) // Add precision argument
{
// Attempt to cast the Operator to a SparseMatrix
const auto *sparse_mat = dynamic_cast<const mfem::SparseMatrix*>(&op);
if (!sparse_mat) {
throw std::runtime_error("The operator is not a SparseMatrix.");
}
const mfem::DenseMatrix *mat = sparse_mat->ToDenseMatrix();
writeDenseMatrixToCSV(filename, precision, mat);
}
#endif //MFEM_SMOUT_H

View File

@@ -0,0 +1,36 @@
[build-system]
requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "SSEDebug"
version = "0.1.0"
description = "A python module for general 4DSSE debugging"
readme = "readme.md"
authors = [
{name = "Emily M. Boudreaux", email = "emily.boudreaux@dartmouth.edu"},
{name = "4D-STAR Collaboration"},
]
maintainers = [
{name = "Emily M. Boudreaux", email="emily.boudreaux@dartmouth.edu"}
]
keywords = ["astrophysics", "MFEM"]
requires-python = ">=3.8"
dependencies = ["numpy >= 1.21.1", "scipy>=1.13.1"]
classifiers = [
"Development Status :: 3 - Alpha",
"Intended Audience :: Science/Research",
"Programming Language :: Python :: 3",
"Topic :: Scientific/Engineering :: Astronomy",
"Operating System :: OS Independent"
]
[tool.setuptools]
package-dir = {"" = "src"}
[tool.setuptools.packages.find]
where = ["src"]

View File

@@ -0,0 +1 @@
__version__="0.1.0"

View File

@@ -0,0 +1 @@
from .smread import loadSparseMatrixBinary, analyze_sparse_matrix, load_and_analyze_sparse_matrix

View File

@@ -0,0 +1,234 @@
import argparse
import numpy as np
import scipy.sparse as sp
import struct
import scipy.sparse.linalg as spla # For matrix norm
import time
import os
def loadSparseMatrixBinary(filename):
"""
Loads a sparse matrix from the custom binary format (.csrbin).
Args:
filename (str): The path to the .csrbin file.
Returns:
scipy.sparse.csr_matrix: The loaded sparse matrix.
Raises:
ValueError: If the file format is incorrect or sizes don't match.
IOError: If the file cannot be read.
"""
INT_SIZE = 8 # Expecting int64_t from the C++ writer
FLT_SIZE = 8 # Expecting double from the C++ writer
EXPECTED_MAGIC = b'CSRB'
EXPECTED_VERSION = 1
try:
with open(filename, 'rb') as f:
# --- Read Header ---
magic = f.read(4)
if magic != EXPECTED_MAGIC:
raise ValueError(f"Invalid magic number. Expected {EXPECTED_MAGIC}, got {magic}")
version, int_size_file, flt_size_file, reserved = struct.unpack('<BBBB', f.read(4))
# '<' means little-endian, 'B' means unsigned char (1 byte)
if version != EXPECTED_VERSION:
print(f"Warning: File version {version} differs from expected {EXPECTED_VERSION}.")
if int_size_file != INT_SIZE:
raise ValueError(f"Integer size mismatch. Expected {INT_SIZE}, file has {int_size_file}")
if flt_size_file != FLT_SIZE:
raise ValueError(f"Float size mismatch. Expected {FLT_SIZE}, file has {flt_size_file}")
height, width, nnz = struct.unpack('<QQQ', f.read(24))
# '<' means little-endian, 'Q' means unsigned long long (8 bytes)
i_count = height + 1
j_count = nnz
data_count = nnz
if nnz == 0: # Handle empty matrix case
print("Warning: Matrix file contains zero non-zero elements.")
# Return an empty matrix with correct shape
return sp.csr_matrix((height, width), dtype=np.float64)
# --- Read Arrays ---
# Read I array (Row Pointers)
expected_i_bytes = i_count * INT_SIZE
I_array = np.fromfile(f, dtype=np.int64, count=i_count) # Read as int64
if I_array.size != i_count:
raise ValueError(f"Error reading I array. Expected {i_count} elements, read {I_array.size}. File truncated or corrupt?")
# Read J array (Column Indices)
expected_j_bytes = j_count * INT_SIZE
J_array = np.fromfile(f, dtype=np.int64, count=j_count) # Read as int64
if J_array.size != j_count:
raise ValueError(f"Error reading J array. Expected {j_count} elements, read {J_array.size}. File truncated or corrupt?")
# Read Data array (Values)
expected_data_bytes = data_count * FLT_SIZE
Data_array = np.fromfile(f, dtype=np.float64, count=data_count) # Read as float64
if Data_array.size != data_count:
raise ValueError(f"Error reading Data array. Expected {data_count} elements, read {Data_array.size}. File truncated or corrupt?")
# --- Check for extra data ---
extra_data = f.read()
if extra_data:
print(f"Warning: {len(extra_data)} extra bytes found at the end of the file.")
# --- Construct SciPy CSR Matrix ---
sparse_matrix = sp.csr_matrix((Data_array, J_array, I_array), shape=(height, width))
if sparse_matrix.nnz != nnz:
print(f"Warning: NNZ mismatch after loading. Header NNZ: {nnz}, Scipy NNZ: {sparse_matrix.nnz}")
return sparse_matrix
except FileNotFoundError:
raise IOError(f"Error: File not found at {filename}")
except Exception as e:
raise RuntimeError(f"An error occurred while reading {filename}: {e}")
def analyze_sparse_matrix(sp_mat):
"""
Analyzes a SciPy sparse matrix and prints various statistics.
Args:
sp_mat (scipy.sparse.spmatrix): The sparse matrix to analyze.
(e.g., csr_matrix, csc_matrix).
"""
print("-" * 50)
print("Sparse Matrix Analysis Report")
print("-" * 50)
if not isinstance(sp_mat, sp.spmatrix):
print("Error: Input is not a SciPy sparse matrix.")
return
rows, cols = sp_mat.shape
print(f"Size (Shape): {rows} rows x {cols} columns")
if rows == 0 or cols == 0:
print("\nMatrix is empty. No further analysis possible.")
print("-" * 50)
return
nnz = sp_mat.nnz
total_elements = rows * cols
sparsity = 0.0
if total_elements > 0:
sparsity = 1.0 - (nnz / total_elements)
else:
sparsity = 1.0
print(f"Non-zero elements (NNZ): {nnz}")
print(f"Total elements: {total_elements}")
print(f"Sparsity: {sparsity:.6%} (percentage of zeros)")
if nnz == 0:
print("\nMatrix contains only zero elements.")
diag_elements = sp_mat.diagonal()
print(f"\nDiagonal Mean: {np.mean(diag_elements):.6e}")
print(f"Diagonal Max: {np.max(diag_elements):.6e}")
print(f"Diagonal Min: {np.min(diag_elements):.6e}")
print(f"Value Range (Min): N/A (no non-zero values)")
print(f"Value Range (Max): N/A (no non-zero values)")
print(f"Mean Non-Zero Value: N/A (no non-zero values)")
print(f"Relative Diagonal Norm: N/A (matrix norm is zero)")
print("-" * 50)
return
all_values = sp_mat.data # Access non-zero values directly
min_val = np.min(all_values)
max_val = np.max(all_values)
mean_val = np.mean(all_values)
print(f"\nValue Range (Min): {min_val:.6e}")
print(f"Value Range (Max): {max_val:.6e}")
print(f"Mean Non-Zero Value: {mean_val:.6e}")
print("\n--- Diagonal Properties ---")
start_diag = time.time()
diag_elements = sp_mat.diagonal()
end_diag = time.time()
print(f"(Diagonal extraction time: {end_diag - start_diag:.4f}s)")
if diag_elements.size > 0: # Should always be true unless rows=0 (handled above)
mean_diag = np.mean(diag_elements)
max_diag = np.max(diag_elements)
min_diag = np.min(diag_elements)
diag_nonzero = diag_elements[diag_elements != 0]
if diag_nonzero.size > 0:
mean_diag_nz = np.mean(diag_nonzero)
print(f"Mean Diagonal (all): {mean_diag:.6e}")
print(f"Mean Diagonal (non-zero):{mean_diag_nz:.6e} ({diag_nonzero.size} elements)")
else:
print(f"Mean Diagonal (all): {mean_diag:.6e}")
print(f"Mean Diagonal (non-zero): N/A (all diagonal elements are zero)")
print(f"Max Diagonal: {max_diag:.6e}")
print(f"Min Diagonal: {min_diag:.6e}")
# 5. "Diagonality" - Relative Diagonal Norm (using Frobenius norm)
# The Frobenius norm is sqrt(sum(abs(A_ij)^2))
start_norm = time.time()
norm_diag = np.linalg.norm(diag_elements)
norm_matrix = spla.norm(sp_mat, ord='fro')
end_norm = time.time()
print(f"(Norm calculation time: {end_norm - start_norm:.4f}s)")
if norm_matrix > 1e-15: # Avoid division by zero
diagonality_ratio = norm_diag / norm_matrix
print(f"\nRelative Diagonal Norm (Frobenius): {diagonality_ratio:.6f}")
print(f" (Ratio of ||diag(A)||_F / ||A||_F)")
print(f" (Diagonal Norm = {norm_diag:.6e}, Matrix Norm = {norm_matrix:.6e})")
if diagonality_ratio > 0.99:
print(" -> Matrix is strongly diagonal dominant by norm.")
elif diagonality_ratio < 0.1:
print(" -> Matrix norm is dominated by off-diagonal elements.")
else:
print("\nRelative Diagonal Norm: N/A (matrix Frobenius norm is zero)")
else: # Should not happen if rows > 0
print("\nCould not extract diagonal (matrix has zero rows?).")
# 6. Other Useful Stats
print("\n--- Other Properties ---")
is_square = rows == cols
print(f"Is Square: {is_square}")
if is_square:
try:
diff_norm = spla.norm(sp_mat - sp_mat.T, ord='fro')
if diff_norm < 1e-10 * norm_matrix : # Check relative difference norm
print(f"Is Symmetric (approx): True (||A - A.T||_F / ||A||_F < 1e-10)")
else:
print(f"Is Symmetric (approx): False (||A - A.T||_F = {diff_norm:.2e})")
except Exception as e:
print(f"Is Symmetric (approx): Check failed ({e})")
else:
print(f"Is Symmetric (approx): False (not square)")
print("-" * 50)
def load_and_analyze_sparse_matrix(filename: str):
sm = loadSparseMatrixBinary(filename)
analyze_sparse_matrix(sm)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Simple tool to get some statistics about a sparse matrix from mfem")
parser.add_argument("path", help="path to the output file", type=str)
args = parser.parse_args()
load_and_analyze_sparse_matrix(args.filename)

View File

@@ -0,0 +1 @@
subdir('MFEMAnalysis-cpp')

View File

@@ -0,0 +1,12 @@
# Tools for analyzing MFEM Sparse Matricies (among other things)
MFEM does a lot of work with sparse matrixes but does not provide trivial tools to use them. Here I include some basic utilities to analyze these matricies.
## Python
There is a python script to preform the actual analysis.
## C++
There is a small C++ header only library which provides an interface to write MFEM sparse matrixes out to disk.
The C++ utility writes mfem sparse matricies in a custom format which was written to be simple. The python script
only understands this format.

View File

@@ -0,0 +1 @@
subdir('MFEMAnalysisUtils')

View File

@@ -1,16 +0,0 @@
[build-system]
requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "opatio"
version = "0.1.0a"
description = "A python module for handling OPAT files"
readme = "readme.md"
authors = [{name = "Emily M. Boudreaux", email = "emily.boudreaux@dartmouth.edu"}]
requires-python = ">=3.8"
dependencies = ["numpy >= 1.21.1"]
[tool.setuptools]
packages = ["opatio", "opatio.opat"]
package-dir = {"" = "src"}

View File

@@ -1,46 +0,0 @@
# opatIO python module
This module defines a set of tools to build, write, and read OPAT files.
The OPAT fileformat is a custom file format designed to efficiently store
opacity information for a variety of compositions.
## Installation
You can install this module with pip
```bash
git clone <repo>
cd 4DSSE/utils/opat
pip install .
```
## General Usage
The general way that this module is mean to be used is to first build a schema for the opaticy table and then save that to disk. The module will handle all the byte aligment and lookup table construction for you.
A simple example might look like the following
```python
from opatio import OpatIO
opacityFile = OpatIO()
opacityFile.set_comment("This is a sample opacity file")
opaticyFile.set_source("OPLIB")
# some code to get a logR, logT, and logKappa table
# where logKappa is of size (n,m) if logR is size n and
# logT is size m
opacityFile.add_table(X, Z, logR, logT, logKappa)
opacityFile.save("opacity.opat")
```
You can also read opat files which have been generated with the loadOpat function
```python
from opatio import loadOpat
opacityFile = loadOpat("opacity.opat")
print(opacityFile.header)
print(opaticyFile.tables[0])
```
## Problems
If you have problems feel free to either submit an issue to the root github repo (tagged as utils/opatio) or email Emily Boudreaux at emily.boudreaux@dartmouth.edu

View File

@@ -1 +0,0 @@
from .opat.opat import OpatIO, loadOpat

View File

@@ -1,544 +0,0 @@
import struct
import numpy as np
from datetime import datetime
from dataclasses import dataclass
from typing import Iterable, List, Tuple
from collections.abc import Iterable as collectionIterable
import hashlib
import os
@dataclass
class Header:
"""
@brief Structure to hold the header information of an OPAT file.
"""
magic: str #< Magic number to identify the file type
version: int #< Version of the OPAT file format
numTables: int #< Number of tables in the file
headerSize: int #< Size of the header
indexOffset: int #< Offset to the index section
creationDate: str #< Creation date of the file
sourceInfo: str #< Source information
comment: str #< Comment section
numIndex: int #< Number of values to use when indexing table
reserved: bytes #< Reserved for future use
@dataclass
class TableIndex:
"""
@brief Structure to hold the index information of a table in an OPAT file.
"""
index: List[float] #< Index values of the table
byteStart: int #< Byte start position of the table
byteEnd: int #< Byte end position of the table
sha256: bytes #< SHA-256 hash of the table data
@dataclass
class OPATTable:
"""
@brief Structure to hold the data of an OPAT table.
"""
N_R: int #< Number of R values
N_T: int #< Number of T values
logR: Iterable[float] #< Logarithm of R values
logT: Iterable[float] #< Logarithm of T values
logKappa: Iterable[Iterable[float]] #< Logarithm of Kappa values
def make_default_header() -> Header:
"""
@brief Create a default header for an OPAT file.
@return The default header.
"""
return Header(
magic="OPAT",
version=1,
numTables=0,
headerSize=256,
indexOffset=0,
creationDate=datetime.now().strftime("%b %d, %Y"),
sourceInfo="no source provided by user",
comment="default header",
numIndex=2,
reserved=b"\x00" * 24
)
class OpatIO:
"""
@brief Class for handling OPAT file input/output operations.
This class provides methods to validate, manipulate, and save OPAT files. It includes functionalities to validate character arrays, 1D arrays, and 2D arrays, compute checksums, set header information, add tables, and save the OPAT file in both ASCII and binary formats.
Attributes:
header (Header): The header of the OPAT file.
tables (List[Tuple[Tuple[float, float], OPATTable]]): A list of tables in the OPAT file.
Methods:
validate_char_array_size(s: str, nmax: int) -> bool:
Validate the size of a character array.
validate_logKappa(logKappa):
Validate the logKappa array.
validate_1D(arr, name: str):
Validate a 1D array.
compute_checksum(data: bytes) -> bytes:
Compute the SHA-256 checksum of the given data.
set_version(version: int) -> int:
Set the version of the OPAT file.
set_source(source: str) -> str:
Set the source information of the OPAT file.
set_comment(comment: str) -> str:
Set the comment of the OPAT file.
add_table(X: float, Z: float, logR: Iterable[float], logT: Iterable[float], logKappa: Iterable[Iterable[float]]):
Add a table to the OPAT file.
_header_bytes() -> bytes:
Convert the header to bytes.
_table_bytes(table: OPATTable) -> Tuple[bytes, bytes]:
Convert a table to bytes.
_tableIndex_bytes(tableIndex: TableIndex) -> bytes:
Convert a table index to bytes.
__repr__() -> str:
Get the string representation of the OpatIO object.
_format_table_as_string(table: OPATTable, X: float, Z: float) -> str:
Format a table as a string.
print_table_indexes(table_indexes: List[TableIndex]) -> str:
Print the table indexes.
save_as_ascii(filename: str) -> str:
Save the OPAT file as an ASCII file.
save(filename: str) -> str:
Save the OPAT file as a binary file.
"""
def __init__(self):
self.header: Header = make_default_header()
self.tables: List[Tuple[Tuple[float, float], OPATTable]] = []
@staticmethod
def validate_char_array_size(s: str, nmax: int) -> bool:
"""
@brief Validate the size of a character array.
@param s The string to validate.
@param nmax The maximum allowed size.
@return True if the string size is valid, False otherwise.
"""
if len(s) > nmax:
return False
return True
@staticmethod
def validate_logKappa(logKappa):
"""
@brief Validate the logKappa array.
@param logKappa The logKappa array to validate.
@throws ValueError if logKappa is not a non-empty 2D array.
@throws TypeError if logKappa is not a 2D array or iterable.
"""
if isinstance(logKappa, np.ndarray):
if logKappa.ndim == 2:
return
else:
raise ValueError("logKappa must be a non-empty 2D array")
if isinstance(logKappa, collectionIterable) and all(isinstance(row, collectionIterable) for row in logKappa):
try:
first_row = next(iter(logKappa))
if all(isinstance(x, (int, float)) for x in first_row):
return
else:
raise ValueError("logKappa must be fully numeric")
except StopIteration:
raise ValueError("logKappa must be a non-empty 2D iterable")
else:
raise TypeError("logKappa must be a non-empty 2D array or iterable")
@staticmethod
def validate_1D(arr, name: str):
"""
@brief Validate a 1D array.
@param arr The array to validate.
@param name The name of the array.
@throws ValueError if the array is not 1D or not fully numeric.
@throws TypeError if the array is not a non-empty 1D array or iterable.
"""
if isinstance(arr, np.ndarray):
if arr.ndim == 1:
return
else:
raise ValueError(f"{name} must be a 1D numpy array")
if isinstance(arr, collectionIterable) and not isinstance(arr, (str, bytes)):
if all(isinstance(x, (int, float)) for x in arr):
return
else:
raise ValueError(f"{name} must be fully numeric")
else:
raise TypeError(f"{name} must be a non-empty 1D array or iterable")
@staticmethod
def compute_checksum(data: np.ndarray) -> bytes:
"""
@brief Compute the SHA-256 checksum of the given data.
@param data The data to compute the checksum for.
@return The SHA-256 checksum.
"""
return hashlib.sha256(data.tobytes()).digest()
def set_version(self, version: int) -> int:
"""
@brief Set the version of the OPAT file.
@param version The version to set.
@return The set version.
"""
self.header.version = version
return self.header.version
def set_source(self, source: str) -> str:
"""
@brief Set the source information of the OPAT file.
@param source The source information to set.
@return The set source information.
@throws TypeError if the source string is too long.
"""
if not self.validate_char_array_size(source, 64):
raise TypeError(f"sourceInfo string ({source}) is too long ({len(source)}). Max length is 64")
self.header.sourceInfo = source
return self.header.sourceInfo
def set_comment(self, comment: str) -> str:
"""
@brief Set the comment of the OPAT file.
@param comment The comment to set.
@return The set comment.
@throws TypeError if the comment string is too long.
"""
if not self.validate_char_array_size(comment, 128):
raise TypeError(f"comment string ({comment}) is too long ({len(comment)}). Max length is 128")
self.header.comment = comment
return self.header.comment
def set_numIndex(self, numIndex: int) -> int:
"""
@brief Set the number of values to use when indexing table.
@param numIndex The number of values to use when indexing table.
@return The set number of values to use when indexing table.
"""
if numIndex < 1:
raise ValueError(f"numIndex must be greater than 0! It is currently {numIndex}")
self.header.numIndex = numIndex
return self.header.numIndex
def add_table(self, indicies: Tuple[float], logR: Iterable[float], logT: Iterable[float], logKappa: Iterable[Iterable[float]]):
"""
@brief Add a table to the OPAT file.
@param indicies The index values of the table.
@param logR The logR values.
@param logT The logT values.
@param logKappa The logKappa values.
@throws ValueError if logKappa is not a non-empty 2D array or if logR and logT are not 1D arrays.
"""
if len(indicies) != self.header.numIndex:
raise ValueError(f"indicies must have length {self.header.numIndex}! Currently it has length {len(indicies)}")
self.validate_logKappa(logKappa)
self.validate_1D(logR, "logR")
self.validate_1D(logT, "logT")
logR = np.array(logR)
logT = np.array(logT)
logKappa = np.array(logKappa)
if logKappa.shape != (logR.shape[0], logT.shape[0]):
raise ValueError(f"logKappa must be of shape ({len(logR)} x {len(logT)})! Currently logKappa has shape {logKappa.shape}")
table = OPATTable(
N_R = logR.shape[0],
N_T = logT.shape[0],
logR = logR,
logT = logT,
logKappa = logKappa
)
self.tables.append((indicies, table))
self.header.numTables += 1
def _header_bytes(self) -> bytes:
"""
@brief Convert the header to bytes.
@return The header as bytes.
"""
headerBytes = struct.pack(
"<4s H I I Q 16s 64s 128s H 24s",
self.header.magic.encode('utf-8'),
self.header.version,
self.header.numTables,
self.header.headerSize,
self.header.indexOffset,
self.header.creationDate.encode('utf-8'),
self.header.sourceInfo.encode('utf-8'),
self.header.comment.encode('utf-8'),
self.header.numIndex,
self.header.reserved
)
return headerBytes
def _table_bytes(self, table: OPATTable) -> Tuple[bytes, bytes]:
"""
@brief Convert a table to bytes.
@param table The OPAT table.
@return A tuple containing the checksum and the table as bytes.
"""
logR = table.logR.flatten()
logT = table.logT.flatten()
logKappa = table.logKappa.flatten()
tableBytes = struct.pack(
f"<II{table.N_R}d{table.N_T}d{table.N_R*table.N_T}d",
table.N_R,
table.N_T,
*logR,
*logT,
*logKappa
)
checksum = self.compute_checksum(logKappa)
return (checksum, tableBytes)
def _tableIndex_bytes(self, tableIndex: TableIndex) -> bytes:
"""
@brief Convert a table index to bytes.
@param tableIndex The table index.
@return The table index as bytes.
@throws RuntimeError if the table index entry does not have 64 bytes.
"""
tableIndexFMTString = "<"+"d"*self.header.numIndex+f"QQ"
tableIndexBytes = struct.pack(
tableIndexFMTString,
*tableIndex.index,
tableIndex.byteStart,
tableIndex.byteEnd
)
tableIndexBytes += tableIndex.sha256
if len(tableIndexBytes) != 16+self.header.numIndex*8+32:
raise RuntimeError(f"Each table index entry must have 64 bytes. Due to an unknown error the table index entry for (X,Z)=({tableIndex.X},{tableIndex.Z}) header has {len(tableIndexBytes)} bytes")
return tableIndexBytes
def __repr__(self) -> str:
"""
@brief Get the string representation of the OpatIO object.
@return The string representation.
"""
reprString = f"""OpatIO(
version: {self.header.version}
numTables: {self.header.numTables}
headerSize: {self.header.headerSize}
indexOffset: {self.header.indexOffset}
creationDate: {self.header.creationDate}
sourceInfo: {self.header.sourceInfo}
comment: {self.header.comment}
numIndex: {self.header.numIndex}
reserved: {self.header.reserved}
)"""
return reprString
def _format_table_as_string(self, table: OPATTable, indices: List[float]) -> str:
"""
@brief Format a table as a string.
@param table The OPAT table.
@indices The index values of the table.
@return The formatted table as a string.
"""
tableString: List[str] = []
# fixed width X and Z header per table
tableIndexString: List[str] = []
for index in indices:
tableIndexString.append(f"{index:<10.4f}")
tableString.append(" ".join(tableIndexString))
tableString.append("-" * 80)
# write logR across the top (reserving one col for where logT will be)
tableString.append(f"{'':<10}{'logR':<10}")
logRRow = f"{'logT':<10}"
logRRowTrue = "".join(f"{r:<10.4f}" for r in table.logR)
tableString.append(logRRow + logRRowTrue)
for i, logT in enumerate(table.logT):
row = f"{logT:<10.4f}"
for kappa in table.logKappa[:, i]:
row += f"{kappa:<10.4f}"
tableString.append(row)
tableString.append("=" * 80)
return '\n'.join(tableString)
@staticmethod
def print_table_indexes(table_indexes: List[TableIndex]) -> str:
"""
@brief Print the table indexes.
@param table_indexes The list of table indexes.
@return The formatted table indexes as a string.
"""
if not table_indexes:
print("No table indexes found.")
return
tableRows: List[str] = []
tableRows.append("\nTable Indexes in OPAT File:\n")
headerString: str = ''
for indexID, index in enumerate(table_indexes[0].index):
indexKey = f"Index {indexID}"
headerString += f"{indexKey:<10}"
headerString += f"{'Byte Start':<15} {'Byte End':<15} {'Checksum (SHA-256)'}"
tableRows.append(headerString)
tableRows.append("=" * 80)
for entry in table_indexes:
tableEntry = ''
for index in entry.index:
tableEntry += f"{index:<10.4f}"
tableEntry += f"{entry.byteStart:<15} {entry.byteEnd:<15} {entry.sha256[:16]}..."
tableRows.append(tableEntry)
return '\n'.join(tableRows)
def save_as_ascii(self, filename: str) -> str:
"""
@brief Save the OPAT file as an ASCII file.
@param filename The name of the file.
@return The name of the saved file.
"""
numericFMT = "{:.18e}"
currentStartByte: int = 256
tableIndexs: List[bytes] = []
tableStrings: List[bytes] = []
for index, table in self.tables:
checksum, tableBytes = self._table_bytes(table)
tableStrings.append(self._format_table_as_string(table, index) + "\n")
tableIndex = TableIndex(
index = index,
byteStart = currentStartByte,
byteEnd = currentStartByte + len(tableBytes),
sha256 = checksum
)
tableIndexs.append(tableIndex)
currentStartByte += len(tableBytes)
self.header.indexOffset = currentStartByte
with open(filename, 'w') as f:
f.write("This is an ASCII representation of an OPAT file, it is not a valid OPAT file in and of itself.\n")
f.write("This file is meant to be human readable and is not meant to be read by a computer.\n")
f.write("The purpose of this file is to provide a human readable representation of the OPAT file which can be used for debugging purposes.\n")
f.write("The full binary specification of the OPAT file can be found in the OPAT file format documentation at:\n")
f.write(" https://github.com/4D-STAR/4DSSE/blob/main/specs/OPAT/OPAT.pdf\n")
f.write("="*35 + " HEADER " + "="*36 + "\n")
f.write(f">> {self.header.magic}\n")
f.write(f">> Version: {self.header.version}\n")
f.write(f">> numTables: {self.header.numTables}\n")
f.write(f">> headerSize (bytes): {self.header.headerSize}\n")
f.write(f">> tableIndex Offset (bytes): {self.header.indexOffset}\n")
f.write(f">> Creation Date: {self.header.creationDate}\n")
f.write(f">> Source Info: {self.header.sourceInfo}\n")
f.write(f">> Comment: {self.header.comment}\n")
f.write(f">> numIndex: {self.header.numIndex}\n")
f.write("="*37 + " DATA " + "="*37 + "\n")
f.write("="*80 + "\n")
for tableString in tableStrings:
f.write(tableString)
f.write("="*80 + "\n")
f.write("="*36 + " INDEX " + "="*37 + "\n")
f.write(self.print_table_indexes(tableIndexs))
def save(self, filename: str) -> str:
"""
@brief Save the OPAT file as a binary file.
@param filename The name of the file.
@return The name of the saved file.
@throws RuntimeError if the header does not have 256 bytes.
"""
tempHeaderBytes = self._header_bytes()
if len(tempHeaderBytes) != 256:
raise RuntimeError(f"Header must have 256 bytes. Due to an unknown error the header has {len(tempHeaderBytes)} bytes")
currentStartByte: int = 256
tableIndicesBytes: List[bytes] = []
tablesBytes: List[bytes] = []
for index, table in self.tables:
checksum, tableBytes = self._table_bytes(table)
tableIndex = TableIndex(
index,
byteStart = currentStartByte,
byteEnd = currentStartByte + len(tableBytes),
sha256 = checksum
)
tableIndexBytes = self._tableIndex_bytes(tableIndex)
tablesBytes.append(tableBytes)
tableIndicesBytes.append(tableIndexBytes)
currentStartByte += len(tableBytes)
self.header.indexOffset = currentStartByte
headerBytes = self._header_bytes()
with open(filename, 'wb') as f:
f.write(headerBytes)
for tableBytes in tablesBytes:
f.write(tableBytes)
for tableIndexBytes in tableIndicesBytes:
f.write(tableIndexBytes)
if os.path.exists(filename):
return filename
def loadOpat(filename: str) -> OpatIO:
"""
@brief Load an OPAT file.
@param filename The name of the file.
@return The loaded OpatIO object.
@throws RuntimeError if the header does not have 256 bytes.
"""
opat = OpatIO()
with open(filename, 'rb') as f:
headerBytes: bytes = f.read(256)
unpackedHeader = struct.unpack("<4s H I I Q 16s 64s 128s H 24s", headerBytes)
loadedHeader = Header(
magic = unpackedHeader[0].decode().replace("\x00", ""),
version = unpackedHeader[1],
numTables = unpackedHeader[2],
headerSize = unpackedHeader[3],
indexOffset = unpackedHeader[4],
creationDate = unpackedHeader[5].decode().replace("\x00", ""),
sourceInfo = unpackedHeader[6].decode().replace("\x00", ""),
comment = unpackedHeader[7].decode().replace("\x00", ""),
numIndex = unpackedHeader[8],
reserved = unpackedHeader[9]
)
opat.header = loadedHeader
f.seek(opat.header.indexOffset)
tableIndices: List[TableIndex] = []
tableIndexChunkSize = 16 + loadedHeader.numIndex*8
tableIndexFMTString = "<"+"d"*loadedHeader.numIndex+"QQ"
while tableIndexEntryBytes := f.read(tableIndexChunkSize):
unpackedTableIndexEntry = struct.unpack(tableIndexFMTString, tableIndexEntryBytes)
checksum = f.read(32)
index = unpackedTableIndexEntry[:loadedHeader.numIndex]
tableIndexEntry = TableIndex(
index = index,
byteStart = unpackedTableIndexEntry[loadedHeader.numIndex],
byteEnd = unpackedTableIndexEntry[loadedHeader.numIndex+1],
sha256 = checksum
)
tableIndices.append(tableIndexEntry)
currentStartByte = opat.header.headerSize
f.seek(currentStartByte)
for tableIndex in tableIndices:
f.seek(tableIndex.byteStart)
byteLength = tableIndex.byteEnd - tableIndex.byteStart
tableBytes = f.read(byteLength)
nr_nt_fmt = "<II"
nr_nt_size = struct.calcsize(nr_nt_fmt)
N_R, N_T = struct.unpack(nr_nt_fmt, tableBytes[:nr_nt_size])
dataFormat = f"<{N_R}d{N_T}d{N_R*N_T}d"
unpackedData = struct.unpack(dataFormat, tableBytes[nr_nt_size:])
logR = np.array(unpackedData[:N_R], dtype=np.float64)
logT = np.array(unpackedData[N_R: N_R+N_T], dtype=np.float64)
logKappa = np.array(unpackedData[N_R+N_T:], dtype=np.float64).reshape((N_R, N_T))
opat.add_table(tableIndex.index, logR, logT, logKappa)
return opat

View File

@@ -1,27 +0,0 @@
from opatio import OpatIO
import numpy as np
np.random.seed(42)
def generate_synthetic_opacity_table(n_r, n_t):
logR = np.linspace(-8, 2, n_r, dtype=np.float64) # log Density grid
logT = np.linspace(3, 9, n_t, dtype=np.float64) # log Temperature grid
logK = np.random.uniform(-2, 2, size=(n_r, n_t)).astype(np.float64) # Synthetic Opacity
return logR, logT, logK
if __name__ == "__main__":
n_r = 50
n_t = 50
num_tables = 20
XValues = np.linspace(0.1, 0.7, num_tables)
ZValues = np.linspace(0.001, 0.03, num_tables)
opat = OpatIO()
opat.set_comment("Synthetic Opacity Tables")
opat.set_source("utils/opatio/utils/mkTestData.py")
for i in range(num_tables):
logR, logT, logK = generate_synthetic_opacity_table(n_r, n_t)
opat.add_table((XValues[i], ZValues[i]), logR, logT, logK)
opat.save("testData/synthetic_tables.opat")
opat.save_as_ascii("testData/synthetic_tables_OPAT.ascii")

File diff suppressed because it is too large Load Diff