feat(debug-utils): added framework for shared debug util tools
This commit is contained in:
11
meson.build
11
meson.build
@@ -43,11 +43,18 @@ subdir('build-config')
|
||||
|
||||
subdir('assets/static')
|
||||
|
||||
if get_option('build_debug_utils')
|
||||
subdir('utils/debugUtils')
|
||||
endif
|
||||
|
||||
|
||||
# Build the main project
|
||||
subdir('src')
|
||||
if get_option('build_tests')
|
||||
subdir('tests')
|
||||
endif
|
||||
|
||||
# Build the utilities
|
||||
subdir('utils')
|
||||
if get_option('build_post_run_utils')
|
||||
subdir('utils')
|
||||
endif
|
||||
|
||||
|
||||
@@ -1,2 +1,4 @@
|
||||
option('build_tests', type: 'boolean', value: true, description: 'Build tests')
|
||||
option('user_mode', type: 'boolean', value: false, description: 'Enable user mode (set mode = 0)')
|
||||
option('build_post_run_utils', type: 'boolean', value: true, description: 'Build Helper Utilities')
|
||||
option('build_debug_utils', type: 'boolean', value: true, description: 'Build Debug Utilities')
|
||||
|
||||
@@ -30,6 +30,7 @@ dependencies = [
|
||||
quill_dep,
|
||||
config_dep,
|
||||
types_dep,
|
||||
mfemanalysis_dep,
|
||||
]
|
||||
|
||||
libpolyutils = static_library('polyutils',
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
mfemanalysis_dep = declare_dependency(include_directories: 'src/include')
|
||||
@@ -0,0 +1,166 @@
|
||||
//
|
||||
// Created by Emily Boudreaux on 4/10/25.
|
||||
//
|
||||
|
||||
#ifndef MFEM_SMOUT_H
|
||||
#define MFEM_SMOUT_H
|
||||
|
||||
#include "mfem.hpp"
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
|
||||
/**
|
||||
* @brief Saves an mfem::SparseMatrix to a custom compact binary file (.csrbin).
|
||||
*
|
||||
* @param mat The mfem::SparseMatrix to save (assumed to be in CSR format).
|
||||
* @param filename The path to the output file.
|
||||
* @return true if saving was successful, false otherwise.
|
||||
*
|
||||
* File Format (.csrbin):
|
||||
* - Magic (4 bytes): 'C','S','R','B'
|
||||
* - Version (1 byte): 1
|
||||
* - IntSize (1 byte): 8 (using int64_t for indices/dims)
|
||||
* - FltSize (1 byte): 8 (using double for data)
|
||||
* - Reserved (1 byte): 0
|
||||
* - Height (uint64_t): Number of rows
|
||||
* - Width (uint64_t): Number of columns
|
||||
* - NNZ (uint64_t): Number of non-zeros
|
||||
* - I array (int64_t * (Height + 1)): CSR Row Pointers
|
||||
* - J array (int64_t * NNZ): CSR Column Indices
|
||||
* - Data array (double * NNZ): CSR Non-zero values
|
||||
*/
|
||||
bool saveSparseMatrixBinary(const mfem::SparseMatrix& mat, const std::string& filename) {
|
||||
std::ofstream outfile(filename, std::ios::binary | std::ios::trunc);
|
||||
if (!outfile) {
|
||||
std::cerr << "Error: Cannot open file for writing: " << filename << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
try {
|
||||
// --- Get Data Pointers and Dimensions from MFEM Matrix ---
|
||||
const int* mfem_I = mat.GetI();
|
||||
const int* mfem_J = mat.GetJ();
|
||||
const double* mfem_data = mat.GetData();
|
||||
|
||||
uint64_t height = static_cast<uint64_t>(mat.Height());
|
||||
uint64_t width = static_cast<uint64_t>(mat.Width());
|
||||
uint64_t nnz = static_cast<uint64_t>(mat.NumNonZeroElems());
|
||||
uint64_t i_count = height + 1;
|
||||
uint64_t j_count = nnz;
|
||||
uint64_t data_count = nnz;
|
||||
|
||||
|
||||
// --- Write Header ---
|
||||
const char magic[4] = {'C', 'S', 'R', 'B'};
|
||||
const uint8_t version = 1;
|
||||
const uint8_t int_size = 8;
|
||||
const uint8_t flt_size = 8;
|
||||
const uint8_t reserved = 0;
|
||||
|
||||
outfile.write(magic, 4);
|
||||
outfile.write(reinterpret_cast<const char*>(&version), 1);
|
||||
outfile.write(reinterpret_cast<const char*>(&int_size), 1);
|
||||
outfile.write(reinterpret_cast<const char*>(&flt_size), 1);
|
||||
outfile.write(reinterpret_cast<const char*>(&reserved), 1);
|
||||
|
||||
outfile.write(reinterpret_cast<const char*>(&height), sizeof(height));
|
||||
outfile.write(reinterpret_cast<const char*>(&width), sizeof(width));
|
||||
outfile.write(reinterpret_cast<const char*>(&nnz), sizeof(nnz));
|
||||
|
||||
if (!outfile) throw std::runtime_error("Error writing header.");
|
||||
|
||||
// --- Write Arrays (Converting int to int64_t for I and J) ---
|
||||
std::vector<int64_t> i_buffer(i_count);
|
||||
for (uint64_t idx = 0; idx < i_count; ++idx) {
|
||||
i_buffer[idx] = static_cast<int64_t>(mfem_I[idx]);
|
||||
}
|
||||
outfile.write(reinterpret_cast<const char*>(i_buffer.data()), i_count * sizeof(int64_t));
|
||||
if (!outfile) throw std::runtime_error("Error writing I array.");
|
||||
|
||||
std::vector<int64_t> j_buffer(j_count);
|
||||
for (uint64_t idx = 0; idx < j_count; ++idx) {
|
||||
j_buffer[idx] = static_cast<int64_t>(mfem_J[idx]);
|
||||
}
|
||||
outfile.write(reinterpret_cast<const char*>(j_buffer.data()), j_count * sizeof(int64_t));
|
||||
if (!outfile) throw std::runtime_error("Error writing J array.");
|
||||
|
||||
outfile.write(reinterpret_cast<const char*>(mfem_data), data_count * sizeof(double));
|
||||
if (!outfile) throw std::runtime_error("Error writing Data array.");
|
||||
|
||||
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << "Error during binary matrix save: " << e.what() << std::endl;
|
||||
outfile.close();
|
||||
return false;
|
||||
}
|
||||
|
||||
outfile.close();
|
||||
if (!outfile) {
|
||||
std::cerr << "Error closing file after writing: " << filename << std::endl;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void writeDenseMatrixToCSV(const std::string &filename, int precision, const mfem::DenseMatrix *mat) {
|
||||
if (!mat) {
|
||||
throw std::runtime_error("The operator is not a SparseMatrix.");
|
||||
}
|
||||
|
||||
std::ofstream outfile(filename);
|
||||
if (!outfile.is_open()) {
|
||||
throw std::runtime_error("Failed to open file: " + filename);
|
||||
}
|
||||
|
||||
|
||||
int height = mat->Height();
|
||||
int width = mat->Width();
|
||||
|
||||
// Set precision for floating-point output
|
||||
outfile << std::fixed << std::setprecision(precision);
|
||||
|
||||
for (int i = 0; i < width; i++) {
|
||||
outfile << i;
|
||||
if (i < width - 1) {
|
||||
outfile << ",";
|
||||
}
|
||||
else {
|
||||
outfile << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
// Iterate through rows
|
||||
for (int i = 0; i < height; ++i) {
|
||||
for (int j = 0; j < width; ++j) {
|
||||
outfile << mat->Elem(i, j);
|
||||
if (j < width - 1) {
|
||||
outfile << ",";
|
||||
}
|
||||
}
|
||||
outfile << std::endl;
|
||||
}
|
||||
|
||||
outfile.close();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Writes the dense representation of an MFEM Operator (if it's a SparseMatrix) to a CSV file.
|
||||
*
|
||||
* @param op The MFEM Operator to write.
|
||||
* @param filename The name of the output CSV file.
|
||||
* @param precision Number of decimal places for floating-point values.
|
||||
*/
|
||||
void writeOperatorToCSV(const mfem::Operator &op,
|
||||
const std::string &filename,
|
||||
int precision = 6) // Add precision argument
|
||||
{
|
||||
// Attempt to cast the Operator to a SparseMatrix
|
||||
const auto *sparse_mat = dynamic_cast<const mfem::SparseMatrix*>(&op);
|
||||
if (!sparse_mat) {
|
||||
throw std::runtime_error("The operator is not a SparseMatrix.");
|
||||
}
|
||||
const mfem::DenseMatrix *mat = sparse_mat->ToDenseMatrix();
|
||||
writeDenseMatrixToCSV(filename, precision, mat);
|
||||
}
|
||||
|
||||
#endif //MFEM_SMOUT_H
|
||||
36
utils/debugUtils/MFEMAnalysisUtils/SSEDebug/pyproject.toml
Normal file
36
utils/debugUtils/MFEMAnalysisUtils/SSEDebug/pyproject.toml
Normal file
@@ -0,0 +1,36 @@
|
||||
[build-system]
|
||||
requires = ["setuptools", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "SSEDebug"
|
||||
version = "0.1.0"
|
||||
description = "A python module for general 4DSSE debugging"
|
||||
readme = "readme.md"
|
||||
authors = [
|
||||
{name = "Emily M. Boudreaux", email = "emily.boudreaux@dartmouth.edu"},
|
||||
{name = "4D-STAR Collaboration"},
|
||||
]
|
||||
|
||||
maintainers = [
|
||||
{name = "Emily M. Boudreaux", email="emily.boudreaux@dartmouth.edu"}
|
||||
]
|
||||
|
||||
keywords = ["astrophysics", "MFEM"]
|
||||
requires-python = ">=3.8"
|
||||
dependencies = ["numpy >= 1.21.1", "scipy>=1.13.1"]
|
||||
|
||||
classifiers = [
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Intended Audience :: Science/Research",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Topic :: Scientific/Engineering :: Astronomy",
|
||||
"Operating System :: OS Independent"
|
||||
]
|
||||
|
||||
|
||||
[tool.setuptools]
|
||||
package-dir = {"" = "src"}
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@@ -0,0 +1 @@
|
||||
__version__="0.1.0"
|
||||
@@ -0,0 +1 @@
|
||||
from .smread import loadSparseMatrixBinary, analyze_sparse_matrix, load_and_analyze_sparse_matrix
|
||||
@@ -0,0 +1,234 @@
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import scipy.sparse as sp
|
||||
import struct
|
||||
import scipy.sparse.linalg as spla # For matrix norm
|
||||
import time
|
||||
import os
|
||||
|
||||
def loadSparseMatrixBinary(filename):
|
||||
"""
|
||||
Loads a sparse matrix from the custom binary format (.csrbin).
|
||||
|
||||
Args:
|
||||
filename (str): The path to the .csrbin file.
|
||||
|
||||
Returns:
|
||||
scipy.sparse.csr_matrix: The loaded sparse matrix.
|
||||
|
||||
Raises:
|
||||
ValueError: If the file format is incorrect or sizes don't match.
|
||||
IOError: If the file cannot be read.
|
||||
"""
|
||||
INT_SIZE = 8 # Expecting int64_t from the C++ writer
|
||||
FLT_SIZE = 8 # Expecting double from the C++ writer
|
||||
EXPECTED_MAGIC = b'CSRB'
|
||||
EXPECTED_VERSION = 1
|
||||
|
||||
try:
|
||||
with open(filename, 'rb') as f:
|
||||
# --- Read Header ---
|
||||
magic = f.read(4)
|
||||
if magic != EXPECTED_MAGIC:
|
||||
raise ValueError(f"Invalid magic number. Expected {EXPECTED_MAGIC}, got {magic}")
|
||||
|
||||
version, int_size_file, flt_size_file, reserved = struct.unpack('<BBBB', f.read(4))
|
||||
# '<' means little-endian, 'B' means unsigned char (1 byte)
|
||||
|
||||
if version != EXPECTED_VERSION:
|
||||
print(f"Warning: File version {version} differs from expected {EXPECTED_VERSION}.")
|
||||
if int_size_file != INT_SIZE:
|
||||
raise ValueError(f"Integer size mismatch. Expected {INT_SIZE}, file has {int_size_file}")
|
||||
if flt_size_file != FLT_SIZE:
|
||||
raise ValueError(f"Float size mismatch. Expected {FLT_SIZE}, file has {flt_size_file}")
|
||||
|
||||
height, width, nnz = struct.unpack('<QQQ', f.read(24))
|
||||
# '<' means little-endian, 'Q' means unsigned long long (8 bytes)
|
||||
|
||||
i_count = height + 1
|
||||
j_count = nnz
|
||||
data_count = nnz
|
||||
|
||||
if nnz == 0: # Handle empty matrix case
|
||||
print("Warning: Matrix file contains zero non-zero elements.")
|
||||
# Return an empty matrix with correct shape
|
||||
return sp.csr_matrix((height, width), dtype=np.float64)
|
||||
|
||||
|
||||
# --- Read Arrays ---
|
||||
|
||||
# Read I array (Row Pointers)
|
||||
expected_i_bytes = i_count * INT_SIZE
|
||||
I_array = np.fromfile(f, dtype=np.int64, count=i_count) # Read as int64
|
||||
if I_array.size != i_count:
|
||||
raise ValueError(f"Error reading I array. Expected {i_count} elements, read {I_array.size}. File truncated or corrupt?")
|
||||
|
||||
# Read J array (Column Indices)
|
||||
expected_j_bytes = j_count * INT_SIZE
|
||||
J_array = np.fromfile(f, dtype=np.int64, count=j_count) # Read as int64
|
||||
if J_array.size != j_count:
|
||||
raise ValueError(f"Error reading J array. Expected {j_count} elements, read {J_array.size}. File truncated or corrupt?")
|
||||
|
||||
# Read Data array (Values)
|
||||
expected_data_bytes = data_count * FLT_SIZE
|
||||
Data_array = np.fromfile(f, dtype=np.float64, count=data_count) # Read as float64
|
||||
if Data_array.size != data_count:
|
||||
raise ValueError(f"Error reading Data array. Expected {data_count} elements, read {Data_array.size}. File truncated or corrupt?")
|
||||
|
||||
# --- Check for extra data ---
|
||||
extra_data = f.read()
|
||||
if extra_data:
|
||||
print(f"Warning: {len(extra_data)} extra bytes found at the end of the file.")
|
||||
|
||||
|
||||
# --- Construct SciPy CSR Matrix ---
|
||||
sparse_matrix = sp.csr_matrix((Data_array, J_array, I_array), shape=(height, width))
|
||||
|
||||
if sparse_matrix.nnz != nnz:
|
||||
print(f"Warning: NNZ mismatch after loading. Header NNZ: {nnz}, Scipy NNZ: {sparse_matrix.nnz}")
|
||||
|
||||
|
||||
return sparse_matrix
|
||||
|
||||
except FileNotFoundError:
|
||||
raise IOError(f"Error: File not found at {filename}")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"An error occurred while reading {filename}: {e}")
|
||||
|
||||
|
||||
def analyze_sparse_matrix(sp_mat):
|
||||
"""
|
||||
Analyzes a SciPy sparse matrix and prints various statistics.
|
||||
|
||||
Args:
|
||||
sp_mat (scipy.sparse.spmatrix): The sparse matrix to analyze.
|
||||
(e.g., csr_matrix, csc_matrix).
|
||||
"""
|
||||
print("-" * 50)
|
||||
print("Sparse Matrix Analysis Report")
|
||||
print("-" * 50)
|
||||
|
||||
if not isinstance(sp_mat, sp.spmatrix):
|
||||
print("Error: Input is not a SciPy sparse matrix.")
|
||||
return
|
||||
|
||||
rows, cols = sp_mat.shape
|
||||
print(f"Size (Shape): {rows} rows x {cols} columns")
|
||||
|
||||
if rows == 0 or cols == 0:
|
||||
print("\nMatrix is empty. No further analysis possible.")
|
||||
print("-" * 50)
|
||||
return
|
||||
|
||||
nnz = sp_mat.nnz
|
||||
total_elements = rows * cols
|
||||
sparsity = 0.0
|
||||
if total_elements > 0:
|
||||
sparsity = 1.0 - (nnz / total_elements)
|
||||
else:
|
||||
sparsity = 1.0
|
||||
|
||||
print(f"Non-zero elements (NNZ): {nnz}")
|
||||
print(f"Total elements: {total_elements}")
|
||||
print(f"Sparsity: {sparsity:.6%} (percentage of zeros)")
|
||||
|
||||
if nnz == 0:
|
||||
print("\nMatrix contains only zero elements.")
|
||||
diag_elements = sp_mat.diagonal()
|
||||
print(f"\nDiagonal Mean: {np.mean(diag_elements):.6e}")
|
||||
print(f"Diagonal Max: {np.max(diag_elements):.6e}")
|
||||
print(f"Diagonal Min: {np.min(diag_elements):.6e}")
|
||||
print(f"Value Range (Min): N/A (no non-zero values)")
|
||||
print(f"Value Range (Max): N/A (no non-zero values)")
|
||||
print(f"Mean Non-Zero Value: N/A (no non-zero values)")
|
||||
print(f"Relative Diagonal Norm: N/A (matrix norm is zero)")
|
||||
print("-" * 50)
|
||||
return
|
||||
|
||||
all_values = sp_mat.data # Access non-zero values directly
|
||||
min_val = np.min(all_values)
|
||||
max_val = np.max(all_values)
|
||||
mean_val = np.mean(all_values)
|
||||
print(f"\nValue Range (Min): {min_val:.6e}")
|
||||
print(f"Value Range (Max): {max_val:.6e}")
|
||||
print(f"Mean Non-Zero Value: {mean_val:.6e}")
|
||||
|
||||
|
||||
print("\n--- Diagonal Properties ---")
|
||||
start_diag = time.time()
|
||||
diag_elements = sp_mat.diagonal()
|
||||
end_diag = time.time()
|
||||
print(f"(Diagonal extraction time: {end_diag - start_diag:.4f}s)")
|
||||
|
||||
if diag_elements.size > 0: # Should always be true unless rows=0 (handled above)
|
||||
mean_diag = np.mean(diag_elements)
|
||||
max_diag = np.max(diag_elements)
|
||||
min_diag = np.min(diag_elements)
|
||||
|
||||
diag_nonzero = diag_elements[diag_elements != 0]
|
||||
if diag_nonzero.size > 0:
|
||||
mean_diag_nz = np.mean(diag_nonzero)
|
||||
print(f"Mean Diagonal (all): {mean_diag:.6e}")
|
||||
print(f"Mean Diagonal (non-zero):{mean_diag_nz:.6e} ({diag_nonzero.size} elements)")
|
||||
else:
|
||||
print(f"Mean Diagonal (all): {mean_diag:.6e}")
|
||||
print(f"Mean Diagonal (non-zero): N/A (all diagonal elements are zero)")
|
||||
|
||||
print(f"Max Diagonal: {max_diag:.6e}")
|
||||
print(f"Min Diagonal: {min_diag:.6e}")
|
||||
|
||||
# 5. "Diagonality" - Relative Diagonal Norm (using Frobenius norm)
|
||||
# The Frobenius norm is sqrt(sum(abs(A_ij)^2))
|
||||
start_norm = time.time()
|
||||
norm_diag = np.linalg.norm(diag_elements)
|
||||
norm_matrix = spla.norm(sp_mat, ord='fro')
|
||||
end_norm = time.time()
|
||||
print(f"(Norm calculation time: {end_norm - start_norm:.4f}s)")
|
||||
|
||||
if norm_matrix > 1e-15: # Avoid division by zero
|
||||
diagonality_ratio = norm_diag / norm_matrix
|
||||
print(f"\nRelative Diagonal Norm (Frobenius): {diagonality_ratio:.6f}")
|
||||
print(f" (Ratio of ||diag(A)||_F / ||A||_F)")
|
||||
print(f" (Diagonal Norm = {norm_diag:.6e}, Matrix Norm = {norm_matrix:.6e})")
|
||||
if diagonality_ratio > 0.99:
|
||||
print(" -> Matrix is strongly diagonal dominant by norm.")
|
||||
elif diagonality_ratio < 0.1:
|
||||
print(" -> Matrix norm is dominated by off-diagonal elements.")
|
||||
else:
|
||||
print("\nRelative Diagonal Norm: N/A (matrix Frobenius norm is zero)")
|
||||
|
||||
else: # Should not happen if rows > 0
|
||||
print("\nCould not extract diagonal (matrix has zero rows?).")
|
||||
|
||||
|
||||
# 6. Other Useful Stats
|
||||
print("\n--- Other Properties ---")
|
||||
is_square = rows == cols
|
||||
print(f"Is Square: {is_square}")
|
||||
if is_square:
|
||||
try:
|
||||
diff_norm = spla.norm(sp_mat - sp_mat.T, ord='fro')
|
||||
if diff_norm < 1e-10 * norm_matrix : # Check relative difference norm
|
||||
print(f"Is Symmetric (approx): True (||A - A.T||_F / ||A||_F < 1e-10)")
|
||||
else:
|
||||
print(f"Is Symmetric (approx): False (||A - A.T||_F = {diff_norm:.2e})")
|
||||
except Exception as e:
|
||||
print(f"Is Symmetric (approx): Check failed ({e})")
|
||||
else:
|
||||
print(f"Is Symmetric (approx): False (not square)")
|
||||
|
||||
|
||||
print("-" * 50)
|
||||
|
||||
def load_and_analyze_sparse_matrix(filename: str):
|
||||
sm = loadSparseMatrixBinary(filename)
|
||||
analyze_sparse_matrix(sm)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Simple tool to get some statistics about a sparse matrix from mfem")
|
||||
parser.add_argument("path", help="path to the output file", type=str)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
load_and_analyze_sparse_matrix(args.filename)
|
||||
1
utils/debugUtils/MFEMAnalysisUtils/meson.build
Normal file
1
utils/debugUtils/MFEMAnalysisUtils/meson.build
Normal file
@@ -0,0 +1 @@
|
||||
subdir('MFEMAnalysis-cpp')
|
||||
12
utils/debugUtils/MFEMAnalysisUtils/readme.md
Normal file
12
utils/debugUtils/MFEMAnalysisUtils/readme.md
Normal file
@@ -0,0 +1,12 @@
|
||||
# Tools for analyzing MFEM Sparse Matricies (among other things)
|
||||
MFEM does a lot of work with sparse matrixes but does not provide trivial tools to use them. Here I include some basic utilities to analyze these matricies.
|
||||
|
||||
## Python
|
||||
There is a python script to preform the actual analysis.
|
||||
|
||||
## C++
|
||||
There is a small C++ header only library which provides an interface to write MFEM sparse matrixes out to disk.
|
||||
|
||||
The C++ utility writes mfem sparse matricies in a custom format which was written to be simple. The python script
|
||||
only understands this format.
|
||||
|
||||
1
utils/debugUtils/meson.build
Normal file
1
utils/debugUtils/meson.build
Normal file
@@ -0,0 +1 @@
|
||||
subdir('MFEMAnalysisUtils')
|
||||
@@ -1,16 +0,0 @@
|
||||
[build-system]
|
||||
requires = ["setuptools", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "opatio"
|
||||
version = "0.1.0a"
|
||||
description = "A python module for handling OPAT files"
|
||||
readme = "readme.md"
|
||||
authors = [{name = "Emily M. Boudreaux", email = "emily.boudreaux@dartmouth.edu"}]
|
||||
requires-python = ">=3.8"
|
||||
dependencies = ["numpy >= 1.21.1"]
|
||||
|
||||
[tool.setuptools]
|
||||
packages = ["opatio", "opatio.opat"]
|
||||
package-dir = {"" = "src"}
|
||||
@@ -1,46 +0,0 @@
|
||||
# opatIO python module
|
||||
This module defines a set of tools to build, write, and read OPAT files.
|
||||
The OPAT fileformat is a custom file format designed to efficiently store
|
||||
opacity information for a variety of compositions.
|
||||
|
||||
## Installation
|
||||
You can install this module with pip
|
||||
```bash
|
||||
git clone <repo>
|
||||
cd 4DSSE/utils/opat
|
||||
pip install .
|
||||
```
|
||||
|
||||
## General Usage
|
||||
The general way that this module is mean to be used is to first build a schema for the opaticy table and then save that to disk. The module will handle all the byte aligment and lookup table construction for you.
|
||||
|
||||
A simple example might look like the following
|
||||
|
||||
```python
|
||||
from opatio import OpatIO
|
||||
|
||||
opacityFile = OpatIO()
|
||||
opacityFile.set_comment("This is a sample opacity file")
|
||||
opaticyFile.set_source("OPLIB")
|
||||
|
||||
# some code to get a logR, logT, and logKappa table
|
||||
# where logKappa is of size (n,m) if logR is size n and
|
||||
# logT is size m
|
||||
|
||||
opacityFile.add_table(X, Z, logR, logT, logKappa)
|
||||
opacityFile.save("opacity.opat")
|
||||
```
|
||||
|
||||
You can also read opat files which have been generated with the loadOpat function
|
||||
|
||||
```python
|
||||
from opatio import loadOpat
|
||||
|
||||
opacityFile = loadOpat("opacity.opat")
|
||||
|
||||
print(opacityFile.header)
|
||||
print(opaticyFile.tables[0])
|
||||
```
|
||||
|
||||
## Problems
|
||||
If you have problems feel free to either submit an issue to the root github repo (tagged as utils/opatio) or email Emily Boudreaux at emily.boudreaux@dartmouth.edu
|
||||
@@ -1 +0,0 @@
|
||||
from .opat.opat import OpatIO, loadOpat
|
||||
@@ -1,544 +0,0 @@
|
||||
import struct
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from typing import Iterable, List, Tuple
|
||||
from collections.abc import Iterable as collectionIterable
|
||||
|
||||
import hashlib
|
||||
|
||||
import os
|
||||
|
||||
@dataclass
|
||||
class Header:
|
||||
"""
|
||||
@brief Structure to hold the header information of an OPAT file.
|
||||
"""
|
||||
magic: str #< Magic number to identify the file type
|
||||
version: int #< Version of the OPAT file format
|
||||
numTables: int #< Number of tables in the file
|
||||
headerSize: int #< Size of the header
|
||||
indexOffset: int #< Offset to the index section
|
||||
creationDate: str #< Creation date of the file
|
||||
sourceInfo: str #< Source information
|
||||
comment: str #< Comment section
|
||||
numIndex: int #< Number of values to use when indexing table
|
||||
reserved: bytes #< Reserved for future use
|
||||
|
||||
@dataclass
|
||||
class TableIndex:
|
||||
"""
|
||||
@brief Structure to hold the index information of a table in an OPAT file.
|
||||
"""
|
||||
index: List[float] #< Index values of the table
|
||||
byteStart: int #< Byte start position of the table
|
||||
byteEnd: int #< Byte end position of the table
|
||||
sha256: bytes #< SHA-256 hash of the table data
|
||||
|
||||
@dataclass
|
||||
class OPATTable:
|
||||
"""
|
||||
@brief Structure to hold the data of an OPAT table.
|
||||
"""
|
||||
N_R: int #< Number of R values
|
||||
N_T: int #< Number of T values
|
||||
logR: Iterable[float] #< Logarithm of R values
|
||||
logT: Iterable[float] #< Logarithm of T values
|
||||
logKappa: Iterable[Iterable[float]] #< Logarithm of Kappa values
|
||||
|
||||
def make_default_header() -> Header:
|
||||
"""
|
||||
@brief Create a default header for an OPAT file.
|
||||
@return The default header.
|
||||
"""
|
||||
return Header(
|
||||
magic="OPAT",
|
||||
version=1,
|
||||
numTables=0,
|
||||
headerSize=256,
|
||||
indexOffset=0,
|
||||
creationDate=datetime.now().strftime("%b %d, %Y"),
|
||||
sourceInfo="no source provided by user",
|
||||
comment="default header",
|
||||
numIndex=2,
|
||||
reserved=b"\x00" * 24
|
||||
)
|
||||
|
||||
class OpatIO:
|
||||
"""
|
||||
@brief Class for handling OPAT file input/output operations.
|
||||
This class provides methods to validate, manipulate, and save OPAT files. It includes functionalities to validate character arrays, 1D arrays, and 2D arrays, compute checksums, set header information, add tables, and save the OPAT file in both ASCII and binary formats.
|
||||
Attributes:
|
||||
header (Header): The header of the OPAT file.
|
||||
tables (List[Tuple[Tuple[float, float], OPATTable]]): A list of tables in the OPAT file.
|
||||
Methods:
|
||||
validate_char_array_size(s: str, nmax: int) -> bool:
|
||||
Validate the size of a character array.
|
||||
validate_logKappa(logKappa):
|
||||
Validate the logKappa array.
|
||||
validate_1D(arr, name: str):
|
||||
Validate a 1D array.
|
||||
compute_checksum(data: bytes) -> bytes:
|
||||
Compute the SHA-256 checksum of the given data.
|
||||
set_version(version: int) -> int:
|
||||
Set the version of the OPAT file.
|
||||
set_source(source: str) -> str:
|
||||
Set the source information of the OPAT file.
|
||||
set_comment(comment: str) -> str:
|
||||
Set the comment of the OPAT file.
|
||||
add_table(X: float, Z: float, logR: Iterable[float], logT: Iterable[float], logKappa: Iterable[Iterable[float]]):
|
||||
Add a table to the OPAT file.
|
||||
_header_bytes() -> bytes:
|
||||
Convert the header to bytes.
|
||||
_table_bytes(table: OPATTable) -> Tuple[bytes, bytes]:
|
||||
Convert a table to bytes.
|
||||
_tableIndex_bytes(tableIndex: TableIndex) -> bytes:
|
||||
Convert a table index to bytes.
|
||||
__repr__() -> str:
|
||||
Get the string representation of the OpatIO object.
|
||||
_format_table_as_string(table: OPATTable, X: float, Z: float) -> str:
|
||||
Format a table as a string.
|
||||
print_table_indexes(table_indexes: List[TableIndex]) -> str:
|
||||
Print the table indexes.
|
||||
save_as_ascii(filename: str) -> str:
|
||||
Save the OPAT file as an ASCII file.
|
||||
save(filename: str) -> str:
|
||||
Save the OPAT file as a binary file.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.header: Header = make_default_header()
|
||||
self.tables: List[Tuple[Tuple[float, float], OPATTable]] = []
|
||||
|
||||
@staticmethod
|
||||
def validate_char_array_size(s: str, nmax: int) -> bool:
|
||||
"""
|
||||
@brief Validate the size of a character array.
|
||||
@param s The string to validate.
|
||||
@param nmax The maximum allowed size.
|
||||
@return True if the string size is valid, False otherwise.
|
||||
"""
|
||||
if len(s) > nmax:
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def validate_logKappa(logKappa):
|
||||
"""
|
||||
@brief Validate the logKappa array.
|
||||
@param logKappa The logKappa array to validate.
|
||||
@throws ValueError if logKappa is not a non-empty 2D array.
|
||||
@throws TypeError if logKappa is not a 2D array or iterable.
|
||||
"""
|
||||
if isinstance(logKappa, np.ndarray):
|
||||
if logKappa.ndim == 2:
|
||||
return
|
||||
else:
|
||||
raise ValueError("logKappa must be a non-empty 2D array")
|
||||
|
||||
if isinstance(logKappa, collectionIterable) and all(isinstance(row, collectionIterable) for row in logKappa):
|
||||
try:
|
||||
first_row = next(iter(logKappa))
|
||||
if all(isinstance(x, (int, float)) for x in first_row):
|
||||
return
|
||||
else:
|
||||
raise ValueError("logKappa must be fully numeric")
|
||||
except StopIteration:
|
||||
raise ValueError("logKappa must be a non-empty 2D iterable")
|
||||
else:
|
||||
raise TypeError("logKappa must be a non-empty 2D array or iterable")
|
||||
|
||||
@staticmethod
|
||||
def validate_1D(arr, name: str):
|
||||
"""
|
||||
@brief Validate a 1D array.
|
||||
@param arr The array to validate.
|
||||
@param name The name of the array.
|
||||
@throws ValueError if the array is not 1D or not fully numeric.
|
||||
@throws TypeError if the array is not a non-empty 1D array or iterable.
|
||||
"""
|
||||
if isinstance(arr, np.ndarray):
|
||||
if arr.ndim == 1:
|
||||
return
|
||||
else:
|
||||
raise ValueError(f"{name} must be a 1D numpy array")
|
||||
if isinstance(arr, collectionIterable) and not isinstance(arr, (str, bytes)):
|
||||
if all(isinstance(x, (int, float)) for x in arr):
|
||||
return
|
||||
else:
|
||||
raise ValueError(f"{name} must be fully numeric")
|
||||
else:
|
||||
raise TypeError(f"{name} must be a non-empty 1D array or iterable")
|
||||
|
||||
@staticmethod
|
||||
def compute_checksum(data: np.ndarray) -> bytes:
|
||||
"""
|
||||
@brief Compute the SHA-256 checksum of the given data.
|
||||
@param data The data to compute the checksum for.
|
||||
@return The SHA-256 checksum.
|
||||
"""
|
||||
return hashlib.sha256(data.tobytes()).digest()
|
||||
|
||||
def set_version(self, version: int) -> int:
|
||||
"""
|
||||
@brief Set the version of the OPAT file.
|
||||
@param version The version to set.
|
||||
@return The set version.
|
||||
"""
|
||||
self.header.version = version
|
||||
return self.header.version
|
||||
|
||||
def set_source(self, source: str) -> str:
|
||||
"""
|
||||
@brief Set the source information of the OPAT file.
|
||||
@param source The source information to set.
|
||||
@return The set source information.
|
||||
@throws TypeError if the source string is too long.
|
||||
"""
|
||||
if not self.validate_char_array_size(source, 64):
|
||||
raise TypeError(f"sourceInfo string ({source}) is too long ({len(source)}). Max length is 64")
|
||||
self.header.sourceInfo = source
|
||||
return self.header.sourceInfo
|
||||
|
||||
def set_comment(self, comment: str) -> str:
|
||||
"""
|
||||
@brief Set the comment of the OPAT file.
|
||||
@param comment The comment to set.
|
||||
@return The set comment.
|
||||
@throws TypeError if the comment string is too long.
|
||||
"""
|
||||
if not self.validate_char_array_size(comment, 128):
|
||||
raise TypeError(f"comment string ({comment}) is too long ({len(comment)}). Max length is 128")
|
||||
self.header.comment = comment
|
||||
return self.header.comment
|
||||
|
||||
def set_numIndex(self, numIndex: int) -> int:
|
||||
"""
|
||||
@brief Set the number of values to use when indexing table.
|
||||
@param numIndex The number of values to use when indexing table.
|
||||
@return The set number of values to use when indexing table.
|
||||
"""
|
||||
if numIndex < 1:
|
||||
raise ValueError(f"numIndex must be greater than 0! It is currently {numIndex}")
|
||||
self.header.numIndex = numIndex
|
||||
return self.header.numIndex
|
||||
|
||||
def add_table(self, indicies: Tuple[float], logR: Iterable[float], logT: Iterable[float], logKappa: Iterable[Iterable[float]]):
|
||||
"""
|
||||
@brief Add a table to the OPAT file.
|
||||
@param indicies The index values of the table.
|
||||
@param logR The logR values.
|
||||
@param logT The logT values.
|
||||
@param logKappa The logKappa values.
|
||||
@throws ValueError if logKappa is not a non-empty 2D array or if logR and logT are not 1D arrays.
|
||||
"""
|
||||
if len(indicies) != self.header.numIndex:
|
||||
raise ValueError(f"indicies must have length {self.header.numIndex}! Currently it has length {len(indicies)}")
|
||||
self.validate_logKappa(logKappa)
|
||||
self.validate_1D(logR, "logR")
|
||||
self.validate_1D(logT, "logT")
|
||||
|
||||
logR = np.array(logR)
|
||||
logT = np.array(logT)
|
||||
logKappa = np.array(logKappa)
|
||||
|
||||
if logKappa.shape != (logR.shape[0], logT.shape[0]):
|
||||
raise ValueError(f"logKappa must be of shape ({len(logR)} x {len(logT)})! Currently logKappa has shape {logKappa.shape}")
|
||||
|
||||
table = OPATTable(
|
||||
N_R = logR.shape[0],
|
||||
N_T = logT.shape[0],
|
||||
logR = logR,
|
||||
logT = logT,
|
||||
logKappa = logKappa
|
||||
)
|
||||
|
||||
self.tables.append((indicies, table))
|
||||
self.header.numTables += 1
|
||||
|
||||
|
||||
def _header_bytes(self) -> bytes:
|
||||
"""
|
||||
@brief Convert the header to bytes.
|
||||
@return The header as bytes.
|
||||
"""
|
||||
headerBytes = struct.pack(
|
||||
"<4s H I I Q 16s 64s 128s H 24s",
|
||||
self.header.magic.encode('utf-8'),
|
||||
self.header.version,
|
||||
self.header.numTables,
|
||||
self.header.headerSize,
|
||||
self.header.indexOffset,
|
||||
self.header.creationDate.encode('utf-8'),
|
||||
self.header.sourceInfo.encode('utf-8'),
|
||||
self.header.comment.encode('utf-8'),
|
||||
self.header.numIndex,
|
||||
self.header.reserved
|
||||
)
|
||||
return headerBytes
|
||||
|
||||
def _table_bytes(self, table: OPATTable) -> Tuple[bytes, bytes]:
|
||||
"""
|
||||
@brief Convert a table to bytes.
|
||||
@param table The OPAT table.
|
||||
@return A tuple containing the checksum and the table as bytes.
|
||||
"""
|
||||
logR = table.logR.flatten()
|
||||
logT = table.logT.flatten()
|
||||
logKappa = table.logKappa.flatten()
|
||||
tableBytes = struct.pack(
|
||||
f"<II{table.N_R}d{table.N_T}d{table.N_R*table.N_T}d",
|
||||
table.N_R,
|
||||
table.N_T,
|
||||
*logR,
|
||||
*logT,
|
||||
*logKappa
|
||||
)
|
||||
checksum = self.compute_checksum(logKappa)
|
||||
return (checksum, tableBytes)
|
||||
|
||||
def _tableIndex_bytes(self, tableIndex: TableIndex) -> bytes:
|
||||
"""
|
||||
@brief Convert a table index to bytes.
|
||||
@param tableIndex The table index.
|
||||
@return The table index as bytes.
|
||||
@throws RuntimeError if the table index entry does not have 64 bytes.
|
||||
"""
|
||||
tableIndexFMTString = "<"+"d"*self.header.numIndex+f"QQ"
|
||||
tableIndexBytes = struct.pack(
|
||||
tableIndexFMTString,
|
||||
*tableIndex.index,
|
||||
tableIndex.byteStart,
|
||||
tableIndex.byteEnd
|
||||
)
|
||||
tableIndexBytes += tableIndex.sha256
|
||||
|
||||
if len(tableIndexBytes) != 16+self.header.numIndex*8+32:
|
||||
raise RuntimeError(f"Each table index entry must have 64 bytes. Due to an unknown error the table index entry for (X,Z)=({tableIndex.X},{tableIndex.Z}) header has {len(tableIndexBytes)} bytes")
|
||||
|
||||
return tableIndexBytes
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""
|
||||
@brief Get the string representation of the OpatIO object.
|
||||
@return The string representation.
|
||||
"""
|
||||
reprString = f"""OpatIO(
|
||||
version: {self.header.version}
|
||||
numTables: {self.header.numTables}
|
||||
headerSize: {self.header.headerSize}
|
||||
indexOffset: {self.header.indexOffset}
|
||||
creationDate: {self.header.creationDate}
|
||||
sourceInfo: {self.header.sourceInfo}
|
||||
comment: {self.header.comment}
|
||||
numIndex: {self.header.numIndex}
|
||||
reserved: {self.header.reserved}
|
||||
)"""
|
||||
return reprString
|
||||
|
||||
def _format_table_as_string(self, table: OPATTable, indices: List[float]) -> str:
|
||||
"""
|
||||
@brief Format a table as a string.
|
||||
@param table The OPAT table.
|
||||
@indices The index values of the table.
|
||||
@return The formatted table as a string.
|
||||
"""
|
||||
tableString: List[str] = []
|
||||
# fixed width X and Z header per table
|
||||
tableIndexString: List[str] = []
|
||||
for index in indices:
|
||||
tableIndexString.append(f"{index:<10.4f}")
|
||||
tableString.append(" ".join(tableIndexString))
|
||||
tableString.append("-" * 80)
|
||||
# write logR across the top (reserving one col for where logT will be)
|
||||
tableString.append(f"{'':<10}{'logR':<10}")
|
||||
logRRow = f"{'logT':<10}"
|
||||
logRRowTrue = "".join(f"{r:<10.4f}" for r in table.logR)
|
||||
tableString.append(logRRow + logRRowTrue)
|
||||
for i, logT in enumerate(table.logT):
|
||||
row = f"{logT:<10.4f}"
|
||||
for kappa in table.logKappa[:, i]:
|
||||
row += f"{kappa:<10.4f}"
|
||||
tableString.append(row)
|
||||
tableString.append("=" * 80)
|
||||
return '\n'.join(tableString)
|
||||
|
||||
@staticmethod
|
||||
def print_table_indexes(table_indexes: List[TableIndex]) -> str:
|
||||
"""
|
||||
@brief Print the table indexes.
|
||||
@param table_indexes The list of table indexes.
|
||||
@return The formatted table indexes as a string.
|
||||
"""
|
||||
if not table_indexes:
|
||||
print("No table indexes found.")
|
||||
return
|
||||
|
||||
tableRows: List[str] = []
|
||||
tableRows.append("\nTable Indexes in OPAT File:\n")
|
||||
headerString: str = ''
|
||||
for indexID, index in enumerate(table_indexes[0].index):
|
||||
indexKey = f"Index {indexID}"
|
||||
headerString += f"{indexKey:<10}"
|
||||
headerString += f"{'Byte Start':<15} {'Byte End':<15} {'Checksum (SHA-256)'}"
|
||||
tableRows.append(headerString)
|
||||
tableRows.append("=" * 80)
|
||||
for entry in table_indexes:
|
||||
tableEntry = ''
|
||||
for index in entry.index:
|
||||
tableEntry += f"{index:<10.4f}"
|
||||
tableEntry += f"{entry.byteStart:<15} {entry.byteEnd:<15} {entry.sha256[:16]}..."
|
||||
tableRows.append(tableEntry)
|
||||
return '\n'.join(tableRows)
|
||||
|
||||
def save_as_ascii(self, filename: str) -> str:
|
||||
"""
|
||||
@brief Save the OPAT file as an ASCII file.
|
||||
@param filename The name of the file.
|
||||
@return The name of the saved file.
|
||||
"""
|
||||
numericFMT = "{:.18e}"
|
||||
currentStartByte: int = 256
|
||||
tableIndexs: List[bytes] = []
|
||||
tableStrings: List[bytes] = []
|
||||
for index, table in self.tables:
|
||||
checksum, tableBytes = self._table_bytes(table)
|
||||
tableStrings.append(self._format_table_as_string(table, index) + "\n")
|
||||
tableIndex = TableIndex(
|
||||
index = index,
|
||||
byteStart = currentStartByte,
|
||||
byteEnd = currentStartByte + len(tableBytes),
|
||||
sha256 = checksum
|
||||
)
|
||||
tableIndexs.append(tableIndex)
|
||||
|
||||
|
||||
currentStartByte += len(tableBytes)
|
||||
self.header.indexOffset = currentStartByte
|
||||
with open(filename, 'w') as f:
|
||||
f.write("This is an ASCII representation of an OPAT file, it is not a valid OPAT file in and of itself.\n")
|
||||
f.write("This file is meant to be human readable and is not meant to be read by a computer.\n")
|
||||
f.write("The purpose of this file is to provide a human readable representation of the OPAT file which can be used for debugging purposes.\n")
|
||||
f.write("The full binary specification of the OPAT file can be found in the OPAT file format documentation at:\n")
|
||||
f.write(" https://github.com/4D-STAR/4DSSE/blob/main/specs/OPAT/OPAT.pdf\n")
|
||||
f.write("="*35 + " HEADER " + "="*36 + "\n")
|
||||
f.write(f">> {self.header.magic}\n")
|
||||
f.write(f">> Version: {self.header.version}\n")
|
||||
f.write(f">> numTables: {self.header.numTables}\n")
|
||||
f.write(f">> headerSize (bytes): {self.header.headerSize}\n")
|
||||
f.write(f">> tableIndex Offset (bytes): {self.header.indexOffset}\n")
|
||||
f.write(f">> Creation Date: {self.header.creationDate}\n")
|
||||
f.write(f">> Source Info: {self.header.sourceInfo}\n")
|
||||
f.write(f">> Comment: {self.header.comment}\n")
|
||||
f.write(f">> numIndex: {self.header.numIndex}\n")
|
||||
f.write("="*37 + " DATA " + "="*37 + "\n")
|
||||
f.write("="*80 + "\n")
|
||||
for tableString in tableStrings:
|
||||
f.write(tableString)
|
||||
f.write("="*80 + "\n")
|
||||
f.write("="*36 + " INDEX " + "="*37 + "\n")
|
||||
f.write(self.print_table_indexes(tableIndexs))
|
||||
|
||||
def save(self, filename: str) -> str:
|
||||
"""
|
||||
@brief Save the OPAT file as a binary file.
|
||||
@param filename The name of the file.
|
||||
@return The name of the saved file.
|
||||
@throws RuntimeError if the header does not have 256 bytes.
|
||||
"""
|
||||
tempHeaderBytes = self._header_bytes()
|
||||
|
||||
if len(tempHeaderBytes) != 256:
|
||||
raise RuntimeError(f"Header must have 256 bytes. Due to an unknown error the header has {len(tempHeaderBytes)} bytes")
|
||||
|
||||
currentStartByte: int = 256
|
||||
tableIndicesBytes: List[bytes] = []
|
||||
tablesBytes: List[bytes] = []
|
||||
for index, table in self.tables:
|
||||
checksum, tableBytes = self._table_bytes(table)
|
||||
tableIndex = TableIndex(
|
||||
index,
|
||||
byteStart = currentStartByte,
|
||||
byteEnd = currentStartByte + len(tableBytes),
|
||||
sha256 = checksum
|
||||
)
|
||||
tableIndexBytes = self._tableIndex_bytes(tableIndex)
|
||||
tablesBytes.append(tableBytes)
|
||||
tableIndicesBytes.append(tableIndexBytes)
|
||||
|
||||
currentStartByte += len(tableBytes)
|
||||
self.header.indexOffset = currentStartByte
|
||||
headerBytes = self._header_bytes()
|
||||
|
||||
with open(filename, 'wb') as f:
|
||||
f.write(headerBytes)
|
||||
for tableBytes in tablesBytes:
|
||||
f.write(tableBytes)
|
||||
for tableIndexBytes in tableIndicesBytes:
|
||||
f.write(tableIndexBytes)
|
||||
|
||||
if os.path.exists(filename):
|
||||
return filename
|
||||
|
||||
|
||||
def loadOpat(filename: str) -> OpatIO:
|
||||
"""
|
||||
@brief Load an OPAT file.
|
||||
@param filename The name of the file.
|
||||
@return The loaded OpatIO object.
|
||||
@throws RuntimeError if the header does not have 256 bytes.
|
||||
"""
|
||||
opat = OpatIO()
|
||||
with open(filename, 'rb') as f:
|
||||
headerBytes: bytes = f.read(256)
|
||||
unpackedHeader = struct.unpack("<4s H I I Q 16s 64s 128s H 24s", headerBytes)
|
||||
loadedHeader = Header(
|
||||
magic = unpackedHeader[0].decode().replace("\x00", ""),
|
||||
version = unpackedHeader[1],
|
||||
numTables = unpackedHeader[2],
|
||||
headerSize = unpackedHeader[3],
|
||||
indexOffset = unpackedHeader[4],
|
||||
creationDate = unpackedHeader[5].decode().replace("\x00", ""),
|
||||
sourceInfo = unpackedHeader[6].decode().replace("\x00", ""),
|
||||
comment = unpackedHeader[7].decode().replace("\x00", ""),
|
||||
numIndex = unpackedHeader[8],
|
||||
reserved = unpackedHeader[9]
|
||||
)
|
||||
opat.header = loadedHeader
|
||||
f.seek(opat.header.indexOffset)
|
||||
tableIndices: List[TableIndex] = []
|
||||
tableIndexChunkSize = 16 + loadedHeader.numIndex*8
|
||||
tableIndexFMTString = "<"+"d"*loadedHeader.numIndex+"QQ"
|
||||
while tableIndexEntryBytes := f.read(tableIndexChunkSize):
|
||||
unpackedTableIndexEntry = struct.unpack(tableIndexFMTString, tableIndexEntryBytes)
|
||||
checksum = f.read(32)
|
||||
index = unpackedTableIndexEntry[:loadedHeader.numIndex]
|
||||
tableIndexEntry = TableIndex(
|
||||
index = index,
|
||||
byteStart = unpackedTableIndexEntry[loadedHeader.numIndex],
|
||||
byteEnd = unpackedTableIndexEntry[loadedHeader.numIndex+1],
|
||||
sha256 = checksum
|
||||
)
|
||||
tableIndices.append(tableIndexEntry)
|
||||
|
||||
currentStartByte = opat.header.headerSize
|
||||
f.seek(currentStartByte)
|
||||
for tableIndex in tableIndices:
|
||||
f.seek(tableIndex.byteStart)
|
||||
byteLength = tableIndex.byteEnd - tableIndex.byteStart
|
||||
tableBytes = f.read(byteLength)
|
||||
|
||||
nr_nt_fmt = "<II"
|
||||
nr_nt_size = struct.calcsize(nr_nt_fmt)
|
||||
N_R, N_T = struct.unpack(nr_nt_fmt, tableBytes[:nr_nt_size])
|
||||
|
||||
dataFormat = f"<{N_R}d{N_T}d{N_R*N_T}d"
|
||||
unpackedData = struct.unpack(dataFormat, tableBytes[nr_nt_size:])
|
||||
|
||||
logR = np.array(unpackedData[:N_R], dtype=np.float64)
|
||||
logT = np.array(unpackedData[N_R: N_R+N_T], dtype=np.float64)
|
||||
logKappa = np.array(unpackedData[N_R+N_T:], dtype=np.float64).reshape((N_R, N_T))
|
||||
|
||||
opat.add_table(tableIndex.index, logR, logT, logKappa)
|
||||
return opat
|
||||
@@ -1,27 +0,0 @@
|
||||
from opatio import OpatIO
|
||||
import numpy as np
|
||||
np.random.seed(42)
|
||||
|
||||
def generate_synthetic_opacity_table(n_r, n_t):
|
||||
logR = np.linspace(-8, 2, n_r, dtype=np.float64) # log Density grid
|
||||
logT = np.linspace(3, 9, n_t, dtype=np.float64) # log Temperature grid
|
||||
logK = np.random.uniform(-2, 2, size=(n_r, n_t)).astype(np.float64) # Synthetic Opacity
|
||||
return logR, logT, logK
|
||||
|
||||
if __name__ == "__main__":
|
||||
n_r = 50
|
||||
n_t = 50
|
||||
num_tables = 20
|
||||
XValues = np.linspace(0.1, 0.7, num_tables)
|
||||
ZValues = np.linspace(0.001, 0.03, num_tables)
|
||||
opat = OpatIO()
|
||||
opat.set_comment("Synthetic Opacity Tables")
|
||||
opat.set_source("utils/opatio/utils/mkTestData.py")
|
||||
|
||||
for i in range(num_tables):
|
||||
logR, logT, logK = generate_synthetic_opacity_table(n_r, n_t)
|
||||
opat.add_table((XValues[i], ZValues[i]), logR, logT, logK)
|
||||
|
||||
opat.save("testData/synthetic_tables.opat")
|
||||
opat.save_as_ascii("testData/synthetic_tables_OPAT.ascii")
|
||||
|
||||
Binary file not shown.
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user