feat(utils/opatio): began update to account for OPAL type II tables (or generally an arbitrary number of index values per table)

This commit is contained in:
2025-02-16 19:03:41 -05:00
parent 9a302fd5d3
commit c4cd731520

View File

@@ -24,6 +24,7 @@ class Header:
creationDate: str #< Creation date of the file creationDate: str #< Creation date of the file
sourceInfo: str #< Source information sourceInfo: str #< Source information
comment: str #< Comment section comment: str #< Comment section
numIndex: int #< Number of values to use when indexing table
reserved: bytes #< Reserved for future use reserved: bytes #< Reserved for future use
@dataclass @dataclass
@@ -31,8 +32,7 @@ class TableIndex:
""" """
@brief Structure to hold the index information of a table in an OPAT file. @brief Structure to hold the index information of a table in an OPAT file.
""" """
X: float #< X composition value index: List[float] #< Index values of the table
Z: float #< Z composition value
byteStart: int #< Byte start position of the table byteStart: int #< Byte start position of the table
byteEnd: int #< Byte end position of the table byteEnd: int #< Byte end position of the table
sha256: bytes #< SHA-256 hash of the table data sha256: bytes #< SHA-256 hash of the table data
@@ -57,7 +57,8 @@ defaultHeader = Header(
creationDate=datetime.now().strftime("%b %d, %Y"), creationDate=datetime.now().strftime("%b %d, %Y"),
sourceInfo="no source provided by user", sourceInfo="no source provided by user",
comment="default header", comment="default header",
reserved=b"\x00" * 26 numIndex=2,
reserved=b"\x00" * 24
) )
class OpatIO: class OpatIO:
@@ -207,16 +208,28 @@ class OpatIO:
self.header.comment = comment self.header.comment = comment
return self.header.comment return self.header.comment
def add_table(self, X: float, Z: float, logR: Iterable[float], logT: Iterable[float], logKappa: Iterable[Iterable[float]]): def set_numIndex(self, numIndex: int) -> int:
"""
@brief Set the number of values to use when indexing table.
@param numIndex The number of values to use when indexing table.
@return The set number of values to use when indexing table.
"""
if numIndex < 1:
raise ValueError(f"numIndex must be greater than 0! It is currently {numIndex}")
self.header.numIndex = numIndex
return self.header.numIndex
def add_table(self, indicies: Tuple[float], logR: Iterable[float], logT: Iterable[float], logKappa: Iterable[Iterable[float]]):
""" """
@brief Add a table to the OPAT file. @brief Add a table to the OPAT file.
@param X The X composition value. @param indicies The index values of the table.
@param Z The Z composition value.
@param logR The logR values. @param logR The logR values.
@param logT The logT values. @param logT The logT values.
@param logKappa The logKappa values. @param logKappa The logKappa values.
@throws ValueError if logKappa is not a non-empty 2D array or if logR and logT are not 1D arrays. @throws ValueError if logKappa is not a non-empty 2D array or if logR and logT are not 1D arrays.
""" """
if len(indicies) != self.header.numIndex:
raise ValueError(f"indicies must have length {self.header.numIndex}! Currently it has length {len(indicies)}")
self.validate_logKappa(logKappa) self.validate_logKappa(logKappa)
self.validate_1D(logR, "logR") self.validate_1D(logR, "logR")
self.validate_1D(logT, "logT") self.validate_1D(logT, "logT")
@@ -236,7 +249,7 @@ class OpatIO:
logKappa = logKappa logKappa = logKappa
) )
self.tables.append(((X, Z), table)) self.tables.append(indicies, table))
self.header.numTables += 1 self.header.numTables += 1
@@ -246,7 +259,7 @@ class OpatIO:
@return The header as bytes. @return The header as bytes.
""" """
headerBytes = struct.pack( headerBytes = struct.pack(
"<4s H I I Q 16s 64s 128s 26s", "<4s H I I Q 16s 64s 128s H 24s",
self.header.magic.encode('utf-8'), self.header.magic.encode('utf-8'),
self.header.version, self.header.version,
self.header.numTables, self.header.numTables,
@@ -255,6 +268,7 @@ class OpatIO:
self.header.creationDate.encode('utf-8'), self.header.creationDate.encode('utf-8'),
self.header.sourceInfo.encode('utf-8'), self.header.sourceInfo.encode('utf-8'),
self.header.comment.encode('utf-8'), self.header.comment.encode('utf-8'),
self.header.numIndex,
self.header.reserved self.header.reserved
) )
return headerBytes return headerBytes
@@ -286,10 +300,10 @@ class OpatIO:
@return The table index as bytes. @return The table index as bytes.
@throws RuntimeError if the table index entry does not have 64 bytes. @throws RuntimeError if the table index entry does not have 64 bytes.
""" """
tableIndexFMTString = "<"+"d"*self.header.numIndex+f"QQ"
tableIndexBytes = struct.pack( tableIndexBytes = struct.pack(
'<ddQQ', tableIndexFMTString,
tableIndex.X, *tableIndex.index,
tableIndex.Z,
tableIndex.byteStart, tableIndex.byteStart,
tableIndex.byteEnd tableIndex.byteEnd
) )
@@ -313,21 +327,24 @@ class OpatIO:
creationDate: {self.header.creationDate} creationDate: {self.header.creationDate}
sourceInfo: {self.header.sourceInfo} sourceInfo: {self.header.sourceInfo}
comment: {self.header.comment} comment: {self.header.comment}
numIndex: {self.header.numIndex}
reserved: {self.header.reserved} reserved: {self.header.reserved}
)""" )"""
return reprString return reprString
def _format_table_as_string(self, table: OPATTable, X: float, Z: float) -> str: def _format_table_as_string(self, table: OPATTable, indices: List[float]) -> str:
""" """
@brief Format a table as a string. @brief Format a table as a string.
@param table The OPAT table. @param table The OPAT table.
@param X The X composition value. @indices The index values of the table.
@param Z The Z composition value.
@return The formatted table as a string. @return The formatted table as a string.
""" """
tableString: List[str] = [] tableString: List[str] = []
# fixed width X and Z header per table # fixed width X and Z header per table
tableString.append(f"X: {X:<10.4f} Z: {Z:<10.4f}") tableIndexString: List[str] = []
for index in indices:
tableIndexString.append(f"{index:<10.4f}")
tableString.append(" ".join(tableIndexString))
tableString.append("-" * 80) tableString.append("-" * 80)
# write logR across the top (reserving one col for where logT will be) # write logR across the top (reserving one col for where logT will be)
logRRow = f"{'':<10}" logRRow = f"{'':<10}"
@@ -354,10 +371,19 @@ class OpatIO:
tableRows: List[str] = [] tableRows: List[str] = []
tableRows.append("\nTable Indexes in OPAT File:\n") tableRows.append("\nTable Indexes in OPAT File:\n")
tableRows.append(f"{'X':<10} {'Z':<10} {'Byte Start':<15} {'Byte End':<15} {'Checksum (SHA-256)'}") headerString: str = ''
for indexID, index in enumerate(table_indexes[0].index):
indexKey = f"Index {indexID}"
headerString += f"{indexKey:<10}"
headerString += f"{'Byte Start':<15} {'Byte End':<15} {'Checksum (SHA-256)'}"
tableRows.append(headerString)
tableRows.append("=" * 80) tableRows.append("=" * 80)
for entry in table_indexes: for entry in table_indexes:
tableRows.append(f"{entry.X:<10.4f} {entry.Z:<10.4f} {entry.byteStart:<15} {entry.byteEnd:<15} {entry.sha256[:16]}...") tableEntry = ''
for index in entry.index:
tableEntry += f"{index:<10.4f}"
tableEntry += f"{entry.byteStart:<15} {entry.byteEnd:<15} {entry.sha256[:16]}..."
tableRows.append(tableEntry)
return '\n'.join(tableRows) return '\n'.join(tableRows)
def save_as_ascii(self, filename: str) -> str: def save_as_ascii(self, filename: str) -> str:
@@ -370,12 +396,11 @@ class OpatIO:
currentStartByte: int = 256 currentStartByte: int = 256
tableIndexs: List[bytes] = [] tableIndexs: List[bytes] = []
tableStrings: List[bytes] = [] tableStrings: List[bytes] = []
for (X, Z), table in self.tables: for index, table in self.tables:
checksum, tableBytes = self._table_bytes(table) checksum, tableBytes = self._table_bytes(table)
tableStrings.append(self._format_table_as_string(table, X, Z) + "\n") tableStrings.append(self._format_table_as_string(table, index) + "\n")
tableIndex = TableIndex( tableIndex = TableIndex(
X = X, index = index,
Z = Z,
byteStart = currentStartByte, byteStart = currentStartByte,
byteEnd = currentStartByte + len(tableBytes), byteEnd = currentStartByte + len(tableBytes),
sha256 = checksum sha256 = checksum
@@ -394,6 +419,7 @@ class OpatIO:
f.write(f"Creation Date: {self.header.creationDate}\n") f.write(f"Creation Date: {self.header.creationDate}\n")
f.write(f"Source Info: {self.header.sourceInfo}\n") f.write(f"Source Info: {self.header.sourceInfo}\n")
f.write(f"Comment: {self.header.comment}\n") f.write(f"Comment: {self.header.comment}\n")
f.write(f"numIndex: {self.header.numIndex}\n")
f.write("="*80 + "\n") f.write("="*80 + "\n")
f.write("="*80 + "\n") f.write("="*80 + "\n")
for tableString in tableStrings: for tableString in tableStrings:
@@ -417,11 +443,10 @@ class OpatIO:
currentStartByte: int = 256 currentStartByte: int = 256
tableIndicesBytes: List[bytes] = [] tableIndicesBytes: List[bytes] = []
tablesBytes: List[bytes] = [] tablesBytes: List[bytes] = []
for (X, Z), table in self.tables: for index, table in self.tables:
checksum, tableBytes = self._table_bytes(table) checksum, tableBytes = self._table_bytes(table)
tableIndex = TableIndex( tableIndex = TableIndex(
X = X, index,
Z = Z,
byteStart = currentStartByte, byteStart = currentStartByte,
byteEnd = currentStartByte + len(tableBytes), byteEnd = currentStartByte + len(tableBytes),
sha256 = checksum sha256 = checksum
@@ -455,7 +480,7 @@ def loadOpat(filename: str) -> OpatIO:
opat = OpatIO() opat = OpatIO()
with open(filename, 'rb') as f: with open(filename, 'rb') as f:
headerBytes: bytes = f.read(256) headerBytes: bytes = f.read(256)
unpackedHeader = struct.unpack("<4s H I I Q 16s 64s 128s 26s", headerBytes) unpackedHeader = struct.unpack("<4s H I I Q 16s 64s 128s H 24s", headerBytes)
loadedHeader = Header( loadedHeader = Header(
magic = unpackedHeader[0].decode().replace("\x00", ""), magic = unpackedHeader[0].decode().replace("\x00", ""),
version = unpackedHeader[1], version = unpackedHeader[1],
@@ -465,14 +490,19 @@ def loadOpat(filename: str) -> OpatIO:
creationDate = unpackedHeader[5].decode().replace("\x00", ""), creationDate = unpackedHeader[5].decode().replace("\x00", ""),
sourceInfo = unpackedHeader[6].decode().replace("\x00", ""), sourceInfo = unpackedHeader[6].decode().replace("\x00", ""),
comment = unpackedHeader[7].decode().replace("\x00", ""), comment = unpackedHeader[7].decode().replace("\x00", ""),
reserved = unpackedHeader[8] numIndex = unpackedHeader[8],
reserved = unpackedHeader[9]
) )
opat.header = loadedHeader opat.header = loadedHeader
f.seek(opat.header.indexOffset) f.seek(opat.header.indexOffset)
tableIndices: List[TableIndex] = [] tableIndices: List[TableIndex] = []
while tableIndexEntryBytes := f.read(32): tableIndexChunkSize = 16 + loadedHeader.numIndex*8
unpackedTableIndexEntry = struct.unpack("<ddQQ", tableIndexEntryBytes) tableIndexFMTString = "<"+"d"*loadedHeader.numIndex+"QQ"
while tableIndexEntryBytes := f.read(tableIndexChunkSize):
unpackedTableIndexEntry = struct.unpack(tableIndexFMTString, tableIndexEntryBytes)
checksum = f.read(32) checksum = f.read(32)
# TODO: Update this to get the fully, general header index set (currently it still gets only X and Z without the update)
# TODO: Also update the spec to reflect the new header index set and the new table index format.
tableIndexEntry = TableIndex( tableIndexEntry = TableIndex(
X = unpackedTableIndexEntry[0], X = unpackedTableIndexEntry[0],
Z = unpackedTableIndexEntry[1], Z = unpackedTableIndexEntry[1],