mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
Merge pull request #1955 from Disty0/dev
Fast image size reading support for JPEG XL
This commit is contained in:
186
library/jpeg_xl_util.py
Normal file
186
library/jpeg_xl_util.py
Normal file
@@ -0,0 +1,186 @@
|
||||
# Modified from https://github.com/Fraetor/jxl_decode
|
||||
# Added partial read support for up to 200x speedup
|
||||
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
|
||||
class JXLBitstream:
|
||||
"""
|
||||
A stream of bits with methods for easy handling.
|
||||
"""
|
||||
|
||||
def __init__(self, file, offset: int = 0, offsets: List[List[int]] = None):
|
||||
self.shift = 0
|
||||
self.bitstream = bytearray()
|
||||
self.file = file
|
||||
self.offset = offset
|
||||
self.offsets = offsets
|
||||
if self.offsets:
|
||||
self.offset = self.offsets[0][1]
|
||||
self.previous_data_len = 0
|
||||
self.index = 0
|
||||
self.file.seek(self.offset)
|
||||
|
||||
def get_bits(self, length: int = 1) -> int:
|
||||
if self.offsets and self.shift + length > self.previous_data_len + self.offsets[self.index][2]:
|
||||
self.partial_to_read_length = length
|
||||
if self.shift < self.previous_data_len + self.offsets[self.index][2]:
|
||||
self.partial_read(0, length)
|
||||
self.bitstream.extend(self.file.read(self.partial_to_read_length))
|
||||
else:
|
||||
self.bitstream.extend(self.file.read(length))
|
||||
bitmask = 2**length - 1
|
||||
bits = (int.from_bytes(self.bitstream, "little") >> self.shift) & bitmask
|
||||
self.shift += length
|
||||
return bits
|
||||
|
||||
def partial_read(self, current_length: int, length: int) -> None:
|
||||
self.previous_data_len += self.offsets[self.index][2]
|
||||
to_read_length = self.previous_data_len - (self.shift + current_length)
|
||||
self.bitstream.extend(self.file.read(to_read_length))
|
||||
current_length += to_read_length
|
||||
self.partial_to_read_length -= to_read_length
|
||||
self.index += 1
|
||||
self.file.seek(self.offsets[self.index][1])
|
||||
if self.shift + length > self.previous_data_len + self.offsets[self.index][2]:
|
||||
self.partial_read(current_length, length)
|
||||
|
||||
|
||||
def decode_codestream(file, offset: int = 0, offsets: List[List[int]] = None) -> Tuple[int,int]:
|
||||
"""
|
||||
Decodes the actual codestream.
|
||||
JXL codestream specification: http://www-internal/2022/18181-1
|
||||
"""
|
||||
|
||||
# Convert codestream to int within an object to get some handy methods.
|
||||
codestream = JXLBitstream(file, offset=offset, offsets=offsets)
|
||||
|
||||
# Skip signature
|
||||
codestream.get_bits(16)
|
||||
|
||||
# SizeHeader
|
||||
div8 = codestream.get_bits(1)
|
||||
if div8:
|
||||
height = 8 * (1 + codestream.get_bits(5))
|
||||
else:
|
||||
distribution = codestream.get_bits(2)
|
||||
match distribution:
|
||||
case 0:
|
||||
height = 1 + codestream.get_bits(9)
|
||||
case 1:
|
||||
height = 1 + codestream.get_bits(13)
|
||||
case 2:
|
||||
height = 1 + codestream.get_bits(18)
|
||||
case 3:
|
||||
height = 1 + codestream.get_bits(30)
|
||||
ratio = codestream.get_bits(3)
|
||||
if div8 and not ratio:
|
||||
width = 8 * (1 + codestream.get_bits(5))
|
||||
elif not ratio:
|
||||
distribution = codestream.get_bits(2)
|
||||
match distribution:
|
||||
case 0:
|
||||
width = 1 + codestream.get_bits(9)
|
||||
case 1:
|
||||
width = 1 + codestream.get_bits(13)
|
||||
case 2:
|
||||
width = 1 + codestream.get_bits(18)
|
||||
case 3:
|
||||
width = 1 + codestream.get_bits(30)
|
||||
else:
|
||||
match ratio:
|
||||
case 1:
|
||||
width = height
|
||||
case 2:
|
||||
width = (height * 12) // 10
|
||||
case 3:
|
||||
width = (height * 4) // 3
|
||||
case 4:
|
||||
width = (height * 3) // 2
|
||||
case 5:
|
||||
width = (height * 16) // 9
|
||||
case 6:
|
||||
width = (height * 5) // 4
|
||||
case 7:
|
||||
width = (height * 2) // 1
|
||||
return width, height
|
||||
|
||||
|
||||
def decode_container(file) -> Tuple[int,int]:
|
||||
"""
|
||||
Parses the ISOBMFF container, extracts the codestream, and decodes it.
|
||||
JXL container specification: http://www-internal/2022/18181-2
|
||||
"""
|
||||
|
||||
def parse_box(file, file_start: int) -> dict:
|
||||
file.seek(file_start)
|
||||
LBox = int.from_bytes(file.read(4), "big")
|
||||
XLBox = None
|
||||
if 1 < LBox <= 8:
|
||||
raise ValueError(f"Invalid LBox at byte {file_start}.")
|
||||
if LBox == 1:
|
||||
file.seek(file_start + 8)
|
||||
XLBox = int.from_bytes(file.read(8), "big")
|
||||
if XLBox <= 16:
|
||||
raise ValueError(f"Invalid XLBox at byte {file_start}.")
|
||||
if XLBox:
|
||||
header_length = 16
|
||||
box_length = XLBox
|
||||
else:
|
||||
header_length = 8
|
||||
if LBox == 0:
|
||||
box_length = os.fstat(file.fileno()).st_size - file_start
|
||||
else:
|
||||
box_length = LBox
|
||||
file.seek(file_start + 4)
|
||||
box_type = file.read(4)
|
||||
file.seek(file_start)
|
||||
return {
|
||||
"length": box_length,
|
||||
"type": box_type,
|
||||
"offset": header_length,
|
||||
}
|
||||
|
||||
file.seek(0)
|
||||
# Reject files missing required boxes. These two boxes are required to be at
|
||||
# the start and contain no values, so we can manually check there presence.
|
||||
# Signature box. (Redundant as has already been checked.)
|
||||
if file.read(12) != bytes.fromhex("0000000C 4A584C20 0D0A870A"):
|
||||
raise ValueError("Invalid signature box.")
|
||||
# File Type box.
|
||||
if file.read(20) != bytes.fromhex(
|
||||
"00000014 66747970 6A786C20 00000000 6A786C20"
|
||||
):
|
||||
raise ValueError("Invalid file type box.")
|
||||
|
||||
offset = 0
|
||||
offsets = []
|
||||
data_offset_not_found = True
|
||||
container_pointer = 32
|
||||
file_size = os.fstat(file.fileno()).st_size
|
||||
while data_offset_not_found:
|
||||
box = parse_box(file, container_pointer)
|
||||
match box["type"]:
|
||||
case b"jxlc":
|
||||
offset = container_pointer + box["offset"]
|
||||
data_offset_not_found = False
|
||||
case b"jxlp":
|
||||
file.seek(container_pointer + box["offset"])
|
||||
index = int.from_bytes(file.read(4), "big")
|
||||
offsets.append([index, container_pointer + box["offset"] + 4, box["length"] - box["offset"] - 4])
|
||||
container_pointer += box["length"]
|
||||
if container_pointer >= file_size:
|
||||
data_offset_not_found = False
|
||||
|
||||
if offsets:
|
||||
offsets.sort(key=lambda i: i[0])
|
||||
file.seek(0)
|
||||
|
||||
return decode_codestream(file, offset=offset, offsets=offsets)
|
||||
|
||||
|
||||
def get_jxl_size(path: str) -> Tuple[int,int]:
|
||||
with open(path, "rb") as file:
|
||||
if file.read(2) == bytes.fromhex("FF0A"):
|
||||
return decode_codestream(file)
|
||||
return decode_container(file)
|
||||
@@ -118,14 +118,16 @@ except:
|
||||
# JPEG-XL on Linux
|
||||
try:
|
||||
from jxlpy import JXLImagePlugin
|
||||
from library.jpeg_xl_util import get_jxl_size
|
||||
|
||||
IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
|
||||
except:
|
||||
pass
|
||||
|
||||
# JPEG-XL on Windows
|
||||
# JPEG-XL on Linux and Windows
|
||||
try:
|
||||
import pillow_jxl
|
||||
from library.jpeg_xl_util import get_jxl_size
|
||||
|
||||
IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
|
||||
except:
|
||||
@@ -1156,6 +1158,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
|
||||
def get_image_size(self, image_path):
|
||||
if image_path.endswith(".jxl") or image_path.endswith(".JXL"):
|
||||
return get_jxl_size(image_path)
|
||||
return imagesize.get(image_path)
|
||||
|
||||
def load_image_with_face_info(self, subset: BaseSubset, image_path: str, alpha_mask=False):
|
||||
|
||||
Reference in New Issue
Block a user