mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
187 lines
6.4 KiB
Python
187 lines
6.4 KiB
Python
# Modified from https://github.com/Fraetor/jxl_decode Original license: MIT
|
|
# Added partial read support for up to 200x speedup
|
|
|
|
import os
|
|
from typing import List, Tuple
|
|
|
|
class JXLBitstream:
|
|
"""
|
|
A stream of bits with methods for easy handling.
|
|
"""
|
|
|
|
def __init__(self, file, offset: int = 0, offsets: List[List[int]] = None):
|
|
self.shift = 0
|
|
self.bitstream = bytearray()
|
|
self.file = file
|
|
self.offset = offset
|
|
self.offsets = offsets
|
|
if self.offsets:
|
|
self.offset = self.offsets[0][1]
|
|
self.previous_data_len = 0
|
|
self.index = 0
|
|
self.file.seek(self.offset)
|
|
|
|
def get_bits(self, length: int = 1) -> int:
|
|
if self.offsets and self.shift + length > self.previous_data_len + self.offsets[self.index][2]:
|
|
self.partial_to_read_length = length
|
|
if self.shift < self.previous_data_len + self.offsets[self.index][2]:
|
|
self.partial_read(0, length)
|
|
self.bitstream.extend(self.file.read(self.partial_to_read_length))
|
|
else:
|
|
self.bitstream.extend(self.file.read(length))
|
|
bitmask = 2**length - 1
|
|
bits = (int.from_bytes(self.bitstream, "little") >> self.shift) & bitmask
|
|
self.shift += length
|
|
return bits
|
|
|
|
def partial_read(self, current_length: int, length: int) -> None:
|
|
self.previous_data_len += self.offsets[self.index][2]
|
|
to_read_length = self.previous_data_len - (self.shift + current_length)
|
|
self.bitstream.extend(self.file.read(to_read_length))
|
|
current_length += to_read_length
|
|
self.partial_to_read_length -= to_read_length
|
|
self.index += 1
|
|
self.file.seek(self.offsets[self.index][1])
|
|
if self.shift + length > self.previous_data_len + self.offsets[self.index][2]:
|
|
self.partial_read(current_length, length)
|
|
|
|
|
|
def decode_codestream(file, offset: int = 0, offsets: List[List[int]] = None) -> Tuple[int,int]:
|
|
"""
|
|
Decodes the actual codestream.
|
|
JXL codestream specification: http://www-internal/2022/18181-1
|
|
"""
|
|
|
|
# Convert codestream to int within an object to get some handy methods.
|
|
codestream = JXLBitstream(file, offset=offset, offsets=offsets)
|
|
|
|
# Skip signature
|
|
codestream.get_bits(16)
|
|
|
|
# SizeHeader
|
|
div8 = codestream.get_bits(1)
|
|
if div8:
|
|
height = 8 * (1 + codestream.get_bits(5))
|
|
else:
|
|
distribution = codestream.get_bits(2)
|
|
match distribution:
|
|
case 0:
|
|
height = 1 + codestream.get_bits(9)
|
|
case 1:
|
|
height = 1 + codestream.get_bits(13)
|
|
case 2:
|
|
height = 1 + codestream.get_bits(18)
|
|
case 3:
|
|
height = 1 + codestream.get_bits(30)
|
|
ratio = codestream.get_bits(3)
|
|
if div8 and not ratio:
|
|
width = 8 * (1 + codestream.get_bits(5))
|
|
elif not ratio:
|
|
distribution = codestream.get_bits(2)
|
|
match distribution:
|
|
case 0:
|
|
width = 1 + codestream.get_bits(9)
|
|
case 1:
|
|
width = 1 + codestream.get_bits(13)
|
|
case 2:
|
|
width = 1 + codestream.get_bits(18)
|
|
case 3:
|
|
width = 1 + codestream.get_bits(30)
|
|
else:
|
|
match ratio:
|
|
case 1:
|
|
width = height
|
|
case 2:
|
|
width = (height * 12) // 10
|
|
case 3:
|
|
width = (height * 4) // 3
|
|
case 4:
|
|
width = (height * 3) // 2
|
|
case 5:
|
|
width = (height * 16) // 9
|
|
case 6:
|
|
width = (height * 5) // 4
|
|
case 7:
|
|
width = (height * 2) // 1
|
|
return width, height
|
|
|
|
|
|
def decode_container(file) -> Tuple[int,int]:
|
|
"""
|
|
Parses the ISOBMFF container, extracts the codestream, and decodes it.
|
|
JXL container specification: http://www-internal/2022/18181-2
|
|
"""
|
|
|
|
def parse_box(file, file_start: int) -> dict:
|
|
file.seek(file_start)
|
|
LBox = int.from_bytes(file.read(4), "big")
|
|
XLBox = None
|
|
if 1 < LBox <= 8:
|
|
raise ValueError(f"Invalid LBox at byte {file_start}.")
|
|
if LBox == 1:
|
|
file.seek(file_start + 8)
|
|
XLBox = int.from_bytes(file.read(8), "big")
|
|
if XLBox <= 16:
|
|
raise ValueError(f"Invalid XLBox at byte {file_start}.")
|
|
if XLBox:
|
|
header_length = 16
|
|
box_length = XLBox
|
|
else:
|
|
header_length = 8
|
|
if LBox == 0:
|
|
box_length = os.fstat(file.fileno()).st_size - file_start
|
|
else:
|
|
box_length = LBox
|
|
file.seek(file_start + 4)
|
|
box_type = file.read(4)
|
|
file.seek(file_start)
|
|
return {
|
|
"length": box_length,
|
|
"type": box_type,
|
|
"offset": header_length,
|
|
}
|
|
|
|
file.seek(0)
|
|
# Reject files missing required boxes. These two boxes are required to be at
|
|
# the start and contain no values, so we can manually check there presence.
|
|
# Signature box. (Redundant as has already been checked.)
|
|
if file.read(12) != bytes.fromhex("0000000C 4A584C20 0D0A870A"):
|
|
raise ValueError("Invalid signature box.")
|
|
# File Type box.
|
|
if file.read(20) != bytes.fromhex(
|
|
"00000014 66747970 6A786C20 00000000 6A786C20"
|
|
):
|
|
raise ValueError("Invalid file type box.")
|
|
|
|
offset = 0
|
|
offsets = []
|
|
data_offset_not_found = True
|
|
container_pointer = 32
|
|
file_size = os.fstat(file.fileno()).st_size
|
|
while data_offset_not_found:
|
|
box = parse_box(file, container_pointer)
|
|
match box["type"]:
|
|
case b"jxlc":
|
|
offset = container_pointer + box["offset"]
|
|
data_offset_not_found = False
|
|
case b"jxlp":
|
|
file.seek(container_pointer + box["offset"])
|
|
index = int.from_bytes(file.read(4), "big")
|
|
offsets.append([index, container_pointer + box["offset"] + 4, box["length"] - box["offset"] - 4])
|
|
container_pointer += box["length"]
|
|
if container_pointer >= file_size:
|
|
data_offset_not_found = False
|
|
|
|
if offsets:
|
|
offsets.sort(key=lambda i: i[0])
|
|
file.seek(0)
|
|
|
|
return decode_codestream(file, offset=offset, offsets=offsets)
|
|
|
|
|
|
def get_jxl_size(path: str) -> Tuple[int,int]:
|
|
with open(path, "rb") as file:
|
|
if file.read(2) == bytes.fromhex("FF0A"):
|
|
return decode_codestream(file)
|
|
return decode_container(file)
|