diff --git a/utils/opatio/src/opatio/opat/opat.py b/utils/opatio/src/opatio/opat/opat.py index eea33da..744b8c9 100644 --- a/utils/opatio/src/opatio/opat/opat.py +++ b/utils/opatio/src/opatio/opat/opat.py @@ -24,6 +24,7 @@ class Header: creationDate: str #< Creation date of the file sourceInfo: str #< Source information comment: str #< Comment section + numIndex: int #< Number of values to use when indexing table reserved: bytes #< Reserved for future use @dataclass @@ -31,8 +32,7 @@ class TableIndex: """ @brief Structure to hold the index information of a table in an OPAT file. """ - X: float #< X composition value - Z: float #< Z composition value + index: List[float] #< Index values of the table byteStart: int #< Byte start position of the table byteEnd: int #< Byte end position of the table sha256: bytes #< SHA-256 hash of the table data @@ -57,7 +57,8 @@ defaultHeader = Header( creationDate=datetime.now().strftime("%b %d, %Y"), sourceInfo="no source provided by user", comment="default header", - reserved=b"\x00" * 26 + numIndex=2, + reserved=b"\x00" * 24 ) class OpatIO: @@ -206,17 +207,29 @@ class OpatIO: raise TypeError(f"comment string ({comment}) is too long ({len(comment)}). Max length is 128") self.header.comment = comment return self.header.comment + + def set_numIndex(self, numIndex: int) -> int: + """ + @brief Set the number of values to use when indexing table. + @param numIndex The number of values to use when indexing table. + @return The set number of values to use when indexing table. + """ + if numIndex < 1: + raise ValueError(f"numIndex must be greater than 0! It is currently {numIndex}") + self.header.numIndex = numIndex + return self.header.numIndex - def add_table(self, X: float, Z: float, logR: Iterable[float], logT: Iterable[float], logKappa: Iterable[Iterable[float]]): + def add_table(self, indicies: Tuple[float], logR: Iterable[float], logT: Iterable[float], logKappa: Iterable[Iterable[float]]): """ @brief Add a table to the OPAT file. - @param X The X composition value. - @param Z The Z composition value. + @param indicies The index values of the table. @param logR The logR values. @param logT The logT values. @param logKappa The logKappa values. @throws ValueError if logKappa is not a non-empty 2D array or if logR and logT are not 1D arrays. """ + if len(indicies) != self.header.numIndex: + raise ValueError(f"indicies must have length {self.header.numIndex}! Currently it has length {len(indicies)}") self.validate_logKappa(logKappa) self.validate_1D(logR, "logR") self.validate_1D(logT, "logT") @@ -236,7 +249,7 @@ class OpatIO: logKappa = logKappa ) - self.tables.append(((X, Z), table)) + self.tables.append(indicies, table)) self.header.numTables += 1 @@ -246,7 +259,7 @@ class OpatIO: @return The header as bytes. """ headerBytes = struct.pack( - "<4s H I I Q 16s 64s 128s 26s", + "<4s H I I Q 16s 64s 128s H 24s", self.header.magic.encode('utf-8'), self.header.version, self.header.numTables, @@ -255,6 +268,7 @@ class OpatIO: self.header.creationDate.encode('utf-8'), self.header.sourceInfo.encode('utf-8'), self.header.comment.encode('utf-8'), + self.header.numIndex, self.header.reserved ) return headerBytes @@ -286,10 +300,10 @@ class OpatIO: @return The table index as bytes. @throws RuntimeError if the table index entry does not have 64 bytes. """ + tableIndexFMTString = "<"+"d"*self.header.numIndex+f"QQ" tableIndexBytes = struct.pack( - ' str: + def _format_table_as_string(self, table: OPATTable, indices: List[float]) -> str: """ @brief Format a table as a string. @param table The OPAT table. - @param X The X composition value. - @param Z The Z composition value. + @indices The index values of the table. @return The formatted table as a string. """ tableString: List[str] = [] # fixed width X and Z header per table - tableString.append(f"X: {X:<10.4f} Z: {Z:<10.4f}") + tableIndexString: List[str] = [] + for index in indices: + tableIndexString.append(f"{index:<10.4f}") + tableString.append(" ".join(tableIndexString)) tableString.append("-" * 80) # write logR across the top (reserving one col for where logT will be) logRRow = f"{'':<10}" @@ -354,10 +371,19 @@ class OpatIO: tableRows: List[str] = [] tableRows.append("\nTable Indexes in OPAT File:\n") - tableRows.append(f"{'X':<10} {'Z':<10} {'Byte Start':<15} {'Byte End':<15} {'Checksum (SHA-256)'}") + headerString: str = '' + for indexID, index in enumerate(table_indexes[0].index): + indexKey = f"Index {indexID}" + headerString += f"{indexKey:<10}" + headerString += f"{'Byte Start':<15} {'Byte End':<15} {'Checksum (SHA-256)'}" + tableRows.append(headerString) tableRows.append("=" * 80) for entry in table_indexes: - tableRows.append(f"{entry.X:<10.4f} {entry.Z:<10.4f} {entry.byteStart:<15} {entry.byteEnd:<15} {entry.sha256[:16]}...") + tableEntry = '' + for index in entry.index: + tableEntry += f"{index:<10.4f}" + tableEntry += f"{entry.byteStart:<15} {entry.byteEnd:<15} {entry.sha256[:16]}..." + tableRows.append(tableEntry) return '\n'.join(tableRows) def save_as_ascii(self, filename: str) -> str: @@ -370,12 +396,11 @@ class OpatIO: currentStartByte: int = 256 tableIndexs: List[bytes] = [] tableStrings: List[bytes] = [] - for (X, Z), table in self.tables: + for index, table in self.tables: checksum, tableBytes = self._table_bytes(table) - tableStrings.append(self._format_table_as_string(table, X, Z) + "\n") + tableStrings.append(self._format_table_as_string(table, index) + "\n") tableIndex = TableIndex( - X = X, - Z = Z, + index = index, byteStart = currentStartByte, byteEnd = currentStartByte + len(tableBytes), sha256 = checksum @@ -394,6 +419,7 @@ class OpatIO: f.write(f"Creation Date: {self.header.creationDate}\n") f.write(f"Source Info: {self.header.sourceInfo}\n") f.write(f"Comment: {self.header.comment}\n") + f.write(f"numIndex: {self.header.numIndex}\n") f.write("="*80 + "\n") f.write("="*80 + "\n") for tableString in tableStrings: @@ -417,11 +443,10 @@ class OpatIO: currentStartByte: int = 256 tableIndicesBytes: List[bytes] = [] tablesBytes: List[bytes] = [] - for (X, Z), table in self.tables: + for index, table in self.tables: checksum, tableBytes = self._table_bytes(table) tableIndex = TableIndex( - X = X, - Z = Z, + index, byteStart = currentStartByte, byteEnd = currentStartByte + len(tableBytes), sha256 = checksum @@ -455,7 +480,7 @@ def loadOpat(filename: str) -> OpatIO: opat = OpatIO() with open(filename, 'rb') as f: headerBytes: bytes = f.read(256) - unpackedHeader = struct.unpack("<4s H I I Q 16s 64s 128s 26s", headerBytes) + unpackedHeader = struct.unpack("<4s H I I Q 16s 64s 128s H 24s", headerBytes) loadedHeader = Header( magic = unpackedHeader[0].decode().replace("\x00", ""), version = unpackedHeader[1], @@ -465,14 +490,19 @@ def loadOpat(filename: str) -> OpatIO: creationDate = unpackedHeader[5].decode().replace("\x00", ""), sourceInfo = unpackedHeader[6].decode().replace("\x00", ""), comment = unpackedHeader[7].decode().replace("\x00", ""), - reserved = unpackedHeader[8] + numIndex = unpackedHeader[8], + reserved = unpackedHeader[9] ) opat.header = loadedHeader f.seek(opat.header.indexOffset) tableIndices: List[TableIndex] = [] - while tableIndexEntryBytes := f.read(32): - unpackedTableIndexEntry = struct.unpack("