mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
Compare commits
869 Commits
v0.8.2
...
vae_batch_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dbd835ee4b | ||
|
|
5a18a03ffc | ||
|
|
572cc3efb8 | ||
|
|
52c8dec953 | ||
|
|
4589262f8f | ||
|
|
c56dc90b26 | ||
|
|
ee0f754b08 | ||
|
|
606e6875d2 | ||
|
|
fd36fd1aa9 | ||
|
|
92845e8806 | ||
|
|
f1423a7229 | ||
|
|
b822b7e60b | ||
|
|
ede3470260 | ||
|
|
b3c56b22bd | ||
|
|
583ab27b3c | ||
|
|
aa5978dffd | ||
|
|
aaa26bb882 | ||
|
|
d0b5c0e5cf | ||
|
|
59d98e45a9 | ||
|
|
3149b2771f | ||
|
|
96a133c998 | ||
|
|
1f432e2c0e | ||
|
|
9e9a13aa8a | ||
|
|
93a4efabb5 | ||
|
|
381303d64f | ||
|
|
0181b7a042 | ||
|
|
182544dcce | ||
|
|
8ebe858f89 | ||
|
|
a0f11730f7 | ||
|
|
6364379f17 | ||
|
|
5253a38783 | ||
|
|
8f4ee8fc34 | ||
|
|
367f348430 | ||
|
|
89f0d27a59 | ||
|
|
d40f5b1e4e | ||
|
|
8aa126582e | ||
|
|
e8b3254858 | ||
|
|
16cef81aea | ||
|
|
d151833526 | ||
|
|
936d333ff4 | ||
|
|
f974c6b257 | ||
|
|
5d5a7d2acf | ||
|
|
1eddac26b0 | ||
|
|
8e6817b0c2 | ||
|
|
d93ad90a71 | ||
|
|
7197266703 | ||
|
|
5b210ad717 | ||
|
|
b81bcd0b01 | ||
|
|
6f4d365775 | ||
|
|
a4f3a9fc1a | ||
|
|
b425466e7b | ||
|
|
c8be141ae0 | ||
|
|
0b25a05e3c | ||
|
|
3647d065b5 | ||
|
|
620a06f517 | ||
|
|
564ec5fb7f | ||
|
|
7e90cdd47a | ||
|
|
e5b5c7e1db | ||
|
|
ea53290f62 | ||
|
|
75933d70a1 | ||
|
|
aa2bde7ece | ||
|
|
3f49053c90 | ||
|
|
acdca2abb7 | ||
|
|
ba5251168a | ||
|
|
272f4c3775 | ||
|
|
734333d0c9 | ||
|
|
2f69f4dbdb | ||
|
|
9a415ba965 | ||
|
|
3d79239be4 | ||
|
|
ec350c83eb | ||
|
|
49651892ce | ||
|
|
1fcac98280 | ||
|
|
b286304e5f | ||
|
|
ae409e83c9 | ||
|
|
5228db1548 | ||
|
|
f4a0047865 | ||
|
|
f68702f71c | ||
|
|
6e90c0f86c | ||
|
|
67fde015f7 | ||
|
|
386b7332c6 | ||
|
|
905f081798 | ||
|
|
59ae9ea20c | ||
|
|
efb2a128cd | ||
|
|
13df47516d | ||
|
|
7f2747176b | ||
|
|
ca1c129ffd | ||
|
|
545425c13e | ||
|
|
7729c4c8f9 | ||
|
|
d0128d18be | ||
|
|
58e9e146a3 | ||
|
|
4a36996134 | ||
|
|
dc7d5fb459 | ||
|
|
894037f2c6 | ||
|
|
9436b41061 | ||
|
|
7c22e12a39 | ||
|
|
6051fa8217 | ||
|
|
f3a010978c | ||
|
|
3c7496ae3f | ||
|
|
4671e23778 | ||
|
|
63337d9fe4 | ||
|
|
ee295c7d9f | ||
|
|
ab88b431b0 | ||
|
|
76b761943b | ||
|
|
cd80752175 | ||
|
|
177203818a | ||
|
|
344845b429 | ||
|
|
0911683717 | ||
|
|
a24db1d532 | ||
|
|
c5b803ce94 | ||
|
|
4a71687d20 | ||
|
|
de830b8941 | ||
|
|
45ec02b2a8 | ||
|
|
42c0a9e1fc | ||
|
|
0778dd9b1d | ||
|
|
0750859133 | ||
|
|
29f31d005f | ||
|
|
b6a3093216 | ||
|
|
86a2f3fd26 | ||
|
|
532f5c58a6 | ||
|
|
59b3b94faf | ||
|
|
f1ac81e07a | ||
|
|
e8529613d8 | ||
|
|
58b82a576e | ||
|
|
b833d47afe | ||
|
|
23ce75cf95 | ||
|
|
6acdbed967 | ||
|
|
c04e5dfe92 | ||
|
|
6e3c1d0b58 | ||
|
|
345daaa986 | ||
|
|
b489082495 | ||
|
|
25929dd0d7 | ||
|
|
ee9265cf26 | ||
|
|
0456858992 | ||
|
|
2bbb40ce51 | ||
|
|
4c61adc996 | ||
|
|
264167fa16 | ||
|
|
d6f158ddf6 | ||
|
|
1e61392cf2 | ||
|
|
9fde0d7972 | ||
|
|
556f3f1696 | ||
|
|
1231f5114c | ||
|
|
742bee9738 | ||
|
|
fcb2ff010c | ||
|
|
f8850296c8 | ||
|
|
c64d1a22fc | ||
|
|
1c63e7cc49 | ||
|
|
f4840ef29e | ||
|
|
bbf6bbd5ea | ||
|
|
1c0ae306e5 | ||
|
|
1f9ba40b8b | ||
|
|
695f38962c | ||
|
|
0522070d19 | ||
|
|
6604b36044 | ||
|
|
58bfa36d02 | ||
|
|
fbfc2753eb | ||
|
|
c8c3569df2 | ||
|
|
534059dea5 | ||
|
|
7470173044 | ||
|
|
d23c7322ee | ||
|
|
7f6e124c7c | ||
|
|
449c1c5c50 | ||
|
|
8743532963 | ||
|
|
cb89e0284e | ||
|
|
64bd5317dc | ||
|
|
62164e5792 | ||
|
|
05bb9183fa | ||
|
|
e89653975d | ||
|
|
f2d38e6cda | ||
|
|
d3305f975e | ||
|
|
8e378cf03d | ||
|
|
3cb8cb2d4f | ||
|
|
e425996a59 | ||
|
|
abff4b0ec7 | ||
|
|
2be336688d | ||
|
|
6bee18db4f | ||
|
|
8b36d907d8 | ||
|
|
3e5d89c76c | ||
|
|
2610e96e9e | ||
|
|
63738ecb07 | ||
|
|
5ab00f9b49 | ||
|
|
e369b9a252 | ||
|
|
09a3740f6c | ||
|
|
e3fd6c52a0 | ||
|
|
1dc873d9b4 | ||
|
|
14c9ba925f | ||
|
|
34e7f509c4 | ||
|
|
bdf9a8cc29 | ||
|
|
1476040787 | ||
|
|
cc11989755 | ||
|
|
0fe6320f09 | ||
|
|
14f642f88b | ||
|
|
a5a27fe4c3 | ||
|
|
7b61e9eb58 | ||
|
|
9c885e549d | ||
|
|
4f7f248071 | ||
|
|
89825d6898 | ||
|
|
dd3b846b54 | ||
|
|
e59e276fb9 | ||
|
|
2dd063a679 | ||
|
|
c7cadbc8c7 | ||
|
|
6593cfbec1 | ||
|
|
87f5224e2d | ||
|
|
928b9393da | ||
|
|
f40632bac6 | ||
|
|
be5860f8e2 | ||
|
|
575f583fd9 | ||
|
|
9dff44d785 | ||
|
|
740ec1d526 | ||
|
|
420a180d93 | ||
|
|
0b5229a955 | ||
|
|
2a61fc0784 | ||
|
|
31ca899b6b | ||
|
|
4dd4cd6ec8 | ||
|
|
35778f0218 | ||
|
|
b2660bbe74 | ||
|
|
2a188f07e6 | ||
|
|
e358b118af | ||
|
|
42f6edf3a8 | ||
|
|
ccfaa001e7 | ||
|
|
0047bb1fc3 | ||
|
|
fd2d879ac8 | ||
|
|
5c5b544b91 | ||
|
|
2bb0f547d7 | ||
|
|
2cb7a6db02 | ||
|
|
17cf249d76 | ||
|
|
cde90b8903 | ||
|
|
3fe94b058a | ||
|
|
92482c7a07 | ||
|
|
7feaae5f06 | ||
|
|
02bd76e6c7 | ||
|
|
26bd4540a6 | ||
|
|
8fac3c3b08 | ||
|
|
2a2042a762 | ||
|
|
b3248a8eef | ||
|
|
186aa5b97d | ||
|
|
b8d3feca77 | ||
|
|
123474d784 | ||
|
|
e877b306c8 | ||
|
|
6adb69be63 | ||
|
|
387b40ea37 | ||
|
|
e5ac095749 | ||
|
|
5eb6d209d5 | ||
|
|
f264f4091f | ||
|
|
5e86323f12 | ||
|
|
588ea9e123 | ||
|
|
bafd10d558 | ||
|
|
e54462a4a9 | ||
|
|
40ed54bfc0 | ||
|
|
43849030cf | ||
|
|
aab943cea3 | ||
|
|
81c0c965a2 | ||
|
|
5e32ee26a1 | ||
|
|
e0db59695f | ||
|
|
264328d117 | ||
|
|
82daa98fe8 | ||
|
|
9aa6f52ac3 | ||
|
|
830df4abcc | ||
|
|
9e23368e3d | ||
|
|
1434d8506f | ||
|
|
70a179e446 | ||
|
|
8c3c825b5f | ||
|
|
bdddc20d68 | ||
|
|
b502f58488 | ||
|
|
c9a1417157 | ||
|
|
ce5b532582 | ||
|
|
1e2f7b0e44 | ||
|
|
80bb3f4ecf | ||
|
|
d4e19fbd5e | ||
|
|
0af4edd8a6 | ||
|
|
75554867ce | ||
|
|
af8e216035 | ||
|
|
1065dd1b56 | ||
|
|
d4f7849592 | ||
|
|
a1255d637f | ||
|
|
db2b4d41b9 | ||
|
|
b649bbf2b6 | ||
|
|
731664b8c3 | ||
|
|
e070bd9973 | ||
|
|
ca44e3e447 | ||
|
|
150579db32 | ||
|
|
8549669f89 | ||
|
|
900d551a6a | ||
|
|
56b4ea963e | ||
|
|
014064fd81 | ||
|
|
56bf761164 | ||
|
|
0031d916f0 | ||
|
|
d2c549d7b2 | ||
|
|
f52fb66e8f | ||
|
|
5fba6f514a | ||
|
|
b1e6504007 | ||
|
|
b8ae745d0c | ||
|
|
c632af860e | ||
|
|
f8c5146d71 | ||
|
|
0286114bd2 | ||
|
|
e3c43bda49 | ||
|
|
623017f716 | ||
|
|
be14c06267 | ||
|
|
0e7c592933 | ||
|
|
e1b63c2249 | ||
|
|
8fc30f8205 | ||
|
|
138dac4aea | ||
|
|
7fe8e162cb | ||
|
|
aa932429d1 | ||
|
|
09b4d1e9b6 | ||
|
|
2c45d979e6 | ||
|
|
ef70aa7b42 | ||
|
|
d8d7142665 | ||
|
|
3cc5b8db99 | ||
|
|
2500f5a798 | ||
|
|
1275e148df | ||
|
|
2d5f7fa709 | ||
|
|
886ffb4d65 | ||
|
|
d02a6ef7c4 | ||
|
|
bfc3a65acd | ||
|
|
2244cf5b83 | ||
|
|
c65cf3812d | ||
|
|
74228c9953 | ||
|
|
5bb9f7fb1a | ||
|
|
e277b5789e | ||
|
|
ecaea909b1 | ||
|
|
c80c304779 | ||
|
|
ff4083b910 | ||
|
|
0d3058b65a | ||
|
|
d005652d03 | ||
|
|
43bfeea600 | ||
|
|
035c4a8552 | ||
|
|
f2bc820133 | ||
|
|
9f4dac5731 | ||
|
|
3de42b6edb | ||
|
|
886f75345c | ||
|
|
126159f7c4 | ||
|
|
83e3048cb0 | ||
|
|
ba08a89894 | ||
|
|
dece2c388f | ||
|
|
3028027e07 | ||
|
|
c2440f9e53 | ||
|
|
33e942e36e | ||
|
|
793999d116 | ||
|
|
d78f6a775c | ||
|
|
8bea039a8d | ||
|
|
012e7e63a5 | ||
|
|
0243c65877 | ||
|
|
8919b31145 | ||
|
|
56a63f01ae | ||
|
|
e0c3630203 | ||
|
|
d050638571 | ||
|
|
1567549220 | ||
|
|
fe2aa32484 | ||
|
|
1a0f5b0c38 | ||
|
|
822fe57859 | ||
|
|
a9aa52658a | ||
|
|
24b1fdb664 | ||
|
|
9249d00311 | ||
|
|
3ebb65f945 | ||
|
|
ce49ced699 | ||
|
|
a94bc84dec | ||
|
|
4296e286b8 | ||
|
|
392e8dedd8 | ||
|
|
2cd6aa281c | ||
|
|
bf91bea2e4 | ||
|
|
da94fd934e | ||
|
|
56a7bc171d | ||
|
|
1beddd84e5 | ||
|
|
65fb69f808 | ||
|
|
e74f58148c | ||
|
|
c1d16a76d6 | ||
|
|
ab7b231870 | ||
|
|
fba769222b | ||
|
|
29177d2f03 | ||
|
|
e1f23af1bc | ||
|
|
95ff9dba0c | ||
|
|
583d4a436c | ||
|
|
24f8975fb7 | ||
|
|
0535cd29b9 | ||
|
|
de4bb657b0 | ||
|
|
3957372ded | ||
|
|
b844c70d14 | ||
|
|
0b7927e50b | ||
|
|
706a48d50e | ||
|
|
d7e14721e2 | ||
|
|
9c757c2fba | ||
|
|
e7040669bc | ||
|
|
1286e00bb0 | ||
|
|
e74502117b | ||
|
|
bbd160b4ca | ||
|
|
a2ad7e5644 | ||
|
|
0cbe95bcc7 | ||
|
|
d8d15f1a7e | ||
|
|
96c677b459 | ||
|
|
be078bdaca | ||
|
|
9f44ef1330 | ||
|
|
6445bb2bc9 | ||
|
|
c9ff4de905 | ||
|
|
2d8ee3c280 | ||
|
|
0485f236a0 | ||
|
|
93d9fbf607 | ||
|
|
c15a3a1a65 | ||
|
|
43ad73860d | ||
|
|
b755ebd0a4 | ||
|
|
f4a0bea6dc | ||
|
|
734d2e5b2b | ||
|
|
f3ce80ef8f | ||
|
|
9d2860760d | ||
|
|
3387dc7306 | ||
|
|
57ae44eb61 | ||
|
|
1d7118a622 | ||
|
|
cefe52629e | ||
|
|
237317fffd | ||
|
|
a823fd9fb8 | ||
|
|
c7c666b182 | ||
|
|
d83f2e92da | ||
|
|
8311e88225 | ||
|
|
eaafa5c9da | ||
|
|
6dbfd47a59 | ||
|
|
fd68703f37 | ||
|
|
65b8a064f6 | ||
|
|
d10ff62a78 | ||
|
|
d29af146b8 | ||
|
|
ce144476cf | ||
|
|
62ec3e6424 | ||
|
|
de25945a93 | ||
|
|
0005867ba5 | ||
|
|
16bb5699ac | ||
|
|
319e4d9831 | ||
|
|
2889108d85 | ||
|
|
d9129522a6 | ||
|
|
90ed2dfb52 | ||
|
|
56cb2fc885 | ||
|
|
b7cff0a754 | ||
|
|
b65ae9b439 | ||
|
|
6abacf04da | ||
|
|
4f6d915d15 | ||
|
|
1e30aa83b4 | ||
|
|
92e7600cc2 | ||
|
|
ef510b3cb9 | ||
|
|
928e0fc096 | ||
|
|
1bcf8d600b | ||
|
|
f8f5b16958 | ||
|
|
826ab5ce2e | ||
|
|
25c9040f4f | ||
|
|
3a6154b7b0 | ||
|
|
2a3aefb4e4 | ||
|
|
35882f8d5b | ||
|
|
34f2315047 | ||
|
|
8fdfd8c857 | ||
|
|
8ecf0fc4bf | ||
|
|
930d709e3d | ||
|
|
daa6ad5165 | ||
|
|
a0cfb0894c | ||
|
|
6c0e8a5a17 | ||
|
|
a61cf73a5c | ||
|
|
3be712e3e0 | ||
|
|
0087a46e14 | ||
|
|
72287d39c7 | ||
|
|
ea9242653c | ||
|
|
d5c076cf90 | ||
|
|
4ca29edbff | ||
|
|
5639c2adc0 | ||
|
|
cf689e7aa6 | ||
|
|
2e89cd2cc6 | ||
|
|
1e8108fec9 | ||
|
|
81411a398e | ||
|
|
99744af53a | ||
|
|
afb971f9c3 | ||
|
|
bf9f798985 | ||
|
|
b0a980844a | ||
|
|
2d8fa3387a | ||
|
|
a4d27a232b | ||
|
|
98c91a7625 | ||
|
|
e1cd19c0c0 | ||
|
|
2b07a92c8d | ||
|
|
e17c42cb0d | ||
|
|
7e459c00b2 | ||
|
|
6ab48b09d8 | ||
|
|
388b3b4b74 | ||
|
|
dbed5126bd | ||
|
|
9381332020 | ||
|
|
6f6faf9b5a | ||
|
|
92b1f6d968 | ||
|
|
c62c95e862 | ||
|
|
9e72be0a13 | ||
|
|
486fe8f70a | ||
|
|
6e72a799c8 | ||
|
|
d034032a5d | ||
|
|
a450488928 | ||
|
|
ef535ec6bb | ||
|
|
7e688913ae | ||
|
|
25f77f6ef0 | ||
|
|
400955d3ea | ||
|
|
7367584e67 | ||
|
|
e45d3f8634 | ||
|
|
3921a4efda | ||
|
|
739a8969bc | ||
|
|
08ef886bfe | ||
|
|
35b6cb0cd1 | ||
|
|
8aaa1967bd | ||
|
|
e2d822cad7 | ||
|
|
7db4222119 | ||
|
|
9760d097b0 | ||
|
|
56d7651f08 | ||
|
|
9711c96f96 | ||
|
|
f5ce754bc2 | ||
|
|
4cf42cc5d4 | ||
|
|
0415d200f5 | ||
|
|
a7d5dabde3 | ||
|
|
4af36f9632 | ||
|
|
9e09a69df1 | ||
|
|
74f91c2ff7 | ||
|
|
d25ae361d0 | ||
|
|
82314ac2e7 | ||
|
|
8a0f12dde8 | ||
|
|
358f13f2c9 | ||
|
|
808d2d1f48 | ||
|
|
36b2e6fc28 | ||
|
|
da4d0fe016 | ||
|
|
231df197dd | ||
|
|
cdb2d9c516 | ||
|
|
aa850aa531 | ||
|
|
f6dbf7c419 | ||
|
|
a593e837f3 | ||
|
|
3d68754def | ||
|
|
b9bdd10129 | ||
|
|
96eb74f0cb | ||
|
|
68162172eb | ||
|
|
1db495127f | ||
|
|
31507b9901 | ||
|
|
002d75179a | ||
|
|
1a977e847a | ||
|
|
41dee60383 | ||
|
|
9ca7a5b6cc | ||
|
|
1f16b80e88 | ||
|
|
2e67978ee2 | ||
|
|
87526942a6 | ||
|
|
082f13658b | ||
|
|
b8896aad40 | ||
|
|
6f0e235f2c | ||
|
|
3d402927ef | ||
|
|
9dc7997803 | ||
|
|
3ea4fce5e0 | ||
|
|
c9de7c4e9a | ||
|
|
50e3d62474 | ||
|
|
ea18d5ba6d | ||
|
|
19086465e8 | ||
|
|
66cf435479 | ||
|
|
381598c8bb | ||
|
|
828a581e29 | ||
|
|
8f2ba27869 | ||
|
|
0b3e4f7ab6 | ||
|
|
4802e4aaec | ||
|
|
0fe4eafac9 | ||
|
|
d53ea22b2a | ||
|
|
a518e3c819 | ||
|
|
9dd1ee458c | ||
|
|
25f961bc77 | ||
|
|
e5268286bf | ||
|
|
56bb81c9e6 | ||
|
|
22413a5247 | ||
|
|
18d7597b0b | ||
|
|
4a441889d4 | ||
|
|
3259928ce4 | ||
|
|
1a104dc75e | ||
|
|
58fb64819a | ||
|
|
5bfe5e411b | ||
|
|
4ecbac131a | ||
|
|
4dbcef429b | ||
|
|
321e24d83b | ||
|
|
e5bab69e3a | ||
|
|
3eb27ced52 | ||
|
|
b2363f1021 | ||
|
|
0d96e10b3e | ||
|
|
fc85496f7e | ||
|
|
2870be9b52 | ||
|
|
71ad3c0f45 | ||
|
|
ffce3b5098 | ||
|
|
a4c3155148 | ||
|
|
58cadf476b | ||
|
|
d50c1b3c5c | ||
|
|
e8cfd4ba1d | ||
|
|
fb12b6d8e5 | ||
|
|
00513b9b70 | ||
|
|
da6fea3d97 | ||
|
|
f2dd43e198 | ||
|
|
db6752901f | ||
|
|
febc5c59fa | ||
|
|
4c798129b0 | ||
|
|
38e4c602b1 | ||
|
|
e4d9e3c843 | ||
|
|
de0e0b9468 | ||
|
|
c68baae480 | ||
|
|
47187f7079 | ||
|
|
e3ddd1fbbe | ||
|
|
0640f017ab | ||
|
|
2f19175dfe | ||
|
|
146edce693 | ||
|
|
153764a687 | ||
|
|
589c2aa025 | ||
|
|
16677da0d9 | ||
|
|
a384bf2187 | ||
|
|
1c296f7229 | ||
|
|
e96a5217c3 | ||
|
|
39b82f26e5 | ||
|
|
3701507874 | ||
|
|
78020936d2 | ||
|
|
9ddb4d7a01 | ||
|
|
8d1b1acd33 | ||
|
|
02298e3c4a | ||
|
|
44190416c6 | ||
|
|
3c8193f642 | ||
|
|
c6a437054a | ||
|
|
1ffc0b330a | ||
|
|
e01e148705 | ||
|
|
e9f3a622f4 | ||
|
|
7983d3db5f | ||
|
|
bee8cee7e8 | ||
|
|
f3d2cf22ff | ||
|
|
6dbc23cf63 | ||
|
|
c1ba0b4356 | ||
|
|
607e041f3d | ||
|
|
793aeb94da | ||
|
|
b56d5f7801 | ||
|
|
017b82ebe3 | ||
|
|
2a359e0a41 | ||
|
|
3fd8cdc55d | ||
|
|
7fe81502d0 | ||
|
|
52e64c69cf | ||
|
|
58c2d856ae | ||
|
|
8db0cadcee | ||
|
|
dbb7bb288e | ||
|
|
969f82ab47 | ||
|
|
834445a1d6 | ||
|
|
fdbb03c360 | ||
|
|
040e26ff1d | ||
|
|
0540c33aca | ||
|
|
52652cba1a | ||
|
|
5cb145d13b | ||
|
|
b886d0a359 | ||
|
|
4477116a64 | ||
|
|
2c9db5d9f2 | ||
|
|
fc374375de | ||
|
|
feefcf256e | ||
|
|
64916a35b2 | ||
|
|
4f203ce40d | ||
|
|
68467bdf4d | ||
|
|
fde8026c2d | ||
|
|
89ad69b6a0 | ||
|
|
459b12539b | ||
|
|
3b251b758d | ||
|
|
229c5a38ef | ||
|
|
36d4023431 | ||
|
|
086f6000f2 | ||
|
|
75833e84a1 | ||
|
|
71e2c91330 | ||
|
|
bfb352bc43 | ||
|
|
c973b29da4 | ||
|
|
683f3d6ab3 | ||
|
|
dfa30790a9 | ||
|
|
d30ebb205c | ||
|
|
90b18795fc | ||
|
|
089727b5ee | ||
|
|
921036dd91 | ||
|
|
cd587ce62c | ||
|
|
1933ab4b48 | ||
|
|
b748b48dbb | ||
|
|
c7691607ea | ||
|
|
f99fe281cb | ||
|
|
80e9f72234 | ||
|
|
2258a1b753 | ||
|
|
059ee047f3 | ||
|
|
2c2ca9d726 | ||
|
|
f5323e3c4b | ||
|
|
cae5aa0a56 | ||
|
|
6ba84288d9 | ||
|
|
434dc408f9 | ||
|
|
ae3f625739 | ||
|
|
f1f30ab418 | ||
|
|
bc586ce190 | ||
|
|
4012fd24f6 | ||
|
|
954731d564 | ||
|
|
dd9763be31 | ||
|
|
b86af6798d | ||
|
|
6f7e93d5cc | ||
|
|
6c08e97e1f | ||
|
|
78e0a7630c | ||
|
|
c86e356013 | ||
|
|
5a2afb3588 | ||
|
|
ab1e389347 | ||
|
|
ea05e3fd5b | ||
|
|
a2b8531627 | ||
|
|
c24422fb9d | ||
|
|
9c4492b58a | ||
|
|
9bbb28c361 | ||
|
|
1648ade6da | ||
|
|
993b2ab4c1 | ||
|
|
8d5858826f | ||
|
|
025347214d | ||
|
|
ae97c8bfd1 | ||
|
|
381c44955e | ||
|
|
ad97410ba5 | ||
|
|
691f04322a | ||
|
|
79d1c12ab0 | ||
|
|
0c7baea88c | ||
|
|
f4a4c11cd3 | ||
|
|
594c7f7050 | ||
|
|
d17c0f5084 | ||
|
|
a35e7bd595 | ||
|
|
d9456020d7 | ||
|
|
fbb98f144e | ||
|
|
9b6b39f204 | ||
|
|
855add067b | ||
|
|
bf6cd4b9da | ||
|
|
3b0db0f17f | ||
|
|
119cc99fb0 | ||
|
|
5f6196e4c7 | ||
|
|
46331a9e8e | ||
|
|
cf09c6aa9f | ||
|
|
80dbbf5e48 | ||
|
|
7da41be281 | ||
|
|
e281e867e6 | ||
|
|
6c51c971d1 | ||
|
|
a71c35ccd9 | ||
|
|
5410a8c79b | ||
|
|
a7dff592d3 | ||
|
|
f9317052ed | ||
|
|
86e40fabbc | ||
|
|
3419c3de0d | ||
|
|
7081a0cf0f | ||
|
|
0ef4fe70f0 | ||
|
|
b5e8045df4 | ||
|
|
443f02942c | ||
|
|
0a8ec5224e | ||
|
|
6b1520a46b | ||
|
|
f811b115ba | ||
|
|
d05965dbad | ||
|
|
53954a1e2e | ||
|
|
86399407b2 | ||
|
|
948029fe61 | ||
|
|
5d7ed0dff0 | ||
|
|
bd7e2295b7 | ||
|
|
97524f1bda | ||
|
|
74c266a597 | ||
|
|
d282c45002 | ||
|
|
a6c41c6bea | ||
|
|
63e58f78e3 | ||
|
|
befbec5335 | ||
|
|
7d84ac2177 | ||
|
|
a51723cc2a | ||
|
|
095b8035e6 | ||
|
|
124ec45876 | ||
|
|
47359b8fac | ||
|
|
923b761ce3 | ||
|
|
78cfb01922 | ||
|
|
b558a5b73d | ||
|
|
14c9372a38 | ||
|
|
a9b64ffba8 | ||
|
|
e3ccf8fbf7 | ||
|
|
0e4a5738df | ||
|
|
eefb3cc1e7 | ||
|
|
074d32af20 | ||
|
|
2d7389185c | ||
|
|
4a5546d40e | ||
|
|
175193623b | ||
|
|
f2c727fc8c | ||
|
|
577e9913ca | ||
|
|
fccbee2727 | ||
|
|
e0acb10f31 | ||
|
|
5d5f39b6e6 | ||
|
|
e69d34103b | ||
|
|
a21218bdd5 | ||
|
|
81e8af6519 | ||
|
|
8b7c14246a | ||
|
|
52b3799989 | ||
|
|
738c397e1a | ||
|
|
0e703608f9 | ||
|
|
fb9110bac1 | ||
|
|
24092e6f21 | ||
|
|
f4132018c5 | ||
|
|
488d1870ab | ||
|
|
86279c8855 | ||
|
|
4d5186d1cf | ||
|
|
a6f1ed2e14 | ||
|
|
d1fb480887 | ||
|
|
75e4a951d0 | ||
|
|
42f3318e17 | ||
|
|
baa0e97ced | ||
|
|
71ebcc5e25 | ||
|
|
93bed60762 | ||
|
|
41d32c0be4 | ||
|
|
cbe9c5dc06 | ||
|
|
d3745db764 | ||
|
|
358ca205a3 | ||
|
|
c748719115 | ||
|
|
98f42d3a0b | ||
|
|
35c6053de3 | ||
|
|
20ae603221 | ||
|
|
672851e805 | ||
|
|
e579648ce9 | ||
|
|
e24d9606a2 | ||
|
|
75ecb047e2 | ||
|
|
f897d55781 | ||
|
|
7202596393 | ||
|
|
03f0816f86 | ||
|
|
5d9e2873f6 | ||
|
|
055f02e1e1 | ||
|
|
9b8ea12d34 | ||
|
|
74fe0453b2 | ||
|
|
a98fecaeb1 | ||
|
|
2445a5b74e | ||
|
|
62556619bd | ||
|
|
7d2a9268b9 | ||
|
|
3970bf4080 | ||
|
|
4295f91dcd | ||
|
|
2824312d5e | ||
|
|
64873c1b43 | ||
|
|
efd3b58973 | ||
|
|
6279b33736 | ||
|
|
5f6bf29e52 | ||
|
|
e793d7780d | ||
|
|
1492bcbfa2 | ||
|
|
bf2de5620c | ||
|
|
dfe08f395f | ||
|
|
6269682c56 | ||
|
|
2f9a344297 | ||
|
|
11aced3500 | ||
|
|
1567ce1e17 | ||
|
|
5cca1fdc40 | ||
|
|
9f0f0d573d | ||
|
|
716a92cbed | ||
|
|
a6a2b5a867 | ||
|
|
2ca4d0c831 | ||
|
|
7f948db158 | ||
|
|
9d7729c00d | ||
|
|
988dee02b9 | ||
|
|
d4b9568269 | ||
|
|
ccc3a481e7 | ||
|
|
8f6f734a6f | ||
|
|
cd19df49cd | ||
|
|
736365bdd5 | ||
|
|
6ceedb9448 | ||
|
|
930a3912a7 | ||
|
|
cf790d87c4 | ||
|
|
4e67fb8444 | ||
|
|
50f631c768 | ||
|
|
85bc371ebc | ||
|
|
322ee52c77 | ||
|
|
c576f80639 | ||
|
|
478156b4f7 | ||
|
|
afc38707d5 | ||
|
|
2e4bee6f24 | ||
|
|
da9b34fa26 | ||
|
|
d61ecb26fd | ||
|
|
07ef03d340 | ||
|
|
9278031e60 | ||
|
|
e8c3a02830 | ||
|
|
7a4e50705c | ||
|
|
2952bca520 | ||
|
|
29b6fa6212 | ||
|
|
2c50ea0403 | ||
|
|
298c6c2343 | ||
|
|
2897a89dfd | ||
|
|
610566fbb9 | ||
|
|
684954695d | ||
|
|
569ca72fc4 | ||
|
|
9c591bdb12 | ||
|
|
e545fdfd9a | ||
|
|
c89252101e | ||
|
|
a93c524b3a | ||
|
|
3de9e6c443 | ||
|
|
33c311ed19 | ||
|
|
5b19bda85c |
48
.github/workflows/tests.yml
vendored
Normal file
48
.github/workflows/tests.yml
vendored
Normal file
@@ -0,0 +1,48 @@
|
||||
name: Test with pytest
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- dev
|
||||
- sd3
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- dev
|
||||
- sd3
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
python-version: ["3.10"] # Python versions to test
|
||||
pytorch-version: ["2.4.0"] # PyTorch versions to test
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
# https://woodruffw.github.io/zizmor/audits/#artipacked
|
||||
persist-credentials: false
|
||||
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: 'pip'
|
||||
|
||||
- name: Install and update pip, setuptools, wheel
|
||||
run: |
|
||||
# Setuptools, wheel for compiling some packages
|
||||
python -m pip install --upgrade pip setuptools wheel
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
# Pre-install torch to pin version (requirements.txt has dependencies like transformers which requires pytorch)
|
||||
pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision==0.19.0 pytest==8.3.4
|
||||
pip install -r requirements.txt
|
||||
|
||||
- name: Test with pytest
|
||||
run: pytest # See pytest.ini for configuration
|
||||
|
||||
11
.github/workflows/typos.yml
vendored
11
.github/workflows/typos.yml
vendored
@@ -1,9 +1,11 @@
|
||||
---
|
||||
# yamllint disable rule:line-length
|
||||
name: Typos
|
||||
|
||||
on: # yamllint disable-line rule:truthy
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- dev
|
||||
pull_request:
|
||||
types:
|
||||
- opened
|
||||
@@ -16,6 +18,9 @@ jobs:
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
# https://woodruffw.github.io/zizmor/audits/#artipacked
|
||||
persist-credentials: false
|
||||
|
||||
- name: typos-action
|
||||
uses: crate-ci/typos@v1.16.26
|
||||
uses: crate-ci/typos@v1.28.1
|
||||
|
||||
105
README-ja.md
105
README-ja.md
@@ -1,12 +1,12 @@
|
||||
SDXLがサポートされました。sdxlブランチはmainブランチにマージされました。リポジトリを更新したときにはUpgradeの手順を実行してください。また accelerate のバージョンが上がっていますので、accelerate config を再度実行してください。
|
||||
|
||||
SDXL学習については[こちら](./README.md#sdxl-training)をご覧ください(英語です)。
|
||||
|
||||
## リポジトリについて
|
||||
Stable Diffusionの学習、画像生成、その他のスクリプトを入れたリポジトリです。
|
||||
|
||||
[README in English](./README.md) ←更新情報はこちらにあります
|
||||
|
||||
開発中のバージョンはdevブランチにあります。最新の変更点はdevブランチをご確認ください。
|
||||
|
||||
FLUX.1およびSD3/SD3.5対応はsd3ブランチで行っています。それらの学習を行う場合はsd3ブランチをご利用ください。
|
||||
|
||||
GUIやPowerShellスクリプトなど、より使いやすくする機能が[bmaltais氏のリポジトリ](https://github.com/bmaltais/kohya_ss)で提供されています(英語です)のであわせてご覧ください。bmaltais氏に感謝します。
|
||||
|
||||
以下のスクリプトがあります。
|
||||
@@ -21,6 +21,7 @@ GUIやPowerShellスクリプトなど、より使いやすくする機能が[bma
|
||||
|
||||
* [学習について、共通編](./docs/train_README-ja.md) : データ整備やオプションなど
|
||||
* [データセット設定](./docs/config_README-ja.md)
|
||||
* [SDXL学習](./docs/train_SDXL-en.md) (英語版)
|
||||
* [DreamBoothの学習について](./docs/train_db_README-ja.md)
|
||||
* [fine-tuningのガイド](./docs/fine_tune_README_ja.md):
|
||||
* [LoRAの学習について](./docs/train_network_README-ja.md)
|
||||
@@ -35,6 +36,8 @@ Python 3.10.6およびGitが必要です。
|
||||
- Python 3.10.6: https://www.python.org/ftp/python/3.10.6/python-3.10.6-amd64.exe
|
||||
- git: https://git-scm.com/download/win
|
||||
|
||||
Python 3.10.x、3.11.x、3.12.xでも恐らく動作しますが、3.10.6でテストしています。
|
||||
|
||||
PowerShellを使う場合、venvを使えるようにするためには以下の手順でセキュリティ設定を変更してください。
|
||||
(venvに限らずスクリプトの実行が可能になりますので注意してください。)
|
||||
|
||||
@@ -44,9 +47,7 @@ PowerShellを使う場合、venvを使えるようにするためには以下の
|
||||
|
||||
## Windows環境でのインストール
|
||||
|
||||
スクリプトはPyTorch 2.0.1でテストしています。PyTorch 1.12.1でも動作すると思われます。
|
||||
|
||||
以下の例ではPyTorchは2.0.1/CUDA 11.8版をインストールします。CUDA 11.6版やPyTorch 1.12.1を使う場合は適宜書き換えください。
|
||||
スクリプトはPyTorch 2.1.2でテストしています。PyTorch 2.2以降でも恐らく動作します。
|
||||
|
||||
(なお、python -m venv~の行で「python」とだけ表示された場合、py -m venv~のようにpythonをpyに変更してください。)
|
||||
|
||||
@@ -59,21 +60,23 @@ cd sd-scripts
|
||||
python -m venv venv
|
||||
.\venv\Scripts\activate
|
||||
|
||||
pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118
|
||||
pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu118
|
||||
pip install --upgrade -r requirements.txt
|
||||
pip install xformers==0.0.20
|
||||
pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu118
|
||||
|
||||
accelerate config
|
||||
```
|
||||
|
||||
コマンドプロンプトでも同一です。
|
||||
|
||||
(注:``python -m venv venv`` のほうが ``python -m venv --system-site-packages venv`` より安全そうなため書き換えました。globalなpythonにパッケージがインストールしてあると、後者だといろいろと問題が起きます。)
|
||||
注:`bitsandbytes==0.44.0`、`prodigyopt==1.0`、`lion-pytorch==0.0.6` は `requirements.txt` に含まれるようになりました。他のバージョンを使う場合は適宜インストールしてください。
|
||||
|
||||
この例では PyTorch および xfomers は2.1.2/CUDA 11.8版をインストールします。CUDA 12.1版やPyTorch 1.12.1を使う場合は適宜書き換えください。たとえば CUDA 12.1版の場合は `pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu121` および `pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu121` としてください。
|
||||
|
||||
PyTorch 2.2以降を用いる場合は、`torch==2.1.2` と `torchvision==0.16.2` 、および `xformers==0.0.23.post1` を適宜変更してください。
|
||||
|
||||
accelerate configの質問には以下のように答えてください。(bf16で学習する場合、最後の質問にはbf16と答えてください。)
|
||||
|
||||
※0.15.0から日本語環境では選択のためにカーソルキーを押すと落ちます(……)。数字キーの0、1、2……で選択できますので、そちらを使ってください。
|
||||
|
||||
```txt
|
||||
- This machine
|
||||
- No distributed training
|
||||
@@ -87,41 +90,6 @@ accelerate configの質問には以下のように答えてください。(bf1
|
||||
※場合によって ``ValueError: fp16 mixed precision requires a GPU`` というエラーが出ることがあるようです。この場合、6番目の質問(
|
||||
``What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:``)に「0」と答えてください。(id `0`のGPUが使われます。)
|
||||
|
||||
### オプション:`bitsandbytes`(8bit optimizer)を使う
|
||||
|
||||
`bitsandbytes`はオプションになりました。Linuxでは通常通りpipでインストールできます(0.41.1または以降のバージョンを推奨)。
|
||||
|
||||
Windowsでは0.35.0または0.41.1を推奨します。
|
||||
|
||||
- `bitsandbytes` 0.35.0: 安定しているとみられるバージョンです。AdamW8bitは使用できますが、他のいくつかの8bit optimizer、学習時の`full_bf16`オプションは使用できません。
|
||||
- `bitsandbytes` 0.41.1: Lion8bit、PagedAdamW8bit、PagedLion8bitをサポートします。`full_bf16`が使用できます。
|
||||
|
||||
注:`bitsandbytes` 0.35.0から0.41.0までのバージョンには問題があるようです。 https://github.com/TimDettmers/bitsandbytes/issues/659
|
||||
|
||||
以下の手順に従い、`bitsandbytes`をインストールしてください。
|
||||
|
||||
### 0.35.0を使う場合
|
||||
|
||||
PowerShellの例です。コマンドプロンプトではcpの代わりにcopyを使ってください。
|
||||
|
||||
```powershell
|
||||
cd sd-scripts
|
||||
.\venv\Scripts\activate
|
||||
pip install bitsandbytes==0.35.0
|
||||
|
||||
cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
|
||||
cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
|
||||
cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
|
||||
```
|
||||
|
||||
### 0.41.1を使う場合
|
||||
|
||||
jllllll氏の配布されている[こちら](https://github.com/jllllll/bitsandbytes-windows-webui) または他の場所から、Windows用のwhlファイルをインストールしてください。
|
||||
|
||||
```powershell
|
||||
python -m pip install bitsandbytes==0.41.1 --prefer-binary --extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui
|
||||
```
|
||||
|
||||
## アップグレード
|
||||
|
||||
新しいリリースがあった場合、以下のコマンドで更新できます。
|
||||
@@ -151,4 +119,47 @@ Conv2d 3x3への拡大は [cloneofsimo氏](https://github.com/cloneofsimo/lora)
|
||||
|
||||
[BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause
|
||||
|
||||
## その他の情報
|
||||
|
||||
### LoRAの名称について
|
||||
|
||||
`train_network.py` がサポートするLoRAについて、混乱を避けるため名前を付けました。ドキュメントは更新済みです。以下は当リポジトリ内の独自の名称です。
|
||||
|
||||
1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers、リエラと読みます)
|
||||
|
||||
Linear 層およびカーネルサイズ 1x1 の Conv2d 層に適用されるLoRA
|
||||
|
||||
2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers、セリアと読みます)
|
||||
|
||||
1.に加え、カーネルサイズ 3x3 の Conv2d 層に適用されるLoRA
|
||||
|
||||
デフォルトではLoRA-LierLaが使われます。LoRA-C3Lierを使う場合は `--network_args` に `conv_dim` を指定してください。
|
||||
|
||||
<!--
|
||||
LoRA-LierLa は[Web UI向け拡張](https://github.com/kohya-ss/sd-webui-additional-networks)、またはAUTOMATIC1111氏のWeb UIのLoRA機能で使用することができます。
|
||||
|
||||
LoRA-C3Lierを使いWeb UIで生成するには拡張を使用してください。
|
||||
-->
|
||||
|
||||
### 学習中のサンプル画像生成
|
||||
|
||||
プロンプトファイルは例えば以下のようになります。
|
||||
|
||||
```
|
||||
# prompt 1
|
||||
masterpiece, best quality, (1girl), in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
|
||||
|
||||
# prompt 2
|
||||
masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n (low quality, worst quality), bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
|
||||
```
|
||||
|
||||
`#` で始まる行はコメントになります。`--n` のように「ハイフン二個+英小文字」の形でオプションを指定できます。以下が使用可能できます。
|
||||
|
||||
* `--n` Negative prompt up to the next option.
|
||||
* `--w` Specifies the width of the generated image.
|
||||
* `--h` Specifies the height of the generated image.
|
||||
* `--d` Specifies the seed of the generated image.
|
||||
* `--l` Specifies the CFG scale of the generated image.
|
||||
* `--s` Specifies the number of steps in the generation.
|
||||
|
||||
`( )` や `[ ]` などの重みづけも動作します。
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
from library.device_utils import init_ipex
|
||||
init_ipex()
|
||||
|
||||
from typing import Union, List, Optional, Dict, Any, Tuple
|
||||
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# Instruction: https://github.com/marketplace/actions/typos-action#getting-started
|
||||
|
||||
[default.extend-identifiers]
|
||||
ddPn08="ddPn08"
|
||||
|
||||
[default.extend-words]
|
||||
NIN="NIN"
|
||||
@@ -27,6 +28,7 @@ rik="rik"
|
||||
koo="koo"
|
||||
yos="yos"
|
||||
wn="wn"
|
||||
hime="hime"
|
||||
|
||||
|
||||
[files]
|
||||
|
||||
389
docs/config_README-en.md
Normal file
389
docs/config_README-en.md
Normal file
@@ -0,0 +1,389 @@
|
||||
Original Source by kohya-ss
|
||||
|
||||
First version:
|
||||
A.I Translation by Model: NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO, editing by Darkstorm2150
|
||||
|
||||
Some parts are manually added.
|
||||
|
||||
# Config Readme
|
||||
|
||||
This README is about the configuration files that can be passed with the `--dataset_config` option.
|
||||
|
||||
## Overview
|
||||
|
||||
By passing a configuration file, users can make detailed settings.
|
||||
|
||||
* Multiple datasets can be configured
|
||||
* For example, by setting `resolution` for each dataset, they can be mixed and trained.
|
||||
* In training methods that support both the DreamBooth approach and the fine-tuning approach, datasets of the DreamBooth method and the fine-tuning method can be mixed.
|
||||
* Settings can be changed for each subset
|
||||
* A subset is a partition of the dataset by image directory or metadata. Several subsets make up a dataset.
|
||||
* Options such as `keep_tokens` and `flip_aug` can be set for each subset. On the other hand, options such as `resolution` and `batch_size` can be set for each dataset, and their values are common among subsets belonging to the same dataset. More details will be provided later.
|
||||
|
||||
The configuration file format can be JSON or TOML. Considering the ease of writing, it is recommended to use [TOML](https://toml.io/ja/v1.0.0-rc.2). The following explanation assumes the use of TOML.
|
||||
|
||||
|
||||
Here is an example of a configuration file written in TOML.
|
||||
|
||||
```toml
|
||||
[general]
|
||||
shuffle_caption = true
|
||||
caption_extension = '.txt'
|
||||
keep_tokens = 1
|
||||
|
||||
# This is a DreamBooth-style dataset
|
||||
[[datasets]]
|
||||
resolution = 512
|
||||
batch_size = 4
|
||||
keep_tokens = 2
|
||||
|
||||
[[datasets.subsets]]
|
||||
image_dir = 'C:\hoge'
|
||||
class_tokens = 'hoge girl'
|
||||
# This subset uses keep_tokens = 2 (the value of the parent datasets)
|
||||
|
||||
[[datasets.subsets]]
|
||||
image_dir = 'C:\fuga'
|
||||
class_tokens = 'fuga boy'
|
||||
keep_tokens = 3
|
||||
|
||||
[[datasets.subsets]]
|
||||
is_reg = true
|
||||
image_dir = 'C:\reg'
|
||||
class_tokens = 'human'
|
||||
keep_tokens = 1
|
||||
|
||||
# This is a fine-tuning dataset
|
||||
[[datasets]]
|
||||
resolution = [768, 768]
|
||||
batch_size = 2
|
||||
|
||||
[[datasets.subsets]]
|
||||
image_dir = 'C:\piyo'
|
||||
metadata_file = 'C:\piyo\piyo_md.json'
|
||||
# This subset uses keep_tokens = 1 (the value of [general])
|
||||
```
|
||||
|
||||
In this example, three directories are trained as a DreamBooth-style dataset at 512x512 (batch size 4), and one directory is trained as a fine-tuning dataset at 768x768 (batch size 2).
|
||||
|
||||
## Settings for datasets and subsets
|
||||
|
||||
Settings for datasets and subsets are divided into several registration locations.
|
||||
|
||||
* `[general]`
|
||||
* This is where options that apply to all datasets or all subsets are specified.
|
||||
* If there are options with the same name in the dataset-specific or subset-specific settings, the dataset-specific or subset-specific settings take precedence.
|
||||
* `[[datasets]]`
|
||||
* `datasets` is where settings for datasets are registered. This is where options that apply individually to each dataset are specified.
|
||||
* If there are subset-specific settings, the subset-specific settings take precedence.
|
||||
* `[[datasets.subsets]]`
|
||||
* `datasets.subsets` is where settings for subsets are registered. This is where options that apply individually to each subset are specified.
|
||||
|
||||
Here is an image showing the correspondence between image directories and registration locations in the previous example.
|
||||
|
||||
```
|
||||
C:\
|
||||
├─ hoge -> [[datasets.subsets]] No.1 ┐ ┐
|
||||
├─ fuga -> [[datasets.subsets]] No.2 |-> [[datasets]] No.1 |-> [general]
|
||||
├─ reg -> [[datasets.subsets]] No.3 ┘ |
|
||||
└─ piyo -> [[datasets.subsets]] No.4 --> [[datasets]] No.2 ┘
|
||||
```
|
||||
|
||||
The image directory corresponds to each `[[datasets.subsets]]`. Then, multiple `[[datasets.subsets]]` are combined to form one `[[datasets]]`. All `[[datasets]]` and `[[datasets.subsets]]` belong to `[general]`.
|
||||
|
||||
The available options for each registration location may differ, but if the same option is specified, the value in the lower registration location will take precedence. You can check how the `keep_tokens` option is handled in the previous example for better understanding.
|
||||
|
||||
Additionally, the available options may vary depending on the method that the learning approach supports.
|
||||
|
||||
* Options specific to the DreamBooth method
|
||||
* Options specific to the fine-tuning method
|
||||
* Options available when using the caption dropout technique
|
||||
|
||||
When using both the DreamBooth method and the fine-tuning method, they can be used together with a learning approach that supports both.
|
||||
When using them together, a point to note is that the method is determined based on the dataset, so it is not possible to mix DreamBooth method subsets and fine-tuning method subsets within the same dataset.
|
||||
In other words, if you want to use both methods together, you need to set up subsets of different methods belonging to different datasets.
|
||||
|
||||
In terms of program behavior, if the `metadata_file` option exists, it is determined to be a subset of fine-tuning. Therefore, for subsets belonging to the same dataset, as long as they are either "all have the `metadata_file` option" or "all have no `metadata_file` option," there is no problem.
|
||||
|
||||
Below, the available options will be explained. For options with the same name as the command-line argument, the explanation will be omitted in principle. Please refer to other READMEs.
|
||||
|
||||
### Common options for all learning methods
|
||||
|
||||
These are options that can be specified regardless of the learning method.
|
||||
|
||||
#### Data set specific options
|
||||
|
||||
These are options related to the configuration of the data set. They cannot be described in `datasets.subsets`.
|
||||
|
||||
|
||||
| Option Name | Example Setting | `[general]` | `[[datasets]]` |
|
||||
| ---- | ---- | ---- | ---- |
|
||||
| `batch_size` | `1` | o | o |
|
||||
| `bucket_no_upscale` | `true` | o | o |
|
||||
| `bucket_reso_steps` | `64` | o | o |
|
||||
| `enable_bucket` | `true` | o | o |
|
||||
| `max_bucket_reso` | `1024` | o | o |
|
||||
| `min_bucket_reso` | `128` | o | o |
|
||||
| `resolution` | `256`, `[512, 512]` | o | o |
|
||||
|
||||
* `batch_size`
|
||||
* This corresponds to the command-line argument `--train_batch_size`.
|
||||
* `max_bucket_reso`, `min_bucket_reso`
|
||||
* Specify the maximum and minimum resolutions of the bucket. It must be divisible by `bucket_reso_steps`.
|
||||
|
||||
These settings are fixed per dataset. That means that subsets belonging to the same dataset will share these settings. For example, if you want to prepare datasets with different resolutions, you can define them as separate datasets as shown in the example above, and set different resolutions for each.
|
||||
|
||||
#### Options for Subsets
|
||||
|
||||
These options are related to subset configuration.
|
||||
|
||||
| Option Name | Example | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
|
||||
| ---- | ---- | ---- | ---- | ---- |
|
||||
| `color_aug` | `false` | o | o | o |
|
||||
| `face_crop_aug_range` | `[1.0, 3.0]` | o | o | o |
|
||||
| `flip_aug` | `true` | o | o | o |
|
||||
| `keep_tokens` | `2` | o | o | o |
|
||||
| `num_repeats` | `10` | o | o | o |
|
||||
| `random_crop` | `false` | o | o | o |
|
||||
| `shuffle_caption` | `true` | o | o | o |
|
||||
| `caption_prefix` | `"masterpiece, best quality, "` | o | o | o |
|
||||
| `caption_suffix` | `", from side"` | o | o | o |
|
||||
| `caption_separator` | (not specified) | o | o | o |
|
||||
| `keep_tokens_separator` | `“|||”` | o | o | o |
|
||||
| `secondary_separator` | `“;;;”` | o | o | o |
|
||||
| `enable_wildcard` | `true` | o | o | o |
|
||||
| `resize_interpolation` | (not specified) | o | o | o |
|
||||
|
||||
* `num_repeats`
|
||||
* Specifies the number of repeats for images in a subset. This is equivalent to `--dataset_repeats` in fine-tuning but can be specified for any training method.
|
||||
* `caption_prefix`, `caption_suffix`
|
||||
* Specifies the prefix and suffix strings to be appended to the captions. Shuffling is performed with these strings included. Be cautious when using `keep_tokens`.
|
||||
* `caption_separator`
|
||||
* Specifies the string to separate the tags. The default is `,`. This option is usually not necessary to set.
|
||||
* `keep_tokens_separator`
|
||||
* Specifies the string to separate the parts to be fixed in the caption. For example, if you specify `aaa, bbb ||| ccc, ddd, eee, fff ||| ggg, hhh`, the parts `aaa, bbb` and `ggg, hhh` will remain, and the rest will be shuffled and dropped. The comma in between is not necessary. As a result, the prompt will be `aaa, bbb, eee, ccc, fff, ggg, hhh` or `aaa, bbb, fff, ccc, eee, ggg, hhh`, etc.
|
||||
* `secondary_separator`
|
||||
* Specifies an additional separator. The part separated by this separator is treated as one tag and is shuffled and dropped. It is then replaced by `caption_separator`. For example, if you specify `aaa;;;bbb;;;ccc`, it will be replaced by `aaa,bbb,ccc` or dropped together.
|
||||
* `enable_wildcard`
|
||||
* Enables wildcard notation. This will be explained later.
|
||||
* `resize_interpolation`
|
||||
* Specifies the interpolation method used when resizing images. Normally, there is no need to specify this. The following options can be specified: `lanczos`, `nearest`, `bilinear`, `linear`, `bicubic`, `cubic`, `area`, `box`. By default (when not specified), `area` is used for downscaling, and `lanczos` is used for upscaling. If this option is specified, the same interpolation method will be used for both upscaling and downscaling. When `lanczos` or `box` is specified, PIL is used; for other options, OpenCV is used.
|
||||
|
||||
### DreamBooth-specific options
|
||||
|
||||
DreamBooth-specific options only exist as subsets-specific options.
|
||||
|
||||
#### Subset-specific options
|
||||
|
||||
Options related to the configuration of DreamBooth subsets.
|
||||
|
||||
| Option Name | Example Setting | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
|
||||
| ---- | ---- | ---- | ---- | ---- |
|
||||
| `image_dir` | `'C:\hoge'` | - | - | o (required) |
|
||||
| `caption_extension` | `".txt"` | o | o | o |
|
||||
| `class_tokens` | `"sks girl"` | - | - | o |
|
||||
| `cache_info` | `false` | o | o | o |
|
||||
| `is_reg` | `false` | - | - | o |
|
||||
|
||||
Firstly, note that for `image_dir`, the path to the image files must be specified as being directly in the directory. Unlike the previous DreamBooth method, where images had to be placed in subdirectories, this is not compatible with that specification. Also, even if you name the folder something like "5_cat", the number of repeats of the image and the class name will not be reflected. If you want to set these individually, you will need to explicitly specify them using `num_repeats` and `class_tokens`.
|
||||
|
||||
* `image_dir`
|
||||
* Specifies the path to the image directory. This is a required option.
|
||||
* Images must be placed directly under the directory.
|
||||
* `class_tokens`
|
||||
* Sets the class tokens.
|
||||
* Only used during training when a corresponding caption file does not exist. The determination of whether or not to use it is made on a per-image basis. If `class_tokens` is not specified and a caption file is not found, an error will occur.
|
||||
* `cache_info`
|
||||
* Specifies whether to cache the image size and caption. If not specified, it is set to `false`. The cache is saved in `metadata_cache.json` in `image_dir`.
|
||||
* Caching speeds up the loading of the dataset after the first time. It is effective when dealing with thousands of images or more.
|
||||
* `is_reg`
|
||||
* Specifies whether the subset images are for normalization. If not specified, it is set to `false`, meaning that the images are not for normalization.
|
||||
|
||||
### Fine-tuning method specific options
|
||||
|
||||
The options for the fine-tuning method only exist for subset-specific options.
|
||||
|
||||
#### Subset-specific options
|
||||
|
||||
These options are related to the configuration of the fine-tuning method's subsets.
|
||||
|
||||
| Option name | Example setting | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
|
||||
| ---- | ---- | ---- | ---- | ---- |
|
||||
| `image_dir` | `'C:\hoge'` | - | - | o |
|
||||
| `metadata_file` | `'C:\piyo\piyo_md.json'` | - | - | o (required) |
|
||||
|
||||
* `image_dir`
|
||||
* Specify the path to the image directory. Unlike the DreamBooth method, specifying it is not mandatory, but it is recommended to do so.
|
||||
* The case where it is not necessary to specify is when the `--full_path` is added to the command line when generating the metadata file.
|
||||
* The images must be placed directly under the directory.
|
||||
* `metadata_file`
|
||||
* Specify the path to the metadata file used for the subset. This is a required option.
|
||||
* It is equivalent to the command-line argument `--in_json`.
|
||||
* Due to the specification that a metadata file must be specified for each subset, it is recommended to avoid creating a metadata file with images from different directories as a single metadata file. It is strongly recommended to prepare a separate metadata file for each image directory and register them as separate subsets.
|
||||
|
||||
### Options available when caption dropout method can be used
|
||||
|
||||
The options available when the caption dropout method can be used exist only for subsets. Regardless of whether it's the DreamBooth method or fine-tuning method, if it supports caption dropout, it can be specified.
|
||||
|
||||
#### Subset-specific options
|
||||
|
||||
Options related to the setting of subsets that caption dropout can be used for.
|
||||
|
||||
| Option Name | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
|
||||
| ---- | ---- | ---- | ---- |
|
||||
| `caption_dropout_every_n_epochs` | o | o | o |
|
||||
| `caption_dropout_rate` | o | o | o |
|
||||
| `caption_tag_dropout_rate` | o | o | o |
|
||||
|
||||
## Behavior when there are duplicate subsets
|
||||
|
||||
In the case of the DreamBooth dataset, if there are multiple `image_dir` directories with the same content, they are considered to be duplicate subsets. For the fine-tuning dataset, if there are multiple `metadata_file` files with the same content, they are considered to be duplicate subsets. If duplicate subsets exist in the dataset, subsequent subsets will be ignored.
|
||||
|
||||
However, if they belong to different datasets, they are not considered duplicates. For example, if you have subsets with the same `image_dir` in different datasets, they will not be considered duplicates. This is useful when you want to train with the same image but with different resolutions.
|
||||
|
||||
```toml
|
||||
# If data sets exist separately, they are not considered duplicates and are both used for training.
|
||||
|
||||
[[datasets]]
|
||||
resolution = 512
|
||||
|
||||
[[datasets.subsets]]
|
||||
image_dir = 'C:\hoge'
|
||||
|
||||
[[datasets]]
|
||||
resolution = 768
|
||||
|
||||
[[datasets.subsets]]
|
||||
image_dir = 'C:\hoge'
|
||||
```
|
||||
|
||||
## Command Line Argument and Configuration File
|
||||
|
||||
There are options in the configuration file that have overlapping roles with command line argument options.
|
||||
|
||||
The following command line argument options are ignored if a configuration file is passed:
|
||||
|
||||
* `--train_data_dir`
|
||||
* `--reg_data_dir`
|
||||
* `--in_json`
|
||||
|
||||
The following command line argument options are given priority over the configuration file options if both are specified simultaneously. In most cases, they have the same names as the corresponding options in the configuration file.
|
||||
|
||||
| Command Line Argument Option | Prioritized Configuration File Option |
|
||||
| ------------------------------- | ------------------------------------- |
|
||||
| `--bucket_no_upscale` | |
|
||||
| `--bucket_reso_steps` | |
|
||||
| `--caption_dropout_every_n_epochs` | |
|
||||
| `--caption_dropout_rate` | |
|
||||
| `--caption_extension` | |
|
||||
| `--caption_tag_dropout_rate` | |
|
||||
| `--color_aug` | |
|
||||
| `--dataset_repeats` | `num_repeats` |
|
||||
| `--enable_bucket` | |
|
||||
| `--face_crop_aug_range` | |
|
||||
| `--flip_aug` | |
|
||||
| `--keep_tokens` | |
|
||||
| `--min_bucket_reso` | |
|
||||
| `--random_crop` | |
|
||||
| `--resolution` | |
|
||||
| `--shuffle_caption` | |
|
||||
| `--train_batch_size` | `batch_size` |
|
||||
|
||||
## Error Guide
|
||||
|
||||
Currently, we are using an external library to check if the configuration file is written correctly, but the development has not been completed, and there is a problem that the error message is not clear. In the future, we plan to improve this problem.
|
||||
|
||||
As a temporary measure, we will list common errors and their solutions. If you encounter an error even though it should be correct or if the error content is not understandable, please contact us as it may be a bug.
|
||||
|
||||
* `voluptuous.error.MultipleInvalid: required key not provided @ ...`: This error occurs when a required option is not provided. It is highly likely that you forgot to specify the option or misspelled the option name.
|
||||
* The error location is indicated by `...` in the error message. For example, if you encounter an error like `voluptuous.error.MultipleInvalid: required key not provided @ data['datasets'][0]['subsets'][0]['image_dir']`, it means that the `image_dir` option does not exist in the 0th `subsets` of the 0th `datasets` setting.
|
||||
* `voluptuous.error.MultipleInvalid: expected int for dictionary value @ ...`: This error occurs when the specified value format is incorrect. It is highly likely that the value format is incorrect. The `int` part changes depending on the target option. The example configurations in this README may be helpful.
|
||||
* `voluptuous.error.MultipleInvalid: extra keys not allowed @ ...`: This error occurs when there is an option name that is not supported. It is highly likely that you misspelled the option name or mistakenly included it.
|
||||
|
||||
## Miscellaneous
|
||||
|
||||
### Multi-line captions
|
||||
|
||||
By setting `enable_wildcard = true`, multiple-line captions are also enabled. If the caption file consists of multiple lines, one line is randomly selected as the caption.
|
||||
|
||||
```txt
|
||||
1girl, hatsune miku, vocaloid, upper body, looking at viewer, microphone, stage
|
||||
a girl with a microphone standing on a stage
|
||||
detailed digital art of a girl with a microphone on a stage
|
||||
```
|
||||
|
||||
It can be combined with wildcard notation.
|
||||
|
||||
In metadata files, you can also specify multiple-line captions. In the `.json` metadata file, use `\n` to represent a line break. If the caption file consists of multiple lines, `merge_captions_to_metadata.py` will create a metadata file in this format.
|
||||
|
||||
The tags in the metadata (`tags`) are added to each line of the caption.
|
||||
|
||||
```json
|
||||
{
|
||||
"/path/to/image.png": {
|
||||
"caption": "a cartoon of a frog with the word frog on it\ntest multiline caption1\ntest multiline caption2",
|
||||
"tags": "open mouth, simple background, standing, no humans, animal, black background, frog, animal costume, animal focus"
|
||||
},
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
In this case, the actual caption will be `a cartoon of a frog with the word frog on it, open mouth, simple background ...`, `test multiline caption1, open mouth, simple background ...`, `test multiline caption2, open mouth, simple background ...`, etc.
|
||||
|
||||
### Example of configuration file : `secondary_separator`, wildcard notation, `keep_tokens_separator`, etc.
|
||||
|
||||
```toml
|
||||
[general]
|
||||
flip_aug = true
|
||||
color_aug = false
|
||||
resolution = [1024, 1024]
|
||||
|
||||
[[datasets]]
|
||||
batch_size = 6
|
||||
enable_bucket = true
|
||||
bucket_no_upscale = true
|
||||
caption_extension = ".txt"
|
||||
keep_tokens_separator= "|||"
|
||||
shuffle_caption = true
|
||||
caption_tag_dropout_rate = 0.1
|
||||
secondary_separator = ";;;" # subset 側に書くこともできます / can be written in the subset side
|
||||
enable_wildcard = true # 同上 / same as above
|
||||
|
||||
[[datasets.subsets]]
|
||||
image_dir = "/path/to/image_dir"
|
||||
num_repeats = 1
|
||||
|
||||
# ||| の前後はカンマは不要です(自動的に追加されます) / No comma is required before and after ||| (it is added automatically)
|
||||
caption_prefix = "1girl, hatsune miku, vocaloid |||"
|
||||
|
||||
# ||| の後はシャッフル、drop されず残ります / After |||, it is not shuffled or dropped and remains
|
||||
# 単純に文字列として連結されるので、カンマなどは自分で入れる必要があります / It is simply concatenated as a string, so you need to put commas yourself
|
||||
caption_suffix = ", anime screencap ||| masterpiece, rating: general"
|
||||
```
|
||||
|
||||
### Example of caption, secondary_separator notation: `secondary_separator = ";;;"`
|
||||
|
||||
```txt
|
||||
1girl, hatsune miku, vocaloid, upper body, looking at viewer, sky;;;cloud;;;day, outdoors
|
||||
```
|
||||
The part `sky;;;cloud;;;day` is replaced with `sky,cloud,day` without shuffling or dropping. When shuffling and dropping are enabled, it is processed as a whole (as one tag). For example, it becomes `vocaloid, 1girl, upper body, sky,cloud,day, outdoors, hatsune miku` (shuffled) or `vocaloid, 1girl, outdoors, looking at viewer, upper body, hatsune miku` (dropped).
|
||||
|
||||
### Example of caption, enable_wildcard notation: `enable_wildcard = true`
|
||||
|
||||
```txt
|
||||
1girl, hatsune miku, vocaloid, upper body, looking at viewer, {simple|white} background
|
||||
```
|
||||
`simple` or `white` is randomly selected, and it becomes `simple background` or `white background`.
|
||||
|
||||
```txt
|
||||
1girl, hatsune miku, vocaloid, {{retro style}}
|
||||
```
|
||||
If you want to include `{` or `}` in the tag string, double them like `{{` or `}}` (in this example, the actual caption used for training is `{retro style}`).
|
||||
|
||||
### Example of caption, `keep_tokens_separator` notation: `keep_tokens_separator = "|||"`
|
||||
|
||||
```txt
|
||||
1girl, hatsune miku, vocaloid ||| stage, microphone, white shirt, smile ||| best quality, rating: general
|
||||
```
|
||||
It becomes `1girl, hatsune miku, vocaloid, microphone, stage, white shirt, best quality, rating: general` or `1girl, hatsune miku, vocaloid, white shirt, smile, stage, microphone, best quality, rating: general` etc.
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
For non-Japanese speakers: this README is provided only in Japanese in the current state. Sorry for inconvenience. We will provide English version in the near future.
|
||||
|
||||
`--dataset_config` で渡すことができる設定ファイルに関する説明です。
|
||||
|
||||
## 概要
|
||||
@@ -120,6 +118,8 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学
|
||||
|
||||
* `batch_size`
|
||||
* コマンドライン引数の `--train_batch_size` と同等です。
|
||||
* `max_bucket_reso`, `min_bucket_reso`
|
||||
* bucketの最大、最小解像度を指定します。`bucket_reso_steps` で割り切れる必要があります。
|
||||
|
||||
これらの設定はデータセットごとに固定です。
|
||||
つまり、データセットに所属するサブセットはこれらの設定を共有することになります。
|
||||
@@ -140,12 +140,32 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学
|
||||
| `shuffle_caption` | `true` | o | o | o |
|
||||
| `caption_prefix` | `“masterpiece, best quality, ”` | o | o | o |
|
||||
| `caption_suffix` | `“, from side”` | o | o | o |
|
||||
| `caption_separator` | (通常は設定しません) | o | o | o |
|
||||
| `keep_tokens_separator` | `“|||”` | o | o | o |
|
||||
| `secondary_separator` | `“;;;”` | o | o | o |
|
||||
| `enable_wildcard` | `true` | o | o | o |
|
||||
| `resize_interpolation` |(通常は設定しません) | o | o | o |
|
||||
|
||||
* `num_repeats`
|
||||
* サブセットの画像の繰り返し回数を指定します。fine tuning における `--dataset_repeats` に相当しますが、`num_repeats` はどの学習方法でも指定可能です。
|
||||
* `caption_prefix`, `caption_suffix`
|
||||
* キャプションの前、後に付与する文字列を指定します。シャッフルはこれらの文字列を含めた状態で行われます。`keep_tokens` を指定する場合には注意してください。
|
||||
|
||||
* `caption_separator`
|
||||
* タグを区切る文字列を指定します。デフォルトは `,` です。このオプションは通常は設定する必要はありません。
|
||||
|
||||
* `keep_tokens_separator`
|
||||
* キャプションで固定したい部分を区切る文字列を指定します。たとえば `aaa, bbb ||| ccc, ddd, eee, fff ||| ggg, hhh` のように指定すると、`aaa, bbb` と `ggg, hhh` の部分はシャッフル、drop されず残ります。間のカンマは不要です。結果としてプロンプトは `aaa, bbb, eee, ccc, fff, ggg, hhh` や `aaa, bbb, fff, ccc, eee, ggg, hhh` などになります。
|
||||
|
||||
* `secondary_separator`
|
||||
* 追加の区切り文字を指定します。この区切り文字で区切られた部分は一つのタグとして扱われ、シャッフル、drop されます。その後、`caption_separator` に置き換えられます。たとえば `aaa;;;bbb;;;ccc` のように指定すると、`aaa,bbb,ccc` に置き換えられるか、まとめて drop されます。
|
||||
|
||||
* `enable_wildcard`
|
||||
* ワイルドカード記法および複数行キャプションを有効にします。ワイルドカード記法、複数行キャプションについては後述します。
|
||||
|
||||
* `resize_interpolation`
|
||||
* 画像のリサイズ時に使用する補間方法を指定します。通常は指定しなくて構いません。`lanczos`, `nearest`, `bilinear`, `linear`, `bicubic`, `cubic`, `area`, `box` が指定可能です。デフォルト(未指定時)は、縮小時は `area`、拡大時は `lanczos` になります。このオプションを指定すると、拡大時・縮小時とも同じ補間方法が使用されます。`lanczos`、`box`を指定するとPILが、それ以外を指定するとOpenCVが使用されます。
|
||||
|
||||
### DreamBooth 方式専用のオプション
|
||||
|
||||
DreamBooth 方式のオプションは、サブセット向けオプションのみ存在します。
|
||||
@@ -159,6 +179,7 @@ DreamBooth 方式のサブセットの設定に関わるオプションです。
|
||||
| `image_dir` | `‘C:\hoge’` | - | - | o(必須) |
|
||||
| `caption_extension` | `".txt"` | o | o | o |
|
||||
| `class_tokens` | `“sks girl”` | - | - | o |
|
||||
| `cache_info` | `false` | o | o | o |
|
||||
| `is_reg` | `false` | - | - | o |
|
||||
|
||||
まず注意点として、 `image_dir` には画像ファイルが直下に置かれているパスを指定する必要があります。従来の DreamBooth の手法ではサブディレクトリに画像を置く必要がありましたが、そちらとは仕様に互換性がありません。また、`5_cat` のようなフォルダ名にしても、画像の繰り返し回数とクラス名は反映されません。これらを個別に設定したい場合、`num_repeats` と `class_tokens` で明示的に指定する必要があることに注意してください。
|
||||
@@ -169,6 +190,9 @@ DreamBooth 方式のサブセットの設定に関わるオプションです。
|
||||
* `class_tokens`
|
||||
* クラストークンを設定します。
|
||||
* 画像に対応する caption ファイルが存在しない場合にのみ学習時に利用されます。利用するかどうかの判定は画像ごとに行います。`class_tokens` を指定しなかった場合に caption ファイルも見つからなかった場合にはエラーになります。
|
||||
* `cache_info`
|
||||
* 画像サイズ、キャプションをキャッシュするかどうかを指定します。指定しなかった場合は `false` になります。キャッシュは `image_dir` に `metadata_cache.json` というファイル名で保存されます。
|
||||
* キャッシュを行うと、二回目以降のデータセット読み込みが高速化されます。数千枚以上の画像を扱う場合には有効です。
|
||||
* `is_reg`
|
||||
* サブセットの画像が正規化用かどうかを指定します。指定しなかった場合は `false` として、つまり正規化画像ではないとして扱います。
|
||||
|
||||
@@ -280,4 +304,89 @@ resolution = 768
|
||||
* `voluptuous.error.MultipleInvalid: expected int for dictionary value @ ...`: 指定する値の形式が不正というエラーです。値の形式が間違っている可能性が高いです。`int` の部分は対象となるオプションによって変わります。この README に載っているオプションの「設定例」が役立つかもしれません。
|
||||
* `voluptuous.error.MultipleInvalid: extra keys not allowed @ ...`: 対応していないオプション名が存在している場合に発生するエラーです。オプション名を間違って記述しているか、誤って紛れ込んでいる可能性が高いです。
|
||||
|
||||
## その他
|
||||
|
||||
### 複数行キャプション
|
||||
|
||||
`enable_wildcard = true` を設定することで、複数行キャプションも同時に有効になります。キャプションファイルが複数の行からなる場合、ランダムに一つの行が選ばれてキャプションとして利用されます。
|
||||
|
||||
```txt
|
||||
1girl, hatsune miku, vocaloid, upper body, looking at viewer, microphone, stage
|
||||
a girl with a microphone standing on a stage
|
||||
detailed digital art of a girl with a microphone on a stage
|
||||
```
|
||||
|
||||
ワイルドカード記法と組み合わせることも可能です。
|
||||
|
||||
メタデータファイルでも同様に複数行キャプションを指定することができます。メタデータの .json 内には、`\n` を使って改行を表現してください。キャプションファイルが複数行からなる場合、`merge_captions_to_metadata.py` を使うと、この形式でメタデータファイルが作成されます。
|
||||
|
||||
メタデータのタグ (`tags`) は、キャプションの各行に追加されます。
|
||||
|
||||
```json
|
||||
{
|
||||
"/path/to/image.png": {
|
||||
"caption": "a cartoon of a frog with the word frog on it\ntest multiline caption1\ntest multiline caption2",
|
||||
"tags": "open mouth, simple background, standing, no humans, animal, black background, frog, animal costume, animal focus"
|
||||
},
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
この場合、実際のキャプションは `a cartoon of a frog with the word frog on it, open mouth, simple background ...` または `test multiline caption1, open mouth, simple background ...`、 `test multiline caption2, open mouth, simple background ...` 等になります。
|
||||
|
||||
### 設定ファイルの記述例:追加の区切り文字、ワイルドカード記法、`keep_tokens_separator` 等
|
||||
|
||||
```toml
|
||||
[general]
|
||||
flip_aug = true
|
||||
color_aug = false
|
||||
resolution = [1024, 1024]
|
||||
|
||||
[[datasets]]
|
||||
batch_size = 6
|
||||
enable_bucket = true
|
||||
bucket_no_upscale = true
|
||||
caption_extension = ".txt"
|
||||
keep_tokens_separator= "|||"
|
||||
shuffle_caption = true
|
||||
caption_tag_dropout_rate = 0.1
|
||||
secondary_separator = ";;;" # subset 側に書くこともできます / can be written in the subset side
|
||||
enable_wildcard = true # 同上 / same as above
|
||||
|
||||
[[datasets.subsets]]
|
||||
image_dir = "/path/to/image_dir"
|
||||
num_repeats = 1
|
||||
|
||||
# ||| の前後はカンマは不要です(自動的に追加されます) / No comma is required before and after ||| (it is added automatically)
|
||||
caption_prefix = "1girl, hatsune miku, vocaloid |||"
|
||||
|
||||
# ||| の後はシャッフル、drop されず残ります / After |||, it is not shuffled or dropped and remains
|
||||
# 単純に文字列として連結されるので、カンマなどは自分で入れる必要があります / It is simply concatenated as a string, so you need to put commas yourself
|
||||
caption_suffix = ", anime screencap ||| masterpiece, rating: general"
|
||||
```
|
||||
|
||||
### キャプション記述例、secondary_separator 記法:`secondary_separator = ";;;"` の場合
|
||||
|
||||
```txt
|
||||
1girl, hatsune miku, vocaloid, upper body, looking at viewer, sky;;;cloud;;;day, outdoors
|
||||
```
|
||||
`sky;;;cloud;;;day` の部分はシャッフル、drop されず `sky,cloud,day` に置換されます。シャッフル、drop が有効な場合、まとめて(一つのタグとして)処理されます。つまり `vocaloid, 1girl, upper body, sky,cloud,day, outdoors, hatsune miku` (シャッフル)や `vocaloid, 1girl, outdoors, looking at viewer, upper body, hatsune miku` (drop されたケース)などになります。
|
||||
|
||||
### キャプション記述例、ワイルドカード記法: `enable_wildcard = true` の場合
|
||||
|
||||
```txt
|
||||
1girl, hatsune miku, vocaloid, upper body, looking at viewer, {simple|white} background
|
||||
```
|
||||
ランダムに `simple` または `white` が選ばれ、`simple background` または `white background` になります。
|
||||
|
||||
```txt
|
||||
1girl, hatsune miku, vocaloid, {{retro style}}
|
||||
```
|
||||
タグ文字列に `{` や `}` そのものを含めたい場合は `{{` や `}}` のように二つ重ねてください(この例では実際に学習に用いられるキャプションは `{retro style}` になります)。
|
||||
|
||||
### キャプション記述例、`keep_tokens_separator` 記法: `keep_tokens_separator = "|||"` の場合
|
||||
|
||||
```txt
|
||||
1girl, hatsune miku, vocaloid ||| stage, microphone, white shirt, smile ||| best quality, rating: general
|
||||
```
|
||||
`1girl, hatsune miku, vocaloid, microphone, stage, white shirt, best quality, rating: general` や `1girl, hatsune miku, vocaloid, white shirt, smile, stage, microphone, best quality, rating: general` などになります。
|
||||
|
||||
@@ -452,3 +452,36 @@ python gen_img_diffusers.py --ckpt wd-v1-3-full-pruned-half.ckpt
|
||||
|
||||
- `--network_show_meta` : 追加ネットワークのメタデータを表示します。
|
||||
|
||||
|
||||
---
|
||||
|
||||
# About Gradual Latent
|
||||
|
||||
Gradual Latent is a Hires fix that gradually increases the size of the latent. `gen_img.py`, `sdxl_gen_img.py`, and `gen_img_diffusers.py` have the following options.
|
||||
|
||||
- `--gradual_latent_timesteps`: Specifies the timestep to start increasing the size of the latent. The default is None, which means Gradual Latent is not used. Please try around 750 at first.
|
||||
- `--gradual_latent_ratio`: Specifies the initial size of the latent. The default is 0.5, which means it starts with half the default latent size.
|
||||
- `--gradual_latent_ratio_step`: Specifies the ratio to increase the size of the latent. The default is 0.125, which means the latent size is gradually increased to 0.625, 0.75, 0.875, 1.0.
|
||||
- `--gradual_latent_ratio_every_n_steps`: Specifies the interval to increase the size of the latent. The default is 3, which means the latent size is increased every 3 steps.
|
||||
|
||||
Each option can also be specified with prompt options, `--glt`, `--glr`, `--gls`, `--gle`.
|
||||
|
||||
__Please specify `euler_a` for the sampler.__ Because the source code of the sampler is modified. It will not work with other samplers.
|
||||
|
||||
It is more effective with SD 1.5. It is quite subtle with SDXL.
|
||||
|
||||
# Gradual Latent について
|
||||
|
||||
latentのサイズを徐々に大きくしていくHires fixです。`gen_img.py` 、``sdxl_gen_img.py` 、`gen_img_diffusers.py` に以下のオプションが追加されています。
|
||||
|
||||
- `--gradual_latent_timesteps` : latentのサイズを大きくし始めるタイムステップを指定します。デフォルトは None で、Gradual Latentを使用しません。750 くらいから始めてみてください。
|
||||
- `--gradual_latent_ratio` : latentの初期サイズを指定します。デフォルトは 0.5 で、デフォルトの latent サイズの半分のサイズから始めます。
|
||||
- `--gradual_latent_ratio_step`: latentのサイズを大きくする割合を指定します。デフォルトは 0.125 で、latentのサイズを 0.625, 0.75, 0.875, 1.0 と徐々に大きくします。
|
||||
- `--gradual_latent_ratio_every_n_steps`: latentのサイズを大きくする間隔を指定します。デフォルトは 3 で、3ステップごとに latent のサイズを大きくします。
|
||||
|
||||
それぞれのオプションは、プロンプトオプション、`--glt`、`--glr`、`--gls`、`--gle` でも指定できます。
|
||||
|
||||
サンプラーに手を加えているため、__サンプラーに `euler_a` を指定してください。__ 他のサンプラーでは動作しません。
|
||||
|
||||
SD 1.5 のほうが効果があります。SDXL ではかなり微妙です。
|
||||
|
||||
|
||||
57
docs/masked_loss_README-ja.md
Normal file
57
docs/masked_loss_README-ja.md
Normal file
@@ -0,0 +1,57 @@
|
||||
## マスクロスについて
|
||||
|
||||
マスクロスは、入力画像のマスクで指定された部分だけ損失計算することで、画像の一部分だけを学習することができる機能です。
|
||||
たとえばキャラクタを学習したい場合、キャラクタ部分だけをマスクして学習することで、背景を無視して学習することができます。
|
||||
|
||||
マスクロスのマスクには、二種類の指定方法があります。
|
||||
|
||||
- マスク画像を用いる方法
|
||||
- 透明度(アルファチャネル)を使用する方法
|
||||
|
||||
なお、サンプルは [ずんずんPJイラスト/3Dデータ](https://zunko.jp/con_illust.html) の「AI画像モデル用学習データ」を使用しています。
|
||||
|
||||
### マスク画像を用いる方法
|
||||
|
||||
学習画像それぞれに対応するマスク画像を用意する方法です。学習画像と同じファイル名のマスク画像を用意し、それを学習画像と別のディレクトリに保存します。
|
||||
|
||||
- 学習画像
|
||||

|
||||
- マスク画像
|
||||

|
||||
|
||||
```.toml
|
||||
[[datasets.subsets]]
|
||||
image_dir = "/path/to/a_zundamon"
|
||||
caption_extension = ".txt"
|
||||
conditioning_data_dir = "/path/to/a_zundamon_mask"
|
||||
num_repeats = 8
|
||||
```
|
||||
|
||||
マスク画像は、学習画像と同じサイズで、学習する部分を白、無視する部分を黒で描画します。グレースケールにも対応しています(127 ならロス重みが 0.5 になります)。なお、正確にはマスク画像の R チャネルが用いられます。
|
||||
|
||||
DreamBooth 方式の dataset で、`conditioning_data_dir` で指定したディレクトリにマスク画像を保存してください。ControlNet のデータセットと同じですので、詳細は [ControlNet-LLLite](train_lllite_README-ja.md#データセットの準備) を参照してください。
|
||||
|
||||
### 透明度(アルファチャネル)を使用する方法
|
||||
|
||||
学習画像の透明度(アルファチャネル)がマスクとして使用されます。透明度が 0 の部分は無視され、255 の部分は学習されます。半透明の場合は、その透明度に応じてロス重みが変化します(127 ならおおむね 0.5)。
|
||||
|
||||

|
||||
|
||||
※それぞれの画像は透過PNG
|
||||
|
||||
学習時のスクリプトのオプションに `--alpha_mask` を指定するか、dataset の設定ファイルの subset で、`alpha_mask` を指定してください。たとえば、以下のようになります。
|
||||
|
||||
```toml
|
||||
[[datasets.subsets]]
|
||||
image_dir = "/path/to/image/dir"
|
||||
caption_extension = ".txt"
|
||||
num_repeats = 8
|
||||
alpha_mask = true
|
||||
```
|
||||
|
||||
## 学習時の注意事項
|
||||
|
||||
- 現時点では DreamBooth 方式の dataset のみ対応しています。
|
||||
- マスクは latents のサイズ、つまり 1/8 に縮小されてから適用されます。そのため、細かい部分(たとえばアホ毛やイヤリングなど)はうまく学習できない可能性があります。マスクをわずかに拡張するなどの工夫が必要かもしれません。
|
||||
- マスクロスを用いる場合、学習対象外の部分をキャプションに含める必要はないかもしれません。(要検証)
|
||||
- `alpha_mask` の場合、マスクの有無を切り替えると latents キャッシュが自動的に再生成されます。
|
||||
56
docs/masked_loss_README.md
Normal file
56
docs/masked_loss_README.md
Normal file
@@ -0,0 +1,56 @@
|
||||
## Masked Loss
|
||||
|
||||
Masked loss is a feature that allows you to train only part of an image by calculating the loss only for the part specified by the mask of the input image. For example, if you want to train a character, you can train only the character part by masking it, ignoring the background.
|
||||
|
||||
There are two ways to specify the mask for masked loss.
|
||||
|
||||
- Using a mask image
|
||||
- Using transparency (alpha channel) of the image
|
||||
|
||||
The sample uses the "AI image model training data" from [ZunZunPJ Illustration/3D Data](https://zunko.jp/con_illust.html).
|
||||
|
||||
### Using a mask image
|
||||
|
||||
This is a method of preparing a mask image corresponding to each training image. Prepare a mask image with the same file name as the training image and save it in a different directory from the training image.
|
||||
|
||||
- Training image
|
||||

|
||||
- Mask image
|
||||

|
||||
|
||||
```.toml
|
||||
[[datasets.subsets]]
|
||||
image_dir = "/path/to/a_zundamon"
|
||||
caption_extension = ".txt"
|
||||
conditioning_data_dir = "/path/to/a_zundamon_mask"
|
||||
num_repeats = 8
|
||||
```
|
||||
|
||||
The mask image is the same size as the training image, with the part to be trained drawn in white and the part to be ignored in black. It also supports grayscale (127 gives a loss weight of 0.5). The R channel of the mask image is used currently.
|
||||
|
||||
Use the dataset in the DreamBooth method, and save the mask image in the directory specified by `conditioning_data_dir`. It is the same as the ControlNet dataset, so please refer to [ControlNet-LLLite](train_lllite_README.md#Preparing-the-dataset) for details.
|
||||
|
||||
### Using transparency (alpha channel) of the image
|
||||
|
||||
The transparency (alpha channel) of the training image is used as a mask. The part with transparency 0 is ignored, the part with transparency 255 is trained. For semi-transparent parts, the loss weight changes according to the transparency (127 gives a weight of about 0.5).
|
||||
|
||||

|
||||
|
||||
※Each image is a transparent PNG
|
||||
|
||||
Specify `--alpha_mask` in the training script options or specify `alpha_mask` in the subset of the dataset configuration file. For example, it will look like this.
|
||||
|
||||
```toml
|
||||
[[datasets.subsets]]
|
||||
image_dir = "/path/to/image/dir"
|
||||
caption_extension = ".txt"
|
||||
num_repeats = 8
|
||||
alpha_mask = true
|
||||
```
|
||||
|
||||
## Notes on training
|
||||
|
||||
- At the moment, only the dataset in the DreamBooth method is supported.
|
||||
- The mask is applied after the size is reduced to 1/8, which is the size of the latents. Therefore, fine details (such as ahoge or earrings) may not be learned well. Some dilations of the mask may be necessary.
|
||||
- If using masked loss, it may not be necessary to include parts that are not to be trained in the caption. (To be verified)
|
||||
- In the case of `alpha_mask`, the latents cache is automatically regenerated when the enable/disable state of the mask is switched.
|
||||
@@ -648,7 +648,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
|
||||
|
||||
詳細については各自お調べください。
|
||||
|
||||
任意のスケジューラを使う場合、任意のオプティマイザと同様に、`--scheduler_args`でオプション引数を指定してください。
|
||||
任意のスケジューラを使う場合、任意のオプティマイザと同様に、`--lr_scheduler_args`でオプション引数を指定してください。
|
||||
|
||||
### オプティマイザの指定について
|
||||
|
||||
|
||||
@@ -582,7 +582,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
|
||||
|
||||
有关详细信息,请自行研究。
|
||||
|
||||
要使用任何调度程序,请像使用任何优化器一样使用“--scheduler_args”指定可选参数。
|
||||
要使用任何调度程序,请像使用任何优化器一样使用“--lr_scheduler_args”指定可选参数。
|
||||
### 关于指定优化器
|
||||
|
||||
使用 --optimizer_args 选项指定优化器选项参数。可以以key=value的格式指定多个值。此外,您可以指定多个值,以逗号分隔。例如,要指定 AdamW 优化器的参数,``--optimizer_args weight_decay=0.01 betas=.9,.999``。
|
||||
|
||||
84
docs/train_SDXL-en.md
Normal file
84
docs/train_SDXL-en.md
Normal file
@@ -0,0 +1,84 @@
|
||||
## SDXL training
|
||||
|
||||
The documentation will be moved to the training documentation in the future. The following is a brief explanation of the training scripts for SDXL.
|
||||
|
||||
### Training scripts for SDXL
|
||||
|
||||
- `sdxl_train.py` is a script for SDXL fine-tuning. The usage is almost the same as `fine_tune.py`, but it also supports DreamBooth dataset.
|
||||
- `--full_bf16` option is added. Thanks to KohakuBlueleaf!
|
||||
- This option enables the full bfloat16 training (includes gradients). This option is useful to reduce the GPU memory usage.
|
||||
- The full bfloat16 training might be unstable. Please use it at your own risk.
|
||||
- The different learning rates for each U-Net block are now supported in sdxl_train.py. Specify with `--block_lr` option. Specify 23 values separated by commas like `--block_lr 1e-3,1e-3 ... 1e-3`.
|
||||
- 23 values correspond to `0: time/label embed, 1-9: input blocks 0-8, 10-12: mid blocks 0-2, 13-21: output blocks 0-8, 22: out`.
|
||||
- `prepare_buckets_latents.py` now supports SDXL fine-tuning.
|
||||
|
||||
- `sdxl_train_network.py` is a script for LoRA training for SDXL. The usage is almost the same as `train_network.py`.
|
||||
|
||||
- Both scripts has following additional options:
|
||||
- `--cache_text_encoder_outputs` and `--cache_text_encoder_outputs_to_disk`: Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions.
|
||||
- `--no_half_vae`: Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs.
|
||||
|
||||
- `--weighted_captions` option is not supported yet for both scripts.
|
||||
|
||||
- `sdxl_train_textual_inversion.py` is a script for Textual Inversion training for SDXL. The usage is almost the same as `train_textual_inversion.py`.
|
||||
- `--cache_text_encoder_outputs` is not supported.
|
||||
- There are two options for captions:
|
||||
1. Training with captions. All captions must include the token string. The token string is replaced with multiple tokens.
|
||||
2. Use `--use_object_template` or `--use_style_template` option. The captions are generated from the template. The existing captions are ignored.
|
||||
- See below for the format of the embeddings.
|
||||
|
||||
- `--min_timestep` and `--max_timestep` options are added to each training script. These options can be used to train U-Net with different timesteps. The default values are 0 and 1000.
|
||||
|
||||
### Utility scripts for SDXL
|
||||
|
||||
- `tools/cache_latents.py` is added. This script can be used to cache the latents to disk in advance.
|
||||
- The options are almost the same as `sdxl_train.py'. See the help message for the usage.
|
||||
- Please launch the script as follows:
|
||||
`accelerate launch --num_cpu_threads_per_process 1 tools/cache_latents.py ...`
|
||||
- This script should work with multi-GPU, but it is not tested in my environment.
|
||||
|
||||
- `tools/cache_text_encoder_outputs.py` is added. This script can be used to cache the text encoder outputs to disk in advance.
|
||||
- The options are almost the same as `cache_latents.py` and `sdxl_train.py`. See the help message for the usage.
|
||||
|
||||
- `sdxl_gen_img.py` is added. This script can be used to generate images with SDXL, including LoRA, Textual Inversion and ControlNet-LLLite. See the help message for the usage.
|
||||
|
||||
### Tips for SDXL training
|
||||
|
||||
- The default resolution of SDXL is 1024x1024.
|
||||
- The fine-tuning can be done with 24GB GPU memory with the batch size of 1. For 24GB GPU, the following options are recommended __for the fine-tuning with 24GB GPU memory__:
|
||||
- Train U-Net only.
|
||||
- Use gradient checkpointing.
|
||||
- Use `--cache_text_encoder_outputs` option and caching latents.
|
||||
- Use Adafactor optimizer. RMSprop 8bit or Adagrad 8bit may work. AdamW 8bit doesn't seem to work.
|
||||
- The LoRA training can be done with 8GB GPU memory (10GB recommended). For reducing the GPU memory usage, the following options are recommended:
|
||||
- Train U-Net only.
|
||||
- Use gradient checkpointing.
|
||||
- Use `--cache_text_encoder_outputs` option and caching latents.
|
||||
- Use one of 8bit optimizers or Adafactor optimizer.
|
||||
- Use lower dim (4 to 8 for 8GB GPU).
|
||||
- `--network_train_unet_only` option is highly recommended for SDXL LoRA. Because SDXL has two text encoders, the result of the training will be unexpected.
|
||||
- PyTorch 2 seems to use slightly less GPU memory than PyTorch 1.
|
||||
- `--bucket_reso_steps` can be set to 32 instead of the default value 64. Smaller values than 32 will not work for SDXL training.
|
||||
|
||||
Example of the optimizer settings for Adafactor with the fixed learning rate:
|
||||
```toml
|
||||
optimizer_type = "adafactor"
|
||||
optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ]
|
||||
lr_scheduler = "constant_with_warmup"
|
||||
lr_warmup_steps = 100
|
||||
learning_rate = 4e-7 # SDXL original learning rate
|
||||
```
|
||||
|
||||
### Format of Textual Inversion embeddings for SDXL
|
||||
|
||||
```python
|
||||
from safetensors.torch import save_file
|
||||
|
||||
state_dict = {"clip_g": embs_for_text_encoder_1280, "clip_l": embs_for_text_encoder_768}
|
||||
save_file(state_dict, file)
|
||||
```
|
||||
|
||||
### ControlNet-LLLite
|
||||
|
||||
ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [documentation](./docs/train_lllite_README.md) for details.
|
||||
|
||||
@@ -21,9 +21,13 @@ ComfyUIのカスタムノードを用意しています。: https://github.com/k
|
||||
## モデルの学習
|
||||
|
||||
### データセットの準備
|
||||
通常のdatasetに加え、`conditioning_data_dir` で指定したディレクトリにconditioning imageを格納してください。conditioning imageは学習用画像と同じbasenameを持つ必要があります。また、conditioning imageは学習用画像と同じサイズに自動的にリサイズされます。conditioning imageにはキャプションファイルは不要です。
|
||||
DreamBooth 方式の dataset で、`conditioning_data_dir` で指定したディレクトリにconditioning imageを格納してください。
|
||||
|
||||
たとえば DreamBooth 方式でキャプションファイルを用いる場合の設定ファイルは以下のようになります。
|
||||
(finetuning 方式の dataset はサポートしていません。)
|
||||
|
||||
conditioning imageは学習用画像と同じbasenameを持つ必要があります。また、conditioning imageは学習用画像と同じサイズに自動的にリサイズされます。conditioning imageにはキャプションファイルは不要です。
|
||||
|
||||
たとえば、キャプションにフォルダ名ではなくキャプションファイルを用いる場合の設定ファイルは以下のようになります。
|
||||
|
||||
```toml
|
||||
[[datasets.subsets]]
|
||||
|
||||
@@ -26,7 +26,9 @@ Due to the limitations of the inference environment, only CrossAttention (attn1
|
||||
|
||||
### Preparing the dataset
|
||||
|
||||
In addition to the normal dataset, please store the conditioning image in the directory specified by `conditioning_data_dir`. The conditioning image must have the same basename as the training image. The conditioning image will be automatically resized to the same size as the training image. The conditioning image does not require a caption file.
|
||||
In addition to the normal DreamBooth method dataset, please store the conditioning image in the directory specified by `conditioning_data_dir`. The conditioning image must have the same basename as the training image. The conditioning image will be automatically resized to the same size as the training image. The conditioning image does not require a caption file.
|
||||
|
||||
(We do not support the finetuning method dataset.)
|
||||
|
||||
```toml
|
||||
[[datasets.subsets]]
|
||||
@@ -183,7 +185,7 @@ for img_file in img_files:
|
||||
|
||||
### Creating a dataset configuration file
|
||||
|
||||
You can use the command line arguments of `sdxl_train_control_net_lllite.py` to specify the conditioning image directory. However, if you want to use a `.toml` file, specify the conditioning image directory in `conditioning_data_dir`.
|
||||
You can use the command line argument `--conditioning_data_dir` of `sdxl_train_control_net_lllite.py` to specify the conditioning image directory. However, if you want to use a `.toml` file, specify the conditioning image directory in `conditioning_data_dir`.
|
||||
|
||||
```toml
|
||||
[general]
|
||||
|
||||
@@ -102,6 +102,8 @@ accelerate launch --num_cpu_threads_per_process 1 train_network.py
|
||||
* Text Encoderに関連するLoRAモジュールに、通常の学習率(--learning_rateオプションで指定)とは異なる学習率を使う時に指定します。Text Encoderのほうを若干低めの学習率(5e-5など)にしたほうが良い、という話もあるようです。
|
||||
* `--network_args`
|
||||
* 複数の引数を指定できます。後述します。
|
||||
* `--alpha_mask`
|
||||
* 画像のアルファ値をマスクとして使用します。透過画像を学習する際に使用します。[PR #1223](https://github.com/kohya-ss/sd-scripts/pull/1223)
|
||||
|
||||
`--network_train_unet_only` と `--network_train_text_encoder_only` の両方とも未指定時(デフォルト)はText EncoderとU-Netの両方のLoRAモジュールを有効にします。
|
||||
|
||||
@@ -181,16 +183,16 @@ python networks\extract_lora_from_dylora.py --model "foldername/dylora-model.saf
|
||||
|
||||
詳細は[PR #355](https://github.com/kohya-ss/sd-scripts/pull/355) をご覧ください。
|
||||
|
||||
SDXLは現在サポートしていません。
|
||||
|
||||
フルモデルの25個のブロックの重みを指定できます。最初のブロックに該当するLoRAは存在しませんが、階層別LoRA適用等との互換性のために25個としています。またconv2d3x3に拡張しない場合も一部のブロックにはLoRAが存在しませんが、記述を統一するため常に25個の値を指定してください。
|
||||
|
||||
SDXL では down/up 9 個、middle 3 個の値を指定してください。
|
||||
|
||||
`--network_args` で以下の引数を指定してください。
|
||||
|
||||
- `down_lr_weight` : U-Netのdown blocksの学習率の重みを指定します。以下が指定可能です。
|
||||
- ブロックごとの重み : `"down_lr_weight=0,0,0,0,0,0,1,1,1,1,1,1"` のように12個の数値を指定します。
|
||||
- ブロックごとの重み : `"down_lr_weight=0,0,0,0,0,0,1,1,1,1,1,1"` のように12個(SDXL では 9 個)の数値を指定します。
|
||||
- プリセットからの指定 : `"down_lr_weight=sine"` のように指定します(サインカーブで重みを指定します)。sine, cosine, linear, reverse_linear, zeros が指定可能です。また `"down_lr_weight=cosine+.25"` のように `+数値` を追加すると、指定した数値を加算します(0.25~1.25になります)。
|
||||
- `mid_lr_weight` : U-Netのmid blockの学習率の重みを指定します。`"down_lr_weight=0.5"` のように数値を一つだけ指定します。
|
||||
- `mid_lr_weight` : U-Netのmid blockの学習率の重みを指定します。`"down_lr_weight=0.5"` のように数値を一つだけ指定します(SDXL の場合は 3 個)。
|
||||
- `up_lr_weight` : U-Netのup blocksの学習率の重みを指定します。down_lr_weightと同様です。
|
||||
- 指定を省略した部分は1.0として扱われます。また重みを0にするとそのブロックのLoRAモジュールは作成されません。
|
||||
- `block_lr_zero_threshold` : 重みがこの値以下の場合、LoRAモジュールを作成しません。デフォルトは0です。
|
||||
@@ -215,6 +217,9 @@ network_args = [ "block_lr_zero_threshold=0.1", "down_lr_weight=sine+.5", "mid_l
|
||||
|
||||
フルモデルの25個のブロックのdim (rank)を指定できます。階層別学習率と同様に一部のブロックにはLoRAが存在しない場合がありますが、常に25個の値を指定してください。
|
||||
|
||||
SDXL では 23 個の値を指定してください。一部のブロックにはLoRA が存在しませんが、`sdxl_train.py` の[階層別学習率](./train_SDXL-en.md) との互換性のためです。
|
||||
対応は、`0: time/label embed, 1-9: input blocks 0-8, 10-12: mid blocks 0-2, 13-21: output blocks 0-8, 22: out` です。
|
||||
|
||||
`--network_args` で以下の引数を指定してください。
|
||||
|
||||
- `block_dims` : 各ブロックのdim (rank)を指定します。`"block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"` のように25個の数値を指定します。
|
||||
|
||||
@@ -101,6 +101,8 @@ LoRA的模型将会被保存在通过`--output_dir`选项指定的文件夹中
|
||||
* 当在Text Encoder相关的LoRA模块中使用与常规学习率(由`--learning_rate`选项指定)不同的学习率时,应指定此选项。可能最好将Text Encoder的学习率稍微降低(例如5e-5)。
|
||||
* `--network_args`
|
||||
* 可以指定多个参数。将在下面详细说明。
|
||||
* `--alpha_mask`
|
||||
* 使用图像的 Alpha 值作为遮罩。这在学习透明图像时使用。[PR #1223](https://github.com/kohya-ss/sd-scripts/pull/1223)
|
||||
|
||||
当未指定`--network_train_unet_only`和`--network_train_text_encoder_only`时(默认情况),将启用Text Encoder和U-Net的两个LoRA模块。
|
||||
|
||||
|
||||
88
docs/wd14_tagger_README-en.md
Normal file
88
docs/wd14_tagger_README-en.md
Normal file
@@ -0,0 +1,88 @@
|
||||
# Image Tagging using WD14Tagger
|
||||
|
||||
This document is based on the information from this github page (https://github.com/toriato/stable-diffusion-webui-wd14-tagger#mrsmilingwolfs-model-aka-waifu-diffusion-14-tagger).
|
||||
|
||||
Using onnx for inference is recommended. Please install onnx with the following command:
|
||||
|
||||
```powershell
|
||||
pip install onnx==1.15.0 onnxruntime-gpu==1.17.1
|
||||
```
|
||||
|
||||
The model weights will be automatically downloaded from Hugging Face.
|
||||
|
||||
# Usage
|
||||
|
||||
Run the script to perform tagging.
|
||||
|
||||
```powershell
|
||||
python finetune/tag_images_by_wd14_tagger.py --onnx --repo_id <model repo id> --batch_size <batch size> <training data folder>
|
||||
```
|
||||
|
||||
For example, if using the repository `SmilingWolf/wd-swinv2-tagger-v3` with a batch size of 4, and the training data is located in the parent folder `train_data`, it would be:
|
||||
|
||||
```powershell
|
||||
python tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-swinv2-tagger-v3 --batch_size 4 ..\train_data
|
||||
```
|
||||
|
||||
On the first run, the model files will be automatically downloaded to the `wd14_tagger_model` folder (the folder can be changed with an option).
|
||||
|
||||
Tag files will be created in the same directory as the training data images, with the same filename and a `.txt` extension.
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
## Example
|
||||
|
||||
To output in the Animagine XL 3.1 format, it would be as follows (enter on a single line in practice):
|
||||
|
||||
```
|
||||
python tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-swinv2-tagger-v3
|
||||
--batch_size 4 --remove_underscore --undesired_tags "PUT,YOUR,UNDESIRED,TAGS" --recursive
|
||||
--use_rating_tags_as_last_tag --character_tags_first --character_tag_expand
|
||||
--always_first_tags "1girl,1boy" ..\train_data
|
||||
```
|
||||
|
||||
## Available Repository IDs
|
||||
|
||||
[SmilingWolf's V2 and V3 models](https://huggingface.co/SmilingWolf) are available for use. Specify them in the format like `SmilingWolf/wd-vit-tagger-v3`. The default when omitted is `SmilingWolf/wd-v1-4-convnext-tagger-v2`.
|
||||
|
||||
# Options
|
||||
|
||||
## General Options
|
||||
|
||||
- `--onnx`: Use ONNX for inference. If not specified, TensorFlow will be used. If using TensorFlow, please install TensorFlow separately.
|
||||
- `--batch_size`: Number of images to process at once. Default is 1. Adjust according to VRAM capacity.
|
||||
- `--caption_extension`: File extension for caption files. Default is `.txt`.
|
||||
- `--max_data_loader_n_workers`: Maximum number of workers for DataLoader. Specifying a value of 1 or more will use DataLoader to speed up image loading. If unspecified, DataLoader will not be used.
|
||||
- `--thresh`: Confidence threshold for outputting tags. Default is 0.35. Lowering the value will assign more tags but accuracy will decrease.
|
||||
- `--general_threshold`: Confidence threshold for general tags. If omitted, same as `--thresh`.
|
||||
- `--character_threshold`: Confidence threshold for character tags. If omitted, same as `--thresh`.
|
||||
- `--recursive`: If specified, subfolders within the specified folder will also be processed recursively.
|
||||
- `--append_tags`: Append tags to existing tag files.
|
||||
- `--frequency_tags`: Output tag frequencies.
|
||||
- `--debug`: Debug mode. Outputs debug information if specified.
|
||||
|
||||
## Model Download
|
||||
|
||||
- `--model_dir`: Folder to save model files. Default is `wd14_tagger_model`.
|
||||
- `--force_download`: Re-download model files if specified.
|
||||
|
||||
## Tag Editing
|
||||
|
||||
- `--remove_underscore`: Remove underscores from output tags.
|
||||
- `--undesired_tags`: Specify tags not to output. Multiple tags can be specified, separated by commas. For example, `black eyes,black hair`.
|
||||
- `--use_rating_tags`: Output rating tags at the beginning of the tags.
|
||||
- `--use_rating_tags_as_last_tag`: Add rating tags at the end of the tags.
|
||||
- `--character_tags_first`: Output character tags first.
|
||||
- `--character_tag_expand`: Expand character tag series names. For example, split the tag `chara_name_(series)` into `chara_name, series`.
|
||||
- `--always_first_tags`: Specify tags to always output first when a certain tag appears in an image. Multiple tags can be specified, separated by commas. For example, `1girl,1boy`.
|
||||
- `--caption_separator`: Separate tags with this string in the output file. Default is `, `.
|
||||
- `--tag_replacement`: Perform tag replacement. Specify in the format `tag1,tag2;tag3,tag4`. If using `,` and `;`, escape them with `\`. \
|
||||
For example, specify `aira tsubase,aira tsubase (uniform)` (when you want to train a specific costume), `aira tsubase,aira tsubase\, heir of shadows` (when the series name is not included in the tag).
|
||||
|
||||
When using `tag_replacement`, it is applied after `character_tag_expand`.
|
||||
|
||||
When specifying `remove_underscore`, specify `undesired_tags`, `always_first_tags`, and `tag_replacement` without including underscores.
|
||||
|
||||
When specifying `caption_separator`, separate `undesired_tags` and `always_first_tags` with `caption_separator`. Always separate `tag_replacement` with `,`.
|
||||
88
docs/wd14_tagger_README-ja.md
Normal file
88
docs/wd14_tagger_README-ja.md
Normal file
@@ -0,0 +1,88 @@
|
||||
# WD14Taggerによるタグ付け
|
||||
|
||||
こちらのgithubページ(https://github.com/toriato/stable-diffusion-webui-wd14-tagger#mrsmilingwolfs-model-aka-waifu-diffusion-14-tagger )の情報を参考にさせていただきました。
|
||||
|
||||
onnx を用いた推論を推奨します。以下のコマンドで onnx をインストールしてください。
|
||||
|
||||
```powershell
|
||||
pip install onnx==1.15.0 onnxruntime-gpu==1.17.1
|
||||
```
|
||||
|
||||
モデルの重みはHugging Faceから自動的にダウンロードしてきます。
|
||||
|
||||
# 使い方
|
||||
|
||||
スクリプトを実行してタグ付けを行います。
|
||||
```
|
||||
python fintune/tag_images_by_wd14_tagger.py --onnx --repo_id <モデルのrepo id> --batch_size <バッチサイズ> <教師データフォルダ>
|
||||
```
|
||||
|
||||
レポジトリに `SmilingWolf/wd-swinv2-tagger-v3` を使用し、バッチサイズを4にして、教師データを親フォルダの `train_data`に置いた場合、以下のようになります。
|
||||
|
||||
```
|
||||
python tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-swinv2-tagger-v3 --batch_size 4 ..\train_data
|
||||
```
|
||||
|
||||
初回起動時にはモデルファイルが `wd14_tagger_model` フォルダに自動的にダウンロードされます(フォルダはオプションで変えられます)。
|
||||
|
||||
タグファイルが教師データ画像と同じディレクトリに、同じファイル名、拡張子.txtで作成されます。
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
## 記述例
|
||||
|
||||
Animagine XL 3.1 方式で出力する場合、以下のようになります(実際には 1 行で入力してください)。
|
||||
|
||||
```
|
||||
python tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-swinv2-tagger-v3
|
||||
--batch_size 4 --remove_underscore --undesired_tags "PUT,YOUR,UNDESIRED,TAGS" --recursive
|
||||
--use_rating_tags_as_last_tag --character_tags_first --character_tag_expand
|
||||
--always_first_tags "1girl,1boy" ..\train_data
|
||||
```
|
||||
|
||||
## 使用可能なリポジトリID
|
||||
|
||||
[SmilingWolf 氏の V2、V3 のモデル](https://huggingface.co/SmilingWolf)が使用可能です。`SmilingWolf/wd-vit-tagger-v3` のように指定してください。省略時のデフォルトは `SmilingWolf/wd-v1-4-convnext-tagger-v2` です。
|
||||
|
||||
# オプション
|
||||
|
||||
## 一般オプション
|
||||
|
||||
- `--onnx` : ONNX を使用して推論します。指定しない場合は TensorFlow を使用します。TensorFlow 使用時は別途 TensorFlow をインストールしてください。
|
||||
- `--batch_size` : 一度に処理する画像の数。デフォルトは1です。VRAMの容量に応じて増減してください。
|
||||
- `--caption_extension` : キャプションファイルの拡張子。デフォルトは `.txt` です。
|
||||
- `--max_data_loader_n_workers` : DataLoader の最大ワーカー数です。このオプションに 1 以上の数値を指定すると、DataLoader を用いて画像読み込みを高速化します。未指定時は DataLoader を用いません。
|
||||
- `--thresh` : 出力するタグの信頼度の閾値。デフォルトは0.35です。値を下げるとより多くのタグが付与されますが、精度は下がります。
|
||||
- `--general_threshold` : 一般タグの信頼度の閾値。省略時は `--thresh` と同じです。
|
||||
- `--character_threshold` : キャラクタータグの信頼度の閾値。省略時は `--thresh` と同じです。
|
||||
- `--recursive` : 指定すると、指定したフォルダ内のサブフォルダも再帰的に処理します。
|
||||
- `--append_tags` : 既存のタグファイルにタグを追加します。
|
||||
- `--frequency_tags` : タグの頻度を出力します。
|
||||
- `--debug` : デバッグモード。指定するとデバッグ情報を出力します。
|
||||
|
||||
## モデルのダウンロード
|
||||
|
||||
- `--model_dir` : モデルファイルの保存先フォルダ。デフォルトは `wd14_tagger_model` です。
|
||||
- `--force_download` : 指定するとモデルファイルを再ダウンロードします。
|
||||
|
||||
## タグ編集関連
|
||||
|
||||
- `--remove_underscore` : 出力するタグからアンダースコアを削除します。
|
||||
- `--undesired_tags` : 出力しないタグを指定します。カンマ区切りで複数指定できます。たとえば `black eyes,black hair` のように指定します。
|
||||
- `--use_rating_tags` : タグの最初にレーティングタグを出力します。
|
||||
- `--use_rating_tags_as_last_tag` : タグの最後にレーティングタグを追加します。
|
||||
- `--character_tags_first` : キャラクタータグを最初に出力します。
|
||||
- `--character_tag_expand` : キャラクタータグのシリーズ名を展開します。たとえば `chara_name_(series)` のタグを `chara_name, series` に分割します。
|
||||
- `--always_first_tags` : あるタグが画像に出力されたとき、そのタグを最初に出力するタグを指定します。カンマ区切りで複数指定できます。たとえば `1girl,1boy` のように指定します。
|
||||
- `--caption_separator` : 出力するファイルでタグをこの文字列で区切ります。デフォルトは `, ` です。
|
||||
- `--tag_replacement` : タグの置換を行います。`tag1,tag2;tag3,tag4` のように指定します。`,` および `;` を使う場合は `\` でエスケープしてください。\
|
||||
たとえば `aira tsubase,aira tsubase (uniform)` (特定の衣装を学習させたいとき)、`aira tsubase,aira tsubase\, heir of shadows` (シリーズ名がタグに含まれないとき)のように指定します。
|
||||
|
||||
`tag_replacement` は `character_tag_expand` の後に適用されます。
|
||||
|
||||
`remove_underscore` 指定時は、`undesired_tags`、`always_first_tags`、`tag_replacement` はアンダースコアを含めずに指定してください。
|
||||
|
||||
`caption_separator` 指定時は、`undesired_tags`、`always_first_tags` は `caption_separator` で区切ってください。`tag_replacement` は必ず `,` で区切ってください。
|
||||
|
||||
170
fine_tune.py
170
fine_tune.py
@@ -2,22 +2,29 @@
|
||||
# XXX dropped option: hypernetwork training
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
from multiprocessing import Value
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
from library.ipex_interop import init_ipex
|
||||
import torch
|
||||
from library import deepspeed_utils, strategy_base
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
|
||||
from library.utils import setup_logging, add_logging_arguments
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import library.train_util as train_util
|
||||
import library.config_util as config_util
|
||||
from library.config_util import (
|
||||
@@ -32,28 +39,39 @@ from library.custom_train_functions import (
|
||||
scale_v_prediction_loss_like_noise_prediction,
|
||||
apply_debiased_estimation,
|
||||
)
|
||||
import library.strategy_sd as strategy_sd
|
||||
|
||||
|
||||
def train(args):
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
deepspeed_utils.prepare_deepspeed_args(args)
|
||||
setup_logging(args, reset=True)
|
||||
|
||||
cache_latents = args.cache_latents
|
||||
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed) # 乱数系列を初期化する
|
||||
|
||||
tokenizer = train_util.load_tokenizer(args)
|
||||
tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
|
||||
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
|
||||
|
||||
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
|
||||
if cache_latents:
|
||||
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
|
||||
False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
|
||||
)
|
||||
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
|
||||
|
||||
# データセットを準備する
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, False, True))
|
||||
if args.dataset_config is not None:
|
||||
print(f"Load dataset config from {args.dataset_config}")
|
||||
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
print(
|
||||
logger.warning(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
@@ -72,21 +90,24 @@ def train(args):
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
blueprint = blueprint_generator.generate(user_config, args)
|
||||
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||
val_dataset_group = None
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(64)
|
||||
|
||||
if args.debug_dataset:
|
||||
train_util.debug_dataset(train_dataset_group)
|
||||
return
|
||||
if len(train_dataset_group) == 0:
|
||||
print(
|
||||
logger.error(
|
||||
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
|
||||
)
|
||||
return
|
||||
@@ -97,11 +118,12 @@ def train(args):
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
|
||||
# acceleratorを準備する
|
||||
print("prepare accelerator")
|
||||
logger.info("prepare accelerator")
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
|
||||
|
||||
# モデルを読み込む
|
||||
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||
@@ -152,15 +174,14 @@ def train(args):
|
||||
|
||||
# 学習を準備する
|
||||
if cache_latents:
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.to(accelerator.device, dtype=vae_dtype)
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||
|
||||
train_dataset_group.new_cache_latents(vae, accelerator)
|
||||
|
||||
vae.to("cpu")
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
@@ -184,10 +205,13 @@ def train(args):
|
||||
else:
|
||||
text_encoder.eval()
|
||||
|
||||
text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip)
|
||||
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
|
||||
|
||||
if not cache_latents:
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.to(accelerator.device, dtype=vae_dtype)
|
||||
|
||||
for m in training_models:
|
||||
m.requires_grad_(True)
|
||||
@@ -206,9 +230,13 @@ def train(args):
|
||||
accelerator.print("prepare optimizer, data loader etc.")
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params)
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
# prepare dataloader
|
||||
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
|
||||
# some strategies can be None
|
||||
train_dataset_group.set_current_strategies()
|
||||
|
||||
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group,
|
||||
batch_size=1,
|
||||
@@ -223,7 +251,9 @@ def train(args):
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||
)
|
||||
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
accelerator.print(
|
||||
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
||||
)
|
||||
|
||||
# データセット側にも学習ステップを送信
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
@@ -240,13 +270,23 @@ def train(args):
|
||||
unet.to(weight_dtype)
|
||||
text_encoder.to(weight_dtype)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
if args.train_text_encoder:
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||
if args.deepspeed:
|
||||
if args.train_text_encoder:
|
||||
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
|
||||
else:
|
||||
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
training_models = [ds_model]
|
||||
else:
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
if args.train_text_encoder:
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
else:
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||
if args.full_fp16:
|
||||
@@ -287,13 +327,22 @@ def train(args):
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
init_kwargs['wandb'] = {'name': args.wandb_run_name}
|
||||
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
||||
accelerator.init_trackers(
|
||||
"finetuning" if args.log_tracker_name is None else args.log_tracker_name,
|
||||
config=train_util.get_sanitized_config_or_none(args),
|
||||
init_kwargs=init_kwargs,
|
||||
)
|
||||
|
||||
# For --sample_at_first
|
||||
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
train_util.sample_images(
|
||||
accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
|
||||
)
|
||||
if len(accelerator.trackers) > 0:
|
||||
# log empty object to commit the sample images to wandb
|
||||
accelerator.log({}, step=0)
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
for epoch in range(num_train_epochs):
|
||||
@@ -305,32 +354,30 @@ def train(args):
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
|
||||
with accelerator.accumulate(*training_models):
|
||||
with torch.no_grad():
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device) # .to(dtype=weight_dtype)
|
||||
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
||||
else:
|
||||
# latentに変換
|
||||
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(weight_dtype)
|
||||
latents = latents * 0.18215
|
||||
b_size = latents.shape[0]
|
||||
|
||||
with torch.set_grad_enabled(args.train_text_encoder):
|
||||
# Get the text embedding for conditioning
|
||||
if args.weighted_captions:
|
||||
encoder_hidden_states = get_weighted_text_embeddings(
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
batch["captions"],
|
||||
accelerator.device,
|
||||
args.max_token_length // 75 if args.max_token_length else 1,
|
||||
clip_skip=args.clip_skip,
|
||||
)
|
||||
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
|
||||
encoder_hidden_states = text_encoding_strategy.encode_tokens_with_weights(
|
||||
tokenize_strategy, [text_encoder], input_ids_list, weights_list
|
||||
)[0]
|
||||
else:
|
||||
input_ids = batch["input_ids"].to(accelerator.device)
|
||||
encoder_hidden_states = train_util.get_hidden_states(
|
||||
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
|
||||
)
|
||||
input_ids = batch["input_ids_list"][0].to(accelerator.device)
|
||||
encoder_hidden_states = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, [text_encoder], [input_ids]
|
||||
)[0]
|
||||
if args.full_fp16:
|
||||
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
|
||||
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
@@ -346,9 +393,10 @@ def train(args):
|
||||
else:
|
||||
target = noise
|
||||
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
|
||||
# do not mean over batch dimension for snr weight or scale v-pred loss
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
if args.min_snr_gamma:
|
||||
@@ -356,11 +404,11 @@ def train(args):
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
if args.debiased_estimation_loss:
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
|
||||
|
||||
loss = loss.mean() # mean over batch dimension
|
||||
else:
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "mean", huber_c)
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
@@ -379,7 +427,7 @@ def train(args):
|
||||
global_step += 1
|
||||
|
||||
train_util.sample_images(
|
||||
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
|
||||
accelerator, args, None, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
@@ -404,7 +452,7 @@ def train(args):
|
||||
)
|
||||
|
||||
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
||||
if args.logging_dir is not None:
|
||||
if len(accelerator.trackers) > 0:
|
||||
logs = {"loss": current_loss}
|
||||
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
|
||||
accelerator.log(logs, step=global_step)
|
||||
@@ -417,7 +465,7 @@ def train(args):
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if args.logging_dir is not None:
|
||||
if len(accelerator.trackers) > 0:
|
||||
logs = {"loss/epoch": loss_recorder.moving_average}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
|
||||
@@ -442,7 +490,9 @@ def train(args):
|
||||
vae,
|
||||
)
|
||||
|
||||
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
train_util.sample_images(
|
||||
accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
|
||||
)
|
||||
|
||||
is_main_process = accelerator.is_main_process
|
||||
if is_main_process:
|
||||
@@ -451,7 +501,7 @@ def train(args):
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
if args.save_state and is_main_process:
|
||||
if is_main_process and (args.save_state or args.save_state_on_train_end):
|
||||
train_util.save_state_on_train_end(args, accelerator)
|
||||
|
||||
del accelerator # この後メモリを使うのでこれは消す
|
||||
@@ -461,21 +511,25 @@ def train(args):
|
||||
train_util.save_sd_model_on_train_end(
|
||||
args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
|
||||
)
|
||||
print("model saved.")
|
||||
logger.info("model saved.")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
add_logging_arguments(parser)
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, False, True, True)
|
||||
train_util.add_training_arguments(parser, False)
|
||||
deepspeed_utils.add_deepspeed_arguments(parser)
|
||||
train_util.add_sd_saving_arguments(parser)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
config_util.add_config_arguments(parser)
|
||||
custom_train_functions.add_custom_train_arguments(parser)
|
||||
|
||||
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
|
||||
parser.add_argument(
|
||||
"--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する"
|
||||
)
|
||||
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
|
||||
parser.add_argument(
|
||||
"--learning_rate_te",
|
||||
@@ -483,6 +537,11 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
default=None,
|
||||
help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_half_vae",
|
||||
action="store_true",
|
||||
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
@@ -491,6 +550,7 @@ if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
train_util.verify_command_line_training_args(args)
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
train(args)
|
||||
|
||||
@@ -21,6 +21,10 @@ import torch.nn.functional as F
|
||||
import os
|
||||
from urllib.parse import urlparse
|
||||
from timm.models.hub import download_cached_file
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BLIP_Base(nn.Module):
|
||||
def __init__(self,
|
||||
@@ -130,8 +134,9 @@ class BLIP_Decoder(nn.Module):
|
||||
def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
|
||||
image_embeds = self.visual_encoder(image)
|
||||
|
||||
if not sample:
|
||||
image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
|
||||
# recent version of transformers seems to do repeat_interleave automatically
|
||||
# if not sample:
|
||||
# image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
|
||||
|
||||
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
||||
model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}
|
||||
@@ -235,6 +240,6 @@ def load_checkpoint(model,url_or_filename):
|
||||
del state_dict[key]
|
||||
|
||||
msg = model.load_state_dict(state_dict,strict=False)
|
||||
print('load checkpoint from %s'%url_or_filename)
|
||||
logger.info('load checkpoint from %s'%url_or_filename)
|
||||
return model,msg
|
||||
|
||||
|
||||
@@ -8,6 +8,10 @@ import json
|
||||
import re
|
||||
|
||||
from tqdm import tqdm
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ')
|
||||
PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ')
|
||||
@@ -36,13 +40,13 @@ def clean_tags(image_key, tags):
|
||||
tokens = tags.split(", rating")
|
||||
if len(tokens) == 1:
|
||||
# WD14 taggerのときはこちらになるのでメッセージは出さない
|
||||
# print("no rating:")
|
||||
# print(f"{image_key} {tags}")
|
||||
# logger.info("no rating:")
|
||||
# logger.info(f"{image_key} {tags}")
|
||||
pass
|
||||
else:
|
||||
if len(tokens) > 2:
|
||||
print("multiple ratings:")
|
||||
print(f"{image_key} {tags}")
|
||||
logger.info("multiple ratings:")
|
||||
logger.info(f"{image_key} {tags}")
|
||||
tags = tokens[0]
|
||||
|
||||
tags = ", " + tags.replace(", ", ", , ") + ", " # カンマ付きで検索をするための身も蓋もない対策
|
||||
@@ -124,43 +128,43 @@ def clean_caption(caption):
|
||||
|
||||
def main(args):
|
||||
if os.path.exists(args.in_json):
|
||||
print(f"loading existing metadata: {args.in_json}")
|
||||
logger.info(f"loading existing metadata: {args.in_json}")
|
||||
with open(args.in_json, "rt", encoding='utf-8') as f:
|
||||
metadata = json.load(f)
|
||||
else:
|
||||
print("no metadata / メタデータファイルがありません")
|
||||
logger.error("no metadata / メタデータファイルがありません")
|
||||
return
|
||||
|
||||
print("cleaning captions and tags.")
|
||||
logger.info("cleaning captions and tags.")
|
||||
image_keys = list(metadata.keys())
|
||||
for image_key in tqdm(image_keys):
|
||||
tags = metadata[image_key].get('tags')
|
||||
if tags is None:
|
||||
print(f"image does not have tags / メタデータにタグがありません: {image_key}")
|
||||
logger.error(f"image does not have tags / メタデータにタグがありません: {image_key}")
|
||||
else:
|
||||
org = tags
|
||||
tags = clean_tags(image_key, tags)
|
||||
metadata[image_key]['tags'] = tags
|
||||
if args.debug and org != tags:
|
||||
print("FROM: " + org)
|
||||
print("TO: " + tags)
|
||||
logger.info("FROM: " + org)
|
||||
logger.info("TO: " + tags)
|
||||
|
||||
caption = metadata[image_key].get('caption')
|
||||
if caption is None:
|
||||
print(f"image does not have caption / メタデータにキャプションがありません: {image_key}")
|
||||
logger.error(f"image does not have caption / メタデータにキャプションがありません: {image_key}")
|
||||
else:
|
||||
org = caption
|
||||
caption = clean_caption(caption)
|
||||
metadata[image_key]['caption'] = caption
|
||||
if args.debug and org != caption:
|
||||
print("FROM: " + org)
|
||||
print("TO: " + caption)
|
||||
logger.info("FROM: " + org)
|
||||
logger.info("TO: " + caption)
|
||||
|
||||
# metadataを書き出して終わり
|
||||
print(f"writing metadata: {args.out_json}")
|
||||
logger.info(f"writing metadata: {args.out_json}")
|
||||
with open(args.out_json, "wt", encoding='utf-8') as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
print("done!")
|
||||
logger.info("done!")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
@@ -178,10 +182,10 @@ if __name__ == '__main__':
|
||||
|
||||
args, unknown = parser.parse_known_args()
|
||||
if len(unknown) == 1:
|
||||
print("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.")
|
||||
print("All captions and tags in the metadata are processed.")
|
||||
print("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。")
|
||||
print("メタデータ内のすべてのキャプションとタグが処理されます。")
|
||||
logger.warning("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.")
|
||||
logger.warning("All captions and tags in the metadata are processed.")
|
||||
logger.warning("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。")
|
||||
logger.warning("メタデータ内のすべてのキャプションとタグが処理されます。")
|
||||
args.in_json = args.out_json
|
||||
args.out_json = unknown[0]
|
||||
elif len(unknown) > 0:
|
||||
|
||||
@@ -9,14 +9,22 @@ from pathlib import Path
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, get_preferred_device
|
||||
init_ipex()
|
||||
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
from blip.blip import blip_decoder, is_url
|
||||
import library.train_util as train_util
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
DEVICE = get_preferred_device()
|
||||
|
||||
|
||||
IMAGE_SIZE = 384
|
||||
@@ -47,7 +55,7 @@ class ImageLoadingTransformDataset(torch.utils.data.Dataset):
|
||||
# convert to tensor temporarily so dataloader will accept it
|
||||
tensor = IMAGE_TRANSFORM(image)
|
||||
except Exception as e:
|
||||
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
||||
logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
||||
return None
|
||||
|
||||
return (tensor, img_path)
|
||||
@@ -74,21 +82,21 @@ def main(args):
|
||||
args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path
|
||||
|
||||
cwd = os.getcwd()
|
||||
print("Current Working Directory is: ", cwd)
|
||||
logger.info(f"Current Working Directory is: {cwd}")
|
||||
os.chdir("finetune")
|
||||
if not is_url(args.caption_weights) and not os.path.isfile(args.caption_weights):
|
||||
args.caption_weights = os.path.join("..", args.caption_weights)
|
||||
|
||||
print(f"load images from {args.train_data_dir}")
|
||||
logger.info(f"load images from {args.train_data_dir}")
|
||||
train_data_dir_path = Path(args.train_data_dir)
|
||||
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
||||
print(f"found {len(image_paths)} images.")
|
||||
logger.info(f"found {len(image_paths)} images.")
|
||||
|
||||
print(f"loading BLIP caption: {args.caption_weights}")
|
||||
logger.info(f"loading BLIP caption: {args.caption_weights}")
|
||||
model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit="large", med_config="./blip/med_config.json")
|
||||
model.eval()
|
||||
model = model.to(DEVICE)
|
||||
print("BLIP loaded")
|
||||
logger.info("BLIP loaded")
|
||||
|
||||
# captioningする
|
||||
def run_batch(path_imgs):
|
||||
@@ -108,7 +116,7 @@ def main(args):
|
||||
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
|
||||
f.write(caption + "\n")
|
||||
if args.debug:
|
||||
print(image_path, caption)
|
||||
logger.info(f'{image_path} {caption}')
|
||||
|
||||
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||
if args.max_data_loader_n_workers is not None:
|
||||
@@ -138,7 +146,7 @@ def main(args):
|
||||
raw_image = raw_image.convert("RGB")
|
||||
img_tensor = IMAGE_TRANSFORM(raw_image)
|
||||
except Exception as e:
|
||||
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||
continue
|
||||
|
||||
b_imgs.append((image_path, img_tensor))
|
||||
@@ -148,7 +156,7 @@ def main(args):
|
||||
if len(b_imgs) > 0:
|
||||
run_batch(b_imgs)
|
||||
|
||||
print("done!")
|
||||
logger.info("done!")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
@@ -5,12 +5,19 @@ import re
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, get_preferred_device
|
||||
init_ipex()
|
||||
|
||||
from transformers import AutoProcessor, AutoModelForCausalLM
|
||||
from transformers.generation.utils import GenerationMixin
|
||||
|
||||
import library.train_util as train_util
|
||||
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
@@ -35,8 +42,8 @@ def remove_words(captions, debug):
|
||||
for pat in PATTERN_REPLACE:
|
||||
cap = pat.sub("", cap)
|
||||
if debug and cap != caption:
|
||||
print(caption)
|
||||
print(cap)
|
||||
logger.info(caption)
|
||||
logger.info(cap)
|
||||
removed_caps.append(cap)
|
||||
return removed_caps
|
||||
|
||||
@@ -70,16 +77,16 @@ def main(args):
|
||||
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
|
||||
"""
|
||||
|
||||
print(f"load images from {args.train_data_dir}")
|
||||
logger.info(f"load images from {args.train_data_dir}")
|
||||
train_data_dir_path = Path(args.train_data_dir)
|
||||
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
||||
print(f"found {len(image_paths)} images.")
|
||||
logger.info(f"found {len(image_paths)} images.")
|
||||
|
||||
# できればcacheに依存せず明示的にダウンロードしたい
|
||||
print(f"loading GIT: {args.model_id}")
|
||||
logger.info(f"loading GIT: {args.model_id}")
|
||||
git_processor = AutoProcessor.from_pretrained(args.model_id)
|
||||
git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE)
|
||||
print("GIT loaded")
|
||||
logger.info("GIT loaded")
|
||||
|
||||
# captioningする
|
||||
def run_batch(path_imgs):
|
||||
@@ -97,7 +104,7 @@ def main(args):
|
||||
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
|
||||
f.write(caption + "\n")
|
||||
if args.debug:
|
||||
print(image_path, caption)
|
||||
logger.info(f"{image_path} {caption}")
|
||||
|
||||
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||
if args.max_data_loader_n_workers is not None:
|
||||
@@ -126,7 +133,7 @@ def main(args):
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
except Exception as e:
|
||||
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||
continue
|
||||
|
||||
b_imgs.append((image_path, image))
|
||||
@@ -137,7 +144,7 @@ def main(args):
|
||||
if len(b_imgs) > 0:
|
||||
run_batch(b_imgs)
|
||||
|
||||
print("done!")
|
||||
logger.info("done!")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
@@ -5,72 +5,96 @@ from typing import List
|
||||
from tqdm import tqdm
|
||||
import library.train_util as train_util
|
||||
import os
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main(args):
|
||||
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
|
||||
assert not args.recursive or (
|
||||
args.recursive and args.full_path
|
||||
), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
|
||||
|
||||
train_data_dir_path = Path(args.train_data_dir)
|
||||
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
||||
print(f"found {len(image_paths)} images.")
|
||||
train_data_dir_path = Path(args.train_data_dir)
|
||||
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
||||
logger.info(f"found {len(image_paths)} images.")
|
||||
|
||||
if args.in_json is None and Path(args.out_json).is_file():
|
||||
args.in_json = args.out_json
|
||||
if args.in_json is None and Path(args.out_json).is_file():
|
||||
args.in_json = args.out_json
|
||||
|
||||
if args.in_json is not None:
|
||||
print(f"loading existing metadata: {args.in_json}")
|
||||
metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
|
||||
print("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
|
||||
else:
|
||||
print("new metadata will be created / 新しいメタデータファイルが作成されます")
|
||||
metadata = {}
|
||||
if args.in_json is not None:
|
||||
logger.info(f"loading existing metadata: {args.in_json}")
|
||||
metadata = json.loads(Path(args.in_json).read_text(encoding="utf-8"))
|
||||
logger.warning("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
|
||||
else:
|
||||
logger.info("new metadata will be created / 新しいメタデータファイルが作成されます")
|
||||
metadata = {}
|
||||
|
||||
print("merge caption texts to metadata json.")
|
||||
for image_path in tqdm(image_paths):
|
||||
caption_path = image_path.with_suffix(args.caption_extension)
|
||||
caption = caption_path.read_text(encoding='utf-8').strip()
|
||||
logger.info("merge caption texts to metadata json.")
|
||||
for image_path in tqdm(image_paths):
|
||||
caption_path = image_path.with_suffix(args.caption_extension)
|
||||
caption = caption_path.read_text(encoding="utf-8").strip()
|
||||
|
||||
if not os.path.exists(caption_path):
|
||||
caption_path = os.path.join(image_path, args.caption_extension)
|
||||
if not os.path.exists(caption_path):
|
||||
caption_path = os.path.join(image_path, args.caption_extension)
|
||||
|
||||
image_key = str(image_path) if args.full_path else image_path.stem
|
||||
if image_key not in metadata:
|
||||
metadata[image_key] = {}
|
||||
image_key = str(image_path) if args.full_path else image_path.stem
|
||||
if image_key not in metadata:
|
||||
metadata[image_key] = {}
|
||||
|
||||
metadata[image_key]['caption'] = caption
|
||||
if args.debug:
|
||||
print(image_key, caption)
|
||||
metadata[image_key]["caption"] = caption
|
||||
if args.debug:
|
||||
logger.info(f"{image_key} {caption}")
|
||||
|
||||
# metadataを書き出して終わり
|
||||
print(f"writing metadata: {args.out_json}")
|
||||
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
|
||||
print("done!")
|
||||
# metadataを書き出して終わり
|
||||
logger.info(f"writing metadata: {args.out_json}")
|
||||
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding="utf-8")
|
||||
logger.info("done!")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
||||
parser.add_argument("--in_json", type=str,
|
||||
help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)")
|
||||
parser.add_argument("--caption_extention", type=str, default=None,
|
||||
help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
|
||||
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子")
|
||||
parser.add_argument("--full_path", action="store_true",
|
||||
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
|
||||
parser.add_argument("--recursive", action="store_true",
|
||||
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
||||
parser.add_argument(
|
||||
"--in_json",
|
||||
type=str,
|
||||
help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--caption_extention",
|
||||
type=str,
|
||||
default=None,
|
||||
help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--full_path",
|
||||
action="store_true",
|
||||
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recursive",
|
||||
action="store_true",
|
||||
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す",
|
||||
)
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||
|
||||
return parser
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = setup_parser()
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = parser.parse_args()
|
||||
|
||||
# スペルミスしていたオプションを復元する
|
||||
if args.caption_extention is not None:
|
||||
args.caption_extension = args.caption_extention
|
||||
# スペルミスしていたオプションを復元する
|
||||
if args.caption_extention is not None:
|
||||
args.caption_extension = args.caption_extention
|
||||
|
||||
main(args)
|
||||
main(args)
|
||||
|
||||
@@ -5,67 +5,89 @@ from typing import List
|
||||
from tqdm import tqdm
|
||||
import library.train_util as train_util
|
||||
import os
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main(args):
|
||||
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
|
||||
assert not args.recursive or (
|
||||
args.recursive and args.full_path
|
||||
), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
|
||||
|
||||
train_data_dir_path = Path(args.train_data_dir)
|
||||
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
||||
print(f"found {len(image_paths)} images.")
|
||||
train_data_dir_path = Path(args.train_data_dir)
|
||||
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
||||
logger.info(f"found {len(image_paths)} images.")
|
||||
|
||||
if args.in_json is None and Path(args.out_json).is_file():
|
||||
args.in_json = args.out_json
|
||||
if args.in_json is None and Path(args.out_json).is_file():
|
||||
args.in_json = args.out_json
|
||||
|
||||
if args.in_json is not None:
|
||||
print(f"loading existing metadata: {args.in_json}")
|
||||
metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
|
||||
print("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます")
|
||||
else:
|
||||
print("new metadata will be created / 新しいメタデータファイルが作成されます")
|
||||
metadata = {}
|
||||
if args.in_json is not None:
|
||||
logger.info(f"loading existing metadata: {args.in_json}")
|
||||
metadata = json.loads(Path(args.in_json).read_text(encoding="utf-8"))
|
||||
logger.warning("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます")
|
||||
else:
|
||||
logger.info("new metadata will be created / 新しいメタデータファイルが作成されます")
|
||||
metadata = {}
|
||||
|
||||
print("merge tags to metadata json.")
|
||||
for image_path in tqdm(image_paths):
|
||||
tags_path = image_path.with_suffix(args.caption_extension)
|
||||
tags = tags_path.read_text(encoding='utf-8').strip()
|
||||
logger.info("merge tags to metadata json.")
|
||||
for image_path in tqdm(image_paths):
|
||||
tags_path = image_path.with_suffix(args.caption_extension)
|
||||
tags = tags_path.read_text(encoding="utf-8").strip()
|
||||
|
||||
if not os.path.exists(tags_path):
|
||||
tags_path = os.path.join(image_path, args.caption_extension)
|
||||
if not os.path.exists(tags_path):
|
||||
tags_path = os.path.join(image_path, args.caption_extension)
|
||||
|
||||
image_key = str(image_path) if args.full_path else image_path.stem
|
||||
if image_key not in metadata:
|
||||
metadata[image_key] = {}
|
||||
image_key = str(image_path) if args.full_path else image_path.stem
|
||||
if image_key not in metadata:
|
||||
metadata[image_key] = {}
|
||||
|
||||
metadata[image_key]['tags'] = tags
|
||||
if args.debug:
|
||||
print(image_key, tags)
|
||||
metadata[image_key]["tags"] = tags
|
||||
if args.debug:
|
||||
logger.info(f"{image_key} {tags}")
|
||||
|
||||
# metadataを書き出して終わり
|
||||
print(f"writing metadata: {args.out_json}")
|
||||
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
|
||||
# metadataを書き出して終わり
|
||||
logger.info(f"writing metadata: {args.out_json}")
|
||||
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding="utf-8")
|
||||
|
||||
print("done!")
|
||||
logger.info("done!")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
||||
parser.add_argument("--in_json", type=str,
|
||||
help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)")
|
||||
parser.add_argument("--full_path", action="store_true",
|
||||
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
|
||||
parser.add_argument("--recursive", action="store_true",
|
||||
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
|
||||
parser.add_argument("--caption_extension", type=str, default=".txt",
|
||||
help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子")
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode, print tags")
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
||||
parser.add_argument(
|
||||
"--in_json",
|
||||
type=str,
|
||||
help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--full_path",
|
||||
action="store_true",
|
||||
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recursive",
|
||||
action="store_true",
|
||||
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--caption_extension",
|
||||
type=str,
|
||||
default=".txt",
|
||||
help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子",
|
||||
)
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode, print tags")
|
||||
|
||||
return parser
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = setup_parser()
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
@@ -8,13 +8,24 @@ from tqdm import tqdm
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import cv2
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, get_preferred_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from torchvision import transforms
|
||||
|
||||
import library.model_util as model_util
|
||||
import library.train_util as train_util
|
||||
from library.utils import setup_logging
|
||||
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEVICE = get_preferred_device()
|
||||
|
||||
IMAGE_TRANSFORMS = transforms.Compose(
|
||||
[
|
||||
@@ -51,22 +62,22 @@ def get_npz_filename(data_dir, image_key, is_full_path, recursive):
|
||||
def main(args):
|
||||
# assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります"
|
||||
if args.bucket_reso_steps % 8 > 0:
|
||||
print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります")
|
||||
logger.warning(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります")
|
||||
if args.bucket_reso_steps % 32 > 0:
|
||||
print(
|
||||
logger.warning(
|
||||
f"WARNING: bucket_reso_steps is not divisible by 32. It is not working with SDXL / bucket_reso_stepsが32で割り切れません。SDXLでは動作しません"
|
||||
)
|
||||
|
||||
train_data_dir_path = Path(args.train_data_dir)
|
||||
image_paths: List[str] = [str(p) for p in train_util.glob_images_pathlib(train_data_dir_path, args.recursive)]
|
||||
print(f"found {len(image_paths)} images.")
|
||||
logger.info(f"found {len(image_paths)} images.")
|
||||
|
||||
if os.path.exists(args.in_json):
|
||||
print(f"loading existing metadata: {args.in_json}")
|
||||
logger.info(f"loading existing metadata: {args.in_json}")
|
||||
with open(args.in_json, "rt", encoding="utf-8") as f:
|
||||
metadata = json.load(f)
|
||||
else:
|
||||
print(f"no metadata / メタデータファイルがありません: {args.in_json}")
|
||||
logger.error(f"no metadata / メタデータファイルがありません: {args.in_json}")
|
||||
return
|
||||
|
||||
weight_dtype = torch.float32
|
||||
@@ -81,7 +92,9 @@ def main(args):
|
||||
|
||||
# bucketのサイズを計算する
|
||||
max_reso = tuple([int(t) for t in args.max_resolution.split(",")])
|
||||
assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
|
||||
assert (
|
||||
len(max_reso) == 2
|
||||
), f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
|
||||
|
||||
bucket_manager = train_util.BucketManager(
|
||||
args.bucket_no_upscale, max_reso, args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps
|
||||
@@ -89,7 +102,7 @@ def main(args):
|
||||
if not args.bucket_no_upscale:
|
||||
bucket_manager.make_buckets()
|
||||
else:
|
||||
print(
|
||||
logger.warning(
|
||||
"min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます"
|
||||
)
|
||||
|
||||
@@ -99,7 +112,7 @@ def main(args):
|
||||
def process_batch(is_last):
|
||||
for bucket in bucket_manager.buckets:
|
||||
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
|
||||
train_util.cache_batch_latents(vae, True, bucket, args.flip_aug, False)
|
||||
train_util.cache_batch_latents(vae, True, bucket, args.flip_aug, args.alpha_mask, False)
|
||||
bucket.clear()
|
||||
|
||||
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||
@@ -130,7 +143,7 @@ def main(args):
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
except Exception as e:
|
||||
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||
continue
|
||||
|
||||
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
|
||||
@@ -183,15 +196,15 @@ def main(args):
|
||||
for i, reso in enumerate(bucket_manager.resos):
|
||||
count = bucket_counts.get(reso, 0)
|
||||
if count > 0:
|
||||
print(f"bucket {i} {reso}: {count}")
|
||||
logger.info(f"bucket {i} {reso}: {count}")
|
||||
img_ar_errors = np.array(img_ar_errors)
|
||||
print(f"mean ar error: {np.mean(img_ar_errors)}")
|
||||
logger.info(f"mean ar error: {np.mean(img_ar_errors)}")
|
||||
|
||||
# metadataを書き出して終わり
|
||||
print(f"writing metadata: {args.out_json}")
|
||||
logger.info(f"writing metadata: {args.out_json}")
|
||||
with open(args.out_json, "wt", encoding="utf-8") as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
print("done!")
|
||||
logger.info("done!")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
@@ -200,7 +213,9 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
|
||||
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
||||
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
|
||||
parser.add_argument("--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)")
|
||||
parser.add_argument(
|
||||
"--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)"
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||
parser.add_argument(
|
||||
"--max_data_loader_n_workers",
|
||||
@@ -223,10 +238,16 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します"
|
||||
"--bucket_no_upscale",
|
||||
action="store_true",
|
||||
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度"
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default="no",
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help="use mixed precision / 混合精度を使う場合、その精度",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--full_path",
|
||||
@@ -234,7 +255,15 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--flip_aug", action="store_true", help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する"
|
||||
"--flip_aug",
|
||||
action="store_true",
|
||||
help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alpha_mask",
|
||||
type=str,
|
||||
default="",
|
||||
help="save alpha mask for images for loss calculation / 損失計算用に画像のアルファマスクを保存する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_existing",
|
||||
|
||||
@@ -11,6 +11,12 @@ from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
import library.train_util as train_util
|
||||
from library.utils import setup_logging, resize_image
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# from wd14 tagger
|
||||
IMAGE_SIZE = 448
|
||||
@@ -36,8 +42,7 @@ def preprocess_image(image):
|
||||
pad_t = pad_y // 2
|
||||
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)
|
||||
|
||||
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
|
||||
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
|
||||
image = resize_image(image, image.shape[0], image.shape[1], IMAGE_SIZE, IMAGE_SIZE)
|
||||
|
||||
image = image.astype(np.float32)
|
||||
return image
|
||||
@@ -56,12 +61,12 @@ class ImageLoadingPrepDataset(torch.utils.data.Dataset):
|
||||
try:
|
||||
image = Image.open(img_path).convert("RGB")
|
||||
image = preprocess_image(image)
|
||||
tensor = torch.tensor(image)
|
||||
# tensor = torch.tensor(image) # これ Tensor に変換する必要ないな……(;・∀・)
|
||||
except Exception as e:
|
||||
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
||||
logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
||||
return None
|
||||
|
||||
return (tensor, img_path)
|
||||
return (image, img_path)
|
||||
|
||||
|
||||
def collate_fn_remove_corrupted(batch):
|
||||
@@ -75,36 +80,43 @@ def collate_fn_remove_corrupted(batch):
|
||||
|
||||
|
||||
def main(args):
|
||||
# model location is model_dir + repo_id
|
||||
# repo id may be like "user/repo" or "user/repo/branch", so we need to remove slash
|
||||
model_location = os.path.join(args.model_dir, args.repo_id.replace("/", "_"))
|
||||
|
||||
# hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする
|
||||
# depreacatedの警告が出るけどなくなったらその時
|
||||
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
|
||||
if not os.path.exists(args.model_dir) or args.force_download:
|
||||
print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
|
||||
if not os.path.exists(model_location) or args.force_download:
|
||||
os.makedirs(args.model_dir, exist_ok=True)
|
||||
logger.info(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
|
||||
files = FILES
|
||||
if args.onnx:
|
||||
files = ["selected_tags.csv"]
|
||||
files += FILES_ONNX
|
||||
else:
|
||||
for file in SUB_DIR_FILES:
|
||||
hf_hub_download(
|
||||
args.repo_id,
|
||||
file,
|
||||
subfolder=SUB_DIR,
|
||||
cache_dir=os.path.join(model_location, SUB_DIR),
|
||||
force_download=True,
|
||||
force_filename=file,
|
||||
)
|
||||
for file in files:
|
||||
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
|
||||
for file in SUB_DIR_FILES:
|
||||
hf_hub_download(
|
||||
args.repo_id,
|
||||
file,
|
||||
subfolder=SUB_DIR,
|
||||
cache_dir=os.path.join(args.model_dir, SUB_DIR),
|
||||
force_download=True,
|
||||
force_filename=file,
|
||||
)
|
||||
hf_hub_download(args.repo_id, file, cache_dir=model_location, force_download=True, force_filename=file)
|
||||
else:
|
||||
print("using existing wd14 tagger model")
|
||||
logger.info("using existing wd14 tagger model")
|
||||
|
||||
# 画像を読み込む
|
||||
# モデルを読み込む
|
||||
if args.onnx:
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
|
||||
onnx_path = f"{args.model_dir}/model.onnx"
|
||||
print("Running wd14 tagger with onnx")
|
||||
print(f"loading onnx model: {onnx_path}")
|
||||
onnx_path = f"{model_location}/model.onnx"
|
||||
logger.info("Running wd14 tagger with onnx")
|
||||
logger.info(f"loading onnx model: {onnx_path}")
|
||||
|
||||
if not os.path.exists(onnx_path):
|
||||
raise Exception(
|
||||
@@ -116,60 +128,112 @@ def main(args):
|
||||
input_name = model.graph.input[0].name
|
||||
try:
|
||||
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_value
|
||||
except:
|
||||
except Exception:
|
||||
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_param
|
||||
|
||||
if args.batch_size != batch_size and type(batch_size) != str:
|
||||
if args.batch_size != batch_size and not isinstance(batch_size, str) and batch_size > 0:
|
||||
# some rebatch model may use 'N' as dynamic axes
|
||||
print(
|
||||
logger.warning(
|
||||
f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}"
|
||||
)
|
||||
args.batch_size = batch_size
|
||||
|
||||
del model
|
||||
|
||||
ort_sess = ort.InferenceSession(
|
||||
onnx_path,
|
||||
providers=["CUDAExecutionProvider"]
|
||||
if "CUDAExecutionProvider" in ort.get_available_providers()
|
||||
else ["CPUExecutionProvider"],
|
||||
)
|
||||
if "OpenVINOExecutionProvider" in ort.get_available_providers():
|
||||
# requires provider options for gpu support
|
||||
# fp16 causes nonsense outputs
|
||||
ort_sess = ort.InferenceSession(
|
||||
onnx_path,
|
||||
providers=(["OpenVINOExecutionProvider"]),
|
||||
provider_options=[{'device_type' : "GPU_FP32"}],
|
||||
)
|
||||
else:
|
||||
ort_sess = ort.InferenceSession(
|
||||
onnx_path,
|
||||
providers=(
|
||||
["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else
|
||||
["ROCMExecutionProvider"] if "ROCMExecutionProvider" in ort.get_available_providers() else
|
||||
["CPUExecutionProvider"]
|
||||
),
|
||||
)
|
||||
else:
|
||||
from tensorflow.keras.models import load_model
|
||||
|
||||
model = load_model(f"{args.model_dir}")
|
||||
model = load_model(f"{model_location}")
|
||||
|
||||
# label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
|
||||
# 依存ライブラリを増やしたくないので自力で読むよ
|
||||
|
||||
with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f:
|
||||
with open(os.path.join(model_location, CSV_FILE), "r", encoding="utf-8") as f:
|
||||
reader = csv.reader(f)
|
||||
l = [row for row in reader]
|
||||
header = l[0] # tag_id,name,category,count
|
||||
rows = l[1:]
|
||||
line = [row for row in reader]
|
||||
header = line[0] # tag_id,name,category,count
|
||||
rows = line[1:]
|
||||
assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"
|
||||
|
||||
general_tags = [row[1] for row in rows[1:] if row[2] == "0"]
|
||||
character_tags = [row[1] for row in rows[1:] if row[2] == "4"]
|
||||
rating_tags = [row[1] for row in rows[0:] if row[2] == "9"]
|
||||
general_tags = [row[1] for row in rows[0:] if row[2] == "0"]
|
||||
character_tags = [row[1] for row in rows[0:] if row[2] == "4"]
|
||||
|
||||
# preprocess tags in advance
|
||||
if args.character_tag_expand:
|
||||
for i, tag in enumerate(character_tags):
|
||||
if tag.endswith(")"):
|
||||
# chara_name_(series) -> chara_name, series
|
||||
# chara_name_(costume)_(series) -> chara_name_(costume), series
|
||||
tags = tag.split("(")
|
||||
character_tag = "(".join(tags[:-1])
|
||||
if character_tag.endswith("_"):
|
||||
character_tag = character_tag[:-1]
|
||||
series_tag = tags[-1].replace(")", "")
|
||||
character_tags[i] = character_tag + args.caption_separator + series_tag
|
||||
|
||||
if args.remove_underscore:
|
||||
rating_tags = [tag.replace("_", " ") if len(tag) > 3 else tag for tag in rating_tags]
|
||||
general_tags = [tag.replace("_", " ") if len(tag) > 3 else tag for tag in general_tags]
|
||||
character_tags = [tag.replace("_", " ") if len(tag) > 3 else tag for tag in character_tags]
|
||||
|
||||
if args.tag_replacement is not None:
|
||||
# escape , and ; in tag_replacement: wd14 tag names may contain , and ;
|
||||
escaped_tag_replacements = args.tag_replacement.replace("\\,", "@@@@").replace("\\;", "####")
|
||||
tag_replacements = escaped_tag_replacements.split(";")
|
||||
for tag_replacement in tag_replacements:
|
||||
tags = tag_replacement.split(",") # source, target
|
||||
assert len(tags) == 2, f"tag replacement must be in the format of `source,target` / タグの置換は `置換元,置換先` の形式で指定してください: {args.tag_replacement}"
|
||||
|
||||
source, target = [tag.replace("@@@@", ",").replace("####", ";") for tag in tags]
|
||||
logger.info(f"replacing tag: {source} -> {target}")
|
||||
|
||||
if source in general_tags:
|
||||
general_tags[general_tags.index(source)] = target
|
||||
elif source in character_tags:
|
||||
character_tags[character_tags.index(source)] = target
|
||||
elif source in rating_tags:
|
||||
rating_tags[rating_tags.index(source)] = target
|
||||
|
||||
# 画像を読み込む
|
||||
|
||||
train_data_dir_path = Path(args.train_data_dir)
|
||||
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
||||
print(f"found {len(image_paths)} images.")
|
||||
logger.info(f"found {len(image_paths)} images.")
|
||||
|
||||
tag_freq = {}
|
||||
|
||||
caption_separator = args.caption_separator
|
||||
stripped_caption_separator = caption_separator.strip()
|
||||
undesired_tags = set(args.undesired_tags.split(stripped_caption_separator))
|
||||
undesired_tags = args.undesired_tags.split(stripped_caption_separator)
|
||||
undesired_tags = set([tag.strip() for tag in undesired_tags if tag.strip() != ""])
|
||||
|
||||
always_first_tags = None
|
||||
if args.always_first_tags is not None:
|
||||
always_first_tags = [tag for tag in args.always_first_tags.split(stripped_caption_separator) if tag.strip() != ""]
|
||||
|
||||
def run_batch(path_imgs):
|
||||
imgs = np.array([im for _, im in path_imgs])
|
||||
|
||||
if args.onnx:
|
||||
if len(imgs) < args.batch_size:
|
||||
imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0)
|
||||
# if len(imgs) < args.batch_size:
|
||||
# imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0)
|
||||
probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy
|
||||
probs = probs[: len(path_imgs)]
|
||||
else:
|
||||
@@ -177,22 +241,16 @@ def main(args):
|
||||
probs = probs.numpy()
|
||||
|
||||
for (image_path, _), prob in zip(path_imgs, probs):
|
||||
# 最初の4つはratingなので無視する
|
||||
# # First 4 labels are actually ratings: pick one with argmax
|
||||
# ratings_names = label_names[:4]
|
||||
# rating_index = ratings_names["probs"].argmax()
|
||||
# found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]]
|
||||
|
||||
# それ以降はタグなのでconfidenceがthresholdより高いものを追加する
|
||||
# Everything else is tags: pick any where prediction confidence > threshold
|
||||
combined_tags = []
|
||||
general_tag_text = ""
|
||||
rating_tag_text = ""
|
||||
character_tag_text = ""
|
||||
general_tag_text = ""
|
||||
|
||||
# 最初の4つ以降はタグなのでconfidenceがthreshold以上のものを追加する
|
||||
# First 4 labels are ratings, the rest are tags: pick any where prediction confidence >= threshold
|
||||
for i, p in enumerate(prob[4:]):
|
||||
if i < len(general_tags) and p >= args.general_threshold:
|
||||
tag_name = general_tags[i]
|
||||
if args.remove_underscore and len(tag_name) > 3: # ignore emoji tags like >_< and ^_^
|
||||
tag_name = tag_name.replace("_", " ")
|
||||
|
||||
if tag_name not in undesired_tags:
|
||||
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
||||
@@ -200,13 +258,37 @@ def main(args):
|
||||
combined_tags.append(tag_name)
|
||||
elif i >= len(general_tags) and p >= args.character_threshold:
|
||||
tag_name = character_tags[i - len(general_tags)]
|
||||
if args.remove_underscore and len(tag_name) > 3:
|
||||
tag_name = tag_name.replace("_", " ")
|
||||
|
||||
if tag_name not in undesired_tags:
|
||||
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
||||
character_tag_text += caption_separator + tag_name
|
||||
combined_tags.append(tag_name)
|
||||
if args.character_tags_first: # insert to the beginning
|
||||
combined_tags.insert(0, tag_name)
|
||||
else:
|
||||
combined_tags.append(tag_name)
|
||||
|
||||
# 最初の4つはratingなのでargmaxで選ぶ
|
||||
# First 4 labels are actually ratings: pick one with argmax
|
||||
if args.use_rating_tags or args.use_rating_tags_as_last_tag:
|
||||
ratings_probs = prob[:4]
|
||||
rating_index = ratings_probs.argmax()
|
||||
found_rating = rating_tags[rating_index]
|
||||
|
||||
if found_rating not in undesired_tags:
|
||||
tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1
|
||||
rating_tag_text = found_rating
|
||||
if args.use_rating_tags:
|
||||
combined_tags.insert(0, found_rating) # insert to the beginning
|
||||
else:
|
||||
combined_tags.append(found_rating)
|
||||
|
||||
# 一番最初に置くタグを指定する
|
||||
# Always put some tags at the beginning
|
||||
if always_first_tags is not None:
|
||||
for tag in always_first_tags:
|
||||
if tag in combined_tags:
|
||||
combined_tags.remove(tag)
|
||||
combined_tags.insert(0, tag)
|
||||
|
||||
# 先頭のカンマを取る
|
||||
if len(general_tag_text) > 0:
|
||||
@@ -237,7 +319,11 @@ def main(args):
|
||||
with open(caption_file, "wt", encoding="utf-8") as f:
|
||||
f.write(tag_text + "\n")
|
||||
if args.debug:
|
||||
print(f"\n{image_path}:\n Character tags: {character_tag_text}\n General tags: {general_tag_text}")
|
||||
logger.info("")
|
||||
logger.info(f"{image_path}:")
|
||||
logger.info(f"\tRating tags: {rating_tag_text}")
|
||||
logger.info(f"\tCharacter tags: {character_tag_text}")
|
||||
logger.info(f"\tGeneral tags: {general_tag_text}")
|
||||
|
||||
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||
if args.max_data_loader_n_workers is not None:
|
||||
@@ -260,16 +346,14 @@ def main(args):
|
||||
continue
|
||||
|
||||
image, image_path = data
|
||||
if image is not None:
|
||||
image = image.detach().numpy()
|
||||
else:
|
||||
if image is None:
|
||||
try:
|
||||
image = Image.open(image_path)
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
image = preprocess_image(image)
|
||||
except Exception as e:
|
||||
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||
continue
|
||||
b_imgs.append((image_path, image))
|
||||
|
||||
@@ -284,16 +368,18 @@ def main(args):
|
||||
|
||||
if args.frequency_tags:
|
||||
sorted_tags = sorted(tag_freq.items(), key=lambda x: x[1], reverse=True)
|
||||
print("\nTag frequencies:")
|
||||
print("Tag frequencies:")
|
||||
for tag, freq in sorted_tags:
|
||||
print(f"{tag}: {freq}")
|
||||
|
||||
print("done!")
|
||||
logger.info("done!")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument(
|
||||
"train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo_id",
|
||||
type=str,
|
||||
@@ -307,9 +393,13 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force_download", action="store_true", help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします"
|
||||
"--force_download",
|
||||
action="store_true",
|
||||
help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ"
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||
parser.add_argument(
|
||||
"--max_data_loader_n_workers",
|
||||
type=int,
|
||||
@@ -322,8 +412,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
default=None,
|
||||
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)",
|
||||
)
|
||||
parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子")
|
||||
parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値")
|
||||
parser.add_argument(
|
||||
"--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--general_threshold",
|
||||
type=float,
|
||||
@@ -336,28 +430,67 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
default=None,
|
||||
help="threshold of confidence to add a tag for character category, same as --thres if omitted / characterカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ",
|
||||
)
|
||||
parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する")
|
||||
parser.add_argument(
|
||||
"--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--remove_underscore",
|
||||
action="store_true",
|
||||
help="replace underscores with spaces in the output tags / 出力されるタグのアンダースコアをスペースに置き換える",
|
||||
)
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||
parser.add_argument(
|
||||
"--debug", action="store_true", help="debug mode"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--undesired_tags",
|
||||
type=str,
|
||||
default="",
|
||||
help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト",
|
||||
)
|
||||
parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する")
|
||||
parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する")
|
||||
parser.add_argument("--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する")
|
||||
parser.add_argument(
|
||||
"--frequency_tags", action="store_true", help="Show frequency of tags for images / タグの出現頻度を表示する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_rating_tags", action="store_true", help="Adds rating tags as the first tag / レーティングタグを最初のタグとして追加する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_rating_tags_as_last_tag", action="store_true", help="Adds rating tags as the last tag / レーティングタグを最後のタグとして追加する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--character_tags_first", action="store_true", help="Always inserts character tags before the general tags / characterタグを常にgeneralタグの前に出力する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--always_first_tags",
|
||||
type=str,
|
||||
default=None,
|
||||
help="comma-separated list of tags to always put at the beginning, e.g. `1girl,1boy`"
|
||||
+ " / 必ず先頭に置くタグのカンマ区切りリスト、例 : `1girl,1boy`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--caption_separator",
|
||||
type=str,
|
||||
default=", ",
|
||||
help="Separator for captions, include space if needed / キャプションの区切り文字、必要ならスペースを含めてください",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tag_replacement",
|
||||
type=str,
|
||||
default=None,
|
||||
help="tag replacement in the format of `source1,target1;source2,target2; ...`. Escape `,` and `;` with `\`. e.g. `tag1,tag2;tag3,tag4`"
|
||||
+ " / タグの置換を `置換元1,置換先1;置換元2,置換先2; ...`で指定する。`\` で `,` と `;` をエスケープできる。例: `tag1,tag2;tag3,tag4`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--character_tag_expand",
|
||||
action="store_true",
|
||||
help="expand tag tail parenthesis to another tag for character tags. `chara_name_(series)` becomes `chara_name, series`"
|
||||
+ " / キャラクタタグの末尾の括弧を別のタグに展開する。`chara_name_(series)` は `chara_name, series` になる",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
576
flux_minimal_inference.py
Normal file
576
flux_minimal_inference.py
Normal file
@@ -0,0 +1,576 @@
|
||||
# Minimum Inference Code for FLUX
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from typing import Callable, List, Optional
|
||||
import einops
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import accelerate
|
||||
from transformers import CLIPTextModel
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from library import device_utils
|
||||
from library.device_utils import init_ipex, get_preferred_device
|
||||
from networks import oft_flux
|
||||
|
||||
init_ipex()
|
||||
|
||||
|
||||
from library.utils import setup_logging, str_to_dtype
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import networks.lora_flux as lora_flux
|
||||
from library import flux_models, flux_utils, sd3_utils, strategy_flux
|
||||
|
||||
|
||||
def time_shift(mu: float, sigma: float, t: torch.Tensor):
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
|
||||
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
|
||||
m = (y2 - y1) / (x2 - x1)
|
||||
b = y1 - m * x1
|
||||
return lambda x: m * x + b
|
||||
|
||||
|
||||
def get_schedule(
|
||||
num_steps: int,
|
||||
image_seq_len: int,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
shift: bool = True,
|
||||
) -> list[float]:
|
||||
# extra step for zero
|
||||
timesteps = torch.linspace(1, 0, num_steps + 1)
|
||||
|
||||
# shifting the schedule to favor high timesteps for higher signal images
|
||||
if shift:
|
||||
# eastimate mu based on linear estimation between two points
|
||||
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
||||
timesteps = time_shift(mu, 1.0, timesteps)
|
||||
|
||||
return timesteps.tolist()
|
||||
|
||||
|
||||
def denoise(
|
||||
model: flux_models.Flux,
|
||||
img: torch.Tensor,
|
||||
img_ids: torch.Tensor,
|
||||
txt: torch.Tensor,
|
||||
txt_ids: torch.Tensor,
|
||||
vec: torch.Tensor,
|
||||
timesteps: list[float],
|
||||
guidance: float = 4.0,
|
||||
t5_attn_mask: Optional[torch.Tensor] = None,
|
||||
neg_txt: Optional[torch.Tensor] = None,
|
||||
neg_vec: Optional[torch.Tensor] = None,
|
||||
neg_t5_attn_mask: Optional[torch.Tensor] = None,
|
||||
cfg_scale: Optional[float] = None,
|
||||
):
|
||||
# this is ignored for schnell
|
||||
logger.info(f"guidance: {guidance}, cfg_scale: {cfg_scale}")
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
|
||||
# prepare classifier free guidance
|
||||
if neg_txt is not None and neg_vec is not None:
|
||||
b_img_ids = torch.cat([img_ids, img_ids], dim=0)
|
||||
b_txt_ids = torch.cat([txt_ids, txt_ids], dim=0)
|
||||
b_txt = torch.cat([neg_txt, txt], dim=0)
|
||||
b_vec = torch.cat([neg_vec, vec], dim=0)
|
||||
if t5_attn_mask is not None and neg_t5_attn_mask is not None:
|
||||
b_t5_attn_mask = torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0)
|
||||
else:
|
||||
b_t5_attn_mask = None
|
||||
else:
|
||||
b_img_ids = img_ids
|
||||
b_txt_ids = txt_ids
|
||||
b_txt = txt
|
||||
b_vec = vec
|
||||
b_t5_attn_mask = t5_attn_mask
|
||||
|
||||
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
||||
t_vec = torch.full((b_img_ids.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
|
||||
# classifier free guidance
|
||||
if neg_txt is not None and neg_vec is not None:
|
||||
b_img = torch.cat([img, img], dim=0)
|
||||
else:
|
||||
b_img = img
|
||||
|
||||
pred = model(
|
||||
img=b_img,
|
||||
img_ids=b_img_ids,
|
||||
txt=b_txt,
|
||||
txt_ids=b_txt_ids,
|
||||
y=b_vec,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=b_t5_attn_mask,
|
||||
)
|
||||
|
||||
# classifier free guidance
|
||||
if neg_txt is not None and neg_vec is not None:
|
||||
pred_uncond, pred = torch.chunk(pred, 2, dim=0)
|
||||
pred = pred_uncond + cfg_scale * (pred - pred_uncond)
|
||||
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def do_sample(
|
||||
accelerator: Optional[accelerate.Accelerator],
|
||||
model: flux_models.Flux,
|
||||
img: torch.Tensor,
|
||||
img_ids: torch.Tensor,
|
||||
l_pooled: torch.Tensor,
|
||||
t5_out: torch.Tensor,
|
||||
txt_ids: torch.Tensor,
|
||||
num_steps: int,
|
||||
guidance: float,
|
||||
t5_attn_mask: Optional[torch.Tensor],
|
||||
is_schnell: bool,
|
||||
device: torch.device,
|
||||
flux_dtype: torch.dtype,
|
||||
neg_l_pooled: Optional[torch.Tensor] = None,
|
||||
neg_t5_out: Optional[torch.Tensor] = None,
|
||||
neg_t5_attn_mask: Optional[torch.Tensor] = None,
|
||||
cfg_scale: Optional[float] = None,
|
||||
):
|
||||
logger.info(f"num_steps: {num_steps}")
|
||||
timesteps = get_schedule(num_steps, img.shape[1], shift=not is_schnell)
|
||||
|
||||
# denoise initial noise
|
||||
if accelerator:
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
x = denoise(
|
||||
model,
|
||||
img,
|
||||
img_ids,
|
||||
t5_out,
|
||||
txt_ids,
|
||||
l_pooled,
|
||||
timesteps,
|
||||
guidance,
|
||||
t5_attn_mask,
|
||||
neg_t5_out,
|
||||
neg_l_pooled,
|
||||
neg_t5_attn_mask,
|
||||
cfg_scale,
|
||||
)
|
||||
else:
|
||||
with torch.autocast(device_type=device.type, dtype=flux_dtype), torch.no_grad():
|
||||
x = denoise(
|
||||
model,
|
||||
img,
|
||||
img_ids,
|
||||
t5_out,
|
||||
txt_ids,
|
||||
l_pooled,
|
||||
timesteps,
|
||||
guidance,
|
||||
t5_attn_mask,
|
||||
neg_t5_out,
|
||||
neg_l_pooled,
|
||||
neg_t5_attn_mask,
|
||||
cfg_scale,
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def generate_image(
|
||||
model,
|
||||
clip_l: CLIPTextModel,
|
||||
t5xxl,
|
||||
ae,
|
||||
prompt: str,
|
||||
seed: Optional[int],
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
steps: Optional[int],
|
||||
guidance: float,
|
||||
negative_prompt: Optional[str],
|
||||
cfg_scale: float,
|
||||
):
|
||||
seed = seed if seed is not None else random.randint(0, 2**32 - 1)
|
||||
logger.info(f"Seed: {seed}")
|
||||
|
||||
# make first noise with packed shape
|
||||
# original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2
|
||||
packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16)
|
||||
noise_dtype = torch.float32 if is_fp8(dtype) else dtype
|
||||
noise = torch.randn(
|
||||
1,
|
||||
packed_latent_height * packed_latent_width,
|
||||
16 * 2 * 2,
|
||||
device=device,
|
||||
dtype=noise_dtype,
|
||||
generator=torch.Generator(device=device).manual_seed(seed),
|
||||
)
|
||||
|
||||
# prepare img and img ids
|
||||
|
||||
# this is needed only for img2img
|
||||
# img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
# if img.shape[0] == 1 and bs > 1:
|
||||
# img = repeat(img, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
# txt2img only needs img_ids
|
||||
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width)
|
||||
|
||||
# prepare fp8 models
|
||||
if is_fp8(clip_l_dtype) and (not hasattr(clip_l, "fp8_prepared") or not clip_l.fp8_prepared):
|
||||
logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}")
|
||||
clip_l.to(clip_l_dtype) # fp8
|
||||
clip_l.text_model.embeddings.to(dtype=torch.bfloat16)
|
||||
clip_l.fp8_prepared = True
|
||||
|
||||
if is_fp8(t5xxl_dtype) and (not hasattr(t5xxl, "fp8_prepared") or not t5xxl.fp8_prepared):
|
||||
logger.info(f"prepare T5xxl for fp8: set to {t5xxl_dtype}")
|
||||
|
||||
def prepare_fp8(text_encoder, target_dtype):
|
||||
def forward_hook(module):
|
||||
def forward(hidden_states):
|
||||
hidden_gelu = module.act(module.wi_0(hidden_states))
|
||||
hidden_linear = module.wi_1(hidden_states)
|
||||
hidden_states = hidden_gelu * hidden_linear
|
||||
hidden_states = module.dropout(hidden_states)
|
||||
|
||||
hidden_states = module.wo(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
return forward
|
||||
|
||||
for module in text_encoder.modules():
|
||||
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
|
||||
# print("set", module.__class__.__name__, "to", target_dtype)
|
||||
module.to(target_dtype)
|
||||
if module.__class__.__name__ in ["T5DenseGatedActDense"]:
|
||||
# print("set", module.__class__.__name__, "hooks")
|
||||
module.forward = forward_hook(module)
|
||||
|
||||
t5xxl.to(t5xxl_dtype)
|
||||
prepare_fp8(t5xxl.encoder, torch.bfloat16)
|
||||
t5xxl.fp8_prepared = True
|
||||
|
||||
# prepare embeddings
|
||||
logger.info("Encoding prompts...")
|
||||
clip_l = clip_l.to(device)
|
||||
t5xxl = t5xxl.to(device)
|
||||
|
||||
def encode(prpt: str):
|
||||
tokens_and_masks = tokenize_strategy.tokenize(prpt)
|
||||
with torch.no_grad():
|
||||
if is_fp8(clip_l_dtype):
|
||||
with accelerator.autocast():
|
||||
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
|
||||
else:
|
||||
with torch.autocast(device_type=device.type, dtype=clip_l_dtype):
|
||||
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
|
||||
|
||||
if is_fp8(t5xxl_dtype):
|
||||
with accelerator.autocast():
|
||||
_, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
|
||||
)
|
||||
else:
|
||||
with torch.autocast(device_type=device.type, dtype=t5xxl_dtype):
|
||||
_, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
|
||||
)
|
||||
return l_pooled, t5_out, txt_ids, t5_attn_mask
|
||||
|
||||
l_pooled, t5_out, txt_ids, t5_attn_mask = encode(prompt)
|
||||
if negative_prompt:
|
||||
neg_l_pooled, neg_t5_out, _, neg_t5_attn_mask = encode(negative_prompt)
|
||||
else:
|
||||
neg_l_pooled, neg_t5_out, neg_t5_attn_mask = None, None, None
|
||||
|
||||
# NaN check
|
||||
if torch.isnan(l_pooled).any():
|
||||
raise ValueError("NaN in l_pooled")
|
||||
if torch.isnan(t5_out).any():
|
||||
raise ValueError("NaN in t5_out")
|
||||
|
||||
if args.offload:
|
||||
clip_l = clip_l.cpu()
|
||||
t5xxl = t5xxl.cpu()
|
||||
# del clip_l, t5xxl
|
||||
device_utils.clean_memory()
|
||||
|
||||
# generate image
|
||||
logger.info("Generating image...")
|
||||
model = model.to(device)
|
||||
if steps is None:
|
||||
steps = 4 if is_schnell else 50
|
||||
|
||||
img_ids = img_ids.to(device)
|
||||
t5_attn_mask = t5_attn_mask.to(device) if args.apply_t5_attn_mask else None
|
||||
|
||||
x = do_sample(
|
||||
accelerator,
|
||||
model,
|
||||
noise,
|
||||
img_ids,
|
||||
l_pooled,
|
||||
t5_out,
|
||||
txt_ids,
|
||||
steps,
|
||||
guidance,
|
||||
t5_attn_mask,
|
||||
is_schnell,
|
||||
device,
|
||||
flux_dtype,
|
||||
neg_l_pooled,
|
||||
neg_t5_out,
|
||||
neg_t5_attn_mask,
|
||||
cfg_scale,
|
||||
)
|
||||
if args.offload:
|
||||
model = model.cpu()
|
||||
# del model
|
||||
device_utils.clean_memory()
|
||||
|
||||
# unpack
|
||||
x = x.float()
|
||||
x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2)
|
||||
|
||||
# decode
|
||||
logger.info("Decoding image...")
|
||||
ae = ae.to(device)
|
||||
with torch.no_grad():
|
||||
if is_fp8(ae_dtype):
|
||||
with accelerator.autocast():
|
||||
x = ae.decode(x)
|
||||
else:
|
||||
with torch.autocast(device_type=device.type, dtype=ae_dtype):
|
||||
x = ae.decode(x)
|
||||
if args.offload:
|
||||
ae = ae.cpu()
|
||||
|
||||
x = x.clamp(-1, 1)
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
img = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
|
||||
|
||||
# save image
|
||||
output_dir = args.output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
output_path = os.path.join(output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png")
|
||||
img.save(output_path)
|
||||
|
||||
logger.info(f"Saved image to {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
target_height = 768 # 1024
|
||||
target_width = 1360 # 1024
|
||||
|
||||
# steps = 50 # 28 # 50
|
||||
# guidance_scale = 5
|
||||
# seed = 1 # None # 1
|
||||
|
||||
device = get_preferred_device()
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ckpt_path", type=str, required=True)
|
||||
parser.add_argument("--clip_l", type=str, required=False)
|
||||
parser.add_argument("--t5xxl", type=str, required=False)
|
||||
parser.add_argument("--ae", type=str, required=False)
|
||||
parser.add_argument("--apply_t5_attn_mask", action="store_true")
|
||||
parser.add_argument("--prompt", type=str, default="A photo of a cat")
|
||||
parser.add_argument("--output_dir", type=str, default=".")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="base dtype")
|
||||
parser.add_argument("--clip_l_dtype", type=str, default=None, help="dtype for clip_l")
|
||||
parser.add_argument("--ae_dtype", type=str, default=None, help="dtype for ae")
|
||||
parser.add_argument("--t5xxl_dtype", type=str, default=None, help="dtype for t5xxl")
|
||||
parser.add_argument("--flux_dtype", type=str, default=None, help="dtype for flux")
|
||||
parser.add_argument("--seed", type=int, default=None)
|
||||
parser.add_argument("--steps", type=int, default=None, help="Number of steps. Default is 4 for schnell, 50 for dev")
|
||||
parser.add_argument("--guidance", type=float, default=3.5)
|
||||
parser.add_argument("--negative_prompt", type=str, default=None)
|
||||
parser.add_argument("--cfg_scale", type=float, default=1.0)
|
||||
parser.add_argument("--offload", action="store_true", help="Offload to CPU")
|
||||
parser.add_argument(
|
||||
"--lora_weights",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=[],
|
||||
help="LoRA weights, only supports networks.lora_flux and lora_oft, each argument is a `path;multiplier` (semi-colon separated)",
|
||||
)
|
||||
parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model")
|
||||
parser.add_argument("--width", type=int, default=target_width)
|
||||
parser.add_argument("--height", type=int, default=target_height)
|
||||
parser.add_argument("--interactive", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
seed = args.seed
|
||||
steps = args.steps
|
||||
guidance_scale = args.guidance
|
||||
|
||||
def is_fp8(dt):
|
||||
return dt in [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz]
|
||||
|
||||
dtype = str_to_dtype(args.dtype)
|
||||
clip_l_dtype = str_to_dtype(args.clip_l_dtype, dtype)
|
||||
t5xxl_dtype = str_to_dtype(args.t5xxl_dtype, dtype)
|
||||
ae_dtype = str_to_dtype(args.ae_dtype, dtype)
|
||||
flux_dtype = str_to_dtype(args.flux_dtype, dtype)
|
||||
|
||||
logger.info(f"Dtypes for clip_l, t5xxl, ae, flux: {clip_l_dtype}, {t5xxl_dtype}, {ae_dtype}, {flux_dtype}")
|
||||
|
||||
loading_device = "cpu" if args.offload else device
|
||||
|
||||
use_fp8 = [is_fp8(d) for d in [dtype, clip_l_dtype, t5xxl_dtype, ae_dtype, flux_dtype]]
|
||||
if any(use_fp8):
|
||||
accelerator = accelerate.Accelerator(mixed_precision="bf16")
|
||||
else:
|
||||
accelerator = None
|
||||
|
||||
# load clip_l
|
||||
logger.info(f"Loading clip_l from {args.clip_l}...")
|
||||
clip_l = flux_utils.load_clip_l(args.clip_l, clip_l_dtype, loading_device)
|
||||
clip_l.eval()
|
||||
|
||||
logger.info(f"Loading t5xxl from {args.t5xxl}...")
|
||||
t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device)
|
||||
t5xxl.eval()
|
||||
|
||||
# if is_fp8(clip_l_dtype):
|
||||
# clip_l = accelerator.prepare(clip_l)
|
||||
# if is_fp8(t5xxl_dtype):
|
||||
# t5xxl = accelerator.prepare(t5xxl)
|
||||
|
||||
# DiT
|
||||
is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device)
|
||||
model.eval()
|
||||
logger.info(f"Casting model to {flux_dtype}")
|
||||
model.to(flux_dtype) # make sure model is dtype
|
||||
# if is_fp8(flux_dtype):
|
||||
# model = accelerator.prepare(model)
|
||||
# if args.offload:
|
||||
# model = model.to("cpu")
|
||||
|
||||
t5xxl_max_length = 256 if is_schnell else 512
|
||||
tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length)
|
||||
encoding_strategy = strategy_flux.FluxTextEncodingStrategy()
|
||||
|
||||
# AE
|
||||
ae = flux_utils.load_ae(args.ae, ae_dtype, loading_device)
|
||||
ae.eval()
|
||||
# if is_fp8(ae_dtype):
|
||||
# ae = accelerator.prepare(ae)
|
||||
|
||||
# LoRA
|
||||
lora_models: List[lora_flux.LoRANetwork] = []
|
||||
for weights_file in args.lora_weights:
|
||||
if ";" in weights_file:
|
||||
weights_file, multiplier = weights_file.split(";")
|
||||
multiplier = float(multiplier)
|
||||
else:
|
||||
multiplier = 1.0
|
||||
|
||||
weights_sd = load_file(weights_file)
|
||||
is_lora = is_oft = False
|
||||
for key in weights_sd.keys():
|
||||
if key.startswith("lora"):
|
||||
is_lora = True
|
||||
if key.startswith("oft"):
|
||||
is_oft = True
|
||||
if is_lora or is_oft:
|
||||
break
|
||||
|
||||
module = lora_flux if is_lora else oft_flux
|
||||
lora_model, _ = module.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True)
|
||||
|
||||
if args.merge_lora_weights:
|
||||
lora_model.merge_to([clip_l, t5xxl], model, weights_sd)
|
||||
else:
|
||||
lora_model.apply_to([clip_l, t5xxl], model)
|
||||
info = lora_model.load_state_dict(weights_sd, strict=True)
|
||||
logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
|
||||
lora_model.eval()
|
||||
lora_model.to(device)
|
||||
|
||||
lora_models.append(lora_model)
|
||||
|
||||
if not args.interactive:
|
||||
generate_image(
|
||||
model,
|
||||
clip_l,
|
||||
t5xxl,
|
||||
ae,
|
||||
args.prompt,
|
||||
args.seed,
|
||||
args.width,
|
||||
args.height,
|
||||
args.steps,
|
||||
args.guidance,
|
||||
args.negative_prompt,
|
||||
args.cfg_scale,
|
||||
)
|
||||
else:
|
||||
# loop for interactive
|
||||
width = target_width
|
||||
height = target_height
|
||||
steps = None
|
||||
guidance = args.guidance
|
||||
cfg_scale = args.cfg_scale
|
||||
|
||||
while True:
|
||||
print(
|
||||
"Enter prompt (empty to exit). Options: --w <width> --h <height> --s <steps> --d <seed> --g <guidance> --m <multipliers for LoRA>"
|
||||
" --n <negative prompt>, `-` for empty negative prompt --c <cfg_scale>"
|
||||
)
|
||||
prompt = input()
|
||||
if prompt == "":
|
||||
break
|
||||
|
||||
# parse options
|
||||
options = prompt.split("--")
|
||||
prompt = options[0].strip()
|
||||
seed = None
|
||||
negative_prompt = None
|
||||
for opt in options[1:]:
|
||||
try:
|
||||
opt = opt.strip()
|
||||
if opt.startswith("w"):
|
||||
width = int(opt[1:].strip())
|
||||
elif opt.startswith("h"):
|
||||
height = int(opt[1:].strip())
|
||||
elif opt.startswith("s"):
|
||||
steps = int(opt[1:].strip())
|
||||
elif opt.startswith("d"):
|
||||
seed = int(opt[1:].strip())
|
||||
elif opt.startswith("g"):
|
||||
guidance = float(opt[1:].strip())
|
||||
elif opt.startswith("m"):
|
||||
mutipliers = opt[1:].strip().split(",")
|
||||
if len(mutipliers) != len(lora_models):
|
||||
logger.error(f"Invalid number of multipliers, expected {len(lora_models)}")
|
||||
continue
|
||||
for i, lora_model in enumerate(lora_models):
|
||||
lora_model.set_multiplier(float(mutipliers[i]))
|
||||
elif opt.startswith("n"):
|
||||
negative_prompt = opt[1:].strip()
|
||||
if negative_prompt == "-":
|
||||
negative_prompt = ""
|
||||
elif opt.startswith("c"):
|
||||
cfg_scale = float(opt[1:].strip())
|
||||
except ValueError as e:
|
||||
logger.error(f"Invalid option: {opt}, {e}")
|
||||
|
||||
generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance, negative_prompt, cfg_scale)
|
||||
|
||||
logger.info("Done!")
|
||||
850
flux_train.py
Normal file
850
flux_train.py
Normal file
@@ -0,0 +1,850 @@
|
||||
# training with captions
|
||||
|
||||
# Swap blocks between CPU and GPU:
|
||||
# This implementation is inspired by and based on the work of 2kpr.
|
||||
# Many thanks to 2kpr for the original concept and implementation of memory-efficient offloading.
|
||||
# The original idea has been adapted and extended to fit the current project's needs.
|
||||
|
||||
# Key features:
|
||||
# - CPU offloading during forward and backward passes
|
||||
# - Use of fused optimizer and grad_hook for efficient gradient processing
|
||||
# - Per-block fused optimizer instances
|
||||
|
||||
import argparse
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import copy
|
||||
import math
|
||||
import os
|
||||
from multiprocessing import Value
|
||||
import time
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from library import utils
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
from library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux
|
||||
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
|
||||
|
||||
import library.train_util as train_util
|
||||
|
||||
from library.utils import setup_logging, add_logging_arguments
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import library.config_util as config_util
|
||||
|
||||
# import library.sdxl_train_util as sdxl_train_util
|
||||
from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments
|
||||
|
||||
|
||||
def train(args):
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
# sdxl_train_util.verify_sdxl_training_args(args)
|
||||
deepspeed_utils.prepare_deepspeed_args(args)
|
||||
setup_logging(args, reset=True)
|
||||
|
||||
# temporary: backward compatibility for deprecated options. remove in the future
|
||||
if not args.skip_cache_check:
|
||||
args.skip_cache_check = args.skip_latents_validity_check
|
||||
|
||||
# assert (
|
||||
# not args.weighted_captions
|
||||
# ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
|
||||
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
||||
logger.warning(
|
||||
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります"
|
||||
)
|
||||
args.cache_text_encoder_outputs = True
|
||||
|
||||
if args.cpu_offload_checkpointing and not args.gradient_checkpointing:
|
||||
logger.warning(
|
||||
"cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled / cpu_offload_checkpointingが有効になっているため、gradient_checkpointingも有効になります"
|
||||
)
|
||||
args.gradient_checkpointing = True
|
||||
|
||||
assert (
|
||||
args.blocks_to_swap is None or args.blocks_to_swap == 0
|
||||
) or not args.cpu_offload_checkpointing, (
|
||||
"blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"
|
||||
)
|
||||
|
||||
cache_latents = args.cache_latents
|
||||
use_dreambooth_method = args.in_json is None
|
||||
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed) # 乱数系列を初期化する
|
||||
|
||||
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
|
||||
if args.cache_latents:
|
||||
latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(
|
||||
args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
|
||||
)
|
||||
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
|
||||
|
||||
# データセットを準備する
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
|
||||
if args.dataset_config is not None:
|
||||
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
logger.warning(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
else:
|
||||
if use_dreambooth_method:
|
||||
logger.info("Using DreamBooth method.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
|
||||
args.train_data_dir, args.reg_data_dir
|
||||
)
|
||||
}
|
||||
]
|
||||
}
|
||||
else:
|
||||
logger.info("Training with captions.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": [
|
||||
{
|
||||
"image_dir": args.train_data_dir,
|
||||
"metadata_file": args.in_json,
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args)
|
||||
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||
val_dataset_group = None
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認
|
||||
|
||||
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
|
||||
if args.debug_dataset:
|
||||
if args.cache_text_encoder_outputs:
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
|
||||
strategy_flux.FluxTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False
|
||||
)
|
||||
)
|
||||
t5xxl_max_token_length = (
|
||||
args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if is_schnell else 512)
|
||||
)
|
||||
strategy_base.TokenizeStrategy.set_strategy(strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length))
|
||||
|
||||
train_dataset_group.set_current_strategies()
|
||||
train_util.debug_dataset(train_dataset_group, True)
|
||||
return
|
||||
if len(train_dataset_group) == 0:
|
||||
logger.error(
|
||||
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
|
||||
)
|
||||
return
|
||||
|
||||
if cache_latents:
|
||||
assert (
|
||||
train_dataset_group.is_latent_cacheable()
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
|
||||
if args.cache_text_encoder_outputs:
|
||||
assert (
|
||||
train_dataset_group.is_text_encoder_output_cacheable()
|
||||
), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
||||
|
||||
# acceleratorを準備する
|
||||
logger.info("prepare accelerator")
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
|
||||
# モデルを読み込む
|
||||
|
||||
# load VAE for caching latents
|
||||
ae = None
|
||||
if cache_latents:
|
||||
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
|
||||
ae.to(accelerator.device, dtype=weight_dtype)
|
||||
ae.requires_grad_(False)
|
||||
ae.eval()
|
||||
|
||||
train_dataset_group.new_cache_latents(ae, accelerator)
|
||||
|
||||
ae.to("cpu") # if no sampling, vae can be deleted
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# prepare tokenize strategy
|
||||
if args.t5xxl_max_token_length is None:
|
||||
if is_schnell:
|
||||
t5xxl_max_token_length = 256
|
||||
else:
|
||||
t5xxl_max_token_length = 512
|
||||
else:
|
||||
t5xxl_max_token_length = args.t5xxl_max_token_length
|
||||
|
||||
flux_tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length)
|
||||
strategy_base.TokenizeStrategy.set_strategy(flux_tokenize_strategy)
|
||||
|
||||
# load clip_l, t5xxl for caching text encoder outputs
|
||||
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
|
||||
t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
|
||||
clip_l.eval()
|
||||
t5xxl.eval()
|
||||
clip_l.requires_grad_(False)
|
||||
t5xxl.requires_grad_(False)
|
||||
|
||||
text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(args.apply_t5_attn_mask)
|
||||
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
|
||||
|
||||
# cache text encoder outputs
|
||||
sample_prompts_te_outputs = None
|
||||
if args.cache_text_encoder_outputs:
|
||||
# Text Encodes are eval and no grad here
|
||||
clip_l.to(accelerator.device)
|
||||
t5xxl.to(accelerator.device)
|
||||
|
||||
text_encoder_caching_strategy = strategy_flux.FluxTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False, args.apply_t5_attn_mask
|
||||
)
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)
|
||||
|
||||
with accelerator.autocast():
|
||||
train_dataset_group.new_cache_text_encoder_outputs([clip_l, t5xxl], accelerator)
|
||||
|
||||
# cache sample prompt's embeddings to free text encoder's memory
|
||||
if args.sample_prompts is not None:
|
||||
logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}")
|
||||
|
||||
text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
|
||||
prompts = train_util.load_prompts(args.sample_prompts)
|
||||
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
for prompt_dict in prompts:
|
||||
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
|
||||
if p not in sample_prompts_te_outputs:
|
||||
logger.info(f"cache Text Encoder outputs for prompt: {p}")
|
||||
tokens_and_masks = flux_tokenize_strategy.tokenize(p)
|
||||
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
|
||||
flux_tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
|
||||
)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# now we can delete Text Encoders to free memory
|
||||
clip_l = None
|
||||
t5xxl = None
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
# load FLUX
|
||||
_, flux = flux_utils.load_flow_model(
|
||||
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors
|
||||
)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
flux.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing)
|
||||
|
||||
flux.requires_grad_(True)
|
||||
|
||||
# block swap
|
||||
|
||||
# backward compatibility
|
||||
if args.blocks_to_swap is None:
|
||||
blocks_to_swap = args.double_blocks_to_swap or 0
|
||||
if args.single_blocks_to_swap is not None:
|
||||
blocks_to_swap += args.single_blocks_to_swap // 2
|
||||
if blocks_to_swap > 0:
|
||||
logger.warning(
|
||||
"double_blocks_to_swap and single_blocks_to_swap are deprecated. Use blocks_to_swap instead."
|
||||
" / double_blocks_to_swapとsingle_blocks_to_swapは非推奨です。blocks_to_swapを使ってください。"
|
||||
)
|
||||
logger.info(
|
||||
f"double_blocks_to_swap={args.double_blocks_to_swap} and single_blocks_to_swap={args.single_blocks_to_swap} are converted to blocks_to_swap={blocks_to_swap}."
|
||||
)
|
||||
args.blocks_to_swap = blocks_to_swap
|
||||
del blocks_to_swap
|
||||
|
||||
is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
|
||||
if is_swapping_blocks:
|
||||
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
|
||||
# This idea is based on 2kpr's great work. Thank you!
|
||||
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
|
||||
flux.enable_block_swap(args.blocks_to_swap, accelerator.device)
|
||||
|
||||
if not cache_latents:
|
||||
# load VAE here if not cached
|
||||
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu")
|
||||
ae.requires_grad_(False)
|
||||
ae.eval()
|
||||
ae.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
training_models = []
|
||||
params_to_optimize = []
|
||||
training_models.append(flux)
|
||||
name_and_params = list(flux.named_parameters())
|
||||
# single param group for now
|
||||
params_to_optimize.append({"params": [p for _, p in name_and_params], "lr": args.learning_rate})
|
||||
param_names = [[n for n, _ in name_and_params]]
|
||||
|
||||
# calculate number of trainable parameters
|
||||
n_params = 0
|
||||
for group in params_to_optimize:
|
||||
for p in group["params"]:
|
||||
n_params += p.numel()
|
||||
|
||||
accelerator.print(f"number of trainable parameters: {n_params}")
|
||||
|
||||
# 学習に必要なクラスを準備する
|
||||
accelerator.print("prepare optimizer, data loader etc.")
|
||||
|
||||
if args.blockwise_fused_optimizers:
|
||||
# fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
|
||||
# Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters.
|
||||
# This balances memory usage and management complexity.
|
||||
|
||||
# split params into groups. currently different learning rates are not supported
|
||||
grouped_params = []
|
||||
param_group = {}
|
||||
for group in params_to_optimize:
|
||||
named_parameters = list(flux.named_parameters())
|
||||
assert len(named_parameters) == len(group["params"]), "number of parameters does not match"
|
||||
for p, np in zip(group["params"], named_parameters):
|
||||
# determine target layer and block index for each parameter
|
||||
block_type = "other" # double, single or other
|
||||
if np[0].startswith("double_blocks"):
|
||||
block_index = int(np[0].split(".")[1])
|
||||
block_type = "double"
|
||||
elif np[0].startswith("single_blocks"):
|
||||
block_index = int(np[0].split(".")[1])
|
||||
block_type = "single"
|
||||
else:
|
||||
block_index = -1
|
||||
|
||||
param_group_key = (block_type, block_index)
|
||||
if param_group_key not in param_group:
|
||||
param_group[param_group_key] = []
|
||||
param_group[param_group_key].append(p)
|
||||
|
||||
block_types_and_indices = []
|
||||
for param_group_key, param_group in param_group.items():
|
||||
block_types_and_indices.append(param_group_key)
|
||||
grouped_params.append({"params": param_group, "lr": args.learning_rate})
|
||||
|
||||
num_params = 0
|
||||
for p in param_group:
|
||||
num_params += p.numel()
|
||||
accelerator.print(f"block {param_group_key}: {num_params} parameters")
|
||||
|
||||
# prepare optimizers for each group
|
||||
optimizers = []
|
||||
for group in grouped_params:
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params=[group])
|
||||
optimizers.append(optimizer)
|
||||
optimizer = optimizers[0] # avoid error in the following code
|
||||
|
||||
logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers")
|
||||
|
||||
if train_util.is_schedulefree_optimizer(optimizers[0], args):
|
||||
raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers")
|
||||
optimizer_train_fn = lambda: None # dummy function
|
||||
optimizer_eval_fn = lambda: None # dummy function
|
||||
else:
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
|
||||
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
|
||||
|
||||
# prepare dataloader
|
||||
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
|
||||
# some strategies can be None
|
||||
train_dataset_group.set_current_strategies()
|
||||
|
||||
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||
)
|
||||
accelerator.print(
|
||||
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
||||
)
|
||||
|
||||
# データセット側にも学習ステップを送信
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
|
||||
# lr schedulerを用意する
|
||||
if args.blockwise_fused_optimizers:
|
||||
# prepare lr schedulers for each optimizer
|
||||
lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers]
|
||||
lr_scheduler = lr_schedulers[0] # avoid error in the following code
|
||||
else:
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
|
||||
if args.full_fp16:
|
||||
assert (
|
||||
args.mixed_precision == "fp16"
|
||||
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
||||
accelerator.print("enable full fp16 training.")
|
||||
flux.to(weight_dtype)
|
||||
if clip_l is not None:
|
||||
clip_l.to(weight_dtype)
|
||||
t5xxl.to(weight_dtype) # TODO check works with fp16 or not
|
||||
elif args.full_bf16:
|
||||
assert (
|
||||
args.mixed_precision == "bf16"
|
||||
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
|
||||
accelerator.print("enable full bf16 training.")
|
||||
flux.to(weight_dtype)
|
||||
if clip_l is not None:
|
||||
clip_l.to(weight_dtype)
|
||||
t5xxl.to(weight_dtype)
|
||||
|
||||
# if we don't cache text encoder outputs, move them to device
|
||||
if not args.cache_text_encoder_outputs:
|
||||
clip_l.to(accelerator.device)
|
||||
t5xxl.to(accelerator.device)
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
if args.deepspeed:
|
||||
ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=flux)
|
||||
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
training_models = [ds_model]
|
||||
|
||||
else:
|
||||
# accelerator does some magic
|
||||
# if we doesn't swap blocks, we can move the model to device
|
||||
flux = accelerator.prepare(flux, device_placement=[not is_swapping_blocks])
|
||||
if is_swapping_blocks:
|
||||
accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
|
||||
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||
if args.full_fp16:
|
||||
# During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.
|
||||
# -> But we think it's ok to patch accelerator even if deepspeed is enabled.
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
|
||||
# resumeする
|
||||
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||
|
||||
if args.fused_backward_pass:
|
||||
# use fused optimizer for backward pass: other optimizers will be supported in the future
|
||||
import library.adafactor_fused
|
||||
|
||||
library.adafactor_fused.patch_adafactor_fused(optimizer)
|
||||
|
||||
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
|
||||
for parameter, param_name in zip(param_group["params"], param_name_group):
|
||||
if parameter.requires_grad:
|
||||
|
||||
def create_grad_hook(p_name, p_group):
|
||||
def grad_hook(tensor: torch.Tensor):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
|
||||
optimizer.step_param(tensor, p_group)
|
||||
tensor.grad = None
|
||||
|
||||
return grad_hook
|
||||
|
||||
parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group))
|
||||
|
||||
elif args.blockwise_fused_optimizers:
|
||||
# prepare for additional optimizers and lr schedulers
|
||||
for i in range(1, len(optimizers)):
|
||||
optimizers[i] = accelerator.prepare(optimizers[i])
|
||||
lr_schedulers[i] = accelerator.prepare(lr_schedulers[i])
|
||||
|
||||
# counters are used to determine when to step the optimizer
|
||||
global optimizer_hooked_count
|
||||
global num_parameters_per_group
|
||||
global parameter_optimizer_map
|
||||
|
||||
optimizer_hooked_count = {}
|
||||
num_parameters_per_group = [0] * len(optimizers)
|
||||
parameter_optimizer_map = {}
|
||||
|
||||
for opt_idx, optimizer in enumerate(optimizers):
|
||||
for param_group in optimizer.param_groups:
|
||||
for parameter in param_group["params"]:
|
||||
if parameter.requires_grad:
|
||||
|
||||
def grad_hook(parameter: torch.Tensor):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(parameter, args.max_grad_norm)
|
||||
|
||||
i = parameter_optimizer_map[parameter]
|
||||
optimizer_hooked_count[i] += 1
|
||||
if optimizer_hooked_count[i] == num_parameters_per_group[i]:
|
||||
optimizers[i].step()
|
||||
optimizers[i].zero_grad(set_to_none=True)
|
||||
|
||||
parameter.register_post_accumulate_grad_hook(grad_hook)
|
||||
parameter_optimizer_map[parameter] = opt_idx
|
||||
num_parameters_per_group[opt_idx] += 1
|
||||
|
||||
# epoch数を計算する
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
||||
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
||||
|
||||
# 学習する
|
||||
# total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
accelerator.print("running training / 学習開始")
|
||||
accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
|
||||
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
||||
accelerator.print(
|
||||
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
|
||||
)
|
||||
# accelerator.print(
|
||||
# f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
|
||||
# )
|
||||
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||
|
||||
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
||||
global_step = 0
|
||||
|
||||
noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
|
||||
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers(
|
||||
"finetuning" if args.log_tracker_name is None else args.log_tracker_name,
|
||||
config=train_util.get_sanitized_config_or_none(args),
|
||||
init_kwargs=init_kwargs,
|
||||
)
|
||||
|
||||
if is_swapping_blocks:
|
||||
accelerator.unwrap_model(flux).prepare_block_swap_before_forward()
|
||||
|
||||
# For --sample_at_first
|
||||
optimizer_eval_fn()
|
||||
flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs)
|
||||
optimizer_train_fn()
|
||||
if len(accelerator.trackers) > 0:
|
||||
# log empty object to commit the sample images to wandb
|
||||
accelerator.log({}, step=0)
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
epoch = 0 # avoid error when max_train_steps is 0
|
||||
for epoch in range(num_train_epochs):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
for m in training_models:
|
||||
m.train()
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
|
||||
if args.blockwise_fused_optimizers:
|
||||
optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step
|
||||
|
||||
with accelerator.accumulate(*training_models):
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device, dtype=weight_dtype)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
# encode images to latents. images are [-1, 1]
|
||||
latents = ae.encode(batch["images"].to(ae.dtype)).to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# NaNが含まれていれば警告を表示し0に置き換える
|
||||
if torch.any(torch.isnan(latents)):
|
||||
accelerator.print("NaN found in latents, replacing with zeros")
|
||||
latents = torch.nan_to_num(latents, 0, out=latents)
|
||||
|
||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||
if text_encoder_outputs_list is not None:
|
||||
text_encoder_conds = text_encoder_outputs_list
|
||||
else:
|
||||
# not cached or training, so get from text encoders
|
||||
tokens_and_masks = batch["input_ids_list"]
|
||||
with torch.no_grad():
|
||||
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
|
||||
text_encoder_conds = text_encoding_strategy.encode_tokens(
|
||||
flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask
|
||||
)
|
||||
if args.full_fp16:
|
||||
text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]
|
||||
|
||||
# TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
|
||||
# get noisy model input and timesteps
|
||||
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype
|
||||
)
|
||||
|
||||
# pack latents and get img_ids
|
||||
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
|
||||
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
|
||||
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
|
||||
|
||||
# get guidance: ensure args.guidance_scale is float
|
||||
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
|
||||
|
||||
# call model
|
||||
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
||||
if not args.apply_t5_attn_mask:
|
||||
t5_attn_mask = None
|
||||
|
||||
with accelerator.autocast():
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
model_pred = flux(
|
||||
img=packed_noisy_model_input,
|
||||
img_ids=img_ids,
|
||||
txt=t5_out,
|
||||
txt_ids=txt_ids,
|
||||
y=l_pooled,
|
||||
timesteps=timesteps / 1000,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
)
|
||||
|
||||
# unpack latents
|
||||
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
|
||||
|
||||
# apply model prediction type
|
||||
model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
|
||||
|
||||
# flow matching loss: this is different from SD3
|
||||
target = noise - latents
|
||||
|
||||
# calculate loss
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
if weighting is not None:
|
||||
loss = loss * weighting
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
loss = loss.mean()
|
||||
|
||||
# backward
|
||||
accelerator.backward(loss)
|
||||
|
||||
if not (args.fused_backward_pass or args.blockwise_fused_optimizers):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
params_to_clip = []
|
||||
for m in training_models:
|
||||
params_to_clip.extend(m.parameters())
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
else:
|
||||
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
|
||||
lr_scheduler.step()
|
||||
if args.blockwise_fused_optimizers:
|
||||
for i in range(1, len(optimizers)):
|
||||
lr_schedulers[i].step()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
optimizer_eval_fn()
|
||||
flux_train_utils.sample_images(
|
||||
accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
flux_train_utils.save_flux_model_on_epoch_end_or_stepwise(
|
||||
args,
|
||||
False,
|
||||
accelerator,
|
||||
save_dtype,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
accelerator.unwrap_model(flux),
|
||||
)
|
||||
optimizer_train_fn()
|
||||
|
||||
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
||||
if len(accelerator.trackers) > 0:
|
||||
logs = {"loss": current_loss}
|
||||
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
|
||||
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||
avr_loss: float = loss_recorder.moving_average
|
||||
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if len(accelerator.trackers) > 0:
|
||||
logs = {"loss/epoch": loss_recorder.moving_average}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
optimizer_eval_fn()
|
||||
if args.save_every_n_epochs is not None:
|
||||
if accelerator.is_main_process:
|
||||
flux_train_utils.save_flux_model_on_epoch_end_or_stepwise(
|
||||
args,
|
||||
True,
|
||||
accelerator,
|
||||
save_dtype,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
accelerator.unwrap_model(flux),
|
||||
)
|
||||
|
||||
flux_train_utils.sample_images(
|
||||
accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs
|
||||
)
|
||||
optimizer_train_fn()
|
||||
|
||||
is_main_process = accelerator.is_main_process
|
||||
# if is_main_process:
|
||||
flux = accelerator.unwrap_model(flux)
|
||||
|
||||
accelerator.end_training()
|
||||
optimizer_eval_fn()
|
||||
|
||||
if args.save_state or args.save_state_on_train_end:
|
||||
train_util.save_state_on_train_end(args, accelerator)
|
||||
|
||||
del accelerator # この後メモリを使うのでこれは消す
|
||||
|
||||
if is_main_process:
|
||||
flux_train_utils.save_flux_model_on_train_end(args, save_dtype, epoch, global_step, flux)
|
||||
logger.info("model saved.")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
add_logging_arguments(parser)
|
||||
train_util.add_sd_models_arguments(parser) # TODO split this
|
||||
train_util.add_dataset_arguments(parser, True, True, True)
|
||||
train_util.add_training_arguments(parser, False)
|
||||
train_util.add_masked_loss_arguments(parser)
|
||||
deepspeed_utils.add_deepspeed_arguments(parser)
|
||||
train_util.add_sd_saving_arguments(parser)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
config_util.add_config_arguments(parser)
|
||||
add_custom_train_arguments(parser) # TODO remove this from here
|
||||
train_util.add_dit_training_arguments(parser)
|
||||
flux_train_utils.add_flux_train_arguments(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--mem_eff_save",
|
||||
action="store_true",
|
||||
help="[EXPERIMENTAL] use memory efficient custom model saving method / メモリ効率の良い独自のモデル保存方法を使う",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--fused_optimizer_groups",
|
||||
type=int,
|
||||
default=None,
|
||||
help="**this option is not working** will be removed in the future / このオプションは動作しません。将来削除されます",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--blockwise_fused_optimizers",
|
||||
action="store_true",
|
||||
help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_latents_validity_check",
|
||||
action="store_true",
|
||||
help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--double_blocks_to_swap",
|
||||
type=int,
|
||||
default=None,
|
||||
help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--single_blocks_to_swap",
|
||||
type=int,
|
||||
default=None,
|
||||
help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpu_offload_checkpointing",
|
||||
action="store_true",
|
||||
help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
train_util.verify_command_line_training_args(args)
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
train(args)
|
||||
878
flux_train_control_net.py
Normal file
878
flux_train_control_net.py
Normal file
@@ -0,0 +1,878 @@
|
||||
# training with captions
|
||||
|
||||
# Swap blocks between CPU and GPU:
|
||||
# This implementation is inspired by and based on the work of 2kpr.
|
||||
# Many thanks to 2kpr for the original concept and implementation of memory-efficient offloading.
|
||||
# The original idea has been adapted and extended to fit the current project's needs.
|
||||
|
||||
# Key features:
|
||||
# - CPU offloading during forward and backward passes
|
||||
# - Use of fused optimizer and grad_hook for efficient gradient processing
|
||||
# - Per-block fused optimizer instances
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from multiprocessing import Value
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import toml
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
|
||||
from library import utils
|
||||
from library.device_utils import clean_memory_on_device, init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
|
||||
import library.train_util as train_util
|
||||
from library import (
|
||||
deepspeed_utils,
|
||||
flux_train_utils,
|
||||
flux_utils,
|
||||
strategy_base,
|
||||
strategy_flux,
|
||||
)
|
||||
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
|
||||
from library.utils import add_logging_arguments, setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import library.config_util as config_util
|
||||
|
||||
# import library.sdxl_train_util as sdxl_train_util
|
||||
from library.config_util import (
|
||||
BlueprintGenerator,
|
||||
ConfigSanitizer,
|
||||
)
|
||||
from library.custom_train_functions import add_custom_train_arguments, apply_masked_loss
|
||||
|
||||
|
||||
def train(args):
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
# sdxl_train_util.verify_sdxl_training_args(args)
|
||||
deepspeed_utils.prepare_deepspeed_args(args)
|
||||
setup_logging(args, reset=True)
|
||||
|
||||
# temporary: backward compatibility for deprecated options. remove in the future
|
||||
if not args.skip_cache_check:
|
||||
args.skip_cache_check = args.skip_latents_validity_check
|
||||
|
||||
# assert (
|
||||
# not args.weighted_captions
|
||||
# ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
|
||||
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
||||
logger.warning(
|
||||
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります"
|
||||
)
|
||||
args.cache_text_encoder_outputs = True
|
||||
|
||||
if args.cpu_offload_checkpointing and not args.gradient_checkpointing:
|
||||
logger.warning(
|
||||
"cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled / cpu_offload_checkpointingが有効になっているため、gradient_checkpointingも有効になります"
|
||||
)
|
||||
args.gradient_checkpointing = True
|
||||
|
||||
assert (
|
||||
args.blocks_to_swap is None or args.blocks_to_swap == 0
|
||||
) or not args.cpu_offload_checkpointing, (
|
||||
"blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"
|
||||
)
|
||||
|
||||
cache_latents = args.cache_latents
|
||||
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed) # 乱数系列を初期化する
|
||||
|
||||
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
|
||||
if args.cache_latents:
|
||||
latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(
|
||||
args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
|
||||
)
|
||||
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
|
||||
|
||||
# データセットを準備する
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
|
||||
if args.dataset_config is not None:
|
||||
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "conditioning_data_dir"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
logger.warning(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
else:
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": config_util.generate_controlnet_subsets_config_by_subdirs(
|
||||
args.train_data_dir, args.conditioning_data_dir, args.caption_extension
|
||||
)
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args)
|
||||
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||
val_dataset_group = None
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認
|
||||
|
||||
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
|
||||
if args.debug_dataset:
|
||||
if args.cache_text_encoder_outputs:
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
|
||||
strategy_flux.FluxTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False
|
||||
)
|
||||
)
|
||||
t5xxl_max_token_length = (
|
||||
args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if is_schnell else 512)
|
||||
)
|
||||
strategy_base.TokenizeStrategy.set_strategy(strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length))
|
||||
|
||||
train_dataset_group.set_current_strategies()
|
||||
train_util.debug_dataset(train_dataset_group, True)
|
||||
return
|
||||
if len(train_dataset_group) == 0:
|
||||
logger.error(
|
||||
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
|
||||
)
|
||||
return
|
||||
|
||||
if cache_latents:
|
||||
assert (
|
||||
train_dataset_group.is_latent_cacheable()
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
|
||||
if args.cache_text_encoder_outputs:
|
||||
assert (
|
||||
train_dataset_group.is_text_encoder_output_cacheable()
|
||||
), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
||||
|
||||
# acceleratorを準備する
|
||||
logger.info("prepare accelerator")
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
|
||||
# モデルを読み込む
|
||||
|
||||
# load VAE for caching latents
|
||||
ae = None
|
||||
if cache_latents:
|
||||
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
|
||||
ae.to(accelerator.device, dtype=weight_dtype)
|
||||
ae.requires_grad_(False)
|
||||
ae.eval()
|
||||
|
||||
train_dataset_group.new_cache_latents(ae, accelerator)
|
||||
|
||||
ae.to("cpu") # if no sampling, vae can be deleted
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# prepare tokenize strategy
|
||||
if args.t5xxl_max_token_length is None:
|
||||
if is_schnell:
|
||||
t5xxl_max_token_length = 256
|
||||
else:
|
||||
t5xxl_max_token_length = 512
|
||||
else:
|
||||
t5xxl_max_token_length = args.t5xxl_max_token_length
|
||||
|
||||
flux_tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length)
|
||||
strategy_base.TokenizeStrategy.set_strategy(flux_tokenize_strategy)
|
||||
|
||||
# load clip_l, t5xxl for caching text encoder outputs
|
||||
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
|
||||
t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
|
||||
clip_l.eval()
|
||||
t5xxl.eval()
|
||||
clip_l.requires_grad_(False)
|
||||
t5xxl.requires_grad_(False)
|
||||
|
||||
text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(args.apply_t5_attn_mask)
|
||||
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
|
||||
|
||||
# cache text encoder outputs
|
||||
sample_prompts_te_outputs = None
|
||||
if args.cache_text_encoder_outputs:
|
||||
# Text Encodes are eval and no grad here
|
||||
clip_l.to(accelerator.device)
|
||||
t5xxl.to(accelerator.device)
|
||||
|
||||
text_encoder_caching_strategy = strategy_flux.FluxTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False, args.apply_t5_attn_mask
|
||||
)
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)
|
||||
|
||||
with accelerator.autocast():
|
||||
train_dataset_group.new_cache_text_encoder_outputs([clip_l, t5xxl], accelerator)
|
||||
|
||||
# cache sample prompt's embeddings to free text encoder's memory
|
||||
if args.sample_prompts is not None:
|
||||
logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}")
|
||||
|
||||
text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
|
||||
prompts = train_util.load_prompts(args.sample_prompts)
|
||||
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
for prompt_dict in prompts:
|
||||
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
|
||||
if p not in sample_prompts_te_outputs:
|
||||
logger.info(f"cache Text Encoder outputs for prompt: {p}")
|
||||
tokens_and_masks = flux_tokenize_strategy.tokenize(p)
|
||||
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
|
||||
flux_tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
|
||||
)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# now we can delete Text Encoders to free memory
|
||||
clip_l = None
|
||||
t5xxl = None
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
# load FLUX
|
||||
is_schnell, flux = flux_utils.load_flow_model(
|
||||
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors
|
||||
)
|
||||
flux.requires_grad_(False)
|
||||
|
||||
# load controlnet
|
||||
controlnet_dtype = torch.float32 if args.deepspeed else weight_dtype
|
||||
controlnet = flux_utils.load_controlnet(
|
||||
args.controlnet_model_name_or_path, is_schnell, controlnet_dtype, accelerator.device, args.disable_mmap_load_safetensors
|
||||
)
|
||||
controlnet.train()
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
if not args.deepspeed:
|
||||
flux.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing)
|
||||
controlnet.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing)
|
||||
|
||||
# block swap
|
||||
|
||||
# backward compatibility
|
||||
if args.blocks_to_swap is None:
|
||||
blocks_to_swap = args.double_blocks_to_swap or 0
|
||||
if args.single_blocks_to_swap is not None:
|
||||
blocks_to_swap += args.single_blocks_to_swap // 2
|
||||
if blocks_to_swap > 0:
|
||||
logger.warning(
|
||||
"double_blocks_to_swap and single_blocks_to_swap are deprecated. Use blocks_to_swap instead."
|
||||
" / double_blocks_to_swapとsingle_blocks_to_swapは非推奨です。blocks_to_swapを使ってください。"
|
||||
)
|
||||
logger.info(
|
||||
f"double_blocks_to_swap={args.double_blocks_to_swap} and single_blocks_to_swap={args.single_blocks_to_swap} are converted to blocks_to_swap={blocks_to_swap}."
|
||||
)
|
||||
args.blocks_to_swap = blocks_to_swap
|
||||
del blocks_to_swap
|
||||
|
||||
is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
|
||||
if is_swapping_blocks:
|
||||
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
|
||||
# This idea is based on 2kpr's great work. Thank you!
|
||||
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
|
||||
flux.enable_block_swap(args.blocks_to_swap, accelerator.device)
|
||||
flux.move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
|
||||
# ControlNet only has two blocks, so we can keep it on GPU
|
||||
# controlnet.enable_block_swap(args.blocks_to_swap, accelerator.device)
|
||||
else:
|
||||
flux.to(accelerator.device)
|
||||
|
||||
if not cache_latents:
|
||||
# load VAE here if not cached
|
||||
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu")
|
||||
ae.requires_grad_(False)
|
||||
ae.eval()
|
||||
ae.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
training_models = []
|
||||
params_to_optimize = []
|
||||
training_models.append(controlnet)
|
||||
name_and_params = list(controlnet.named_parameters())
|
||||
# single param group for now
|
||||
params_to_optimize.append({"params": [p for _, p in name_and_params], "lr": args.learning_rate})
|
||||
param_names = [[n for n, _ in name_and_params]]
|
||||
|
||||
# calculate number of trainable parameters
|
||||
n_params = 0
|
||||
for group in params_to_optimize:
|
||||
for p in group["params"]:
|
||||
n_params += p.numel()
|
||||
|
||||
accelerator.print(f"number of trainable parameters: {n_params}")
|
||||
|
||||
# 学習に必要なクラスを準備する
|
||||
accelerator.print("prepare optimizer, data loader etc.")
|
||||
|
||||
if args.blockwise_fused_optimizers:
|
||||
# fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
|
||||
# Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters.
|
||||
# This balances memory usage and management complexity.
|
||||
|
||||
# split params into groups. currently different learning rates are not supported
|
||||
grouped_params = []
|
||||
param_group = {}
|
||||
for group in params_to_optimize:
|
||||
named_parameters = list(controlnet.named_parameters())
|
||||
assert len(named_parameters) == len(group["params"]), "number of parameters does not match"
|
||||
for p, np in zip(group["params"], named_parameters):
|
||||
# determine target layer and block index for each parameter
|
||||
block_type = "other" # double, single or other
|
||||
if np[0].startswith("double_blocks"):
|
||||
block_index = int(np[0].split(".")[1])
|
||||
block_type = "double"
|
||||
elif np[0].startswith("single_blocks"):
|
||||
block_index = int(np[0].split(".")[1])
|
||||
block_type = "single"
|
||||
else:
|
||||
block_index = -1
|
||||
|
||||
param_group_key = (block_type, block_index)
|
||||
if param_group_key not in param_group:
|
||||
param_group[param_group_key] = []
|
||||
param_group[param_group_key].append(p)
|
||||
|
||||
block_types_and_indices = []
|
||||
for param_group_key, param_group in param_group.items():
|
||||
block_types_and_indices.append(param_group_key)
|
||||
grouped_params.append({"params": param_group, "lr": args.learning_rate})
|
||||
|
||||
num_params = 0
|
||||
for p in param_group:
|
||||
num_params += p.numel()
|
||||
accelerator.print(f"block {param_group_key}: {num_params} parameters")
|
||||
|
||||
# prepare optimizers for each group
|
||||
optimizers = []
|
||||
for group in grouped_params:
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params=[group])
|
||||
optimizers.append(optimizer)
|
||||
optimizer = optimizers[0] # avoid error in the following code
|
||||
|
||||
logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers")
|
||||
|
||||
if train_util.is_schedulefree_optimizer(optimizers[0], args):
|
||||
raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers")
|
||||
optimizer_train_fn = lambda: None # dummy function
|
||||
optimizer_eval_fn = lambda: None # dummy function
|
||||
else:
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
|
||||
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
|
||||
|
||||
# prepare dataloader
|
||||
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
|
||||
# some strategies can be None
|
||||
train_dataset_group.set_current_strategies()
|
||||
|
||||
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||
)
|
||||
accelerator.print(
|
||||
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
||||
)
|
||||
|
||||
# データセット側にも学習ステップを送信
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
|
||||
# lr schedulerを用意する
|
||||
if args.blockwise_fused_optimizers:
|
||||
# prepare lr schedulers for each optimizer
|
||||
lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers]
|
||||
lr_scheduler = lr_schedulers[0] # avoid error in the following code
|
||||
else:
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
|
||||
if args.full_fp16:
|
||||
assert (
|
||||
args.mixed_precision == "fp16"
|
||||
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
||||
accelerator.print("enable full fp16 training.")
|
||||
flux.to(weight_dtype)
|
||||
controlnet.to(weight_dtype)
|
||||
if clip_l is not None:
|
||||
clip_l.to(weight_dtype)
|
||||
t5xxl.to(weight_dtype) # TODO check works with fp16 or not
|
||||
elif args.full_bf16:
|
||||
assert (
|
||||
args.mixed_precision == "bf16"
|
||||
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
|
||||
accelerator.print("enable full bf16 training.")
|
||||
flux.to(weight_dtype)
|
||||
controlnet.to(weight_dtype)
|
||||
if clip_l is not None:
|
||||
clip_l.to(weight_dtype)
|
||||
t5xxl.to(weight_dtype)
|
||||
|
||||
# if we don't cache text encoder outputs, move them to device
|
||||
if not args.cache_text_encoder_outputs:
|
||||
clip_l.to(accelerator.device)
|
||||
t5xxl.to(accelerator.device)
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
if args.deepspeed:
|
||||
ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=controlnet)
|
||||
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
training_models = [ds_model]
|
||||
|
||||
else:
|
||||
# accelerator does some magic
|
||||
# if we doesn't swap blocks, we can move the model to device
|
||||
controlnet = accelerator.prepare(controlnet) # , device_placement=[not is_swapping_blocks])
|
||||
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||
if args.full_fp16:
|
||||
# During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.
|
||||
# -> But we think it's ok to patch accelerator even if deepspeed is enabled.
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
|
||||
# resumeする
|
||||
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||
|
||||
if args.fused_backward_pass:
|
||||
# use fused optimizer for backward pass: other optimizers will be supported in the future
|
||||
import library.adafactor_fused
|
||||
|
||||
library.adafactor_fused.patch_adafactor_fused(optimizer)
|
||||
|
||||
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
|
||||
for parameter, param_name in zip(param_group["params"], param_name_group):
|
||||
if parameter.requires_grad:
|
||||
|
||||
def create_grad_hook(p_name, p_group):
|
||||
def grad_hook(tensor: torch.Tensor):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
|
||||
optimizer.step_param(tensor, p_group)
|
||||
tensor.grad = None
|
||||
|
||||
return grad_hook
|
||||
|
||||
parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group))
|
||||
|
||||
elif args.blockwise_fused_optimizers:
|
||||
# prepare for additional optimizers and lr schedulers
|
||||
for i in range(1, len(optimizers)):
|
||||
optimizers[i] = accelerator.prepare(optimizers[i])
|
||||
lr_schedulers[i] = accelerator.prepare(lr_schedulers[i])
|
||||
|
||||
# counters are used to determine when to step the optimizer
|
||||
global optimizer_hooked_count
|
||||
global num_parameters_per_group
|
||||
global parameter_optimizer_map
|
||||
|
||||
optimizer_hooked_count = {}
|
||||
num_parameters_per_group = [0] * len(optimizers)
|
||||
parameter_optimizer_map = {}
|
||||
|
||||
for opt_idx, optimizer in enumerate(optimizers):
|
||||
for param_group in optimizer.param_groups:
|
||||
for parameter in param_group["params"]:
|
||||
if parameter.requires_grad:
|
||||
|
||||
def grad_hook(parameter: torch.Tensor):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(parameter, args.max_grad_norm)
|
||||
|
||||
i = parameter_optimizer_map[parameter]
|
||||
optimizer_hooked_count[i] += 1
|
||||
if optimizer_hooked_count[i] == num_parameters_per_group[i]:
|
||||
optimizers[i].step()
|
||||
optimizers[i].zero_grad(set_to_none=True)
|
||||
|
||||
parameter.register_post_accumulate_grad_hook(grad_hook)
|
||||
parameter_optimizer_map[parameter] = opt_idx
|
||||
num_parameters_per_group[opt_idx] += 1
|
||||
|
||||
# epoch数を計算する
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
||||
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
||||
|
||||
# 学習する
|
||||
# total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
accelerator.print("running training / 学習開始")
|
||||
accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
|
||||
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
||||
accelerator.print(
|
||||
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
|
||||
)
|
||||
# accelerator.print(
|
||||
# f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
|
||||
# )
|
||||
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||
|
||||
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
||||
global_step = 0
|
||||
|
||||
noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
|
||||
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers(
|
||||
"finetuning" if args.log_tracker_name is None else args.log_tracker_name,
|
||||
config=train_util.get_sanitized_config_or_none(args),
|
||||
init_kwargs=init_kwargs,
|
||||
)
|
||||
|
||||
if is_swapping_blocks:
|
||||
flux.prepare_block_swap_before_forward()
|
||||
|
||||
# For --sample_at_first
|
||||
optimizer_eval_fn()
|
||||
flux_train_utils.sample_images(
|
||||
accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet
|
||||
)
|
||||
optimizer_train_fn()
|
||||
if len(accelerator.trackers) > 0:
|
||||
# log empty object to commit the sample images to wandb
|
||||
accelerator.log({}, step=0)
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
epoch = 0 # avoid error when max_train_steps is 0
|
||||
for epoch in range(num_train_epochs):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
for m in training_models:
|
||||
m.train()
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
|
||||
if args.blockwise_fused_optimizers:
|
||||
optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step
|
||||
|
||||
with accelerator.accumulate(*training_models):
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device, dtype=weight_dtype)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
# encode images to latents. images are [-1, 1]
|
||||
latents = ae.encode(batch["images"].to(ae.dtype)).to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# NaNが含まれていれば警告を表示し0に置き換える
|
||||
if torch.any(torch.isnan(latents)):
|
||||
accelerator.print("NaN found in latents, replacing with zeros")
|
||||
latents = torch.nan_to_num(latents, 0, out=latents)
|
||||
|
||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||
if text_encoder_outputs_list is not None:
|
||||
text_encoder_conds = text_encoder_outputs_list
|
||||
else:
|
||||
# not cached or training, so get from text encoders
|
||||
tokens_and_masks = batch["input_ids_list"]
|
||||
with torch.no_grad():
|
||||
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
|
||||
text_encoder_conds = text_encoding_strategy.encode_tokens(
|
||||
flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask
|
||||
)
|
||||
text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]
|
||||
|
||||
# TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
|
||||
# get noisy model input and timesteps
|
||||
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype
|
||||
)
|
||||
|
||||
# pack latents and get img_ids
|
||||
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
|
||||
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
|
||||
img_ids = (
|
||||
flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width)
|
||||
.to(device=accelerator.device)
|
||||
.to(weight_dtype)
|
||||
)
|
||||
|
||||
# get guidance: ensure args.guidance_scale is float
|
||||
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# call model
|
||||
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
||||
if not args.apply_t5_attn_mask:
|
||||
t5_attn_mask = None
|
||||
|
||||
with accelerator.autocast():
|
||||
block_samples, block_single_samples = controlnet(
|
||||
img=packed_noisy_model_input,
|
||||
img_ids=img_ids,
|
||||
controlnet_cond=batch["conditioning_images"].to(accelerator.device).to(weight_dtype),
|
||||
txt=t5_out,
|
||||
txt_ids=txt_ids,
|
||||
y=l_pooled,
|
||||
timesteps=timesteps / 1000,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
)
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
model_pred = flux(
|
||||
img=packed_noisy_model_input,
|
||||
img_ids=img_ids,
|
||||
txt=t5_out,
|
||||
txt_ids=txt_ids,
|
||||
y=l_pooled,
|
||||
block_controlnet_hidden_states=block_samples,
|
||||
block_controlnet_single_hidden_states=block_single_samples,
|
||||
timesteps=timesteps / 1000,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
)
|
||||
|
||||
# unpack latents
|
||||
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
|
||||
|
||||
# apply model prediction type
|
||||
model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
|
||||
|
||||
# flow matching loss: this is different from SD3
|
||||
target = noise - latents
|
||||
|
||||
# calculate loss
|
||||
loss = train_util.conditional_loss(
|
||||
model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=None
|
||||
)
|
||||
if weighting is not None:
|
||||
loss = loss * weighting
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
loss = loss.mean()
|
||||
|
||||
# backward
|
||||
accelerator.backward(loss)
|
||||
|
||||
if not (args.fused_backward_pass or args.blockwise_fused_optimizers):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
params_to_clip = []
|
||||
for m in training_models:
|
||||
params_to_clip.extend(m.parameters())
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
else:
|
||||
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
|
||||
lr_scheduler.step()
|
||||
if args.blockwise_fused_optimizers:
|
||||
for i in range(1, len(optimizers)):
|
||||
lr_schedulers[i].step()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
optimizer_eval_fn()
|
||||
flux_train_utils.sample_images(
|
||||
accelerator,
|
||||
args,
|
||||
None,
|
||||
global_step,
|
||||
flux,
|
||||
ae,
|
||||
[clip_l, t5xxl],
|
||||
sample_prompts_te_outputs,
|
||||
controlnet=controlnet,
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
flux_train_utils.save_flux_model_on_epoch_end_or_stepwise(
|
||||
args,
|
||||
False,
|
||||
accelerator,
|
||||
save_dtype,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
accelerator.unwrap_model(controlnet),
|
||||
)
|
||||
optimizer_train_fn()
|
||||
|
||||
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
||||
if len(accelerator.trackers) > 0:
|
||||
logs = {"loss": current_loss}
|
||||
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
|
||||
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||
avr_loss: float = loss_recorder.moving_average
|
||||
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if len(accelerator.trackers) > 0:
|
||||
logs = {"loss/epoch": loss_recorder.moving_average}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
optimizer_eval_fn()
|
||||
if args.save_every_n_epochs is not None:
|
||||
if accelerator.is_main_process:
|
||||
flux_train_utils.save_flux_model_on_epoch_end_or_stepwise(
|
||||
args,
|
||||
True,
|
||||
accelerator,
|
||||
save_dtype,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
accelerator.unwrap_model(controlnet),
|
||||
)
|
||||
|
||||
flux_train_utils.sample_images(
|
||||
accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet
|
||||
)
|
||||
optimizer_train_fn()
|
||||
|
||||
is_main_process = accelerator.is_main_process
|
||||
# if is_main_process:
|
||||
controlnet = accelerator.unwrap_model(controlnet)
|
||||
|
||||
accelerator.end_training()
|
||||
optimizer_eval_fn()
|
||||
|
||||
if args.save_state or args.save_state_on_train_end:
|
||||
train_util.save_state_on_train_end(args, accelerator)
|
||||
|
||||
del accelerator # この後メモリを使うのでこれは消す
|
||||
|
||||
if is_main_process:
|
||||
flux_train_utils.save_flux_model_on_train_end(args, save_dtype, epoch, global_step, controlnet)
|
||||
logger.info("model saved.")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
add_logging_arguments(parser)
|
||||
train_util.add_sd_models_arguments(parser) # TODO split this
|
||||
train_util.add_dataset_arguments(parser, False, True, True)
|
||||
train_util.add_training_arguments(parser, False)
|
||||
train_util.add_masked_loss_arguments(parser)
|
||||
deepspeed_utils.add_deepspeed_arguments(parser)
|
||||
train_util.add_sd_saving_arguments(parser)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
config_util.add_config_arguments(parser)
|
||||
add_custom_train_arguments(parser) # TODO remove this from here
|
||||
train_util.add_dit_training_arguments(parser)
|
||||
flux_train_utils.add_flux_train_arguments(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--mem_eff_save",
|
||||
action="store_true",
|
||||
help="[EXPERIMENTAL] use memory efficient custom model saving method / メモリ効率の良い独自のモデル保存方法を使う",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--fused_optimizer_groups",
|
||||
type=int,
|
||||
default=None,
|
||||
help="**this option is not working** will be removed in the future / このオプションは動作しません。将来削除されます",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--blockwise_fused_optimizers",
|
||||
action="store_true",
|
||||
help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_latents_validity_check",
|
||||
action="store_true",
|
||||
help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--double_blocks_to_swap",
|
||||
type=int,
|
||||
default=None,
|
||||
help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--single_blocks_to_swap",
|
||||
type=int,
|
||||
default=None,
|
||||
help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpu_offload_checkpointing",
|
||||
action="store_true",
|
||||
help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
train_util.verify_command_line_training_args(args)
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
train(args)
|
||||
559
flux_train_network.py
Normal file
559
flux_train_network.py
Normal file
@@ -0,0 +1,559 @@
|
||||
import argparse
|
||||
import copy
|
||||
import math
|
||||
import random
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
|
||||
from library.device_utils import clean_memory_on_device, init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
import train_network
|
||||
from library import (
|
||||
flux_models,
|
||||
flux_train_utils,
|
||||
flux_utils,
|
||||
sd3_train_utils,
|
||||
strategy_base,
|
||||
strategy_flux,
|
||||
train_util,
|
||||
)
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sample_prompts_te_outputs = None
|
||||
self.is_schnell: Optional[bool] = None
|
||||
self.is_swapping_blocks: bool = False
|
||||
|
||||
def assert_extra_args(
|
||||
self,
|
||||
args,
|
||||
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
|
||||
val_dataset_group: Optional[train_util.DatasetGroup],
|
||||
):
|
||||
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
|
||||
# sdxl_train_util.verify_sdxl_training_args(args)
|
||||
|
||||
if args.fp8_base_unet:
|
||||
args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1
|
||||
|
||||
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
||||
logger.warning(
|
||||
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります"
|
||||
)
|
||||
args.cache_text_encoder_outputs = True
|
||||
|
||||
if args.cache_text_encoder_outputs:
|
||||
assert (
|
||||
train_dataset_group.is_text_encoder_output_cacheable()
|
||||
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
||||
|
||||
# prepare CLIP-L/T5XXL training flags
|
||||
self.train_clip_l = not args.network_train_unet_only
|
||||
self.train_t5xxl = False # default is False even if args.network_train_unet_only is False
|
||||
|
||||
if args.max_token_length is not None:
|
||||
logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")
|
||||
|
||||
assert (
|
||||
args.blocks_to_swap is None or args.blocks_to_swap == 0
|
||||
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"
|
||||
|
||||
# deprecated split_mode option
|
||||
if args.split_mode:
|
||||
if args.blocks_to_swap is not None:
|
||||
logger.warning(
|
||||
"split_mode is deprecated. Because `--blocks_to_swap` is set, `--split_mode` is ignored."
|
||||
" / split_modeは非推奨です。`--blocks_to_swap`が設定されているため、`--split_mode`は無視されます。"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"split_mode is deprecated. Please use `--blocks_to_swap` instead. `--blocks_to_swap 18` is automatically set."
|
||||
" / split_modeは非推奨です。代わりに`--blocks_to_swap`を使用してください。`--blocks_to_swap 18`が自動的に設定されました。"
|
||||
)
|
||||
args.blocks_to_swap = 18 # 18 is safe for most cases
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.verify_bucket_reso_steps(32) # TODO check this
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
# currently offload to cpu for some models
|
||||
|
||||
# if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
|
||||
loading_dtype = None if args.fp8_base else weight_dtype
|
||||
|
||||
# if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
|
||||
self.is_schnell, model = flux_utils.load_flow_model(
|
||||
args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors
|
||||
)
|
||||
if args.fp8_base:
|
||||
# check dtype of model
|
||||
if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz:
|
||||
raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
|
||||
elif model.dtype == torch.float8_e4m3fn:
|
||||
logger.info("Loaded fp8 FLUX model")
|
||||
else:
|
||||
logger.info(
|
||||
"Cast FLUX model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint."
|
||||
" / FLUXモデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。"
|
||||
)
|
||||
model.to(torch.float8_e4m3fn)
|
||||
|
||||
# if args.split_mode:
|
||||
# model = self.prepare_split_model(model, weight_dtype, accelerator)
|
||||
|
||||
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
|
||||
if self.is_swapping_blocks:
|
||||
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
|
||||
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
|
||||
model.enable_block_swap(args.blocks_to_swap, accelerator.device)
|
||||
|
||||
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
||||
clip_l.eval()
|
||||
|
||||
# if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8)
|
||||
if args.fp8_base and not args.fp8_base_unet:
|
||||
loading_dtype = None # as is
|
||||
else:
|
||||
loading_dtype = weight_dtype
|
||||
|
||||
# loading t5xxl to cpu takes a long time, so we should load to gpu in future
|
||||
t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
||||
t5xxl.eval()
|
||||
if args.fp8_base and not args.fp8_base_unet:
|
||||
# check dtype of model
|
||||
if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz:
|
||||
raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}")
|
||||
elif t5xxl.dtype == torch.float8_e4m3fn:
|
||||
logger.info("Loaded fp8 T5XXL model")
|
||||
|
||||
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
||||
|
||||
return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model
|
||||
|
||||
def get_tokenize_strategy(self, args):
|
||||
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
|
||||
|
||||
if args.t5xxl_max_token_length is None:
|
||||
if is_schnell:
|
||||
t5xxl_max_token_length = 256
|
||||
else:
|
||||
t5xxl_max_token_length = 512
|
||||
else:
|
||||
t5xxl_max_token_length = args.t5xxl_max_token_length
|
||||
|
||||
logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}")
|
||||
return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir)
|
||||
|
||||
def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy):
|
||||
return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl]
|
||||
|
||||
def get_latents_caching_strategy(self, args):
|
||||
latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False)
|
||||
return latents_caching_strategy
|
||||
|
||||
def get_text_encoding_strategy(self, args):
|
||||
return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask)
|
||||
|
||||
def post_process_network(self, args, accelerator, network, text_encoders, unet):
|
||||
# check t5xxl is trained or not
|
||||
self.train_t5xxl = network.train_t5xxl
|
||||
|
||||
if self.train_t5xxl and args.cache_text_encoder_outputs:
|
||||
raise ValueError(
|
||||
"T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません"
|
||||
)
|
||||
|
||||
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
||||
if args.cache_text_encoder_outputs:
|
||||
if self.train_clip_l and not self.train_t5xxl:
|
||||
return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached
|
||||
else:
|
||||
return None # no text encoders are needed for encoding because both are cached
|
||||
else:
|
||||
return text_encoders # both CLIP-L and T5XXL are needed for encoding
|
||||
|
||||
def get_text_encoders_train_flags(self, args, text_encoders):
|
||||
return [self.train_clip_l, self.train_t5xxl]
|
||||
|
||||
def get_text_encoder_outputs_caching_strategy(self, args):
|
||||
if args.cache_text_encoder_outputs:
|
||||
# if the text encoders is trained, we need tokenization, so is_partial is True
|
||||
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
args.text_encoder_batch_size,
|
||||
args.skip_cache_check,
|
||||
is_partial=self.train_clip_l or self.train_t5xxl,
|
||||
apply_t5_attn_mask=args.apply_t5_attn_mask,
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
def cache_text_encoder_outputs_if_needed(
|
||||
self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
|
||||
):
|
||||
if args.cache_text_encoder_outputs:
|
||||
if not args.lowram:
|
||||
# メモリ消費を減らす
|
||||
logger.info("move vae and unet to cpu to save memory")
|
||||
org_vae_device = vae.device
|
||||
org_unet_device = unet.device
|
||||
vae.to("cpu")
|
||||
unet.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
||||
logger.info("move text encoders to gpu")
|
||||
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
|
||||
text_encoders[1].to(accelerator.device)
|
||||
|
||||
if text_encoders[1].dtype == torch.float8_e4m3fn:
|
||||
# if we load fp8 weights, the model is already fp8, so we use it as is
|
||||
self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
|
||||
else:
|
||||
# otherwise, we need to convert it to target dtype
|
||||
text_encoders[1].to(weight_dtype)
|
||||
|
||||
with accelerator.autocast():
|
||||
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
|
||||
|
||||
# cache sample prompts
|
||||
if args.sample_prompts is not None:
|
||||
logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}")
|
||||
|
||||
tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
|
||||
text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
|
||||
prompts = train_util.load_prompts(args.sample_prompts)
|
||||
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
for prompt_dict in prompts:
|
||||
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
|
||||
if p not in sample_prompts_te_outputs:
|
||||
logger.info(f"cache Text Encoder outputs for prompt: {p}")
|
||||
tokens_and_masks = tokenize_strategy.tokenize(p)
|
||||
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, text_encoders, tokens_and_masks, args.apply_t5_attn_mask
|
||||
)
|
||||
self.sample_prompts_te_outputs = sample_prompts_te_outputs
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# move back to cpu
|
||||
if not self.is_train_text_encoder(args):
|
||||
logger.info("move CLIP-L back to cpu")
|
||||
text_encoders[0].to("cpu")
|
||||
logger.info("move t5XXL back to cpu")
|
||||
text_encoders[1].to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
if not args.lowram:
|
||||
logger.info("move vae and unet back to original device")
|
||||
vae.to(org_vae_device)
|
||||
unet.to(org_unet_device)
|
||||
else:
|
||||
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
|
||||
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoders[1].to(accelerator.device)
|
||||
|
||||
# def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
||||
# noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
||||
|
||||
# # get size embeddings
|
||||
# orig_size = batch["original_sizes_hw"]
|
||||
# crop_size = batch["crop_top_lefts"]
|
||||
# target_size = batch["target_sizes_hw"]
|
||||
# embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
|
||||
|
||||
# # concat embeddings
|
||||
# encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
|
||||
# vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
|
||||
# text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
|
||||
|
||||
# noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
||||
# return noise_pred
|
||||
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux):
|
||||
text_encoders = text_encoder # for compatibility
|
||||
text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)
|
||||
|
||||
flux_train_utils.sample_images(
|
||||
accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs
|
||||
)
|
||||
# return
|
||||
|
||||
"""
|
||||
class FluxUpperLowerWrapper(torch.nn.Module):
|
||||
def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.FluxLower, device: torch.device):
|
||||
super().__init__()
|
||||
self.flux_upper = flux_upper
|
||||
self.flux_lower = flux_lower
|
||||
self.target_device = device
|
||||
|
||||
def prepare_block_swap_before_forward(self):
|
||||
pass
|
||||
|
||||
def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None):
|
||||
self.flux_lower.to("cpu")
|
||||
clean_memory_on_device(self.target_device)
|
||||
self.flux_upper.to(self.target_device)
|
||||
img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance, txt_attention_mask)
|
||||
self.flux_upper.to("cpu")
|
||||
clean_memory_on_device(self.target_device)
|
||||
self.flux_lower.to(self.target_device)
|
||||
return self.flux_lower(img, txt, vec, pe, txt_attention_mask)
|
||||
|
||||
wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
flux_train_utils.sample_images(
|
||||
accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs
|
||||
)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
"""
|
||||
|
||||
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
|
||||
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
|
||||
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
||||
return noise_scheduler
|
||||
|
||||
def encode_images_to_latents(self, args, vae, images):
|
||||
return vae.encode(images)
|
||||
|
||||
def shift_scale_latents(self, args, latents):
|
||||
return latents
|
||||
|
||||
def get_noise_pred_and_target(
|
||||
self,
|
||||
args,
|
||||
accelerator,
|
||||
noise_scheduler,
|
||||
latents,
|
||||
batch,
|
||||
text_encoder_conds,
|
||||
unet: flux_models.Flux,
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=True,
|
||||
):
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
|
||||
# get noisy model input and timesteps
|
||||
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
|
||||
)
|
||||
|
||||
# pack latents and get img_ids
|
||||
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
|
||||
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
|
||||
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
|
||||
|
||||
# get guidance
|
||||
# ensure guidance_scale in args is float
|
||||
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
|
||||
|
||||
# ensure the hidden state will require grad
|
||||
if args.gradient_checkpointing:
|
||||
noisy_model_input.requires_grad_(True)
|
||||
for t in text_encoder_conds:
|
||||
if t is not None and t.dtype.is_floating_point:
|
||||
t.requires_grad_(True)
|
||||
img_ids.requires_grad_(True)
|
||||
guidance_vec.requires_grad_(True)
|
||||
|
||||
# Predict the noise residual
|
||||
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
||||
if not args.apply_t5_attn_mask:
|
||||
t5_attn_mask = None
|
||||
|
||||
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
|
||||
# grad is enabled even if unet is not in train mode, because Text Encoder is in train mode
|
||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
model_pred = unet(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=t5_out,
|
||||
txt_ids=txt_ids,
|
||||
y=l_pooled,
|
||||
timesteps=timesteps / 1000,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
)
|
||||
return model_pred
|
||||
|
||||
model_pred = call_dit(
|
||||
img=packed_noisy_model_input,
|
||||
img_ids=img_ids,
|
||||
t5_out=t5_out,
|
||||
txt_ids=txt_ids,
|
||||
l_pooled=l_pooled,
|
||||
timesteps=timesteps,
|
||||
guidance_vec=guidance_vec,
|
||||
t5_attn_mask=t5_attn_mask,
|
||||
)
|
||||
|
||||
# unpack latents
|
||||
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
|
||||
|
||||
# apply model prediction type
|
||||
model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
|
||||
|
||||
# flow matching loss: this is different from SD3
|
||||
target = noise - latents
|
||||
|
||||
# differential output preservation
|
||||
if "custom_attributes" in batch:
|
||||
diff_output_pr_indices = []
|
||||
for i, custom_attributes in enumerate(batch["custom_attributes"]):
|
||||
if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
|
||||
diff_output_pr_indices.append(i)
|
||||
|
||||
if len(diff_output_pr_indices) > 0:
|
||||
network.set_multiplier(0.0)
|
||||
unet.prepare_block_swap_before_forward()
|
||||
with torch.no_grad():
|
||||
model_pred_prior = call_dit(
|
||||
img=packed_noisy_model_input[diff_output_pr_indices],
|
||||
img_ids=img_ids[diff_output_pr_indices],
|
||||
t5_out=t5_out[diff_output_pr_indices],
|
||||
txt_ids=txt_ids[diff_output_pr_indices],
|
||||
l_pooled=l_pooled[diff_output_pr_indices],
|
||||
timesteps=timesteps[diff_output_pr_indices],
|
||||
guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None,
|
||||
t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None,
|
||||
)
|
||||
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
|
||||
|
||||
model_pred_prior = flux_utils.unpack_latents(model_pred_prior, packed_latent_height, packed_latent_width)
|
||||
model_pred_prior, _ = flux_train_utils.apply_model_prediction_type(
|
||||
args,
|
||||
model_pred_prior,
|
||||
noisy_model_input[diff_output_pr_indices],
|
||||
sigmas[diff_output_pr_indices] if sigmas is not None else None,
|
||||
)
|
||||
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
|
||||
|
||||
return model_pred, target, timesteps, weighting
|
||||
|
||||
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
||||
return loss
|
||||
|
||||
def get_sai_model_spec(self, args):
|
||||
return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev")
|
||||
|
||||
def update_metadata(self, metadata, args):
|
||||
metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask
|
||||
metadata["ss_weighting_scheme"] = args.weighting_scheme
|
||||
metadata["ss_logit_mean"] = args.logit_mean
|
||||
metadata["ss_logit_std"] = args.logit_std
|
||||
metadata["ss_mode_scale"] = args.mode_scale
|
||||
metadata["ss_guidance_scale"] = args.guidance_scale
|
||||
metadata["ss_timestep_sampling"] = args.timestep_sampling
|
||||
metadata["ss_sigmoid_scale"] = args.sigmoid_scale
|
||||
metadata["ss_model_prediction_type"] = args.model_prediction_type
|
||||
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
|
||||
|
||||
def is_text_encoder_not_needed_for_training(self, args):
|
||||
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
|
||||
|
||||
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
|
||||
if index == 0: # CLIP-L
|
||||
return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder)
|
||||
else: # T5XXL
|
||||
text_encoder.encoder.embed_tokens.requires_grad_(True)
|
||||
|
||||
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
|
||||
if index == 0: # CLIP-L
|
||||
logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}")
|
||||
text_encoder.to(te_weight_dtype) # fp8
|
||||
text_encoder.text_model.embeddings.to(dtype=weight_dtype)
|
||||
else: # T5XXL
|
||||
|
||||
def prepare_fp8(text_encoder, target_dtype):
|
||||
def forward_hook(module):
|
||||
def forward(hidden_states):
|
||||
hidden_gelu = module.act(module.wi_0(hidden_states))
|
||||
hidden_linear = module.wi_1(hidden_states)
|
||||
hidden_states = hidden_gelu * hidden_linear
|
||||
hidden_states = module.dropout(hidden_states)
|
||||
|
||||
hidden_states = module.wo(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
return forward
|
||||
|
||||
for module in text_encoder.modules():
|
||||
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
|
||||
# print("set", module.__class__.__name__, "to", target_dtype)
|
||||
module.to(target_dtype)
|
||||
if module.__class__.__name__ in ["T5DenseGatedActDense"]:
|
||||
# print("set", module.__class__.__name__, "hooks")
|
||||
module.forward = forward_hook(module)
|
||||
|
||||
if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype:
|
||||
logger.info(f"T5XXL already prepared for fp8")
|
||||
else:
|
||||
logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks")
|
||||
text_encoder.to(te_weight_dtype) # fp8
|
||||
prepare_fp8(text_encoder, weight_dtype)
|
||||
|
||||
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
||||
if self.is_swapping_blocks:
|
||||
# prepare for next forward: because backward pass is not called, we need to prepare it here
|
||||
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
|
||||
|
||||
def prepare_unet_with_accelerator(
|
||||
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
|
||||
) -> torch.nn.Module:
|
||||
if not self.is_swapping_blocks:
|
||||
return super().prepare_unet_with_accelerator(args, accelerator, unet)
|
||||
|
||||
# if we doesn't swap blocks, we can move the model to device
|
||||
flux: flux_models.Flux = unet
|
||||
flux = accelerator.prepare(flux, device_placement=[not self.is_swapping_blocks])
|
||||
accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
|
||||
accelerator.unwrap_model(flux).prepare_block_swap_before_forward()
|
||||
|
||||
return flux
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = train_network.setup_parser()
|
||||
train_util.add_dit_training_arguments(parser)
|
||||
flux_train_utils.add_flux_train_arguments(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--split_mode",
|
||||
action="store_true",
|
||||
# help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required"
|
||||
# + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要",
|
||||
help="[Deprecated] This option is deprecated. Please use `--blocks_to_swap` instead."
|
||||
" / このオプションは非推奨です。代わりに`--blocks_to_swap`を使用してください。",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
train_util.verify_command_line_training_args(args)
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
trainer = FluxNetworkTrainer()
|
||||
trainer.train(args)
|
||||
3417
gen_img.py
Normal file
3417
gen_img.py
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
138
library/adafactor_fused.py
Normal file
138
library/adafactor_fused.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import math
|
||||
import torch
|
||||
from transformers import Adafactor
|
||||
|
||||
# stochastic rounding for bfloat16
|
||||
# The implementation was provided by 2kpr. Thank you very much!
|
||||
|
||||
def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
||||
"""
|
||||
copies source into target using stochastic rounding
|
||||
|
||||
Args:
|
||||
target: the target tensor with dtype=bfloat16
|
||||
source: the target tensor with dtype=float32
|
||||
"""
|
||||
# create a random 16 bit integer
|
||||
result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
|
||||
|
||||
# add the random number to the lower 16 bit of the mantissa
|
||||
result.add_(source.view(dtype=torch.int32))
|
||||
|
||||
# mask off the lower 16 bit of the mantissa
|
||||
result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
|
||||
|
||||
# copy the higher 16 bit into the target tensor
|
||||
target.copy_(result.view(dtype=torch.float32))
|
||||
|
||||
del result
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def adafactor_step_param(self, p, group):
|
||||
if p.grad is None:
|
||||
return
|
||||
grad = p.grad
|
||||
if grad.dtype in {torch.float16, torch.bfloat16}:
|
||||
grad = grad.float()
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError("Adafactor does not support sparse gradients.")
|
||||
|
||||
state = self.state[p]
|
||||
grad_shape = grad.shape
|
||||
|
||||
factored, use_first_moment = Adafactor._get_options(group, grad_shape)
|
||||
# State Initialization
|
||||
if len(state) == 0:
|
||||
state["step"] = 0
|
||||
|
||||
if use_first_moment:
|
||||
# Exponential moving average of gradient values
|
||||
state["exp_avg"] = torch.zeros_like(grad)
|
||||
if factored:
|
||||
state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
|
||||
state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
|
||||
else:
|
||||
state["exp_avg_sq"] = torch.zeros_like(grad)
|
||||
|
||||
state["RMS"] = 0
|
||||
else:
|
||||
if use_first_moment:
|
||||
state["exp_avg"] = state["exp_avg"].to(grad)
|
||||
if factored:
|
||||
state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
|
||||
state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
|
||||
else:
|
||||
state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
|
||||
|
||||
p_data_fp32 = p
|
||||
if p.dtype in {torch.float16, torch.bfloat16}:
|
||||
p_data_fp32 = p_data_fp32.float()
|
||||
|
||||
state["step"] += 1
|
||||
state["RMS"] = Adafactor._rms(p_data_fp32)
|
||||
lr = Adafactor._get_lr(group, state)
|
||||
|
||||
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
|
||||
update = (grad**2) + group["eps"][0]
|
||||
if factored:
|
||||
exp_avg_sq_row = state["exp_avg_sq_row"]
|
||||
exp_avg_sq_col = state["exp_avg_sq_col"]
|
||||
|
||||
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
|
||||
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
|
||||
|
||||
# Approximation of exponential moving average of square of gradient
|
||||
update = Adafactor._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
|
||||
update.mul_(grad)
|
||||
else:
|
||||
exp_avg_sq = state["exp_avg_sq"]
|
||||
|
||||
exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
|
||||
update = exp_avg_sq.rsqrt().mul_(grad)
|
||||
|
||||
update.div_((Adafactor._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
|
||||
update.mul_(lr)
|
||||
|
||||
if use_first_moment:
|
||||
exp_avg = state["exp_avg"]
|
||||
exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
|
||||
update = exp_avg
|
||||
|
||||
if group["weight_decay"] != 0:
|
||||
p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
|
||||
|
||||
p_data_fp32.add_(-update)
|
||||
|
||||
# if p.dtype in {torch.float16, torch.bfloat16}:
|
||||
# p.copy_(p_data_fp32)
|
||||
|
||||
if p.dtype == torch.bfloat16:
|
||||
copy_stochastic_(p, p_data_fp32)
|
||||
elif p.dtype == torch.float16:
|
||||
p.copy_(p_data_fp32)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def adafactor_step(self, closure=None):
|
||||
"""
|
||||
Performs a single optimization step
|
||||
|
||||
Arguments:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
adafactor_step_param(self, p, group)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def patch_adafactor_fused(optimizer: Adafactor):
|
||||
optimizer.step_param = adafactor_step_param.__get__(optimizer)
|
||||
optimizer.step = adafactor_step.__get__(optimizer)
|
||||
@@ -10,13 +10,7 @@ import json
|
||||
from pathlib import Path
|
||||
|
||||
# from toolz import curry
|
||||
from typing import (
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import toml
|
||||
import voluptuous
|
||||
@@ -40,10 +34,18 @@ from .train_util import (
|
||||
ControlNetDataset,
|
||||
DatasetGroup,
|
||||
)
|
||||
from .utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def add_config_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル")
|
||||
parser.add_argument(
|
||||
"--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル"
|
||||
)
|
||||
|
||||
|
||||
# TODO: inherit Params class in Subset, Dataset
|
||||
@@ -57,6 +59,8 @@ class BaseSubsetParams:
|
||||
caption_separator: str = (",",)
|
||||
keep_tokens: int = 0
|
||||
keep_tokens_separator: str = (None,)
|
||||
secondary_separator: Optional[str] = None
|
||||
enable_wildcard: bool = False
|
||||
color_aug: bool = False
|
||||
flip_aug: bool = False
|
||||
face_crop_aug_range: Optional[Tuple[float, float]] = None
|
||||
@@ -68,6 +72,10 @@ class BaseSubsetParams:
|
||||
caption_tag_dropout_rate: float = 0.0
|
||||
token_warmup_min: int = 1
|
||||
token_warmup_step: float = 0
|
||||
custom_attributes: Optional[Dict[str, Any]] = None
|
||||
validation_seed: int = 0
|
||||
validation_split: float = 0.0
|
||||
resize_interpolation: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -75,27 +83,31 @@ class DreamBoothSubsetParams(BaseSubsetParams):
|
||||
is_reg: bool = False
|
||||
class_tokens: Optional[str] = None
|
||||
caption_extension: str = ".caption"
|
||||
cache_info: bool = False
|
||||
alpha_mask: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class FineTuningSubsetParams(BaseSubsetParams):
|
||||
metadata_file: Optional[str] = None
|
||||
alpha_mask: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class ControlNetSubsetParams(BaseSubsetParams):
|
||||
conditioning_data_dir: str = None
|
||||
caption_extension: str = ".caption"
|
||||
cache_info: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseDatasetParams:
|
||||
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None
|
||||
max_token_length: int = None
|
||||
resolution: Optional[Tuple[int, int]] = None
|
||||
network_multiplier: float = 1.0
|
||||
debug_dataset: bool = False
|
||||
|
||||
validation_seed: Optional[int] = None
|
||||
validation_split: float = 0.0
|
||||
resize_interpolation: Optional[str] = None
|
||||
|
||||
@dataclass
|
||||
class DreamBoothDatasetParams(BaseDatasetParams):
|
||||
@@ -106,8 +118,7 @@ class DreamBoothDatasetParams(BaseDatasetParams):
|
||||
bucket_reso_steps: int = 64
|
||||
bucket_no_upscale: bool = False
|
||||
prior_loss_weight: float = 1.0
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class FineTuningDatasetParams(BaseDatasetParams):
|
||||
batch_size: int = 1
|
||||
@@ -178,10 +189,15 @@ class ConfigSanitizer:
|
||||
"shuffle_caption": bool,
|
||||
"keep_tokens": int,
|
||||
"keep_tokens_separator": str,
|
||||
"secondary_separator": str,
|
||||
"caption_separator": str,
|
||||
"enable_wildcard": bool,
|
||||
"token_warmup_min": int,
|
||||
"token_warmup_step": Any(float, int),
|
||||
"caption_prefix": str,
|
||||
"caption_suffix": str,
|
||||
"custom_attributes": dict,
|
||||
"resize_interpolation": str,
|
||||
}
|
||||
# DO means DropOut
|
||||
DO_SUBSET_ASCENDABLE_SCHEMA = {
|
||||
@@ -193,18 +209,22 @@ class ConfigSanitizer:
|
||||
DB_SUBSET_ASCENDABLE_SCHEMA = {
|
||||
"caption_extension": str,
|
||||
"class_tokens": str,
|
||||
"cache_info": bool,
|
||||
}
|
||||
DB_SUBSET_DISTINCT_SCHEMA = {
|
||||
Required("image_dir"): str,
|
||||
"is_reg": bool,
|
||||
"alpha_mask": bool,
|
||||
}
|
||||
# FT means FineTuning
|
||||
FT_SUBSET_DISTINCT_SCHEMA = {
|
||||
Required("metadata_file"): str,
|
||||
"image_dir": str,
|
||||
"alpha_mask": bool,
|
||||
}
|
||||
CN_SUBSET_ASCENDABLE_SCHEMA = {
|
||||
"caption_extension": str,
|
||||
"cache_info": bool,
|
||||
}
|
||||
CN_SUBSET_DISTINCT_SCHEMA = {
|
||||
Required("image_dir"): str,
|
||||
@@ -219,8 +239,11 @@ class ConfigSanitizer:
|
||||
"enable_bucket": bool,
|
||||
"max_bucket_reso": int,
|
||||
"min_bucket_reso": int,
|
||||
"validation_seed": int,
|
||||
"validation_split": float,
|
||||
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
||||
"network_multiplier": float,
|
||||
"resize_interpolation": str,
|
||||
}
|
||||
|
||||
# options handled by argparse but not handled by user config
|
||||
@@ -241,9 +264,10 @@ class ConfigSanitizer:
|
||||
}
|
||||
|
||||
def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None:
|
||||
assert (
|
||||
support_dreambooth or support_finetuning or support_controlnet
|
||||
), "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。"
|
||||
assert support_dreambooth or support_finetuning or support_controlnet, (
|
||||
"Neither DreamBooth mode nor fine tuning mode nor controlnet mode specified. Please specify one mode or more."
|
||||
+ " / DreamBooth モードか fine tuning モードか controlnet モードのどれも指定されていません。1つ以上指定してください。"
|
||||
)
|
||||
|
||||
self.db_subset_schema = self.__merge_dict(
|
||||
self.SUBSET_ASCENDABLE_SCHEMA,
|
||||
@@ -310,7 +334,10 @@ class ConfigSanitizer:
|
||||
|
||||
self.dataset_schema = validate_flex_dataset
|
||||
elif support_dreambooth:
|
||||
self.dataset_schema = self.db_dataset_schema
|
||||
if support_controlnet:
|
||||
self.dataset_schema = self.cn_dataset_schema
|
||||
else:
|
||||
self.dataset_schema = self.db_dataset_schema
|
||||
elif support_finetuning:
|
||||
self.dataset_schema = self.ft_dataset_schema
|
||||
elif support_controlnet:
|
||||
@@ -345,7 +372,7 @@ class ConfigSanitizer:
|
||||
return self.user_config_validator(user_config)
|
||||
except MultipleInvalid:
|
||||
# TODO: エラー発生時のメッセージをわかりやすくする
|
||||
print("Invalid user config / ユーザ設定の形式が正しくないようです")
|
||||
logger.error("Invalid user config / ユーザ設定の形式が正しくないようです")
|
||||
raise
|
||||
|
||||
# NOTE: In nature, argument parser result is not needed to be sanitize
|
||||
@@ -355,7 +382,9 @@ class ConfigSanitizer:
|
||||
return self.argparse_config_validator(argparse_namespace)
|
||||
except MultipleInvalid:
|
||||
# XXX: this should be a bug
|
||||
print("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。")
|
||||
logger.error(
|
||||
"Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。"
|
||||
)
|
||||
raise
|
||||
|
||||
# NOTE: value would be overwritten by latter dict if there is already the same key
|
||||
@@ -441,114 +470,138 @@ class BlueprintGenerator:
|
||||
|
||||
return default_value
|
||||
|
||||
|
||||
def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint):
|
||||
def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint) -> Tuple[DatasetGroup, Optional[DatasetGroup]]:
|
||||
datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
|
||||
|
||||
for dataset_blueprint in dataset_group_blueprint.datasets:
|
||||
extra_dataset_params = {}
|
||||
|
||||
if dataset_blueprint.is_controlnet:
|
||||
subset_klass = ControlNetSubset
|
||||
dataset_klass = ControlNetDataset
|
||||
elif dataset_blueprint.is_dreambooth:
|
||||
subset_klass = DreamBoothSubset
|
||||
dataset_klass = DreamBoothDataset
|
||||
# DreamBooth datasets support splitting training and validation datasets
|
||||
extra_dataset_params = {"is_training_dataset": True}
|
||||
else:
|
||||
subset_klass = FineTuningSubset
|
||||
dataset_klass = FineTuningDataset
|
||||
|
||||
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
|
||||
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
|
||||
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params), **extra_dataset_params)
|
||||
datasets.append(dataset)
|
||||
|
||||
# print info
|
||||
info = ""
|
||||
for i, dataset in enumerate(datasets):
|
||||
is_dreambooth = isinstance(dataset, DreamBoothDataset)
|
||||
is_controlnet = isinstance(dataset, ControlNetDataset)
|
||||
info += dedent(
|
||||
f"""\
|
||||
[Dataset {i}]
|
||||
batch_size: {dataset.batch_size}
|
||||
resolution: {(dataset.width, dataset.height)}
|
||||
enable_bucket: {dataset.enable_bucket}
|
||||
network_multiplier: {dataset.network_multiplier}
|
||||
"""
|
||||
)
|
||||
val_datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
|
||||
for dataset_blueprint in dataset_group_blueprint.datasets:
|
||||
if dataset_blueprint.params.validation_split < 0.0 or dataset_blueprint.params.validation_split > 1.0:
|
||||
logging.warning(f"Dataset param `validation_split` ({dataset_blueprint.params.validation_split}) is not a valid number between 0.0 and 1.0, skipping validation split...")
|
||||
continue
|
||||
|
||||
if dataset.enable_bucket:
|
||||
info += indent(
|
||||
dedent(
|
||||
f"""\
|
||||
min_bucket_reso: {dataset.min_bucket_reso}
|
||||
max_bucket_reso: {dataset.max_bucket_reso}
|
||||
bucket_reso_steps: {dataset.bucket_reso_steps}
|
||||
bucket_no_upscale: {dataset.bucket_no_upscale}
|
||||
\n"""
|
||||
),
|
||||
" ",
|
||||
)
|
||||
# if the dataset isn't setting a validation split, there is no current validation dataset
|
||||
if dataset_blueprint.params.validation_split == 0.0:
|
||||
continue
|
||||
|
||||
extra_dataset_params = {}
|
||||
if dataset_blueprint.is_controlnet:
|
||||
subset_klass = ControlNetSubset
|
||||
dataset_klass = ControlNetDataset
|
||||
elif dataset_blueprint.is_dreambooth:
|
||||
subset_klass = DreamBoothSubset
|
||||
dataset_klass = DreamBoothDataset
|
||||
# DreamBooth datasets support splitting training and validation datasets
|
||||
extra_dataset_params = {"is_training_dataset": False}
|
||||
else:
|
||||
info += "\n"
|
||||
subset_klass = FineTuningSubset
|
||||
dataset_klass = FineTuningDataset
|
||||
|
||||
for j, subset in enumerate(dataset.subsets):
|
||||
info += indent(
|
||||
dedent(
|
||||
f"""\
|
||||
[Subset {j} of Dataset {i}]
|
||||
image_dir: "{subset.image_dir}"
|
||||
image_count: {subset.img_count}
|
||||
num_repeats: {subset.num_repeats}
|
||||
shuffle_caption: {subset.shuffle_caption}
|
||||
keep_tokens: {subset.keep_tokens}
|
||||
keep_tokens_separator: {subset.keep_tokens_separator}
|
||||
caption_dropout_rate: {subset.caption_dropout_rate}
|
||||
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
|
||||
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
|
||||
caption_prefix: {subset.caption_prefix}
|
||||
caption_suffix: {subset.caption_suffix}
|
||||
color_aug: {subset.color_aug}
|
||||
flip_aug: {subset.flip_aug}
|
||||
face_crop_aug_range: {subset.face_crop_aug_range}
|
||||
random_crop: {subset.random_crop}
|
||||
token_warmup_min: {subset.token_warmup_min},
|
||||
token_warmup_step: {subset.token_warmup_step},
|
||||
"""
|
||||
),
|
||||
" ",
|
||||
)
|
||||
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
|
||||
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params), **extra_dataset_params)
|
||||
val_datasets.append(dataset)
|
||||
|
||||
if is_dreambooth:
|
||||
info += indent(
|
||||
dedent(
|
||||
f"""\
|
||||
is_reg: {subset.is_reg}
|
||||
class_tokens: {subset.class_tokens}
|
||||
caption_extension: {subset.caption_extension}
|
||||
\n"""
|
||||
),
|
||||
" ",
|
||||
)
|
||||
elif not is_controlnet:
|
||||
info += indent(
|
||||
dedent(
|
||||
f"""\
|
||||
metadata_file: {subset.metadata_file}
|
||||
\n"""
|
||||
),
|
||||
" ",
|
||||
)
|
||||
def print_info(_datasets, dataset_type: str):
|
||||
info = ""
|
||||
for i, dataset in enumerate(_datasets):
|
||||
is_dreambooth = isinstance(dataset, DreamBoothDataset)
|
||||
is_controlnet = isinstance(dataset, ControlNetDataset)
|
||||
info += dedent(f"""\
|
||||
[{dataset_type} {i}]
|
||||
batch_size: {dataset.batch_size}
|
||||
resolution: {(dataset.width, dataset.height)}
|
||||
resize_interpolation: {dataset.resize_interpolation}
|
||||
enable_bucket: {dataset.enable_bucket}
|
||||
""")
|
||||
|
||||
print(info)
|
||||
if dataset.enable_bucket:
|
||||
info += indent(dedent(f"""\
|
||||
min_bucket_reso: {dataset.min_bucket_reso}
|
||||
max_bucket_reso: {dataset.max_bucket_reso}
|
||||
bucket_reso_steps: {dataset.bucket_reso_steps}
|
||||
bucket_no_upscale: {dataset.bucket_no_upscale}
|
||||
\n"""), " ")
|
||||
else:
|
||||
info += "\n"
|
||||
|
||||
for j, subset in enumerate(dataset.subsets):
|
||||
info += indent(dedent(f"""\
|
||||
[Subset {j} of {dataset_type} {i}]
|
||||
image_dir: "{subset.image_dir}"
|
||||
image_count: {subset.img_count}
|
||||
num_repeats: {subset.num_repeats}
|
||||
shuffle_caption: {subset.shuffle_caption}
|
||||
keep_tokens: {subset.keep_tokens}
|
||||
caption_dropout_rate: {subset.caption_dropout_rate}
|
||||
caption_dropout_every_n_epochs: {subset.caption_dropout_every_n_epochs}
|
||||
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
|
||||
caption_prefix: {subset.caption_prefix}
|
||||
caption_suffix: {subset.caption_suffix}
|
||||
color_aug: {subset.color_aug}
|
||||
flip_aug: {subset.flip_aug}
|
||||
face_crop_aug_range: {subset.face_crop_aug_range}
|
||||
random_crop: {subset.random_crop}
|
||||
token_warmup_min: {subset.token_warmup_min},
|
||||
token_warmup_step: {subset.token_warmup_step},
|
||||
alpha_mask: {subset.alpha_mask}
|
||||
resize_interpolation: {subset.resize_interpolation}
|
||||
custom_attributes: {subset.custom_attributes}
|
||||
"""), " ")
|
||||
|
||||
if is_dreambooth:
|
||||
info += indent(dedent(f"""\
|
||||
is_reg: {subset.is_reg}
|
||||
class_tokens: {subset.class_tokens}
|
||||
caption_extension: {subset.caption_extension}
|
||||
\n"""), " ")
|
||||
elif not is_controlnet:
|
||||
info += indent(dedent(f"""\
|
||||
metadata_file: {subset.metadata_file}
|
||||
\n"""), " ")
|
||||
|
||||
logger.info(info)
|
||||
|
||||
print_info(datasets, "Dataset")
|
||||
|
||||
if len(val_datasets) > 0:
|
||||
print_info(val_datasets, "Validation Dataset")
|
||||
|
||||
# make buckets first because it determines the length of dataset
|
||||
# and set the same seed for all datasets
|
||||
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
|
||||
|
||||
for i, dataset in enumerate(datasets):
|
||||
print(f"[Dataset {i}]")
|
||||
logger.info(f"[Prepare dataset {i}]")
|
||||
dataset.make_buckets()
|
||||
dataset.set_seed(seed)
|
||||
|
||||
return DatasetGroup(datasets)
|
||||
for i, dataset in enumerate(val_datasets):
|
||||
logger.info(f"[Prepare validation dataset {i}]")
|
||||
dataset.make_buckets()
|
||||
dataset.set_seed(seed)
|
||||
|
||||
return (
|
||||
DatasetGroup(datasets),
|
||||
DatasetGroup(val_datasets) if val_datasets else None
|
||||
)
|
||||
|
||||
|
||||
def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None):
|
||||
@@ -557,7 +610,7 @@ def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str]
|
||||
try:
|
||||
n_repeats = int(tokens[0])
|
||||
except ValueError as e:
|
||||
print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}")
|
||||
logger.warning(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}")
|
||||
return 0, ""
|
||||
caption_by_folder = "_".join(tokens[1:])
|
||||
return n_repeats, caption_by_folder
|
||||
@@ -629,7 +682,7 @@ def load_user_config(file: str) -> dict:
|
||||
with open(file, "r") as f:
|
||||
config = json.load(f)
|
||||
except Exception:
|
||||
print(
|
||||
logger.error(
|
||||
f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
|
||||
)
|
||||
raise
|
||||
@@ -637,7 +690,7 @@ def load_user_config(file: str) -> dict:
|
||||
try:
|
||||
config = toml.load(file)
|
||||
except Exception:
|
||||
print(
|
||||
logger.error(
|
||||
f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
|
||||
)
|
||||
raise
|
||||
@@ -665,23 +718,26 @@ if __name__ == "__main__":
|
||||
argparse_namespace = parser.parse_args(remain)
|
||||
train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning)
|
||||
|
||||
print("[argparse_namespace]")
|
||||
print(vars(argparse_namespace))
|
||||
logger.info("[argparse_namespace]")
|
||||
logger.info(f"{vars(argparse_namespace)}")
|
||||
|
||||
user_config = load_user_config(config_args.dataset_config)
|
||||
|
||||
print("\n[user_config]")
|
||||
print(user_config)
|
||||
logger.info("")
|
||||
logger.info("[user_config]")
|
||||
logger.info(f"{user_config}")
|
||||
|
||||
sanitizer = ConfigSanitizer(
|
||||
config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout
|
||||
)
|
||||
sanitized_user_config = sanitizer.sanitize_user_config(user_config)
|
||||
|
||||
print("\n[sanitized_user_config]")
|
||||
print(sanitized_user_config)
|
||||
logger.info("")
|
||||
logger.info("[sanitized_user_config]")
|
||||
logger.info(f"{sanitized_user_config}")
|
||||
|
||||
blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
|
||||
|
||||
print("\n[blueprint]")
|
||||
print(blueprint)
|
||||
logger.info("")
|
||||
logger.info("[blueprint]")
|
||||
logger.info(f"{blueprint}")
|
||||
|
||||
227
library/custom_offloading_utils.py
Normal file
227
library/custom_offloading_utils.py
Normal file
@@ -0,0 +1,227 @@
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import time
|
||||
from typing import Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from library.device_utils import clean_memory_on_device
|
||||
|
||||
|
||||
def synchronize_device(device: torch.device):
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
elif device.type == "xpu":
|
||||
torch.xpu.synchronize()
|
||||
elif device.type == "mps":
|
||||
torch.mps.synchronize()
|
||||
|
||||
|
||||
def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
||||
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
||||
|
||||
weight_swap_jobs = []
|
||||
|
||||
# This is not working for all cases (e.g. SD3), so we need to find the corresponding modules
|
||||
# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
||||
# print(module_to_cpu.__class__, module_to_cuda.__class__)
|
||||
# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
||||
# weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
||||
|
||||
modules_to_cpu = {k: v for k, v in layer_to_cpu.named_modules()}
|
||||
for module_to_cuda_name, module_to_cuda in layer_to_cuda.named_modules():
|
||||
if hasattr(module_to_cuda, "weight") and module_to_cuda.weight is not None:
|
||||
module_to_cpu = modules_to_cpu.get(module_to_cuda_name, None)
|
||||
if module_to_cpu is not None and module_to_cpu.weight.shape == module_to_cuda.weight.shape:
|
||||
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
||||
else:
|
||||
if module_to_cuda.weight.data.device.type != device.type:
|
||||
# print(
|
||||
# f"Module {module_to_cuda_name} not found in CPU model or shape mismatch, so not swapping and moving to device"
|
||||
# )
|
||||
module_to_cuda.weight.data = module_to_cuda.weight.data.to(device)
|
||||
|
||||
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
||||
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
# cuda to cpu
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
cuda_data_view.record_stream(stream)
|
||||
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
|
||||
|
||||
stream.synchronize()
|
||||
|
||||
# cpu to cuda
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
||||
module_to_cuda.weight.data = cuda_data_view
|
||||
|
||||
stream.synchronize()
|
||||
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
||||
|
||||
|
||||
def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
||||
"""
|
||||
not tested
|
||||
"""
|
||||
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
||||
|
||||
weight_swap_jobs = []
|
||||
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
||||
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
||||
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
||||
|
||||
# device to cpu
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
|
||||
|
||||
synchronize_device()
|
||||
|
||||
# cpu to device
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
||||
module_to_cuda.weight.data = cuda_data_view
|
||||
|
||||
synchronize_device()
|
||||
|
||||
|
||||
def weighs_to_device(layer: nn.Module, device: torch.device):
|
||||
for module in layer.modules():
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
module.weight.data = module.weight.data.to(device, non_blocking=True)
|
||||
|
||||
|
||||
class Offloader:
|
||||
"""
|
||||
common offloading class
|
||||
"""
|
||||
|
||||
def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
|
||||
self.num_blocks = num_blocks
|
||||
self.blocks_to_swap = blocks_to_swap
|
||||
self.device = device
|
||||
self.debug = debug
|
||||
|
||||
self.thread_pool = ThreadPoolExecutor(max_workers=1)
|
||||
self.futures = {}
|
||||
self.cuda_available = device.type == "cuda"
|
||||
|
||||
def swap_weight_devices(self, block_to_cpu: nn.Module, block_to_cuda: nn.Module):
|
||||
if self.cuda_available:
|
||||
swap_weight_devices_cuda(self.device, block_to_cpu, block_to_cuda)
|
||||
else:
|
||||
swap_weight_devices_no_cuda(self.device, block_to_cpu, block_to_cuda)
|
||||
|
||||
def _submit_move_blocks(self, blocks, block_idx_to_cpu, block_idx_to_cuda):
|
||||
def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
|
||||
if self.debug:
|
||||
start_time = time.perf_counter()
|
||||
print(f"Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to {'CUDA' if self.cuda_available else 'device'}")
|
||||
|
||||
self.swap_weight_devices(block_to_cpu, block_to_cuda)
|
||||
|
||||
if self.debug:
|
||||
print(f"Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s")
|
||||
return bidx_to_cpu, bidx_to_cuda # , event
|
||||
|
||||
block_to_cpu = blocks[block_idx_to_cpu]
|
||||
block_to_cuda = blocks[block_idx_to_cuda]
|
||||
|
||||
self.futures[block_idx_to_cuda] = self.thread_pool.submit(
|
||||
move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda
|
||||
)
|
||||
|
||||
def _wait_blocks_move(self, block_idx):
|
||||
if block_idx not in self.futures:
|
||||
return
|
||||
|
||||
if self.debug:
|
||||
print(f"Wait for block {block_idx}")
|
||||
start_time = time.perf_counter()
|
||||
|
||||
future = self.futures.pop(block_idx)
|
||||
_, bidx_to_cuda = future.result()
|
||||
|
||||
assert block_idx == bidx_to_cuda, f"Block index mismatch: {block_idx} != {bidx_to_cuda}"
|
||||
|
||||
if self.debug:
|
||||
print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")
|
||||
|
||||
|
||||
class ModelOffloader(Offloader):
|
||||
"""
|
||||
supports forward offloading
|
||||
"""
|
||||
|
||||
def __init__(self, blocks: list[nn.Module], num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
|
||||
super().__init__(num_blocks, blocks_to_swap, device, debug)
|
||||
|
||||
# register backward hooks
|
||||
self.remove_handles = []
|
||||
for i, block in enumerate(blocks):
|
||||
hook = self.create_backward_hook(blocks, i)
|
||||
if hook is not None:
|
||||
handle = block.register_full_backward_hook(hook)
|
||||
self.remove_handles.append(handle)
|
||||
|
||||
def __del__(self):
|
||||
for handle in self.remove_handles:
|
||||
handle.remove()
|
||||
|
||||
def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]:
|
||||
# -1 for 0-based index
|
||||
num_blocks_propagated = self.num_blocks - block_index - 1
|
||||
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
|
||||
waiting = block_index > 0 and block_index <= self.blocks_to_swap
|
||||
|
||||
if not swapping and not waiting:
|
||||
return None
|
||||
|
||||
# create hook
|
||||
block_idx_to_cpu = self.num_blocks - num_blocks_propagated
|
||||
block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated
|
||||
block_idx_to_wait = block_index - 1
|
||||
|
||||
def backward_hook(module, grad_input, grad_output):
|
||||
if self.debug:
|
||||
print(f"Backward hook for block {block_index}")
|
||||
|
||||
if swapping:
|
||||
self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
|
||||
if waiting:
|
||||
self._wait_blocks_move(block_idx_to_wait)
|
||||
return None
|
||||
|
||||
return backward_hook
|
||||
|
||||
def prepare_block_devices_before_forward(self, blocks: list[nn.Module]):
|
||||
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||
return
|
||||
|
||||
if self.debug:
|
||||
print("Prepare block devices before forward")
|
||||
|
||||
for b in blocks[0 : self.num_blocks - self.blocks_to_swap]:
|
||||
b.to(self.device)
|
||||
weighs_to_device(b, self.device) # make sure weights are on device
|
||||
|
||||
for b in blocks[self.num_blocks - self.blocks_to_swap :]:
|
||||
b.to(self.device) # move block to device first
|
||||
weighs_to_device(b, "cpu") # make sure weights are on cpu
|
||||
|
||||
synchronize_device(self.device)
|
||||
clean_memory_on_device(self.device)
|
||||
|
||||
def wait_for_block(self, block_idx: int):
|
||||
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||
return
|
||||
self._wait_blocks_move(block_idx)
|
||||
|
||||
def submit_move_blocks(self, blocks: list[nn.Module], block_idx: int):
|
||||
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||
return
|
||||
if block_idx >= self.blocks_to_swap:
|
||||
return
|
||||
block_idx_to_cpu = block_idx
|
||||
block_idx_to_cuda = self.num_blocks - self.blocks_to_swap + block_idx
|
||||
self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
|
||||
@@ -1,8 +1,16 @@
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
import torch
|
||||
import argparse
|
||||
import random
|
||||
import re
|
||||
from torch.types import Number
|
||||
from typing import List, Optional, Union
|
||||
from .utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def prepare_scheduler_for_custom_training(noise_scheduler, device):
|
||||
@@ -21,7 +29,7 @@ def prepare_scheduler_for_custom_training(noise_scheduler, device):
|
||||
|
||||
def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
|
||||
# fix beta: zero terminal SNR
|
||||
print(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891")
|
||||
logger.info(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891")
|
||||
|
||||
def enforce_zero_terminal_snr(betas):
|
||||
# Convert betas to alphas_bar_sqrt
|
||||
@@ -49,53 +57,58 @@ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
|
||||
# print("original:", noise_scheduler.betas)
|
||||
# print("fixed:", betas)
|
||||
# logger.info(f"original: {noise_scheduler.betas}")
|
||||
# logger.info(f"fixed: {betas}")
|
||||
|
||||
noise_scheduler.betas = betas
|
||||
noise_scheduler.alphas = alphas
|
||||
noise_scheduler.alphas_cumprod = alphas_cumprod
|
||||
|
||||
|
||||
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
|
||||
def apply_snr_weight(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False):
|
||||
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
|
||||
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
|
||||
if v_prediction:
|
||||
snr_weight = torch.div(min_snr_gamma, snr+1).float().to(loss.device)
|
||||
snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device)
|
||||
else:
|
||||
snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
|
||||
loss = loss * snr_weight
|
||||
return loss
|
||||
|
||||
|
||||
def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
|
||||
def scale_v_prediction_loss_like_noise_prediction(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler):
|
||||
scale = get_snr_scale(timesteps, noise_scheduler)
|
||||
loss = loss * scale
|
||||
return loss
|
||||
|
||||
|
||||
def get_snr_scale(timesteps, noise_scheduler):
|
||||
def get_snr_scale(timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler):
|
||||
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
|
||||
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
||||
scale = snr_t / (snr_t + 1)
|
||||
# # show debug info
|
||||
# print(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}")
|
||||
# logger.info(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}")
|
||||
return scale
|
||||
|
||||
|
||||
def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss):
|
||||
def add_v_prediction_like_loss(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_pred_like_loss: torch.Tensor):
|
||||
scale = get_snr_scale(timesteps, noise_scheduler)
|
||||
# print(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
|
||||
# logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
|
||||
loss = loss + loss / scale * v_pred_like_loss
|
||||
return loss
|
||||
|
||||
def apply_debiased_estimation(loss, timesteps, noise_scheduler):
|
||||
|
||||
def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_prediction=False):
|
||||
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
|
||||
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
||||
weight = 1/torch.sqrt(snr_t)
|
||||
if v_prediction:
|
||||
weight = 1 / (snr_t + 1)
|
||||
else:
|
||||
weight = 1 / torch.sqrt(snr_t)
|
||||
loss = weight * loss
|
||||
return loss
|
||||
|
||||
|
||||
# TODO train_utilと分散しているのでどちらかに寄せる
|
||||
|
||||
|
||||
@@ -268,7 +281,7 @@ def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int):
|
||||
tokens.append(text_token)
|
||||
weights.append(text_weight)
|
||||
if truncated:
|
||||
print("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
|
||||
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
|
||||
return tokens, weights
|
||||
|
||||
|
||||
@@ -442,7 +455,7 @@ def get_weighted_text_embeddings(
|
||||
|
||||
|
||||
# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
|
||||
def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
|
||||
def pyramid_noise_like(noise, device, iterations=6, discount=0.4) -> torch.FloatTensor:
|
||||
b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant!
|
||||
u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
|
||||
for i in range(iterations):
|
||||
@@ -455,7 +468,7 @@ def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
|
||||
|
||||
|
||||
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||
def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
|
||||
def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale) -> torch.FloatTensor:
|
||||
if noise_offset is None:
|
||||
return noise
|
||||
if adaptive_noise_scale is not None:
|
||||
@@ -471,6 +484,25 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
|
||||
return noise
|
||||
|
||||
|
||||
def apply_masked_loss(loss, batch) -> torch.FloatTensor:
|
||||
if "conditioning_images" in batch:
|
||||
# conditioning image is -1 to 1. we need to convert it to 0 to 1
|
||||
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
|
||||
mask_image = mask_image / 2 + 0.5
|
||||
# print(f"conditioning_image: {mask_image.shape}")
|
||||
elif "alpha_masks" in batch and batch["alpha_masks"] is not None:
|
||||
# alpha mask is 0 to 1
|
||||
mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension
|
||||
# print(f"mask_image: {mask_image.shape}, {mask_image.mean()}")
|
||||
else:
|
||||
return loss
|
||||
|
||||
# resize to the same size as the loss
|
||||
mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
|
||||
loss = loss * mask_image
|
||||
return loss
|
||||
|
||||
|
||||
"""
|
||||
##########################################
|
||||
# Perlin Noise
|
||||
|
||||
139
library/deepspeed_utils.py
Normal file
139
library/deepspeed_utils.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import os
|
||||
import argparse
|
||||
import torch
|
||||
from accelerate import DeepSpeedPlugin, Accelerator
|
||||
|
||||
from .utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def add_deepspeed_arguments(parser: argparse.ArgumentParser):
|
||||
# DeepSpeed Arguments. https://huggingface.co/docs/accelerate/usage_guides/deepspeed
|
||||
parser.add_argument("--deepspeed", action="store_true", help="enable deepspeed training")
|
||||
parser.add_argument("--zero_stage", type=int, default=2, choices=[0, 1, 2, 3], help="Possible options are 0,1,2,3.")
|
||||
parser.add_argument(
|
||||
"--offload_optimizer_device",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=[None, "cpu", "nvme"],
|
||||
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--offload_optimizer_nvme_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--offload_param_device",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=[None, "cpu", "nvme"],
|
||||
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--offload_param_nvme_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--zero3_init_flag",
|
||||
action="store_true",
|
||||
help="Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models."
|
||||
"Only applicable with ZeRO Stage-3.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--zero3_save_16bit_model",
|
||||
action="store_true",
|
||||
help="Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fp16_master_weights_and_gradients",
|
||||
action="store_true",
|
||||
help="fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32.",
|
||||
)
|
||||
|
||||
|
||||
def prepare_deepspeed_args(args: argparse.Namespace):
|
||||
if not args.deepspeed:
|
||||
return
|
||||
|
||||
# To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
|
||||
args.max_data_loader_n_workers = 1
|
||||
|
||||
|
||||
def prepare_deepspeed_plugin(args: argparse.Namespace):
|
||||
if not args.deepspeed:
|
||||
return None
|
||||
|
||||
try:
|
||||
import deepspeed
|
||||
except ImportError as e:
|
||||
logger.error(
|
||||
"deepspeed is not installed. please install deepspeed in your environment with following command. DS_BUILD_OPS=0 pip install deepspeed"
|
||||
)
|
||||
exit(1)
|
||||
|
||||
deepspeed_plugin = DeepSpeedPlugin(
|
||||
zero_stage=args.zero_stage,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
gradient_clipping=args.max_grad_norm,
|
||||
offload_optimizer_device=args.offload_optimizer_device,
|
||||
offload_optimizer_nvme_path=args.offload_optimizer_nvme_path,
|
||||
offload_param_device=args.offload_param_device,
|
||||
offload_param_nvme_path=args.offload_param_nvme_path,
|
||||
zero3_init_flag=args.zero3_init_flag,
|
||||
zero3_save_16bit_model=args.zero3_save_16bit_model,
|
||||
)
|
||||
deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size
|
||||
deepspeed_plugin.deepspeed_config["train_batch_size"] = (
|
||||
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
|
||||
)
|
||||
deepspeed_plugin.set_mixed_precision(args.mixed_precision)
|
||||
if args.mixed_precision.lower() == "fp16":
|
||||
deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
|
||||
if args.full_fp16 or args.fp16_master_weights_and_gradients:
|
||||
if args.offload_optimizer_device == "cpu" and args.zero_stage == 2:
|
||||
deepspeed_plugin.deepspeed_config["fp16"]["fp16_master_weights_and_grads"] = True
|
||||
logger.info("[DeepSpeed] full fp16 enable.")
|
||||
else:
|
||||
logger.info(
|
||||
"[DeepSpeed]full fp16, fp16_master_weights_and_grads currently only supported using ZeRO-Offload with DeepSpeedCPUAdam on ZeRO-2 stage."
|
||||
)
|
||||
|
||||
if args.offload_optimizer_device is not None:
|
||||
logger.info("[DeepSpeed] start to manually build cpu_adam.")
|
||||
deepspeed.ops.op_builder.CPUAdamBuilder().load()
|
||||
logger.info("[DeepSpeed] building cpu_adam done.")
|
||||
|
||||
return deepspeed_plugin
|
||||
|
||||
|
||||
# Accelerate library does not support multiple models for deepspeed. So, we need to wrap multiple models into a single model.
|
||||
def prepare_deepspeed_model(args: argparse.Namespace, **models):
|
||||
# remove None from models
|
||||
models = {k: v for k, v in models.items() if v is not None}
|
||||
|
||||
class DeepSpeedWrapper(torch.nn.Module):
|
||||
def __init__(self, **kw_models) -> None:
|
||||
super().__init__()
|
||||
self.models = torch.nn.ModuleDict()
|
||||
|
||||
for key, model in kw_models.items():
|
||||
if isinstance(model, list):
|
||||
model = torch.nn.ModuleList(model)
|
||||
assert isinstance(
|
||||
model, torch.nn.Module
|
||||
), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
|
||||
self.models.update(torch.nn.ModuleDict({key: model}))
|
||||
|
||||
def get_models(self):
|
||||
return self.models
|
||||
|
||||
ds_model = DeepSpeedWrapper(**models)
|
||||
return ds_model
|
||||
89
library/device_utils.py
Normal file
89
library/device_utils.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import functools
|
||||
import gc
|
||||
|
||||
import torch
|
||||
try:
|
||||
# intel gpu support for pytorch older than 2.5
|
||||
# ipex is not needed after pytorch 2.5
|
||||
import intel_extension_for_pytorch as ipex # noqa
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
try:
|
||||
HAS_CUDA = torch.cuda.is_available()
|
||||
except Exception:
|
||||
HAS_CUDA = False
|
||||
|
||||
try:
|
||||
HAS_MPS = torch.backends.mps.is_available()
|
||||
except Exception:
|
||||
HAS_MPS = False
|
||||
|
||||
try:
|
||||
HAS_XPU = torch.xpu.is_available()
|
||||
except Exception:
|
||||
HAS_XPU = False
|
||||
|
||||
|
||||
def clean_memory():
|
||||
gc.collect()
|
||||
if HAS_CUDA:
|
||||
torch.cuda.empty_cache()
|
||||
if HAS_XPU:
|
||||
torch.xpu.empty_cache()
|
||||
if HAS_MPS:
|
||||
torch.mps.empty_cache()
|
||||
|
||||
|
||||
def clean_memory_on_device(device: torch.device):
|
||||
r"""
|
||||
Clean memory on the specified device, will be called from training scripts.
|
||||
"""
|
||||
gc.collect()
|
||||
|
||||
# device may "cuda" or "cuda:0", so we need to check the type of device
|
||||
if device.type == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
if device.type == "xpu":
|
||||
torch.xpu.empty_cache()
|
||||
if device.type == "mps":
|
||||
torch.mps.empty_cache()
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_preferred_device() -> torch.device:
|
||||
r"""
|
||||
Do not call this function from training scripts. Use accelerator.device instead.
|
||||
"""
|
||||
if HAS_CUDA:
|
||||
device = torch.device("cuda")
|
||||
elif HAS_XPU:
|
||||
device = torch.device("xpu")
|
||||
elif HAS_MPS:
|
||||
device = torch.device("mps")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
print(f"get_preferred_device() -> {device}")
|
||||
return device
|
||||
|
||||
|
||||
def init_ipex():
|
||||
"""
|
||||
Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`.
|
||||
|
||||
This function should run right after importing torch and before doing anything else.
|
||||
|
||||
If xpu is not available, this function does nothing.
|
||||
"""
|
||||
try:
|
||||
if HAS_XPU:
|
||||
from library.ipex import ipex_init
|
||||
|
||||
is_initialized, error_message = ipex_init()
|
||||
if not is_initialized:
|
||||
print("failed to initialize ipex:", error_message)
|
||||
else:
|
||||
return
|
||||
except Exception as e:
|
||||
print("failed to initialize ipex:", e)
|
||||
1493
library/flux_models.py
Normal file
1493
library/flux_models.py
Normal file
File diff suppressed because it is too large
Load Diff
621
library/flux_train_utils.py
Normal file
621
library/flux_train_utils.py
Normal file
@@ -0,0 +1,621 @@
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import numpy as np
|
||||
import toml
|
||||
import json
|
||||
import time
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator, PartialState
|
||||
from transformers import CLIPTextModel
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from library import flux_models, flux_utils, strategy_base, train_util
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from .utils import setup_logging, mem_eff_save_file
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# region sample images
|
||||
|
||||
|
||||
def sample_images(
|
||||
accelerator: Accelerator,
|
||||
args: argparse.Namespace,
|
||||
epoch,
|
||||
steps,
|
||||
flux,
|
||||
ae,
|
||||
text_encoders,
|
||||
sample_prompts_te_outputs,
|
||||
prompt_replacement=None,
|
||||
controlnet=None
|
||||
):
|
||||
if steps == 0:
|
||||
if not args.sample_at_first:
|
||||
return
|
||||
else:
|
||||
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
|
||||
return
|
||||
if args.sample_every_n_epochs is not None:
|
||||
# sample_every_n_steps は無視する
|
||||
if epoch is None or epoch % args.sample_every_n_epochs != 0:
|
||||
return
|
||||
else:
|
||||
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
|
||||
return
|
||||
|
||||
logger.info("")
|
||||
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
|
||||
if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None:
|
||||
logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
|
||||
return
|
||||
|
||||
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
|
||||
|
||||
# unwrap unet and text_encoder(s)
|
||||
flux = accelerator.unwrap_model(flux)
|
||||
if text_encoders is not None:
|
||||
text_encoders = [accelerator.unwrap_model(te) for te in text_encoders]
|
||||
if controlnet is not None:
|
||||
controlnet = accelerator.unwrap_model(controlnet)
|
||||
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
|
||||
|
||||
prompts = train_util.load_prompts(args.sample_prompts)
|
||||
|
||||
save_dir = args.output_dir + "/sample"
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# save random state to restore later
|
||||
rng_state = torch.get_rng_state()
|
||||
cuda_rng_state = None
|
||||
try:
|
||||
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if distributed_state.num_processes <= 1:
|
||||
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
for prompt_dict in prompts:
|
||||
sample_image_inference(
|
||||
accelerator,
|
||||
args,
|
||||
flux,
|
||||
text_encoders,
|
||||
ae,
|
||||
save_dir,
|
||||
prompt_dict,
|
||||
epoch,
|
||||
steps,
|
||||
sample_prompts_te_outputs,
|
||||
prompt_replacement,
|
||||
controlnet
|
||||
)
|
||||
else:
|
||||
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
|
||||
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
|
||||
per_process_prompts = [] # list of lists
|
||||
for i in range(distributed_state.num_processes):
|
||||
per_process_prompts.append(prompts[i :: distributed_state.num_processes])
|
||||
|
||||
with torch.no_grad():
|
||||
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
|
||||
for prompt_dict in prompt_dict_lists[0]:
|
||||
sample_image_inference(
|
||||
accelerator,
|
||||
args,
|
||||
flux,
|
||||
text_encoders,
|
||||
ae,
|
||||
save_dir,
|
||||
prompt_dict,
|
||||
epoch,
|
||||
steps,
|
||||
sample_prompts_te_outputs,
|
||||
prompt_replacement,
|
||||
controlnet
|
||||
)
|
||||
|
||||
torch.set_rng_state(rng_state)
|
||||
if cuda_rng_state is not None:
|
||||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
|
||||
def sample_image_inference(
|
||||
accelerator: Accelerator,
|
||||
args: argparse.Namespace,
|
||||
flux: flux_models.Flux,
|
||||
text_encoders: Optional[List[CLIPTextModel]],
|
||||
ae: flux_models.AutoEncoder,
|
||||
save_dir,
|
||||
prompt_dict,
|
||||
epoch,
|
||||
steps,
|
||||
sample_prompts_te_outputs,
|
||||
prompt_replacement,
|
||||
controlnet
|
||||
):
|
||||
assert isinstance(prompt_dict, dict)
|
||||
# negative_prompt = prompt_dict.get("negative_prompt")
|
||||
sample_steps = prompt_dict.get("sample_steps", 20)
|
||||
width = prompt_dict.get("width", 512)
|
||||
height = prompt_dict.get("height", 512)
|
||||
scale = prompt_dict.get("scale", 3.5)
|
||||
seed = prompt_dict.get("seed")
|
||||
controlnet_image = prompt_dict.get("controlnet_image")
|
||||
prompt: str = prompt_dict.get("prompt", "")
|
||||
# sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
|
||||
|
||||
if prompt_replacement is not None:
|
||||
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
# if negative_prompt is not None:
|
||||
# negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
|
||||
if seed is not None:
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
else:
|
||||
# True random sample image generation
|
||||
torch.seed()
|
||||
torch.cuda.seed()
|
||||
|
||||
# if negative_prompt is None:
|
||||
# negative_prompt = ""
|
||||
height = max(64, height - height % 16) # round to divisible by 16
|
||||
width = max(64, width - width % 16) # round to divisible by 16
|
||||
logger.info(f"prompt: {prompt}")
|
||||
# logger.info(f"negative_prompt: {negative_prompt}")
|
||||
logger.info(f"height: {height}")
|
||||
logger.info(f"width: {width}")
|
||||
logger.info(f"sample_steps: {sample_steps}")
|
||||
logger.info(f"scale: {scale}")
|
||||
# logger.info(f"sample_sampler: {sampler_name}")
|
||||
if seed is not None:
|
||||
logger.info(f"seed: {seed}")
|
||||
|
||||
# encode prompts
|
||||
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
|
||||
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
|
||||
text_encoder_conds = []
|
||||
if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs:
|
||||
text_encoder_conds = sample_prompts_te_outputs[prompt]
|
||||
print(f"Using cached text encoder outputs for prompt: {prompt}")
|
||||
if text_encoders is not None:
|
||||
print(f"Encoding prompt: {prompt}")
|
||||
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
||||
# strategy has apply_t5_attn_mask option
|
||||
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
|
||||
|
||||
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
|
||||
if len(text_encoder_conds) == 0:
|
||||
text_encoder_conds = encoded_text_encoder_conds
|
||||
else:
|
||||
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
|
||||
for i in range(len(encoded_text_encoder_conds)):
|
||||
if encoded_text_encoder_conds[i] is not None:
|
||||
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||
|
||||
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
||||
|
||||
# sample image
|
||||
weight_dtype = ae.dtype # TOFO give dtype as argument
|
||||
packed_latent_height = height // 16
|
||||
packed_latent_width = width // 16
|
||||
noise = torch.randn(
|
||||
1,
|
||||
packed_latent_height * packed_latent_width,
|
||||
16 * 2 * 2,
|
||||
device=accelerator.device,
|
||||
dtype=weight_dtype,
|
||||
generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None,
|
||||
)
|
||||
timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) # FLUX.1 dev -> shift=True
|
||||
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype)
|
||||
t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None
|
||||
|
||||
if controlnet_image is not None:
|
||||
controlnet_image = Image.open(controlnet_image).convert("RGB")
|
||||
controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS)
|
||||
controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1)
|
||||
controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device)
|
||||
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image)
|
||||
|
||||
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
|
||||
|
||||
# latent to image
|
||||
clean_memory_on_device(accelerator.device)
|
||||
org_vae_device = ae.device # will be on cpu
|
||||
ae.to(accelerator.device) # distributed_state.device is same as accelerator.device
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
x = ae.decode(x)
|
||||
ae.to(org_vae_device)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
x = x.clamp(-1, 1)
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
|
||||
|
||||
# adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
|
||||
# but adding 'enum' to the filename should be enough
|
||||
|
||||
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
||||
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
|
||||
seed_suffix = "" if seed is None else f"_{seed}"
|
||||
i: int = prompt_dict["enum"]
|
||||
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
|
||||
image.save(os.path.join(save_dir, img_filename))
|
||||
|
||||
# send images to wandb if enabled
|
||||
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
|
||||
wandb_tracker = accelerator.get_tracker("wandb")
|
||||
|
||||
import wandb
|
||||
|
||||
# not to commit images to avoid inconsistency between training and logging steps
|
||||
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
|
||||
|
||||
|
||||
def time_shift(mu: float, sigma: float, t: torch.Tensor):
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
|
||||
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
|
||||
m = (y2 - y1) / (x2 - x1)
|
||||
b = y1 - m * x1
|
||||
return lambda x: m * x + b
|
||||
|
||||
|
||||
def get_schedule(
|
||||
num_steps: int,
|
||||
image_seq_len: int,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
shift: bool = True,
|
||||
) -> list[float]:
|
||||
# extra step for zero
|
||||
timesteps = torch.linspace(1, 0, num_steps + 1)
|
||||
|
||||
# shifting the schedule to favor high timesteps for higher signal images
|
||||
if shift:
|
||||
# eastimate mu based on linear estimation between two points
|
||||
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
||||
timesteps = time_shift(mu, 1.0, timesteps)
|
||||
|
||||
return timesteps.tolist()
|
||||
|
||||
|
||||
def denoise(
|
||||
model: flux_models.Flux,
|
||||
img: torch.Tensor,
|
||||
img_ids: torch.Tensor,
|
||||
txt: torch.Tensor,
|
||||
txt_ids: torch.Tensor,
|
||||
vec: torch.Tensor,
|
||||
timesteps: list[float],
|
||||
guidance: float = 4.0,
|
||||
t5_attn_mask: Optional[torch.Tensor] = None,
|
||||
controlnet: Optional[flux_models.ControlNetFlux] = None,
|
||||
controlnet_img: Optional[torch.Tensor] = None,
|
||||
):
|
||||
# this is ignored for schnell
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
|
||||
|
||||
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
model.prepare_block_swap_before_forward()
|
||||
if controlnet is not None:
|
||||
block_samples, block_single_samples = controlnet(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
controlnet_cond=controlnet_img,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
)
|
||||
else:
|
||||
block_samples = None
|
||||
block_single_samples = None
|
||||
pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
block_controlnet_hidden_states=block_samples,
|
||||
block_controlnet_single_hidden_states=block_single_samples,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
)
|
||||
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
|
||||
model.prepare_block_swap_before_forward()
|
||||
return img
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region train
|
||||
def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
|
||||
sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
|
||||
schedule_timesteps = noise_scheduler.timesteps.to(device)
|
||||
timesteps = timesteps.to(device)
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
return sigma
|
||||
|
||||
|
||||
def compute_density_for_timestep_sampling(
|
||||
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
|
||||
):
|
||||
"""Compute the density for sampling the timesteps when doing SD3 training.
|
||||
|
||||
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
||||
|
||||
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
||||
"""
|
||||
if weighting_scheme == "logit_normal":
|
||||
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
|
||||
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
|
||||
u = torch.nn.functional.sigmoid(u)
|
||||
elif weighting_scheme == "mode":
|
||||
u = torch.rand(size=(batch_size,), device="cpu")
|
||||
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
|
||||
else:
|
||||
u = torch.rand(size=(batch_size,), device="cpu")
|
||||
return u
|
||||
|
||||
|
||||
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
||||
"""Computes loss weighting scheme for SD3 training.
|
||||
|
||||
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
||||
|
||||
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
||||
"""
|
||||
if weighting_scheme == "sigma_sqrt":
|
||||
weighting = (sigmas**-2.0).float()
|
||||
elif weighting_scheme == "cosmap":
|
||||
bot = 1 - 2 * sigmas + 2 * sigmas**2
|
||||
weighting = 2 / (math.pi * bot)
|
||||
else:
|
||||
weighting = torch.ones_like(sigmas)
|
||||
return weighting
|
||||
|
||||
|
||||
def get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
bsz, _, h, w = latents.shape
|
||||
assert bsz > 0, "Batch size not large enough"
|
||||
num_timesteps = noise_scheduler.config.num_train_timesteps
|
||||
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
|
||||
# Simple random sigma-based noise sampling
|
||||
if args.timestep_sampling == "sigmoid":
|
||||
# https://github.com/XLabs-AI/x-flux/tree/main
|
||||
sigmas = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
|
||||
else:
|
||||
sigmas = torch.rand((bsz,), device=device)
|
||||
|
||||
timesteps = sigmas * num_timesteps
|
||||
elif args.timestep_sampling == "shift":
|
||||
shift = args.discrete_flow_shift
|
||||
sigmas = torch.randn(bsz, device=device)
|
||||
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
|
||||
sigmas = sigmas.sigmoid()
|
||||
sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas)
|
||||
timesteps = sigmas * num_timesteps
|
||||
elif args.timestep_sampling == "flux_shift":
|
||||
sigmas = torch.randn(bsz, device=device)
|
||||
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
|
||||
sigmas = sigmas.sigmoid()
|
||||
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size
|
||||
sigmas = time_shift(mu, 1.0, sigmas)
|
||||
timesteps = sigmas * num_timesteps
|
||||
else:
|
||||
# Sample a random timestep for each image
|
||||
# for weighting schemes where we sample timesteps non-uniformly
|
||||
u = compute_density_for_timestep_sampling(
|
||||
weighting_scheme=args.weighting_scheme,
|
||||
batch_size=bsz,
|
||||
logit_mean=args.logit_mean,
|
||||
logit_std=args.logit_std,
|
||||
mode_scale=args.mode_scale,
|
||||
)
|
||||
indices = (u * num_timesteps).long()
|
||||
timesteps = noise_scheduler.timesteps[indices].to(device=device)
|
||||
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
|
||||
|
||||
# Broadcast sigmas to latent shape
|
||||
sigmas = sigmas.view(-1, 1, 1, 1)
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
if args.ip_noise_gamma:
|
||||
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
|
||||
if args.ip_noise_gamma_random_strength:
|
||||
ip_noise_gamma = (torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma)
|
||||
else:
|
||||
ip_noise_gamma = args.ip_noise_gamma
|
||||
noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi)
|
||||
else:
|
||||
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
|
||||
|
||||
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
|
||||
|
||||
|
||||
def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas):
|
||||
weighting = None
|
||||
if args.model_prediction_type == "raw":
|
||||
pass
|
||||
elif args.model_prediction_type == "additive":
|
||||
# add the model_pred to the noisy_model_input
|
||||
model_pred = model_pred + noisy_model_input
|
||||
elif args.model_prediction_type == "sigma_scaled":
|
||||
# apply sigma scaling
|
||||
model_pred = model_pred * (-sigmas) + noisy_model_input
|
||||
|
||||
# these weighting schemes use a uniform timestep sampling
|
||||
# and instead post-weight the loss
|
||||
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
|
||||
|
||||
return model_pred, weighting
|
||||
|
||||
|
||||
def save_models(
|
||||
ckpt_path: str,
|
||||
flux: flux_models.Flux,
|
||||
sai_metadata: Optional[dict],
|
||||
save_dtype: Optional[torch.dtype] = None,
|
||||
use_mem_eff_save: bool = False,
|
||||
):
|
||||
state_dict = {}
|
||||
|
||||
def update_sd(prefix, sd):
|
||||
for k, v in sd.items():
|
||||
key = prefix + k
|
||||
if save_dtype is not None and v.dtype != save_dtype:
|
||||
v = v.detach().clone().to("cpu").to(save_dtype)
|
||||
state_dict[key] = v
|
||||
|
||||
update_sd("", flux.state_dict())
|
||||
|
||||
if not use_mem_eff_save:
|
||||
save_file(state_dict, ckpt_path, metadata=sai_metadata)
|
||||
else:
|
||||
mem_eff_save_file(state_dict, ckpt_path, metadata=sai_metadata)
|
||||
|
||||
|
||||
def save_flux_model_on_train_end(
|
||||
args: argparse.Namespace, save_dtype: torch.dtype, epoch: int, global_step: int, flux: flux_models.Flux
|
||||
):
|
||||
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev")
|
||||
save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save)
|
||||
|
||||
train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
|
||||
|
||||
|
||||
# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している
|
||||
# on_epoch_end: Trueならepoch終了時、Falseならstep経過時
|
||||
def save_flux_model_on_epoch_end_or_stepwise(
|
||||
args: argparse.Namespace,
|
||||
on_epoch_end: bool,
|
||||
accelerator,
|
||||
save_dtype: torch.dtype,
|
||||
epoch: int,
|
||||
num_train_epochs: int,
|
||||
global_step: int,
|
||||
flux: flux_models.Flux,
|
||||
):
|
||||
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev")
|
||||
save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save)
|
||||
|
||||
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
|
||||
args,
|
||||
on_epoch_end,
|
||||
accelerator,
|
||||
True,
|
||||
True,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
sd_saver,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
def add_flux_train_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--clip_l",
|
||||
type=str,
|
||||
help="path to clip_l (*.sft or *.safetensors), should be float16 / clip_lのパス(*.sftまたは*.safetensors)、float16が前提",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--t5xxl",
|
||||
type=str,
|
||||
help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス(*.sftまたは*.safetensors)、float16が前提",
|
||||
)
|
||||
parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)")
|
||||
parser.add_argument(
|
||||
"--controlnet_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="path to controlnet (*.sft or *.safetensors) / controlnetのパス(*.sftまたは*.safetensors)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--t5xxl_max_token_length",
|
||||
type=int,
|
||||
default=None,
|
||||
help="maximum token length for T5-XXL. if omitted, 256 for schnell and 512 for dev"
|
||||
" / T5-XXLの最大トークン長。省略された場合、schnellの場合は256、devの場合は512",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--apply_t5_attn_mask",
|
||||
action="store_true",
|
||||
help="apply attention mask to T5-XXL encode and FLUX double blocks / T5-XXLエンコードとFLUXダブルブロックにアテンションマスクを適用する",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--guidance_scale",
|
||||
type=float,
|
||||
default=3.5,
|
||||
help="the FLUX.1 dev variant is a guidance distilled model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--timestep_sampling",
|
||||
choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"],
|
||||
default="sigma",
|
||||
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and FLUX.1 shifting."
|
||||
" / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、FLUX.1のシフト。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sigmoid_scale",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_prediction_type",
|
||||
choices=["raw", "additive", "sigma_scaled"],
|
||||
default="sigma_scaled",
|
||||
help="How to interpret and process the model prediction: "
|
||||
"raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)."
|
||||
" / モデル予測の解釈と処理方法:"
|
||||
"raw(そのまま使用)、additive(ノイズ入力に加算)、sigma_scaled(シグマスケーリングを適用)。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--discrete_flow_shift",
|
||||
type=float,
|
||||
default=3.0,
|
||||
help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
|
||||
)
|
||||
488
library/flux_utils.py
Normal file
488
library/flux_utils.py
Normal file
@@ -0,0 +1,488 @@
|
||||
import json
|
||||
import os
|
||||
from dataclasses import replace
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import einops
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import load_file
|
||||
from transformers import CLIPConfig, CLIPTextModel, T5Config, T5EncoderModel
|
||||
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from library import flux_models
|
||||
from library.utils import load_safetensors
|
||||
|
||||
MODEL_VERSION_FLUX_V1 = "flux1"
|
||||
MODEL_NAME_DEV = "dev"
|
||||
MODEL_NAME_SCHNELL = "schnell"
|
||||
|
||||
|
||||
def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]:
|
||||
"""
|
||||
チェックポイントの状態を分析し、DiffusersかBFLか、devかschnellか、ブロック数を計算して返す。
|
||||
|
||||
Args:
|
||||
ckpt_path (str): チェックポイントファイルまたはディレクトリのパス。
|
||||
|
||||
Returns:
|
||||
Tuple[bool, bool, Tuple[int, int], List[str]]:
|
||||
- bool: Diffusersかどうかを示すフラグ。
|
||||
- bool: Schnellかどうかを示すフラグ。
|
||||
- Tuple[int, int]: ダブルブロックとシングルブロックの数。
|
||||
- List[str]: チェックポイントに含まれるキーのリスト。
|
||||
"""
|
||||
# check the state dict: Diffusers or BFL, dev or schnell, number of blocks
|
||||
logger.info(f"Checking the state dict: Diffusers or BFL, dev or schnell")
|
||||
|
||||
if os.path.isdir(ckpt_path): # if ckpt_path is a directory, it is Diffusers
|
||||
ckpt_path = os.path.join(ckpt_path, "transformer", "diffusion_pytorch_model-00001-of-00003.safetensors")
|
||||
if "00001-of-00003" in ckpt_path:
|
||||
ckpt_paths = [ckpt_path.replace("00001-of-00003", f"0000{i}-of-00003") for i in range(1, 4)]
|
||||
else:
|
||||
ckpt_paths = [ckpt_path]
|
||||
|
||||
keys = []
|
||||
for ckpt_path in ckpt_paths:
|
||||
with safe_open(ckpt_path, framework="pt") as f:
|
||||
keys.extend(f.keys())
|
||||
|
||||
# if the key has annoying prefix, remove it
|
||||
if keys[0].startswith("model.diffusion_model."):
|
||||
keys = [key.replace("model.diffusion_model.", "") for key in keys]
|
||||
|
||||
is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys
|
||||
is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys)
|
||||
|
||||
# check number of double and single blocks
|
||||
if not is_diffusers:
|
||||
max_double_block_index = max(
|
||||
[int(key.split(".")[1]) for key in keys if key.startswith("double_blocks.") and key.endswith(".img_attn.proj.bias")]
|
||||
)
|
||||
max_single_block_index = max(
|
||||
[int(key.split(".")[1]) for key in keys if key.startswith("single_blocks.") and key.endswith(".modulation.lin.bias")]
|
||||
)
|
||||
else:
|
||||
max_double_block_index = max(
|
||||
[
|
||||
int(key.split(".")[1])
|
||||
for key in keys
|
||||
if key.startswith("transformer_blocks.") and key.endswith(".attn.add_k_proj.bias")
|
||||
]
|
||||
)
|
||||
max_single_block_index = max(
|
||||
[
|
||||
int(key.split(".")[1])
|
||||
for key in keys
|
||||
if key.startswith("single_transformer_blocks.") and key.endswith(".attn.to_k.bias")
|
||||
]
|
||||
)
|
||||
|
||||
num_double_blocks = max_double_block_index + 1
|
||||
num_single_blocks = max_single_block_index + 1
|
||||
|
||||
return is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths
|
||||
|
||||
|
||||
def load_flow_model(
|
||||
ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False
|
||||
) -> Tuple[bool, flux_models.Flux]:
|
||||
is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path)
|
||||
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
|
||||
|
||||
# build model
|
||||
logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint")
|
||||
with torch.device("meta"):
|
||||
params = flux_models.configs[name].params
|
||||
|
||||
# set the number of blocks
|
||||
if params.depth != num_double_blocks:
|
||||
logger.info(f"Setting the number of double blocks from {params.depth} to {num_double_blocks}")
|
||||
params = replace(params, depth=num_double_blocks)
|
||||
if params.depth_single_blocks != num_single_blocks:
|
||||
logger.info(f"Setting the number of single blocks from {params.depth_single_blocks} to {num_single_blocks}")
|
||||
params = replace(params, depth_single_blocks=num_single_blocks)
|
||||
|
||||
model = flux_models.Flux(params)
|
||||
if dtype is not None:
|
||||
model = model.to(dtype)
|
||||
|
||||
# load_sft doesn't support torch.device
|
||||
logger.info(f"Loading state dict from {ckpt_path}")
|
||||
sd = {}
|
||||
for ckpt_path in ckpt_paths:
|
||||
sd.update(load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype))
|
||||
|
||||
# convert Diffusers to BFL
|
||||
if is_diffusers:
|
||||
logger.info("Converting Diffusers to BFL")
|
||||
sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks)
|
||||
logger.info("Converted Diffusers to BFL")
|
||||
|
||||
# if the key has annoying prefix, remove it
|
||||
for key in list(sd.keys()):
|
||||
new_key = key.replace("model.diffusion_model.", "")
|
||||
if new_key == key:
|
||||
break # the model doesn't have annoying prefix
|
||||
sd[new_key] = sd.pop(key)
|
||||
|
||||
info = model.load_state_dict(sd, strict=False, assign=True)
|
||||
logger.info(f"Loaded Flux: {info}")
|
||||
return is_schnell, model
|
||||
|
||||
|
||||
def load_ae(
|
||||
ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False
|
||||
) -> flux_models.AutoEncoder:
|
||||
logger.info("Building AutoEncoder")
|
||||
with torch.device("meta"):
|
||||
# dev and schnell have the same AE params
|
||||
ae = flux_models.AutoEncoder(flux_models.configs[MODEL_NAME_DEV].ae_params).to(dtype)
|
||||
|
||||
logger.info(f"Loading state dict from {ckpt_path}")
|
||||
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
|
||||
info = ae.load_state_dict(sd, strict=False, assign=True)
|
||||
logger.info(f"Loaded AE: {info}")
|
||||
return ae
|
||||
|
||||
|
||||
def load_controlnet(
|
||||
ckpt_path: Optional[str], is_schnell: bool, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False
|
||||
):
|
||||
logger.info("Building ControlNet")
|
||||
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
|
||||
with torch.device(device):
|
||||
controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params).to(dtype)
|
||||
|
||||
if ckpt_path is not None:
|
||||
logger.info(f"Loading state dict from {ckpt_path}")
|
||||
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
|
||||
info = controlnet.load_state_dict(sd, strict=False, assign=True)
|
||||
logger.info(f"Loaded ControlNet: {info}")
|
||||
return controlnet
|
||||
|
||||
|
||||
def load_clip_l(
|
||||
ckpt_path: Optional[str],
|
||||
dtype: torch.dtype,
|
||||
device: Union[str, torch.device],
|
||||
disable_mmap: bool = False,
|
||||
state_dict: Optional[dict] = None,
|
||||
) -> CLIPTextModel:
|
||||
logger.info("Building CLIP-L")
|
||||
CLIPL_CONFIG = {
|
||||
"_name_or_path": "clip-vit-large-patch14/",
|
||||
"architectures": ["CLIPModel"],
|
||||
"initializer_factor": 1.0,
|
||||
"logit_scale_init_value": 2.6592,
|
||||
"model_type": "clip",
|
||||
"projection_dim": 768,
|
||||
# "text_config": {
|
||||
"_name_or_path": "",
|
||||
"add_cross_attention": False,
|
||||
"architectures": None,
|
||||
"attention_dropout": 0.0,
|
||||
"bad_words_ids": None,
|
||||
"bos_token_id": 0,
|
||||
"chunk_size_feed_forward": 0,
|
||||
"cross_attention_hidden_size": None,
|
||||
"decoder_start_token_id": None,
|
||||
"diversity_penalty": 0.0,
|
||||
"do_sample": False,
|
||||
"dropout": 0.0,
|
||||
"early_stopping": False,
|
||||
"encoder_no_repeat_ngram_size": 0,
|
||||
"eos_token_id": 2,
|
||||
"finetuning_task": None,
|
||||
"forced_bos_token_id": None,
|
||||
"forced_eos_token_id": None,
|
||||
"hidden_act": "quick_gelu",
|
||||
"hidden_size": 768,
|
||||
"id2label": {"0": "LABEL_0", "1": "LABEL_1"},
|
||||
"initializer_factor": 1.0,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 3072,
|
||||
"is_decoder": False,
|
||||
"is_encoder_decoder": False,
|
||||
"label2id": {"LABEL_0": 0, "LABEL_1": 1},
|
||||
"layer_norm_eps": 1e-05,
|
||||
"length_penalty": 1.0,
|
||||
"max_length": 20,
|
||||
"max_position_embeddings": 77,
|
||||
"min_length": 0,
|
||||
"model_type": "clip_text_model",
|
||||
"no_repeat_ngram_size": 0,
|
||||
"num_attention_heads": 12,
|
||||
"num_beam_groups": 1,
|
||||
"num_beams": 1,
|
||||
"num_hidden_layers": 12,
|
||||
"num_return_sequences": 1,
|
||||
"output_attentions": False,
|
||||
"output_hidden_states": False,
|
||||
"output_scores": False,
|
||||
"pad_token_id": 1,
|
||||
"prefix": None,
|
||||
"problem_type": None,
|
||||
"projection_dim": 768,
|
||||
"pruned_heads": {},
|
||||
"remove_invalid_values": False,
|
||||
"repetition_penalty": 1.0,
|
||||
"return_dict": True,
|
||||
"return_dict_in_generate": False,
|
||||
"sep_token_id": None,
|
||||
"task_specific_params": None,
|
||||
"temperature": 1.0,
|
||||
"tie_encoder_decoder": False,
|
||||
"tie_word_embeddings": True,
|
||||
"tokenizer_class": None,
|
||||
"top_k": 50,
|
||||
"top_p": 1.0,
|
||||
"torch_dtype": None,
|
||||
"torchscript": False,
|
||||
"transformers_version": "4.16.0.dev0",
|
||||
"use_bfloat16": False,
|
||||
"vocab_size": 49408,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_size": 1280,
|
||||
"intermediate_size": 5120,
|
||||
"num_attention_heads": 20,
|
||||
"num_hidden_layers": 32,
|
||||
# },
|
||||
# "text_config_dict": {
|
||||
"hidden_size": 768,
|
||||
"intermediate_size": 3072,
|
||||
"num_attention_heads": 12,
|
||||
"num_hidden_layers": 12,
|
||||
"projection_dim": 768,
|
||||
# },
|
||||
# "torch_dtype": "float32",
|
||||
# "transformers_version": None,
|
||||
}
|
||||
config = CLIPConfig(**CLIPL_CONFIG)
|
||||
with init_empty_weights():
|
||||
clip = CLIPTextModel._from_config(config)
|
||||
|
||||
if state_dict is not None:
|
||||
sd = state_dict
|
||||
else:
|
||||
logger.info(f"Loading state dict from {ckpt_path}")
|
||||
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
|
||||
info = clip.load_state_dict(sd, strict=False, assign=True)
|
||||
logger.info(f"Loaded CLIP-L: {info}")
|
||||
return clip
|
||||
|
||||
|
||||
def load_t5xxl(
|
||||
ckpt_path: str,
|
||||
dtype: Optional[torch.dtype],
|
||||
device: Union[str, torch.device],
|
||||
disable_mmap: bool = False,
|
||||
state_dict: Optional[dict] = None,
|
||||
) -> T5EncoderModel:
|
||||
T5_CONFIG_JSON = """
|
||||
{
|
||||
"architectures": [
|
||||
"T5EncoderModel"
|
||||
],
|
||||
"classifier_dropout": 0.0,
|
||||
"d_ff": 10240,
|
||||
"d_kv": 64,
|
||||
"d_model": 4096,
|
||||
"decoder_start_token_id": 0,
|
||||
"dense_act_fn": "gelu_new",
|
||||
"dropout_rate": 0.1,
|
||||
"eos_token_id": 1,
|
||||
"feed_forward_proj": "gated-gelu",
|
||||
"initializer_factor": 1.0,
|
||||
"is_encoder_decoder": true,
|
||||
"is_gated_act": true,
|
||||
"layer_norm_epsilon": 1e-06,
|
||||
"model_type": "t5",
|
||||
"num_decoder_layers": 24,
|
||||
"num_heads": 64,
|
||||
"num_layers": 24,
|
||||
"output_past": true,
|
||||
"pad_token_id": 0,
|
||||
"relative_attention_max_distance": 128,
|
||||
"relative_attention_num_buckets": 32,
|
||||
"tie_word_embeddings": false,
|
||||
"torch_dtype": "float16",
|
||||
"transformers_version": "4.41.2",
|
||||
"use_cache": true,
|
||||
"vocab_size": 32128
|
||||
}
|
||||
"""
|
||||
config = json.loads(T5_CONFIG_JSON)
|
||||
config = T5Config(**config)
|
||||
with init_empty_weights():
|
||||
t5xxl = T5EncoderModel._from_config(config)
|
||||
|
||||
if state_dict is not None:
|
||||
sd = state_dict
|
||||
else:
|
||||
logger.info(f"Loading state dict from {ckpt_path}")
|
||||
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
|
||||
info = t5xxl.load_state_dict(sd, strict=False, assign=True)
|
||||
logger.info(f"Loaded T5xxl: {info}")
|
||||
return t5xxl
|
||||
|
||||
|
||||
def get_t5xxl_actual_dtype(t5xxl: T5EncoderModel) -> torch.dtype:
|
||||
# nn.Embedding is the first layer, but it could be casted to bfloat16 or float32
|
||||
return t5xxl.encoder.block[0].layer[0].SelfAttention.q.weight.dtype
|
||||
|
||||
|
||||
def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_width: int):
|
||||
img_ids = torch.zeros(packed_latent_height, packed_latent_width, 3)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(packed_latent_height)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(packed_latent_width)[None, :]
|
||||
img_ids = einops.repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
|
||||
return img_ids
|
||||
|
||||
|
||||
def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor:
|
||||
"""
|
||||
x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2
|
||||
"""
|
||||
x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2)
|
||||
return x
|
||||
|
||||
|
||||
def pack_latents(x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
x: [b c (h ph) (w pw)] -> [b (h w) (c ph pw)], ph=2, pw=2
|
||||
"""
|
||||
x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
return x
|
||||
|
||||
|
||||
# region Diffusers
|
||||
|
||||
NUM_DOUBLE_BLOCKS = 19
|
||||
NUM_SINGLE_BLOCKS = 38
|
||||
|
||||
BFL_TO_DIFFUSERS_MAP = {
|
||||
"time_in.in_layer.weight": ["time_text_embed.timestep_embedder.linear_1.weight"],
|
||||
"time_in.in_layer.bias": ["time_text_embed.timestep_embedder.linear_1.bias"],
|
||||
"time_in.out_layer.weight": ["time_text_embed.timestep_embedder.linear_2.weight"],
|
||||
"time_in.out_layer.bias": ["time_text_embed.timestep_embedder.linear_2.bias"],
|
||||
"vector_in.in_layer.weight": ["time_text_embed.text_embedder.linear_1.weight"],
|
||||
"vector_in.in_layer.bias": ["time_text_embed.text_embedder.linear_1.bias"],
|
||||
"vector_in.out_layer.weight": ["time_text_embed.text_embedder.linear_2.weight"],
|
||||
"vector_in.out_layer.bias": ["time_text_embed.text_embedder.linear_2.bias"],
|
||||
"guidance_in.in_layer.weight": ["time_text_embed.guidance_embedder.linear_1.weight"],
|
||||
"guidance_in.in_layer.bias": ["time_text_embed.guidance_embedder.linear_1.bias"],
|
||||
"guidance_in.out_layer.weight": ["time_text_embed.guidance_embedder.linear_2.weight"],
|
||||
"guidance_in.out_layer.bias": ["time_text_embed.guidance_embedder.linear_2.bias"],
|
||||
"txt_in.weight": ["context_embedder.weight"],
|
||||
"txt_in.bias": ["context_embedder.bias"],
|
||||
"img_in.weight": ["x_embedder.weight"],
|
||||
"img_in.bias": ["x_embedder.bias"],
|
||||
"double_blocks.().img_mod.lin.weight": ["norm1.linear.weight"],
|
||||
"double_blocks.().img_mod.lin.bias": ["norm1.linear.bias"],
|
||||
"double_blocks.().txt_mod.lin.weight": ["norm1_context.linear.weight"],
|
||||
"double_blocks.().txt_mod.lin.bias": ["norm1_context.linear.bias"],
|
||||
"double_blocks.().img_attn.qkv.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight"],
|
||||
"double_blocks.().img_attn.qkv.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias"],
|
||||
"double_blocks.().txt_attn.qkv.weight": ["attn.add_q_proj.weight", "attn.add_k_proj.weight", "attn.add_v_proj.weight"],
|
||||
"double_blocks.().txt_attn.qkv.bias": ["attn.add_q_proj.bias", "attn.add_k_proj.bias", "attn.add_v_proj.bias"],
|
||||
"double_blocks.().img_attn.norm.query_norm.scale": ["attn.norm_q.weight"],
|
||||
"double_blocks.().img_attn.norm.key_norm.scale": ["attn.norm_k.weight"],
|
||||
"double_blocks.().txt_attn.norm.query_norm.scale": ["attn.norm_added_q.weight"],
|
||||
"double_blocks.().txt_attn.norm.key_norm.scale": ["attn.norm_added_k.weight"],
|
||||
"double_blocks.().img_mlp.0.weight": ["ff.net.0.proj.weight"],
|
||||
"double_blocks.().img_mlp.0.bias": ["ff.net.0.proj.bias"],
|
||||
"double_blocks.().img_mlp.2.weight": ["ff.net.2.weight"],
|
||||
"double_blocks.().img_mlp.2.bias": ["ff.net.2.bias"],
|
||||
"double_blocks.().txt_mlp.0.weight": ["ff_context.net.0.proj.weight"],
|
||||
"double_blocks.().txt_mlp.0.bias": ["ff_context.net.0.proj.bias"],
|
||||
"double_blocks.().txt_mlp.2.weight": ["ff_context.net.2.weight"],
|
||||
"double_blocks.().txt_mlp.2.bias": ["ff_context.net.2.bias"],
|
||||
"double_blocks.().img_attn.proj.weight": ["attn.to_out.0.weight"],
|
||||
"double_blocks.().img_attn.proj.bias": ["attn.to_out.0.bias"],
|
||||
"double_blocks.().txt_attn.proj.weight": ["attn.to_add_out.weight"],
|
||||
"double_blocks.().txt_attn.proj.bias": ["attn.to_add_out.bias"],
|
||||
"single_blocks.().modulation.lin.weight": ["norm.linear.weight"],
|
||||
"single_blocks.().modulation.lin.bias": ["norm.linear.bias"],
|
||||
"single_blocks.().linear1.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight", "proj_mlp.weight"],
|
||||
"single_blocks.().linear1.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias", "proj_mlp.bias"],
|
||||
"single_blocks.().linear2.weight": ["proj_out.weight"],
|
||||
"single_blocks.().norm.query_norm.scale": ["attn.norm_q.weight"],
|
||||
"single_blocks.().norm.key_norm.scale": ["attn.norm_k.weight"],
|
||||
"single_blocks.().linear2.weight": ["proj_out.weight"],
|
||||
"single_blocks.().linear2.bias": ["proj_out.bias"],
|
||||
"final_layer.linear.weight": ["proj_out.weight"],
|
||||
"final_layer.linear.bias": ["proj_out.bias"],
|
||||
"final_layer.adaLN_modulation.1.weight": ["norm_out.linear.weight"],
|
||||
"final_layer.adaLN_modulation.1.bias": ["norm_out.linear.bias"],
|
||||
}
|
||||
|
||||
|
||||
def make_diffusers_to_bfl_map(num_double_blocks: int, num_single_blocks: int) -> dict[str, tuple[int, str]]:
|
||||
# make reverse map from diffusers map
|
||||
diffusers_to_bfl_map = {} # key: diffusers_key, value: (index, bfl_key)
|
||||
for b in range(num_double_blocks):
|
||||
for key, weights in BFL_TO_DIFFUSERS_MAP.items():
|
||||
if key.startswith("double_blocks."):
|
||||
block_prefix = f"transformer_blocks.{b}."
|
||||
for i, weight in enumerate(weights):
|
||||
diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}"))
|
||||
for b in range(num_single_blocks):
|
||||
for key, weights in BFL_TO_DIFFUSERS_MAP.items():
|
||||
if key.startswith("single_blocks."):
|
||||
block_prefix = f"single_transformer_blocks.{b}."
|
||||
for i, weight in enumerate(weights):
|
||||
diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}"))
|
||||
for key, weights in BFL_TO_DIFFUSERS_MAP.items():
|
||||
if not (key.startswith("double_blocks.") or key.startswith("single_blocks.")):
|
||||
for i, weight in enumerate(weights):
|
||||
diffusers_to_bfl_map[weight] = (i, key)
|
||||
return diffusers_to_bfl_map
|
||||
|
||||
|
||||
def convert_diffusers_sd_to_bfl(
|
||||
diffusers_sd: dict[str, torch.Tensor], num_double_blocks: int = NUM_DOUBLE_BLOCKS, num_single_blocks: int = NUM_SINGLE_BLOCKS
|
||||
) -> dict[str, torch.Tensor]:
|
||||
diffusers_to_bfl_map = make_diffusers_to_bfl_map(num_double_blocks, num_single_blocks)
|
||||
|
||||
# iterate over three safetensors files to reduce memory usage
|
||||
flux_sd = {}
|
||||
for diffusers_key, tensor in diffusers_sd.items():
|
||||
if diffusers_key in diffusers_to_bfl_map:
|
||||
index, bfl_key = diffusers_to_bfl_map[diffusers_key]
|
||||
if bfl_key not in flux_sd:
|
||||
flux_sd[bfl_key] = []
|
||||
flux_sd[bfl_key].append((index, tensor))
|
||||
else:
|
||||
logger.error(f"Error: Key not found in diffusers_to_bfl_map: {diffusers_key}")
|
||||
raise KeyError(f"Key not found in diffusers_to_bfl_map: {diffusers_key}")
|
||||
|
||||
# concat tensors if multiple tensors are mapped to a single key, sort by index
|
||||
for key, values in flux_sd.items():
|
||||
if len(values) == 1:
|
||||
flux_sd[key] = values[0][1]
|
||||
else:
|
||||
flux_sd[key] = torch.cat([value[1] for value in sorted(values, key=lambda x: x[0])])
|
||||
|
||||
# special case for final_layer.adaLN_modulation.1.weight and final_layer.adaLN_modulation.1.bias
|
||||
def swap_scale_shift(weight):
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
return new_weight
|
||||
|
||||
if "final_layer.adaLN_modulation.1.weight" in flux_sd:
|
||||
flux_sd["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.weight"])
|
||||
if "final_layer.adaLN_modulation.1.bias" in flux_sd:
|
||||
flux_sd["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.bias"])
|
||||
|
||||
return flux_sd
|
||||
|
||||
|
||||
# endregion
|
||||
@@ -4,7 +4,10 @@ from pathlib import Path
|
||||
import argparse
|
||||
import os
|
||||
from library.utils import fire_in_thread
|
||||
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None):
|
||||
api = HfApi(
|
||||
@@ -33,9 +36,9 @@ def upload(
|
||||
try:
|
||||
api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
|
||||
except Exception as e: # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので
|
||||
print("===========================================")
|
||||
print(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}")
|
||||
print("===========================================")
|
||||
logger.error("===========================================")
|
||||
logger.error(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}")
|
||||
logger.error("===========================================")
|
||||
|
||||
is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir())
|
||||
|
||||
@@ -56,9 +59,9 @@ def upload(
|
||||
path_in_repo=path_in_repo,
|
||||
)
|
||||
except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので
|
||||
print("===========================================")
|
||||
print(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}")
|
||||
print("===========================================")
|
||||
logger.error("===========================================")
|
||||
logger.error(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}")
|
||||
logger.error("===========================================")
|
||||
|
||||
if args.async_upload and not force_sync_upload:
|
||||
fire_in_thread(uploader)
|
||||
|
||||
@@ -2,169 +2,217 @@ import os
|
||||
import sys
|
||||
import contextlib
|
||||
import torch
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||
legacy = True
|
||||
except Exception:
|
||||
legacy = False
|
||||
from .hijacks import ipex_hijacks
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||
|
||||
def ipex_init(): # pylint: disable=too-many-statements
|
||||
try:
|
||||
# Replace cuda with xpu:
|
||||
torch.cuda.current_device = torch.xpu.current_device
|
||||
torch.cuda.current_stream = torch.xpu.current_stream
|
||||
torch.cuda.device = torch.xpu.device
|
||||
torch.cuda.device_count = torch.xpu.device_count
|
||||
torch.cuda.device_of = torch.xpu.device_of
|
||||
torch.cuda.get_device_name = torch.xpu.get_device_name
|
||||
torch.cuda.get_device_properties = torch.xpu.get_device_properties
|
||||
torch.cuda.init = torch.xpu.init
|
||||
torch.cuda.is_available = torch.xpu.is_available
|
||||
torch.cuda.is_initialized = torch.xpu.is_initialized
|
||||
torch.cuda.is_current_stream_capturing = lambda: False
|
||||
torch.cuda.set_device = torch.xpu.set_device
|
||||
torch.cuda.stream = torch.xpu.stream
|
||||
torch.cuda.synchronize = torch.xpu.synchronize
|
||||
torch.cuda.Event = torch.xpu.Event
|
||||
torch.cuda.Stream = torch.xpu.Stream
|
||||
torch.cuda.FloatTensor = torch.xpu.FloatTensor
|
||||
torch.Tensor.cuda = torch.Tensor.xpu
|
||||
torch.Tensor.is_cuda = torch.Tensor.is_xpu
|
||||
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
|
||||
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
|
||||
torch.cuda._initialized = torch.xpu.lazy_init._initialized
|
||||
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
|
||||
torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
|
||||
torch.cuda._tls = torch.xpu.lazy_init._tls
|
||||
torch.cuda.threading = torch.xpu.lazy_init.threading
|
||||
torch.cuda.traceback = torch.xpu.lazy_init.traceback
|
||||
torch.cuda.Optional = torch.xpu.Optional
|
||||
torch.cuda.__cached__ = torch.xpu.__cached__
|
||||
torch.cuda.__loader__ = torch.xpu.__loader__
|
||||
torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
|
||||
torch.cuda.Tuple = torch.xpu.Tuple
|
||||
torch.cuda.streams = torch.xpu.streams
|
||||
torch.cuda._lazy_new = torch.xpu._lazy_new
|
||||
torch.cuda.FloatStorage = torch.xpu.FloatStorage
|
||||
torch.cuda.Any = torch.xpu.Any
|
||||
torch.cuda.__doc__ = torch.xpu.__doc__
|
||||
torch.cuda.default_generators = torch.xpu.default_generators
|
||||
torch.cuda.HalfTensor = torch.xpu.HalfTensor
|
||||
torch.cuda._get_device_index = torch.xpu._get_device_index
|
||||
torch.cuda.__path__ = torch.xpu.__path__
|
||||
torch.cuda.Device = torch.xpu.Device
|
||||
torch.cuda.IntTensor = torch.xpu.IntTensor
|
||||
torch.cuda.ByteStorage = torch.xpu.ByteStorage
|
||||
torch.cuda.set_stream = torch.xpu.set_stream
|
||||
torch.cuda.BoolStorage = torch.xpu.BoolStorage
|
||||
torch.cuda.os = torch.xpu.os
|
||||
torch.cuda.torch = torch.xpu.torch
|
||||
torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
|
||||
torch.cuda.Union = torch.xpu.Union
|
||||
torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
|
||||
torch.cuda.ShortTensor = torch.xpu.ShortTensor
|
||||
torch.cuda.LongTensor = torch.xpu.LongTensor
|
||||
torch.cuda.IntStorage = torch.xpu.IntStorage
|
||||
torch.cuda.LongStorage = torch.xpu.LongStorage
|
||||
torch.cuda.__annotations__ = torch.xpu.__annotations__
|
||||
torch.cuda.__package__ = torch.xpu.__package__
|
||||
torch.cuda.__builtins__ = torch.xpu.__builtins__
|
||||
torch.cuda.CharTensor = torch.xpu.CharTensor
|
||||
torch.cuda.List = torch.xpu.List
|
||||
torch.cuda._lazy_init = torch.xpu._lazy_init
|
||||
torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
|
||||
torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
|
||||
torch.cuda.ByteTensor = torch.xpu.ByteTensor
|
||||
torch.cuda.StreamContext = torch.xpu.StreamContext
|
||||
torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
|
||||
torch.cuda.ShortStorage = torch.xpu.ShortStorage
|
||||
torch.cuda._lazy_call = torch.xpu._lazy_call
|
||||
torch.cuda.HalfStorage = torch.xpu.HalfStorage
|
||||
torch.cuda.random = torch.xpu.random
|
||||
torch.cuda._device = torch.xpu._device
|
||||
torch.cuda.classproperty = torch.xpu.classproperty
|
||||
torch.cuda.__name__ = torch.xpu.__name__
|
||||
torch.cuda._device_t = torch.xpu._device_t
|
||||
torch.cuda.warnings = torch.xpu.warnings
|
||||
torch.cuda.__spec__ = torch.xpu.__spec__
|
||||
torch.cuda.BoolTensor = torch.xpu.BoolTensor
|
||||
torch.cuda.CharStorage = torch.xpu.CharStorage
|
||||
torch.cuda.__file__ = torch.xpu.__file__
|
||||
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
|
||||
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
|
||||
if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_xpu_hijacked") and torch.cuda.is_xpu_hijacked:
|
||||
return True, "Skipping IPEX hijack"
|
||||
else:
|
||||
try: # force xpu device on torch compile and triton
|
||||
torch._inductor.utils.GPU_TYPES = ["xpu"]
|
||||
torch._inductor.utils.get_gpu_type = lambda *args, **kwargs: "xpu"
|
||||
from triton import backends as triton_backends # pylint: disable=import-error
|
||||
triton_backends.backends["nvidia"].driver.is_active = lambda *args, **kwargs: False
|
||||
except Exception:
|
||||
pass
|
||||
# Replace cuda with xpu:
|
||||
torch.cuda.current_device = torch.xpu.current_device
|
||||
torch.cuda.current_stream = torch.xpu.current_stream
|
||||
torch.cuda.device = torch.xpu.device
|
||||
torch.cuda.device_count = torch.xpu.device_count
|
||||
torch.cuda.device_of = torch.xpu.device_of
|
||||
torch.cuda.get_device_name = torch.xpu.get_device_name
|
||||
torch.cuda.get_device_properties = torch.xpu.get_device_properties
|
||||
torch.cuda.init = torch.xpu.init
|
||||
torch.cuda.is_available = torch.xpu.is_available
|
||||
torch.cuda.is_initialized = torch.xpu.is_initialized
|
||||
torch.cuda.is_current_stream_capturing = lambda: False
|
||||
torch.cuda.set_device = torch.xpu.set_device
|
||||
torch.cuda.stream = torch.xpu.stream
|
||||
torch.cuda.Event = torch.xpu.Event
|
||||
torch.cuda.Stream = torch.xpu.Stream
|
||||
torch.Tensor.cuda = torch.Tensor.xpu
|
||||
torch.Tensor.is_cuda = torch.Tensor.is_xpu
|
||||
torch.nn.Module.cuda = torch.nn.Module.xpu
|
||||
torch.cuda.Optional = torch.xpu.Optional
|
||||
torch.cuda.__cached__ = torch.xpu.__cached__
|
||||
torch.cuda.__loader__ = torch.xpu.__loader__
|
||||
torch.cuda.Tuple = torch.xpu.Tuple
|
||||
torch.cuda.streams = torch.xpu.streams
|
||||
torch.cuda.Any = torch.xpu.Any
|
||||
torch.cuda.__doc__ = torch.xpu.__doc__
|
||||
torch.cuda.default_generators = torch.xpu.default_generators
|
||||
torch.cuda._get_device_index = torch.xpu._get_device_index
|
||||
torch.cuda.__path__ = torch.xpu.__path__
|
||||
torch.cuda.set_stream = torch.xpu.set_stream
|
||||
torch.cuda.torch = torch.xpu.torch
|
||||
torch.cuda.Union = torch.xpu.Union
|
||||
torch.cuda.__annotations__ = torch.xpu.__annotations__
|
||||
torch.cuda.__package__ = torch.xpu.__package__
|
||||
torch.cuda.__builtins__ = torch.xpu.__builtins__
|
||||
torch.cuda.List = torch.xpu.List
|
||||
torch.cuda._lazy_init = torch.xpu._lazy_init
|
||||
torch.cuda.StreamContext = torch.xpu.StreamContext
|
||||
torch.cuda._lazy_call = torch.xpu._lazy_call
|
||||
torch.cuda.random = torch.xpu.random
|
||||
torch.cuda._device = torch.xpu._device
|
||||
torch.cuda.__name__ = torch.xpu.__name__
|
||||
torch.cuda._device_t = torch.xpu._device_t
|
||||
torch.cuda.__spec__ = torch.xpu.__spec__
|
||||
torch.cuda.__file__ = torch.xpu.__file__
|
||||
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
|
||||
|
||||
# Memory:
|
||||
torch.cuda.memory = torch.xpu.memory
|
||||
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
|
||||
torch.xpu.empty_cache = lambda: None
|
||||
torch.cuda.empty_cache = torch.xpu.empty_cache
|
||||
torch.cuda.memory_stats = torch.xpu.memory_stats
|
||||
torch.cuda.memory_summary = torch.xpu.memory_summary
|
||||
torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
|
||||
torch.cuda.memory_allocated = torch.xpu.memory_allocated
|
||||
torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated
|
||||
torch.cuda.memory_reserved = torch.xpu.memory_reserved
|
||||
torch.cuda.memory_cached = torch.xpu.memory_reserved
|
||||
torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved
|
||||
torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved
|
||||
torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats
|
||||
torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats
|
||||
torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats
|
||||
torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict
|
||||
torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats
|
||||
if legacy:
|
||||
torch.cuda.os = torch.xpu.os
|
||||
torch.cuda.Device = torch.xpu.Device
|
||||
torch.cuda.warnings = torch.xpu.warnings
|
||||
torch.cuda.classproperty = torch.xpu.classproperty
|
||||
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
|
||||
if float(ipex.__version__[:3]) < 2.3:
|
||||
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
|
||||
torch.cuda._initialized = torch.xpu.lazy_init._initialized
|
||||
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
|
||||
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
|
||||
torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
|
||||
torch.cuda._tls = torch.xpu.lazy_init._tls
|
||||
torch.cuda.threading = torch.xpu.lazy_init.threading
|
||||
torch.cuda.traceback = torch.xpu.lazy_init.traceback
|
||||
torch.cuda._lazy_new = torch.xpu._lazy_new
|
||||
|
||||
# RNG:
|
||||
torch.cuda.get_rng_state = torch.xpu.get_rng_state
|
||||
torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all
|
||||
torch.cuda.set_rng_state = torch.xpu.set_rng_state
|
||||
torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all
|
||||
torch.cuda.manual_seed = torch.xpu.manual_seed
|
||||
torch.cuda.manual_seed_all = torch.xpu.manual_seed_all
|
||||
torch.cuda.seed = torch.xpu.seed
|
||||
torch.cuda.seed_all = torch.xpu.seed_all
|
||||
torch.cuda.initial_seed = torch.xpu.initial_seed
|
||||
torch.cuda.FloatTensor = torch.xpu.FloatTensor
|
||||
torch.cuda.FloatStorage = torch.xpu.FloatStorage
|
||||
torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
|
||||
torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
|
||||
torch.cuda.HalfTensor = torch.xpu.HalfTensor
|
||||
torch.cuda.HalfStorage = torch.xpu.HalfStorage
|
||||
torch.cuda.ByteTensor = torch.xpu.ByteTensor
|
||||
torch.cuda.ByteStorage = torch.xpu.ByteStorage
|
||||
torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
|
||||
torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
|
||||
torch.cuda.ShortTensor = torch.xpu.ShortTensor
|
||||
torch.cuda.ShortStorage = torch.xpu.ShortStorage
|
||||
torch.cuda.LongTensor = torch.xpu.LongTensor
|
||||
torch.cuda.LongStorage = torch.xpu.LongStorage
|
||||
torch.cuda.IntTensor = torch.xpu.IntTensor
|
||||
torch.cuda.IntStorage = torch.xpu.IntStorage
|
||||
torch.cuda.CharTensor = torch.xpu.CharTensor
|
||||
torch.cuda.CharStorage = torch.xpu.CharStorage
|
||||
torch.cuda.BoolTensor = torch.xpu.BoolTensor
|
||||
torch.cuda.BoolStorage = torch.xpu.BoolStorage
|
||||
torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
|
||||
torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
|
||||
|
||||
# AMP:
|
||||
torch.cuda.amp = torch.xpu.amp
|
||||
if not hasattr(torch.cuda.amp, "common"):
|
||||
torch.cuda.amp.common = contextlib.nullcontext()
|
||||
torch.cuda.amp.common.amp_definitely_not_available = lambda: False
|
||||
try:
|
||||
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
try:
|
||||
from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
|
||||
gradscaler_init()
|
||||
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
|
||||
if not legacy or float(ipex.__version__[:3]) >= 2.3:
|
||||
torch.cuda._initialization_lock = torch.xpu._initialization_lock
|
||||
torch.cuda._initialized = torch.xpu._initialized
|
||||
torch.cuda._is_in_bad_fork = torch.xpu._is_in_bad_fork
|
||||
torch.cuda._lazy_seed_tracker = torch.xpu._lazy_seed_tracker
|
||||
torch.cuda._queued_calls = torch.xpu._queued_calls
|
||||
torch.cuda._tls = torch.xpu._tls
|
||||
torch.cuda.threading = torch.xpu.threading
|
||||
torch.cuda.traceback = torch.xpu.traceback
|
||||
|
||||
# C
|
||||
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
|
||||
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_eu_count
|
||||
ipex._C._DeviceProperties.major = 2023
|
||||
ipex._C._DeviceProperties.minor = 2
|
||||
# Memory:
|
||||
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
|
||||
torch.xpu.empty_cache = lambda: None
|
||||
torch.cuda.empty_cache = torch.xpu.empty_cache
|
||||
|
||||
# Fix functions with ipex:
|
||||
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
|
||||
torch._utils._get_available_device_type = lambda: "xpu"
|
||||
torch.has_cuda = True
|
||||
torch.cuda.has_half = True
|
||||
torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
|
||||
torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
|
||||
torch.version.cuda = "11.7"
|
||||
torch.cuda.get_device_capability = lambda *args, **kwargs: [11,7]
|
||||
torch.cuda.get_device_properties.major = 11
|
||||
torch.cuda.get_device_properties.minor = 7
|
||||
torch.cuda.ipc_collect = lambda *args, **kwargs: None
|
||||
torch.cuda.utilization = lambda *args, **kwargs: 0
|
||||
if legacy:
|
||||
torch.cuda.memory_summary = torch.xpu.memory_summary
|
||||
torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
|
||||
torch.cuda.memory = torch.xpu.memory
|
||||
torch.cuda.memory_stats = torch.xpu.memory_stats
|
||||
torch.cuda.memory_allocated = torch.xpu.memory_allocated
|
||||
torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated
|
||||
torch.cuda.memory_reserved = torch.xpu.memory_reserved
|
||||
torch.cuda.memory_cached = torch.xpu.memory_reserved
|
||||
torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved
|
||||
torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved
|
||||
torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats
|
||||
torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats
|
||||
torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats
|
||||
torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict
|
||||
torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats
|
||||
|
||||
ipex_hijacks()
|
||||
if not torch.xpu.has_fp64_dtype():
|
||||
# RNG:
|
||||
torch.cuda.get_rng_state = torch.xpu.get_rng_state
|
||||
torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all
|
||||
torch.cuda.set_rng_state = torch.xpu.set_rng_state
|
||||
torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all
|
||||
torch.cuda.manual_seed = torch.xpu.manual_seed
|
||||
torch.cuda.manual_seed_all = torch.xpu.manual_seed_all
|
||||
torch.cuda.seed = torch.xpu.seed
|
||||
torch.cuda.seed_all = torch.xpu.seed_all
|
||||
torch.cuda.initial_seed = torch.xpu.initial_seed
|
||||
|
||||
# AMP:
|
||||
if legacy:
|
||||
torch.xpu.amp.custom_fwd = torch.cuda.amp.custom_fwd
|
||||
torch.xpu.amp.custom_bwd = torch.cuda.amp.custom_bwd
|
||||
torch.cuda.amp = torch.xpu.amp
|
||||
if float(ipex.__version__[:3]) < 2.3:
|
||||
torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled
|
||||
torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype
|
||||
|
||||
if not hasattr(torch.cuda.amp, "common"):
|
||||
torch.cuda.amp.common = contextlib.nullcontext()
|
||||
torch.cuda.amp.common.amp_definitely_not_available = lambda: False
|
||||
|
||||
try:
|
||||
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
try:
|
||||
from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
|
||||
gradscaler_init()
|
||||
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
|
||||
|
||||
# C
|
||||
if legacy and float(ipex.__version__[:3]) < 2.3:
|
||||
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentRawStream
|
||||
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
|
||||
ipex._C._DeviceProperties.major = 12
|
||||
ipex._C._DeviceProperties.minor = 1
|
||||
else:
|
||||
torch._C._cuda_getCurrentRawStream = torch._C._xpu_getCurrentRawStream
|
||||
torch._C._XpuDeviceProperties.multi_processor_count = torch._C._XpuDeviceProperties.gpu_subslice_count
|
||||
torch._C._XpuDeviceProperties.major = 12
|
||||
torch._C._XpuDeviceProperties.minor = 1
|
||||
|
||||
# Fix functions with ipex:
|
||||
# torch.xpu.mem_get_info always returns the total memory as free memory
|
||||
torch.xpu.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
|
||||
torch.cuda.mem_get_info = torch.xpu.mem_get_info
|
||||
torch._utils._get_available_device_type = lambda: "xpu"
|
||||
torch.has_cuda = True
|
||||
torch.cuda.has_half = True
|
||||
torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
|
||||
torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
|
||||
torch.backends.cuda.is_built = lambda *args, **kwargs: True
|
||||
torch.version.cuda = "12.1"
|
||||
torch.cuda.get_arch_list = lambda: ["ats-m150", "pvc"]
|
||||
torch.cuda.get_device_capability = lambda *args, **kwargs: (12,1)
|
||||
torch.cuda.get_device_properties.major = 12
|
||||
torch.cuda.get_device_properties.minor = 1
|
||||
torch.cuda.ipc_collect = lambda *args, **kwargs: None
|
||||
torch.cuda.utilization = lambda *args, **kwargs: 0
|
||||
|
||||
device_supports_fp64, can_allocate_plus_4gb = ipex_hijacks(legacy=legacy)
|
||||
try:
|
||||
from .diffusers import ipex_diffusers
|
||||
ipex_diffusers()
|
||||
ipex_diffusers(device_supports_fp64=device_supports_fp64, can_allocate_plus_4gb=can_allocate_plus_4gb)
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
pass
|
||||
torch.cuda.is_xpu_hijacked = True
|
||||
except Exception as e:
|
||||
return False, e
|
||||
return True, None
|
||||
|
||||
@@ -1,175 +1,119 @@
|
||||
import os
|
||||
import torch
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||
from functools import cache
|
||||
from functools import cache, wraps
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||
|
||||
# ARC GPUs can't allocate more than 4GB to a single block so we slice the attetion layers
|
||||
# ARC GPUs can't allocate more than 4GB to a single block so we slice the attention layers
|
||||
|
||||
sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4))
|
||||
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
|
||||
sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 1))
|
||||
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 0.5))
|
||||
|
||||
# Find something divisible with the input_tokens
|
||||
@cache
|
||||
def find_slice_size(slice_size, slice_block_size):
|
||||
while (slice_size * slice_block_size) > attention_slice_rate:
|
||||
slice_size = slice_size // 2
|
||||
if slice_size <= 1:
|
||||
slice_size = 1
|
||||
break
|
||||
return slice_size
|
||||
def find_split_size(original_size, slice_block_size, slice_rate=2):
|
||||
split_size = original_size
|
||||
while True:
|
||||
if (split_size * slice_block_size) <= slice_rate and original_size % split_size == 0:
|
||||
return split_size
|
||||
split_size = split_size - 1
|
||||
if split_size <= 1:
|
||||
return 1
|
||||
return split_size
|
||||
|
||||
|
||||
# Find slice sizes for SDPA
|
||||
@cache
|
||||
def find_sdpa_slice_sizes(query_shape, query_element_size):
|
||||
if len(query_shape) == 3:
|
||||
batch_size_attention, query_tokens, shape_three = query_shape
|
||||
shape_four = 1
|
||||
else:
|
||||
batch_size_attention, query_tokens, shape_three, shape_four = query_shape
|
||||
def find_sdpa_slice_sizes(query_shape, key_shape, query_element_size, slice_rate=2, trigger_rate=3):
|
||||
batch_size, attn_heads, query_len, _ = query_shape
|
||||
_, _, key_len, _ = key_shape
|
||||
|
||||
slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size
|
||||
block_size = batch_size_attention * slice_block_size
|
||||
slice_batch_size = attn_heads * (query_len * key_len) * query_element_size / 1024 / 1024 / 1024
|
||||
|
||||
split_slice_size = batch_size_attention
|
||||
split_2_slice_size = query_tokens
|
||||
split_3_slice_size = shape_three
|
||||
split_batch_size = batch_size
|
||||
split_head_size = attn_heads
|
||||
split_query_size = query_len
|
||||
|
||||
do_split = False
|
||||
do_split_2 = False
|
||||
do_split_3 = False
|
||||
do_batch_split = False
|
||||
do_head_split = False
|
||||
do_query_split = False
|
||||
|
||||
if block_size > sdpa_slice_trigger_rate:
|
||||
do_split = True
|
||||
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
|
||||
if split_slice_size * slice_block_size > attention_slice_rate:
|
||||
slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
|
||||
do_split_2 = True
|
||||
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
|
||||
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
|
||||
slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
|
||||
do_split_3 = True
|
||||
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
|
||||
if batch_size * slice_batch_size >= trigger_rate:
|
||||
do_batch_split = True
|
||||
split_batch_size = find_split_size(batch_size, slice_batch_size, slice_rate=slice_rate)
|
||||
|
||||
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
||||
if split_batch_size * slice_batch_size > slice_rate:
|
||||
slice_head_size = split_batch_size * (query_len * key_len) * query_element_size / 1024 / 1024 / 1024
|
||||
do_head_split = True
|
||||
split_head_size = find_split_size(attn_heads, slice_head_size, slice_rate=slice_rate)
|
||||
|
||||
# Find slice sizes for BMM
|
||||
@cache
|
||||
def find_bmm_slice_sizes(input_shape, input_element_size, mat2_shape):
|
||||
batch_size_attention, input_tokens, mat2_atten_shape = input_shape[0], input_shape[1], mat2_shape[2]
|
||||
slice_block_size = input_tokens * mat2_atten_shape / 1024 / 1024 * input_element_size
|
||||
block_size = batch_size_attention * slice_block_size
|
||||
if split_head_size * slice_head_size > slice_rate:
|
||||
slice_query_size = split_batch_size * split_head_size * (key_len) * query_element_size / 1024 / 1024 / 1024
|
||||
do_query_split = True
|
||||
split_query_size = find_split_size(query_len, slice_query_size, slice_rate=slice_rate)
|
||||
|
||||
split_slice_size = batch_size_attention
|
||||
split_2_slice_size = input_tokens
|
||||
split_3_slice_size = mat2_atten_shape
|
||||
return do_batch_split, do_head_split, do_query_split, split_batch_size, split_head_size, split_query_size
|
||||
|
||||
do_split = False
|
||||
do_split_2 = False
|
||||
do_split_3 = False
|
||||
|
||||
if block_size > attention_slice_rate:
|
||||
do_split = True
|
||||
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
|
||||
if split_slice_size * slice_block_size > attention_slice_rate:
|
||||
slice_2_block_size = split_slice_size * mat2_atten_shape / 1024 / 1024 * input_element_size
|
||||
do_split_2 = True
|
||||
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
|
||||
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
|
||||
slice_3_block_size = split_slice_size * split_2_slice_size / 1024 / 1024 * input_element_size
|
||||
do_split_3 = True
|
||||
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
|
||||
|
||||
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
||||
|
||||
|
||||
original_torch_bmm = torch.bmm
|
||||
def torch_bmm_32_bit(input, mat2, *, out=None):
|
||||
if input.device.type != "xpu":
|
||||
return original_torch_bmm(input, mat2, out=out)
|
||||
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_bmm_slice_sizes(input.shape, input.element_size(), mat2.shape)
|
||||
|
||||
# Slice BMM
|
||||
if do_split:
|
||||
batch_size_attention, input_tokens, mat2_atten_shape = input.shape[0], input.shape[1], mat2.shape[2]
|
||||
hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype)
|
||||
for i in range(batch_size_attention // split_slice_size):
|
||||
start_idx = i * split_slice_size
|
||||
end_idx = (i + 1) * split_slice_size
|
||||
if do_split_2:
|
||||
for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_2 = i2 * split_2_slice_size
|
||||
end_idx_2 = (i2 + 1) * split_2_slice_size
|
||||
if do_split_3:
|
||||
for i3 in range(mat2_atten_shape // split_3_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_3 = i3 * split_3_slice_size
|
||||
end_idx_3 = (i3 + 1) * split_3_slice_size
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_torch_bmm(
|
||||
input[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
||||
mat2[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
||||
out=out
|
||||
)
|
||||
else:
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm(
|
||||
input[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
mat2[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
out=out
|
||||
)
|
||||
else:
|
||||
hidden_states[start_idx:end_idx] = original_torch_bmm(
|
||||
input[start_idx:end_idx],
|
||||
mat2[start_idx:end_idx],
|
||||
out=out
|
||||
)
|
||||
else:
|
||||
return original_torch_bmm(input, mat2, out=out)
|
||||
return hidden_states
|
||||
|
||||
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
||||
def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
|
||||
@wraps(torch.nn.functional.scaled_dot_product_attention)
|
||||
def dynamic_scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs):
|
||||
if query.device.type != "xpu":
|
||||
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
|
||||
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size())
|
||||
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
|
||||
is_unsqueezed = False
|
||||
if len(query.shape) == 3:
|
||||
query = query.unsqueeze(0)
|
||||
is_unsqueezed = True
|
||||
if len(key.shape) == 3:
|
||||
key = key.unsqueeze(0)
|
||||
if len(value.shape) == 3:
|
||||
value = value.unsqueeze(0)
|
||||
do_batch_split, do_head_split, do_query_split, split_batch_size, split_head_size, split_query_size = find_sdpa_slice_sizes(query.shape, key.shape, query.element_size(), slice_rate=attention_slice_rate, trigger_rate=sdpa_slice_trigger_rate)
|
||||
|
||||
# Slice SDPA
|
||||
if do_split:
|
||||
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
|
||||
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
|
||||
for i in range(batch_size_attention // split_slice_size):
|
||||
start_idx = i * split_slice_size
|
||||
end_idx = (i + 1) * split_slice_size
|
||||
if do_split_2:
|
||||
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_2 = i2 * split_2_slice_size
|
||||
end_idx_2 = (i2 + 1) * split_2_slice_size
|
||||
if do_split_3:
|
||||
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_3 = i3 * split_3_slice_size
|
||||
end_idx_3 = (i3 + 1) * split_3_slice_size
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_scaled_dot_product_attention(
|
||||
query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
||||
key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
||||
value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
||||
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask,
|
||||
dropout_p=dropout_p, is_causal=is_causal
|
||||
if do_batch_split:
|
||||
batch_size, attn_heads, query_len, _ = query.shape
|
||||
_, _, _, head_dim = value.shape
|
||||
hidden_states = torch.zeros((batch_size, attn_heads, query_len, head_dim), device=query.device, dtype=query.dtype)
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.expand((query.shape[0], query.shape[1], query.shape[2], key.shape[-2]))
|
||||
for ib in range(batch_size // split_batch_size):
|
||||
start_idx = ib * split_batch_size
|
||||
end_idx = (ib + 1) * split_batch_size
|
||||
if do_head_split:
|
||||
for ih in range(attn_heads // split_head_size): # pylint: disable=invalid-name
|
||||
start_idx_h = ih * split_head_size
|
||||
end_idx_h = (ih + 1) * split_head_size
|
||||
if do_query_split:
|
||||
for iq in range(query_len // split_query_size): # pylint: disable=invalid-name
|
||||
start_idx_q = iq * split_query_size
|
||||
end_idx_q = (iq + 1) * split_query_size
|
||||
hidden_states[start_idx:end_idx, start_idx_h:end_idx_h, start_idx_q:end_idx_q, :] = original_scaled_dot_product_attention(
|
||||
query[start_idx:end_idx, start_idx_h:end_idx_h, start_idx_q:end_idx_q, :],
|
||||
key[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
|
||||
value[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
|
||||
attn_mask=attn_mask[start_idx:end_idx, start_idx_h:end_idx_h, start_idx_q:end_idx_q, :] if attn_mask is not None else attn_mask,
|
||||
dropout_p=dropout_p, is_causal=is_causal, **kwargs
|
||||
)
|
||||
else:
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
|
||||
query[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
key[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
value[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
|
||||
dropout_p=dropout_p, is_causal=is_causal
|
||||
hidden_states[start_idx:end_idx, start_idx_h:end_idx_h, :, :] = original_scaled_dot_product_attention(
|
||||
query[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
|
||||
key[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
|
||||
value[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
|
||||
attn_mask=attn_mask[start_idx:end_idx, start_idx_h:end_idx_h, :, :] if attn_mask is not None else attn_mask,
|
||||
dropout_p=dropout_p, is_causal=is_causal, **kwargs
|
||||
)
|
||||
else:
|
||||
hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention(
|
||||
query[start_idx:end_idx],
|
||||
key[start_idx:end_idx],
|
||||
value[start_idx:end_idx],
|
||||
attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask,
|
||||
dropout_p=dropout_p, is_causal=is_causal
|
||||
hidden_states[start_idx:end_idx, :, :, :] = original_scaled_dot_product_attention(
|
||||
query[start_idx:end_idx, :, :, :],
|
||||
key[start_idx:end_idx, :, :, :],
|
||||
value[start_idx:end_idx, :, :, :],
|
||||
attn_mask=attn_mask[start_idx:end_idx, :, :, :] if attn_mask is not None else attn_mask,
|
||||
dropout_p=dropout_p, is_causal=is_causal, **kwargs
|
||||
)
|
||||
torch.xpu.synchronize(query.device)
|
||||
else:
|
||||
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
|
||||
hidden_states = original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
|
||||
if is_unsqueezed:
|
||||
hidden_states.squeeze(0)
|
||||
return hidden_states
|
||||
|
||||
@@ -1,310 +1,47 @@
|
||||
import os
|
||||
from functools import wraps
|
||||
import torch
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||
import diffusers #0.24.0 # pylint: disable=import-error
|
||||
from diffusers.models.attention_processor import Attention
|
||||
from diffusers.utils import USE_PEFT_BACKEND
|
||||
from functools import cache
|
||||
import diffusers # pylint: disable=import-error
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||
|
||||
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
|
||||
|
||||
@cache
|
||||
def find_slice_size(slice_size, slice_block_size):
|
||||
while (slice_size * slice_block_size) > attention_slice_rate:
|
||||
slice_size = slice_size // 2
|
||||
if slice_size <= 1:
|
||||
slice_size = 1
|
||||
break
|
||||
return slice_size
|
||||
|
||||
@cache
|
||||
def find_attention_slice_sizes(query_shape, query_element_size, query_device_type, slice_size=None):
|
||||
if len(query_shape) == 3:
|
||||
batch_size_attention, query_tokens, shape_three = query_shape
|
||||
shape_four = 1
|
||||
else:
|
||||
batch_size_attention, query_tokens, shape_three, shape_four = query_shape
|
||||
if slice_size is not None:
|
||||
batch_size_attention = slice_size
|
||||
|
||||
slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size
|
||||
block_size = batch_size_attention * slice_block_size
|
||||
|
||||
split_slice_size = batch_size_attention
|
||||
split_2_slice_size = query_tokens
|
||||
split_3_slice_size = shape_three
|
||||
|
||||
do_split = False
|
||||
do_split_2 = False
|
||||
do_split_3 = False
|
||||
|
||||
if query_device_type != "xpu":
|
||||
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
||||
|
||||
if block_size > attention_slice_rate:
|
||||
do_split = True
|
||||
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
|
||||
if split_slice_size * slice_block_size > attention_slice_rate:
|
||||
slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
|
||||
do_split_2 = True
|
||||
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
|
||||
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
|
||||
slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
|
||||
do_split_3 = True
|
||||
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
|
||||
|
||||
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
||||
|
||||
class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
|
||||
r"""
|
||||
Processor for implementing sliced attention.
|
||||
|
||||
Args:
|
||||
slice_size (`int`, *optional*):
|
||||
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
|
||||
`attention_head_dim` must be a multiple of the `slice_size`.
|
||||
"""
|
||||
|
||||
def __init__(self, slice_size):
|
||||
self.slice_size = slice_size
|
||||
|
||||
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states=None, attention_mask=None) -> torch.FloatTensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
dim = query.shape[-1]
|
||||
query = attn.head_to_batch_dim(query)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
batch_size_attention, query_tokens, shape_three = query.shape
|
||||
hidden_states = torch.zeros(
|
||||
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
|
||||
)
|
||||
|
||||
####################################################################
|
||||
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
||||
_, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type, slice_size=self.slice_size)
|
||||
|
||||
for i in range(batch_size_attention // split_slice_size):
|
||||
start_idx = i * split_slice_size
|
||||
end_idx = (i + 1) * split_slice_size
|
||||
if do_split_2:
|
||||
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_2 = i2 * split_2_slice_size
|
||||
end_idx_2 = (i2 + 1) * split_2_slice_size
|
||||
if do_split_3:
|
||||
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_3 = i3 * split_3_slice_size
|
||||
end_idx_3 = (i3 + 1) * split_3_slice_size
|
||||
|
||||
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
||||
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
|
||||
|
||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||
del query_slice
|
||||
del key_slice
|
||||
del attn_mask_slice
|
||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
|
||||
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
|
||||
del attn_slice
|
||||
else:
|
||||
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
|
||||
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
|
||||
|
||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||
del query_slice
|
||||
del key_slice
|
||||
del attn_mask_slice
|
||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
|
||||
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
|
||||
del attn_slice
|
||||
else:
|
||||
query_slice = query[start_idx:end_idx]
|
||||
key_slice = key[start_idx:end_idx]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
||||
|
||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||
del query_slice
|
||||
del key_slice
|
||||
del attn_mask_slice
|
||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
||||
|
||||
hidden_states[start_idx:end_idx] = attn_slice
|
||||
del attn_slice
|
||||
####################################################################
|
||||
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AttnProcessor:
|
||||
r"""
|
||||
Default processor for performing attention-related computations.
|
||||
"""
|
||||
|
||||
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states=None, attention_mask=None,
|
||||
temb=None, scale: float = 1.0) -> torch.Tensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
args = () if USE_PEFT_BACKEND else (scale,)
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states, *args)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states, *args)
|
||||
value = attn.to_v(encoder_hidden_states, *args)
|
||||
|
||||
query = attn.head_to_batch_dim(query)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
####################################################################
|
||||
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
||||
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
|
||||
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
|
||||
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type)
|
||||
|
||||
if do_split:
|
||||
for i in range(batch_size_attention // split_slice_size):
|
||||
start_idx = i * split_slice_size
|
||||
end_idx = (i + 1) * split_slice_size
|
||||
if do_split_2:
|
||||
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_2 = i2 * split_2_slice_size
|
||||
end_idx_2 = (i2 + 1) * split_2_slice_size
|
||||
if do_split_3:
|
||||
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_3 = i3 * split_3_slice_size
|
||||
end_idx_3 = (i3 + 1) * split_3_slice_size
|
||||
|
||||
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
||||
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
|
||||
|
||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||
del query_slice
|
||||
del key_slice
|
||||
del attn_mask_slice
|
||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
|
||||
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
|
||||
del attn_slice
|
||||
else:
|
||||
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
|
||||
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
|
||||
|
||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||
del query_slice
|
||||
del key_slice
|
||||
del attn_mask_slice
|
||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
|
||||
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
|
||||
del attn_slice
|
||||
else:
|
||||
query_slice = query[start_idx:end_idx]
|
||||
key_slice = key[start_idx:end_idx]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
||||
|
||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||
del query_slice
|
||||
del key_slice
|
||||
del attn_mask_slice
|
||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
||||
|
||||
hidden_states[start_idx:end_idx] = attn_slice
|
||||
del attn_slice
|
||||
else:
|
||||
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
||||
hidden_states = torch.bmm(attention_probs, value)
|
||||
####################################################################
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states, *args)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
def ipex_diffusers():
|
||||
#ARC GPUs can't allocate more than 4GB to a single block:
|
||||
diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor
|
||||
diffusers.models.attention_processor.AttnProcessor = AttnProcessor
|
||||
# Diffusers FreeU
|
||||
original_fourier_filter = diffusers.utils.torch_utils.fourier_filter
|
||||
@wraps(diffusers.utils.torch_utils.fourier_filter)
|
||||
def fourier_filter(x_in, threshold, scale):
|
||||
return_dtype = x_in.dtype
|
||||
return original_fourier_filter(x_in.to(dtype=torch.float32), threshold, scale).to(dtype=return_dtype)
|
||||
|
||||
|
||||
# fp64 error
|
||||
class FluxPosEmbed(torch.nn.Module):
|
||||
def __init__(self, theta: int, axes_dim):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
n_axes = ids.shape[-1]
|
||||
cos_out = []
|
||||
sin_out = []
|
||||
pos = ids.float()
|
||||
for i in range(n_axes):
|
||||
cos, sin = diffusers.models.embeddings.get_1d_rotary_pos_embed(
|
||||
self.axes_dim[i],
|
||||
pos[:, i],
|
||||
theta=self.theta,
|
||||
repeat_interleave_real=True,
|
||||
use_real=True,
|
||||
freqs_dtype=torch.float32,
|
||||
)
|
||||
cos_out.append(cos)
|
||||
sin_out.append(sin)
|
||||
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
|
||||
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
def ipex_diffusers(device_supports_fp64=False, can_allocate_plus_4gb=False):
|
||||
diffusers.utils.torch_utils.fourier_filter = fourier_filter
|
||||
if not device_supports_fp64:
|
||||
diffusers.models.embeddings.FluxPosEmbed = FluxPosEmbed
|
||||
|
||||
@@ -5,7 +5,7 @@ import intel_extension_for_pytorch._C as core # pylint: disable=import-error, un
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||
|
||||
device_supports_fp64 = torch.xpu.has_fp64_dtype()
|
||||
device_supports_fp64 = torch.xpu.has_fp64_dtype() if hasattr(torch.xpu, "has_fp64_dtype") else torch.xpu.get_device_properties("xpu").has_fp64
|
||||
OptState = ipex.cpu.autocast._grad_scaler.OptState
|
||||
_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
|
||||
_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state
|
||||
|
||||
@@ -1,6 +1,20 @@
|
||||
import contextlib
|
||||
import os
|
||||
from functools import wraps
|
||||
from contextlib import nullcontext
|
||||
import torch
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||
import numpy as np
|
||||
|
||||
device_supports_fp64 = torch.xpu.has_fp64_dtype() if hasattr(torch.xpu, "has_fp64_dtype") else torch.xpu.get_device_properties("xpu").has_fp64
|
||||
if os.environ.get('IPEX_FORCE_ATTENTION_SLICE', '0') == '0' and (torch.xpu.get_device_properties("xpu").total_memory / 1024 / 1024 / 1024) > 4.1:
|
||||
try:
|
||||
x = torch.ones((33000,33000), dtype=torch.float32, device="xpu")
|
||||
del x
|
||||
torch.xpu.empty_cache()
|
||||
can_allocate_plus_4gb = True
|
||||
except Exception:
|
||||
can_allocate_plus_4gb = False
|
||||
else:
|
||||
can_allocate_plus_4gb = bool(os.environ.get('IPEX_FORCE_ATTENTION_SLICE', '0') == '-1')
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
|
||||
|
||||
@@ -11,7 +25,7 @@ class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstr
|
||||
return module.to("xpu")
|
||||
|
||||
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
|
||||
return contextlib.nullcontext()
|
||||
return nullcontext()
|
||||
|
||||
@property
|
||||
def is_cuda(self):
|
||||
@@ -21,21 +35,23 @@ def check_device(device):
|
||||
return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
|
||||
|
||||
def return_xpu(device):
|
||||
return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu"
|
||||
return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device(f"xpu:{device.index}" if device.index is not None else "xpu") if isinstance(device, torch.device) else "xpu"
|
||||
|
||||
|
||||
# Autocast
|
||||
original_autocast = torch.autocast
|
||||
def ipex_autocast(*args, **kwargs):
|
||||
if len(args) > 0 and args[0] == "cuda":
|
||||
return original_autocast("xpu", *args[1:], **kwargs)
|
||||
original_autocast_init = torch.amp.autocast_mode.autocast.__init__
|
||||
@wraps(torch.amp.autocast_mode.autocast.__init__)
|
||||
def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=None):
|
||||
if device_type == "cuda":
|
||||
return original_autocast_init(self, device_type="xpu", dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
|
||||
else:
|
||||
return original_autocast(*args, **kwargs)
|
||||
return original_autocast_init(self, device_type=device_type, dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
|
||||
|
||||
# Latent Antialias CPU Offload:
|
||||
original_interpolate = torch.nn.functional.interpolate
|
||||
@wraps(torch.nn.functional.interpolate)
|
||||
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
|
||||
if antialias or align_corners is not None:
|
||||
if mode in {'bicubic', 'bilinear'}:
|
||||
return_device = tensor.device
|
||||
return_dtype = tensor.dtype
|
||||
return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode,
|
||||
@@ -44,42 +60,72 @@ def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corn
|
||||
return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode,
|
||||
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias)
|
||||
|
||||
|
||||
# Diffusers Float64 (Alchemist GPUs doesn't support 64 bit):
|
||||
original_from_numpy = torch.from_numpy
|
||||
@wraps(torch.from_numpy)
|
||||
def from_numpy(ndarray):
|
||||
if ndarray.dtype == float:
|
||||
return original_from_numpy(ndarray.astype('float32'))
|
||||
else:
|
||||
return original_from_numpy(ndarray)
|
||||
|
||||
if torch.xpu.has_fp64_dtype():
|
||||
original_torch_bmm = torch.bmm
|
||||
original_as_tensor = torch.as_tensor
|
||||
@wraps(torch.as_tensor)
|
||||
def as_tensor(data, dtype=None, device=None):
|
||||
if check_device(device):
|
||||
device = return_xpu(device)
|
||||
if isinstance(data, np.ndarray) and data.dtype == float and not (
|
||||
(isinstance(device, torch.device) and device.type == "cpu") or (isinstance(device, str) and "cpu" in device)):
|
||||
return original_as_tensor(data, dtype=torch.float32, device=device)
|
||||
else:
|
||||
return original_as_tensor(data, dtype=dtype, device=device)
|
||||
|
||||
|
||||
if can_allocate_plus_4gb:
|
||||
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
||||
else:
|
||||
# 32 bit attention workarounds for Alchemist:
|
||||
try:
|
||||
from .attention import torch_bmm_32_bit as original_torch_bmm
|
||||
from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention
|
||||
from .attention import dynamic_scaled_dot_product_attention as original_scaled_dot_product_attention
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
original_torch_bmm = torch.bmm
|
||||
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
||||
|
||||
@wraps(torch.nn.functional.scaled_dot_product_attention)
|
||||
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs):
|
||||
if query.dtype != key.dtype:
|
||||
key = key.to(dtype=query.dtype)
|
||||
if query.dtype != value.dtype:
|
||||
value = value.to(dtype=query.dtype)
|
||||
if attn_mask is not None and query.dtype != attn_mask.dtype:
|
||||
attn_mask = attn_mask.to(dtype=query.dtype)
|
||||
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
|
||||
|
||||
# Data Type Errors:
|
||||
original_torch_bmm = torch.bmm
|
||||
@wraps(torch.bmm)
|
||||
def torch_bmm(input, mat2, *, out=None):
|
||||
if input.dtype != mat2.dtype:
|
||||
mat2 = mat2.to(input.dtype)
|
||||
return original_torch_bmm(input, mat2, out=out)
|
||||
|
||||
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
|
||||
if query.dtype != key.dtype:
|
||||
key = key.to(dtype=query.dtype)
|
||||
if query.dtype != value.dtype:
|
||||
value = value.to(dtype=query.dtype)
|
||||
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
|
||||
# Diffusers FreeU
|
||||
original_fft_fftn = torch.fft.fftn
|
||||
@wraps(torch.fft.fftn)
|
||||
def fft_fftn(input, s=None, dim=None, norm=None, *, out=None):
|
||||
return_dtype = input.dtype
|
||||
return original_fft_fftn(input.to(dtype=torch.float32), s=s, dim=dim, norm=norm, out=out).to(dtype=return_dtype)
|
||||
|
||||
# Diffusers FreeU
|
||||
original_fft_ifftn = torch.fft.ifftn
|
||||
@wraps(torch.fft.ifftn)
|
||||
def fft_ifftn(input, s=None, dim=None, norm=None, *, out=None):
|
||||
return_dtype = input.dtype
|
||||
return original_fft_ifftn(input.to(dtype=torch.float32), s=s, dim=dim, norm=norm, out=out).to(dtype=return_dtype)
|
||||
|
||||
# A1111 FP16
|
||||
original_functional_group_norm = torch.nn.functional.group_norm
|
||||
@wraps(torch.nn.functional.group_norm)
|
||||
def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05):
|
||||
if weight is not None and input.dtype != weight.data.dtype:
|
||||
input = input.to(dtype=weight.data.dtype)
|
||||
@@ -89,6 +135,7 @@ def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05):
|
||||
|
||||
# A1111 BF16
|
||||
original_functional_layer_norm = torch.nn.functional.layer_norm
|
||||
@wraps(torch.nn.functional.layer_norm)
|
||||
def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
|
||||
if weight is not None and input.dtype != weight.data.dtype:
|
||||
input = input.to(dtype=weight.data.dtype)
|
||||
@@ -98,6 +145,7 @@ def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1
|
||||
|
||||
# Training
|
||||
original_functional_linear = torch.nn.functional.linear
|
||||
@wraps(torch.nn.functional.linear)
|
||||
def functional_linear(input, weight, bias=None):
|
||||
if input.dtype != weight.data.dtype:
|
||||
input = input.to(dtype=weight.data.dtype)
|
||||
@@ -105,7 +153,17 @@ def functional_linear(input, weight, bias=None):
|
||||
bias.data = bias.data.to(dtype=weight.data.dtype)
|
||||
return original_functional_linear(input, weight, bias=bias)
|
||||
|
||||
original_functional_conv1d = torch.nn.functional.conv1d
|
||||
@wraps(torch.nn.functional.conv1d)
|
||||
def functional_conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
||||
if input.dtype != weight.data.dtype:
|
||||
input = input.to(dtype=weight.data.dtype)
|
||||
if bias is not None and bias.data.dtype != weight.data.dtype:
|
||||
bias.data = bias.data.to(dtype=weight.data.dtype)
|
||||
return original_functional_conv1d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
|
||||
original_functional_conv2d = torch.nn.functional.conv2d
|
||||
@wraps(torch.nn.functional.conv2d)
|
||||
def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
||||
if input.dtype != weight.data.dtype:
|
||||
input = input.to(dtype=weight.data.dtype)
|
||||
@@ -113,16 +171,19 @@ def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1,
|
||||
bias.data = bias.data.to(dtype=weight.data.dtype)
|
||||
return original_functional_conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
|
||||
# A1111 Embedding BF16
|
||||
original_torch_cat = torch.cat
|
||||
def torch_cat(tensor, *args, **kwargs):
|
||||
if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
|
||||
return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
|
||||
else:
|
||||
return original_torch_cat(tensor, *args, **kwargs)
|
||||
# LTX Video
|
||||
original_functional_conv3d = torch.nn.functional.conv3d
|
||||
@wraps(torch.nn.functional.conv3d)
|
||||
def functional_conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
||||
if input.dtype != weight.data.dtype:
|
||||
input = input.to(dtype=weight.data.dtype)
|
||||
if bias is not None and bias.data.dtype != weight.data.dtype:
|
||||
bias.data = bias.data.to(dtype=weight.data.dtype)
|
||||
return original_functional_conv3d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
|
||||
# SwinIR BF16:
|
||||
original_functional_pad = torch.nn.functional.pad
|
||||
@wraps(torch.nn.functional.pad)
|
||||
def functional_pad(input, pad, mode='constant', value=None):
|
||||
if mode == 'reflect' and input.dtype == torch.bfloat16:
|
||||
return original_functional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16)
|
||||
@@ -131,13 +192,21 @@ def functional_pad(input, pad, mode='constant', value=None):
|
||||
|
||||
|
||||
original_torch_tensor = torch.tensor
|
||||
def torch_tensor(*args, device=None, **kwargs):
|
||||
@wraps(torch.tensor)
|
||||
def torch_tensor(data, *args, dtype=None, device=None, **kwargs):
|
||||
global device_supports_fp64
|
||||
if check_device(device):
|
||||
return original_torch_tensor(*args, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
return original_torch_tensor(*args, device=device, **kwargs)
|
||||
device = return_xpu(device)
|
||||
if not device_supports_fp64:
|
||||
if (isinstance(device, torch.device) and device.type == "xpu") or (isinstance(device, str) and "xpu" in device):
|
||||
if dtype == torch.float64:
|
||||
dtype = torch.float32
|
||||
elif dtype is None and (hasattr(data, "dtype") and (data.dtype == torch.float64 or data.dtype == float)):
|
||||
dtype = torch.float32
|
||||
return original_torch_tensor(data, *args, dtype=dtype, device=device, **kwargs)
|
||||
|
||||
original_Tensor_to = torch.Tensor.to
|
||||
@wraps(torch.Tensor.to)
|
||||
def Tensor_to(self, device=None, *args, **kwargs):
|
||||
if check_device(device):
|
||||
return original_Tensor_to(self, return_xpu(device), *args, **kwargs)
|
||||
@@ -145,13 +214,25 @@ def Tensor_to(self, device=None, *args, **kwargs):
|
||||
return original_Tensor_to(self, device, *args, **kwargs)
|
||||
|
||||
original_Tensor_cuda = torch.Tensor.cuda
|
||||
@wraps(torch.Tensor.cuda)
|
||||
def Tensor_cuda(self, device=None, *args, **kwargs):
|
||||
if check_device(device):
|
||||
return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs)
|
||||
else:
|
||||
return original_Tensor_cuda(self, device, *args, **kwargs)
|
||||
|
||||
original_Tensor_pin_memory = torch.Tensor.pin_memory
|
||||
@wraps(torch.Tensor.pin_memory)
|
||||
def Tensor_pin_memory(self, device=None, *args, **kwargs):
|
||||
if device is None:
|
||||
device = "xpu"
|
||||
if check_device(device):
|
||||
return original_Tensor_pin_memory(self, return_xpu(device), *args, **kwargs)
|
||||
else:
|
||||
return original_Tensor_pin_memory(self, device, *args, **kwargs)
|
||||
|
||||
original_UntypedStorage_init = torch.UntypedStorage.__init__
|
||||
@wraps(torch.UntypedStorage.__init__)
|
||||
def UntypedStorage_init(*args, device=None, **kwargs):
|
||||
if check_device(device):
|
||||
return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs)
|
||||
@@ -159,6 +240,7 @@ def UntypedStorage_init(*args, device=None, **kwargs):
|
||||
return original_UntypedStorage_init(*args, device=device, **kwargs)
|
||||
|
||||
original_UntypedStorage_cuda = torch.UntypedStorage.cuda
|
||||
@wraps(torch.UntypedStorage.cuda)
|
||||
def UntypedStorage_cuda(self, device=None, *args, **kwargs):
|
||||
if check_device(device):
|
||||
return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs)
|
||||
@@ -166,6 +248,7 @@ def UntypedStorage_cuda(self, device=None, *args, **kwargs):
|
||||
return original_UntypedStorage_cuda(self, device, *args, **kwargs)
|
||||
|
||||
original_torch_empty = torch.empty
|
||||
@wraps(torch.empty)
|
||||
def torch_empty(*args, device=None, **kwargs):
|
||||
if check_device(device):
|
||||
return original_torch_empty(*args, device=return_xpu(device), **kwargs)
|
||||
@@ -173,13 +256,17 @@ def torch_empty(*args, device=None, **kwargs):
|
||||
return original_torch_empty(*args, device=device, **kwargs)
|
||||
|
||||
original_torch_randn = torch.randn
|
||||
def torch_randn(*args, device=None, **kwargs):
|
||||
@wraps(torch.randn)
|
||||
def torch_randn(*args, device=None, dtype=None, **kwargs):
|
||||
if dtype is bytes:
|
||||
dtype = None
|
||||
if check_device(device):
|
||||
return original_torch_randn(*args, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
return original_torch_randn(*args, device=device, **kwargs)
|
||||
|
||||
original_torch_ones = torch.ones
|
||||
@wraps(torch.ones)
|
||||
def torch_ones(*args, device=None, **kwargs):
|
||||
if check_device(device):
|
||||
return original_torch_ones(*args, device=return_xpu(device), **kwargs)
|
||||
@@ -187,62 +274,94 @@ def torch_ones(*args, device=None, **kwargs):
|
||||
return original_torch_ones(*args, device=device, **kwargs)
|
||||
|
||||
original_torch_zeros = torch.zeros
|
||||
@wraps(torch.zeros)
|
||||
def torch_zeros(*args, device=None, **kwargs):
|
||||
if check_device(device):
|
||||
return original_torch_zeros(*args, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
return original_torch_zeros(*args, device=device, **kwargs)
|
||||
|
||||
original_torch_full = torch.full
|
||||
@wraps(torch.full)
|
||||
def torch_full(*args, device=None, **kwargs):
|
||||
if check_device(device):
|
||||
return original_torch_full(*args, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
return original_torch_full(*args, device=device, **kwargs)
|
||||
|
||||
original_torch_linspace = torch.linspace
|
||||
@wraps(torch.linspace)
|
||||
def torch_linspace(*args, device=None, **kwargs):
|
||||
if check_device(device):
|
||||
return original_torch_linspace(*args, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
return original_torch_linspace(*args, device=device, **kwargs)
|
||||
|
||||
original_torch_load = torch.load
|
||||
@wraps(torch.load)
|
||||
def torch_load(f, map_location=None, *args, **kwargs):
|
||||
if map_location is None:
|
||||
map_location = "xpu"
|
||||
if check_device(map_location):
|
||||
return original_torch_load(f, *args, map_location=return_xpu(map_location), **kwargs)
|
||||
else:
|
||||
return original_torch_load(f, *args, map_location=map_location, **kwargs)
|
||||
|
||||
original_torch_Generator = torch.Generator
|
||||
@wraps(torch.Generator)
|
||||
def torch_Generator(device=None):
|
||||
if check_device(device):
|
||||
return original_torch_Generator(return_xpu(device))
|
||||
else:
|
||||
return original_torch_Generator(device)
|
||||
|
||||
original_torch_load = torch.load
|
||||
def torch_load(f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs):
|
||||
if check_device(map_location):
|
||||
return original_torch_load(f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs)
|
||||
@wraps(torch.cuda.synchronize)
|
||||
def torch_cuda_synchronize(device=None):
|
||||
if check_device(device):
|
||||
return torch.xpu.synchronize(return_xpu(device))
|
||||
else:
|
||||
return original_torch_load(f, map_location=map_location, pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs)
|
||||
return torch.xpu.synchronize(device)
|
||||
|
||||
|
||||
# Hijack Functions:
|
||||
def ipex_hijacks():
|
||||
def ipex_hijacks(legacy=True):
|
||||
global device_supports_fp64, can_allocate_plus_4gb
|
||||
if legacy and float(torch.__version__[:3]) < 2.5:
|
||||
torch.nn.functional.interpolate = interpolate
|
||||
torch.tensor = torch_tensor
|
||||
torch.Tensor.to = Tensor_to
|
||||
torch.Tensor.cuda = Tensor_cuda
|
||||
torch.Tensor.pin_memory = Tensor_pin_memory
|
||||
torch.UntypedStorage.__init__ = UntypedStorage_init
|
||||
torch.UntypedStorage.cuda = UntypedStorage_cuda
|
||||
torch.empty = torch_empty
|
||||
torch.randn = torch_randn
|
||||
torch.ones = torch_ones
|
||||
torch.zeros = torch_zeros
|
||||
torch.full = torch_full
|
||||
torch.linspace = torch_linspace
|
||||
torch.Generator = torch_Generator
|
||||
torch.load = torch_load
|
||||
torch.Generator = torch_Generator
|
||||
torch.cuda.synchronize = torch_cuda_synchronize
|
||||
|
||||
torch.backends.cuda.sdp_kernel = return_null_context
|
||||
torch.nn.DataParallel = DummyDataParallel
|
||||
torch.UntypedStorage.is_cuda = is_cuda
|
||||
torch.autocast = ipex_autocast
|
||||
torch.amp.autocast_mode.autocast.__init__ = autocast_init
|
||||
|
||||
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
|
||||
torch.nn.functional.group_norm = functional_group_norm
|
||||
torch.nn.functional.layer_norm = functional_layer_norm
|
||||
torch.nn.functional.linear = functional_linear
|
||||
torch.nn.functional.conv1d = functional_conv1d
|
||||
torch.nn.functional.conv2d = functional_conv2d
|
||||
torch.nn.functional.interpolate = interpolate
|
||||
torch.nn.functional.conv3d = functional_conv3d
|
||||
torch.nn.functional.pad = functional_pad
|
||||
|
||||
torch.bmm = torch_bmm
|
||||
torch.cat = torch_cat
|
||||
if not torch.xpu.has_fp64_dtype():
|
||||
torch.fft.fftn = fft_fftn
|
||||
torch.fft.ifftn = fft_ifftn
|
||||
if not device_supports_fp64:
|
||||
torch.from_numpy = from_numpy
|
||||
torch.as_tensor = as_tensor
|
||||
return device_supports_fp64, can_allocate_plus_4gb
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
import torch
|
||||
|
||||
|
||||
def init_ipex():
|
||||
"""
|
||||
Try to import `intel_extension_for_pytorch`, and apply
|
||||
the hijacks using `library.ipex.ipex_init`.
|
||||
|
||||
If IPEX is not installed, this function does nothing.
|
||||
"""
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex # noqa
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
try:
|
||||
from library.ipex import ipex_init
|
||||
|
||||
if torch.xpu.is_available():
|
||||
is_initialized, error_message = ipex_init()
|
||||
if not is_initialized:
|
||||
print("failed to initialize ipex:", error_message)
|
||||
except Exception as e:
|
||||
print("failed to initialize ipex:", e)
|
||||
186
library/jpeg_xl_util.py
Normal file
186
library/jpeg_xl_util.py
Normal file
@@ -0,0 +1,186 @@
|
||||
# Modified from https://github.com/Fraetor/jxl_decode Original license: MIT
|
||||
# Added partial read support for up to 200x speedup
|
||||
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
|
||||
class JXLBitstream:
|
||||
"""
|
||||
A stream of bits with methods for easy handling.
|
||||
"""
|
||||
|
||||
def __init__(self, file, offset: int = 0, offsets: List[List[int]] = None):
|
||||
self.shift = 0
|
||||
self.bitstream = bytearray()
|
||||
self.file = file
|
||||
self.offset = offset
|
||||
self.offsets = offsets
|
||||
if self.offsets:
|
||||
self.offset = self.offsets[0][1]
|
||||
self.previous_data_len = 0
|
||||
self.index = 0
|
||||
self.file.seek(self.offset)
|
||||
|
||||
def get_bits(self, length: int = 1) -> int:
|
||||
if self.offsets and self.shift + length > self.previous_data_len + self.offsets[self.index][2]:
|
||||
self.partial_to_read_length = length
|
||||
if self.shift < self.previous_data_len + self.offsets[self.index][2]:
|
||||
self.partial_read(0, length)
|
||||
self.bitstream.extend(self.file.read(self.partial_to_read_length))
|
||||
else:
|
||||
self.bitstream.extend(self.file.read(length))
|
||||
bitmask = 2**length - 1
|
||||
bits = (int.from_bytes(self.bitstream, "little") >> self.shift) & bitmask
|
||||
self.shift += length
|
||||
return bits
|
||||
|
||||
def partial_read(self, current_length: int, length: int) -> None:
|
||||
self.previous_data_len += self.offsets[self.index][2]
|
||||
to_read_length = self.previous_data_len - (self.shift + current_length)
|
||||
self.bitstream.extend(self.file.read(to_read_length))
|
||||
current_length += to_read_length
|
||||
self.partial_to_read_length -= to_read_length
|
||||
self.index += 1
|
||||
self.file.seek(self.offsets[self.index][1])
|
||||
if self.shift + length > self.previous_data_len + self.offsets[self.index][2]:
|
||||
self.partial_read(current_length, length)
|
||||
|
||||
|
||||
def decode_codestream(file, offset: int = 0, offsets: List[List[int]] = None) -> Tuple[int,int]:
|
||||
"""
|
||||
Decodes the actual codestream.
|
||||
JXL codestream specification: http://www-internal/2022/18181-1
|
||||
"""
|
||||
|
||||
# Convert codestream to int within an object to get some handy methods.
|
||||
codestream = JXLBitstream(file, offset=offset, offsets=offsets)
|
||||
|
||||
# Skip signature
|
||||
codestream.get_bits(16)
|
||||
|
||||
# SizeHeader
|
||||
div8 = codestream.get_bits(1)
|
||||
if div8:
|
||||
height = 8 * (1 + codestream.get_bits(5))
|
||||
else:
|
||||
distribution = codestream.get_bits(2)
|
||||
match distribution:
|
||||
case 0:
|
||||
height = 1 + codestream.get_bits(9)
|
||||
case 1:
|
||||
height = 1 + codestream.get_bits(13)
|
||||
case 2:
|
||||
height = 1 + codestream.get_bits(18)
|
||||
case 3:
|
||||
height = 1 + codestream.get_bits(30)
|
||||
ratio = codestream.get_bits(3)
|
||||
if div8 and not ratio:
|
||||
width = 8 * (1 + codestream.get_bits(5))
|
||||
elif not ratio:
|
||||
distribution = codestream.get_bits(2)
|
||||
match distribution:
|
||||
case 0:
|
||||
width = 1 + codestream.get_bits(9)
|
||||
case 1:
|
||||
width = 1 + codestream.get_bits(13)
|
||||
case 2:
|
||||
width = 1 + codestream.get_bits(18)
|
||||
case 3:
|
||||
width = 1 + codestream.get_bits(30)
|
||||
else:
|
||||
match ratio:
|
||||
case 1:
|
||||
width = height
|
||||
case 2:
|
||||
width = (height * 12) // 10
|
||||
case 3:
|
||||
width = (height * 4) // 3
|
||||
case 4:
|
||||
width = (height * 3) // 2
|
||||
case 5:
|
||||
width = (height * 16) // 9
|
||||
case 6:
|
||||
width = (height * 5) // 4
|
||||
case 7:
|
||||
width = (height * 2) // 1
|
||||
return width, height
|
||||
|
||||
|
||||
def decode_container(file) -> Tuple[int,int]:
|
||||
"""
|
||||
Parses the ISOBMFF container, extracts the codestream, and decodes it.
|
||||
JXL container specification: http://www-internal/2022/18181-2
|
||||
"""
|
||||
|
||||
def parse_box(file, file_start: int) -> dict:
|
||||
file.seek(file_start)
|
||||
LBox = int.from_bytes(file.read(4), "big")
|
||||
XLBox = None
|
||||
if 1 < LBox <= 8:
|
||||
raise ValueError(f"Invalid LBox at byte {file_start}.")
|
||||
if LBox == 1:
|
||||
file.seek(file_start + 8)
|
||||
XLBox = int.from_bytes(file.read(8), "big")
|
||||
if XLBox <= 16:
|
||||
raise ValueError(f"Invalid XLBox at byte {file_start}.")
|
||||
if XLBox:
|
||||
header_length = 16
|
||||
box_length = XLBox
|
||||
else:
|
||||
header_length = 8
|
||||
if LBox == 0:
|
||||
box_length = os.fstat(file.fileno()).st_size - file_start
|
||||
else:
|
||||
box_length = LBox
|
||||
file.seek(file_start + 4)
|
||||
box_type = file.read(4)
|
||||
file.seek(file_start)
|
||||
return {
|
||||
"length": box_length,
|
||||
"type": box_type,
|
||||
"offset": header_length,
|
||||
}
|
||||
|
||||
file.seek(0)
|
||||
# Reject files missing required boxes. These two boxes are required to be at
|
||||
# the start and contain no values, so we can manually check there presence.
|
||||
# Signature box. (Redundant as has already been checked.)
|
||||
if file.read(12) != bytes.fromhex("0000000C 4A584C20 0D0A870A"):
|
||||
raise ValueError("Invalid signature box.")
|
||||
# File Type box.
|
||||
if file.read(20) != bytes.fromhex(
|
||||
"00000014 66747970 6A786C20 00000000 6A786C20"
|
||||
):
|
||||
raise ValueError("Invalid file type box.")
|
||||
|
||||
offset = 0
|
||||
offsets = []
|
||||
data_offset_not_found = True
|
||||
container_pointer = 32
|
||||
file_size = os.fstat(file.fileno()).st_size
|
||||
while data_offset_not_found:
|
||||
box = parse_box(file, container_pointer)
|
||||
match box["type"]:
|
||||
case b"jxlc":
|
||||
offset = container_pointer + box["offset"]
|
||||
data_offset_not_found = False
|
||||
case b"jxlp":
|
||||
file.seek(container_pointer + box["offset"])
|
||||
index = int.from_bytes(file.read(4), "big")
|
||||
offsets.append([index, container_pointer + box["offset"] + 4, box["length"] - box["offset"] - 4])
|
||||
container_pointer += box["length"]
|
||||
if container_pointer >= file_size:
|
||||
data_offset_not_found = False
|
||||
|
||||
if offsets:
|
||||
offsets.sort(key=lambda i: i[0])
|
||||
file.seek(0)
|
||||
|
||||
return decode_codestream(file, offset=offset, offsets=offsets)
|
||||
|
||||
|
||||
def get_jxl_size(path: str) -> Tuple[int,int]:
|
||||
with open(path, "rb") as file:
|
||||
if file.read(2) == bytes.fromhex("FF0A"):
|
||||
return decode_codestream(file)
|
||||
return decode_container(file)
|
||||
@@ -17,7 +17,6 @@ from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
||||
from diffusers.utils import logging
|
||||
|
||||
|
||||
try:
|
||||
from diffusers.utils import PIL_INTERPOLATION
|
||||
except ImportError:
|
||||
@@ -626,7 +625,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
print(height, width)
|
||||
logger.info(f'{height} {width}')
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
|
||||
@@ -3,16 +3,20 @@
|
||||
|
||||
import math
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
from library.device_utils import init_ipex
|
||||
init_ipex()
|
||||
|
||||
import diffusers
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
|
||||
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel
|
||||
from safetensors.torch import load_file, save_file
|
||||
from library.original_unet import UNet2DConditionModel
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# DiffUsers版StableDiffusionのモデルパラメータ
|
||||
NUM_TRAIN_TIMESTEPS = 1000
|
||||
@@ -639,16 +643,15 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
|
||||
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
|
||||
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
|
||||
|
||||
# rename or add position_ids
|
||||
# remove position_ids for newer transformer, which causes error :(
|
||||
ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
|
||||
if ANOTHER_POSITION_IDS_KEY in new_sd:
|
||||
# waifu diffusion v1.4
|
||||
position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
|
||||
del new_sd[ANOTHER_POSITION_IDS_KEY]
|
||||
else:
|
||||
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
|
||||
|
||||
new_sd["text_model.embeddings.position_ids"] = position_ids
|
||||
if "text_model.embeddings.position_ids" in new_sd:
|
||||
del new_sd["text_model.embeddings.position_ids"]
|
||||
|
||||
return new_sd
|
||||
|
||||
|
||||
@@ -944,7 +947,7 @@ def convert_vae_state_dict(vae_state_dict):
|
||||
for k, v in new_state_dict.items():
|
||||
for weight_name in weights_to_convert:
|
||||
if f"mid.attn_1.{weight_name}.weight" in k:
|
||||
# print(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1")
|
||||
# logger.info(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1")
|
||||
new_state_dict[k] = reshape_weight_for_sd(v)
|
||||
|
||||
return new_state_dict
|
||||
@@ -1002,7 +1005,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt
|
||||
|
||||
unet = UNet2DConditionModel(**unet_config).to(device)
|
||||
info = unet.load_state_dict(converted_unet_checkpoint)
|
||||
print("loading u-net:", info)
|
||||
logger.info(f"loading u-net: {info}")
|
||||
|
||||
# Convert the VAE model.
|
||||
vae_config = create_vae_diffusers_config()
|
||||
@@ -1010,7 +1013,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt
|
||||
|
||||
vae = AutoencoderKL(**vae_config).to(device)
|
||||
info = vae.load_state_dict(converted_vae_checkpoint)
|
||||
print("loading vae:", info)
|
||||
logger.info(f"loading vae: {info}")
|
||||
|
||||
# convert text_model
|
||||
if v2:
|
||||
@@ -1044,7 +1047,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt
|
||||
# logging.set_verbosity_error() # don't show annoying warning
|
||||
# text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
|
||||
# logging.set_verbosity_warning()
|
||||
# print(f"config: {text_model.config}")
|
||||
# logger.info(f"config: {text_model.config}")
|
||||
cfg = CLIPTextConfig(
|
||||
vocab_size=49408,
|
||||
hidden_size=768,
|
||||
@@ -1067,7 +1070,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt
|
||||
)
|
||||
text_model = CLIPTextModel._from_config(cfg)
|
||||
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
||||
print("loading text encoder:", info)
|
||||
logger.info(f"loading text encoder: {info}")
|
||||
|
||||
return text_model, vae, unet
|
||||
|
||||
@@ -1142,7 +1145,7 @@ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=Fals
|
||||
|
||||
# 最後の層などを捏造するか
|
||||
if make_dummy_weights:
|
||||
print("make dummy weights for resblock.23, text_projection and logit scale.")
|
||||
logger.info("make dummy weights for resblock.23, text_projection and logit scale.")
|
||||
keys = list(new_sd.keys())
|
||||
for key in keys:
|
||||
if key.startswith("transformer.resblocks.22."):
|
||||
@@ -1261,14 +1264,14 @@ VAE_PREFIX = "first_stage_model."
|
||||
|
||||
|
||||
def load_vae(vae_id, dtype):
|
||||
print(f"load VAE: {vae_id}")
|
||||
logger.info(f"load VAE: {vae_id}")
|
||||
if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
|
||||
# Diffusers local/remote
|
||||
try:
|
||||
vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
|
||||
except EnvironmentError as e:
|
||||
print(f"exception occurs in loading vae: {e}")
|
||||
print("retry with subfolder='vae'")
|
||||
logger.error(f"exception occurs in loading vae: {e}")
|
||||
logger.error("retry with subfolder='vae'")
|
||||
vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
|
||||
return vae
|
||||
|
||||
@@ -1340,13 +1343,13 @@ def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64)
|
||||
|
||||
if __name__ == "__main__":
|
||||
resos = make_bucket_resolutions((512, 768))
|
||||
print(len(resos))
|
||||
print(resos)
|
||||
logger.info(f"{len(resos)}")
|
||||
logger.info(f"{resos}")
|
||||
aspect_ratios = [w / h for w, h in resos]
|
||||
print(aspect_ratios)
|
||||
logger.info(f"{aspect_ratios}")
|
||||
|
||||
ars = set()
|
||||
for ar in aspect_ratios:
|
||||
if ar in ars:
|
||||
print("error! duplicate ar:", ar)
|
||||
logger.error(f"error! duplicate ar: {ar}")
|
||||
ars.add(ar)
|
||||
|
||||
@@ -113,6 +113,10 @@ import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from einops import rearrange
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280)
|
||||
TIMESTEP_INPUT_DIM = BLOCK_OUT_CHANNELS[0]
|
||||
@@ -1380,7 +1384,7 @@ class UNet2DConditionModel(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
assert sample_size is not None, "sample_size must be specified"
|
||||
print(
|
||||
logger.info(
|
||||
f"UNet2DConditionModel: {sample_size}, {attention_head_dim}, {cross_attention_dim}, {use_linear_projection}, {upcast_attention}"
|
||||
)
|
||||
|
||||
@@ -1514,7 +1518,7 @@ class UNet2DConditionModel(nn.Module):
|
||||
def set_gradient_checkpointing(self, value=False):
|
||||
modules = self.down_blocks + [self.mid_block] + self.up_blocks
|
||||
for module in modules:
|
||||
print(module.__class__.__name__, module.gradient_checkpointing, "->", value)
|
||||
logger.info(f"{module.__class__.__name__} {module.gradient_checkpointing} -> {value}")
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
# endregion
|
||||
@@ -1709,14 +1713,14 @@ class InferUNet2DConditionModel:
|
||||
|
||||
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
|
||||
if ds_depth_1 is None:
|
||||
print("Deep Shrink is disabled.")
|
||||
logger.info("Deep Shrink is disabled.")
|
||||
self.ds_depth_1 = None
|
||||
self.ds_timesteps_1 = None
|
||||
self.ds_depth_2 = None
|
||||
self.ds_timesteps_2 = None
|
||||
self.ds_ratio = None
|
||||
else:
|
||||
print(
|
||||
logger.info(
|
||||
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
|
||||
)
|
||||
self.ds_depth_1 = ds_depth_1
|
||||
|
||||
@@ -5,6 +5,12 @@ from io import BytesIO
|
||||
import os
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import safetensors
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
r"""
|
||||
# Metadata Example
|
||||
@@ -51,12 +57,18 @@ ARCH_SD_V1 = "stable-diffusion-v1"
|
||||
ARCH_SD_V2_512 = "stable-diffusion-v2-512"
|
||||
ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v"
|
||||
ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
|
||||
ARCH_SD3_M = "stable-diffusion-3" # may be followed by "-m" or "-5-large" etc.
|
||||
# ARCH_SD3_UNKNOWN = "stable-diffusion-3"
|
||||
ARCH_FLUX_1_DEV = "flux-1-dev"
|
||||
ARCH_FLUX_1_UNKNOWN = "flux-1"
|
||||
|
||||
ADAPTER_LORA = "lora"
|
||||
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
|
||||
|
||||
IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
|
||||
IMPL_COMFY_UI = "https://github.com/comfyanonymous/ComfyUI"
|
||||
IMPL_DIFFUSERS = "diffusers"
|
||||
IMPL_FLUX = "https://github.com/black-forest-labs/flux"
|
||||
|
||||
PRED_TYPE_EPSILON = "epsilon"
|
||||
PRED_TYPE_V = "v"
|
||||
@@ -109,7 +121,12 @@ def build_metadata(
|
||||
merged_from: Optional[str] = None,
|
||||
timesteps: Optional[Tuple[int, int]] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
sd3: Optional[str] = None,
|
||||
flux: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
sd3: only supports "m", flux: only supports "dev"
|
||||
"""
|
||||
# if state_dict is None, hash is not calculated
|
||||
|
||||
metadata = {}
|
||||
@@ -122,6 +139,13 @@ def build_metadata(
|
||||
|
||||
if sdxl:
|
||||
arch = ARCH_SD_XL_V1_BASE
|
||||
elif sd3 is not None:
|
||||
arch = ARCH_SD3_M + "-" + sd3
|
||||
elif flux is not None:
|
||||
if flux == "dev":
|
||||
arch = ARCH_FLUX_1_DEV
|
||||
else:
|
||||
arch = ARCH_FLUX_1_UNKNOWN
|
||||
elif v2:
|
||||
if v_parameterization:
|
||||
arch = ARCH_SD_V2_768_V
|
||||
@@ -138,9 +162,12 @@ def build_metadata(
|
||||
metadata["modelspec.architecture"] = arch
|
||||
|
||||
if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
|
||||
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
|
||||
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
|
||||
|
||||
if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
|
||||
if flux is not None:
|
||||
# Flux
|
||||
impl = IMPL_FLUX
|
||||
elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
|
||||
# Stable Diffusion ckpt, TI, SDXL LoRA
|
||||
impl = IMPL_STABILITY_AI
|
||||
else:
|
||||
@@ -198,7 +225,7 @@ def build_metadata(
|
||||
reso = (reso[0], reso[0])
|
||||
else:
|
||||
# resolution is defined in dataset, so use default
|
||||
if sdxl:
|
||||
if sdxl or sd3 is not None or flux is not None:
|
||||
reso = 1024
|
||||
elif v2 and v_parameterization:
|
||||
reso = 768
|
||||
@@ -209,7 +236,9 @@ def build_metadata(
|
||||
|
||||
metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}"
|
||||
|
||||
if v_parameterization:
|
||||
if flux is not None:
|
||||
del metadata["modelspec.prediction_type"]
|
||||
elif v_parameterization:
|
||||
metadata["modelspec.prediction_type"] = PRED_TYPE_V
|
||||
else:
|
||||
metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON
|
||||
@@ -231,8 +260,8 @@ def build_metadata(
|
||||
# # assert all values are filled
|
||||
# assert all([v is not None for v in metadata.values()]), metadata
|
||||
if not all([v is not None for v in metadata.values()]):
|
||||
print(f"Internal error: some metadata values are None: {metadata}")
|
||||
|
||||
logger.error(f"Internal error: some metadata values are None: {metadata}")
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
@@ -246,7 +275,7 @@ def get_title(metadata: dict) -> Optional[str]:
|
||||
def load_metadata_from_safetensors(model: str) -> dict:
|
||||
if not model.endswith(".safetensors"):
|
||||
return {}
|
||||
|
||||
|
||||
with safetensors.safe_open(model, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata is None:
|
||||
|
||||
1428
library/sd3_models.py
Normal file
1428
library/sd3_models.py
Normal file
File diff suppressed because it is too large
Load Diff
945
library/sd3_train_utils.py
Normal file
945
library/sd3_train_utils.py
Normal file
@@ -0,0 +1,945 @@
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import toml
|
||||
import json
|
||||
import time
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
from accelerate import Accelerator, PartialState
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
from transformers import CLIPTextModelWithProjection, T5EncoderModel
|
||||
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
# from transformers import CLIPTokenizer
|
||||
# from library import model_util
|
||||
# , sdxl_model_util, train_util, sdxl_original_unet
|
||||
# from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
|
||||
from .utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from library import sd3_models, sd3_utils, strategy_base, train_util
|
||||
|
||||
|
||||
def save_models(
|
||||
ckpt_path: str,
|
||||
mmdit: Optional[sd3_models.MMDiT],
|
||||
vae: Optional[sd3_models.SDVAE],
|
||||
clip_l: Optional[CLIPTextModelWithProjection],
|
||||
clip_g: Optional[CLIPTextModelWithProjection],
|
||||
t5xxl: Optional[T5EncoderModel],
|
||||
sai_metadata: Optional[dict],
|
||||
save_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
r"""
|
||||
Save models to checkpoint file. Only supports unified checkpoint format.
|
||||
"""
|
||||
|
||||
state_dict = {}
|
||||
|
||||
def update_sd(prefix, sd):
|
||||
for k, v in sd.items():
|
||||
key = prefix + k
|
||||
if save_dtype is not None:
|
||||
v = v.detach().clone().to("cpu").to(save_dtype)
|
||||
state_dict[key] = v
|
||||
|
||||
update_sd("model.diffusion_model.", mmdit.state_dict())
|
||||
update_sd("first_stage_model.", vae.state_dict())
|
||||
|
||||
# do not support unified checkpoint format for now
|
||||
# if clip_l is not None:
|
||||
# update_sd("text_encoders.clip_l.", clip_l.state_dict())
|
||||
# if clip_g is not None:
|
||||
# update_sd("text_encoders.clip_g.", clip_g.state_dict())
|
||||
# if t5xxl is not None:
|
||||
# update_sd("text_encoders.t5xxl.", t5xxl.state_dict())
|
||||
|
||||
save_file(state_dict, ckpt_path, metadata=sai_metadata)
|
||||
|
||||
if clip_l is not None:
|
||||
clip_l_path = ckpt_path.replace(".safetensors", "_clip_l.safetensors")
|
||||
save_file(clip_l.state_dict(), clip_l_path)
|
||||
if clip_g is not None:
|
||||
clip_g_path = ckpt_path.replace(".safetensors", "_clip_g.safetensors")
|
||||
save_file(clip_g.state_dict(), clip_g_path)
|
||||
if t5xxl is not None:
|
||||
t5xxl_path = ckpt_path.replace(".safetensors", "_t5xxl.safetensors")
|
||||
t5xxl_state_dict = t5xxl.state_dict()
|
||||
|
||||
# replace "shared.weight" with copy of it to avoid annoying shared tensor error on safetensors.save_file
|
||||
shared_weight = t5xxl_state_dict["shared.weight"]
|
||||
shared_weight_copy = shared_weight.detach().clone()
|
||||
t5xxl_state_dict["shared.weight"] = shared_weight_copy
|
||||
|
||||
save_file(t5xxl_state_dict, t5xxl_path)
|
||||
|
||||
|
||||
def save_sd3_model_on_train_end(
|
||||
args: argparse.Namespace,
|
||||
save_dtype: torch.dtype,
|
||||
epoch: int,
|
||||
global_step: int,
|
||||
clip_l: Optional[CLIPTextModelWithProjection],
|
||||
clip_g: Optional[CLIPTextModelWithProjection],
|
||||
t5xxl: Optional[T5EncoderModel],
|
||||
mmdit: sd3_models.MMDiT,
|
||||
vae: sd3_models.SDVAE,
|
||||
):
|
||||
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||
sai_metadata = train_util.get_sai_model_spec(
|
||||
None, args, False, False, False, is_stable_diffusion_ckpt=True, sd3=mmdit.model_type
|
||||
)
|
||||
save_models(ckpt_file, mmdit, vae, clip_l, clip_g, t5xxl, sai_metadata, save_dtype)
|
||||
|
||||
train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
|
||||
|
||||
|
||||
# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している
|
||||
# on_epoch_end: Trueならepoch終了時、Falseならstep経過時
|
||||
def save_sd3_model_on_epoch_end_or_stepwise(
|
||||
args: argparse.Namespace,
|
||||
on_epoch_end: bool,
|
||||
accelerator,
|
||||
save_dtype: torch.dtype,
|
||||
epoch: int,
|
||||
num_train_epochs: int,
|
||||
global_step: int,
|
||||
clip_l: Optional[CLIPTextModelWithProjection],
|
||||
clip_g: Optional[CLIPTextModelWithProjection],
|
||||
t5xxl: Optional[T5EncoderModel],
|
||||
mmdit: sd3_models.MMDiT,
|
||||
vae: sd3_models.SDVAE,
|
||||
):
|
||||
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||
sai_metadata = train_util.get_sai_model_spec(
|
||||
None, args, False, False, False, is_stable_diffusion_ckpt=True, sd3=mmdit.model_type
|
||||
)
|
||||
save_models(ckpt_file, mmdit, vae, clip_l, clip_g, t5xxl, sai_metadata, save_dtype)
|
||||
|
||||
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
|
||||
args,
|
||||
on_epoch_end,
|
||||
accelerator,
|
||||
True,
|
||||
True,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
sd_saver,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def add_sd3_training_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--clip_l",
|
||||
type=str,
|
||||
required=False,
|
||||
help="CLIP-L model path. if not specified, use ckpt's state_dict / CLIP-Lモデルのパス。指定しない場合はckptのstate_dictを使用",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clip_g",
|
||||
type=str,
|
||||
required=False,
|
||||
help="CLIP-G model path. if not specified, use ckpt's state_dict / CLIP-Gモデルのパス。指定しない場合はckptのstate_dictを使用",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--t5xxl",
|
||||
type=str,
|
||||
required=False,
|
||||
help="T5-XXL model path. if not specified, use ckpt's state_dict / T5-XXLモデルのパス。指定しない場合はckptのstate_dictを使用",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_clip",
|
||||
action="store_true",
|
||||
help="[DOES NOT WORK] unified checkpoint is not supported / 統合チェックポイントはまだサポートされていません",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_t5xxl",
|
||||
action="store_true",
|
||||
help="[DOES NOT WORK] unified checkpoint is not supported / 統合チェックポイントはまだサポートされていません",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--t5xxl_device",
|
||||
type=str,
|
||||
default=None,
|
||||
help="[DOES NOT WORK] not supported yet. T5-XXL device. if not specified, use accelerator's device / T5-XXLデバイス。指定しない場合はacceleratorのデバイスを使用",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--t5xxl_dtype",
|
||||
type=str,
|
||||
default=None,
|
||||
help="[DOES NOT WORK] not supported yet. T5-XXL dtype. if not specified, use default dtype (from mixed precision) / T5-XXL dtype。指定しない場合はデフォルトのdtype(mixed precisionから)を使用",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--t5xxl_max_token_length",
|
||||
type=int,
|
||||
default=256,
|
||||
help="maximum token length for T5-XXL. 256 is the default value / T5-XXLの最大トークン長。デフォルトは256",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--apply_lg_attn_mask",
|
||||
action="store_true",
|
||||
help="apply attention mask (zero embs) to CLIP-L and G / CLIP-LとGにアテンションマスク(ゼロ埋め)を適用する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--apply_t5_attn_mask",
|
||||
action="store_true",
|
||||
help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスク(ゼロ埋め)を適用する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clip_l_dropout_rate",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Dropout rate for CLIP-L encoder, default is 0.0 / CLIP-Lエンコーダのドロップアウト率、デフォルトは0.0",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clip_g_dropout_rate",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Dropout rate for CLIP-G encoder, default is 0.0 / CLIP-Gエンコーダのドロップアウト率、デフォルトは0.0",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--t5_dropout_rate",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Dropout rate for T5 encoder, default is 0.0 / T5エンコーダのドロップアウト率、デフォルトは0.0",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pos_emb_random_crop_rate",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Random crop rate for positional embeddings, default is 0.0. Only for SD3.5M"
|
||||
" / 位置埋め込みのランダムクロップ率、デフォルトは0.0。SD3.5M以外では予期しない動作になります",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_scaled_pos_embed",
|
||||
action="store_true",
|
||||
help="Scale position embeddings for each resolution during multi-resolution training. Only for SD3.5M"
|
||||
" / 複数解像度学習時に解像度ごとに位置埋め込みをスケーリングする。SD3.5M以外では予期しない動作になります",
|
||||
)
|
||||
|
||||
# Dependencies of Diffusers noise sampler has been removed for clarity in training
|
||||
|
||||
parser.add_argument(
|
||||
"--training_shift",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Discrete flow shift for training timestep distribution adjustment, applied in addition to the weighting scheme, default is 1.0. /タイムステップ分布のための離散フローシフト、重み付けスキームの上に適用される、デフォルトは1.0。",
|
||||
)
|
||||
|
||||
|
||||
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
|
||||
assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
|
||||
if args.v_parameterization:
|
||||
logger.warning("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります")
|
||||
|
||||
if args.clip_skip is not None:
|
||||
logger.warning("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
|
||||
|
||||
# if args.multires_noise_iterations:
|
||||
# logger.info(
|
||||
# f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります"
|
||||
# )
|
||||
# else:
|
||||
# if args.noise_offset is None:
|
||||
# args.noise_offset = DEFAULT_NOISE_OFFSET
|
||||
# elif args.noise_offset != DEFAULT_NOISE_OFFSET:
|
||||
# logger.info(
|
||||
# f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています"
|
||||
# )
|
||||
# logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
|
||||
|
||||
assert (
|
||||
not hasattr(args, "weighted_captions") or not args.weighted_captions
|
||||
), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
|
||||
|
||||
if supportTextEncoderCaching:
|
||||
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
||||
args.cache_text_encoder_outputs = True
|
||||
logger.warning(
|
||||
"cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / "
|
||||
+ "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました"
|
||||
)
|
||||
|
||||
|
||||
# temporary copied from sd3_minimal_inferece.py
|
||||
|
||||
|
||||
def get_all_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps):
|
||||
start = sampling.timestep(sampling.sigma_max)
|
||||
end = sampling.timestep(sampling.sigma_min)
|
||||
timesteps = torch.linspace(start, end, steps)
|
||||
sigs = []
|
||||
for x in range(len(timesteps)):
|
||||
ts = timesteps[x]
|
||||
sigs.append(sampling.sigma(ts))
|
||||
sigs += [0.0]
|
||||
return torch.FloatTensor(sigs)
|
||||
|
||||
|
||||
def max_denoise(model_sampling, sigmas):
|
||||
max_sigma = float(model_sampling.sigma_max)
|
||||
sigma = float(sigmas[0])
|
||||
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
||||
|
||||
|
||||
def do_sample(
|
||||
height: int,
|
||||
width: int,
|
||||
seed: int,
|
||||
cond: Tuple[torch.Tensor, torch.Tensor],
|
||||
neg_cond: Tuple[torch.Tensor, torch.Tensor],
|
||||
mmdit: sd3_models.MMDiT,
|
||||
steps: int,
|
||||
guidance_scale: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
):
|
||||
latent = torch.zeros(1, 16, height // 8, width // 8, device=device)
|
||||
latent = latent.to(dtype).to(device)
|
||||
|
||||
# noise = get_noise(seed, latent).to(device)
|
||||
if seed is not None:
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = None
|
||||
noise = (
|
||||
torch.randn(latent.size(), dtype=torch.float32, layout=latent.layout, generator=generator, device="cpu")
|
||||
.to(latent.dtype)
|
||||
.to(device)
|
||||
)
|
||||
|
||||
model_sampling = sd3_utils.ModelSamplingDiscreteFlow(shift=3.0) # 3.0 is for SD3
|
||||
|
||||
sigmas = get_all_sigmas(model_sampling, steps).to(device)
|
||||
|
||||
noise_scaled = model_sampling.noise_scaling(sigmas[0], noise, latent, max_denoise(model_sampling, sigmas))
|
||||
|
||||
c_crossattn = torch.cat([cond[0], neg_cond[0]]).to(device).to(dtype)
|
||||
y = torch.cat([cond[1], neg_cond[1]]).to(device).to(dtype)
|
||||
|
||||
x = noise_scaled.to(device).to(dtype)
|
||||
# print(x.shape)
|
||||
|
||||
# with torch.no_grad():
|
||||
for i in tqdm(range(len(sigmas) - 1)):
|
||||
sigma_hat = sigmas[i]
|
||||
|
||||
timestep = model_sampling.timestep(sigma_hat).float()
|
||||
timestep = torch.FloatTensor([timestep, timestep]).to(device)
|
||||
|
||||
x_c_nc = torch.cat([x, x], dim=0)
|
||||
# print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape)
|
||||
|
||||
mmdit.prepare_block_swap_before_forward()
|
||||
model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y)
|
||||
model_output = model_output.float()
|
||||
batched = model_sampling.calculate_denoised(sigma_hat, model_output, x)
|
||||
|
||||
pos_out, neg_out = batched.chunk(2)
|
||||
denoised = neg_out + (pos_out - neg_out) * guidance_scale
|
||||
# print(denoised.shape)
|
||||
|
||||
# d = to_d(x, sigma_hat, denoised)
|
||||
dims_to_append = x.ndim - sigma_hat.ndim
|
||||
sigma_hat_dims = sigma_hat[(...,) + (None,) * dims_to_append]
|
||||
# print(dims_to_append, x.shape, sigma_hat.shape, denoised.shape, sigma_hat_dims.shape)
|
||||
"""Converts a denoiser output to a Karras ODE derivative."""
|
||||
d = (x - denoised) / sigma_hat_dims
|
||||
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
|
||||
# Euler method
|
||||
x = x + d * dt
|
||||
x = x.to(dtype)
|
||||
|
||||
mmdit.prepare_block_swap_before_forward()
|
||||
return x
|
||||
|
||||
|
||||
def sample_images(
|
||||
accelerator: Accelerator,
|
||||
args: argparse.Namespace,
|
||||
epoch,
|
||||
steps,
|
||||
mmdit,
|
||||
vae,
|
||||
text_encoders,
|
||||
sample_prompts_te_outputs,
|
||||
prompt_replacement=None,
|
||||
):
|
||||
if steps == 0:
|
||||
if not args.sample_at_first:
|
||||
return
|
||||
else:
|
||||
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
|
||||
return
|
||||
if args.sample_every_n_epochs is not None:
|
||||
# sample_every_n_steps は無視する
|
||||
if epoch is None or epoch % args.sample_every_n_epochs != 0:
|
||||
return
|
||||
else:
|
||||
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
|
||||
return
|
||||
|
||||
logger.info("")
|
||||
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
|
||||
if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None:
|
||||
logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
|
||||
return
|
||||
|
||||
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
|
||||
|
||||
# unwrap unet and text_encoder(s)
|
||||
mmdit = accelerator.unwrap_model(mmdit)
|
||||
text_encoders = None if text_encoders is None else [accelerator.unwrap_model(te) for te in text_encoders]
|
||||
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
|
||||
|
||||
prompts = train_util.load_prompts(args.sample_prompts)
|
||||
|
||||
save_dir = args.output_dir + "/sample"
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# save random state to restore later
|
||||
rng_state = torch.get_rng_state()
|
||||
cuda_rng_state = None
|
||||
try:
|
||||
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if distributed_state.num_processes <= 1:
|
||||
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
for prompt_dict in prompts:
|
||||
sample_image_inference(
|
||||
accelerator,
|
||||
args,
|
||||
mmdit,
|
||||
text_encoders,
|
||||
vae,
|
||||
save_dir,
|
||||
prompt_dict,
|
||||
epoch,
|
||||
steps,
|
||||
sample_prompts_te_outputs,
|
||||
prompt_replacement,
|
||||
)
|
||||
else:
|
||||
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
|
||||
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
|
||||
per_process_prompts = [] # list of lists
|
||||
for i in range(distributed_state.num_processes):
|
||||
per_process_prompts.append(prompts[i :: distributed_state.num_processes])
|
||||
|
||||
with torch.no_grad():
|
||||
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
|
||||
for prompt_dict in prompt_dict_lists[0]:
|
||||
sample_image_inference(
|
||||
accelerator,
|
||||
args,
|
||||
mmdit,
|
||||
text_encoders,
|
||||
vae,
|
||||
save_dir,
|
||||
prompt_dict,
|
||||
epoch,
|
||||
steps,
|
||||
sample_prompts_te_outputs,
|
||||
prompt_replacement,
|
||||
)
|
||||
|
||||
torch.set_rng_state(rng_state)
|
||||
if cuda_rng_state is not None:
|
||||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
|
||||
def sample_image_inference(
|
||||
accelerator: Accelerator,
|
||||
args: argparse.Namespace,
|
||||
mmdit: sd3_models.MMDiT,
|
||||
text_encoders: List[Union[CLIPTextModelWithProjection, T5EncoderModel]],
|
||||
vae: sd3_models.SDVAE,
|
||||
save_dir,
|
||||
prompt_dict,
|
||||
epoch,
|
||||
steps,
|
||||
sample_prompts_te_outputs,
|
||||
prompt_replacement,
|
||||
):
|
||||
assert isinstance(prompt_dict, dict)
|
||||
negative_prompt = prompt_dict.get("negative_prompt")
|
||||
sample_steps = prompt_dict.get("sample_steps", 30)
|
||||
width = prompt_dict.get("width", 512)
|
||||
height = prompt_dict.get("height", 512)
|
||||
scale = prompt_dict.get("scale", 7.5)
|
||||
seed = prompt_dict.get("seed")
|
||||
# controlnet_image = prompt_dict.get("controlnet_image")
|
||||
prompt: str = prompt_dict.get("prompt", "")
|
||||
# sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
|
||||
|
||||
if prompt_replacement is not None:
|
||||
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
if negative_prompt is not None:
|
||||
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
|
||||
if seed is not None:
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
else:
|
||||
# True random sample image generation
|
||||
torch.seed()
|
||||
torch.cuda.seed()
|
||||
|
||||
if negative_prompt is None:
|
||||
negative_prompt = ""
|
||||
|
||||
height = max(64, height - height % 8) # round to divisible by 8
|
||||
width = max(64, width - width % 8) # round to divisible by 8
|
||||
logger.info(f"prompt: {prompt}")
|
||||
logger.info(f"negative_prompt: {negative_prompt}")
|
||||
logger.info(f"height: {height}")
|
||||
logger.info(f"width: {width}")
|
||||
logger.info(f"sample_steps: {sample_steps}")
|
||||
logger.info(f"scale: {scale}")
|
||||
# logger.info(f"sample_sampler: {sampler_name}")
|
||||
if seed is not None:
|
||||
logger.info(f"seed: {seed}")
|
||||
|
||||
# encode prompts
|
||||
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
|
||||
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
|
||||
def encode_prompt(prpt):
|
||||
text_encoder_conds = []
|
||||
if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs:
|
||||
text_encoder_conds = sample_prompts_te_outputs[prpt]
|
||||
print(f"Using cached text encoder outputs for prompt: {prpt}")
|
||||
if text_encoders is not None:
|
||||
print(f"Encoding prompt: {prpt}")
|
||||
tokens_and_masks = tokenize_strategy.tokenize(prpt)
|
||||
# strategy has apply_t5_attn_mask option
|
||||
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
|
||||
|
||||
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
|
||||
if len(text_encoder_conds) == 0:
|
||||
text_encoder_conds = encoded_text_encoder_conds
|
||||
else:
|
||||
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
|
||||
for i in range(len(encoded_text_encoder_conds)):
|
||||
if encoded_text_encoder_conds[i] is not None:
|
||||
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||
return text_encoder_conds
|
||||
|
||||
lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encode_prompt(prompt)
|
||||
cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
|
||||
|
||||
# encode negative prompts
|
||||
lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encode_prompt(negative_prompt)
|
||||
neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
|
||||
|
||||
# sample image
|
||||
clean_memory_on_device(accelerator.device)
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
# mmdit may be fp8, so we need weight_dtype here. vae is always in that dtype.
|
||||
latents = do_sample(height, width, seed, cond, neg_cond, mmdit, sample_steps, scale, vae.dtype, accelerator.device)
|
||||
|
||||
# latent to image
|
||||
clean_memory_on_device(accelerator.device)
|
||||
org_vae_device = vae.device # will be on cpu
|
||||
vae.to(accelerator.device)
|
||||
latents = vae.process_out(latents.to(vae.device, dtype=vae.dtype))
|
||||
image = vae.decode(latents)
|
||||
vae.to(org_vae_device)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
image = image.float()
|
||||
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0]
|
||||
decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2)
|
||||
decoded_np = decoded_np.astype(np.uint8)
|
||||
|
||||
image = Image.fromarray(decoded_np)
|
||||
# adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
|
||||
# but adding 'enum' to the filename should be enough
|
||||
|
||||
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
||||
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
|
||||
seed_suffix = "" if seed is None else f"_{seed}"
|
||||
i: int = prompt_dict["enum"]
|
||||
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
|
||||
image.save(os.path.join(save_dir, img_filename))
|
||||
|
||||
# send images to wandb if enabled
|
||||
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
|
||||
wandb_tracker = accelerator.get_tracker("wandb")
|
||||
|
||||
import wandb
|
||||
|
||||
# not to commit images to avoid inconsistency between training and logging steps
|
||||
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
|
||||
|
||||
|
||||
# region Diffusers
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
|
||||
"""
|
||||
Output class for the scheduler's `step` function output.
|
||||
|
||||
Args:
|
||||
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
||||
denoising loop.
|
||||
"""
|
||||
|
||||
prev_sample: torch.FloatTensor
|
||||
|
||||
|
||||
class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Euler scheduler.
|
||||
|
||||
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
||||
methods the library implements for all schedulers such as loading and saving.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`, defaults to 1000):
|
||||
The number of diffusion steps to train the model.
|
||||
timestep_spacing (`str`, defaults to `"linspace"`):
|
||||
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
||||
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||||
shift (`float`, defaults to 1.0):
|
||||
The shift value for the timestep schedule.
|
||||
"""
|
||||
|
||||
_compatibles = []
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
shift: float = 1.0,
|
||||
):
|
||||
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
||||
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
|
||||
|
||||
sigmas = timesteps / num_train_timesteps
|
||||
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||
|
||||
self.timesteps = sigmas * num_train_timesteps
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
self.sigma_min = self.sigmas[-1].item()
|
||||
self.sigma_max = self.sigmas[0].item()
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for current timestep. It will increase 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def scale_noise(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[float, torch.FloatTensor],
|
||||
noise: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Forward process in flow-matching
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`):
|
||||
The input sample.
|
||||
timestep (`int`, *optional*):
|
||||
The current timestep in the diffusion chain.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`:
|
||||
A scaled input sample.
|
||||
"""
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sample = sigma * noise + (1.0 - sigma) * sample
|
||||
|
||||
return sample
|
||||
|
||||
def _sigma_to_t(self, sigma):
|
||||
return sigma * self.config.num_train_timesteps
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
"""
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
timesteps = np.linspace(self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps)
|
||||
|
||||
sigmas = timesteps / self.config.num_train_timesteps
|
||||
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
|
||||
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
||||
|
||||
timesteps = sigmas * self.config.num_train_timesteps
|
||||
self.timesteps = timesteps.to(device=device)
|
||||
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
indices = (schedule_timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
|
||||
return indices[pos].item()
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: Union[float, torch.FloatTensor],
|
||||
sample: torch.FloatTensor,
|
||||
s_churn: float = 0.0,
|
||||
s_tmin: float = 0.0,
|
||||
s_tmax: float = float("inf"),
|
||||
s_noise: float = 1.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`):
|
||||
The direct output from learned diffusion model.
|
||||
timestep (`float`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
s_churn (`float`):
|
||||
s_tmin (`float`):
|
||||
s_tmax (`float`):
|
||||
s_noise (`float`, defaults to 1.0):
|
||||
Scaling factor for noise added to the sample.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A random number generator.
|
||||
return_dict (`bool`):
|
||||
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
|
||||
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
|
||||
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
||||
"""
|
||||
|
||||
if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor):
|
||||
raise ValueError(
|
||||
(
|
||||
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
||||
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
||||
" one of the `scheduler.timesteps` as a timestep."
|
||||
),
|
||||
)
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
# Upcast to avoid precision issues when computing prev_sample
|
||||
sample = sample.to(torch.float32)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
|
||||
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
|
||||
|
||||
noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator)
|
||||
|
||||
eps = noise * s_noise
|
||||
sigma_hat = sigma * (gamma + 1)
|
||||
|
||||
if gamma > 0:
|
||||
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
||||
|
||||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||
# NOTE: "original_sample" should not be an expected prediction_type but is left in for
|
||||
# backwards compatibility
|
||||
|
||||
# if self.config.prediction_type == "vector_field":
|
||||
|
||||
denoised = sample - model_output * sigma
|
||||
# 2. Convert to an ODE derivative
|
||||
derivative = (sample - denoised) / sigma_hat
|
||||
|
||||
dt = self.sigmas[self.step_index + 1] - sigma_hat
|
||||
|
||||
prev_sample = sample + derivative * dt
|
||||
# Cast sample back to model compatible dtype
|
||||
prev_sample = prev_sample.to(model_output.dtype)
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
|
||||
def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
|
||||
sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
|
||||
schedule_timesteps = noise_scheduler.timesteps.to(device)
|
||||
timesteps = timesteps.to(device)
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < n_dim:
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
return sigma
|
||||
|
||||
|
||||
def compute_density_for_timestep_sampling(
|
||||
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
|
||||
):
|
||||
"""Compute the density for sampling the timesteps when doing SD3 training.
|
||||
|
||||
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
||||
|
||||
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
||||
"""
|
||||
if weighting_scheme == "logit_normal":
|
||||
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
|
||||
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
|
||||
u = torch.nn.functional.sigmoid(u)
|
||||
elif weighting_scheme == "mode":
|
||||
u = torch.rand(size=(batch_size,), device="cpu")
|
||||
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
|
||||
else:
|
||||
u = torch.rand(size=(batch_size,), device="cpu")
|
||||
return u
|
||||
|
||||
|
||||
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
||||
"""Computes loss weighting scheme for SD3 training.
|
||||
|
||||
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
||||
|
||||
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
||||
"""
|
||||
if weighting_scheme == "sigma_sqrt":
|
||||
weighting = (sigmas**-2.0).float()
|
||||
elif weighting_scheme == "cosmap":
|
||||
bot = 1 - 2 * sigmas + 2 * sigmas**2
|
||||
weighting = 2 / (math.pi * bot)
|
||||
else:
|
||||
weighting = torch.ones_like(sigmas)
|
||||
return weighting
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
def get_noisy_model_input_and_timesteps(args, latents, noise, device, dtype) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
bsz = latents.shape[0]
|
||||
|
||||
# Sample a random timestep for each image
|
||||
# for weighting schemes where we sample timesteps non-uniformly
|
||||
u = compute_density_for_timestep_sampling(
|
||||
weighting_scheme=args.weighting_scheme,
|
||||
batch_size=bsz,
|
||||
logit_mean=args.logit_mean,
|
||||
logit_std=args.logit_std,
|
||||
mode_scale=args.mode_scale,
|
||||
)
|
||||
t_min = args.min_timestep if args.min_timestep is not None else 0
|
||||
t_max = args.max_timestep if args.max_timestep is not None else 1000
|
||||
shift = args.training_shift
|
||||
|
||||
# weighting shift, value >1 will shift distribution to noisy side (focus more on overall structure), value <1 will shift towards less-noisy side (focus more on details)
|
||||
u = (u * shift) / (1 + (shift - 1) * u)
|
||||
|
||||
indices = (u * (t_max - t_min) + t_min).long()
|
||||
timesteps = indices.to(device=device, dtype=dtype)
|
||||
|
||||
# sigmas according to flowmatching
|
||||
sigmas = timesteps / 1000
|
||||
sigmas = sigmas.view(-1, 1, 1, 1)
|
||||
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
|
||||
|
||||
return noisy_model_input, timesteps, sigmas
|
||||
302
library/sd3_utils.py
Normal file
302
library/sd3_utils.py
Normal file
@@ -0,0 +1,302 @@
|
||||
from dataclasses import dataclass
|
||||
import math
|
||||
import re
|
||||
from typing import Dict, List, Optional, Union
|
||||
import torch
|
||||
import safetensors
|
||||
from safetensors.torch import load_file
|
||||
from accelerate import init_empty_weights
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPConfig, CLIPTextConfig
|
||||
|
||||
from .utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from library import sd3_models
|
||||
|
||||
# TODO move some of functions to model_util.py
|
||||
from library import sdxl_model_util
|
||||
|
||||
# region models
|
||||
|
||||
# TODO remove dependency on flux_utils
|
||||
from library.utils import load_safetensors
|
||||
from library.flux_utils import load_t5xxl as flux_utils_load_t5xxl
|
||||
|
||||
|
||||
def analyze_state_dict_state(state_dict: Dict, prefix: str = ""):
|
||||
logger.info(f"Analyzing state dict state...")
|
||||
|
||||
# analyze configs
|
||||
patch_size = state_dict[f"{prefix}x_embedder.proj.weight"].shape[2]
|
||||
depth = state_dict[f"{prefix}x_embedder.proj.weight"].shape[0] // 64
|
||||
num_patches = state_dict[f"{prefix}pos_embed"].shape[1]
|
||||
pos_embed_max_size = round(math.sqrt(num_patches))
|
||||
adm_in_channels = state_dict[f"{prefix}y_embedder.mlp.0.weight"].shape[1]
|
||||
context_shape = state_dict[f"{prefix}context_embedder.weight"].shape
|
||||
qk_norm = "rms" if f"{prefix}joint_blocks.0.context_block.attn.ln_k.weight" in state_dict.keys() else None
|
||||
|
||||
# x_block_self_attn_layers.append(int(key.split(".x_block.attn2.ln_k.weight")[0].split(".")[-1]))
|
||||
x_block_self_attn_layers = []
|
||||
re_attn = re.compile(r"\.(\d+)\.x_block\.attn2\.ln_k\.weight")
|
||||
for key in list(state_dict.keys()):
|
||||
m = re_attn.search(key)
|
||||
if m:
|
||||
x_block_self_attn_layers.append(int(m.group(1)))
|
||||
|
||||
context_embedder_in_features = context_shape[1]
|
||||
context_embedder_out_features = context_shape[0]
|
||||
|
||||
# only supports 3-5-large, medium or 3-medium
|
||||
if qk_norm is not None:
|
||||
if len(x_block_self_attn_layers) == 0:
|
||||
model_type = "3-5-large"
|
||||
else:
|
||||
model_type = "3-5-medium"
|
||||
else:
|
||||
model_type = "3-medium"
|
||||
|
||||
params = sd3_models.SD3Params(
|
||||
patch_size=patch_size,
|
||||
depth=depth,
|
||||
num_patches=num_patches,
|
||||
pos_embed_max_size=pos_embed_max_size,
|
||||
adm_in_channels=adm_in_channels,
|
||||
qk_norm=qk_norm,
|
||||
x_block_self_attn_layers=x_block_self_attn_layers,
|
||||
context_embedder_in_features=context_embedder_in_features,
|
||||
context_embedder_out_features=context_embedder_out_features,
|
||||
model_type=model_type,
|
||||
)
|
||||
logger.info(f"Analyzed state dict state: {params}")
|
||||
return params
|
||||
|
||||
|
||||
def load_mmdit(
|
||||
state_dict: Dict, dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device], attn_mode: str = "torch"
|
||||
) -> sd3_models.MMDiT:
|
||||
mmdit_sd = {}
|
||||
|
||||
mmdit_prefix = "model.diffusion_model."
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith(mmdit_prefix):
|
||||
mmdit_sd[k[len(mmdit_prefix) :]] = state_dict.pop(k)
|
||||
|
||||
# load MMDiT
|
||||
logger.info("Building MMDit")
|
||||
params = analyze_state_dict_state(mmdit_sd)
|
||||
with init_empty_weights():
|
||||
mmdit = sd3_models.create_sd3_mmdit(params, attn_mode)
|
||||
|
||||
logger.info("Loading state dict...")
|
||||
info = mmdit.load_state_dict(mmdit_sd, strict=False, assign=True)
|
||||
logger.info(f"Loaded MMDiT: {info}")
|
||||
return mmdit
|
||||
|
||||
|
||||
def load_clip_l(
|
||||
clip_l_path: Optional[str],
|
||||
dtype: Optional[Union[str, torch.dtype]],
|
||||
device: Union[str, torch.device],
|
||||
disable_mmap: bool = False,
|
||||
state_dict: Optional[Dict] = None,
|
||||
):
|
||||
clip_l_sd = None
|
||||
if clip_l_path is None:
|
||||
if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
|
||||
# found clip_l: remove prefix "text_encoders.clip_l."
|
||||
logger.info("clip_l is included in the checkpoint")
|
||||
clip_l_sd = {}
|
||||
prefix = "text_encoders.clip_l."
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith(prefix):
|
||||
clip_l_sd[k[len(prefix) :]] = state_dict.pop(k)
|
||||
elif clip_l_path is None:
|
||||
logger.info("clip_l is not included in the checkpoint and clip_l_path is not provided")
|
||||
return None
|
||||
|
||||
# load clip_l
|
||||
logger.info("Building CLIP-L")
|
||||
config = CLIPTextConfig(
|
||||
vocab_size=49408,
|
||||
hidden_size=768,
|
||||
intermediate_size=3072,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
max_position_embeddings=77,
|
||||
hidden_act="quick_gelu",
|
||||
layer_norm_eps=1e-05,
|
||||
dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
initializer_range=0.02,
|
||||
initializer_factor=1.0,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
model_type="clip_text_model",
|
||||
projection_dim=768,
|
||||
# torch_dtype="float32",
|
||||
# transformers_version="4.25.0.dev0",
|
||||
)
|
||||
with init_empty_weights():
|
||||
clip = CLIPTextModelWithProjection(config)
|
||||
|
||||
if clip_l_sd is None:
|
||||
logger.info(f"Loading state dict from {clip_l_path}")
|
||||
clip_l_sd = load_safetensors(clip_l_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
|
||||
|
||||
if "text_projection.weight" not in clip_l_sd:
|
||||
logger.info("Adding text_projection.weight to clip_l_sd")
|
||||
clip_l_sd["text_projection.weight"] = torch.eye(768, dtype=dtype, device=device)
|
||||
|
||||
info = clip.load_state_dict(clip_l_sd, strict=False, assign=True)
|
||||
logger.info(f"Loaded CLIP-L: {info}")
|
||||
return clip
|
||||
|
||||
|
||||
def load_clip_g(
|
||||
clip_g_path: Optional[str],
|
||||
dtype: Optional[Union[str, torch.dtype]],
|
||||
device: Union[str, torch.device],
|
||||
disable_mmap: bool = False,
|
||||
state_dict: Optional[Dict] = None,
|
||||
):
|
||||
clip_g_sd = None
|
||||
if state_dict is not None:
|
||||
if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
|
||||
# found clip_g: remove prefix "text_encoders.clip_g."
|
||||
logger.info("clip_g is included in the checkpoint")
|
||||
clip_g_sd = {}
|
||||
prefix = "text_encoders.clip_g."
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith(prefix):
|
||||
clip_g_sd[k[len(prefix) :]] = state_dict.pop(k)
|
||||
elif clip_g_path is None:
|
||||
logger.info("clip_g is not included in the checkpoint and clip_g_path is not provided")
|
||||
return None
|
||||
|
||||
# load clip_g
|
||||
logger.info("Building CLIP-G")
|
||||
config = CLIPTextConfig(
|
||||
vocab_size=49408,
|
||||
hidden_size=1280,
|
||||
intermediate_size=5120,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=20,
|
||||
max_position_embeddings=77,
|
||||
hidden_act="gelu",
|
||||
layer_norm_eps=1e-05,
|
||||
dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
initializer_range=0.02,
|
||||
initializer_factor=1.0,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
model_type="clip_text_model",
|
||||
projection_dim=1280,
|
||||
# torch_dtype="float32",
|
||||
# transformers_version="4.25.0.dev0",
|
||||
)
|
||||
with init_empty_weights():
|
||||
clip = CLIPTextModelWithProjection(config)
|
||||
|
||||
if clip_g_sd is None:
|
||||
logger.info(f"Loading state dict from {clip_g_path}")
|
||||
clip_g_sd = load_safetensors(clip_g_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
|
||||
info = clip.load_state_dict(clip_g_sd, strict=False, assign=True)
|
||||
logger.info(f"Loaded CLIP-G: {info}")
|
||||
return clip
|
||||
|
||||
|
||||
def load_t5xxl(
|
||||
t5xxl_path: Optional[str],
|
||||
dtype: Optional[Union[str, torch.dtype]],
|
||||
device: Union[str, torch.device],
|
||||
disable_mmap: bool = False,
|
||||
state_dict: Optional[Dict] = None,
|
||||
):
|
||||
t5xxl_sd = None
|
||||
if state_dict is not None:
|
||||
if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict:
|
||||
# found t5xxl: remove prefix "text_encoders.t5xxl."
|
||||
logger.info("t5xxl is included in the checkpoint")
|
||||
t5xxl_sd = {}
|
||||
prefix = "text_encoders.t5xxl."
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith(prefix):
|
||||
t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k)
|
||||
elif t5xxl_path is None:
|
||||
logger.info("t5xxl is not included in the checkpoint and t5xxl_path is not provided")
|
||||
return None
|
||||
|
||||
return flux_utils_load_t5xxl(t5xxl_path, dtype, device, disable_mmap, state_dict=t5xxl_sd)
|
||||
|
||||
|
||||
def load_vae(
|
||||
vae_path: Optional[str],
|
||||
vae_dtype: Optional[Union[str, torch.dtype]],
|
||||
device: Optional[Union[str, torch.device]],
|
||||
disable_mmap: bool = False,
|
||||
state_dict: Optional[Dict] = None,
|
||||
):
|
||||
vae_sd = {}
|
||||
if vae_path:
|
||||
logger.info(f"Loading VAE from {vae_path}...")
|
||||
vae_sd = load_safetensors(vae_path, device, disable_mmap)
|
||||
else:
|
||||
# remove prefix "first_stage_model."
|
||||
vae_sd = {}
|
||||
vae_prefix = "first_stage_model."
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith(vae_prefix):
|
||||
vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k)
|
||||
|
||||
logger.info("Building VAE")
|
||||
vae = sd3_models.SDVAE(vae_dtype, device)
|
||||
logger.info("Loading state dict...")
|
||||
info = vae.load_state_dict(vae_sd)
|
||||
logger.info(f"Loaded VAE: {info}")
|
||||
vae.to(device=device, dtype=vae_dtype) # make sure it's in the right device and dtype
|
||||
return vae
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
class ModelSamplingDiscreteFlow:
|
||||
"""Helper for sampler scheduling (ie timestep/sigma calculations) for Discrete Flow models"""
|
||||
|
||||
def __init__(self, shift=1.0):
|
||||
self.shift = shift
|
||||
timesteps = 1000
|
||||
self.sigmas = self.sigma(torch.arange(1, timesteps + 1, 1))
|
||||
|
||||
@property
|
||||
def sigma_min(self):
|
||||
return self.sigmas[0]
|
||||
|
||||
@property
|
||||
def sigma_max(self):
|
||||
return self.sigmas[-1]
|
||||
|
||||
def timestep(self, sigma):
|
||||
return sigma * 1000
|
||||
|
||||
def sigma(self, timestep: torch.Tensor):
|
||||
timestep = timestep / 1000.0
|
||||
if self.shift == 1.0:
|
||||
return timestep
|
||||
return self.shift * timestep / (1 + (self.shift - 1) * timestep)
|
||||
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||
return model_input - model_output * sigma
|
||||
|
||||
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
||||
# assert max_denoise is False, "max_denoise not implemented"
|
||||
# max_denoise is always True, I'm not sure why it's there
|
||||
return sigma * noise + (1.0 - sigma) * latent_image
|
||||
@@ -13,12 +13,20 @@ from tqdm import tqdm
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import SchedulerMixin, StableDiffusionPipeline
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
||||
from diffusers.models import AutoencoderKL
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
||||
from diffusers.utils import logging
|
||||
from PIL import Image
|
||||
|
||||
from library import sdxl_model_util, sdxl_train_util, train_util
|
||||
from library import (
|
||||
sdxl_model_util,
|
||||
sdxl_train_util,
|
||||
strategy_base,
|
||||
strategy_sdxl,
|
||||
train_util,
|
||||
sdxl_original_unet,
|
||||
sdxl_original_control_net,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
@@ -537,7 +545,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: List[CLIPTextModel],
|
||||
tokenizer: List[CLIPTokenizer],
|
||||
unet: UNet2DConditionModel,
|
||||
unet: Union[sdxl_original_unet.SdxlUNet2DConditionModel, sdxl_original_control_net.SdxlControlledUNet],
|
||||
scheduler: SchedulerMixin,
|
||||
# clip_skip: int,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
@@ -594,74 +602,6 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
max_embeddings_multiples,
|
||||
is_sdxl_text_encoder2,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `list(int)`):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `guidance_scale` is less than `1`).
|
||||
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
||||
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
||||
"""
|
||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||
|
||||
if negative_prompt is None:
|
||||
negative_prompt = [""] * batch_size
|
||||
elif isinstance(negative_prompt, str):
|
||||
negative_prompt = [negative_prompt] * batch_size
|
||||
if batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
text_embeddings, text_pool, uncond_embeddings, uncond_pool = get_weighted_text_embeddings(
|
||||
pipe=self,
|
||||
prompt=prompt,
|
||||
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
|
||||
max_embeddings_multiples=max_embeddings_multiples,
|
||||
clip_skip=self.clip_skip,
|
||||
is_sdxl_text_encoder2=is_sdxl_text_encoder2,
|
||||
)
|
||||
bs_embed, seq_len, _ = text_embeddings.shape
|
||||
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) # ??
|
||||
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
if text_pool is not None:
|
||||
text_pool = text_pool.repeat(1, num_images_per_prompt)
|
||||
text_pool = text_pool.view(bs_embed * num_images_per_prompt, -1)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
bs_embed, seq_len, _ = uncond_embeddings.shape
|
||||
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||
uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
if uncond_pool is not None:
|
||||
uncond_pool = uncond_pool.repeat(1, num_images_per_prompt)
|
||||
uncond_pool = uncond_pool.view(bs_embed * num_images_per_prompt, -1)
|
||||
|
||||
return text_embeddings, text_pool, uncond_embeddings, uncond_pool
|
||||
|
||||
return text_embeddings, text_pool, None, None
|
||||
|
||||
def check_inputs(self, prompt, height, width, strength, callback_steps):
|
||||
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
@@ -792,7 +732,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
|
||||
max_embeddings_multiples: Optional[int] = 3,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
controlnet=None,
|
||||
controlnet: sdxl_original_control_net.SdxlControlNet = None,
|
||||
controlnet_image=None,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
||||
@@ -896,32 +836,24 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
# 実装を簡単にするためにtokenzer/text encoderを切り替えて二回呼び出す
|
||||
# To simplify the implementation, switch the tokenzer/text encoder and call it twice
|
||||
text_embeddings_list = []
|
||||
text_pool = None
|
||||
uncond_embeddings_list = []
|
||||
uncond_pool = None
|
||||
for i in range(len(self.tokenizers)):
|
||||
self.tokenizer = self.tokenizers[i]
|
||||
self.text_encoder = self.text_encoders[i]
|
||||
tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
|
||||
encoding_strategy: strategy_sdxl.SdxlTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
|
||||
text_embeddings, tp1, uncond_embeddings, up1 = self._encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
max_embeddings_multiples,
|
||||
is_sdxl_text_encoder2=i == 1,
|
||||
text_input_ids, text_weights = tokenize_strategy.tokenize_with_weights(prompt)
|
||||
hidden_states_1, hidden_states_2, text_pool = encoding_strategy.encode_tokens_with_weights(
|
||||
tokenize_strategy, self.text_encoders, text_input_ids, text_weights
|
||||
)
|
||||
text_embeddings = torch.cat([hidden_states_1, hidden_states_2], dim=-1)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
input_ids, weights = tokenize_strategy.tokenize_with_weights(negative_prompt or "")
|
||||
hidden_states_1, hidden_states_2, uncond_pool = encoding_strategy.encode_tokens_with_weights(
|
||||
tokenize_strategy, self.text_encoders, input_ids, weights
|
||||
)
|
||||
text_embeddings_list.append(text_embeddings)
|
||||
uncond_embeddings_list.append(uncond_embeddings)
|
||||
|
||||
if tp1 is not None:
|
||||
text_pool = tp1
|
||||
if up1 is not None:
|
||||
uncond_pool = up1
|
||||
uncond_embeddings = torch.cat([hidden_states_1, hidden_states_2], dim=-1)
|
||||
else:
|
||||
uncond_embeddings = None
|
||||
uncond_pool = None
|
||||
|
||||
unet_dtype = self.unet.dtype
|
||||
dtype = unet_dtype
|
||||
@@ -970,23 +902,23 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# create size embs and concat embeddings for SDXL
|
||||
orig_size = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1).to(dtype)
|
||||
orig_size = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1).to(device, dtype)
|
||||
crop_size = torch.zeros_like(orig_size)
|
||||
target_size = orig_size
|
||||
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, device).to(dtype)
|
||||
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, device).to(device, dtype)
|
||||
|
||||
# make conditionings
|
||||
text_pool = text_pool.to(device, dtype)
|
||||
if do_classifier_free_guidance:
|
||||
text_embeddings = torch.cat(text_embeddings_list, dim=2)
|
||||
uncond_embeddings = torch.cat(uncond_embeddings_list, dim=2)
|
||||
text_embedding = torch.cat([uncond_embeddings, text_embeddings]).to(dtype)
|
||||
text_embedding = torch.cat([uncond_embeddings, text_embeddings]).to(device, dtype)
|
||||
|
||||
cond_vector = torch.cat([text_pool, embs], dim=1)
|
||||
uncond_vector = torch.cat([uncond_pool, embs], dim=1)
|
||||
vector_embedding = torch.cat([uncond_vector, cond_vector]).to(dtype)
|
||||
uncond_pool = uncond_pool.to(device, dtype)
|
||||
cond_vector = torch.cat([text_pool, embs], dim=1).to(dtype)
|
||||
uncond_vector = torch.cat([uncond_pool, embs], dim=1).to(dtype)
|
||||
vector_embedding = torch.cat([uncond_vector, cond_vector])
|
||||
else:
|
||||
text_embedding = torch.cat(text_embeddings_list, dim=2).to(dtype)
|
||||
vector_embedding = torch.cat([text_pool, embs], dim=1).to(dtype)
|
||||
text_embedding = text_embeddings.to(device, dtype)
|
||||
vector_embedding = torch.cat([text_pool, embs], dim=1)
|
||||
|
||||
# 8. Denoising loop
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
@@ -994,22 +926,14 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
unet_additional_args = {}
|
||||
if controlnet is not None:
|
||||
down_block_res_samples, mid_block_res_sample = controlnet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=text_embeddings,
|
||||
controlnet_cond=controlnet_image,
|
||||
conditioning_scale=1.0,
|
||||
guess_mode=False,
|
||||
return_dict=False,
|
||||
)
|
||||
unet_additional_args["down_block_additional_residuals"] = down_block_res_samples
|
||||
unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample
|
||||
# FIXME SD1 ControlNet is not working
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding)
|
||||
if controlnet is not None:
|
||||
input_resi_add, mid_add = controlnet(latent_model_input, t, text_embedding, vector_embedding, controlnet_image)
|
||||
noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding, input_resi_add, mid_add)
|
||||
else:
|
||||
noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding)
|
||||
noise_pred = noise_pred.to(dtype) # U-Net changes dtype in LoRA training
|
||||
|
||||
# perform guidance
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import safetensors
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.utils.modeling import set_module_tensor_to_device
|
||||
from safetensors.torch import load_file, save_file
|
||||
@@ -7,7 +8,12 @@ from typing import List
|
||||
from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
|
||||
from library import model_util
|
||||
from library import sdxl_original_unet
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VAE_SCALE_FACTOR = 0.13025
|
||||
MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_base_v1-0"
|
||||
@@ -131,7 +137,7 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
|
||||
|
||||
# temporary workaround for text_projection.weight.weight for Playground-v2
|
||||
if "text_projection.weight.weight" in new_sd:
|
||||
print(f"convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight")
|
||||
logger.info("convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight")
|
||||
new_sd["text_projection.weight"] = new_sd["text_projection.weight.weight"]
|
||||
del new_sd["text_projection.weight.weight"]
|
||||
|
||||
@@ -160,17 +166,20 @@ def _load_state_dict_on_device(model, state_dict, device, dtype=None):
|
||||
raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
|
||||
|
||||
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None):
|
||||
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None, disable_mmap=False):
|
||||
# model_version is reserved for future use
|
||||
# dtype is used for full_fp16/bf16 integration. Text Encoder will remain fp32, because it runs on CPU when caching
|
||||
|
||||
# Load the state dict
|
||||
if model_util.is_safetensors(ckpt_path):
|
||||
checkpoint = None
|
||||
try:
|
||||
state_dict = load_file(ckpt_path, device=map_location)
|
||||
except:
|
||||
state_dict = load_file(ckpt_path) # prevent device invalid Error
|
||||
if disable_mmap:
|
||||
state_dict = safetensors.torch.load(open(ckpt_path, "rb").read())
|
||||
else:
|
||||
try:
|
||||
state_dict = load_file(ckpt_path, device=map_location)
|
||||
except:
|
||||
state_dict = load_file(ckpt_path) # prevent device invalid Error
|
||||
epoch = None
|
||||
global_step = None
|
||||
else:
|
||||
@@ -186,20 +195,20 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
|
||||
checkpoint = None
|
||||
|
||||
# U-Net
|
||||
print("building U-Net")
|
||||
logger.info("building U-Net")
|
||||
with init_empty_weights():
|
||||
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
|
||||
|
||||
print("loading U-Net from checkpoint")
|
||||
logger.info("loading U-Net from checkpoint")
|
||||
unet_sd = {}
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith("model.diffusion_model."):
|
||||
unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k)
|
||||
info = _load_state_dict_on_device(unet, unet_sd, device=map_location, dtype=dtype)
|
||||
print("U-Net: ", info)
|
||||
logger.info(f"U-Net: {info}")
|
||||
|
||||
# Text Encoders
|
||||
print("building text encoders")
|
||||
logger.info("building text encoders")
|
||||
|
||||
# Text Encoder 1 is same to Stability AI's SDXL
|
||||
text_model1_cfg = CLIPTextConfig(
|
||||
@@ -252,7 +261,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
|
||||
with init_empty_weights():
|
||||
text_model2 = CLIPTextModelWithProjection(text_model2_cfg)
|
||||
|
||||
print("loading text encoders from checkpoint")
|
||||
logger.info("loading text encoders from checkpoint")
|
||||
te1_sd = {}
|
||||
te2_sd = {}
|
||||
for k in list(state_dict.keys()):
|
||||
@@ -266,22 +275,22 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
|
||||
te1_sd.pop("text_model.embeddings.position_ids")
|
||||
|
||||
info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32
|
||||
print("text encoder 1:", info1)
|
||||
logger.info(f"text encoder 1: {info1}")
|
||||
|
||||
converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77)
|
||||
info2 = _load_state_dict_on_device(text_model2, converted_sd, device=map_location) # remain fp32
|
||||
print("text encoder 2:", info2)
|
||||
logger.info(f"text encoder 2: {info2}")
|
||||
|
||||
# prepare vae
|
||||
print("building VAE")
|
||||
logger.info("building VAE")
|
||||
vae_config = model_util.create_vae_diffusers_config()
|
||||
with init_empty_weights():
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
|
||||
print("loading VAE from checkpoint")
|
||||
logger.info("loading VAE from checkpoint")
|
||||
converted_vae_checkpoint = model_util.convert_ldm_vae_checkpoint(state_dict, vae_config)
|
||||
info = _load_state_dict_on_device(vae, converted_vae_checkpoint, device=map_location, dtype=dtype)
|
||||
print("VAE:", info)
|
||||
logger.info(f"VAE: {info}")
|
||||
|
||||
ckpt_info = (epoch, global_step) if epoch is not None else None
|
||||
return text_model1, text_model2, vae, unet, logit_scale, ckpt_info
|
||||
|
||||
272
library/sdxl_original_control_net.py
Normal file
272
library/sdxl_original_control_net.py
Normal file
@@ -0,0 +1,272 @@
|
||||
# some parts are modified from Diffusers library (Apache License 2.0)
|
||||
|
||||
import math
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Optional
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from einops import rearrange
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from library import sdxl_original_unet
|
||||
from library.sdxl_model_util import convert_sdxl_unet_state_dict_to_diffusers, convert_diffusers_unet_state_dict_to_sdxl
|
||||
|
||||
|
||||
class ControlNetConditioningEmbedding(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
dims = [16, 32, 96, 256]
|
||||
|
||||
self.conv_in = nn.Conv2d(3, dims[0], kernel_size=3, padding=1)
|
||||
self.blocks = nn.ModuleList([])
|
||||
|
||||
for i in range(len(dims) - 1):
|
||||
channel_in = dims[i]
|
||||
channel_out = dims[i + 1]
|
||||
self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
|
||||
self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
|
||||
|
||||
self.conv_out = nn.Conv2d(dims[-1], 320, kernel_size=3, padding=1)
|
||||
nn.init.zeros_(self.conv_out.weight) # zero module weight
|
||||
nn.init.zeros_(self.conv_out.bias) # zero module bias
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_in(x)
|
||||
x = F.silu(x)
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
x = F.silu(x)
|
||||
x = self.conv_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class SdxlControlNet(sdxl_original_unet.SdxlUNet2DConditionModel):
|
||||
def __init__(self, multiplier: Optional[float] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.multiplier = multiplier
|
||||
|
||||
# remove unet layers
|
||||
self.output_blocks = nn.ModuleList([])
|
||||
del self.out
|
||||
|
||||
self.controlnet_cond_embedding = ControlNetConditioningEmbedding()
|
||||
|
||||
dims = [320, 320, 320, 320, 640, 640, 640, 1280, 1280]
|
||||
self.controlnet_down_blocks = nn.ModuleList([])
|
||||
for dim in dims:
|
||||
self.controlnet_down_blocks.append(nn.Conv2d(dim, dim, kernel_size=1))
|
||||
nn.init.zeros_(self.controlnet_down_blocks[-1].weight) # zero module weight
|
||||
nn.init.zeros_(self.controlnet_down_blocks[-1].bias) # zero module bias
|
||||
|
||||
self.controlnet_mid_block = nn.Conv2d(1280, 1280, kernel_size=1)
|
||||
nn.init.zeros_(self.controlnet_mid_block.weight) # zero module weight
|
||||
nn.init.zeros_(self.controlnet_mid_block.bias) # zero module bias
|
||||
|
||||
def init_from_unet(self, unet: sdxl_original_unet.SdxlUNet2DConditionModel):
|
||||
unet_sd = unet.state_dict()
|
||||
unet_sd = {k: v for k, v in unet_sd.items() if not k.startswith("out")}
|
||||
sd = super().state_dict()
|
||||
sd.update(unet_sd)
|
||||
info = super().load_state_dict(sd, strict=True, assign=True)
|
||||
return info
|
||||
|
||||
def load_state_dict(self, state_dict: dict, strict: bool = True, assign: bool = True) -> Any:
|
||||
# convert state_dict to SAI format
|
||||
unet_sd = {}
|
||||
for k in list(state_dict.keys()):
|
||||
if not k.startswith("controlnet_"):
|
||||
unet_sd[k] = state_dict.pop(k)
|
||||
unet_sd = convert_diffusers_unet_state_dict_to_sdxl(unet_sd)
|
||||
state_dict.update(unet_sd)
|
||||
super().load_state_dict(state_dict, strict=strict, assign=assign)
|
||||
|
||||
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
||||
# convert state_dict to Diffusers format
|
||||
state_dict = super().state_dict(destination, prefix, keep_vars)
|
||||
control_net_sd = {}
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith("controlnet_"):
|
||||
control_net_sd[k] = state_dict.pop(k)
|
||||
state_dict = convert_sdxl_unet_state_dict_to_diffusers(state_dict)
|
||||
state_dict.update(control_net_sd)
|
||||
return state_dict
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timesteps: Optional[torch.Tensor] = None,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
cond_image: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
# broadcast timesteps to batch dimension
|
||||
timesteps = timesteps.expand(x.shape[0])
|
||||
|
||||
t_emb = sdxl_original_unet.get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0)
|
||||
t_emb = t_emb.to(x.dtype)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
|
||||
assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
def call_module(module, h, emb, context):
|
||||
x = h
|
||||
for layer in module:
|
||||
if isinstance(layer, sdxl_original_unet.ResnetBlock2D):
|
||||
x = layer(x, emb)
|
||||
elif isinstance(layer, sdxl_original_unet.Transformer2DModel):
|
||||
x = layer(x, context)
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
h = x
|
||||
multiplier = self.multiplier if self.multiplier is not None else 1.0
|
||||
hs = []
|
||||
for i, module in enumerate(self.input_blocks):
|
||||
h = call_module(module, h, emb, context)
|
||||
if i == 0:
|
||||
h = self.controlnet_cond_embedding(cond_image) + h
|
||||
hs.append(self.controlnet_down_blocks[i](h) * multiplier)
|
||||
|
||||
h = call_module(self.middle_block, h, emb, context)
|
||||
h = self.controlnet_mid_block(h) * multiplier
|
||||
|
||||
return hs, h
|
||||
|
||||
|
||||
class SdxlControlledUNet(sdxl_original_unet.SdxlUNet2DConditionModel):
|
||||
"""
|
||||
This class is for training purpose only.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, y=None, input_resi_add=None, mid_add=None, **kwargs):
|
||||
# broadcast timesteps to batch dimension
|
||||
timesteps = timesteps.expand(x.shape[0])
|
||||
|
||||
hs = []
|
||||
t_emb = sdxl_original_unet.get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0)
|
||||
t_emb = t_emb.to(x.dtype)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
|
||||
assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
def call_module(module, h, emb, context):
|
||||
x = h
|
||||
for layer in module:
|
||||
if isinstance(layer, sdxl_original_unet.ResnetBlock2D):
|
||||
x = layer(x, emb)
|
||||
elif isinstance(layer, sdxl_original_unet.Transformer2DModel):
|
||||
x = layer(x, context)
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
h = x
|
||||
for module in self.input_blocks:
|
||||
h = call_module(module, h, emb, context)
|
||||
hs.append(h)
|
||||
|
||||
h = call_module(self.middle_block, h, emb, context)
|
||||
h = h + mid_add
|
||||
|
||||
for module in self.output_blocks:
|
||||
resi = hs.pop() + input_resi_add.pop()
|
||||
h = torch.cat([h, resi], dim=1)
|
||||
h = call_module(module, h, emb, context)
|
||||
|
||||
h = h.type(x.dtype)
|
||||
h = call_module(self.out, h, emb, context)
|
||||
|
||||
return h
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
|
||||
logger.info("create unet")
|
||||
unet = SdxlControlledUNet()
|
||||
unet.to("cuda", torch.bfloat16)
|
||||
unet.set_use_sdpa(True)
|
||||
unet.set_gradient_checkpointing(True)
|
||||
unet.train()
|
||||
|
||||
logger.info("create control_net")
|
||||
control_net = SdxlControlNet()
|
||||
control_net.to("cuda")
|
||||
control_net.set_use_sdpa(True)
|
||||
control_net.set_gradient_checkpointing(True)
|
||||
control_net.train()
|
||||
|
||||
logger.info("Initialize control_net from unet")
|
||||
control_net.init_from_unet(unet)
|
||||
|
||||
unet.requires_grad_(False)
|
||||
control_net.requires_grad_(True)
|
||||
|
||||
# 使用メモリ量確認用の疑似学習ループ
|
||||
logger.info("preparing optimizer")
|
||||
|
||||
# optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working
|
||||
|
||||
import bitsandbytes
|
||||
|
||||
optimizer = bitsandbytes.adam.Adam8bit(control_net.parameters(), lr=1e-3) # not working
|
||||
# optimizer = bitsandbytes.optim.RMSprop8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2
|
||||
# optimizer=bitsandbytes.optim.Adagrad8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2
|
||||
|
||||
# import transformers
|
||||
# optimizer = transformers.optimization.Adafactor(unet.parameters(), relative_step=True) # working at 22.2GB with torch2
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=True)
|
||||
|
||||
logger.info("start training")
|
||||
steps = 10
|
||||
batch_size = 1
|
||||
|
||||
for step in range(steps):
|
||||
logger.info(f"step {step}")
|
||||
if step == 1:
|
||||
time_start = time.perf_counter()
|
||||
|
||||
x = torch.randn(batch_size, 4, 128, 128).cuda() # 1024x1024
|
||||
t = torch.randint(low=0, high=1000, size=(batch_size,), device="cuda")
|
||||
txt = torch.randn(batch_size, 77, 2048).cuda()
|
||||
vector = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda()
|
||||
cond_img = torch.rand(batch_size, 3, 1024, 1024).cuda()
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
|
||||
input_resi_add, mid_add = control_net(x, t, txt, vector, cond_img)
|
||||
output = unet(x, t, txt, vector, input_resi_add, mid_add)
|
||||
target = torch.randn_like(output)
|
||||
loss = torch.nn.functional.mse_loss(output, target)
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
time_end = time.perf_counter()
|
||||
logger.info(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps")
|
||||
|
||||
logger.info("finish training")
|
||||
sd = control_net.state_dict()
|
||||
|
||||
from safetensors.torch import save_file
|
||||
|
||||
save_file(sd, r"E:\Work\SD\Tmp\sdxl\ctrl\control_net.safetensors")
|
||||
@@ -30,7 +30,12 @@ import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from einops import rearrange
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
IN_CHANNELS: int = 4
|
||||
OUT_CHANNELS: int = 4
|
||||
@@ -332,7 +337,7 @@ class ResnetBlock2D(nn.Module):
|
||||
|
||||
def forward(self, x, emb):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
# print("ResnetBlock2D: gradient_checkpointing")
|
||||
# logger.info("ResnetBlock2D: gradient_checkpointing")
|
||||
|
||||
def create_custom_forward(func):
|
||||
def custom_forward(*inputs):
|
||||
@@ -366,7 +371,7 @@ class Downsample2D(nn.Module):
|
||||
|
||||
def forward(self, hidden_states):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
# print("Downsample2D: gradient_checkpointing")
|
||||
# logger.info("Downsample2D: gradient_checkpointing")
|
||||
|
||||
def create_custom_forward(func):
|
||||
def custom_forward(*inputs):
|
||||
@@ -653,7 +658,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
|
||||
def forward(self, hidden_states, context=None, timestep=None):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
# print("BasicTransformerBlock: checkpointing")
|
||||
# logger.info("BasicTransformerBlock: checkpointing")
|
||||
|
||||
def create_custom_forward(func):
|
||||
def custom_forward(*inputs):
|
||||
@@ -796,7 +801,7 @@ class Upsample2D(nn.Module):
|
||||
|
||||
def forward(self, hidden_states, output_size=None):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
# print("Upsample2D: gradient_checkpointing")
|
||||
# logger.info("Upsample2D: gradient_checkpointing")
|
||||
|
||||
def create_custom_forward(func):
|
||||
def custom_forward(*inputs):
|
||||
@@ -1046,7 +1051,7 @@ class SdxlUNet2DConditionModel(nn.Module):
|
||||
for block in blocks:
|
||||
for module in block:
|
||||
if hasattr(module, "set_use_memory_efficient_attention"):
|
||||
# print(module.__class__.__name__)
|
||||
# logger.info(module.__class__.__name__)
|
||||
module.set_use_memory_efficient_attention(xformers, mem_eff)
|
||||
|
||||
def set_use_sdpa(self, sdpa: bool) -> None:
|
||||
@@ -1061,7 +1066,7 @@ class SdxlUNet2DConditionModel(nn.Module):
|
||||
for block in blocks:
|
||||
for module in block.modules():
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
# print(module.__class__.__name__, module.gradient_checkpointing, "->", value)
|
||||
# logger.info(f{module.__class__.__name__} {module.gradient_checkpointing} -> {value}")
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
# endregion
|
||||
@@ -1071,7 +1076,7 @@ class SdxlUNet2DConditionModel(nn.Module):
|
||||
timesteps = timesteps.expand(x.shape[0])
|
||||
|
||||
hs = []
|
||||
t_emb = get_timestep_embedding(timesteps, self.model_channels) # , repeat_only=False)
|
||||
t_emb = get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0) # , repeat_only=False)
|
||||
t_emb = t_emb.to(x.dtype)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
@@ -1083,7 +1088,7 @@ class SdxlUNet2DConditionModel(nn.Module):
|
||||
def call_module(module, h, emb, context):
|
||||
x = h
|
||||
for layer in module:
|
||||
# print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None)
|
||||
# logger.info(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None)
|
||||
if isinstance(layer, ResnetBlock2D):
|
||||
x = layer(x, emb)
|
||||
elif isinstance(layer, Transformer2DModel):
|
||||
@@ -1129,20 +1134,20 @@ class InferSdxlUNet2DConditionModel:
|
||||
# call original model's methods
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.delegate, name)
|
||||
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.delegate(*args, **kwargs)
|
||||
|
||||
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
|
||||
if ds_depth_1 is None:
|
||||
print("Deep Shrink is disabled.")
|
||||
logger.info("Deep Shrink is disabled.")
|
||||
self.ds_depth_1 = None
|
||||
self.ds_timesteps_1 = None
|
||||
self.ds_depth_2 = None
|
||||
self.ds_timesteps_2 = None
|
||||
self.ds_ratio = None
|
||||
else:
|
||||
print(
|
||||
logger.info(
|
||||
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
|
||||
)
|
||||
self.ds_depth_1 = ds_depth_1
|
||||
@@ -1151,9 +1156,9 @@ class InferSdxlUNet2DConditionModel:
|
||||
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
|
||||
self.ds_ratio = ds_ratio
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
||||
def forward(self, x, timesteps=None, context=None, y=None, input_resi_add=None, mid_add=None, **kwargs):
|
||||
r"""
|
||||
current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink.
|
||||
current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink and ControlNet.
|
||||
"""
|
||||
_self = self.delegate
|
||||
|
||||
@@ -1161,7 +1166,7 @@ class InferSdxlUNet2DConditionModel:
|
||||
timesteps = timesteps.expand(x.shape[0])
|
||||
|
||||
hs = []
|
||||
t_emb = get_timestep_embedding(timesteps, _self.model_channels) # , repeat_only=False)
|
||||
t_emb = get_timestep_embedding(timesteps, _self.model_channels, downscale_freq_shift=0) # , repeat_only=False)
|
||||
t_emb = t_emb.to(x.dtype)
|
||||
emb = _self.time_embed(t_emb)
|
||||
|
||||
@@ -1204,6 +1209,8 @@ class InferSdxlUNet2DConditionModel:
|
||||
hs.append(h)
|
||||
|
||||
h = call_module(_self.middle_block, h, emb, context)
|
||||
if mid_add is not None:
|
||||
h = h + mid_add
|
||||
|
||||
for module in _self.output_blocks:
|
||||
# Deep Shrink
|
||||
@@ -1212,7 +1219,11 @@ class InferSdxlUNet2DConditionModel:
|
||||
# print("upsample", h.shape, hs[-1].shape)
|
||||
h = resize_like(h, hs[-1])
|
||||
|
||||
h = torch.cat([h, hs.pop()], dim=1)
|
||||
resi = hs.pop()
|
||||
if input_resi_add is not None:
|
||||
resi = resi + input_resi_add.pop()
|
||||
|
||||
h = torch.cat([h, resi], dim=1)
|
||||
h = call_module(module, h, emb, context)
|
||||
|
||||
# Deep Shrink: in case of depth 0
|
||||
@@ -1229,7 +1240,7 @@ class InferSdxlUNet2DConditionModel:
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
|
||||
print("create unet")
|
||||
logger.info("create unet")
|
||||
unet = SdxlUNet2DConditionModel()
|
||||
|
||||
unet.to("cuda")
|
||||
@@ -1238,7 +1249,7 @@ if __name__ == "__main__":
|
||||
unet.train()
|
||||
|
||||
# 使用メモリ量確認用の疑似学習ループ
|
||||
print("preparing optimizer")
|
||||
logger.info("preparing optimizer")
|
||||
|
||||
# optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working
|
||||
|
||||
@@ -1253,12 +1264,12 @@ if __name__ == "__main__":
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=True)
|
||||
|
||||
print("start training")
|
||||
logger.info("start training")
|
||||
steps = 10
|
||||
batch_size = 1
|
||||
|
||||
for step in range(steps):
|
||||
print(f"step {step}")
|
||||
logger.info(f"step {step}")
|
||||
if step == 1:
|
||||
time_start = time.perf_counter()
|
||||
|
||||
@@ -1278,4 +1289,4 @@ if __name__ == "__main__":
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
time_end = time.perf_counter()
|
||||
print(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps")
|
||||
logger.info(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps")
|
||||
|
||||
@@ -1,14 +1,23 @@
|
||||
import argparse
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from accelerate import init_empty_weights
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPTokenizer
|
||||
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
|
||||
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
|
||||
from .utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
|
||||
TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
@@ -17,11 +26,10 @@ TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
|
||||
|
||||
def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
||||
# load models for each process
|
||||
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16
|
||||
for pi in range(accelerator.state.num_processes):
|
||||
if pi == accelerator.state.local_process_index:
|
||||
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
|
||||
logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
|
||||
|
||||
(
|
||||
load_stable_diffusion_format,
|
||||
@@ -38,6 +46,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
||||
weight_dtype,
|
||||
accelerator.device if args.lowram else "cpu",
|
||||
model_dtype,
|
||||
args.disable_mmap_load_safetensors,
|
||||
)
|
||||
|
||||
# work on low-ram device
|
||||
@@ -47,22 +56,21 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
||||
unet.to(accelerator.device)
|
||||
vae.to(accelerator.device)
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
clean_memory_on_device(accelerator.device)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
|
||||
|
||||
|
||||
def _load_target_model(
|
||||
name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None
|
||||
name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None, disable_mmap=False
|
||||
):
|
||||
# model_dtype only work with full fp16/bf16
|
||||
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
||||
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
||||
|
||||
if load_stable_diffusion_format:
|
||||
print(f"load StableDiffusion checkpoint: {name_or_path}")
|
||||
logger.info(f"load StableDiffusion checkpoint: {name_or_path}")
|
||||
(
|
||||
text_encoder1,
|
||||
text_encoder2,
|
||||
@@ -70,13 +78,13 @@ def _load_target_model(
|
||||
unet,
|
||||
logit_scale,
|
||||
ckpt_info,
|
||||
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype)
|
||||
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype, disable_mmap)
|
||||
else:
|
||||
# Diffusers model is loaded to CPU
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
|
||||
variant = "fp16" if weight_dtype == torch.float16 else None
|
||||
print(f"load Diffusers pretrained models: {name_or_path}, variant={variant}")
|
||||
logger.info(f"load Diffusers pretrained models: {name_or_path}, variant={variant}")
|
||||
try:
|
||||
try:
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
@@ -84,12 +92,12 @@ def _load_target_model(
|
||||
)
|
||||
except EnvironmentError as ex:
|
||||
if variant is not None:
|
||||
print("try to load fp32 model")
|
||||
logger.info("try to load fp32 model")
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, variant=None, tokenizer=None)
|
||||
else:
|
||||
raise ex
|
||||
except EnvironmentError as ex:
|
||||
print(
|
||||
logger.error(
|
||||
f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}"
|
||||
)
|
||||
raise ex
|
||||
@@ -112,7 +120,7 @@ def _load_target_model(
|
||||
with init_empty_weights():
|
||||
unet = sdxl_original_unet.SdxlUNet2DConditionModel() # overwrite unet
|
||||
sdxl_model_util._load_state_dict_on_device(unet, state_dict, device=device, dtype=model_dtype)
|
||||
print("U-Net converted to original U-Net")
|
||||
logger.info("U-Net converted to original U-Net")
|
||||
|
||||
logit_scale = None
|
||||
ckpt_info = None
|
||||
@@ -120,13 +128,13 @@ def _load_target_model(
|
||||
# VAEを読み込む
|
||||
if vae_path is not None:
|
||||
vae = model_util.load_vae(vae_path, weight_dtype)
|
||||
print("additional VAE loaded")
|
||||
logger.info("additional VAE loaded")
|
||||
|
||||
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
|
||||
|
||||
|
||||
def load_tokenizers(args: argparse.Namespace):
|
||||
print("prepare tokenizers")
|
||||
logger.info("prepare tokenizers")
|
||||
|
||||
original_paths = [TOKENIZER1_PATH, TOKENIZER2_PATH]
|
||||
tokeniers = []
|
||||
@@ -135,14 +143,14 @@ def load_tokenizers(args: argparse.Namespace):
|
||||
if args.tokenizer_cache_dir:
|
||||
local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
|
||||
if os.path.exists(local_tokenizer_path):
|
||||
print(f"load tokenizer from cache: {local_tokenizer_path}")
|
||||
logger.info(f"load tokenizer from cache: {local_tokenizer_path}")
|
||||
tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path)
|
||||
|
||||
if tokenizer is None:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(original_path)
|
||||
|
||||
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
|
||||
print(f"save Tokenizer to cache: {local_tokenizer_path}")
|
||||
logger.info(f"save Tokenizer to cache: {local_tokenizer_path}")
|
||||
tokenizer.save_pretrained(local_tokenizer_path)
|
||||
|
||||
if i == 1:
|
||||
@@ -151,7 +159,7 @@ def load_tokenizers(args: argparse.Namespace):
|
||||
tokeniers.append(tokenizer)
|
||||
|
||||
if hasattr(args, "max_token_length") and args.max_token_length is not None:
|
||||
print(f"update token length: {args.max_token_length}")
|
||||
logger.info(f"update token length: {args.max_token_length}")
|
||||
|
||||
return tokeniers
|
||||
|
||||
@@ -318,7 +326,7 @@ def save_sd_model_on_epoch_end_or_stepwise(
|
||||
)
|
||||
|
||||
|
||||
def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
|
||||
def add_sdxl_training_arguments(parser: argparse.ArgumentParser, support_text_encoder_caching: bool = True):
|
||||
parser.add_argument(
|
||||
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
|
||||
)
|
||||
@@ -327,41 +335,46 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
|
||||
action="store_true",
|
||||
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable_mmap_load_safetensors",
|
||||
action="store_true",
|
||||
help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる",
|
||||
)
|
||||
|
||||
|
||||
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
|
||||
assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
|
||||
if args.v_parameterization:
|
||||
print("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります")
|
||||
|
||||
if args.clip_skip is not None:
|
||||
print("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
|
||||
logger.warning("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
|
||||
|
||||
# if args.multires_noise_iterations:
|
||||
# print(
|
||||
# logger.info(
|
||||
# f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります"
|
||||
# )
|
||||
# else:
|
||||
# if args.noise_offset is None:
|
||||
# args.noise_offset = DEFAULT_NOISE_OFFSET
|
||||
# elif args.noise_offset != DEFAULT_NOISE_OFFSET:
|
||||
# print(
|
||||
# logger.info(
|
||||
# f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています"
|
||||
# )
|
||||
# print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
|
||||
# logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
|
||||
|
||||
assert (
|
||||
not hasattr(args, "weighted_captions") or not args.weighted_captions
|
||||
), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
|
||||
# assert (
|
||||
# not hasattr(args, "weighted_captions") or not args.weighted_captions
|
||||
# ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
|
||||
|
||||
if supportTextEncoderCaching:
|
||||
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
||||
args.cache_text_encoder_outputs = True
|
||||
print(
|
||||
logger.warning(
|
||||
"cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / "
|
||||
+ "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました"
|
||||
)
|
||||
|
||||
|
||||
def sample_images(*args, **kwargs):
|
||||
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
|
||||
|
||||
return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
|
||||
|
||||
@@ -26,7 +26,10 @@ from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
||||
from diffusers.models.vae import DecoderOutput, DiagonalGaussianDistribution
|
||||
from diffusers.models.autoencoder_kl import AutoencoderKLOutput
|
||||
|
||||
from .utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def slice_h(x, num_slices):
|
||||
# slice with pad 1 both sides: to eliminate side effect of padding of conv2d
|
||||
@@ -89,7 +92,7 @@ def resblock_forward(_self, num_slices, input_tensor, temb, **kwargs):
|
||||
# sliced_tensor = torch.chunk(x, num_div, dim=1)
|
||||
# sliced_weight = torch.chunk(norm.weight, num_div, dim=0)
|
||||
# sliced_bias = torch.chunk(norm.bias, num_div, dim=0)
|
||||
# print(sliced_tensor[0].shape, num_div, sliced_weight[0].shape, sliced_bias[0].shape)
|
||||
# logger.info(sliced_tensor[0].shape, num_div, sliced_weight[0].shape, sliced_bias[0].shape)
|
||||
# normed_tensor = []
|
||||
# for i in range(num_div):
|
||||
# n = torch.group_norm(sliced_tensor[i], norm.num_groups, sliced_weight[i], sliced_bias[i], norm.eps)
|
||||
@@ -243,7 +246,7 @@ class SlicingEncoder(nn.Module):
|
||||
|
||||
self.num_slices = num_slices
|
||||
div = num_slices / (2 ** (len(self.down_blocks) - 1)) # 深い層はそこまで分割しなくていいので適宜減らす
|
||||
# print(f"initial divisor: {div}")
|
||||
# logger.info(f"initial divisor: {div}")
|
||||
if div >= 2:
|
||||
div = int(div)
|
||||
for resnet in self.mid_block.resnets:
|
||||
@@ -253,11 +256,11 @@ class SlicingEncoder(nn.Module):
|
||||
for i, down_block in enumerate(self.down_blocks[::-1]):
|
||||
if div >= 2:
|
||||
div = int(div)
|
||||
# print(f"down block: {i} divisor: {div}")
|
||||
# logger.info(f"down block: {i} divisor: {div}")
|
||||
for resnet in down_block.resnets:
|
||||
resnet.forward = wrapper(resblock_forward, resnet, div)
|
||||
if down_block.downsamplers is not None:
|
||||
# print("has downsample")
|
||||
# logger.info("has downsample")
|
||||
for downsample in down_block.downsamplers:
|
||||
downsample.forward = wrapper(self.downsample_forward, downsample, div * 2)
|
||||
div *= 2
|
||||
@@ -307,7 +310,7 @@ class SlicingEncoder(nn.Module):
|
||||
def downsample_forward(self, _self, num_slices, hidden_states):
|
||||
assert hidden_states.shape[1] == _self.channels
|
||||
assert _self.use_conv and _self.padding == 0
|
||||
print("downsample forward", num_slices, hidden_states.shape)
|
||||
logger.info(f"downsample forward {num_slices} {hidden_states.shape}")
|
||||
|
||||
org_device = hidden_states.device
|
||||
cpu_device = torch.device("cpu")
|
||||
@@ -350,7 +353,7 @@ class SlicingEncoder(nn.Module):
|
||||
hidden_states = torch.cat([hidden_states, x], dim=2)
|
||||
|
||||
hidden_states = hidden_states.to(org_device)
|
||||
# print("downsample forward done", hidden_states.shape)
|
||||
# logger.info(f"downsample forward done {hidden_states.shape}")
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -426,7 +429,7 @@ class SlicingDecoder(nn.Module):
|
||||
|
||||
self.num_slices = num_slices
|
||||
div = num_slices / (2 ** (len(self.up_blocks) - 1))
|
||||
print(f"initial divisor: {div}")
|
||||
logger.info(f"initial divisor: {div}")
|
||||
if div >= 2:
|
||||
div = int(div)
|
||||
for resnet in self.mid_block.resnets:
|
||||
@@ -436,11 +439,11 @@ class SlicingDecoder(nn.Module):
|
||||
for i, up_block in enumerate(self.up_blocks):
|
||||
if div >= 2:
|
||||
div = int(div)
|
||||
# print(f"up block: {i} divisor: {div}")
|
||||
# logger.info(f"up block: {i} divisor: {div}")
|
||||
for resnet in up_block.resnets:
|
||||
resnet.forward = wrapper(resblock_forward, resnet, div)
|
||||
if up_block.upsamplers is not None:
|
||||
# print("has upsample")
|
||||
# logger.info("has upsample")
|
||||
for upsample in up_block.upsamplers:
|
||||
upsample.forward = wrapper(self.upsample_forward, upsample, div * 2)
|
||||
div *= 2
|
||||
@@ -528,7 +531,7 @@ class SlicingDecoder(nn.Module):
|
||||
del x
|
||||
|
||||
hidden_states = torch.cat(sliced, dim=2)
|
||||
# print("us hidden_states", hidden_states.shape)
|
||||
# logger.info(f"us hidden_states {hidden_states.shape}")
|
||||
del sliced
|
||||
|
||||
hidden_states = hidden_states.to(org_device)
|
||||
|
||||
570
library/strategy_base.py
Normal file
570
library/strategy_base.py
Normal file
@@ -0,0 +1,570 @@
|
||||
# base class for platform strategies. this file defines the interface for strategies
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
|
||||
|
||||
|
||||
# TODO remove circular import by moving ImageInfo to a separate file
|
||||
# from library.train_util import ImageInfo
|
||||
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TokenizeStrategy:
|
||||
_strategy = None # strategy instance: actual strategy class
|
||||
|
||||
_re_attention = re.compile(
|
||||
r"""\\\(|
|
||||
\\\)|
|
||||
\\\[|
|
||||
\\]|
|
||||
\\\\|
|
||||
\\|
|
||||
\(|
|
||||
\[|
|
||||
:([+-]?[.\d]+)\)|
|
||||
\)|
|
||||
]|
|
||||
[^\\()\[\]:]+|
|
||||
:
|
||||
""",
|
||||
re.X,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def set_strategy(cls, strategy):
|
||||
if cls._strategy is not None:
|
||||
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
|
||||
cls._strategy = strategy
|
||||
|
||||
@classmethod
|
||||
def get_strategy(cls) -> Optional["TokenizeStrategy"]:
|
||||
return cls._strategy
|
||||
|
||||
def _load_tokenizer(
|
||||
self, model_class: Any, model_id: str, subfolder: Optional[str] = None, tokenizer_cache_dir: Optional[str] = None
|
||||
) -> Any:
|
||||
tokenizer = None
|
||||
if tokenizer_cache_dir:
|
||||
local_tokenizer_path = os.path.join(tokenizer_cache_dir, model_id.replace("/", "_"))
|
||||
if os.path.exists(local_tokenizer_path):
|
||||
logger.info(f"load tokenizer from cache: {local_tokenizer_path}")
|
||||
tokenizer = model_class.from_pretrained(local_tokenizer_path) # same for v1 and v2
|
||||
|
||||
if tokenizer is None:
|
||||
tokenizer = model_class.from_pretrained(model_id, subfolder=subfolder)
|
||||
|
||||
if tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
|
||||
logger.info(f"save Tokenizer to cache: {local_tokenizer_path}")
|
||||
tokenizer.save_pretrained(local_tokenizer_path)
|
||||
|
||||
return tokenizer
|
||||
|
||||
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def tokenize_with_weights(self, text: Union[str, List[str]]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
"""
|
||||
returns: [tokens1, tokens2, ...], [weights1, weights2, ...]
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_weighted_input_ids(
|
||||
self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
max_length includes starting and ending tokens.
|
||||
"""
|
||||
|
||||
def parse_prompt_attention(text):
|
||||
"""
|
||||
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
||||
Accepted tokens are:
|
||||
(abc) - increases attention to abc by a multiplier of 1.1
|
||||
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
||||
[abc] - decreases attention to abc by a multiplier of 1.1
|
||||
\( - literal character '('
|
||||
\[ - literal character '['
|
||||
\) - literal character ')'
|
||||
\] - literal character ']'
|
||||
\\ - literal character '\'
|
||||
anything else - just text
|
||||
>>> parse_prompt_attention('normal text')
|
||||
[['normal text', 1.0]]
|
||||
>>> parse_prompt_attention('an (important) word')
|
||||
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
||||
>>> parse_prompt_attention('(unbalanced')
|
||||
[['unbalanced', 1.1]]
|
||||
>>> parse_prompt_attention('\(literal\]')
|
||||
[['(literal]', 1.0]]
|
||||
>>> parse_prompt_attention('(unnecessary)(parens)')
|
||||
[['unnecessaryparens', 1.1]]
|
||||
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
||||
[['a ', 1.0],
|
||||
['house', 1.5730000000000004],
|
||||
[' ', 1.1],
|
||||
['on', 1.0],
|
||||
[' a ', 1.1],
|
||||
['hill', 0.55],
|
||||
[', sun, ', 1.1],
|
||||
['sky', 1.4641000000000006],
|
||||
['.', 1.1]]
|
||||
"""
|
||||
|
||||
res = []
|
||||
round_brackets = []
|
||||
square_brackets = []
|
||||
|
||||
round_bracket_multiplier = 1.1
|
||||
square_bracket_multiplier = 1 / 1.1
|
||||
|
||||
def multiply_range(start_position, multiplier):
|
||||
for p in range(start_position, len(res)):
|
||||
res[p][1] *= multiplier
|
||||
|
||||
for m in TokenizeStrategy._re_attention.finditer(text):
|
||||
text = m.group(0)
|
||||
weight = m.group(1)
|
||||
|
||||
if text.startswith("\\"):
|
||||
res.append([text[1:], 1.0])
|
||||
elif text == "(":
|
||||
round_brackets.append(len(res))
|
||||
elif text == "[":
|
||||
square_brackets.append(len(res))
|
||||
elif weight is not None and len(round_brackets) > 0:
|
||||
multiply_range(round_brackets.pop(), float(weight))
|
||||
elif text == ")" and len(round_brackets) > 0:
|
||||
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
||||
elif text == "]" and len(square_brackets) > 0:
|
||||
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
||||
else:
|
||||
res.append([text, 1.0])
|
||||
|
||||
for pos in round_brackets:
|
||||
multiply_range(pos, round_bracket_multiplier)
|
||||
|
||||
for pos in square_brackets:
|
||||
multiply_range(pos, square_bracket_multiplier)
|
||||
|
||||
if len(res) == 0:
|
||||
res = [["", 1.0]]
|
||||
|
||||
# merge runs of identical weights
|
||||
i = 0
|
||||
while i + 1 < len(res):
|
||||
if res[i][1] == res[i + 1][1]:
|
||||
res[i][0] += res[i + 1][0]
|
||||
res.pop(i + 1)
|
||||
else:
|
||||
i += 1
|
||||
|
||||
return res
|
||||
|
||||
def get_prompts_with_weights(text: str, max_length: int):
|
||||
r"""
|
||||
Tokenize a list of prompts and return its tokens with weights of each token. max_length does not include starting and ending token.
|
||||
|
||||
No padding, starting or ending token is included.
|
||||
"""
|
||||
truncated = False
|
||||
|
||||
texts_and_weights = parse_prompt_attention(text)
|
||||
tokens = []
|
||||
weights = []
|
||||
for word, weight in texts_and_weights:
|
||||
# tokenize and discard the starting and the ending token
|
||||
token = tokenizer(word).input_ids[1:-1]
|
||||
tokens += token
|
||||
# copy the weight by length of token
|
||||
weights += [weight] * len(token)
|
||||
# stop if the text is too long (longer than truncation limit)
|
||||
if len(tokens) > max_length:
|
||||
truncated = True
|
||||
break
|
||||
# truncate
|
||||
if len(tokens) > max_length:
|
||||
truncated = True
|
||||
tokens = tokens[:max_length]
|
||||
weights = weights[:max_length]
|
||||
if truncated:
|
||||
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
|
||||
return tokens, weights
|
||||
|
||||
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad):
|
||||
r"""
|
||||
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
||||
"""
|
||||
tokens = [bos] + tokens + [eos] + [pad] * (max_length - 2 - len(tokens))
|
||||
weights = [1.0] + weights + [1.0] * (max_length - 1 - len(weights))
|
||||
return tokens, weights
|
||||
|
||||
if max_length is None:
|
||||
max_length = tokenizer.model_max_length
|
||||
|
||||
tokens, weights = get_prompts_with_weights(text, max_length - 2)
|
||||
tokens, weights = pad_tokens_and_weights(
|
||||
tokens, weights, max_length, tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.pad_token_id
|
||||
)
|
||||
return torch.tensor(tokens).unsqueeze(0), torch.tensor(weights).unsqueeze(0)
|
||||
|
||||
def _get_input_ids(
|
||||
self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None, weighted: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
for SD1.5/2.0/SDXL
|
||||
TODO support batch input
|
||||
"""
|
||||
if max_length is None:
|
||||
max_length = tokenizer.model_max_length - 2
|
||||
|
||||
if weighted:
|
||||
input_ids, weights = self._get_weighted_input_ids(tokenizer, text, max_length)
|
||||
else:
|
||||
input_ids = tokenizer(text, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt").input_ids
|
||||
|
||||
if max_length > tokenizer.model_max_length:
|
||||
input_ids = input_ids.squeeze(0)
|
||||
iids_list = []
|
||||
if tokenizer.pad_token_id == tokenizer.eos_token_id:
|
||||
# v1
|
||||
# 77以上の時は "<BOS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<BOS>...<EOS>"の三連に変換する
|
||||
# 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に
|
||||
for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2): # (1, 152, 75)
|
||||
ids_chunk = (
|
||||
input_ids[0].unsqueeze(0),
|
||||
input_ids[i : i + tokenizer.model_max_length - 2],
|
||||
input_ids[-1].unsqueeze(0),
|
||||
)
|
||||
ids_chunk = torch.cat(ids_chunk)
|
||||
iids_list.append(ids_chunk)
|
||||
else:
|
||||
# v2 or SDXL
|
||||
# 77以上の時は "<BOS> .... <EOS> <PAD> <PAD>..." でトータル227とかになっているので、"<BOS>...<EOS> <PAD> <PAD> ..."の三連に変換する
|
||||
for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2):
|
||||
ids_chunk = (
|
||||
input_ids[0].unsqueeze(0), # BOS
|
||||
input_ids[i : i + tokenizer.model_max_length - 2],
|
||||
input_ids[-1].unsqueeze(0),
|
||||
) # PAD or EOS
|
||||
ids_chunk = torch.cat(ids_chunk)
|
||||
|
||||
# 末尾が <EOS> <PAD> または <PAD> <PAD> の場合は、何もしなくてよい
|
||||
# 末尾が x <PAD/EOS> の場合は末尾を <EOS> に変える(x <EOS> なら結果的に変化なし)
|
||||
if ids_chunk[-2] != tokenizer.eos_token_id and ids_chunk[-2] != tokenizer.pad_token_id:
|
||||
ids_chunk[-1] = tokenizer.eos_token_id
|
||||
# 先頭が <BOS> <PAD> ... の場合は <BOS> <EOS> <PAD> ... に変える
|
||||
if ids_chunk[1] == tokenizer.pad_token_id:
|
||||
ids_chunk[1] = tokenizer.eos_token_id
|
||||
|
||||
iids_list.append(ids_chunk)
|
||||
|
||||
input_ids = torch.stack(iids_list) # 3,77
|
||||
|
||||
if weighted:
|
||||
weights = weights.squeeze(0)
|
||||
new_weights = torch.ones(input_ids.shape)
|
||||
for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2):
|
||||
b = i // (tokenizer.model_max_length - 2)
|
||||
new_weights[b, 1 : 1 + tokenizer.model_max_length - 2] = weights[i : i + tokenizer.model_max_length - 2]
|
||||
weights = new_weights
|
||||
|
||||
if weighted:
|
||||
return input_ids, weights
|
||||
return input_ids
|
||||
|
||||
|
||||
class TextEncodingStrategy:
|
||||
_strategy = None # strategy instance: actual strategy class
|
||||
|
||||
@classmethod
|
||||
def set_strategy(cls, strategy):
|
||||
if cls._strategy is not None:
|
||||
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
|
||||
cls._strategy = strategy
|
||||
|
||||
@classmethod
|
||||
def get_strategy(cls) -> Optional["TextEncodingStrategy"]:
|
||||
return cls._strategy
|
||||
|
||||
def encode_tokens(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
|
||||
) -> List[torch.Tensor]:
|
||||
"""
|
||||
Encode tokens into embeddings and outputs.
|
||||
:param tokens: list of token tensors for each TextModel
|
||||
:return: list of output embeddings for each architecture
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def encode_tokens_with_weights(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor], weights: List[torch.Tensor]
|
||||
) -> List[torch.Tensor]:
|
||||
"""
|
||||
Encode tokens into embeddings and outputs.
|
||||
:param tokens: list of token tensors for each TextModel
|
||||
:param weights: list of weight tensors for each TextModel
|
||||
:return: list of output embeddings for each architecture
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TextEncoderOutputsCachingStrategy:
|
||||
_strategy = None # strategy instance: actual strategy class
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_to_disk: bool,
|
||||
batch_size: Optional[int],
|
||||
skip_disk_cache_validity_check: bool,
|
||||
is_partial: bool = False,
|
||||
is_weighted: bool = False,
|
||||
) -> None:
|
||||
self._cache_to_disk = cache_to_disk
|
||||
self._batch_size = batch_size
|
||||
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
|
||||
self._is_partial = is_partial
|
||||
self._is_weighted = is_weighted
|
||||
|
||||
@classmethod
|
||||
def set_strategy(cls, strategy):
|
||||
if cls._strategy is not None:
|
||||
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
|
||||
cls._strategy = strategy
|
||||
|
||||
@classmethod
|
||||
def get_strategy(cls) -> Optional["TextEncoderOutputsCachingStrategy"]:
|
||||
return cls._strategy
|
||||
|
||||
@property
|
||||
def cache_to_disk(self):
|
||||
return self._cache_to_disk
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return self._batch_size
|
||||
|
||||
@property
|
||||
def is_partial(self):
|
||||
return self._is_partial
|
||||
|
||||
@property
|
||||
def is_weighted(self):
|
||||
return self._is_weighted
|
||||
|
||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||
raise NotImplementedError
|
||||
|
||||
def is_disk_cached_outputs_expected(self, npz_path: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def cache_batch_outputs(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, batch: List
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class LatentsCachingStrategy:
|
||||
# TODO commonize utillity functions to this class, such as npz handling etc.
|
||||
|
||||
_strategy = None # strategy instance: actual strategy class
|
||||
|
||||
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||
self._cache_to_disk = cache_to_disk
|
||||
self._batch_size = batch_size
|
||||
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
|
||||
|
||||
@classmethod
|
||||
def set_strategy(cls, strategy):
|
||||
if cls._strategy is not None:
|
||||
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
|
||||
cls._strategy = strategy
|
||||
|
||||
@classmethod
|
||||
def get_strategy(cls) -> Optional["LatentsCachingStrategy"]:
|
||||
return cls._strategy
|
||||
|
||||
@property
|
||||
def cache_to_disk(self):
|
||||
return self._cache_to_disk
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return self._batch_size
|
||||
|
||||
@property
|
||||
def cache_suffix(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_image_size_from_disk_cache_path(self, absolute_path: str, npz_path: str) -> Tuple[Optional[int], Optional[int]]:
|
||||
w, h = os.path.splitext(npz_path)[0].split("_")[-2].split("x")
|
||||
return int(w), int(h)
|
||||
|
||||
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def is_disk_cached_latents_expected(
|
||||
self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool
|
||||
) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def cache_batch_latents(self, model: Any, batch: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
raise NotImplementedError
|
||||
|
||||
def _default_is_disk_cached_latents_expected(
|
||||
self,
|
||||
latents_stride: int,
|
||||
bucket_reso: Tuple[int, int],
|
||||
npz_path: str,
|
||||
flip_aug: bool,
|
||||
alpha_mask: bool,
|
||||
multi_resolution: bool = False,
|
||||
):
|
||||
if not self.cache_to_disk:
|
||||
return False
|
||||
if not os.path.exists(npz_path):
|
||||
return False
|
||||
if self.skip_disk_cache_validity_check:
|
||||
return True
|
||||
|
||||
expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
|
||||
|
||||
# e.g. "_32x64", HxW
|
||||
key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" if multi_resolution else ""
|
||||
|
||||
try:
|
||||
npz = np.load(npz_path)
|
||||
if "latents" + key_reso_suffix not in npz:
|
||||
return False
|
||||
if flip_aug and "latents_flipped" + key_reso_suffix not in npz:
|
||||
return False
|
||||
if alpha_mask and "alpha_mask" + key_reso_suffix not in npz:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
raise e
|
||||
|
||||
return True
|
||||
|
||||
# TODO remove circular dependency for ImageInfo
|
||||
def _default_cache_batch_latents(
|
||||
self,
|
||||
encode_by_vae,
|
||||
vae_device,
|
||||
vae_dtype,
|
||||
image_infos: List,
|
||||
flip_aug: bool,
|
||||
alpha_mask: bool,
|
||||
random_crop: bool,
|
||||
multi_resolution: bool = False,
|
||||
):
|
||||
"""
|
||||
Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common.
|
||||
"""
|
||||
from library import train_util # import here to avoid circular import
|
||||
|
||||
img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching(
|
||||
image_infos, alpha_mask, random_crop
|
||||
)
|
||||
img_tensor = img_tensor.to(device=vae_device, dtype=vae_dtype)
|
||||
|
||||
with torch.no_grad():
|
||||
latents_tensors = encode_by_vae(img_tensor).to("cpu")
|
||||
if flip_aug:
|
||||
img_tensor = torch.flip(img_tensor, dims=[3])
|
||||
with torch.no_grad():
|
||||
flipped_latents = encode_by_vae(img_tensor).to("cpu")
|
||||
else:
|
||||
flipped_latents = [None] * len(latents_tensors)
|
||||
|
||||
# for info, latents, flipped_latent, alpha_mask in zip(image_infos, latents_tensors, flipped_latents, alpha_masks):
|
||||
for i in range(len(image_infos)):
|
||||
info = image_infos[i]
|
||||
latents = latents_tensors[i]
|
||||
flipped_latent = flipped_latents[i]
|
||||
alpha_mask = alpha_masks[i]
|
||||
original_size = original_sizes[i]
|
||||
crop_ltrb = crop_ltrbs[i]
|
||||
|
||||
latents_size = latents.shape[1:3] # H, W
|
||||
key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" if multi_resolution else "" # e.g. "_32x64", HxW
|
||||
|
||||
if self.cache_to_disk:
|
||||
self.save_latents_to_disk(
|
||||
info.latents_npz, latents, original_size, crop_ltrb, flipped_latent, alpha_mask, key_reso_suffix
|
||||
)
|
||||
else:
|
||||
info.latents_original_size = original_size
|
||||
info.latents_crop_ltrb = crop_ltrb
|
||||
info.latents = latents
|
||||
if flip_aug:
|
||||
info.latents_flipped = flipped_latent
|
||||
info.alpha_mask = alpha_mask
|
||||
|
||||
def load_latents_from_disk(
|
||||
self, npz_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
"""
|
||||
for SD/SDXL
|
||||
"""
|
||||
return self._default_load_latents_from_disk(None, npz_path, bucket_reso)
|
||||
|
||||
def _default_load_latents_from_disk(
|
||||
self, latents_stride: Optional[int], npz_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
if latents_stride is None:
|
||||
key_reso_suffix = ""
|
||||
else:
|
||||
latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
|
||||
key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" # e.g. "_32x64", HxW
|
||||
|
||||
npz = np.load(npz_path)
|
||||
if "latents" + key_reso_suffix not in npz:
|
||||
raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}")
|
||||
|
||||
latents = npz["latents" + key_reso_suffix]
|
||||
original_size = npz["original_size" + key_reso_suffix].tolist()
|
||||
crop_ltrb = npz["crop_ltrb" + key_reso_suffix].tolist()
|
||||
flipped_latents = npz["latents_flipped" + key_reso_suffix] if "latents_flipped" + key_reso_suffix in npz else None
|
||||
alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None
|
||||
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
|
||||
|
||||
def save_latents_to_disk(
|
||||
self,
|
||||
npz_path,
|
||||
latents_tensor,
|
||||
original_size,
|
||||
crop_ltrb,
|
||||
flipped_latents_tensor=None,
|
||||
alpha_mask=None,
|
||||
key_reso_suffix="",
|
||||
):
|
||||
kwargs = {}
|
||||
|
||||
if os.path.exists(npz_path):
|
||||
# load existing npz and update it
|
||||
npz = np.load(npz_path)
|
||||
for key in npz.files:
|
||||
kwargs[key] = npz[key]
|
||||
|
||||
kwargs["latents" + key_reso_suffix] = latents_tensor.float().cpu().numpy()
|
||||
kwargs["original_size" + key_reso_suffix] = np.array(original_size)
|
||||
kwargs["crop_ltrb" + key_reso_suffix] = np.array(crop_ltrb)
|
||||
if flipped_latents_tensor is not None:
|
||||
kwargs["latents_flipped" + key_reso_suffix] = flipped_latents_tensor.float().cpu().numpy()
|
||||
if alpha_mask is not None:
|
||||
kwargs["alpha_mask" + key_reso_suffix] = alpha_mask.float().cpu().numpy()
|
||||
np.savez(npz_path, **kwargs)
|
||||
271
library/strategy_flux.py
Normal file
271
library/strategy_flux.py
Normal file
@@ -0,0 +1,271 @@
|
||||
import os
|
||||
import glob
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
import torch
|
||||
import numpy as np
|
||||
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||
|
||||
from library import flux_utils, train_util
|
||||
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
|
||||
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14"
|
||||
T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl"
|
||||
|
||||
|
||||
class FluxTokenizeStrategy(TokenizeStrategy):
|
||||
def __init__(self, t5xxl_max_length: int = 512, tokenizer_cache_dir: Optional[str] = None) -> None:
|
||||
self.t5xxl_max_length = t5xxl_max_length
|
||||
self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
|
||||
self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
|
||||
|
||||
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
|
||||
text = [text] if isinstance(text, str) else text
|
||||
|
||||
l_tokens = self.clip_l(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
|
||||
t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt")
|
||||
|
||||
t5_attn_mask = t5_tokens["attention_mask"]
|
||||
l_tokens = l_tokens["input_ids"]
|
||||
t5_tokens = t5_tokens["input_ids"]
|
||||
|
||||
return [l_tokens, t5_tokens, t5_attn_mask]
|
||||
|
||||
|
||||
class FluxTextEncodingStrategy(TextEncodingStrategy):
|
||||
def __init__(self, apply_t5_attn_mask: Optional[bool] = None) -> None:
|
||||
"""
|
||||
Args:
|
||||
apply_t5_attn_mask: Default value for apply_t5_attn_mask.
|
||||
"""
|
||||
self.apply_t5_attn_mask = apply_t5_attn_mask
|
||||
|
||||
def encode_tokens(
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: List[Any],
|
||||
tokens: List[torch.Tensor],
|
||||
apply_t5_attn_mask: Optional[bool] = None,
|
||||
) -> List[torch.Tensor]:
|
||||
# supports single model inference
|
||||
|
||||
if apply_t5_attn_mask is None:
|
||||
apply_t5_attn_mask = self.apply_t5_attn_mask
|
||||
|
||||
clip_l, t5xxl = models if len(models) == 2 else (models[0], None)
|
||||
l_tokens, t5_tokens = tokens[:2]
|
||||
t5_attn_mask = tokens[2] if len(tokens) > 2 else None
|
||||
|
||||
# clip_l is None when using T5 only
|
||||
if clip_l is not None and l_tokens is not None:
|
||||
l_pooled = clip_l(l_tokens.to(clip_l.device))["pooler_output"]
|
||||
else:
|
||||
l_pooled = None
|
||||
|
||||
# t5xxl is None when using CLIP only
|
||||
if t5xxl is not None and t5_tokens is not None:
|
||||
# t5_out is [b, max length, 4096]
|
||||
attention_mask = None if not apply_t5_attn_mask else t5_attn_mask.to(t5xxl.device)
|
||||
t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), attention_mask, return_dict=False, output_hidden_states=True)
|
||||
# if zero_pad_t5_output:
|
||||
# t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1)
|
||||
txt_ids = torch.zeros(t5_out.shape[0], t5_out.shape[1], 3, device=t5_out.device)
|
||||
else:
|
||||
t5_out = None
|
||||
txt_ids = None
|
||||
t5_attn_mask = None # caption may be dropped/shuffled, so t5_attn_mask should not be used to make sure the mask is same as the cached one
|
||||
|
||||
return [l_pooled, t5_out, txt_ids, t5_attn_mask] # returns t5_attn_mask for attention mask in transformer
|
||||
|
||||
|
||||
class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_flux_te.npz"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_to_disk: bool,
|
||||
batch_size: int,
|
||||
skip_disk_cache_validity_check: bool,
|
||||
is_partial: bool = False,
|
||||
apply_t5_attn_mask: bool = False,
|
||||
) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
|
||||
self.apply_t5_attn_mask = apply_t5_attn_mask
|
||||
|
||||
self.warn_fp8_weights = False
|
||||
|
||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||
return os.path.splitext(image_abs_path)[0] + FluxTextEncoderOutputsCachingStrategy.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||
|
||||
def is_disk_cached_outputs_expected(self, npz_path: str):
|
||||
if not self.cache_to_disk:
|
||||
return False
|
||||
if not os.path.exists(npz_path):
|
||||
return False
|
||||
if self.skip_disk_cache_validity_check:
|
||||
return True
|
||||
|
||||
try:
|
||||
npz = np.load(npz_path)
|
||||
if "l_pooled" not in npz:
|
||||
return False
|
||||
if "t5_out" not in npz:
|
||||
return False
|
||||
if "txt_ids" not in npz:
|
||||
return False
|
||||
if "t5_attn_mask" not in npz:
|
||||
return False
|
||||
if "apply_t5_attn_mask" not in npz:
|
||||
return False
|
||||
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
|
||||
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
raise e
|
||||
|
||||
return True
|
||||
|
||||
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||
data = np.load(npz_path)
|
||||
l_pooled = data["l_pooled"]
|
||||
t5_out = data["t5_out"]
|
||||
txt_ids = data["txt_ids"]
|
||||
t5_attn_mask = data["t5_attn_mask"]
|
||||
# apply_t5_attn_mask should be same as self.apply_t5_attn_mask
|
||||
return [l_pooled, t5_out, txt_ids, t5_attn_mask]
|
||||
|
||||
def cache_batch_outputs(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
|
||||
):
|
||||
if not self.warn_fp8_weights:
|
||||
if flux_utils.get_t5xxl_actual_dtype(models[1]) == torch.float8_e4m3fn:
|
||||
logger.warning(
|
||||
"T5 model is using fp8 weights for caching. This may affect the quality of the cached outputs."
|
||||
" / T5モデルはfp8の重みを使用しています。これはキャッシュの品質に影響を与える可能性があります。"
|
||||
)
|
||||
self.warn_fp8_weights = True
|
||||
|
||||
flux_text_encoding_strategy: FluxTextEncodingStrategy = text_encoding_strategy
|
||||
captions = [info.caption for info in infos]
|
||||
|
||||
tokens_and_masks = tokenize_strategy.tokenize(captions)
|
||||
with torch.no_grad():
|
||||
# attn_mask is applied in text_encoding_strategy.encode_tokens if apply_t5_attn_mask is True
|
||||
l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens(tokenize_strategy, models, tokens_and_masks)
|
||||
|
||||
if l_pooled.dtype == torch.bfloat16:
|
||||
l_pooled = l_pooled.float()
|
||||
if t5_out.dtype == torch.bfloat16:
|
||||
t5_out = t5_out.float()
|
||||
if txt_ids.dtype == torch.bfloat16:
|
||||
txt_ids = txt_ids.float()
|
||||
|
||||
l_pooled = l_pooled.cpu().numpy()
|
||||
t5_out = t5_out.cpu().numpy()
|
||||
txt_ids = txt_ids.cpu().numpy()
|
||||
t5_attn_mask = tokens_and_masks[2].cpu().numpy()
|
||||
|
||||
for i, info in enumerate(infos):
|
||||
l_pooled_i = l_pooled[i]
|
||||
t5_out_i = t5_out[i]
|
||||
txt_ids_i = txt_ids[i]
|
||||
t5_attn_mask_i = t5_attn_mask[i]
|
||||
apply_t5_attn_mask_i = self.apply_t5_attn_mask
|
||||
|
||||
if self.cache_to_disk:
|
||||
np.savez(
|
||||
info.text_encoder_outputs_npz,
|
||||
l_pooled=l_pooled_i,
|
||||
t5_out=t5_out_i,
|
||||
txt_ids=txt_ids_i,
|
||||
t5_attn_mask=t5_attn_mask_i,
|
||||
apply_t5_attn_mask=apply_t5_attn_mask_i,
|
||||
)
|
||||
else:
|
||||
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
|
||||
info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i)
|
||||
|
||||
|
||||
class FluxLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
FLUX_LATENTS_NPZ_SUFFIX = "_flux.npz"
|
||||
|
||||
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||
|
||||
@property
|
||||
def cache_suffix(self) -> str:
|
||||
return FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX
|
||||
|
||||
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||
return (
|
||||
os.path.splitext(absolute_path)[0]
|
||||
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
|
||||
+ FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX
|
||||
)
|
||||
|
||||
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
|
||||
|
||||
def load_latents_from_disk(
|
||||
self, npz_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution
|
||||
|
||||
# TODO remove circular dependency for ImageInfo
|
||||
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu")
|
||||
vae_device = vae.device
|
||||
vae_dtype = vae.dtype
|
||||
|
||||
self._default_cache_batch_latents(
|
||||
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
|
||||
)
|
||||
|
||||
if not train_util.HIGH_VRAM:
|
||||
train_util.clean_memory_on_device(vae.device)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# test code for FluxTokenizeStrategy
|
||||
# tokenizer = sd3_models.SD3Tokenizer()
|
||||
strategy = FluxTokenizeStrategy(256)
|
||||
text = "hello world"
|
||||
|
||||
l_tokens, g_tokens, t5_tokens = strategy.tokenize(text)
|
||||
# print(l_tokens.shape)
|
||||
print(l_tokens)
|
||||
print(g_tokens)
|
||||
print(t5_tokens)
|
||||
|
||||
texts = ["hello world", "the quick brown fox jumps over the lazy dog"]
|
||||
l_tokens_2 = strategy.clip_l(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
|
||||
g_tokens_2 = strategy.clip_g(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
|
||||
t5_tokens_2 = strategy.t5xxl(
|
||||
texts, max_length=strategy.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||
)
|
||||
print(l_tokens_2)
|
||||
print(g_tokens_2)
|
||||
print(t5_tokens_2)
|
||||
|
||||
# compare
|
||||
print(torch.allclose(l_tokens, l_tokens_2["input_ids"][0]))
|
||||
print(torch.allclose(g_tokens, g_tokens_2["input_ids"][0]))
|
||||
print(torch.allclose(t5_tokens, t5_tokens_2["input_ids"][0]))
|
||||
|
||||
text = ",".join(["hello world! this is long text"] * 50)
|
||||
l_tokens, g_tokens, t5_tokens = strategy.tokenize(text)
|
||||
print(l_tokens)
|
||||
print(g_tokens)
|
||||
print(t5_tokens)
|
||||
|
||||
print(f"model max length l: {strategy.clip_l.model_max_length}")
|
||||
print(f"model max length g: {strategy.clip_g.model_max_length}")
|
||||
print(f"model max length t5: {strategy.t5xxl.model_max_length}")
|
||||
171
library/strategy_sd.py
Normal file
171
library/strategy_sd.py
Normal file
@@ -0,0 +1,171 @@
|
||||
import glob
|
||||
import os
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTokenizer
|
||||
from library import train_util
|
||||
from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncodingStrategy
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
TOKENIZER_ID = "openai/clip-vit-large-patch14"
|
||||
V2_STABLE_DIFFUSION_ID = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ
|
||||
|
||||
|
||||
class SdTokenizeStrategy(TokenizeStrategy):
|
||||
def __init__(self, v2: bool, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None:
|
||||
"""
|
||||
max_length does not include <BOS> and <EOS> (None, 75, 150, 225)
|
||||
"""
|
||||
logger.info(f"Using {'v2' if v2 else 'v1'} tokenizer")
|
||||
if v2:
|
||||
self.tokenizer = self._load_tokenizer(
|
||||
CLIPTokenizer, V2_STABLE_DIFFUSION_ID, subfolder="tokenizer", tokenizer_cache_dir=tokenizer_cache_dir
|
||||
)
|
||||
else:
|
||||
self.tokenizer = self._load_tokenizer(CLIPTokenizer, TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
|
||||
|
||||
if max_length is None:
|
||||
self.max_length = self.tokenizer.model_max_length
|
||||
else:
|
||||
self.max_length = max_length + 2
|
||||
|
||||
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
|
||||
text = [text] if isinstance(text, str) else text
|
||||
return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)]
|
||||
|
||||
def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
text = [text] if isinstance(text, str) else text
|
||||
tokens_list = []
|
||||
weights_list = []
|
||||
for t in text:
|
||||
tokens, weights = self._get_input_ids(self.tokenizer, t, self.max_length, weighted=True)
|
||||
tokens_list.append(tokens)
|
||||
weights_list.append(weights)
|
||||
return [torch.stack(tokens_list, dim=0)], [torch.stack(weights_list, dim=0)]
|
||||
|
||||
|
||||
class SdTextEncodingStrategy(TextEncodingStrategy):
|
||||
def __init__(self, clip_skip: Optional[int] = None) -> None:
|
||||
self.clip_skip = clip_skip
|
||||
|
||||
def encode_tokens(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
|
||||
) -> List[torch.Tensor]:
|
||||
text_encoder = models[0]
|
||||
tokens = tokens[0]
|
||||
sd_tokenize_strategy = tokenize_strategy # type: SdTokenizeStrategy
|
||||
|
||||
# tokens: b,n,77
|
||||
b_size = tokens.size()[0]
|
||||
max_token_length = tokens.size()[1] * tokens.size()[2]
|
||||
model_max_length = sd_tokenize_strategy.tokenizer.model_max_length
|
||||
tokens = tokens.reshape((-1, model_max_length)) # batch_size*3, 77
|
||||
|
||||
tokens = tokens.to(text_encoder.device)
|
||||
|
||||
if self.clip_skip is None:
|
||||
encoder_hidden_states = text_encoder(tokens)[0]
|
||||
else:
|
||||
enc_out = text_encoder(tokens, output_hidden_states=True, return_dict=True)
|
||||
encoder_hidden_states = enc_out["hidden_states"][-self.clip_skip]
|
||||
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
|
||||
|
||||
# bs*3, 77, 768 or 1024
|
||||
encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
|
||||
|
||||
if max_token_length != model_max_length:
|
||||
v1 = sd_tokenize_strategy.tokenizer.pad_token_id == sd_tokenize_strategy.tokenizer.eos_token_id
|
||||
if not v1:
|
||||
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
|
||||
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
|
||||
for i in range(1, max_token_length, model_max_length):
|
||||
chunk = encoder_hidden_states[:, i : i + model_max_length - 2] # <BOS> の後から 最後の前まで
|
||||
if i > 0:
|
||||
for j in range(len(chunk)):
|
||||
if tokens[j, 1] == sd_tokenize_strategy.tokenizer.eos_token:
|
||||
# 空、つまり <BOS> <EOS> <PAD> ...のパターン
|
||||
chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
|
||||
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
|
||||
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
|
||||
encoder_hidden_states = torch.cat(states_list, dim=1)
|
||||
else:
|
||||
# v1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
|
||||
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
|
||||
for i in range(1, max_token_length, model_max_length):
|
||||
states_list.append(encoder_hidden_states[:, i : i + model_max_length - 2]) # <BOS> の後から <EOS> の前まで
|
||||
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
|
||||
encoder_hidden_states = torch.cat(states_list, dim=1)
|
||||
|
||||
return [encoder_hidden_states]
|
||||
|
||||
def encode_tokens_with_weights(
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: List[Any],
|
||||
tokens_list: List[torch.Tensor],
|
||||
weights_list: List[torch.Tensor],
|
||||
) -> List[torch.Tensor]:
|
||||
encoder_hidden_states = self.encode_tokens(tokenize_strategy, models, tokens_list)[0]
|
||||
|
||||
weights = weights_list[0].to(encoder_hidden_states.device)
|
||||
|
||||
# apply weights
|
||||
if weights.shape[1] == 1: # no max_token_length
|
||||
# weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768)
|
||||
encoder_hidden_states = encoder_hidden_states * weights.squeeze(1).unsqueeze(2)
|
||||
else:
|
||||
# weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768)
|
||||
for i in range(weights.shape[1]):
|
||||
encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] = encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] * weights[
|
||||
:, i, 1:-1
|
||||
].unsqueeze(-1)
|
||||
|
||||
return [encoder_hidden_states]
|
||||
|
||||
|
||||
class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
# sd and sdxl share the same strategy. we can make them separate, but the difference is only the suffix.
|
||||
# and we keep the old npz for the backward compatibility.
|
||||
|
||||
SD_OLD_LATENTS_NPZ_SUFFIX = ".npz"
|
||||
SD_LATENTS_NPZ_SUFFIX = "_sd.npz"
|
||||
SDXL_LATENTS_NPZ_SUFFIX = "_sdxl.npz"
|
||||
|
||||
def __init__(self, sd: bool, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||
self.sd = sd
|
||||
self.suffix = (
|
||||
SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX
|
||||
)
|
||||
|
||||
@property
|
||||
def cache_suffix(self) -> str:
|
||||
return self.suffix
|
||||
|
||||
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||
# support old .npz
|
||||
old_npz_file = os.path.splitext(absolute_path)[0] + SdSdxlLatentsCachingStrategy.SD_OLD_LATENTS_NPZ_SUFFIX
|
||||
if os.path.exists(old_npz_file):
|
||||
return old_npz_file
|
||||
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix
|
||||
|
||||
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask)
|
||||
|
||||
# TODO remove circular dependency for ImageInfo
|
||||
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).latent_dist.sample()
|
||||
vae_device = vae.device
|
||||
vae_dtype = vae.dtype
|
||||
|
||||
self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop)
|
||||
|
||||
if not train_util.HIGH_VRAM:
|
||||
train_util.clean_memory_on_device(vae.device)
|
||||
420
library/strategy_sd3.py
Normal file
420
library/strategy_sd3.py
Normal file
@@ -0,0 +1,420 @@
|
||||
import os
|
||||
import glob
|
||||
import random
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
import torch
|
||||
import numpy as np
|
||||
from transformers import CLIPTokenizer, T5TokenizerFast, CLIPTextModel, CLIPTextModelWithProjection, T5EncoderModel
|
||||
|
||||
from library import sd3_utils, train_util
|
||||
from library import sd3_models
|
||||
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
|
||||
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14"
|
||||
CLIP_G_TOKENIZER_ID = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl"
|
||||
|
||||
|
||||
class Sd3TokenizeStrategy(TokenizeStrategy):
|
||||
def __init__(self, t5xxl_max_length: int = 256, tokenizer_cache_dir: Optional[str] = None) -> None:
|
||||
self.t5xxl_max_length = t5xxl_max_length
|
||||
self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
|
||||
self.clip_g = self._load_tokenizer(CLIPTokenizer, CLIP_G_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
|
||||
self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
|
||||
self.clip_g.pad_token_id = 0 # use 0 as pad token for clip_g
|
||||
|
||||
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
|
||||
text = [text] if isinstance(text, str) else text
|
||||
|
||||
l_tokens = self.clip_l(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
|
||||
g_tokens = self.clip_g(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
|
||||
t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt")
|
||||
|
||||
l_attn_mask = l_tokens["attention_mask"]
|
||||
g_attn_mask = g_tokens["attention_mask"]
|
||||
t5_attn_mask = t5_tokens["attention_mask"]
|
||||
l_tokens = l_tokens["input_ids"]
|
||||
g_tokens = g_tokens["input_ids"]
|
||||
t5_tokens = t5_tokens["input_ids"]
|
||||
|
||||
return [l_tokens, g_tokens, t5_tokens, l_attn_mask, g_attn_mask, t5_attn_mask]
|
||||
|
||||
|
||||
class Sd3TextEncodingStrategy(TextEncodingStrategy):
|
||||
def __init__(
|
||||
self,
|
||||
apply_lg_attn_mask: Optional[bool] = None,
|
||||
apply_t5_attn_mask: Optional[bool] = None,
|
||||
l_dropout_rate: float = 0.0,
|
||||
g_dropout_rate: float = 0.0,
|
||||
t5_dropout_rate: float = 0.0,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
apply_t5_attn_mask: Default value for apply_t5_attn_mask.
|
||||
"""
|
||||
self.apply_lg_attn_mask = apply_lg_attn_mask
|
||||
self.apply_t5_attn_mask = apply_t5_attn_mask
|
||||
self.l_dropout_rate = l_dropout_rate
|
||||
self.g_dropout_rate = g_dropout_rate
|
||||
self.t5_dropout_rate = t5_dropout_rate
|
||||
|
||||
def encode_tokens(
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: List[Any],
|
||||
tokens: List[torch.Tensor],
|
||||
apply_lg_attn_mask: Optional[bool] = False,
|
||||
apply_t5_attn_mask: Optional[bool] = False,
|
||||
enable_dropout: bool = True,
|
||||
) -> List[torch.Tensor]:
|
||||
"""
|
||||
returned embeddings are not masked
|
||||
"""
|
||||
clip_l, clip_g, t5xxl = models
|
||||
clip_l: Optional[CLIPTextModel]
|
||||
clip_g: Optional[CLIPTextModelWithProjection]
|
||||
t5xxl: Optional[T5EncoderModel]
|
||||
|
||||
if apply_lg_attn_mask is None:
|
||||
apply_lg_attn_mask = self.apply_lg_attn_mask
|
||||
if apply_t5_attn_mask is None:
|
||||
apply_t5_attn_mask = self.apply_t5_attn_mask
|
||||
|
||||
l_tokens, g_tokens, t5_tokens, l_attn_mask, g_attn_mask, t5_attn_mask = tokens
|
||||
|
||||
# dropout: if enable_dropout is False, dropout is not applied. dropout means zeroing out embeddings
|
||||
|
||||
if l_tokens is None or clip_l is None:
|
||||
assert g_tokens is None, "g_tokens must be None if l_tokens is None"
|
||||
lg_out = None
|
||||
lg_pooled = None
|
||||
l_attn_mask = None
|
||||
g_attn_mask = None
|
||||
else:
|
||||
assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None"
|
||||
|
||||
# drop some members of the batch: we do not call clip_l and clip_g for dropped members
|
||||
batch_size, l_seq_len = l_tokens.shape
|
||||
g_seq_len = g_tokens.shape[1]
|
||||
|
||||
non_drop_l_indices = []
|
||||
non_drop_g_indices = []
|
||||
for i in range(l_tokens.shape[0]):
|
||||
drop_l = enable_dropout and (self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate)
|
||||
drop_g = enable_dropout and (self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate)
|
||||
if not drop_l:
|
||||
non_drop_l_indices.append(i)
|
||||
if not drop_g:
|
||||
non_drop_g_indices.append(i)
|
||||
|
||||
# filter out dropped members
|
||||
if len(non_drop_l_indices) > 0 and len(non_drop_l_indices) < batch_size:
|
||||
l_tokens = l_tokens[non_drop_l_indices]
|
||||
l_attn_mask = l_attn_mask[non_drop_l_indices]
|
||||
if len(non_drop_g_indices) > 0 and len(non_drop_g_indices) < batch_size:
|
||||
g_tokens = g_tokens[non_drop_g_indices]
|
||||
g_attn_mask = g_attn_mask[non_drop_g_indices]
|
||||
|
||||
# call clip_l for non-dropped members
|
||||
if len(non_drop_l_indices) > 0:
|
||||
nd_l_attn_mask = l_attn_mask.to(clip_l.device)
|
||||
prompt_embeds = clip_l(
|
||||
l_tokens.to(clip_l.device), nd_l_attn_mask if apply_lg_attn_mask else None, output_hidden_states=True
|
||||
)
|
||||
nd_l_pooled = prompt_embeds[0]
|
||||
nd_l_out = prompt_embeds.hidden_states[-2]
|
||||
if len(non_drop_g_indices) > 0:
|
||||
nd_g_attn_mask = g_attn_mask.to(clip_g.device)
|
||||
prompt_embeds = clip_g(
|
||||
g_tokens.to(clip_g.device), nd_g_attn_mask if apply_lg_attn_mask else None, output_hidden_states=True
|
||||
)
|
||||
nd_g_pooled = prompt_embeds[0]
|
||||
nd_g_out = prompt_embeds.hidden_states[-2]
|
||||
|
||||
# fill in the dropped members
|
||||
if len(non_drop_l_indices) == batch_size:
|
||||
l_pooled = nd_l_pooled
|
||||
l_out = nd_l_out
|
||||
else:
|
||||
# model output is always float32 because of the models are wrapped with Accelerator
|
||||
l_pooled = torch.zeros((batch_size, 768), device=clip_l.device, dtype=torch.float32)
|
||||
l_out = torch.zeros((batch_size, l_seq_len, 768), device=clip_l.device, dtype=torch.float32)
|
||||
l_attn_mask = torch.zeros((batch_size, l_seq_len), device=clip_l.device, dtype=l_attn_mask.dtype)
|
||||
if len(non_drop_l_indices) > 0:
|
||||
l_pooled[non_drop_l_indices] = nd_l_pooled
|
||||
l_out[non_drop_l_indices] = nd_l_out
|
||||
l_attn_mask[non_drop_l_indices] = nd_l_attn_mask
|
||||
|
||||
if len(non_drop_g_indices) == batch_size:
|
||||
g_pooled = nd_g_pooled
|
||||
g_out = nd_g_out
|
||||
else:
|
||||
g_pooled = torch.zeros((batch_size, 1280), device=clip_g.device, dtype=torch.float32)
|
||||
g_out = torch.zeros((batch_size, g_seq_len, 1280), device=clip_g.device, dtype=torch.float32)
|
||||
g_attn_mask = torch.zeros((batch_size, g_seq_len), device=clip_g.device, dtype=g_attn_mask.dtype)
|
||||
if len(non_drop_g_indices) > 0:
|
||||
g_pooled[non_drop_g_indices] = nd_g_pooled
|
||||
g_out[non_drop_g_indices] = nd_g_out
|
||||
g_attn_mask[non_drop_g_indices] = nd_g_attn_mask
|
||||
|
||||
lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1)
|
||||
lg_out = torch.cat([l_out, g_out], dim=-1)
|
||||
|
||||
if t5xxl is None or t5_tokens is None:
|
||||
t5_out = None
|
||||
t5_attn_mask = None
|
||||
else:
|
||||
# drop some members of the batch: we do not call t5xxl for dropped members
|
||||
batch_size, t5_seq_len = t5_tokens.shape
|
||||
non_drop_t5_indices = []
|
||||
for i in range(t5_tokens.shape[0]):
|
||||
drop_t5 = enable_dropout and (self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate)
|
||||
if not drop_t5:
|
||||
non_drop_t5_indices.append(i)
|
||||
|
||||
# filter out dropped members
|
||||
if len(non_drop_t5_indices) > 0 and len(non_drop_t5_indices) < batch_size:
|
||||
t5_tokens = t5_tokens[non_drop_t5_indices]
|
||||
t5_attn_mask = t5_attn_mask[non_drop_t5_indices]
|
||||
|
||||
# call t5xxl for non-dropped members
|
||||
if len(non_drop_t5_indices) > 0:
|
||||
nd_t5_attn_mask = t5_attn_mask.to(t5xxl.device)
|
||||
nd_t5_out, _ = t5xxl(
|
||||
t5_tokens.to(t5xxl.device),
|
||||
nd_t5_attn_mask if apply_t5_attn_mask else None,
|
||||
return_dict=False,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
# fill in the dropped members
|
||||
if len(non_drop_t5_indices) == batch_size:
|
||||
t5_out = nd_t5_out
|
||||
else:
|
||||
t5_out = torch.zeros((batch_size, t5_seq_len, 4096), device=t5xxl.device, dtype=torch.float32)
|
||||
t5_attn_mask = torch.zeros((batch_size, t5_seq_len), device=t5xxl.device, dtype=t5_attn_mask.dtype)
|
||||
if len(non_drop_t5_indices) > 0:
|
||||
t5_out[non_drop_t5_indices] = nd_t5_out
|
||||
t5_attn_mask[non_drop_t5_indices] = nd_t5_attn_mask
|
||||
|
||||
# masks are used for attention masking in transformer
|
||||
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
|
||||
|
||||
def drop_cached_text_encoder_outputs(
|
||||
self,
|
||||
lg_out: torch.Tensor,
|
||||
t5_out: torch.Tensor,
|
||||
lg_pooled: torch.Tensor,
|
||||
l_attn_mask: torch.Tensor,
|
||||
g_attn_mask: torch.Tensor,
|
||||
t5_attn_mask: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# dropout: if enable_dropout is True, dropout is not applied. dropout means zeroing out embeddings
|
||||
if lg_out is not None:
|
||||
for i in range(lg_out.shape[0]):
|
||||
drop_l = self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate
|
||||
if drop_l:
|
||||
lg_out[i, :, :768] = torch.zeros_like(lg_out[i, :, :768])
|
||||
lg_pooled[i, :768] = torch.zeros_like(lg_pooled[i, :768])
|
||||
if l_attn_mask is not None:
|
||||
l_attn_mask[i] = torch.zeros_like(l_attn_mask[i])
|
||||
drop_g = self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate
|
||||
if drop_g:
|
||||
lg_out[i, :, 768:] = torch.zeros_like(lg_out[i, :, 768:])
|
||||
lg_pooled[i, 768:] = torch.zeros_like(lg_pooled[i, 768:])
|
||||
if g_attn_mask is not None:
|
||||
g_attn_mask[i] = torch.zeros_like(g_attn_mask[i])
|
||||
|
||||
if t5_out is not None:
|
||||
for i in range(t5_out.shape[0]):
|
||||
drop_t5 = self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate
|
||||
if drop_t5:
|
||||
t5_out[i] = torch.zeros_like(t5_out[i])
|
||||
if t5_attn_mask is not None:
|
||||
t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i])
|
||||
|
||||
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
|
||||
|
||||
def concat_encodings(
|
||||
self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
|
||||
if t5_out is None:
|
||||
t5_out = torch.zeros((lg_out.shape[0], 77, 4096), device=lg_out.device, dtype=lg_out.dtype)
|
||||
return torch.cat([lg_out, t5_out], dim=-2), lg_pooled
|
||||
|
||||
|
||||
class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_sd3_te.npz"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_to_disk: bool,
|
||||
batch_size: int,
|
||||
skip_disk_cache_validity_check: bool,
|
||||
is_partial: bool = False,
|
||||
apply_lg_attn_mask: bool = False,
|
||||
apply_t5_attn_mask: bool = False,
|
||||
) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
|
||||
self.apply_lg_attn_mask = apply_lg_attn_mask
|
||||
self.apply_t5_attn_mask = apply_t5_attn_mask
|
||||
|
||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||
return os.path.splitext(image_abs_path)[0] + Sd3TextEncoderOutputsCachingStrategy.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||
|
||||
def is_disk_cached_outputs_expected(self, npz_path: str):
|
||||
if not self.cache_to_disk:
|
||||
return False
|
||||
if not os.path.exists(npz_path):
|
||||
return False
|
||||
if self.skip_disk_cache_validity_check:
|
||||
return True
|
||||
|
||||
try:
|
||||
npz = np.load(npz_path)
|
||||
if "lg_out" not in npz:
|
||||
return False
|
||||
if "lg_pooled" not in npz:
|
||||
return False
|
||||
if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz: # necessary even if not used
|
||||
return False
|
||||
if "apply_lg_attn_mask" not in npz:
|
||||
return False
|
||||
if "t5_out" not in npz:
|
||||
return False
|
||||
if "t5_attn_mask" not in npz:
|
||||
return False
|
||||
npz_apply_lg_attn_mask = npz["apply_lg_attn_mask"]
|
||||
if npz_apply_lg_attn_mask != self.apply_lg_attn_mask:
|
||||
return False
|
||||
if "apply_t5_attn_mask" not in npz:
|
||||
return False
|
||||
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
|
||||
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
raise e
|
||||
|
||||
return True
|
||||
|
||||
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||
data = np.load(npz_path)
|
||||
lg_out = data["lg_out"]
|
||||
lg_pooled = data["lg_pooled"]
|
||||
t5_out = data["t5_out"]
|
||||
|
||||
l_attn_mask = data["clip_l_attn_mask"]
|
||||
g_attn_mask = data["clip_g_attn_mask"]
|
||||
t5_attn_mask = data["t5_attn_mask"]
|
||||
|
||||
# apply_t5_attn_mask and apply_lg_attn_mask are same as self.apply_t5_attn_mask and self.apply_lg_attn_mask
|
||||
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
|
||||
|
||||
def cache_batch_outputs(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
|
||||
):
|
||||
sd3_text_encoding_strategy: Sd3TextEncodingStrategy = text_encoding_strategy
|
||||
captions = [info.caption for info in infos]
|
||||
|
||||
tokens_and_masks = tokenize_strategy.tokenize(captions)
|
||||
with torch.no_grad():
|
||||
# always disable dropout during caching
|
||||
lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = sd3_text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy,
|
||||
models,
|
||||
tokens_and_masks,
|
||||
apply_lg_attn_mask=self.apply_lg_attn_mask,
|
||||
apply_t5_attn_mask=self.apply_t5_attn_mask,
|
||||
enable_dropout=False,
|
||||
)
|
||||
|
||||
if lg_out.dtype == torch.bfloat16:
|
||||
lg_out = lg_out.float()
|
||||
if lg_pooled.dtype == torch.bfloat16:
|
||||
lg_pooled = lg_pooled.float()
|
||||
if t5_out.dtype == torch.bfloat16:
|
||||
t5_out = t5_out.float()
|
||||
|
||||
lg_out = lg_out.cpu().numpy()
|
||||
lg_pooled = lg_pooled.cpu().numpy()
|
||||
t5_out = t5_out.cpu().numpy()
|
||||
|
||||
l_attn_mask = tokens_and_masks[3].cpu().numpy()
|
||||
g_attn_mask = tokens_and_masks[4].cpu().numpy()
|
||||
t5_attn_mask = tokens_and_masks[5].cpu().numpy()
|
||||
|
||||
for i, info in enumerate(infos):
|
||||
lg_out_i = lg_out[i]
|
||||
t5_out_i = t5_out[i]
|
||||
lg_pooled_i = lg_pooled[i]
|
||||
l_attn_mask_i = l_attn_mask[i]
|
||||
g_attn_mask_i = g_attn_mask[i]
|
||||
t5_attn_mask_i = t5_attn_mask[i]
|
||||
apply_lg_attn_mask = self.apply_lg_attn_mask
|
||||
apply_t5_attn_mask = self.apply_t5_attn_mask
|
||||
|
||||
if self.cache_to_disk:
|
||||
np.savez(
|
||||
info.text_encoder_outputs_npz,
|
||||
lg_out=lg_out_i,
|
||||
lg_pooled=lg_pooled_i,
|
||||
t5_out=t5_out_i,
|
||||
clip_l_attn_mask=l_attn_mask_i,
|
||||
clip_g_attn_mask=g_attn_mask_i,
|
||||
t5_attn_mask=t5_attn_mask_i,
|
||||
apply_lg_attn_mask=apply_lg_attn_mask,
|
||||
apply_t5_attn_mask=apply_t5_attn_mask,
|
||||
)
|
||||
else:
|
||||
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
|
||||
info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i, l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i)
|
||||
|
||||
|
||||
class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
|
||||
SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz"
|
||||
|
||||
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||
|
||||
@property
|
||||
def cache_suffix(self) -> str:
|
||||
return Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
|
||||
|
||||
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||
return (
|
||||
os.path.splitext(absolute_path)[0]
|
||||
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
|
||||
+ Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
|
||||
)
|
||||
|
||||
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
|
||||
|
||||
def load_latents_from_disk(
|
||||
self, npz_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution
|
||||
|
||||
# TODO remove circular dependency for ImageInfo
|
||||
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu")
|
||||
vae_device = vae.device
|
||||
vae_dtype = vae.dtype
|
||||
|
||||
self._default_cache_batch_latents(
|
||||
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
|
||||
)
|
||||
|
||||
if not train_util.HIGH_VRAM:
|
||||
train_util.clean_memory_on_device(vae.device)
|
||||
306
library/strategy_sdxl.py
Normal file
306
library/strategy_sdxl.py
Normal file
@@ -0,0 +1,306 @@
|
||||
import os
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
|
||||
from library.strategy_base import TokenizeStrategy, TextEncodingStrategy, TextEncoderOutputsCachingStrategy
|
||||
|
||||
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
|
||||
TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
|
||||
|
||||
class SdxlTokenizeStrategy(TokenizeStrategy):
|
||||
def __init__(self, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None:
|
||||
self.tokenizer1 = self._load_tokenizer(CLIPTokenizer, TOKENIZER1_PATH, tokenizer_cache_dir=tokenizer_cache_dir)
|
||||
self.tokenizer2 = self._load_tokenizer(CLIPTokenizer, TOKENIZER2_PATH, tokenizer_cache_dir=tokenizer_cache_dir)
|
||||
self.tokenizer2.pad_token_id = 0 # use 0 as pad token for tokenizer2
|
||||
|
||||
if max_length is None:
|
||||
self.max_length = self.tokenizer1.model_max_length
|
||||
else:
|
||||
self.max_length = max_length + 2
|
||||
|
||||
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
|
||||
text = [text] if isinstance(text, str) else text
|
||||
return (
|
||||
torch.stack([self._get_input_ids(self.tokenizer1, t, self.max_length) for t in text], dim=0),
|
||||
torch.stack([self._get_input_ids(self.tokenizer2, t, self.max_length) for t in text], dim=0),
|
||||
)
|
||||
|
||||
def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]:
|
||||
text = [text] if isinstance(text, str) else text
|
||||
tokens1_list, tokens2_list = [], []
|
||||
weights1_list, weights2_list = [], []
|
||||
for t in text:
|
||||
tokens1, weights1 = self._get_input_ids(self.tokenizer1, t, self.max_length, weighted=True)
|
||||
tokens2, weights2 = self._get_input_ids(self.tokenizer2, t, self.max_length, weighted=True)
|
||||
tokens1_list.append(tokens1)
|
||||
tokens2_list.append(tokens2)
|
||||
weights1_list.append(weights1)
|
||||
weights2_list.append(weights2)
|
||||
return [torch.stack(tokens1_list, dim=0), torch.stack(tokens2_list, dim=0)], [
|
||||
torch.stack(weights1_list, dim=0),
|
||||
torch.stack(weights2_list, dim=0),
|
||||
]
|
||||
|
||||
|
||||
class SdxlTextEncodingStrategy(TextEncodingStrategy):
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def _pool_workaround(
|
||||
self, text_encoder: CLIPTextModelWithProjection, last_hidden_state: torch.Tensor, input_ids: torch.Tensor, eos_token_id: int
|
||||
):
|
||||
r"""
|
||||
workaround for CLIP's pooling bug: it returns the hidden states for the max token id as the pooled output
|
||||
instead of the hidden states for the EOS token
|
||||
If we use Textual Inversion, we need to use the hidden states for the EOS token as the pooled output
|
||||
|
||||
Original code from CLIP's pooling function:
|
||||
|
||||
\# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
||||
\# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
\# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
||||
pooled_output = last_hidden_state[
|
||||
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
||||
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
|
||||
]
|
||||
"""
|
||||
|
||||
# input_ids: b*n,77
|
||||
# find index for EOS token
|
||||
|
||||
# Following code is not working if one of the input_ids has multiple EOS tokens (very odd case)
|
||||
# eos_token_index = torch.where(input_ids == eos_token_id)[1]
|
||||
# eos_token_index = eos_token_index.to(device=last_hidden_state.device)
|
||||
|
||||
# Create a mask where the EOS tokens are
|
||||
eos_token_mask = (input_ids == eos_token_id).int()
|
||||
|
||||
# Use argmax to find the last index of the EOS token for each element in the batch
|
||||
eos_token_index = torch.argmax(eos_token_mask, dim=1) # this will be 0 if there is no EOS token, it's fine
|
||||
eos_token_index = eos_token_index.to(device=last_hidden_state.device)
|
||||
|
||||
# get hidden states for EOS token
|
||||
pooled_output = last_hidden_state[
|
||||
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index
|
||||
]
|
||||
|
||||
# apply projection: projection may be of different dtype than last_hidden_state
|
||||
pooled_output = text_encoder.text_projection(pooled_output.to(text_encoder.text_projection.weight.dtype))
|
||||
pooled_output = pooled_output.to(last_hidden_state.dtype)
|
||||
|
||||
return pooled_output
|
||||
|
||||
def _get_hidden_states_sdxl(
|
||||
self,
|
||||
input_ids1: torch.Tensor,
|
||||
input_ids2: torch.Tensor,
|
||||
tokenizer1: CLIPTokenizer,
|
||||
tokenizer2: CLIPTokenizer,
|
||||
text_encoder1: Union[CLIPTextModel, torch.nn.Module],
|
||||
text_encoder2: Union[CLIPTextModelWithProjection, torch.nn.Module],
|
||||
unwrapped_text_encoder2: Optional[CLIPTextModelWithProjection] = None,
|
||||
):
|
||||
# input_ids: b,n,77 -> b*n, 77
|
||||
b_size = input_ids1.size()[0]
|
||||
if input_ids1.size()[1] == 1:
|
||||
max_token_length = None
|
||||
else:
|
||||
max_token_length = input_ids1.size()[1] * input_ids1.size()[2]
|
||||
input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77
|
||||
input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77
|
||||
input_ids1 = input_ids1.to(text_encoder1.device)
|
||||
input_ids2 = input_ids2.to(text_encoder2.device)
|
||||
|
||||
# text_encoder1
|
||||
enc_out = text_encoder1(input_ids1, output_hidden_states=True, return_dict=True)
|
||||
hidden_states1 = enc_out["hidden_states"][11]
|
||||
|
||||
# text_encoder2
|
||||
enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True)
|
||||
hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer
|
||||
|
||||
# pool2 = enc_out["text_embeds"]
|
||||
unwrapped_text_encoder2 = unwrapped_text_encoder2 or text_encoder2
|
||||
pool2 = self._pool_workaround(unwrapped_text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id)
|
||||
|
||||
# b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280
|
||||
n_size = 1 if max_token_length is None else max_token_length // 75
|
||||
hidden_states1 = hidden_states1.reshape((b_size, -1, hidden_states1.shape[-1]))
|
||||
hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1]))
|
||||
|
||||
if max_token_length is not None:
|
||||
# bs*3, 77, 768 or 1024
|
||||
# encoder1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
|
||||
states_list = [hidden_states1[:, 0].unsqueeze(1)] # <BOS>
|
||||
for i in range(1, max_token_length, tokenizer1.model_max_length):
|
||||
states_list.append(hidden_states1[:, i : i + tokenizer1.model_max_length - 2]) # <BOS> の後から <EOS> の前まで
|
||||
states_list.append(hidden_states1[:, -1].unsqueeze(1)) # <EOS>
|
||||
hidden_states1 = torch.cat(states_list, dim=1)
|
||||
|
||||
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
|
||||
states_list = [hidden_states2[:, 0].unsqueeze(1)] # <BOS>
|
||||
for i in range(1, max_token_length, tokenizer2.model_max_length):
|
||||
chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # <BOS> の後から 最後の前まで
|
||||
# this causes an error:
|
||||
# RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
|
||||
# if i > 1:
|
||||
# for j in range(len(chunk)): # batch_size
|
||||
# if input_ids2[n_index + j * n_size, 1] == tokenizer2.eos_token_id: # 空、つまり <BOS> <EOS> <PAD> ...のパターン
|
||||
# chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
|
||||
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
|
||||
states_list.append(hidden_states2[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
|
||||
hidden_states2 = torch.cat(states_list, dim=1)
|
||||
|
||||
# pool はnの最初のものを使う
|
||||
pool2 = pool2[::n_size]
|
||||
|
||||
return hidden_states1, hidden_states2, pool2
|
||||
|
||||
def encode_tokens(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
|
||||
) -> List[torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
tokenize_strategy: TokenizeStrategy
|
||||
models: List of models, [text_encoder1, text_encoder2, unwrapped text_encoder2 (optional)].
|
||||
If text_encoder2 is wrapped by accelerate, unwrapped_text_encoder2 is required
|
||||
tokens: List of tokens, for text_encoder1 and text_encoder2
|
||||
"""
|
||||
if len(models) == 2:
|
||||
text_encoder1, text_encoder2 = models
|
||||
unwrapped_text_encoder2 = None
|
||||
else:
|
||||
text_encoder1, text_encoder2, unwrapped_text_encoder2 = models
|
||||
tokens1, tokens2 = tokens
|
||||
sdxl_tokenize_strategy = tokenize_strategy # type: SdxlTokenizeStrategy
|
||||
tokenizer1, tokenizer2 = sdxl_tokenize_strategy.tokenizer1, sdxl_tokenize_strategy.tokenizer2
|
||||
|
||||
hidden_states1, hidden_states2, pool2 = self._get_hidden_states_sdxl(
|
||||
tokens1, tokens2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, unwrapped_text_encoder2
|
||||
)
|
||||
return [hidden_states1, hidden_states2, pool2]
|
||||
|
||||
def encode_tokens_with_weights(
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: List[Any],
|
||||
tokens_list: List[torch.Tensor],
|
||||
weights_list: List[torch.Tensor],
|
||||
) -> List[torch.Tensor]:
|
||||
hidden_states1, hidden_states2, pool2 = self.encode_tokens(tokenize_strategy, models, tokens_list)
|
||||
|
||||
weights_list = [weights.to(hidden_states1.device) for weights in weights_list]
|
||||
|
||||
# apply weights
|
||||
if weights_list[0].shape[1] == 1: # no max_token_length
|
||||
# weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768)
|
||||
hidden_states1 = hidden_states1 * weights_list[0].squeeze(1).unsqueeze(2)
|
||||
hidden_states2 = hidden_states2 * weights_list[1].squeeze(1).unsqueeze(2)
|
||||
else:
|
||||
# weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768)
|
||||
for weight, hidden_states in zip(weights_list, [hidden_states1, hidden_states2]):
|
||||
for i in range(weight.shape[1]):
|
||||
hidden_states[:, i * 75 + 1 : i * 75 + 76] = hidden_states[:, i * 75 + 1 : i * 75 + 76] * weight[
|
||||
:, i, 1:-1
|
||||
].unsqueeze(-1)
|
||||
|
||||
return [hidden_states1, hidden_states2, pool2]
|
||||
|
||||
|
||||
class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_to_disk: bool,
|
||||
batch_size: int,
|
||||
skip_disk_cache_validity_check: bool,
|
||||
is_partial: bool = False,
|
||||
is_weighted: bool = False,
|
||||
) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial, is_weighted)
|
||||
|
||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||
return os.path.splitext(image_abs_path)[0] + SdxlTextEncoderOutputsCachingStrategy.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||
|
||||
def is_disk_cached_outputs_expected(self, npz_path: str):
|
||||
if not self.cache_to_disk:
|
||||
return False
|
||||
if not os.path.exists(npz_path):
|
||||
return False
|
||||
if self.skip_disk_cache_validity_check:
|
||||
return True
|
||||
|
||||
try:
|
||||
npz = np.load(npz_path)
|
||||
if "hidden_state1" not in npz or "hidden_state2" not in npz or "pool2" not in npz:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
raise e
|
||||
|
||||
return True
|
||||
|
||||
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||
data = np.load(npz_path)
|
||||
hidden_state1 = data["hidden_state1"]
|
||||
hidden_state2 = data["hidden_state2"]
|
||||
pool2 = data["pool2"]
|
||||
return [hidden_state1, hidden_state2, pool2]
|
||||
|
||||
def cache_batch_outputs(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
|
||||
):
|
||||
sdxl_text_encoding_strategy = text_encoding_strategy # type: SdxlTextEncodingStrategy
|
||||
captions = [info.caption for info in infos]
|
||||
|
||||
if self.is_weighted:
|
||||
tokens_list, weights_list = tokenize_strategy.tokenize_with_weights(captions)
|
||||
with torch.no_grad():
|
||||
hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens_with_weights(
|
||||
tokenize_strategy, models, tokens_list, weights_list
|
||||
)
|
||||
else:
|
||||
tokens1, tokens2 = tokenize_strategy.tokenize(captions)
|
||||
with torch.no_grad():
|
||||
hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, models, [tokens1, tokens2]
|
||||
)
|
||||
|
||||
if hidden_state1.dtype == torch.bfloat16:
|
||||
hidden_state1 = hidden_state1.float()
|
||||
if hidden_state2.dtype == torch.bfloat16:
|
||||
hidden_state2 = hidden_state2.float()
|
||||
if pool2.dtype == torch.bfloat16:
|
||||
pool2 = pool2.float()
|
||||
|
||||
hidden_state1 = hidden_state1.cpu().numpy()
|
||||
hidden_state2 = hidden_state2.cpu().numpy()
|
||||
pool2 = pool2.cpu().numpy()
|
||||
|
||||
for i, info in enumerate(infos):
|
||||
hidden_state1_i = hidden_state1[i]
|
||||
hidden_state2_i = hidden_state2[i]
|
||||
pool2_i = pool2[i]
|
||||
|
||||
if self.cache_to_disk:
|
||||
np.savez(
|
||||
info.text_encoder_outputs_npz,
|
||||
hidden_state1=hidden_state1_i,
|
||||
hidden_state2=hidden_state2_i,
|
||||
pool2=pool2_i,
|
||||
)
|
||||
else:
|
||||
info.text_encoder_outputs = [hidden_state1_i, hidden_state2_i, pool2_i]
|
||||
File diff suppressed because it is too large
Load Diff
691
library/utils.py
691
library/utils.py
@@ -1,6 +1,695 @@
|
||||
import logging
|
||||
import sys
|
||||
import threading
|
||||
from typing import *
|
||||
import json
|
||||
import struct
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision import transforms
|
||||
from diffusers import EulerAncestralDiscreteScheduler
|
||||
import diffusers.schedulers.scheduling_euler_ancestral_discrete
|
||||
from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput
|
||||
import cv2
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from safetensors.torch import load_file
|
||||
|
||||
def fire_in_thread(f, *args, **kwargs):
|
||||
threading.Thread(target=f, args=args, kwargs=kwargs).start()
|
||||
threading.Thread(target=f, args=args, kwargs=kwargs).start()
|
||||
|
||||
|
||||
# region Logging
|
||||
|
||||
|
||||
def add_logging_arguments(parser):
|
||||
parser.add_argument(
|
||||
"--console_log_level",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
||||
help="Set the logging level, default is INFO / ログレベルを設定する。デフォルトはINFO",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--console_log_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Log to a file instead of stderr / 標準エラー出力ではなくファイルにログを出力する",
|
||||
)
|
||||
parser.add_argument("--console_log_simple", action="store_true", help="Simple log output / シンプルなログ出力")
|
||||
|
||||
|
||||
def setup_logging(args=None, log_level=None, reset=False):
|
||||
if logging.root.handlers:
|
||||
if reset:
|
||||
# remove all handlers
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
else:
|
||||
return
|
||||
|
||||
# log_level can be set by the caller or by the args, the caller has priority. If not set, use INFO
|
||||
if log_level is None and args is not None:
|
||||
log_level = args.console_log_level
|
||||
if log_level is None:
|
||||
log_level = "INFO"
|
||||
log_level = getattr(logging, log_level)
|
||||
|
||||
msg_init = None
|
||||
if args is not None and args.console_log_file:
|
||||
handler = logging.FileHandler(args.console_log_file, mode="w")
|
||||
else:
|
||||
handler = None
|
||||
if not args or not args.console_log_simple:
|
||||
try:
|
||||
from rich.logging import RichHandler
|
||||
from rich.console import Console
|
||||
from rich.logging import RichHandler
|
||||
|
||||
handler = RichHandler(console=Console(stderr=True))
|
||||
except ImportError:
|
||||
# print("rich is not installed, using basic logging")
|
||||
msg_init = "rich is not installed, using basic logging"
|
||||
|
||||
if handler is None:
|
||||
handler = logging.StreamHandler(sys.stdout) # same as print
|
||||
handler.propagate = False
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="%(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
handler.setFormatter(formatter)
|
||||
logging.root.setLevel(log_level)
|
||||
logging.root.addHandler(handler)
|
||||
|
||||
if msg_init is not None:
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(msg_init)
|
||||
|
||||
setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# endregion
|
||||
|
||||
# region PyTorch utils
|
||||
|
||||
|
||||
def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
||||
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
||||
|
||||
weight_swap_jobs = []
|
||||
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
||||
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
||||
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
||||
|
||||
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
||||
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
# cuda to cpu
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
cuda_data_view.record_stream(stream)
|
||||
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
|
||||
|
||||
stream.synchronize()
|
||||
|
||||
# cpu to cuda
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
||||
module_to_cuda.weight.data = cuda_data_view
|
||||
|
||||
stream.synchronize()
|
||||
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
||||
|
||||
|
||||
def weighs_to_device(layer: nn.Module, device: torch.device):
|
||||
for module in layer.modules():
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
module.weight.data = module.weight.data.to(device, non_blocking=True)
|
||||
|
||||
|
||||
def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype:
|
||||
"""
|
||||
Convert a string to a torch.dtype
|
||||
|
||||
Args:
|
||||
s: string representation of the dtype
|
||||
default_dtype: default dtype to return if s is None
|
||||
|
||||
Returns:
|
||||
torch.dtype: the corresponding torch.dtype
|
||||
|
||||
Raises:
|
||||
ValueError: if the dtype is not supported
|
||||
|
||||
Examples:
|
||||
>>> str_to_dtype("float32")
|
||||
torch.float32
|
||||
>>> str_to_dtype("fp32")
|
||||
torch.float32
|
||||
>>> str_to_dtype("float16")
|
||||
torch.float16
|
||||
>>> str_to_dtype("fp16")
|
||||
torch.float16
|
||||
>>> str_to_dtype("bfloat16")
|
||||
torch.bfloat16
|
||||
>>> str_to_dtype("bf16")
|
||||
torch.bfloat16
|
||||
>>> str_to_dtype("fp8")
|
||||
torch.float8_e4m3fn
|
||||
>>> str_to_dtype("fp8_e4m3fn")
|
||||
torch.float8_e4m3fn
|
||||
>>> str_to_dtype("fp8_e4m3fnuz")
|
||||
torch.float8_e4m3fnuz
|
||||
>>> str_to_dtype("fp8_e5m2")
|
||||
torch.float8_e5m2
|
||||
>>> str_to_dtype("fp8_e5m2fnuz")
|
||||
torch.float8_e5m2fnuz
|
||||
"""
|
||||
if s is None:
|
||||
return default_dtype
|
||||
if s in ["bf16", "bfloat16"]:
|
||||
return torch.bfloat16
|
||||
elif s in ["fp16", "float16"]:
|
||||
return torch.float16
|
||||
elif s in ["fp32", "float32", "float"]:
|
||||
return torch.float32
|
||||
elif s in ["fp8_e4m3fn", "e4m3fn", "float8_e4m3fn"]:
|
||||
return torch.float8_e4m3fn
|
||||
elif s in ["fp8_e4m3fnuz", "e4m3fnuz", "float8_e4m3fnuz"]:
|
||||
return torch.float8_e4m3fnuz
|
||||
elif s in ["fp8_e5m2", "e5m2", "float8_e5m2"]:
|
||||
return torch.float8_e5m2
|
||||
elif s in ["fp8_e5m2fnuz", "e5m2fnuz", "float8_e5m2fnuz"]:
|
||||
return torch.float8_e5m2fnuz
|
||||
elif s in ["fp8", "float8"]:
|
||||
return torch.float8_e4m3fn # default fp8
|
||||
else:
|
||||
raise ValueError(f"Unsupported dtype: {s}")
|
||||
|
||||
|
||||
def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None):
|
||||
"""
|
||||
memory efficient save file
|
||||
"""
|
||||
|
||||
_TYPES = {
|
||||
torch.float64: "F64",
|
||||
torch.float32: "F32",
|
||||
torch.float16: "F16",
|
||||
torch.bfloat16: "BF16",
|
||||
torch.int64: "I64",
|
||||
torch.int32: "I32",
|
||||
torch.int16: "I16",
|
||||
torch.int8: "I8",
|
||||
torch.uint8: "U8",
|
||||
torch.bool: "BOOL",
|
||||
getattr(torch, "float8_e5m2", None): "F8_E5M2",
|
||||
getattr(torch, "float8_e4m3fn", None): "F8_E4M3",
|
||||
}
|
||||
_ALIGN = 256
|
||||
|
||||
def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]:
|
||||
validated = {}
|
||||
for key, value in metadata.items():
|
||||
if not isinstance(key, str):
|
||||
raise ValueError(f"Metadata key must be a string, got {type(key)}")
|
||||
if not isinstance(value, str):
|
||||
print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.")
|
||||
validated[key] = str(value)
|
||||
else:
|
||||
validated[key] = value
|
||||
return validated
|
||||
|
||||
print(f"Using memory efficient save file: {filename}")
|
||||
|
||||
header = {}
|
||||
offset = 0
|
||||
if metadata:
|
||||
header["__metadata__"] = validate_metadata(metadata)
|
||||
for k, v in tensors.items():
|
||||
if v.numel() == 0: # empty tensor
|
||||
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]}
|
||||
else:
|
||||
size = v.numel() * v.element_size()
|
||||
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]}
|
||||
offset += size
|
||||
|
||||
hjson = json.dumps(header).encode("utf-8")
|
||||
hjson += b" " * (-(len(hjson) + 8) % _ALIGN)
|
||||
|
||||
with open(filename, "wb") as f:
|
||||
f.write(struct.pack("<Q", len(hjson)))
|
||||
f.write(hjson)
|
||||
|
||||
for k, v in tensors.items():
|
||||
if v.numel() == 0:
|
||||
continue
|
||||
if v.is_cuda:
|
||||
# Direct GPU to disk save
|
||||
with torch.cuda.device(v.device):
|
||||
if v.dim() == 0: # if scalar, need to add a dimension to work with view
|
||||
v = v.unsqueeze(0)
|
||||
tensor_bytes = v.contiguous().view(torch.uint8)
|
||||
tensor_bytes.cpu().numpy().tofile(f)
|
||||
else:
|
||||
# CPU tensor save
|
||||
if v.dim() == 0: # if scalar, need to add a dimension to work with view
|
||||
v = v.unsqueeze(0)
|
||||
v.contiguous().view(torch.uint8).numpy().tofile(f)
|
||||
|
||||
|
||||
class MemoryEfficientSafeOpen:
|
||||
def __init__(self, filename):
|
||||
self.filename = filename
|
||||
self.file = open(filename, "rb")
|
||||
self.header, self.header_size = self._read_header()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.file.close()
|
||||
|
||||
def keys(self):
|
||||
return [k for k in self.header.keys() if k != "__metadata__"]
|
||||
|
||||
def metadata(self) -> Dict[str, str]:
|
||||
return self.header.get("__metadata__", {})
|
||||
|
||||
def get_tensor(self, key):
|
||||
if key not in self.header:
|
||||
raise KeyError(f"Tensor '{key}' not found in the file")
|
||||
|
||||
metadata = self.header[key]
|
||||
offset_start, offset_end = metadata["data_offsets"]
|
||||
|
||||
if offset_start == offset_end:
|
||||
tensor_bytes = None
|
||||
else:
|
||||
# adjust offset by header size
|
||||
self.file.seek(self.header_size + 8 + offset_start)
|
||||
tensor_bytes = self.file.read(offset_end - offset_start)
|
||||
|
||||
return self._deserialize_tensor(tensor_bytes, metadata)
|
||||
|
||||
def _read_header(self):
|
||||
header_size = struct.unpack("<Q", self.file.read(8))[0]
|
||||
header_json = self.file.read(header_size).decode("utf-8")
|
||||
return json.loads(header_json), header_size
|
||||
|
||||
def _deserialize_tensor(self, tensor_bytes, metadata):
|
||||
dtype = self._get_torch_dtype(metadata["dtype"])
|
||||
shape = metadata["shape"]
|
||||
|
||||
if tensor_bytes is None:
|
||||
byte_tensor = torch.empty(0, dtype=torch.uint8)
|
||||
else:
|
||||
tensor_bytes = bytearray(tensor_bytes) # make it writable
|
||||
byte_tensor = torch.frombuffer(tensor_bytes, dtype=torch.uint8)
|
||||
|
||||
# process float8 types
|
||||
if metadata["dtype"] in ["F8_E5M2", "F8_E4M3"]:
|
||||
return self._convert_float8(byte_tensor, metadata["dtype"], shape)
|
||||
|
||||
# convert to the target dtype and reshape
|
||||
return byte_tensor.view(dtype).reshape(shape)
|
||||
|
||||
@staticmethod
|
||||
def _get_torch_dtype(dtype_str):
|
||||
dtype_map = {
|
||||
"F64": torch.float64,
|
||||
"F32": torch.float32,
|
||||
"F16": torch.float16,
|
||||
"BF16": torch.bfloat16,
|
||||
"I64": torch.int64,
|
||||
"I32": torch.int32,
|
||||
"I16": torch.int16,
|
||||
"I8": torch.int8,
|
||||
"U8": torch.uint8,
|
||||
"BOOL": torch.bool,
|
||||
}
|
||||
# add float8 types if available
|
||||
if hasattr(torch, "float8_e5m2"):
|
||||
dtype_map["F8_E5M2"] = torch.float8_e5m2
|
||||
if hasattr(torch, "float8_e4m3fn"):
|
||||
dtype_map["F8_E4M3"] = torch.float8_e4m3fn
|
||||
return dtype_map.get(dtype_str)
|
||||
|
||||
@staticmethod
|
||||
def _convert_float8(byte_tensor, dtype_str, shape):
|
||||
if dtype_str == "F8_E5M2" and hasattr(torch, "float8_e5m2"):
|
||||
return byte_tensor.view(torch.float8_e5m2).reshape(shape)
|
||||
elif dtype_str == "F8_E4M3" and hasattr(torch, "float8_e4m3fn"):
|
||||
return byte_tensor.view(torch.float8_e4m3fn).reshape(shape)
|
||||
else:
|
||||
# # convert to float16 if float8 is not supported
|
||||
# print(f"Warning: {dtype_str} is not supported in this PyTorch version. Converting to float16.")
|
||||
# return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape)
|
||||
raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)")
|
||||
|
||||
|
||||
def load_safetensors(
|
||||
path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32
|
||||
) -> dict[str, torch.Tensor]:
|
||||
if disable_mmap:
|
||||
# return safetensors.torch.load(open(path, "rb").read())
|
||||
# use experimental loader
|
||||
# logger.info(f"Loading without mmap (experimental)")
|
||||
state_dict = {}
|
||||
with MemoryEfficientSafeOpen(path) as f:
|
||||
for key in f.keys():
|
||||
state_dict[key] = f.get_tensor(key).to(device, dtype=dtype)
|
||||
return state_dict
|
||||
else:
|
||||
try:
|
||||
state_dict = load_file(path, device=device)
|
||||
except:
|
||||
state_dict = load_file(path) # prevent device invalid Error
|
||||
if dtype is not None:
|
||||
for key in state_dict.keys():
|
||||
state_dict[key] = state_dict[key].to(dtype=dtype)
|
||||
return state_dict
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Image utils
|
||||
|
||||
|
||||
def pil_resize(image, size, interpolation):
|
||||
has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False
|
||||
|
||||
if has_alpha:
|
||||
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA))
|
||||
else:
|
||||
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
||||
|
||||
resized_pil = pil_image.resize(size, resample=interpolation)
|
||||
|
||||
# Convert back to cv2 format
|
||||
if has_alpha:
|
||||
resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGBA2BGRA)
|
||||
else:
|
||||
resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR)
|
||||
|
||||
return resized_cv2
|
||||
|
||||
|
||||
def resize_image(image: np.ndarray, width: int, height: int, resized_width: int, resized_height: int, resize_interpolation: Optional[str] = None):
|
||||
"""
|
||||
Resize image with resize interpolation. Default interpolation to AREA if image is smaller, else LANCZOS.
|
||||
|
||||
Args:
|
||||
image: numpy.ndarray
|
||||
width: int Original image width
|
||||
height: int Original image height
|
||||
resized_width: int Resized image width
|
||||
resized_height: int Resized image height
|
||||
resize_interpolation: Optional[str] Resize interpolation method "lanczos", "area", "bilinear", "bicubic", "nearest", "box"
|
||||
|
||||
Returns:
|
||||
image
|
||||
"""
|
||||
|
||||
# Ensure all size parameters are actual integers
|
||||
width = int(width)
|
||||
height = int(height)
|
||||
resized_width = int(resized_width)
|
||||
resized_height = int(resized_height)
|
||||
|
||||
if resize_interpolation is None:
|
||||
if width >= resized_width and height >= resized_height:
|
||||
resize_interpolation = "area"
|
||||
else:
|
||||
resize_interpolation = "lanczos"
|
||||
|
||||
# we use PIL for lanczos (for backward compatibility) and box, cv2 for others
|
||||
use_pil = resize_interpolation in ["lanczos", "lanczos4", "box"]
|
||||
|
||||
resized_size = (resized_width, resized_height)
|
||||
if use_pil:
|
||||
interpolation = get_pil_interpolation(resize_interpolation)
|
||||
image = pil_resize(image, resized_size, interpolation=interpolation)
|
||||
logger.debug(f"resize image using {resize_interpolation} (PIL)")
|
||||
else:
|
||||
interpolation = get_cv2_interpolation(resize_interpolation)
|
||||
image = cv2.resize(image, resized_size, interpolation=interpolation)
|
||||
logger.debug(f"resize image using {resize_interpolation} (cv2)")
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def get_cv2_interpolation(interpolation: Optional[str]) -> Optional[int]:
|
||||
"""
|
||||
Convert interpolation value to cv2 interpolation integer
|
||||
|
||||
https://docs.opencv.org/3.4/da/d54/group__imgproc__transform.html#ga5bb5a1fea74ea38e1a5445ca803ff121
|
||||
"""
|
||||
if interpolation is None:
|
||||
return None
|
||||
|
||||
if interpolation == "lanczos" or interpolation == "lanczos4":
|
||||
# Lanczos interpolation over 8x8 neighborhood
|
||||
return cv2.INTER_LANCZOS4
|
||||
elif interpolation == "nearest":
|
||||
# Bit exact nearest neighbor interpolation. This will produce same results as the nearest neighbor method in PIL, scikit-image or Matlab.
|
||||
return cv2.INTER_NEAREST_EXACT
|
||||
elif interpolation == "bilinear" or interpolation == "linear":
|
||||
# bilinear interpolation
|
||||
return cv2.INTER_LINEAR
|
||||
elif interpolation == "bicubic" or interpolation == "cubic":
|
||||
# bicubic interpolation
|
||||
return cv2.INTER_CUBIC
|
||||
elif interpolation == "area":
|
||||
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
|
||||
return cv2.INTER_AREA
|
||||
elif interpolation == "box":
|
||||
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
|
||||
return cv2.INTER_AREA
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resampling]:
|
||||
"""
|
||||
Convert interpolation value to PIL interpolation
|
||||
|
||||
https://pillow.readthedocs.io/en/stable/handbook/concepts.html#concept-filters
|
||||
"""
|
||||
if interpolation is None:
|
||||
return None
|
||||
|
||||
if interpolation == "lanczos":
|
||||
return Image.Resampling.LANCZOS
|
||||
elif interpolation == "nearest":
|
||||
# Pick one nearest pixel from the input image. Ignore all other input pixels.
|
||||
return Image.Resampling.NEAREST
|
||||
elif interpolation == "bilinear" or interpolation == "linear":
|
||||
# For resize calculate the output pixel value using linear interpolation on all pixels that may contribute to the output value. For other transformations linear interpolation over a 2x2 environment in the input image is used.
|
||||
return Image.Resampling.BILINEAR
|
||||
elif interpolation == "bicubic" or interpolation == "cubic":
|
||||
# For resize calculate the output pixel value using cubic interpolation on all pixels that may contribute to the output value. For other transformations cubic interpolation over a 4x4 environment in the input image is used.
|
||||
return Image.Resampling.BICUBIC
|
||||
elif interpolation == "area":
|
||||
# Image.Resampling.BOX may be more appropriate if upscaling
|
||||
# Area interpolation is related to cv2.INTER_AREA
|
||||
# Produces a sharper image than Resampling.BILINEAR, doesn’t have dislocations on local level like with Resampling.BOX.
|
||||
return Image.Resampling.HAMMING
|
||||
elif interpolation == "box":
|
||||
# Each pixel of source image contributes to one pixel of the destination image with identical weights. For upscaling is equivalent of Resampling.NEAREST.
|
||||
return Image.Resampling.BOX
|
||||
else:
|
||||
return None
|
||||
|
||||
def validate_interpolation_fn(interpolation_str: str) -> bool:
|
||||
"""
|
||||
Check if a interpolation function is supported
|
||||
"""
|
||||
return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area", "box"]
|
||||
|
||||
# endregion
|
||||
|
||||
# TODO make inf_utils.py
|
||||
# region Gradual Latent hires fix
|
||||
|
||||
|
||||
class GradualLatent:
|
||||
def __init__(
|
||||
self,
|
||||
ratio,
|
||||
start_timesteps,
|
||||
every_n_steps,
|
||||
ratio_step,
|
||||
s_noise=1.0,
|
||||
gaussian_blur_ksize=None,
|
||||
gaussian_blur_sigma=0.5,
|
||||
gaussian_blur_strength=0.5,
|
||||
unsharp_target_x=True,
|
||||
):
|
||||
self.ratio = ratio
|
||||
self.start_timesteps = start_timesteps
|
||||
self.every_n_steps = every_n_steps
|
||||
self.ratio_step = ratio_step
|
||||
self.s_noise = s_noise
|
||||
self.gaussian_blur_ksize = gaussian_blur_ksize
|
||||
self.gaussian_blur_sigma = gaussian_blur_sigma
|
||||
self.gaussian_blur_strength = gaussian_blur_strength
|
||||
self.unsharp_target_x = unsharp_target_x
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"GradualLatent(ratio={self.ratio}, start_timesteps={self.start_timesteps}, "
|
||||
+ f"every_n_steps={self.every_n_steps}, ratio_step={self.ratio_step}, s_noise={self.s_noise}, "
|
||||
+ f"gaussian_blur_ksize={self.gaussian_blur_ksize}, gaussian_blur_sigma={self.gaussian_blur_sigma}, gaussian_blur_strength={self.gaussian_blur_strength}, "
|
||||
+ f"unsharp_target_x={self.unsharp_target_x})"
|
||||
)
|
||||
|
||||
def apply_unshark_mask(self, x: torch.Tensor):
|
||||
if self.gaussian_blur_ksize is None:
|
||||
return x
|
||||
blurred = transforms.functional.gaussian_blur(x, self.gaussian_blur_ksize, self.gaussian_blur_sigma)
|
||||
# mask = torch.sigmoid((x - blurred) * self.gaussian_blur_strength)
|
||||
mask = (x - blurred) * self.gaussian_blur_strength
|
||||
sharpened = x + mask
|
||||
return sharpened
|
||||
|
||||
def interpolate(self, x: torch.Tensor, resized_size, unsharp=True):
|
||||
org_dtype = x.dtype
|
||||
if org_dtype == torch.bfloat16:
|
||||
x = x.float()
|
||||
|
||||
x = torch.nn.functional.interpolate(x, size=resized_size, mode="bicubic", align_corners=False).to(dtype=org_dtype)
|
||||
|
||||
# apply unsharp mask / アンシャープマスクを適用する
|
||||
if unsharp and self.gaussian_blur_ksize:
|
||||
x = self.apply_unshark_mask(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.resized_size = None
|
||||
self.gradual_latent = None
|
||||
|
||||
def set_gradual_latent_params(self, size, gradual_latent: GradualLatent):
|
||||
self.resized_size = size
|
||||
self.gradual_latent = gradual_latent
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: Union[float, torch.FloatTensor],
|
||||
sample: torch.FloatTensor,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`):
|
||||
The direct output from learned diffusion model.
|
||||
timestep (`float`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A random number generator.
|
||||
return_dict (`bool`):
|
||||
Whether or not to return a
|
||||
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
|
||||
|
||||
Returns:
|
||||
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
|
||||
If return_dict is `True`,
|
||||
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
|
||||
otherwise a tuple is returned where the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
|
||||
if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor):
|
||||
raise ValueError(
|
||||
(
|
||||
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
||||
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
||||
" one of the `scheduler.timesteps` as a timestep."
|
||||
),
|
||||
)
|
||||
|
||||
if not self.is_scale_input_called:
|
||||
# logger.warning(
|
||||
print(
|
||||
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
|
||||
"See `StableDiffusionPipeline` for a usage example."
|
||||
)
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
|
||||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||
if self.config.prediction_type == "epsilon":
|
||||
pred_original_sample = sample - sigma * model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
# * c_out + input * c_skip
|
||||
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
|
||||
elif self.config.prediction_type == "sample":
|
||||
raise NotImplementedError("prediction_type not implemented yet: sample")
|
||||
else:
|
||||
raise ValueError(f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`")
|
||||
|
||||
sigma_from = self.sigmas[self.step_index]
|
||||
sigma_to = self.sigmas[self.step_index + 1]
|
||||
sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
|
||||
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
|
||||
|
||||
# 2. Convert to an ODE derivative
|
||||
derivative = (sample - pred_original_sample) / sigma
|
||||
|
||||
dt = sigma_down - sigma
|
||||
|
||||
device = model_output.device
|
||||
if self.resized_size is None:
|
||||
prev_sample = sample + derivative * dt
|
||||
|
||||
noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor(
|
||||
model_output.shape, dtype=model_output.dtype, device=device, generator=generator
|
||||
)
|
||||
s_noise = 1.0
|
||||
else:
|
||||
print("resized_size", self.resized_size, "model_output.shape", model_output.shape, "sample.shape", sample.shape)
|
||||
s_noise = self.gradual_latent.s_noise
|
||||
|
||||
if self.gradual_latent.unsharp_target_x:
|
||||
prev_sample = sample + derivative * dt
|
||||
prev_sample = self.gradual_latent.interpolate(prev_sample, self.resized_size)
|
||||
else:
|
||||
sample = self.gradual_latent.interpolate(sample, self.resized_size)
|
||||
derivative = self.gradual_latent.interpolate(derivative, self.resized_size, unsharp=False)
|
||||
prev_sample = sample + derivative * dt
|
||||
|
||||
noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor(
|
||||
(model_output.shape[0], model_output.shape[1], self.resized_size[0], self.resized_size[1]),
|
||||
dtype=model_output.dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
prev_sample = prev_sample + noise * sigma_up * s_noise
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
return EulerAncestralDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -2,10 +2,13 @@ import argparse
|
||||
import os
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def main(file):
|
||||
print(f"loading: {file}")
|
||||
logger.info(f"loading: {file}")
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
sd = load_file(file)
|
||||
else:
|
||||
@@ -15,7 +18,7 @@ def main(file):
|
||||
|
||||
keys = list(sd.keys())
|
||||
for key in keys:
|
||||
if "lora_up" in key or "lora_down" in key:
|
||||
if "lora_up" in key or "lora_down" in key or "lora_A" in key or "lora_B" in key or "oft_" in key:
|
||||
values.append((key, sd[key]))
|
||||
print(f"number of LoRA modules: {len(values)}")
|
||||
|
||||
|
||||
@@ -2,7 +2,10 @@ import os
|
||||
from typing import Optional, List, Type
|
||||
import torch
|
||||
from library import sdxl_original_unet
|
||||
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# input_blocksに適用するかどうか / if True, input_blocks are not applied
|
||||
SKIP_INPUT_BLOCKS = False
|
||||
@@ -125,7 +128,7 @@ class LLLiteModule(torch.nn.Module):
|
||||
return
|
||||
|
||||
# timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance
|
||||
# print(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}")
|
||||
# logger.info(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}")
|
||||
cx = self.conditioning1(cond_image)
|
||||
if not self.is_conv2d:
|
||||
# reshape / b,c,h,w -> b,h*w,c
|
||||
@@ -155,7 +158,7 @@ class LLLiteModule(torch.nn.Module):
|
||||
cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1)
|
||||
if self.use_zeros_for_batch_uncond:
|
||||
cx[0::2] = 0.0 # uncond is zero
|
||||
# print(f"C {self.lllite_name}, x.shape={x.shape}, cx.shape={cx.shape}")
|
||||
# logger.info(f"C {self.lllite_name}, x.shape={x.shape}, cx.shape={cx.shape}")
|
||||
|
||||
# downで入力の次元数を削減し、conditioning image embeddingと結合する
|
||||
# 加算ではなくchannel方向に結合することで、うまいこと混ぜてくれることを期待している
|
||||
@@ -286,7 +289,7 @@ class ControlNetLLLite(torch.nn.Module):
|
||||
|
||||
# create module instances
|
||||
self.unet_modules: List[LLLiteModule] = create_modules(unet, target_modules, LLLiteModule)
|
||||
print(f"create ControlNet LLLite for U-Net: {len(self.unet_modules)} modules.")
|
||||
logger.info(f"create ControlNet LLLite for U-Net: {len(self.unet_modules)} modules.")
|
||||
|
||||
def forward(self, x):
|
||||
return x # dummy
|
||||
@@ -319,7 +322,7 @@ class ControlNetLLLite(torch.nn.Module):
|
||||
return info
|
||||
|
||||
def apply_to(self):
|
||||
print("applying LLLite for U-Net...")
|
||||
logger.info("applying LLLite for U-Net...")
|
||||
for module in self.unet_modules:
|
||||
module.apply_to()
|
||||
self.add_module(module.lllite_name, module)
|
||||
@@ -374,19 +377,19 @@ if __name__ == "__main__":
|
||||
# sdxl_original_unet.USE_REENTRANT = False
|
||||
|
||||
# test shape etc
|
||||
print("create unet")
|
||||
logger.info("create unet")
|
||||
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
|
||||
unet.to("cuda").to(torch.float16)
|
||||
|
||||
print("create ControlNet-LLLite")
|
||||
logger.info("create ControlNet-LLLite")
|
||||
control_net = ControlNetLLLite(unet, 32, 64)
|
||||
control_net.apply_to()
|
||||
control_net.to("cuda")
|
||||
|
||||
print(control_net)
|
||||
logger.info(control_net)
|
||||
|
||||
# print number of parameters
|
||||
print("number of parameters", sum(p.numel() for p in control_net.parameters() if p.requires_grad))
|
||||
# logger.info number of parameters
|
||||
logger.info(f"number of parameters {sum(p.numel() for p in control_net.parameters() if p.requires_grad)}")
|
||||
|
||||
input()
|
||||
|
||||
@@ -398,12 +401,12 @@ if __name__ == "__main__":
|
||||
|
||||
# # visualize
|
||||
# import torchviz
|
||||
# print("run visualize")
|
||||
# logger.info("run visualize")
|
||||
# controlnet.set_control(conditioning_image)
|
||||
# output = unet(x, t, ctx, y)
|
||||
# print("make_dot")
|
||||
# logger.info("make_dot")
|
||||
# image = torchviz.make_dot(output, params=dict(controlnet.named_parameters()))
|
||||
# print("render")
|
||||
# logger.info("render")
|
||||
# image.format = "svg" # "png"
|
||||
# image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time
|
||||
# input()
|
||||
@@ -414,12 +417,12 @@ if __name__ == "__main__":
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=True)
|
||||
|
||||
print("start training")
|
||||
logger.info("start training")
|
||||
steps = 10
|
||||
|
||||
sample_param = [p for p in control_net.named_parameters() if "up" in p[0]][0]
|
||||
for step in range(steps):
|
||||
print(f"step {step}")
|
||||
logger.info(f"step {step}")
|
||||
|
||||
batch_size = 1
|
||||
conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0
|
||||
@@ -439,7 +442,7 @@ if __name__ == "__main__":
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
print(sample_param)
|
||||
logger.info(f"{sample_param}")
|
||||
|
||||
# from safetensors.torch import save_file
|
||||
|
||||
|
||||
@@ -6,7 +6,12 @@ import re
|
||||
from typing import Optional, List, Type
|
||||
import torch
|
||||
from library import sdxl_original_unet
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# input_blocksに適用するかどうか / if True, input_blocks are not applied
|
||||
SKIP_INPUT_BLOCKS = False
|
||||
@@ -100,19 +105,15 @@ class LLLiteLinear(ORIGINAL_LINEAR):
|
||||
add_lllite_modules(self, in_dim, depth, cond_emb_dim, mlp_dim)
|
||||
|
||||
self.cond_image = None
|
||||
self.cond_emb = None
|
||||
|
||||
def set_cond_image(self, cond_image):
|
||||
self.cond_image = cond_image
|
||||
self.cond_emb = None
|
||||
|
||||
def forward(self, x):
|
||||
if not self.enabled:
|
||||
return super().forward(x)
|
||||
|
||||
if self.cond_emb is None:
|
||||
self.cond_emb = self.lllite_conditioning1(self.cond_image)
|
||||
cx = self.cond_emb
|
||||
cx = self.lllite_conditioning1(self.cond_image) # make forward and backward compatible
|
||||
|
||||
# reshape / b,c,h,w -> b,h*w,c
|
||||
n, c, h, w = cx.shape
|
||||
@@ -156,9 +157,7 @@ class LLLiteConv2d(ORIGINAL_CONV2D):
|
||||
if not self.enabled:
|
||||
return super().forward(x)
|
||||
|
||||
if self.cond_emb is None:
|
||||
self.cond_emb = self.lllite_conditioning1(self.cond_image)
|
||||
cx = self.cond_emb
|
||||
cx = self.lllite_conditioning1(self.cond_image)
|
||||
|
||||
cx = torch.cat([cx, self.down(x)], dim=1)
|
||||
cx = self.mid(cx)
|
||||
@@ -270,7 +269,7 @@ class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DCond
|
||||
|
||||
# create module instances
|
||||
self.lllite_modules = apply_to_modules(self, target_modules)
|
||||
print(f"enable ControlNet LLLite for U-Net: {len(self.lllite_modules)} modules.")
|
||||
logger.info(f"enable ControlNet LLLite for U-Net: {len(self.lllite_modules)} modules.")
|
||||
|
||||
# def prepare_optimizer_params(self):
|
||||
def prepare_params(self):
|
||||
@@ -281,8 +280,8 @@ class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DCond
|
||||
train_params.append(p)
|
||||
else:
|
||||
non_train_params.append(p)
|
||||
print(f"count of trainable parameters: {len(train_params)}")
|
||||
print(f"count of non-trainable parameters: {len(non_train_params)}")
|
||||
logger.info(f"count of trainable parameters: {len(train_params)}")
|
||||
logger.info(f"count of non-trainable parameters: {len(non_train_params)}")
|
||||
|
||||
for p in non_train_params:
|
||||
p.requires_grad_(False)
|
||||
@@ -388,7 +387,7 @@ class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DCond
|
||||
matches = pattern.findall(module_name)
|
||||
if matches is not None:
|
||||
for m in matches:
|
||||
print(module_name, m)
|
||||
logger.info(f"{module_name} {m}")
|
||||
module_name = module_name.replace(m, m.replace("_", "@"))
|
||||
module_name = module_name.replace("_", ".")
|
||||
module_name = module_name.replace("@", "_")
|
||||
@@ -407,7 +406,7 @@ class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DCond
|
||||
|
||||
|
||||
def replace_unet_linear_and_conv2d():
|
||||
print("replace torch.nn.Linear and torch.nn.Conv2d to LLLiteLinear and LLLiteConv2d in U-Net")
|
||||
logger.info("replace torch.nn.Linear and torch.nn.Conv2d to LLLiteLinear and LLLiteConv2d in U-Net")
|
||||
sdxl_original_unet.torch.nn.Linear = LLLiteLinear
|
||||
sdxl_original_unet.torch.nn.Conv2d = LLLiteConv2d
|
||||
|
||||
@@ -419,10 +418,10 @@ if __name__ == "__main__":
|
||||
replace_unet_linear_and_conv2d()
|
||||
|
||||
# test shape etc
|
||||
print("create unet")
|
||||
logger.info("create unet")
|
||||
unet = SdxlUNet2DConditionModelControlNetLLLite()
|
||||
|
||||
print("enable ControlNet-LLLite")
|
||||
logger.info("enable ControlNet-LLLite")
|
||||
unet.apply_lllite(32, 64, None, False, 1.0)
|
||||
unet.to("cuda") # .to(torch.float16)
|
||||
|
||||
@@ -439,14 +438,14 @@ if __name__ == "__main__":
|
||||
# unet_sd[converted_key] = model_sd[key]
|
||||
|
||||
# info = unet.load_lllite_weights("r:/lllite_from_unet.safetensors", unet_sd)
|
||||
# print(info)
|
||||
# logger.info(info)
|
||||
|
||||
# print(unet)
|
||||
# logger.info(unet)
|
||||
|
||||
# print number of parameters
|
||||
# logger.info number of parameters
|
||||
params = unet.prepare_params()
|
||||
print("number of parameters", sum(p.numel() for p in params))
|
||||
# print("type any key to continue")
|
||||
logger.info(f"number of parameters {sum(p.numel() for p in params)}")
|
||||
# logger.info("type any key to continue")
|
||||
# input()
|
||||
|
||||
unet.set_use_memory_efficient_attention(True, False)
|
||||
@@ -455,12 +454,12 @@ if __name__ == "__main__":
|
||||
|
||||
# # visualize
|
||||
# import torchviz
|
||||
# print("run visualize")
|
||||
# logger.info("run visualize")
|
||||
# controlnet.set_control(conditioning_image)
|
||||
# output = unet(x, t, ctx, y)
|
||||
# print("make_dot")
|
||||
# logger.info("make_dot")
|
||||
# image = torchviz.make_dot(output, params=dict(controlnet.named_parameters()))
|
||||
# print("render")
|
||||
# logger.info("render")
|
||||
# image.format = "svg" # "png"
|
||||
# image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time
|
||||
# input()
|
||||
@@ -471,13 +470,13 @@ if __name__ == "__main__":
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=True)
|
||||
|
||||
print("start training")
|
||||
logger.info("start training")
|
||||
steps = 10
|
||||
batch_size = 1
|
||||
|
||||
sample_param = [p for p in unet.named_parameters() if ".lllite_up." in p[0]][0]
|
||||
for step in range(steps):
|
||||
print(f"step {step}")
|
||||
logger.info(f"step {step}")
|
||||
|
||||
conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0
|
||||
x = torch.randn(batch_size, 4, 128, 128).cuda()
|
||||
@@ -494,9 +493,9 @@ if __name__ == "__main__":
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
print(sample_param)
|
||||
logger.info(sample_param)
|
||||
|
||||
# from safetensors.torch import save_file
|
||||
|
||||
# print("save weights")
|
||||
# logger.info("save weights")
|
||||
# unet.save_lllite_weights("r:/lllite_from_unet.safetensors", torch.float16, None)
|
||||
|
||||
434
networks/convert_flux_lora.py
Normal file
434
networks/convert_flux_lora.py
Normal file
@@ -0,0 +1,434 @@
|
||||
# convert key mapping and data format from some LoRA format to another
|
||||
"""
|
||||
Original LoRA format: Based on Black Forest Labs, QKV and MLP are unified into one module
|
||||
alpha is scalar for each LoRA module
|
||||
|
||||
0 to 18
|
||||
lora_unet_double_blocks_0_img_attn_proj.alpha torch.Size([])
|
||||
lora_unet_double_blocks_0_img_attn_proj.lora_down.weight torch.Size([4, 3072])
|
||||
lora_unet_double_blocks_0_img_attn_proj.lora_up.weight torch.Size([3072, 4])
|
||||
lora_unet_double_blocks_0_img_attn_qkv.alpha torch.Size([])
|
||||
lora_unet_double_blocks_0_img_attn_qkv.lora_down.weight torch.Size([4, 3072])
|
||||
lora_unet_double_blocks_0_img_attn_qkv.lora_up.weight torch.Size([9216, 4])
|
||||
lora_unet_double_blocks_0_img_mlp_0.alpha torch.Size([])
|
||||
lora_unet_double_blocks_0_img_mlp_0.lora_down.weight torch.Size([4, 3072])
|
||||
lora_unet_double_blocks_0_img_mlp_0.lora_up.weight torch.Size([12288, 4])
|
||||
lora_unet_double_blocks_0_img_mlp_2.alpha torch.Size([])
|
||||
lora_unet_double_blocks_0_img_mlp_2.lora_down.weight torch.Size([4, 12288])
|
||||
lora_unet_double_blocks_0_img_mlp_2.lora_up.weight torch.Size([3072, 4])
|
||||
lora_unet_double_blocks_0_img_mod_lin.alpha torch.Size([])
|
||||
lora_unet_double_blocks_0_img_mod_lin.lora_down.weight torch.Size([4, 3072])
|
||||
lora_unet_double_blocks_0_img_mod_lin.lora_up.weight torch.Size([18432, 4])
|
||||
lora_unet_double_blocks_0_txt_attn_proj.alpha torch.Size([])
|
||||
lora_unet_double_blocks_0_txt_attn_proj.lora_down.weight torch.Size([4, 3072])
|
||||
lora_unet_double_blocks_0_txt_attn_proj.lora_up.weight torch.Size([3072, 4])
|
||||
lora_unet_double_blocks_0_txt_attn_qkv.alpha torch.Size([])
|
||||
lora_unet_double_blocks_0_txt_attn_qkv.lora_down.weight torch.Size([4, 3072])
|
||||
lora_unet_double_blocks_0_txt_attn_qkv.lora_up.weight torch.Size([9216, 4])
|
||||
lora_unet_double_blocks_0_txt_mlp_0.alpha torch.Size([])
|
||||
lora_unet_double_blocks_0_txt_mlp_0.lora_down.weight torch.Size([4, 3072])
|
||||
lora_unet_double_blocks_0_txt_mlp_0.lora_up.weight torch.Size([12288, 4])
|
||||
lora_unet_double_blocks_0_txt_mlp_2.alpha torch.Size([])
|
||||
lora_unet_double_blocks_0_txt_mlp_2.lora_down.weight torch.Size([4, 12288])
|
||||
lora_unet_double_blocks_0_txt_mlp_2.lora_up.weight torch.Size([3072, 4])
|
||||
lora_unet_double_blocks_0_txt_mod_lin.alpha torch.Size([])
|
||||
lora_unet_double_blocks_0_txt_mod_lin.lora_down.weight torch.Size([4, 3072])
|
||||
lora_unet_double_blocks_0_txt_mod_lin.lora_up.weight torch.Size([18432, 4])
|
||||
|
||||
0 to 37
|
||||
lora_unet_single_blocks_0_linear1.alpha torch.Size([])
|
||||
lora_unet_single_blocks_0_linear1.lora_down.weight torch.Size([4, 3072])
|
||||
lora_unet_single_blocks_0_linear1.lora_up.weight torch.Size([21504, 4])
|
||||
lora_unet_single_blocks_0_linear2.alpha torch.Size([])
|
||||
lora_unet_single_blocks_0_linear2.lora_down.weight torch.Size([4, 15360])
|
||||
lora_unet_single_blocks_0_linear2.lora_up.weight torch.Size([3072, 4])
|
||||
lora_unet_single_blocks_0_modulation_lin.alpha torch.Size([])
|
||||
lora_unet_single_blocks_0_modulation_lin.lora_down.weight torch.Size([4, 3072])
|
||||
lora_unet_single_blocks_0_modulation_lin.lora_up.weight torch.Size([9216, 4])
|
||||
"""
|
||||
"""
|
||||
ai-toolkit: Based on Diffusers, QKV and MLP are separated into 3 modules.
|
||||
A is down, B is up. No alpha for each LoRA module.
|
||||
|
||||
0 to 18
|
||||
transformer.transformer_blocks.0.attn.add_k_proj.lora_A.weight torch.Size([16, 3072])
|
||||
transformer.transformer_blocks.0.attn.add_k_proj.lora_B.weight torch.Size([3072, 16])
|
||||
transformer.transformer_blocks.0.attn.add_q_proj.lora_A.weight torch.Size([16, 3072])
|
||||
transformer.transformer_blocks.0.attn.add_q_proj.lora_B.weight torch.Size([3072, 16])
|
||||
transformer.transformer_blocks.0.attn.add_v_proj.lora_A.weight torch.Size([16, 3072])
|
||||
transformer.transformer_blocks.0.attn.add_v_proj.lora_B.weight torch.Size([3072, 16])
|
||||
transformer.transformer_blocks.0.attn.to_add_out.lora_A.weight torch.Size([16, 3072])
|
||||
transformer.transformer_blocks.0.attn.to_add_out.lora_B.weight torch.Size([3072, 16])
|
||||
transformer.transformer_blocks.0.attn.to_k.lora_A.weight torch.Size([16, 3072])
|
||||
transformer.transformer_blocks.0.attn.to_k.lora_B.weight torch.Size([3072, 16])
|
||||
transformer.transformer_blocks.0.attn.to_out.0.lora_A.weight torch.Size([16, 3072])
|
||||
transformer.transformer_blocks.0.attn.to_out.0.lora_B.weight torch.Size([3072, 16])
|
||||
transformer.transformer_blocks.0.attn.to_q.lora_A.weight torch.Size([16, 3072])
|
||||
transformer.transformer_blocks.0.attn.to_q.lora_B.weight torch.Size([3072, 16])
|
||||
transformer.transformer_blocks.0.attn.to_v.lora_A.weight torch.Size([16, 3072])
|
||||
transformer.transformer_blocks.0.attn.to_v.lora_B.weight torch.Size([3072, 16])
|
||||
transformer.transformer_blocks.0.ff.net.0.proj.lora_A.weight torch.Size([16, 3072])
|
||||
transformer.transformer_blocks.0.ff.net.0.proj.lora_B.weight torch.Size([12288, 16])
|
||||
transformer.transformer_blocks.0.ff.net.2.lora_A.weight torch.Size([16, 12288])
|
||||
transformer.transformer_blocks.0.ff.net.2.lora_B.weight torch.Size([3072, 16])
|
||||
transformer.transformer_blocks.0.ff_context.net.0.proj.lora_A.weight torch.Size([16, 3072])
|
||||
transformer.transformer_blocks.0.ff_context.net.0.proj.lora_B.weight torch.Size([12288, 16])
|
||||
transformer.transformer_blocks.0.ff_context.net.2.lora_A.weight torch.Size([16, 12288])
|
||||
transformer.transformer_blocks.0.ff_context.net.2.lora_B.weight torch.Size([3072, 16])
|
||||
transformer.transformer_blocks.0.norm1.linear.lora_A.weight torch.Size([16, 3072])
|
||||
transformer.transformer_blocks.0.norm1.linear.lora_B.weight torch.Size([18432, 16])
|
||||
transformer.transformer_blocks.0.norm1_context.linear.lora_A.weight torch.Size([16, 3072])
|
||||
transformer.transformer_blocks.0.norm1_context.linear.lora_B.weight torch.Size([18432, 16])
|
||||
|
||||
0 to 37
|
||||
transformer.single_transformer_blocks.0.attn.to_k.lora_A.weight torch.Size([16, 3072])
|
||||
transformer.single_transformer_blocks.0.attn.to_k.lora_B.weight torch.Size([3072, 16])
|
||||
transformer.single_transformer_blocks.0.attn.to_q.lora_A.weight torch.Size([16, 3072])
|
||||
transformer.single_transformer_blocks.0.attn.to_q.lora_B.weight torch.Size([3072, 16])
|
||||
transformer.single_transformer_blocks.0.attn.to_v.lora_A.weight torch.Size([16, 3072])
|
||||
transformer.single_transformer_blocks.0.attn.to_v.lora_B.weight torch.Size([3072, 16])
|
||||
transformer.single_transformer_blocks.0.norm.linear.lora_A.weight torch.Size([16, 3072])
|
||||
transformer.single_transformer_blocks.0.norm.linear.lora_B.weight torch.Size([9216, 16])
|
||||
transformer.single_transformer_blocks.0.proj_mlp.lora_A.weight torch.Size([16, 3072])
|
||||
transformer.single_transformer_blocks.0.proj_mlp.lora_B.weight torch.Size([12288, 16])
|
||||
transformer.single_transformer_blocks.0.proj_out.lora_A.weight torch.Size([16, 15360])
|
||||
transformer.single_transformer_blocks.0.proj_out.lora_B.weight torch.Size([3072, 16])
|
||||
"""
|
||||
"""
|
||||
xlabs: Unknown format.
|
||||
0 to 18
|
||||
double_blocks.0.processor.proj_lora1.down.weight torch.Size([16, 3072])
|
||||
double_blocks.0.processor.proj_lora1.up.weight torch.Size([3072, 16])
|
||||
double_blocks.0.processor.proj_lora2.down.weight torch.Size([16, 3072])
|
||||
double_blocks.0.processor.proj_lora2.up.weight torch.Size([3072, 16])
|
||||
double_blocks.0.processor.qkv_lora1.down.weight torch.Size([16, 3072])
|
||||
double_blocks.0.processor.qkv_lora1.up.weight torch.Size([9216, 16])
|
||||
double_blocks.0.processor.qkv_lora2.down.weight torch.Size([16, 3072])
|
||||
double_blocks.0.processor.qkv_lora2.up.weight torch.Size([9216, 16])
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
from safetensors.torch import save_file
|
||||
from safetensors import safe_open
|
||||
import torch
|
||||
|
||||
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def convert_to_sd_scripts(sds_sd, ait_sd, sds_key, ait_key):
|
||||
ait_down_key = ait_key + ".lora_A.weight"
|
||||
if ait_down_key not in ait_sd:
|
||||
return
|
||||
ait_up_key = ait_key + ".lora_B.weight"
|
||||
|
||||
down_weight = ait_sd.pop(ait_down_key)
|
||||
sds_sd[sds_key + ".lora_down.weight"] = down_weight
|
||||
sds_sd[sds_key + ".lora_up.weight"] = ait_sd.pop(ait_up_key)
|
||||
rank = down_weight.shape[0]
|
||||
sds_sd[sds_key + ".alpha"] = torch.scalar_tensor(rank, dtype=down_weight.dtype, device=down_weight.device)
|
||||
|
||||
|
||||
def convert_to_sd_scripts_cat(sds_sd, ait_sd, sds_key, ait_keys):
|
||||
ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
|
||||
if ait_down_keys[0] not in ait_sd:
|
||||
return
|
||||
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
|
||||
|
||||
down_weights = [ait_sd.pop(k) for k in ait_down_keys]
|
||||
up_weights = [ait_sd.pop(k) for k in ait_up_keys]
|
||||
|
||||
# lora_down is concatenated along dim=0, so rank is multiplied by the number of splits
|
||||
rank = down_weights[0].shape[0]
|
||||
num_splits = len(ait_keys)
|
||||
sds_sd[sds_key + ".lora_down.weight"] = torch.cat(down_weights, dim=0)
|
||||
|
||||
merged_up_weights = torch.zeros(
|
||||
(sum(w.shape[0] for w in up_weights), rank * num_splits),
|
||||
dtype=up_weights[0].dtype,
|
||||
device=up_weights[0].device,
|
||||
)
|
||||
|
||||
i = 0
|
||||
for j, up_weight in enumerate(up_weights):
|
||||
merged_up_weights[i : i + up_weight.shape[0], j * rank : (j + 1) * rank] = up_weight
|
||||
i += up_weight.shape[0]
|
||||
|
||||
sds_sd[sds_key + ".lora_up.weight"] = merged_up_weights
|
||||
|
||||
# set alpha to new_rank
|
||||
new_rank = rank * num_splits
|
||||
sds_sd[sds_key + ".alpha"] = torch.scalar_tensor(new_rank, dtype=down_weights[0].dtype, device=down_weights[0].device)
|
||||
|
||||
|
||||
def convert_ai_toolkit_to_sd_scripts(ait_sd):
|
||||
sds_sd = {}
|
||||
for i in range(19):
|
||||
convert_to_sd_scripts(
|
||||
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_attn_proj", f"transformer.transformer_blocks.{i}.attn.to_out.0"
|
||||
)
|
||||
convert_to_sd_scripts_cat(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_img_attn_qkv",
|
||||
[
|
||||
f"transformer.transformer_blocks.{i}.attn.to_q",
|
||||
f"transformer.transformer_blocks.{i}.attn.to_k",
|
||||
f"transformer.transformer_blocks.{i}.attn.to_v",
|
||||
],
|
||||
)
|
||||
convert_to_sd_scripts(
|
||||
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mlp_0", f"transformer.transformer_blocks.{i}.ff.net.0.proj"
|
||||
)
|
||||
convert_to_sd_scripts(
|
||||
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mlp_2", f"transformer.transformer_blocks.{i}.ff.net.2"
|
||||
)
|
||||
convert_to_sd_scripts(
|
||||
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mod_lin", f"transformer.transformer_blocks.{i}.norm1.linear"
|
||||
)
|
||||
convert_to_sd_scripts(
|
||||
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_attn_proj", f"transformer.transformer_blocks.{i}.attn.to_add_out"
|
||||
)
|
||||
convert_to_sd_scripts_cat(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_txt_attn_qkv",
|
||||
[
|
||||
f"transformer.transformer_blocks.{i}.attn.add_q_proj",
|
||||
f"transformer.transformer_blocks.{i}.attn.add_k_proj",
|
||||
f"transformer.transformer_blocks.{i}.attn.add_v_proj",
|
||||
],
|
||||
)
|
||||
convert_to_sd_scripts(
|
||||
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mlp_0", f"transformer.transformer_blocks.{i}.ff_context.net.0.proj"
|
||||
)
|
||||
convert_to_sd_scripts(
|
||||
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mlp_2", f"transformer.transformer_blocks.{i}.ff_context.net.2"
|
||||
)
|
||||
convert_to_sd_scripts(
|
||||
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mod_lin", f"transformer.transformer_blocks.{i}.norm1_context.linear"
|
||||
)
|
||||
|
||||
for i in range(38):
|
||||
convert_to_sd_scripts_cat(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_single_blocks_{i}_linear1",
|
||||
[
|
||||
f"transformer.single_transformer_blocks.{i}.attn.to_q",
|
||||
f"transformer.single_transformer_blocks.{i}.attn.to_k",
|
||||
f"transformer.single_transformer_blocks.{i}.attn.to_v",
|
||||
f"transformer.single_transformer_blocks.{i}.proj_mlp",
|
||||
],
|
||||
)
|
||||
convert_to_sd_scripts(
|
||||
sds_sd, ait_sd, f"lora_unet_single_blocks_{i}_linear2", f"transformer.single_transformer_blocks.{i}.proj_out"
|
||||
)
|
||||
convert_to_sd_scripts(
|
||||
sds_sd, ait_sd, f"lora_unet_single_blocks_{i}_modulation_lin", f"transformer.single_transformer_blocks.{i}.norm.linear"
|
||||
)
|
||||
|
||||
if len(ait_sd) > 0:
|
||||
logger.warning(f"Unsuppored keys for sd-scripts: {ait_sd.keys()}")
|
||||
return sds_sd
|
||||
|
||||
|
||||
def convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
|
||||
if sds_key + ".lora_down.weight" not in sds_sd:
|
||||
return
|
||||
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
|
||||
|
||||
# scale weight by alpha and dim
|
||||
rank = down_weight.shape[0]
|
||||
alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar
|
||||
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
|
||||
# print(f"rank: {rank}, alpha: {alpha}, scale: {scale}")
|
||||
|
||||
# calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2
|
||||
scale_down = scale
|
||||
scale_up = 1.0
|
||||
while scale_down * 2 < scale_up:
|
||||
scale_down *= 2
|
||||
scale_up /= 2
|
||||
# print(f"scale: {scale}, scale_down: {scale_down}, scale_up: {scale_up}")
|
||||
|
||||
ait_sd[ait_key + ".lora_A.weight"] = down_weight * scale_down
|
||||
ait_sd[ait_key + ".lora_B.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up
|
||||
|
||||
|
||||
def convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
|
||||
if sds_key + ".lora_down.weight" not in sds_sd:
|
||||
return
|
||||
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
|
||||
up_weight = sds_sd.pop(sds_key + ".lora_up.weight")
|
||||
sd_lora_rank = down_weight.shape[0]
|
||||
|
||||
# scale weight by alpha and dim
|
||||
alpha = sds_sd.pop(sds_key + ".alpha")
|
||||
scale = alpha / sd_lora_rank
|
||||
|
||||
# calculate scale_down and scale_up
|
||||
scale_down = scale
|
||||
scale_up = 1.0
|
||||
while scale_down * 2 < scale_up:
|
||||
scale_down *= 2
|
||||
scale_up /= 2
|
||||
|
||||
down_weight = down_weight * scale_down
|
||||
up_weight = up_weight * scale_up
|
||||
|
||||
# calculate dims if not provided
|
||||
num_splits = len(ait_keys)
|
||||
if dims is None:
|
||||
dims = [up_weight.shape[0] // num_splits] * num_splits
|
||||
else:
|
||||
assert sum(dims) == up_weight.shape[0]
|
||||
|
||||
# check upweight is sparse or not
|
||||
is_sparse = False
|
||||
if sd_lora_rank % num_splits == 0:
|
||||
ait_rank = sd_lora_rank // num_splits
|
||||
is_sparse = True
|
||||
i = 0
|
||||
for j in range(len(dims)):
|
||||
for k in range(len(dims)):
|
||||
if j == k:
|
||||
continue
|
||||
is_sparse = is_sparse and torch.all(up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0)
|
||||
i += dims[j]
|
||||
if is_sparse:
|
||||
logger.info(f"weight is sparse: {sds_key}")
|
||||
|
||||
# make ai-toolkit weight
|
||||
ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
|
||||
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
|
||||
if not is_sparse:
|
||||
# down_weight is copied to each split
|
||||
ait_sd.update({k: down_weight for k in ait_down_keys})
|
||||
|
||||
# up_weight is split to each split
|
||||
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))})
|
||||
else:
|
||||
# down_weight is chunked to each split
|
||||
ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))})
|
||||
|
||||
# up_weight is sparse: only non-zero values are copied to each split
|
||||
i = 0
|
||||
for j in range(len(dims)):
|
||||
ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous()
|
||||
i += dims[j]
|
||||
|
||||
|
||||
def convert_sd_scripts_to_ai_toolkit(sds_sd):
|
||||
ait_sd = {}
|
||||
for i in range(19):
|
||||
convert_to_ai_toolkit(
|
||||
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_attn_proj", f"transformer.transformer_blocks.{i}.attn.to_out.0"
|
||||
)
|
||||
convert_to_ai_toolkit_cat(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_img_attn_qkv",
|
||||
[
|
||||
f"transformer.transformer_blocks.{i}.attn.to_q",
|
||||
f"transformer.transformer_blocks.{i}.attn.to_k",
|
||||
f"transformer.transformer_blocks.{i}.attn.to_v",
|
||||
],
|
||||
)
|
||||
convert_to_ai_toolkit(
|
||||
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mlp_0", f"transformer.transformer_blocks.{i}.ff.net.0.proj"
|
||||
)
|
||||
convert_to_ai_toolkit(
|
||||
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mlp_2", f"transformer.transformer_blocks.{i}.ff.net.2"
|
||||
)
|
||||
convert_to_ai_toolkit(
|
||||
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mod_lin", f"transformer.transformer_blocks.{i}.norm1.linear"
|
||||
)
|
||||
convert_to_ai_toolkit(
|
||||
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_attn_proj", f"transformer.transformer_blocks.{i}.attn.to_add_out"
|
||||
)
|
||||
convert_to_ai_toolkit_cat(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_txt_attn_qkv",
|
||||
[
|
||||
f"transformer.transformer_blocks.{i}.attn.add_q_proj",
|
||||
f"transformer.transformer_blocks.{i}.attn.add_k_proj",
|
||||
f"transformer.transformer_blocks.{i}.attn.add_v_proj",
|
||||
],
|
||||
)
|
||||
convert_to_ai_toolkit(
|
||||
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mlp_0", f"transformer.transformer_blocks.{i}.ff_context.net.0.proj"
|
||||
)
|
||||
convert_to_ai_toolkit(
|
||||
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mlp_2", f"transformer.transformer_blocks.{i}.ff_context.net.2"
|
||||
)
|
||||
convert_to_ai_toolkit(
|
||||
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mod_lin", f"transformer.transformer_blocks.{i}.norm1_context.linear"
|
||||
)
|
||||
|
||||
for i in range(38):
|
||||
convert_to_ai_toolkit_cat(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_single_blocks_{i}_linear1",
|
||||
[
|
||||
f"transformer.single_transformer_blocks.{i}.attn.to_q",
|
||||
f"transformer.single_transformer_blocks.{i}.attn.to_k",
|
||||
f"transformer.single_transformer_blocks.{i}.attn.to_v",
|
||||
f"transformer.single_transformer_blocks.{i}.proj_mlp",
|
||||
],
|
||||
dims=[3072, 3072, 3072, 12288],
|
||||
)
|
||||
convert_to_ai_toolkit(
|
||||
sds_sd, ait_sd, f"lora_unet_single_blocks_{i}_linear2", f"transformer.single_transformer_blocks.{i}.proj_out"
|
||||
)
|
||||
convert_to_ai_toolkit(
|
||||
sds_sd, ait_sd, f"lora_unet_single_blocks_{i}_modulation_lin", f"transformer.single_transformer_blocks.{i}.norm.linear"
|
||||
)
|
||||
|
||||
if len(sds_sd) > 0:
|
||||
logger.warning(f"Unsuppored keys for ai-toolkit: {sds_sd.keys()}")
|
||||
return ait_sd
|
||||
|
||||
|
||||
def main(args):
|
||||
# load source safetensors
|
||||
logger.info(f"Loading source file {args.src_path}")
|
||||
state_dict = {}
|
||||
with safe_open(args.src_path, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
for k in f.keys():
|
||||
state_dict[k] = f.get_tensor(k)
|
||||
|
||||
logger.info(f"Converting {args.src} to {args.dst} format")
|
||||
if args.src == "ai-toolkit" and args.dst == "sd-scripts":
|
||||
state_dict = convert_ai_toolkit_to_sd_scripts(state_dict)
|
||||
elif args.src == "sd-scripts" and args.dst == "ai-toolkit":
|
||||
state_dict = convert_sd_scripts_to_ai_toolkit(state_dict)
|
||||
|
||||
# eliminate 'shared tensors'
|
||||
for k in list(state_dict.keys()):
|
||||
state_dict[k] = state_dict[k].detach().clone()
|
||||
else:
|
||||
raise NotImplementedError(f"Conversion from {args.src} to {args.dst} is not supported")
|
||||
|
||||
# save destination safetensors
|
||||
logger.info(f"Saving destination file {args.dst_path}")
|
||||
save_file(state_dict, args.dst_path, metadata=metadata)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert LoRA format")
|
||||
parser.add_argument("--src", type=str, default="ai-toolkit", help="source format, ai-toolkit or sd-scripts")
|
||||
parser.add_argument("--dst", type=str, default="sd-scripts", help="destination format, ai-toolkit or sd-scripts")
|
||||
parser.add_argument("--src_path", type=str, default=None, help="source path")
|
||||
parser.add_argument("--dst_path", type=str, default=None, help="destination path")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@@ -12,9 +12,17 @@
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from typing import List, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
from diffusers import AutoencoderKL
|
||||
from transformers import CLIPTextModel
|
||||
import torch
|
||||
from torch import nn
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DyLoRAModule(torch.nn.Module):
|
||||
@@ -165,7 +173,15 @@ class DyLoRAModule(torch.nn.Module):
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
|
||||
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
||||
def create_network(
|
||||
multiplier: float,
|
||||
network_dim: Optional[int],
|
||||
network_alpha: Optional[float],
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
|
||||
unet,
|
||||
**kwargs,
|
||||
):
|
||||
if network_dim is None:
|
||||
network_dim = 4 # default
|
||||
if network_alpha is None:
|
||||
@@ -182,6 +198,7 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
|
||||
conv_alpha = 1.0
|
||||
else:
|
||||
conv_alpha = float(conv_alpha)
|
||||
|
||||
if unit is not None:
|
||||
unit = int(unit)
|
||||
else:
|
||||
@@ -197,6 +214,16 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
|
||||
unit=unit,
|
||||
varbose=True,
|
||||
)
|
||||
|
||||
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
|
||||
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
|
||||
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
|
||||
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
|
||||
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
|
||||
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
|
||||
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
|
||||
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
|
||||
|
||||
return network
|
||||
|
||||
|
||||
@@ -223,7 +250,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
|
||||
elif "lora_down" in key:
|
||||
dim = value.size()[0]
|
||||
modules_dim[lora_name] = dim
|
||||
# print(lora_name, value.size(), dim)
|
||||
# logger.info(f"{lora_name} {value.size()} {dim}")
|
||||
|
||||
# support old LoRA without alpha
|
||||
for key in modules_dim.keys():
|
||||
@@ -241,7 +268,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
|
||||
class DyLoRANetwork(torch.nn.Module):
|
||||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
||||
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
|
||||
LORA_PREFIX_UNET = "lora_unet"
|
||||
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
||||
|
||||
@@ -266,12 +293,16 @@ class DyLoRANetwork(torch.nn.Module):
|
||||
self.alpha = alpha
|
||||
self.apply_to_conv = apply_to_conv
|
||||
|
||||
self.loraplus_lr_ratio = None
|
||||
self.loraplus_unet_lr_ratio = None
|
||||
self.loraplus_text_encoder_lr_ratio = None
|
||||
|
||||
if modules_dim is not None:
|
||||
print(f"create LoRA network from weights")
|
||||
logger.info("create LoRA network from weights")
|
||||
else:
|
||||
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}")
|
||||
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}")
|
||||
if self.apply_to_conv:
|
||||
print(f"apply LoRA to Conv2d with kernel size (3,3).")
|
||||
logger.info("apply LoRA to Conv2d with kernel size (3,3).")
|
||||
|
||||
# create module instances
|
||||
def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[DyLoRAModule]:
|
||||
@@ -307,8 +338,22 @@ class DyLoRANetwork(torch.nn.Module):
|
||||
loras.append(lora)
|
||||
return loras
|
||||
|
||||
self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
|
||||
|
||||
self.text_encoder_loras = []
|
||||
for i, text_encoder in enumerate(text_encoders):
|
||||
if len(text_encoders) > 1:
|
||||
index = i + 1
|
||||
logger.info(f"create LoRA for Text Encoder {index}")
|
||||
else:
|
||||
index = None
|
||||
logger.info("create LoRA for Text Encoder")
|
||||
|
||||
text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||
self.text_encoder_loras.extend(text_encoder_loras)
|
||||
|
||||
# self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||
logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||
|
||||
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
||||
target_modules = DyLoRANetwork.UNET_TARGET_REPLACE_MODULE
|
||||
@@ -316,7 +361,15 @@ class DyLoRANetwork(torch.nn.Module):
|
||||
target_modules += DyLoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
|
||||
self.unet_loras = create_modules(True, unet, target_modules)
|
||||
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||
|
||||
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
|
||||
self.loraplus_lr_ratio = loraplus_lr_ratio
|
||||
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
|
||||
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
|
||||
|
||||
logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}")
|
||||
logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
|
||||
|
||||
def set_multiplier(self, multiplier):
|
||||
self.multiplier = multiplier
|
||||
@@ -336,12 +389,12 @@ class DyLoRANetwork(torch.nn.Module):
|
||||
|
||||
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
|
||||
if apply_text_encoder:
|
||||
print("enable LoRA for text encoder")
|
||||
logger.info("enable LoRA for text encoder")
|
||||
else:
|
||||
self.text_encoder_loras = []
|
||||
|
||||
if apply_unet:
|
||||
print("enable LoRA for U-Net")
|
||||
logger.info("enable LoRA for U-Net")
|
||||
else:
|
||||
self.unet_loras = []
|
||||
|
||||
@@ -359,12 +412,12 @@ class DyLoRANetwork(torch.nn.Module):
|
||||
apply_unet = True
|
||||
|
||||
if apply_text_encoder:
|
||||
print("enable LoRA for text encoder")
|
||||
logger.info("enable LoRA for text encoder")
|
||||
else:
|
||||
self.text_encoder_loras = []
|
||||
|
||||
if apply_unet:
|
||||
print("enable LoRA for U-Net")
|
||||
logger.info("enable LoRA for U-Net")
|
||||
else:
|
||||
self.unet_loras = []
|
||||
|
||||
@@ -375,30 +428,56 @@ class DyLoRANetwork(torch.nn.Module):
|
||||
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
|
||||
lora.merge_to(sd_for_lora, dtype, device)
|
||||
|
||||
print(f"weights are merged")
|
||||
logger.info(f"weights are merged")
|
||||
"""
|
||||
|
||||
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
||||
self.requires_grad_(True)
|
||||
all_params = []
|
||||
|
||||
def enumerate_params(loras):
|
||||
params = []
|
||||
def assemble_params(loras, lr, ratio):
|
||||
param_groups = {"lora": {}, "plus": {}}
|
||||
for lora in loras:
|
||||
params.extend(lora.parameters())
|
||||
for name, param in lora.named_parameters():
|
||||
if ratio is not None and "lora_B" in name:
|
||||
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
|
||||
else:
|
||||
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
|
||||
|
||||
params = []
|
||||
for key in param_groups.keys():
|
||||
param_data = {"params": param_groups[key].values()}
|
||||
|
||||
if len(param_data["params"]) == 0:
|
||||
continue
|
||||
|
||||
if lr is not None:
|
||||
if key == "plus":
|
||||
param_data["lr"] = lr * ratio
|
||||
else:
|
||||
param_data["lr"] = lr
|
||||
|
||||
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
|
||||
continue
|
||||
|
||||
params.append(param_data)
|
||||
|
||||
return params
|
||||
|
||||
if self.text_encoder_loras:
|
||||
param_data = {"params": enumerate_params(self.text_encoder_loras)}
|
||||
if text_encoder_lr is not None:
|
||||
param_data["lr"] = text_encoder_lr
|
||||
all_params.append(param_data)
|
||||
params = assemble_params(
|
||||
self.text_encoder_loras,
|
||||
text_encoder_lr if text_encoder_lr is not None else default_lr,
|
||||
self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio,
|
||||
)
|
||||
all_params.extend(params)
|
||||
|
||||
if self.unet_loras:
|
||||
param_data = {"params": enumerate_params(self.unet_loras)}
|
||||
if unet_lr is not None:
|
||||
param_data["lr"] = unet_lr
|
||||
all_params.append(param_data)
|
||||
params = assemble_params(
|
||||
self.unet_loras, default_lr if unet_lr is None else unet_lr, self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio
|
||||
)
|
||||
all_params.extend(params)
|
||||
|
||||
return all_params
|
||||
|
||||
|
||||
@@ -10,7 +10,10 @@ from safetensors.torch import load_file, save_file, safe_open
|
||||
from tqdm import tqdm
|
||||
from library import train_util, model_util
|
||||
import numpy as np
|
||||
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def load_state_dict(file_name):
|
||||
if model_util.is_safetensors(file_name):
|
||||
@@ -40,13 +43,13 @@ def split_lora_model(lora_sd, unit):
|
||||
rank = value.size()[0]
|
||||
if rank > max_rank:
|
||||
max_rank = rank
|
||||
print(f"Max rank: {max_rank}")
|
||||
logger.info(f"Max rank: {max_rank}")
|
||||
|
||||
rank = unit
|
||||
split_models = []
|
||||
new_alpha = None
|
||||
while rank < max_rank:
|
||||
print(f"Splitting rank {rank}")
|
||||
logger.info(f"Splitting rank {rank}")
|
||||
new_sd = {}
|
||||
for key, value in lora_sd.items():
|
||||
if "lora_down" in key:
|
||||
@@ -57,7 +60,7 @@ def split_lora_model(lora_sd, unit):
|
||||
# なぜかscaleするとおかしくなる……
|
||||
# this_rank = lora_sd[key.replace("alpha", "lora_down.weight")].size()[0]
|
||||
# scale = math.sqrt(this_rank / rank) # rank is > unit
|
||||
# print(key, value.size(), this_rank, rank, value, scale)
|
||||
# logger.info(key, value.size(), this_rank, rank, value, scale)
|
||||
# new_alpha = value * scale # always same
|
||||
# new_sd[key] = new_alpha
|
||||
new_sd[key] = value
|
||||
@@ -69,10 +72,10 @@ def split_lora_model(lora_sd, unit):
|
||||
|
||||
|
||||
def split(args):
|
||||
print("loading Model...")
|
||||
logger.info("loading Model...")
|
||||
lora_sd, metadata = load_state_dict(args.model)
|
||||
|
||||
print("Splitting Model...")
|
||||
logger.info("Splitting Model...")
|
||||
original_rank, split_models = split_lora_model(lora_sd, args.unit)
|
||||
|
||||
comment = metadata.get("ss_training_comment", "")
|
||||
@@ -94,7 +97,7 @@ def split(args):
|
||||
filename, ext = os.path.splitext(args.save_to)
|
||||
model_file_name = filename + f"-{new_rank:04d}{ext}"
|
||||
|
||||
print(f"saving model to: {model_file_name}")
|
||||
logger.info(f"saving model to: {model_file_name}")
|
||||
save_to_file(model_file_name, state_dict, new_metadata)
|
||||
|
||||
|
||||
|
||||
@@ -11,7 +11,10 @@ from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
from library import sai_model_spec, model_util, sdxl_model_util
|
||||
import lora
|
||||
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# CLAMP_QUANTILE = 0.99
|
||||
# MIN_DIFF = 1e-1
|
||||
@@ -66,14 +69,14 @@ def svd(
|
||||
|
||||
# load models
|
||||
if not sdxl:
|
||||
print(f"loading original SD model : {model_org}")
|
||||
logger.info(f"loading original SD model : {model_org}")
|
||||
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org)
|
||||
text_encoders_o = [text_encoder_o]
|
||||
if load_dtype is not None:
|
||||
text_encoder_o = text_encoder_o.to(load_dtype)
|
||||
unet_o = unet_o.to(load_dtype)
|
||||
|
||||
print(f"loading tuned SD model : {model_tuned}")
|
||||
logger.info(f"loading tuned SD model : {model_tuned}")
|
||||
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned)
|
||||
text_encoders_t = [text_encoder_t]
|
||||
if load_dtype is not None:
|
||||
@@ -85,7 +88,7 @@ def svd(
|
||||
device_org = load_original_model_to if load_original_model_to else "cpu"
|
||||
device_tuned = load_tuned_model_to if load_tuned_model_to else "cpu"
|
||||
|
||||
print(f"loading original SDXL model : {model_org}")
|
||||
logger.info(f"loading original SDXL model : {model_org}")
|
||||
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, device_org
|
||||
)
|
||||
@@ -95,7 +98,7 @@ def svd(
|
||||
text_encoder_o2 = text_encoder_o2.to(load_dtype)
|
||||
unet_o = unet_o.to(load_dtype)
|
||||
|
||||
print(f"loading original SDXL model : {model_tuned}")
|
||||
logger.info(f"loading original SDXL model : {model_tuned}")
|
||||
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, device_tuned
|
||||
)
|
||||
@@ -135,7 +138,7 @@ def svd(
|
||||
# Text Encoder might be same
|
||||
if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff:
|
||||
text_encoder_different = True
|
||||
print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}")
|
||||
logger.info(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}")
|
||||
|
||||
diffs[lora_name] = diff
|
||||
|
||||
@@ -144,7 +147,7 @@ def svd(
|
||||
del text_encoder
|
||||
|
||||
if not text_encoder_different:
|
||||
print("Text encoder is same. Extract U-Net only.")
|
||||
logger.warning("Text encoder is same. Extract U-Net only.")
|
||||
lora_network_o.text_encoder_loras = []
|
||||
diffs = {} # clear diffs
|
||||
|
||||
@@ -166,7 +169,7 @@ def svd(
|
||||
del unet_t
|
||||
|
||||
# make LoRA with svd
|
||||
print("calculating by svd")
|
||||
logger.info("calculating by svd")
|
||||
lora_weights = {}
|
||||
with torch.no_grad():
|
||||
for lora_name, mat in tqdm(list(diffs.items())):
|
||||
@@ -185,7 +188,7 @@ def svd(
|
||||
if device:
|
||||
mat = mat.to(device)
|
||||
|
||||
# print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
|
||||
# logger.info(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
|
||||
rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
|
||||
|
||||
if conv2d:
|
||||
@@ -230,7 +233,7 @@ def svd(
|
||||
lora_network_save.apply_to(text_encoders_o, unet_o) # create internal module references for state_dict
|
||||
|
||||
info = lora_network_save.load_state_dict(lora_sd)
|
||||
print(f"Loading extracted LoRA weights: {info}")
|
||||
logger.info(f"Loading extracted LoRA weights: {info}")
|
||||
|
||||
dir_name = os.path.dirname(save_to)
|
||||
if dir_name and not os.path.exists(dir_name):
|
||||
@@ -257,7 +260,7 @@ def svd(
|
||||
metadata.update(sai_metadata)
|
||||
|
||||
lora_network_save.save_weights(save_to, save_dtype, metadata)
|
||||
print(f"LoRA weights are saved to: {save_to}")
|
||||
logger.info(f"LoRA weights are saved to: {save_to}")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
219
networks/flux_extract_lora.py
Normal file
219
networks/flux_extract_lora.py
Normal file
@@ -0,0 +1,219 @@
|
||||
# extract approximating LoRA by svd from two FLUX models
|
||||
# The code is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
|
||||
# Thanks to cloneofsimo!
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from safetensors import safe_open
|
||||
from tqdm import tqdm
|
||||
from library import flux_utils, sai_model_spec, model_util, sdxl_model_util
|
||||
import lora
|
||||
from library.utils import MemoryEfficientSafeOpen
|
||||
from library.utils import setup_logging
|
||||
from networks import lora_flux
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# CLAMP_QUANTILE = 0.99
|
||||
# MIN_DIFF = 1e-1
|
||||
|
||||
|
||||
def save_to_file(file_name, state_dict, metadata, dtype):
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
if type(state_dict[key]) == torch.Tensor:
|
||||
state_dict[key] = state_dict[key].to(dtype)
|
||||
|
||||
save_file(state_dict, file_name, metadata=metadata)
|
||||
|
||||
|
||||
def svd(
|
||||
model_org=None,
|
||||
model_tuned=None,
|
||||
save_to=None,
|
||||
dim=4,
|
||||
device=None,
|
||||
save_precision=None,
|
||||
clamp_quantile=0.99,
|
||||
min_diff=0.01,
|
||||
no_metadata=False,
|
||||
mem_eff_safe_open=False,
|
||||
):
|
||||
def str_to_dtype(p):
|
||||
if p == "float":
|
||||
return torch.float
|
||||
if p == "fp16":
|
||||
return torch.float16
|
||||
if p == "bf16":
|
||||
return torch.bfloat16
|
||||
return None
|
||||
|
||||
calc_dtype = torch.float
|
||||
save_dtype = str_to_dtype(save_precision)
|
||||
store_device = "cpu"
|
||||
|
||||
# open models
|
||||
lora_weights = {}
|
||||
if not mem_eff_safe_open:
|
||||
# use original safetensors.safe_open
|
||||
open_fn = lambda fn: safe_open(fn, framework="pt")
|
||||
else:
|
||||
logger.info("Using memory efficient safe_open")
|
||||
open_fn = lambda fn: MemoryEfficientSafeOpen(fn)
|
||||
|
||||
with open_fn(model_org) as f_org:
|
||||
# filter keys
|
||||
keys = []
|
||||
for key in f_org.keys():
|
||||
if not ("single_block" in key or "double_block" in key):
|
||||
continue
|
||||
if ".bias" in key:
|
||||
continue
|
||||
if "norm" in key:
|
||||
continue
|
||||
keys.append(key)
|
||||
|
||||
with open_fn(model_tuned) as f_tuned:
|
||||
for key in tqdm(keys):
|
||||
# get tensors and calculate difference
|
||||
value_o = f_org.get_tensor(key)
|
||||
value_t = f_tuned.get_tensor(key)
|
||||
mat = value_t.to(calc_dtype) - value_o.to(calc_dtype)
|
||||
del value_o, value_t
|
||||
|
||||
# extract LoRA weights
|
||||
if device:
|
||||
mat = mat.to(device)
|
||||
out_dim, in_dim = mat.size()[0:2]
|
||||
rank = min(dim, in_dim, out_dim) # LoRA rank cannot exceed the original dim
|
||||
|
||||
mat = mat.squeeze()
|
||||
|
||||
U, S, Vh = torch.linalg.svd(mat)
|
||||
|
||||
U = U[:, :rank]
|
||||
S = S[:rank]
|
||||
U = U @ torch.diag(S)
|
||||
|
||||
Vh = Vh[:rank, :]
|
||||
|
||||
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||
hi_val = torch.quantile(dist, clamp_quantile)
|
||||
low_val = -hi_val
|
||||
|
||||
U = U.clamp(low_val, hi_val)
|
||||
Vh = Vh.clamp(low_val, hi_val)
|
||||
|
||||
U = U.to(store_device, dtype=save_dtype).contiguous()
|
||||
Vh = Vh.to(store_device, dtype=save_dtype).contiguous()
|
||||
|
||||
# print(f"key: {key}, U: {U.size()}, Vh: {Vh.size()}")
|
||||
lora_weights[key] = (U, Vh)
|
||||
del mat, U, S, Vh
|
||||
|
||||
# make state dict for LoRA
|
||||
lora_sd = {}
|
||||
for key, (up_weight, down_weight) in lora_weights.items():
|
||||
lora_name = key.replace(".weight", "").replace(".", "_")
|
||||
lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + lora_name
|
||||
lora_sd[lora_name + ".lora_up.weight"] = up_weight
|
||||
lora_sd[lora_name + ".lora_down.weight"] = down_weight
|
||||
lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0]) # same as rank
|
||||
|
||||
# minimum metadata
|
||||
net_kwargs = {}
|
||||
metadata = {
|
||||
"ss_v2": str(False),
|
||||
"ss_base_model_version": flux_utils.MODEL_VERSION_FLUX_V1,
|
||||
"ss_network_module": "networks.lora_flux",
|
||||
"ss_network_dim": str(dim),
|
||||
"ss_network_alpha": str(float(dim)),
|
||||
"ss_network_args": json.dumps(net_kwargs),
|
||||
}
|
||||
|
||||
if not no_metadata:
|
||||
title = os.path.splitext(os.path.basename(save_to))[0]
|
||||
sai_metadata = sai_model_spec.build_metadata(lora_sd, False, False, False, True, False, time.time(), title, flux="dev")
|
||||
metadata.update(sai_metadata)
|
||||
|
||||
save_to_file(save_to, lora_sd, metadata, save_dtype)
|
||||
|
||||
logger.info(f"LoRA weights saved to {save_to}")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--save_precision",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=[None, "float", "fp16", "bf16"],
|
||||
help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_org",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Original model: safetensors file / 元モデル、safetensors",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_tuned",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Tuned model, LoRA is difference of `original to tuned`: safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mem_eff_safe_open",
|
||||
action="store_true",
|
||||
help="use memory efficient safe_open. This is an experimental feature, use only when memory is not enough."
|
||||
" / メモリ効率の良いsafe_openを使用する。実装は実験的なものなので、メモリが足りない場合のみ使用してください。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_to",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="destination file name: safetensors file / 保存先のファイル名、safetensors",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clamp_quantile",
|
||||
type=float,
|
||||
default=0.99,
|
||||
help="Quantile clamping value, float, (0-1). Default = 0.99 / 値をクランプするための分位点、float、(0-1)。デフォルトは0.99",
|
||||
)
|
||||
# parser.add_argument(
|
||||
# "--min_diff",
|
||||
# type=float,
|
||||
# default=0.01,
|
||||
# help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01 /"
|
||||
# + "LoRAを抽出するために元モデルと派生モデルの差分の最小値、float、(0-1)。デフォルトは0.01",
|
||||
# )
|
||||
parser.add_argument(
|
||||
"--no_metadata",
|
||||
action="store_true",
|
||||
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
|
||||
+ "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
svd(**vars(args))
|
||||
765
networks/flux_merge_lora.py
Normal file
765
networks/flux_merge_lora.py
Normal file
@@ -0,0 +1,765 @@
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import torch
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
|
||||
from library.utils import setup_logging, str_to_dtype, MemoryEfficientSafeOpen, mem_eff_save_file
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import lora_flux as lora_flux
|
||||
from library import sai_model_spec, train_util
|
||||
|
||||
|
||||
def load_state_dict(file_name, dtype):
|
||||
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||
sd = load_file(file_name)
|
||||
metadata = train_util.load_metadata_from_safetensors(file_name)
|
||||
else:
|
||||
sd = torch.load(file_name, map_location="cpu")
|
||||
metadata = {}
|
||||
|
||||
for key in list(sd.keys()):
|
||||
if type(sd[key]) == torch.Tensor:
|
||||
sd[key] = sd[key].to(dtype)
|
||||
|
||||
return sd, metadata
|
||||
|
||||
|
||||
def save_to_file(file_name, state_dict: Dict[str, Union[Any, torch.Tensor]], dtype, metadata, mem_eff_save=False):
|
||||
if dtype is not None:
|
||||
logger.info(f"converting to {dtype}...")
|
||||
for key in tqdm(list(state_dict.keys())):
|
||||
if type(state_dict[key]) == torch.Tensor and state_dict[key].dtype.is_floating_point:
|
||||
state_dict[key] = state_dict[key].to(dtype)
|
||||
|
||||
logger.info(f"saving to: {file_name}")
|
||||
if mem_eff_save:
|
||||
mem_eff_save_file(state_dict, file_name, metadata=metadata)
|
||||
else:
|
||||
save_file(state_dict, file_name, metadata=metadata)
|
||||
|
||||
|
||||
def merge_to_flux_model(
|
||||
loading_device,
|
||||
working_device,
|
||||
flux_path: str,
|
||||
clip_l_path: str,
|
||||
t5xxl_path: str,
|
||||
models,
|
||||
ratios,
|
||||
merge_dtype,
|
||||
save_dtype,
|
||||
mem_eff_load_save=False,
|
||||
):
|
||||
# create module map without loading state_dict
|
||||
lora_name_to_module_key = {}
|
||||
if flux_path is not None:
|
||||
logger.info(f"loading keys from FLUX.1 model: {flux_path}")
|
||||
with safe_open(flux_path, framework="pt", device=loading_device) as flux_file:
|
||||
keys = list(flux_file.keys())
|
||||
for key in keys:
|
||||
if key.endswith(".weight"):
|
||||
module_name = ".".join(key.split(".")[:-1])
|
||||
lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + module_name.replace(".", "_")
|
||||
lora_name_to_module_key[lora_name] = key
|
||||
|
||||
lora_name_to_clip_l_key = {}
|
||||
if clip_l_path is not None:
|
||||
logger.info(f"loading keys from clip_l model: {clip_l_path}")
|
||||
with safe_open(clip_l_path, framework="pt", device=loading_device) as clip_l_file:
|
||||
keys = list(clip_l_file.keys())
|
||||
for key in keys:
|
||||
if key.endswith(".weight"):
|
||||
module_name = ".".join(key.split(".")[:-1])
|
||||
lora_name = lora_flux.LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP + "_" + module_name.replace(".", "_")
|
||||
lora_name_to_clip_l_key[lora_name] = key
|
||||
|
||||
lora_name_to_t5xxl_key = {}
|
||||
if t5xxl_path is not None:
|
||||
logger.info(f"loading keys from t5xxl model: {t5xxl_path}")
|
||||
with safe_open(t5xxl_path, framework="pt", device=loading_device) as t5xxl_file:
|
||||
keys = list(t5xxl_file.keys())
|
||||
for key in keys:
|
||||
if key.endswith(".weight"):
|
||||
module_name = ".".join(key.split(".")[:-1])
|
||||
lora_name = lora_flux.LoRANetwork.LORA_PREFIX_TEXT_ENCODER_T5 + "_" + module_name.replace(".", "_")
|
||||
lora_name_to_t5xxl_key[lora_name] = key
|
||||
|
||||
flux_state_dict = {}
|
||||
clip_l_state_dict = {}
|
||||
t5xxl_state_dict = {}
|
||||
if mem_eff_load_save:
|
||||
if flux_path is not None:
|
||||
with MemoryEfficientSafeOpen(flux_path) as flux_file:
|
||||
for key in tqdm(flux_file.keys()):
|
||||
flux_state_dict[key] = flux_file.get_tensor(key).to(loading_device) # dtype is not changed
|
||||
|
||||
if clip_l_path is not None:
|
||||
with MemoryEfficientSafeOpen(clip_l_path) as clip_l_file:
|
||||
for key in tqdm(clip_l_file.keys()):
|
||||
clip_l_state_dict[key] = clip_l_file.get_tensor(key).to(loading_device)
|
||||
|
||||
if t5xxl_path is not None:
|
||||
with MemoryEfficientSafeOpen(t5xxl_path) as t5xxl_file:
|
||||
for key in tqdm(t5xxl_file.keys()):
|
||||
t5xxl_state_dict[key] = t5xxl_file.get_tensor(key).to(loading_device)
|
||||
else:
|
||||
if flux_path is not None:
|
||||
flux_state_dict = load_file(flux_path, device=loading_device)
|
||||
if clip_l_path is not None:
|
||||
clip_l_state_dict = load_file(clip_l_path, device=loading_device)
|
||||
if t5xxl_path is not None:
|
||||
t5xxl_state_dict = load_file(t5xxl_path, device=loading_device)
|
||||
|
||||
for model, ratio in zip(models, ratios):
|
||||
logger.info(f"loading: {model}")
|
||||
lora_sd, _ = load_state_dict(model, merge_dtype) # loading on CPU
|
||||
|
||||
logger.info(f"merging...")
|
||||
for key in tqdm(list(lora_sd.keys())):
|
||||
if "lora_down" in key:
|
||||
lora_name = key[: key.rfind(".lora_down")]
|
||||
up_key = key.replace("lora_down", "lora_up")
|
||||
alpha_key = key[: key.index("lora_down")] + "alpha"
|
||||
|
||||
if lora_name in lora_name_to_module_key:
|
||||
module_weight_key = lora_name_to_module_key[lora_name]
|
||||
state_dict = flux_state_dict
|
||||
elif lora_name in lora_name_to_clip_l_key:
|
||||
module_weight_key = lora_name_to_clip_l_key[lora_name]
|
||||
state_dict = clip_l_state_dict
|
||||
elif lora_name in lora_name_to_t5xxl_key:
|
||||
module_weight_key = lora_name_to_t5xxl_key[lora_name]
|
||||
state_dict = t5xxl_state_dict
|
||||
else:
|
||||
logger.warning(
|
||||
f"no module found for LoRA weight: {key}. Skipping..."
|
||||
f"LoRAの重みに対応するモジュールが見つかりませんでした。スキップします。"
|
||||
)
|
||||
continue
|
||||
|
||||
down_weight = lora_sd.pop(key)
|
||||
up_weight = lora_sd.pop(up_key)
|
||||
|
||||
dim = down_weight.size()[0]
|
||||
alpha = lora_sd.pop(alpha_key, dim)
|
||||
scale = alpha / dim
|
||||
|
||||
# W <- W + U * D
|
||||
weight = state_dict[module_weight_key]
|
||||
|
||||
weight = weight.to(working_device, merge_dtype)
|
||||
up_weight = up_weight.to(working_device, merge_dtype)
|
||||
down_weight = down_weight.to(working_device, merge_dtype)
|
||||
|
||||
# logger.info(module_name, down_weight.size(), up_weight.size())
|
||||
if len(weight.size()) == 2:
|
||||
# linear
|
||||
weight = weight + ratio * (up_weight @ down_weight) * scale
|
||||
elif down_weight.size()[2:4] == (1, 1):
|
||||
# conv2d 1x1
|
||||
weight = (
|
||||
weight
|
||||
+ ratio
|
||||
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
* scale
|
||||
)
|
||||
else:
|
||||
# conv2d 3x3
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
|
||||
weight = weight + ratio * conved * scale
|
||||
|
||||
state_dict[module_weight_key] = weight.to(loading_device, save_dtype)
|
||||
del up_weight
|
||||
del down_weight
|
||||
del weight
|
||||
|
||||
if len(lora_sd) > 0:
|
||||
logger.warning(f"Unused keys in LoRA model: {list(lora_sd.keys())}")
|
||||
|
||||
return flux_state_dict, clip_l_state_dict, t5xxl_state_dict
|
||||
|
||||
|
||||
def merge_to_flux_model_diffusers(
|
||||
loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype, mem_eff_load_save=False
|
||||
):
|
||||
logger.info(f"loading keys from FLUX.1 model: {flux_model}")
|
||||
if mem_eff_load_save:
|
||||
flux_state_dict = {}
|
||||
with MemoryEfficientSafeOpen(flux_model) as flux_file:
|
||||
for key in tqdm(flux_file.keys()):
|
||||
flux_state_dict[key] = flux_file.get_tensor(key).to(loading_device) # dtype is not changed
|
||||
else:
|
||||
flux_state_dict = load_file(flux_model, device=loading_device)
|
||||
|
||||
def create_key_map(n_double_layers, n_single_layers):
|
||||
key_map = {}
|
||||
for index in range(n_double_layers):
|
||||
prefix_from = f"transformer_blocks.{index}"
|
||||
prefix_to = f"double_blocks.{index}"
|
||||
|
||||
for end in ("weight", "bias"):
|
||||
k = f"{prefix_from}.attn."
|
||||
qkv_img = f"{prefix_to}.img_attn.qkv.{end}"
|
||||
qkv_txt = f"{prefix_to}.txt_attn.qkv.{end}"
|
||||
|
||||
key_map[f"{k}to_q.{end}"] = qkv_img
|
||||
key_map[f"{k}to_k.{end}"] = qkv_img
|
||||
key_map[f"{k}to_v.{end}"] = qkv_img
|
||||
key_map[f"{k}add_q_proj.{end}"] = qkv_txt
|
||||
key_map[f"{k}add_k_proj.{end}"] = qkv_txt
|
||||
key_map[f"{k}add_v_proj.{end}"] = qkv_txt
|
||||
|
||||
block_map = {
|
||||
"attn.to_out.0.weight": "img_attn.proj.weight",
|
||||
"attn.to_out.0.bias": "img_attn.proj.bias",
|
||||
"norm1.linear.weight": "img_mod.lin.weight",
|
||||
"norm1.linear.bias": "img_mod.lin.bias",
|
||||
"norm1_context.linear.weight": "txt_mod.lin.weight",
|
||||
"norm1_context.linear.bias": "txt_mod.lin.bias",
|
||||
"attn.to_add_out.weight": "txt_attn.proj.weight",
|
||||
"attn.to_add_out.bias": "txt_attn.proj.bias",
|
||||
"ff.net.0.proj.weight": "img_mlp.0.weight",
|
||||
"ff.net.0.proj.bias": "img_mlp.0.bias",
|
||||
"ff.net.2.weight": "img_mlp.2.weight",
|
||||
"ff.net.2.bias": "img_mlp.2.bias",
|
||||
"ff_context.net.0.proj.weight": "txt_mlp.0.weight",
|
||||
"ff_context.net.0.proj.bias": "txt_mlp.0.bias",
|
||||
"ff_context.net.2.weight": "txt_mlp.2.weight",
|
||||
"ff_context.net.2.bias": "txt_mlp.2.bias",
|
||||
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
|
||||
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
|
||||
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
|
||||
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
|
||||
}
|
||||
|
||||
for k, v in block_map.items():
|
||||
key_map[f"{prefix_from}.{k}"] = f"{prefix_to}.{v}"
|
||||
|
||||
for index in range(n_single_layers):
|
||||
prefix_from = f"single_transformer_blocks.{index}"
|
||||
prefix_to = f"single_blocks.{index}"
|
||||
|
||||
for end in ("weight", "bias"):
|
||||
k = f"{prefix_from}.attn."
|
||||
qkv = f"{prefix_to}.linear1.{end}"
|
||||
key_map[f"{k}to_q.{end}"] = qkv
|
||||
key_map[f"{k}to_k.{end}"] = qkv
|
||||
key_map[f"{k}to_v.{end}"] = qkv
|
||||
key_map[f"{prefix_from}.proj_mlp.{end}"] = qkv
|
||||
|
||||
block_map = {
|
||||
"norm.linear.weight": "modulation.lin.weight",
|
||||
"norm.linear.bias": "modulation.lin.bias",
|
||||
"proj_out.weight": "linear2.weight",
|
||||
"proj_out.bias": "linear2.bias",
|
||||
"attn.norm_q.weight": "norm.query_norm.scale",
|
||||
"attn.norm_k.weight": "norm.key_norm.scale",
|
||||
}
|
||||
|
||||
for k, v in block_map.items():
|
||||
key_map[f"{prefix_from}.{k}"] = f"{prefix_to}.{v}"
|
||||
|
||||
# add as-is keys
|
||||
values = list([(v if isinstance(v, str) else v[0]) for v in set(key_map.values())])
|
||||
values.sort()
|
||||
key_map.update({v: v for v in values})
|
||||
|
||||
return key_map
|
||||
|
||||
key_map = create_key_map(18, 38) # 18 double layers, 38 single layers
|
||||
|
||||
def find_matching_key(flux_dict, lora_key):
|
||||
lora_key = lora_key.replace("diffusion_model.", "")
|
||||
lora_key = lora_key.replace("transformer.", "")
|
||||
lora_key = lora_key.replace("lora_A", "lora_down").replace("lora_B", "lora_up")
|
||||
lora_key = lora_key.replace("single_transformer_blocks", "single_blocks")
|
||||
lora_key = lora_key.replace("transformer_blocks", "double_blocks")
|
||||
|
||||
double_block_map = {
|
||||
"attn.to_out.0": "img_attn.proj",
|
||||
"norm1.linear": "img_mod.lin",
|
||||
"norm1_context.linear": "txt_mod.lin",
|
||||
"attn.to_add_out": "txt_attn.proj",
|
||||
"ff.net.0.proj": "img_mlp.0",
|
||||
"ff.net.2": "img_mlp.2",
|
||||
"ff_context.net.0.proj": "txt_mlp.0",
|
||||
"ff_context.net.2": "txt_mlp.2",
|
||||
"attn.norm_q": "img_attn.norm.query_norm",
|
||||
"attn.norm_k": "img_attn.norm.key_norm",
|
||||
"attn.norm_added_q": "txt_attn.norm.query_norm",
|
||||
"attn.norm_added_k": "txt_attn.norm.key_norm",
|
||||
"attn.to_q": "img_attn.qkv",
|
||||
"attn.to_k": "img_attn.qkv",
|
||||
"attn.to_v": "img_attn.qkv",
|
||||
"attn.add_q_proj": "txt_attn.qkv",
|
||||
"attn.add_k_proj": "txt_attn.qkv",
|
||||
"attn.add_v_proj": "txt_attn.qkv",
|
||||
}
|
||||
single_block_map = {
|
||||
"norm.linear": "modulation.lin",
|
||||
"proj_out": "linear2",
|
||||
"attn.norm_q": "norm.query_norm",
|
||||
"attn.norm_k": "norm.key_norm",
|
||||
"attn.to_q": "linear1",
|
||||
"attn.to_k": "linear1",
|
||||
"attn.to_v": "linear1",
|
||||
"proj_mlp": "linear1",
|
||||
}
|
||||
|
||||
# same key exists in both single_block_map and double_block_map, so we must care about single/double
|
||||
# print("lora_key before double_block_map", lora_key)
|
||||
for old, new in double_block_map.items():
|
||||
if "double" in lora_key:
|
||||
lora_key = lora_key.replace(old, new)
|
||||
# print("lora_key before single_block_map", lora_key)
|
||||
for old, new in single_block_map.items():
|
||||
if "single" in lora_key:
|
||||
lora_key = lora_key.replace(old, new)
|
||||
# print("lora_key after mapping", lora_key)
|
||||
|
||||
if lora_key in key_map:
|
||||
flux_key = key_map[lora_key]
|
||||
logger.info(f"Found matching key: {flux_key}")
|
||||
return flux_key
|
||||
|
||||
# If not found in key_map, try partial matching
|
||||
potential_key = lora_key + ".weight"
|
||||
logger.info(f"Searching for key: {potential_key}")
|
||||
matches = [k for k in flux_dict.keys() if potential_key in k]
|
||||
if matches:
|
||||
logger.info(f"Found matching key: {matches[0]}")
|
||||
return matches[0]
|
||||
return None
|
||||
|
||||
merged_keys = set()
|
||||
for model, ratio in zip(models, ratios):
|
||||
logger.info(f"loading: {model}")
|
||||
lora_sd, _ = load_state_dict(model, merge_dtype)
|
||||
|
||||
logger.info("merging...")
|
||||
for key in lora_sd.keys():
|
||||
if "lora_down" in key or "lora_A" in key:
|
||||
lora_name = key[: key.rfind(".lora_down" if "lora_down" in key else ".lora_A")]
|
||||
up_key = key.replace("lora_down", "lora_up").replace("lora_A", "lora_B")
|
||||
alpha_key = key[: key.index("lora_down" if "lora_down" in key else "lora_A")] + "alpha"
|
||||
|
||||
logger.info(f"Processing LoRA key: {lora_name}")
|
||||
flux_key = find_matching_key(flux_state_dict, lora_name)
|
||||
|
||||
if flux_key is None:
|
||||
logger.warning(f"no module found for LoRA weight: {key}")
|
||||
continue
|
||||
|
||||
logger.info(f"Merging LoRA key {lora_name} into Flux key {flux_key}")
|
||||
|
||||
down_weight = lora_sd[key]
|
||||
up_weight = lora_sd[up_key]
|
||||
|
||||
dim = down_weight.size()[0]
|
||||
alpha = lora_sd.get(alpha_key, dim)
|
||||
scale = alpha / dim
|
||||
|
||||
weight = flux_state_dict[flux_key]
|
||||
|
||||
weight = weight.to(working_device, merge_dtype)
|
||||
up_weight = up_weight.to(working_device, merge_dtype)
|
||||
down_weight = down_weight.to(working_device, merge_dtype)
|
||||
|
||||
# print(up_weight.size(), down_weight.size(), weight.size())
|
||||
|
||||
if lora_name.startswith("transformer."):
|
||||
if "qkv" in flux_key or "linear1" in flux_key: # combined qkv or qkv+mlp
|
||||
update = ratio * (up_weight @ down_weight) * scale
|
||||
# print(update.shape)
|
||||
|
||||
if "img_attn" in flux_key or "txt_attn" in flux_key:
|
||||
q, k, v = torch.chunk(weight, 3, dim=0)
|
||||
if "to_q" in lora_name or "add_q_proj" in lora_name:
|
||||
q += update.reshape(q.shape)
|
||||
elif "to_k" in lora_name or "add_k_proj" in lora_name:
|
||||
k += update.reshape(k.shape)
|
||||
elif "to_v" in lora_name or "add_v_proj" in lora_name:
|
||||
v += update.reshape(v.shape)
|
||||
weight = torch.cat([q, k, v], dim=0)
|
||||
elif "linear1" in flux_key:
|
||||
q, k, v = torch.chunk(weight[: int(update.shape[-1] * 3)], 3, dim=0)
|
||||
mlp = weight[int(update.shape[-1] * 3) :]
|
||||
# print(q.shape, k.shape, v.shape, mlp.shape)
|
||||
if "to_q" in lora_name:
|
||||
q += update.reshape(q.shape)
|
||||
elif "to_k" in lora_name:
|
||||
k += update.reshape(k.shape)
|
||||
elif "to_v" in lora_name:
|
||||
v += update.reshape(v.shape)
|
||||
elif "proj_mlp" in lora_name:
|
||||
mlp += update.reshape(mlp.shape)
|
||||
weight = torch.cat([q, k, v, mlp], dim=0)
|
||||
else:
|
||||
if len(weight.size()) == 2:
|
||||
weight = weight + ratio * (up_weight @ down_weight) * scale
|
||||
elif down_weight.size()[2:4] == (1, 1):
|
||||
weight = (
|
||||
weight
|
||||
+ ratio
|
||||
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
* scale
|
||||
)
|
||||
else:
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
weight = weight + ratio * conved * scale
|
||||
else:
|
||||
if len(weight.size()) == 2:
|
||||
weight = weight + ratio * (up_weight @ down_weight) * scale
|
||||
elif down_weight.size()[2:4] == (1, 1):
|
||||
weight = (
|
||||
weight
|
||||
+ ratio
|
||||
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
* scale
|
||||
)
|
||||
else:
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
weight = weight + ratio * conved * scale
|
||||
|
||||
flux_state_dict[flux_key] = weight.to(loading_device, save_dtype)
|
||||
merged_keys.add(flux_key)
|
||||
del up_weight
|
||||
del down_weight
|
||||
del weight
|
||||
|
||||
logger.info(f"Merged keys: {sorted(list(merged_keys))}")
|
||||
return flux_state_dict
|
||||
|
||||
|
||||
def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
|
||||
base_alphas = {} # alpha for merged model
|
||||
base_dims = {}
|
||||
|
||||
merged_sd = {}
|
||||
base_model = None
|
||||
for model, ratio in zip(models, ratios):
|
||||
logger.info(f"loading: {model}")
|
||||
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
|
||||
|
||||
if lora_metadata is not None:
|
||||
if base_model is None:
|
||||
base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None)
|
||||
|
||||
# get alpha and dim
|
||||
alphas = {} # alpha for current model
|
||||
dims = {} # dims for current model
|
||||
for key in lora_sd.keys():
|
||||
if "alpha" in key:
|
||||
lora_module_name = key[: key.rfind(".alpha")]
|
||||
alpha = float(lora_sd[key].detach().numpy())
|
||||
alphas[lora_module_name] = alpha
|
||||
if lora_module_name not in base_alphas:
|
||||
base_alphas[lora_module_name] = alpha
|
||||
elif "lora_down" in key:
|
||||
lora_module_name = key[: key.rfind(".lora_down")]
|
||||
dim = lora_sd[key].size()[0]
|
||||
dims[lora_module_name] = dim
|
||||
if lora_module_name not in base_dims:
|
||||
base_dims[lora_module_name] = dim
|
||||
|
||||
for lora_module_name in dims.keys():
|
||||
if lora_module_name not in alphas:
|
||||
alpha = dims[lora_module_name]
|
||||
alphas[lora_module_name] = alpha
|
||||
if lora_module_name not in base_alphas:
|
||||
base_alphas[lora_module_name] = alpha
|
||||
|
||||
logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
|
||||
|
||||
# merge
|
||||
logger.info("merging...")
|
||||
for key in tqdm(lora_sd.keys()):
|
||||
if "alpha" in key:
|
||||
continue
|
||||
|
||||
if "lora_up" in key and concat:
|
||||
concat_dim = 1
|
||||
elif "lora_down" in key and concat:
|
||||
concat_dim = 0
|
||||
else:
|
||||
concat_dim = None
|
||||
|
||||
lora_module_name = key[: key.rfind(".lora_")]
|
||||
|
||||
base_alpha = base_alphas[lora_module_name]
|
||||
alpha = alphas[lora_module_name]
|
||||
|
||||
scale = math.sqrt(alpha / base_alpha) * ratio
|
||||
scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。
|
||||
|
||||
if key in merged_sd:
|
||||
assert (
|
||||
merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None
|
||||
), "weights shape mismatch, different dims? / 重みのサイズが合いません。dimが異なる可能性があります。"
|
||||
if concat_dim is not None:
|
||||
merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim)
|
||||
else:
|
||||
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
|
||||
else:
|
||||
merged_sd[key] = lora_sd[key] * scale
|
||||
|
||||
# set alpha to sd
|
||||
for lora_module_name, alpha in base_alphas.items():
|
||||
key = lora_module_name + ".alpha"
|
||||
merged_sd[key] = torch.tensor(alpha)
|
||||
if shuffle:
|
||||
key_down = lora_module_name + ".lora_down.weight"
|
||||
key_up = lora_module_name + ".lora_up.weight"
|
||||
dim = merged_sd[key_down].shape[0]
|
||||
perm = torch.randperm(dim)
|
||||
merged_sd[key_down] = merged_sd[key_down][perm]
|
||||
merged_sd[key_up] = merged_sd[key_up][:, perm]
|
||||
|
||||
logger.info("merged model")
|
||||
logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
|
||||
|
||||
# check all dims are same
|
||||
dims_list = list(set(base_dims.values()))
|
||||
alphas_list = list(set(base_alphas.values()))
|
||||
all_same_dims = True
|
||||
all_same_alphas = True
|
||||
for dims in dims_list:
|
||||
if dims != dims_list[0]:
|
||||
all_same_dims = False
|
||||
break
|
||||
for alphas in alphas_list:
|
||||
if alphas != alphas_list[0]:
|
||||
all_same_alphas = False
|
||||
break
|
||||
|
||||
# build minimum metadata
|
||||
dims = f"{dims_list[0]}" if all_same_dims else "Dynamic"
|
||||
alphas = f"{alphas_list[0]}" if all_same_alphas else "Dynamic"
|
||||
metadata = train_util.build_minimum_network_metadata(str(False), base_model, "networks.lora", dims, alphas, None)
|
||||
|
||||
return merged_sd, metadata
|
||||
|
||||
|
||||
def merge(args):
|
||||
if args.models is None:
|
||||
args.models = []
|
||||
if args.ratios is None:
|
||||
args.ratios = []
|
||||
|
||||
assert len(args.models) == len(
|
||||
args.ratios
|
||||
), "number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
|
||||
|
||||
merge_dtype = str_to_dtype(args.precision)
|
||||
save_dtype = str_to_dtype(args.save_precision)
|
||||
if save_dtype is None:
|
||||
save_dtype = merge_dtype
|
||||
|
||||
assert (
|
||||
args.save_to or args.clip_l_save_to or args.t5xxl_save_to
|
||||
), "save_to or clip_l_save_to or t5xxl_save_to must be specified / save_toまたはclip_l_save_toまたはt5xxl_save_toを指定してください"
|
||||
dest_dir = os.path.dirname(args.save_to or args.clip_l_save_to or args.t5xxl_save_to)
|
||||
if not os.path.exists(dest_dir):
|
||||
logger.info(f"creating directory: {dest_dir}")
|
||||
os.makedirs(dest_dir)
|
||||
|
||||
if args.flux_model is not None or args.clip_l is not None or args.t5xxl is not None:
|
||||
if not args.diffusers:
|
||||
assert (args.clip_l is None and args.clip_l_save_to is None) or (
|
||||
args.clip_l is not None and args.clip_l_save_to is not None
|
||||
), "clip_l_save_to must be specified if clip_l is specified / clip_lが指定されている場合はclip_l_save_toも指定してください"
|
||||
assert (args.t5xxl is None and args.t5xxl_save_to is None) or (
|
||||
args.t5xxl is not None and args.t5xxl_save_to is not None
|
||||
), "t5xxl_save_to must be specified if t5xxl is specified / t5xxlが指定されている場合はt5xxl_save_toも指定してください"
|
||||
flux_state_dict, clip_l_state_dict, t5xxl_state_dict = merge_to_flux_model(
|
||||
args.loading_device,
|
||||
args.working_device,
|
||||
args.flux_model,
|
||||
args.clip_l,
|
||||
args.t5xxl,
|
||||
args.models,
|
||||
args.ratios,
|
||||
merge_dtype,
|
||||
save_dtype,
|
||||
args.mem_eff_load_save,
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
args.clip_l is None and args.t5xxl is None
|
||||
), "clip_l and t5xxl are not supported with --diffusers / clip_l、t5xxlはDiffusersではサポートされていません"
|
||||
flux_state_dict = merge_to_flux_model_diffusers(
|
||||
args.loading_device,
|
||||
args.working_device,
|
||||
args.flux_model,
|
||||
args.models,
|
||||
args.ratios,
|
||||
merge_dtype,
|
||||
save_dtype,
|
||||
args.mem_eff_load_save,
|
||||
)
|
||||
clip_l_state_dict = None
|
||||
t5xxl_state_dict = None
|
||||
|
||||
if args.no_metadata or (flux_state_dict is None or len(flux_state_dict) == 0):
|
||||
sai_metadata = None
|
||||
else:
|
||||
merged_from = sai_model_spec.build_merged_from([args.flux_model] + args.models)
|
||||
title = os.path.splitext(os.path.basename(args.save_to))[0]
|
||||
sai_metadata = sai_model_spec.build_metadata(
|
||||
None, False, False, False, False, False, time.time(), title=title, merged_from=merged_from, flux="dev"
|
||||
)
|
||||
|
||||
if flux_state_dict is not None and len(flux_state_dict) > 0:
|
||||
logger.info(f"saving FLUX model to: {args.save_to}")
|
||||
save_to_file(args.save_to, flux_state_dict, save_dtype, sai_metadata, args.mem_eff_load_save)
|
||||
|
||||
if clip_l_state_dict is not None and len(clip_l_state_dict) > 0:
|
||||
logger.info(f"saving clip_l model to: {args.clip_l_save_to}")
|
||||
save_to_file(args.clip_l_save_to, clip_l_state_dict, save_dtype, None, args.mem_eff_load_save)
|
||||
|
||||
if t5xxl_state_dict is not None and len(t5xxl_state_dict) > 0:
|
||||
logger.info(f"saving t5xxl model to: {args.t5xxl_save_to}")
|
||||
save_to_file(args.t5xxl_save_to, t5xxl_state_dict, save_dtype, None, args.mem_eff_load_save)
|
||||
|
||||
else:
|
||||
flux_state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
|
||||
|
||||
logger.info("calculating hashes and creating metadata...")
|
||||
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(flux_state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
metadata["sshs_legacy_hash"] = legacy_hash
|
||||
|
||||
if not args.no_metadata:
|
||||
merged_from = sai_model_spec.build_merged_from(args.models)
|
||||
title = os.path.splitext(os.path.basename(args.save_to))[0]
|
||||
sai_metadata = sai_model_spec.build_metadata(
|
||||
flux_state_dict, False, False, False, True, False, time.time(), title=title, merged_from=merged_from, flux="dev"
|
||||
)
|
||||
metadata.update(sai_metadata)
|
||||
|
||||
logger.info(f"saving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, flux_state_dict, save_dtype, metadata)
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--save_precision",
|
||||
type=str,
|
||||
default=None,
|
||||
help="precision in saving, same to merging if omitted. supported types: "
|
||||
"float32, fp16, bf16, fp8 (same as fp8_e4m3fn), fp8_e4m3fn, fp8_e4m3fnuz, fp8_e5m2, fp8_e5m2fnuz"
|
||||
" / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
type=str,
|
||||
default="float",
|
||||
help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--flux_model",
|
||||
type=str,
|
||||
default=None,
|
||||
help="FLUX.1 model to load, merge LoRA models if omitted / 読み込むモデル、指定しない場合はLoRAモデルをマージする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clip_l",
|
||||
type=str,
|
||||
default=None,
|
||||
help="path to clip_l (*.sft or *.safetensors), should be float16 / clip_lのパス(*.sftまたは*.safetensors)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--t5xxl",
|
||||
type=str,
|
||||
default=None,
|
||||
help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス(*.sftまたは*.safetensors)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mem_eff_load_save",
|
||||
action="store_true",
|
||||
help="use custom memory efficient load and save functions for FLUX.1 model"
|
||||
" / カスタムのメモリ効率の良い読み込みと保存関数をFLUX.1モデルに使用する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--loading_device",
|
||||
type=str,
|
||||
default="cpu",
|
||||
help="device to load FLUX.1 model. LoRA models are loaded on CPU / FLUX.1モデルを読み込むデバイス。LoRAモデルはCPUで読み込まれます",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--working_device",
|
||||
type=str,
|
||||
default="cpu",
|
||||
help="device to work (merge). Merging LoRA models are done on CPU."
|
||||
+ " / 作業(マージ)するデバイス。LoRAモデルのマージはCPUで行われます。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_to",
|
||||
type=str,
|
||||
default=None,
|
||||
help="destination file name: safetensors file / 保存先のファイル名、safetensorsファイル",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clip_l_save_to",
|
||||
type=str,
|
||||
default=None,
|
||||
help="destination file name for clip_l: safetensors file / clip_lの保存先のファイル名、safetensorsファイル",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--t5xxl_save_to",
|
||||
type=str,
|
||||
default=None,
|
||||
help="destination file name for t5xxl: safetensors file / t5xxlの保存先のファイル名、safetensorsファイル",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="LoRA models to merge: safetensors file / マージするLoRAモデル、safetensorsファイル",
|
||||
)
|
||||
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
|
||||
parser.add_argument(
|
||||
"--no_metadata",
|
||||
action="store_true",
|
||||
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
|
||||
+ "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--concat",
|
||||
action="store_true",
|
||||
help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / "
|
||||
+ "マージの代わりに結合する(LoRAのdim(rank)は入力dimの合計になる)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shuffle",
|
||||
action="store_true",
|
||||
help="shuffle lora weight./ " + "LoRAの重みをシャッフルする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--diffusers",
|
||||
action="store_true",
|
||||
help="merge Diffusers (?) LoRA models / Diffusers (?) LoRAモデルをマージする",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
merge(args)
|
||||
555
networks/lora.py
555
networks/lora.py
@@ -11,7 +11,13 @@ from transformers import CLIPTextModel
|
||||
import numpy as np
|
||||
import torch
|
||||
import re
|
||||
from library.utils import setup_logging
|
||||
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
||||
|
||||
@@ -46,7 +52,7 @@ class LoRAModule(torch.nn.Module):
|
||||
# if limit_rank:
|
||||
# self.lora_dim = min(lora_dim, in_dim, out_dim)
|
||||
# if self.lora_dim != lora_dim:
|
||||
# print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
|
||||
# logger.info(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
|
||||
# else:
|
||||
self.lora_dim = lora_dim
|
||||
|
||||
@@ -177,7 +183,7 @@ class LoRAInfModule(LoRAModule):
|
||||
else:
|
||||
# conv2d 3x3
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
# print(conved.size(), weight.size(), module.stride, module.padding)
|
||||
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
|
||||
weight = weight + self.multiplier * conved * self.scale
|
||||
|
||||
# set weight to org_module
|
||||
@@ -216,7 +222,7 @@ class LoRAInfModule(LoRAModule):
|
||||
self.region_mask = None
|
||||
|
||||
def default_forward(self, x):
|
||||
# print("default_forward", self.lora_name, x.size())
|
||||
# logger.info(f"default_forward {self.lora_name} {x.size()}")
|
||||
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||
|
||||
def forward(self, x):
|
||||
@@ -242,13 +248,13 @@ class LoRAInfModule(LoRAModule):
|
||||
area = x.size()[1]
|
||||
|
||||
mask = self.network.mask_dic.get(area, None)
|
||||
if mask is None:
|
||||
# raise ValueError(f"mask is None for resolution {area}")
|
||||
if mask is None or len(x.size()) == 2:
|
||||
# emb_layers in SDXL doesn't have mask
|
||||
# print(f"mask is None for resolution {area}, {x.size()}")
|
||||
# if "emb" not in self.lora_name:
|
||||
# print(f"mask is None for resolution {self.lora_name}, {area}, {x.size()}")
|
||||
mask_size = (1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1)
|
||||
return torch.ones(mask_size, dtype=x.dtype, device=x.device) / self.network.num_sub_prompts
|
||||
if len(x.size()) != 4:
|
||||
if len(x.size()) == 3:
|
||||
mask = torch.reshape(mask, (1, -1, 1))
|
||||
return mask
|
||||
|
||||
@@ -263,6 +269,8 @@ class LoRAInfModule(LoRAModule):
|
||||
lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||
mask = self.get_mask_for_x(lx)
|
||||
# print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
|
||||
# if mask.ndim > lx.ndim: # in some resolution, lx is 2d and mask is 3d (the reason is not checked)
|
||||
# mask = mask.squeeze(-1)
|
||||
lx = lx * mask
|
||||
|
||||
x = self.org_forward(x)
|
||||
@@ -291,7 +299,7 @@ class LoRAInfModule(LoRAModule):
|
||||
if has_real_uncond:
|
||||
query[-self.network.batch_size :] = x[-self.network.batch_size :]
|
||||
|
||||
# print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts)
|
||||
# logger.info(f"postp_to_q {self.lora_name} {x.size()} {query.size()} {self.network.num_sub_prompts}")
|
||||
return query
|
||||
|
||||
def sub_prompt_forward(self, x):
|
||||
@@ -306,7 +314,7 @@ class LoRAInfModule(LoRAModule):
|
||||
lx = x[emb_idx :: self.network.num_sub_prompts]
|
||||
lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale
|
||||
|
||||
# print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx)
|
||||
# logger.info(f"sub_prompt_forward {self.lora_name} {x.size()} {lx.size()} {emb_idx}")
|
||||
|
||||
x = self.org_forward(x)
|
||||
x[emb_idx :: self.network.num_sub_prompts] += lx
|
||||
@@ -314,7 +322,7 @@ class LoRAInfModule(LoRAModule):
|
||||
return x
|
||||
|
||||
def to_out_forward(self, x):
|
||||
# print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network)
|
||||
# logger.info(f"to_out_forward {self.lora_name} {x.size()} {self.network.is_last_network}")
|
||||
|
||||
if self.network.is_last_network:
|
||||
masks = [None] * self.network.num_sub_prompts
|
||||
@@ -332,7 +340,7 @@ class LoRAInfModule(LoRAModule):
|
||||
)
|
||||
self.network.shared[self.lora_name] = (lx, masks)
|
||||
|
||||
# print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts)
|
||||
# logger.info(f"to_out_forward {lx.size()} {lx1.size()} {self.network.sub_prompt_index} {self.network.num_sub_prompts}")
|
||||
lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1
|
||||
masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1)
|
||||
|
||||
@@ -351,7 +359,7 @@ class LoRAInfModule(LoRAModule):
|
||||
if has_real_uncond:
|
||||
out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond
|
||||
|
||||
# print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts)
|
||||
# logger.info(f"to_out_forward {self.lora_name} {self.network.sub_prompt_index} {self.network.num_sub_prompts}")
|
||||
# if num_sub_prompts > num of LoRAs, fill with zero
|
||||
for i in range(len(masks)):
|
||||
if masks[i] is None:
|
||||
@@ -374,18 +382,18 @@ class LoRAInfModule(LoRAModule):
|
||||
x1 = x1 + lx1
|
||||
out[self.network.batch_size + i] = x1
|
||||
|
||||
# print("to_out_forward", x.size(), out.size(), has_real_uncond)
|
||||
# logger.info(f"to_out_forward {x.size()} {out.size()} {has_real_uncond}")
|
||||
return out
|
||||
|
||||
|
||||
def parse_block_lr_kwargs(nw_kwargs):
|
||||
def parse_block_lr_kwargs(is_sdxl: bool, nw_kwargs: Dict) -> Optional[List[float]]:
|
||||
down_lr_weight = nw_kwargs.get("down_lr_weight", None)
|
||||
mid_lr_weight = nw_kwargs.get("mid_lr_weight", None)
|
||||
up_lr_weight = nw_kwargs.get("up_lr_weight", None)
|
||||
|
||||
# 以上のいずれにも設定がない場合は無効としてNoneを返す
|
||||
if down_lr_weight is None and mid_lr_weight is None and up_lr_weight is None:
|
||||
return None, None, None
|
||||
return None
|
||||
|
||||
# extract learning rate weight for each block
|
||||
if down_lr_weight is not None:
|
||||
@@ -394,18 +402,16 @@ def parse_block_lr_kwargs(nw_kwargs):
|
||||
down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")]
|
||||
|
||||
if mid_lr_weight is not None:
|
||||
mid_lr_weight = float(mid_lr_weight)
|
||||
mid_lr_weight = [(float(s) if s else 0.0) for s in mid_lr_weight.split(",")]
|
||||
|
||||
if up_lr_weight is not None:
|
||||
if "," in up_lr_weight:
|
||||
up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")]
|
||||
|
||||
down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight(
|
||||
down_lr_weight, mid_lr_weight, up_lr_weight, float(nw_kwargs.get("block_lr_zero_threshold", 0.0))
|
||||
return get_block_lr_weight(
|
||||
is_sdxl, down_lr_weight, mid_lr_weight, up_lr_weight, float(nw_kwargs.get("block_lr_zero_threshold", 0.0))
|
||||
)
|
||||
|
||||
return down_lr_weight, mid_lr_weight, up_lr_weight
|
||||
|
||||
|
||||
def create_network(
|
||||
multiplier: float,
|
||||
@@ -417,6 +423,9 @@ def create_network(
|
||||
neuron_dropout: Optional[float] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True
|
||||
is_sdxl = unet is not None and issubclass(unet.__class__, SdxlUNet2DConditionModel)
|
||||
|
||||
if network_dim is None:
|
||||
network_dim = 4 # default
|
||||
if network_alpha is None:
|
||||
@@ -434,21 +443,21 @@ def create_network(
|
||||
|
||||
# block dim/alpha/lr
|
||||
block_dims = kwargs.get("block_dims", None)
|
||||
down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs)
|
||||
block_lr_weight = parse_block_lr_kwargs(is_sdxl, kwargs)
|
||||
|
||||
# 以上のいずれかに指定があればblockごとのdim(rank)を有効にする
|
||||
if block_dims is not None or down_lr_weight is not None or mid_lr_weight is not None or up_lr_weight is not None:
|
||||
if block_dims is not None or block_lr_weight is not None:
|
||||
block_alphas = kwargs.get("block_alphas", None)
|
||||
conv_block_dims = kwargs.get("conv_block_dims", None)
|
||||
conv_block_alphas = kwargs.get("conv_block_alphas", None)
|
||||
|
||||
block_dims, block_alphas, conv_block_dims, conv_block_alphas = get_block_dims_and_alphas(
|
||||
block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
|
||||
is_sdxl, block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
|
||||
)
|
||||
|
||||
# remove block dim/alpha without learning rate
|
||||
block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas(
|
||||
block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
|
||||
is_sdxl, block_dims, block_alphas, conv_block_dims, conv_block_alphas, block_lr_weight
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -481,10 +490,20 @@ def create_network(
|
||||
conv_block_dims=conv_block_dims,
|
||||
conv_block_alphas=conv_block_alphas,
|
||||
varbose=True,
|
||||
is_sdxl=is_sdxl,
|
||||
)
|
||||
|
||||
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
|
||||
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
|
||||
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
|
||||
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
|
||||
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
|
||||
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
|
||||
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
|
||||
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
|
||||
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
|
||||
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
|
||||
|
||||
if block_lr_weight is not None:
|
||||
network.set_block_lr_weight(block_lr_weight)
|
||||
|
||||
return network
|
||||
|
||||
@@ -494,9 +513,13 @@ def create_network(
|
||||
# block_dims, block_alphas は両方ともNoneまたは両方とも値が入っている
|
||||
# conv_dim, conv_alpha は両方ともNoneまたは両方とも値が入っている
|
||||
def get_block_dims_and_alphas(
|
||||
block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
|
||||
is_sdxl, block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
|
||||
):
|
||||
num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + 1
|
||||
if not is_sdxl:
|
||||
num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + LoRANetwork.NUM_OF_MID_BLOCKS
|
||||
else:
|
||||
# 1+9+3+9+1=23, no LoRA for emb_layers (0)
|
||||
num_total_blocks = 1 + LoRANetwork.SDXL_NUM_OF_BLOCKS * 2 + LoRANetwork.SDXL_NUM_OF_MID_BLOCKS + 1
|
||||
|
||||
def parse_ints(s):
|
||||
return [int(i) for i in s.split(",")]
|
||||
@@ -507,11 +530,14 @@ def get_block_dims_and_alphas(
|
||||
# block_dimsとblock_alphasをパースする。必ず値が入る
|
||||
if block_dims is not None:
|
||||
block_dims = parse_ints(block_dims)
|
||||
assert (
|
||||
len(block_dims) == num_total_blocks
|
||||
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
|
||||
assert len(block_dims) == num_total_blocks, (
|
||||
f"block_dims must have {num_total_blocks} elements but {len(block_dims)} elements are given"
|
||||
+ f" / block_dimsは{num_total_blocks}個指定してください(指定された個数: {len(block_dims)})"
|
||||
)
|
||||
else:
|
||||
print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
|
||||
logger.warning(
|
||||
f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります"
|
||||
)
|
||||
block_dims = [network_dim] * num_total_blocks
|
||||
|
||||
if block_alphas is not None:
|
||||
@@ -520,7 +546,7 @@ def get_block_dims_and_alphas(
|
||||
len(block_alphas) == num_total_blocks
|
||||
), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください"
|
||||
else:
|
||||
print(
|
||||
logger.warning(
|
||||
f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります"
|
||||
)
|
||||
block_alphas = [network_alpha] * num_total_blocks
|
||||
@@ -540,13 +566,13 @@ def get_block_dims_and_alphas(
|
||||
else:
|
||||
if conv_alpha is None:
|
||||
conv_alpha = 1.0
|
||||
print(
|
||||
logger.warning(
|
||||
f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります"
|
||||
)
|
||||
conv_block_alphas = [conv_alpha] * num_total_blocks
|
||||
else:
|
||||
if conv_dim is not None:
|
||||
print(
|
||||
logger.warning(
|
||||
f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります"
|
||||
)
|
||||
conv_block_dims = [conv_dim] * num_total_blocks
|
||||
@@ -558,15 +584,25 @@ def get_block_dims_and_alphas(
|
||||
return block_dims, block_alphas, conv_block_dims, conv_block_alphas
|
||||
|
||||
|
||||
# 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出される可能性を考慮しておく
|
||||
# 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出せるようにclass外に出しておく
|
||||
# 戻り値は block ごとの倍率のリスト
|
||||
def get_block_lr_weight(
|
||||
down_lr_weight, mid_lr_weight, up_lr_weight, zero_threshold
|
||||
) -> Tuple[List[float], List[float], List[float]]:
|
||||
is_sdxl,
|
||||
down_lr_weight: Union[str, List[float]],
|
||||
mid_lr_weight: List[float],
|
||||
up_lr_weight: Union[str, List[float]],
|
||||
zero_threshold: float,
|
||||
) -> Optional[List[float]]:
|
||||
# パラメータ未指定時は何もせず、今までと同じ動作とする
|
||||
if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None:
|
||||
return None, None, None
|
||||
return None
|
||||
|
||||
max_len = LoRANetwork.NUM_OF_BLOCKS # フルモデル相当でのup,downの層の数
|
||||
if not is_sdxl:
|
||||
max_len_for_down_or_up = LoRANetwork.NUM_OF_BLOCKS
|
||||
max_len_for_mid = LoRANetwork.NUM_OF_MID_BLOCKS
|
||||
else:
|
||||
max_len_for_down_or_up = LoRANetwork.SDXL_NUM_OF_BLOCKS
|
||||
max_len_for_mid = LoRANetwork.SDXL_NUM_OF_MID_BLOCKS
|
||||
|
||||
def get_list(name_with_suffix) -> List[float]:
|
||||
import math
|
||||
@@ -576,17 +612,20 @@ def get_block_lr_weight(
|
||||
base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0
|
||||
|
||||
if name == "cosine":
|
||||
return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in reversed(range(max_len))]
|
||||
return [
|
||||
math.sin(math.pi * (i / (max_len_for_down_or_up - 1)) / 2) + base_lr
|
||||
for i in reversed(range(max_len_for_down_or_up))
|
||||
]
|
||||
elif name == "sine":
|
||||
return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in range(max_len)]
|
||||
return [math.sin(math.pi * (i / (max_len_for_down_or_up - 1)) / 2) + base_lr for i in range(max_len_for_down_or_up)]
|
||||
elif name == "linear":
|
||||
return [i / (max_len - 1) + base_lr for i in range(max_len)]
|
||||
return [i / (max_len_for_down_or_up - 1) + base_lr for i in range(max_len_for_down_or_up)]
|
||||
elif name == "reverse_linear":
|
||||
return [i / (max_len - 1) + base_lr for i in reversed(range(max_len))]
|
||||
return [i / (max_len_for_down_or_up - 1) + base_lr for i in reversed(range(max_len_for_down_or_up))]
|
||||
elif name == "zeros":
|
||||
return [0.0 + base_lr] * max_len
|
||||
return [0.0 + base_lr] * max_len_for_down_or_up
|
||||
else:
|
||||
print(
|
||||
logger.error(
|
||||
"Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
|
||||
% (name)
|
||||
)
|
||||
@@ -597,99 +636,176 @@ def get_block_lr_weight(
|
||||
if type(up_lr_weight) == str:
|
||||
up_lr_weight = get_list(up_lr_weight)
|
||||
|
||||
if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len):
|
||||
print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
|
||||
print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
|
||||
up_lr_weight = up_lr_weight[:max_len]
|
||||
down_lr_weight = down_lr_weight[:max_len]
|
||||
if (up_lr_weight != None and len(up_lr_weight) > max_len_for_down_or_up) or (
|
||||
down_lr_weight != None and len(down_lr_weight) > max_len_for_down_or_up
|
||||
):
|
||||
logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len_for_down_or_up)
|
||||
logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len_for_down_or_up)
|
||||
up_lr_weight = up_lr_weight[:max_len_for_down_or_up]
|
||||
down_lr_weight = down_lr_weight[:max_len_for_down_or_up]
|
||||
|
||||
if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len):
|
||||
print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
|
||||
print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
|
||||
if mid_lr_weight != None and len(mid_lr_weight) > max_len_for_mid:
|
||||
logger.warning("mid_weight is too long. Parameters after %d-th are ignored." % max_len_for_mid)
|
||||
logger.warning("mid_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len_for_mid)
|
||||
mid_lr_weight = mid_lr_weight[:max_len_for_mid]
|
||||
|
||||
if down_lr_weight != None and len(down_lr_weight) < max_len:
|
||||
down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
|
||||
if up_lr_weight != None and len(up_lr_weight) < max_len:
|
||||
up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
|
||||
if (up_lr_weight != None and len(up_lr_weight) < max_len_for_down_or_up) or (
|
||||
down_lr_weight != None and len(down_lr_weight) < max_len_for_down_or_up
|
||||
):
|
||||
logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len_for_down_or_up)
|
||||
logger.warning(
|
||||
"down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len_for_down_or_up
|
||||
)
|
||||
|
||||
if down_lr_weight != None and len(down_lr_weight) < max_len_for_down_or_up:
|
||||
down_lr_weight = down_lr_weight + [1.0] * (max_len_for_down_or_up - len(down_lr_weight))
|
||||
if up_lr_weight != None and len(up_lr_weight) < max_len_for_down_or_up:
|
||||
up_lr_weight = up_lr_weight + [1.0] * (max_len_for_down_or_up - len(up_lr_weight))
|
||||
|
||||
if mid_lr_weight != None and len(mid_lr_weight) < max_len_for_mid:
|
||||
logger.warning("mid_weight is too short. Parameters after %d-th are filled with 1." % max_len_for_mid)
|
||||
logger.warning("mid_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len_for_mid)
|
||||
mid_lr_weight = mid_lr_weight + [1.0] * (max_len_for_mid - len(mid_lr_weight))
|
||||
|
||||
if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
|
||||
print("apply block learning rate / 階層別学習率を適用します。")
|
||||
logger.info("apply block learning rate / 階層別学習率を適用します。")
|
||||
if down_lr_weight != None:
|
||||
down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
|
||||
print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", down_lr_weight)
|
||||
logger.info(f"down_lr_weight (shallower -> deeper, 浅い層->深い層): {down_lr_weight}")
|
||||
else:
|
||||
print("down_lr_weight: all 1.0, すべて1.0")
|
||||
down_lr_weight = [1.0] * max_len_for_down_or_up
|
||||
logger.info("down_lr_weight: all 1.0, すべて1.0")
|
||||
|
||||
if mid_lr_weight != None:
|
||||
mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
|
||||
print("mid_lr_weight:", mid_lr_weight)
|
||||
mid_lr_weight = [w if w > zero_threshold else 0 for w in mid_lr_weight]
|
||||
logger.info(f"mid_lr_weight: {mid_lr_weight}")
|
||||
else:
|
||||
print("mid_lr_weight: 1.0")
|
||||
mid_lr_weight = [1.0] * max_len_for_mid
|
||||
logger.info("mid_lr_weight: all 1.0, すべて1.0")
|
||||
|
||||
if up_lr_weight != None:
|
||||
up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
|
||||
print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", up_lr_weight)
|
||||
logger.info(f"up_lr_weight (deeper -> shallower, 深い層->浅い層): {up_lr_weight}")
|
||||
else:
|
||||
print("up_lr_weight: all 1.0, すべて1.0")
|
||||
up_lr_weight = [1.0] * max_len_for_down_or_up
|
||||
logger.info("up_lr_weight: all 1.0, すべて1.0")
|
||||
|
||||
return down_lr_weight, mid_lr_weight, up_lr_weight
|
||||
lr_weight = down_lr_weight + mid_lr_weight + up_lr_weight
|
||||
|
||||
if is_sdxl:
|
||||
lr_weight = [1.0] + lr_weight + [1.0] # add 1.0 for emb_layers and out
|
||||
|
||||
assert (not is_sdxl and len(lr_weight) == LoRANetwork.NUM_OF_BLOCKS * 2 + LoRANetwork.NUM_OF_MID_BLOCKS) or (
|
||||
is_sdxl and len(lr_weight) == 1 + LoRANetwork.SDXL_NUM_OF_BLOCKS * 2 + LoRANetwork.SDXL_NUM_OF_MID_BLOCKS + 1
|
||||
), f"lr_weight length is invalid: {len(lr_weight)}"
|
||||
|
||||
return lr_weight
|
||||
|
||||
|
||||
# lr_weightが0のblockをblock_dimsから除外する、外部から呼び出す可能性を考慮しておく
|
||||
def remove_block_dims_and_alphas(
|
||||
block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
|
||||
is_sdxl, block_dims, block_alphas, conv_block_dims, conv_block_alphas, block_lr_weight: Optional[List[float]]
|
||||
):
|
||||
# set 0 to block dim without learning rate to remove the block
|
||||
if down_lr_weight != None:
|
||||
for i, lr in enumerate(down_lr_weight):
|
||||
if block_lr_weight is not None:
|
||||
for i, lr in enumerate(block_lr_weight):
|
||||
if lr == 0:
|
||||
block_dims[i] = 0
|
||||
if conv_block_dims is not None:
|
||||
conv_block_dims[i] = 0
|
||||
if mid_lr_weight != None:
|
||||
if mid_lr_weight == 0:
|
||||
block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0
|
||||
if conv_block_dims is not None:
|
||||
conv_block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0
|
||||
if up_lr_weight != None:
|
||||
for i, lr in enumerate(up_lr_weight):
|
||||
if lr == 0:
|
||||
block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0
|
||||
if conv_block_dims is not None:
|
||||
conv_block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0
|
||||
|
||||
return block_dims, block_alphas, conv_block_dims, conv_block_alphas
|
||||
|
||||
|
||||
# 外部から呼び出す可能性を考慮しておく
|
||||
def get_block_index(lora_name: str) -> int:
|
||||
def get_block_index(lora_name: str, is_sdxl: bool = False) -> int:
|
||||
block_idx = -1 # invalid lora name
|
||||
if not is_sdxl:
|
||||
m = RE_UPDOWN.search(lora_name)
|
||||
if m:
|
||||
g = m.groups()
|
||||
i = int(g[1])
|
||||
j = int(g[3])
|
||||
if g[2] == "resnets":
|
||||
idx = 3 * i + j
|
||||
elif g[2] == "attentions":
|
||||
idx = 3 * i + j
|
||||
elif g[2] == "upsamplers" or g[2] == "downsamplers":
|
||||
idx = 3 * i + 2
|
||||
|
||||
m = RE_UPDOWN.search(lora_name)
|
||||
if m:
|
||||
g = m.groups()
|
||||
i = int(g[1])
|
||||
j = int(g[3])
|
||||
if g[2] == "resnets":
|
||||
idx = 3 * i + j
|
||||
elif g[2] == "attentions":
|
||||
idx = 3 * i + j
|
||||
elif g[2] == "upsamplers" or g[2] == "downsamplers":
|
||||
idx = 3 * i + 2
|
||||
|
||||
if g[0] == "down":
|
||||
block_idx = 1 + idx # 0に該当するLoRAは存在しない
|
||||
elif g[0] == "up":
|
||||
block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx
|
||||
|
||||
elif "mid_block_" in lora_name:
|
||||
block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12
|
||||
if g[0] == "down":
|
||||
block_idx = 1 + idx # 0に該当するLoRAは存在しない
|
||||
elif g[0] == "up":
|
||||
block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx
|
||||
elif "mid_block_" in lora_name:
|
||||
block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12
|
||||
else:
|
||||
# copy from sdxl_train
|
||||
if lora_name.startswith("lora_unet_"):
|
||||
name = lora_name[len("lora_unet_") :]
|
||||
if name.startswith("time_embed_") or name.startswith("label_emb_"): # No LoRA
|
||||
block_idx = 0 # 0
|
||||
elif name.startswith("input_blocks_"): # 1-9
|
||||
block_idx = 1 + int(name.split("_")[2])
|
||||
elif name.startswith("middle_block_"): # 10-12
|
||||
block_idx = 10 + int(name.split("_")[2])
|
||||
elif name.startswith("output_blocks_"): # 13-21
|
||||
block_idx = 13 + int(name.split("_")[2])
|
||||
elif name.startswith("out_"): # 22, out, no LoRA
|
||||
block_idx = 22
|
||||
|
||||
return block_idx
|
||||
|
||||
|
||||
def convert_diffusers_to_sai_if_needed(weights_sd):
|
||||
# only supports U-Net LoRA modules
|
||||
|
||||
found_up_down_blocks = False
|
||||
for k in list(weights_sd.keys()):
|
||||
if "down_blocks" in k:
|
||||
found_up_down_blocks = True
|
||||
break
|
||||
if "up_blocks" in k:
|
||||
found_up_down_blocks = True
|
||||
break
|
||||
if not found_up_down_blocks:
|
||||
return
|
||||
|
||||
from library.sdxl_model_util import make_unet_conversion_map
|
||||
|
||||
unet_conversion_map = make_unet_conversion_map()
|
||||
unet_conversion_map = {hf.replace(".", "_")[:-1]: sd.replace(".", "_")[:-1] for sd, hf in unet_conversion_map}
|
||||
|
||||
# # add extra conversion
|
||||
# unet_conversion_map["up_blocks_1_upsamplers_0"] = "lora_unet_output_blocks_2_2_conv"
|
||||
|
||||
logger.info(f"Converting LoRA keys from Diffusers to SAI")
|
||||
lora_unet_prefix = "lora_unet_"
|
||||
for k in list(weights_sd.keys()):
|
||||
if not k.startswith(lora_unet_prefix):
|
||||
continue
|
||||
|
||||
unet_module_name = k[len(lora_unet_prefix) :].split(".")[0]
|
||||
|
||||
# search for conversion: this is slow because the algorithm is O(n^2), but the number of keys is small
|
||||
for hf_module_name, sd_module_name in unet_conversion_map.items():
|
||||
if hf_module_name in unet_module_name:
|
||||
new_key = (
|
||||
lora_unet_prefix
|
||||
+ unet_module_name.replace(hf_module_name, sd_module_name)
|
||||
+ k[len(lora_unet_prefix) + len(unet_module_name) :]
|
||||
)
|
||||
weights_sd[new_key] = weights_sd.pop(k)
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found:
|
||||
logger.warning(f"Key {k} is not found in unet_conversion_map")
|
||||
|
||||
|
||||
# Create network from weights for inference, weights are not loaded here (because can be merged)
|
||||
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
|
||||
# if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True
|
||||
is_sdxl = unet is not None and issubclass(unet.__class__, SdxlUNet2DConditionModel)
|
||||
|
||||
if weights_sd is None:
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file, safe_open
|
||||
@@ -698,6 +814,10 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
|
||||
# if keys are Diffusers based, convert to SAI based
|
||||
if is_sdxl:
|
||||
convert_diffusers_to_sai_if_needed(weights_sd)
|
||||
|
||||
# get dim/alpha mapping
|
||||
modules_dim = {}
|
||||
modules_alpha = {}
|
||||
@@ -711,7 +831,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
|
||||
elif "lora_down" in key:
|
||||
dim = value.size()[0]
|
||||
modules_dim[lora_name] = dim
|
||||
# print(lora_name, value.size(), dim)
|
||||
# logger.info(lora_name, value.size(), dim)
|
||||
|
||||
# support old LoRA without alpha
|
||||
for key in modules_dim.keys():
|
||||
@@ -721,23 +841,32 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
|
||||
module_class = LoRAInfModule if for_inference else LoRAModule
|
||||
|
||||
network = LoRANetwork(
|
||||
text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
|
||||
text_encoder,
|
||||
unet,
|
||||
multiplier=multiplier,
|
||||
modules_dim=modules_dim,
|
||||
modules_alpha=modules_alpha,
|
||||
module_class=module_class,
|
||||
is_sdxl=is_sdxl,
|
||||
)
|
||||
|
||||
# block lr
|
||||
down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs)
|
||||
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
|
||||
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
|
||||
block_lr_weight = parse_block_lr_kwargs(is_sdxl, kwargs)
|
||||
if block_lr_weight is not None:
|
||||
network.set_block_lr_weight(block_lr_weight)
|
||||
|
||||
return network, weights_sd
|
||||
|
||||
|
||||
class LoRANetwork(torch.nn.Module):
|
||||
NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
|
||||
NUM_OF_MID_BLOCKS = 1
|
||||
SDXL_NUM_OF_BLOCKS = 9 # SDXLのモデルでのinput/outputの層の数 total=1(base) 9(input) + 3(mid) + 9(output) + 1(out) = 23
|
||||
SDXL_NUM_OF_MID_BLOCKS = 3
|
||||
|
||||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
||||
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
|
||||
LORA_PREFIX_UNET = "lora_unet"
|
||||
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
||||
|
||||
@@ -765,6 +894,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
modules_alpha: Optional[Dict[str, int]] = None,
|
||||
module_class: Type[object] = LoRAModule,
|
||||
varbose: Optional[bool] = False,
|
||||
is_sdxl: Optional[bool] = False,
|
||||
) -> None:
|
||||
"""
|
||||
LoRA network: すごく引数が多いが、パターンは以下の通り
|
||||
@@ -785,21 +915,31 @@ class LoRANetwork(torch.nn.Module):
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
|
||||
self.loraplus_lr_ratio = None
|
||||
self.loraplus_unet_lr_ratio = None
|
||||
self.loraplus_text_encoder_lr_ratio = None
|
||||
|
||||
if modules_dim is not None:
|
||||
print(f"create LoRA network from weights")
|
||||
logger.info(f"create LoRA network from weights")
|
||||
elif block_dims is not None:
|
||||
print(f"create LoRA network from block_dims")
|
||||
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
||||
print(f"block_dims: {block_dims}")
|
||||
print(f"block_alphas: {block_alphas}")
|
||||
logger.info(f"create LoRA network from block_dims")
|
||||
logger.info(
|
||||
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
|
||||
)
|
||||
logger.info(f"block_dims: {block_dims}")
|
||||
logger.info(f"block_alphas: {block_alphas}")
|
||||
if conv_block_dims is not None:
|
||||
print(f"conv_block_dims: {conv_block_dims}")
|
||||
print(f"conv_block_alphas: {conv_block_alphas}")
|
||||
logger.info(f"conv_block_dims: {conv_block_dims}")
|
||||
logger.info(f"conv_block_alphas: {conv_block_alphas}")
|
||||
else:
|
||||
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
||||
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
||||
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
||||
logger.info(
|
||||
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
|
||||
)
|
||||
if self.conv_lora_dim is not None:
|
||||
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
||||
logger.info(
|
||||
f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
|
||||
)
|
||||
|
||||
# create module instances
|
||||
def create_modules(
|
||||
@@ -840,7 +980,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
alpha = modules_alpha[lora_name]
|
||||
elif is_unet and block_dims is not None:
|
||||
# U-Netでblock_dims指定あり
|
||||
block_idx = get_block_index(lora_name)
|
||||
block_idx = get_block_index(lora_name, is_sdxl)
|
||||
if is_linear or is_conv2d_1x1:
|
||||
dim = block_dims[block_idx]
|
||||
alpha = block_alphas[block_idx]
|
||||
@@ -884,15 +1024,15 @@ class LoRANetwork(torch.nn.Module):
|
||||
for i, text_encoder in enumerate(text_encoders):
|
||||
if len(text_encoders) > 1:
|
||||
index = i + 1
|
||||
print(f"create LoRA for Text Encoder {index}:")
|
||||
logger.info(f"create LoRA for Text Encoder {index}:")
|
||||
else:
|
||||
index = None
|
||||
print(f"create LoRA for Text Encoder:")
|
||||
logger.info(f"create LoRA for Text Encoder:")
|
||||
|
||||
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||
self.text_encoder_loras.extend(text_encoder_loras)
|
||||
skipped_te += skipped
|
||||
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||
logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||
|
||||
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
||||
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
|
||||
@@ -900,19 +1040,17 @@ class LoRANetwork(torch.nn.Module):
|
||||
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
|
||||
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
||||
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||
|
||||
skipped = skipped_te + skipped_un
|
||||
if varbose and len(skipped) > 0:
|
||||
print(
|
||||
logger.warning(
|
||||
f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
|
||||
)
|
||||
for name in skipped:
|
||||
print(f"\t{name}")
|
||||
logger.info(f"\t{name}")
|
||||
|
||||
self.up_lr_weight: List[float] = None
|
||||
self.down_lr_weight: List[float] = None
|
||||
self.mid_lr_weight: float = None
|
||||
self.block_lr_weight = None
|
||||
self.block_lr = False
|
||||
|
||||
# assertion
|
||||
@@ -926,6 +1064,10 @@ class LoRANetwork(torch.nn.Module):
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.multiplier = self.multiplier
|
||||
|
||||
def set_enabled(self, is_enabled):
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.enabled = is_enabled
|
||||
|
||||
def load_weights(self, file):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
@@ -939,12 +1081,12 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
|
||||
if apply_text_encoder:
|
||||
print("enable LoRA for text encoder")
|
||||
logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules")
|
||||
else:
|
||||
self.text_encoder_loras = []
|
||||
|
||||
if apply_unet:
|
||||
print("enable LoRA for U-Net")
|
||||
logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules")
|
||||
else:
|
||||
self.unet_loras = []
|
||||
|
||||
@@ -966,12 +1108,12 @@ class LoRANetwork(torch.nn.Module):
|
||||
apply_unet = True
|
||||
|
||||
if apply_text_encoder:
|
||||
print("enable LoRA for text encoder")
|
||||
logger.info("enable LoRA for text encoder")
|
||||
else:
|
||||
self.text_encoder_loras = []
|
||||
|
||||
if apply_unet:
|
||||
print("enable LoRA for U-Net")
|
||||
logger.info("enable LoRA for U-Net")
|
||||
else:
|
||||
self.unet_loras = []
|
||||
|
||||
@@ -982,84 +1124,120 @@ class LoRANetwork(torch.nn.Module):
|
||||
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
|
||||
lora.merge_to(sd_for_lora, dtype, device)
|
||||
|
||||
print(f"weights are merged")
|
||||
logger.info(f"weights are merged")
|
||||
|
||||
# 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない
|
||||
def set_block_lr_weight(
|
||||
self,
|
||||
up_lr_weight: List[float] = None,
|
||||
mid_lr_weight: float = None,
|
||||
down_lr_weight: List[float] = None,
|
||||
):
|
||||
def set_block_lr_weight(self, block_lr_weight: Optional[List[float]]):
|
||||
self.block_lr = True
|
||||
self.down_lr_weight = down_lr_weight
|
||||
self.mid_lr_weight = mid_lr_weight
|
||||
self.up_lr_weight = up_lr_weight
|
||||
self.block_lr_weight = block_lr_weight
|
||||
|
||||
def get_lr_weight(self, lora: LoRAModule) -> float:
|
||||
lr_weight = 1.0
|
||||
block_idx = get_block_index(lora.lora_name)
|
||||
if block_idx < 0:
|
||||
return lr_weight
|
||||
def get_lr_weight(self, block_idx: int) -> float:
|
||||
if not self.block_lr or self.block_lr_weight is None:
|
||||
return 1.0
|
||||
return self.block_lr_weight[block_idx]
|
||||
|
||||
if block_idx < LoRANetwork.NUM_OF_BLOCKS:
|
||||
if self.down_lr_weight != None:
|
||||
lr_weight = self.down_lr_weight[block_idx]
|
||||
elif block_idx == LoRANetwork.NUM_OF_BLOCKS:
|
||||
if self.mid_lr_weight != None:
|
||||
lr_weight = self.mid_lr_weight
|
||||
elif block_idx > LoRANetwork.NUM_OF_BLOCKS:
|
||||
if self.up_lr_weight != None:
|
||||
lr_weight = self.up_lr_weight[block_idx - LoRANetwork.NUM_OF_BLOCKS - 1]
|
||||
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
|
||||
self.loraplus_lr_ratio = loraplus_lr_ratio
|
||||
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
|
||||
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
|
||||
|
||||
return lr_weight
|
||||
logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}")
|
||||
logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
|
||||
|
||||
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
||||
self.requires_grad_(True)
|
||||
all_params = []
|
||||
# TODO warn if optimizer is not compatible with LoRA+ (but it will cause error so we don't need to check it here?)
|
||||
# if (
|
||||
# self.loraplus_lr_ratio is not None
|
||||
# or self.loraplus_text_encoder_lr_ratio is not None
|
||||
# or self.loraplus_unet_lr_ratio is not None
|
||||
# ):
|
||||
# assert (
|
||||
# optimizer_type.lower() != "prodigy" and "dadapt" not in optimizer_type.lower()
|
||||
# ), "LoRA+ and Prodigy/DAdaptation is not supported / LoRA+とProdigy/DAdaptationの組み合わせはサポートされていません"
|
||||
|
||||
def enumerate_params(loras):
|
||||
params = []
|
||||
self.requires_grad_(True)
|
||||
|
||||
all_params = []
|
||||
lr_descriptions = []
|
||||
|
||||
def assemble_params(loras, lr, ratio):
|
||||
param_groups = {"lora": {}, "plus": {}}
|
||||
for lora in loras:
|
||||
params.extend(lora.parameters())
|
||||
return params
|
||||
for name, param in lora.named_parameters():
|
||||
if ratio is not None and "lora_up" in name:
|
||||
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
|
||||
else:
|
||||
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
|
||||
|
||||
params = []
|
||||
descriptions = []
|
||||
for key in param_groups.keys():
|
||||
param_data = {"params": param_groups[key].values()}
|
||||
|
||||
if len(param_data["params"]) == 0:
|
||||
continue
|
||||
|
||||
if lr is not None:
|
||||
if key == "plus":
|
||||
param_data["lr"] = lr * ratio
|
||||
else:
|
||||
param_data["lr"] = lr
|
||||
|
||||
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
|
||||
logger.info("NO LR skipping!")
|
||||
continue
|
||||
|
||||
params.append(param_data)
|
||||
descriptions.append("plus" if key == "plus" else "")
|
||||
|
||||
return params, descriptions
|
||||
|
||||
if self.text_encoder_loras:
|
||||
param_data = {"params": enumerate_params(self.text_encoder_loras)}
|
||||
if text_encoder_lr is not None:
|
||||
param_data["lr"] = text_encoder_lr
|
||||
all_params.append(param_data)
|
||||
params, descriptions = assemble_params(
|
||||
self.text_encoder_loras,
|
||||
text_encoder_lr if text_encoder_lr is not None else default_lr,
|
||||
self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio,
|
||||
)
|
||||
all_params.extend(params)
|
||||
lr_descriptions.extend(["textencoder" + (" " + d if d else "") for d in descriptions])
|
||||
|
||||
if self.unet_loras:
|
||||
if self.block_lr:
|
||||
is_sdxl = False
|
||||
for lora in self.unet_loras:
|
||||
if "input_blocks" in lora.lora_name or "output_blocks" in lora.lora_name:
|
||||
is_sdxl = True
|
||||
break
|
||||
|
||||
# 学習率のグラフをblockごとにしたいので、blockごとにloraを分類
|
||||
block_idx_to_lora = {}
|
||||
for lora in self.unet_loras:
|
||||
idx = get_block_index(lora.lora_name)
|
||||
idx = get_block_index(lora.lora_name, is_sdxl)
|
||||
if idx not in block_idx_to_lora:
|
||||
block_idx_to_lora[idx] = []
|
||||
block_idx_to_lora[idx].append(lora)
|
||||
|
||||
# blockごとにパラメータを設定する
|
||||
for idx, block_loras in block_idx_to_lora.items():
|
||||
param_data = {"params": enumerate_params(block_loras)}
|
||||
|
||||
if unet_lr is not None:
|
||||
param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0])
|
||||
elif default_lr is not None:
|
||||
param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0])
|
||||
if ("lr" in param_data) and (param_data["lr"] == 0):
|
||||
continue
|
||||
all_params.append(param_data)
|
||||
params, descriptions = assemble_params(
|
||||
block_loras,
|
||||
(unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(idx),
|
||||
self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
|
||||
)
|
||||
all_params.extend(params)
|
||||
lr_descriptions.extend([f"unet_block{idx}" + (" " + d if d else "") for d in descriptions])
|
||||
|
||||
else:
|
||||
param_data = {"params": enumerate_params(self.unet_loras)}
|
||||
if unet_lr is not None:
|
||||
param_data["lr"] = unet_lr
|
||||
all_params.append(param_data)
|
||||
params, descriptions = assemble_params(
|
||||
self.unet_loras,
|
||||
unet_lr if unet_lr is not None else default_lr,
|
||||
self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
|
||||
)
|
||||
all_params.extend(params)
|
||||
lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions])
|
||||
|
||||
return all_params
|
||||
return all_params, lr_descriptions
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
# not supported
|
||||
@@ -1113,7 +1291,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.set_network(self)
|
||||
|
||||
def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared):
|
||||
def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared, ds_ratio=None):
|
||||
self.batch_size = batch_size
|
||||
self.num_sub_prompts = num_sub_prompts
|
||||
self.current_size = (height, width)
|
||||
@@ -1128,7 +1306,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
device = ref_weight.device
|
||||
|
||||
def resize_add(mh, mw):
|
||||
# print(mh, mw, mh * mw)
|
||||
# logger.info(mh, mw, mh * mw)
|
||||
m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16
|
||||
m = m.to(device, dtype=dtype)
|
||||
mask_dic[mh * mw] = m
|
||||
@@ -1139,6 +1317,13 @@ class LoRANetwork(torch.nn.Module):
|
||||
resize_add(h, w)
|
||||
if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2
|
||||
resize_add(h + h % 2, w + w % 2)
|
||||
|
||||
# deep shrink
|
||||
if ds_ratio is not None:
|
||||
hd = int(h * ds_ratio)
|
||||
wd = int(w * ds_ratio)
|
||||
resize_add(hd, wd)
|
||||
|
||||
h = (h + 1) // 2
|
||||
w = (w + 1) // 2
|
||||
|
||||
|
||||
@@ -9,8 +9,15 @@ from diffusers import UNet2DConditionModel
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPTextModel
|
||||
import torch
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, get_preferred_device
|
||||
init_ipex()
|
||||
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def make_unet_conversion_map() -> Dict[str, str]:
|
||||
unet_conversion_map_layer = []
|
||||
@@ -248,7 +255,7 @@ def create_network_from_weights(
|
||||
elif "lora_down" in key:
|
||||
dim = value.size()[0]
|
||||
modules_dim[lora_name] = dim
|
||||
# print(lora_name, value.size(), dim)
|
||||
# logger.info(f"{lora_name} {value.size()} {dim}")
|
||||
|
||||
# support old LoRA without alpha
|
||||
for key in modules_dim.keys():
|
||||
@@ -271,7 +278,7 @@ def merge_lora_weights(pipe, weights_sd: Dict, multiplier: float = 1.0):
|
||||
class LoRANetwork(torch.nn.Module):
|
||||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
||||
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
|
||||
LORA_PREFIX_UNET = "lora_unet"
|
||||
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
||||
|
||||
@@ -291,12 +298,12 @@ class LoRANetwork(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.multiplier = multiplier
|
||||
|
||||
print(f"create LoRA network from weights")
|
||||
logger.info("create LoRA network from weights")
|
||||
|
||||
# convert SDXL Stability AI's U-Net modules to Diffusers
|
||||
converted = self.convert_unet_modules(modules_dim, modules_alpha)
|
||||
if converted:
|
||||
print(f"converted {converted} Stability AI's U-Net LoRA modules to Diffusers (SDXL)")
|
||||
logger.info(f"converted {converted} Stability AI's U-Net LoRA modules to Diffusers (SDXL)")
|
||||
|
||||
# create module instances
|
||||
def create_modules(
|
||||
@@ -331,7 +338,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
lora_name = lora_name.replace(".", "_")
|
||||
|
||||
if lora_name not in modules_dim:
|
||||
# print(f"skipped {lora_name} (not found in modules_dim)")
|
||||
# logger.info(f"skipped {lora_name} (not found in modules_dim)")
|
||||
skipped.append(lora_name)
|
||||
continue
|
||||
|
||||
@@ -362,18 +369,18 @@ class LoRANetwork(torch.nn.Module):
|
||||
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||
self.text_encoder_loras.extend(text_encoder_loras)
|
||||
skipped_te += skipped
|
||||
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||
logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||
if len(skipped_te) > 0:
|
||||
print(f"skipped {len(skipped_te)} modules because of missing weight for text encoder.")
|
||||
logger.warning(f"skipped {len(skipped_te)} modules because of missing weight for text encoder.")
|
||||
|
||||
# extend U-Net target modules to include Conv2d 3x3
|
||||
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
|
||||
self.unet_loras: List[LoRAModule]
|
||||
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
||||
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||
if len(skipped_un) > 0:
|
||||
print(f"skipped {len(skipped_un)} modules because of missing weight for U-Net.")
|
||||
logger.warning(f"skipped {len(skipped_un)} modules because of missing weight for U-Net.")
|
||||
|
||||
# assertion
|
||||
names = set()
|
||||
@@ -420,11 +427,11 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
def apply_to(self, multiplier=1.0, apply_text_encoder=True, apply_unet=True):
|
||||
if apply_text_encoder:
|
||||
print("enable LoRA for text encoder")
|
||||
logger.info("enable LoRA for text encoder")
|
||||
for lora in self.text_encoder_loras:
|
||||
lora.apply_to(multiplier)
|
||||
if apply_unet:
|
||||
print("enable LoRA for U-Net")
|
||||
logger.info("enable LoRA for U-Net")
|
||||
for lora in self.unet_loras:
|
||||
lora.apply_to(multiplier)
|
||||
|
||||
@@ -433,16 +440,16 @@ class LoRANetwork(torch.nn.Module):
|
||||
lora.unapply_to()
|
||||
|
||||
def merge_to(self, multiplier=1.0):
|
||||
print("merge LoRA weights to original weights")
|
||||
logger.info("merge LoRA weights to original weights")
|
||||
for lora in tqdm(self.text_encoder_loras + self.unet_loras):
|
||||
lora.merge_to(multiplier)
|
||||
print(f"weights are merged")
|
||||
logger.info(f"weights are merged")
|
||||
|
||||
def restore_from(self, multiplier=1.0):
|
||||
print("restore LoRA weights from original weights")
|
||||
logger.info("restore LoRA weights from original weights")
|
||||
for lora in tqdm(self.text_encoder_loras + self.unet_loras):
|
||||
lora.restore_from(multiplier)
|
||||
print(f"weights are restored")
|
||||
logger.info(f"weights are restored")
|
||||
|
||||
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
||||
# convert SDXL Stability AI's state dict to Diffusers' based state dict
|
||||
@@ -463,7 +470,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
my_state_dict = self.state_dict()
|
||||
for key in state_dict.keys():
|
||||
if state_dict[key].size() != my_state_dict[key].size():
|
||||
# print(f"convert {key} from {state_dict[key].size()} to {my_state_dict[key].size()}")
|
||||
# logger.info(f"convert {key} from {state_dict[key].size()} to {my_state_dict[key].size()}")
|
||||
state_dict[key] = state_dict[key].view(my_state_dict[key].size())
|
||||
|
||||
return super().load_state_dict(state_dict, strict)
|
||||
@@ -476,7 +483,7 @@ if __name__ == "__main__":
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
|
||||
import torch
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
device = get_preferred_device()
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_id", type=str, default=None, help="model id for huggingface")
|
||||
@@ -490,7 +497,7 @@ if __name__ == "__main__":
|
||||
image_prefix = args.model_id.replace("/", "_") + "_"
|
||||
|
||||
# load Diffusers model
|
||||
print(f"load model from {args.model_id}")
|
||||
logger.info(f"load model from {args.model_id}")
|
||||
pipe: Union[StableDiffusionPipeline, StableDiffusionXLPipeline]
|
||||
if args.sdxl:
|
||||
# use_safetensors=True does not work with 0.18.2
|
||||
@@ -503,7 +510,7 @@ if __name__ == "__main__":
|
||||
text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if args.sdxl else [pipe.text_encoder]
|
||||
|
||||
# load LoRA weights
|
||||
print(f"load LoRA weights from {args.lora_weights}")
|
||||
logger.info(f"load LoRA weights from {args.lora_weights}")
|
||||
if os.path.splitext(args.lora_weights)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
@@ -512,10 +519,10 @@ if __name__ == "__main__":
|
||||
lora_sd = torch.load(args.lora_weights)
|
||||
|
||||
# create by LoRA weights and load weights
|
||||
print(f"create LoRA network")
|
||||
logger.info(f"create LoRA network")
|
||||
lora_network: LoRANetwork = create_network_from_weights(text_encoders, pipe.unet, lora_sd, multiplier=1.0)
|
||||
|
||||
print(f"load LoRA network weights")
|
||||
logger.info(f"load LoRA network weights")
|
||||
lora_network.load_state_dict(lora_sd)
|
||||
|
||||
lora_network.to(device, dtype=pipe.unet.dtype) # required to apply_to. merge_to works without this
|
||||
@@ -544,34 +551,34 @@ if __name__ == "__main__":
|
||||
random.seed(seed)
|
||||
|
||||
# create image with original weights
|
||||
print(f"create image with original weights")
|
||||
logger.info(f"create image with original weights")
|
||||
seed_everything(args.seed)
|
||||
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
||||
image.save(image_prefix + "original.png")
|
||||
|
||||
# apply LoRA network to the model: slower than merge_to, but can be reverted easily
|
||||
print(f"apply LoRA network to the model")
|
||||
logger.info(f"apply LoRA network to the model")
|
||||
lora_network.apply_to(multiplier=1.0)
|
||||
|
||||
print(f"create image with applied LoRA")
|
||||
logger.info(f"create image with applied LoRA")
|
||||
seed_everything(args.seed)
|
||||
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
||||
image.save(image_prefix + "applied_lora.png")
|
||||
|
||||
# unapply LoRA network to the model
|
||||
print(f"unapply LoRA network to the model")
|
||||
logger.info(f"unapply LoRA network to the model")
|
||||
lora_network.unapply_to()
|
||||
|
||||
print(f"create image with unapplied LoRA")
|
||||
logger.info(f"create image with unapplied LoRA")
|
||||
seed_everything(args.seed)
|
||||
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
||||
image.save(image_prefix + "unapplied_lora.png")
|
||||
|
||||
# merge LoRA network to the model: faster than apply_to, but requires back-up of original weights (or unmerge_to)
|
||||
print(f"merge LoRA network to the model")
|
||||
logger.info(f"merge LoRA network to the model")
|
||||
lora_network.merge_to(multiplier=1.0)
|
||||
|
||||
print(f"create image with LoRA")
|
||||
logger.info(f"create image with LoRA")
|
||||
seed_everything(args.seed)
|
||||
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
||||
image.save(image_prefix + "merged_lora.png")
|
||||
@@ -579,31 +586,31 @@ if __name__ == "__main__":
|
||||
# restore (unmerge) LoRA weights: numerically unstable
|
||||
# マージされた重みを元に戻す。計算誤差のため、元の重みと完全に一致しないことがあるかもしれない
|
||||
# 保存したstate_dictから元の重みを復元するのが確実
|
||||
print(f"restore (unmerge) LoRA weights")
|
||||
logger.info(f"restore (unmerge) LoRA weights")
|
||||
lora_network.restore_from(multiplier=1.0)
|
||||
|
||||
print(f"create image without LoRA")
|
||||
logger.info(f"create image without LoRA")
|
||||
seed_everything(args.seed)
|
||||
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
||||
image.save(image_prefix + "unmerged_lora.png")
|
||||
|
||||
# restore original weights
|
||||
print(f"restore original weights")
|
||||
logger.info(f"restore original weights")
|
||||
pipe.unet.load_state_dict(org_unet_sd)
|
||||
pipe.text_encoder.load_state_dict(org_text_encoder_sd)
|
||||
if args.sdxl:
|
||||
pipe.text_encoder_2.load_state_dict(org_text_encoder_2_sd)
|
||||
|
||||
print(f"create image with restored original weights")
|
||||
logger.info(f"create image with restored original weights")
|
||||
seed_everything(args.seed)
|
||||
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
||||
image.save(image_prefix + "restore_original.png")
|
||||
|
||||
# use convenience function to merge LoRA weights
|
||||
print(f"merge LoRA weights with convenience function")
|
||||
logger.info(f"merge LoRA weights with convenience function")
|
||||
merge_lora_weights(pipe, lora_sd, multiplier=1.0)
|
||||
|
||||
print(f"create image with merged LoRA weights")
|
||||
logger.info(f"create image with merged LoRA weights")
|
||||
seed_everything(args.seed)
|
||||
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
||||
image.save(image_prefix + "convenience_merged_lora.png")
|
||||
|
||||
@@ -14,7 +14,10 @@ from transformers import CLIPTextModel
|
||||
import numpy as np
|
||||
import torch
|
||||
import re
|
||||
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
||||
|
||||
@@ -49,7 +52,7 @@ class LoRAModule(torch.nn.Module):
|
||||
# if limit_rank:
|
||||
# self.lora_dim = min(lora_dim, in_dim, out_dim)
|
||||
# if self.lora_dim != lora_dim:
|
||||
# print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
|
||||
# logger.info(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
|
||||
# else:
|
||||
self.lora_dim = lora_dim
|
||||
|
||||
@@ -197,7 +200,7 @@ class LoRAInfModule(LoRAModule):
|
||||
else:
|
||||
# conv2d 3x3
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
# print(conved.size(), weight.size(), module.stride, module.padding)
|
||||
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
|
||||
weight = weight + self.multiplier * conved * self.scale
|
||||
|
||||
# set weight to org_module
|
||||
@@ -236,7 +239,7 @@ class LoRAInfModule(LoRAModule):
|
||||
self.region_mask = None
|
||||
|
||||
def default_forward(self, x):
|
||||
# print("default_forward", self.lora_name, x.size())
|
||||
# logger.info("default_forward", self.lora_name, x.size())
|
||||
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||
|
||||
def forward(self, x):
|
||||
@@ -278,7 +281,7 @@ class LoRAInfModule(LoRAModule):
|
||||
# apply mask for LoRA result
|
||||
lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||
mask = self.get_mask_for_x(lx)
|
||||
# print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
|
||||
# logger.info("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
|
||||
lx = lx * mask
|
||||
|
||||
x = self.org_forward(x)
|
||||
@@ -307,7 +310,7 @@ class LoRAInfModule(LoRAModule):
|
||||
if has_real_uncond:
|
||||
query[-self.network.batch_size :] = x[-self.network.batch_size :]
|
||||
|
||||
# print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts)
|
||||
# logger.info("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts)
|
||||
return query
|
||||
|
||||
def sub_prompt_forward(self, x):
|
||||
@@ -322,7 +325,7 @@ class LoRAInfModule(LoRAModule):
|
||||
lx = x[emb_idx :: self.network.num_sub_prompts]
|
||||
lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale
|
||||
|
||||
# print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx)
|
||||
# logger.info("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx)
|
||||
|
||||
x = self.org_forward(x)
|
||||
x[emb_idx :: self.network.num_sub_prompts] += lx
|
||||
@@ -330,7 +333,7 @@ class LoRAInfModule(LoRAModule):
|
||||
return x
|
||||
|
||||
def to_out_forward(self, x):
|
||||
# print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network)
|
||||
# logger.info("to_out_forward", self.lora_name, x.size(), self.network.is_last_network)
|
||||
|
||||
if self.network.is_last_network:
|
||||
masks = [None] * self.network.num_sub_prompts
|
||||
@@ -348,7 +351,7 @@ class LoRAInfModule(LoRAModule):
|
||||
)
|
||||
self.network.shared[self.lora_name] = (lx, masks)
|
||||
|
||||
# print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts)
|
||||
# logger.info("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts)
|
||||
lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1
|
||||
masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1)
|
||||
|
||||
@@ -367,7 +370,7 @@ class LoRAInfModule(LoRAModule):
|
||||
if has_real_uncond:
|
||||
out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond
|
||||
|
||||
# print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts)
|
||||
# logger.info("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts)
|
||||
# for i in range(len(masks)):
|
||||
# if masks[i] is None:
|
||||
# masks[i] = torch.zeros_like(masks[-1])
|
||||
@@ -389,7 +392,7 @@ class LoRAInfModule(LoRAModule):
|
||||
x1 = x1 + lx1
|
||||
out[self.network.batch_size + i] = x1
|
||||
|
||||
# print("to_out_forward", x.size(), out.size(), has_real_uncond)
|
||||
# logger.info("to_out_forward", x.size(), out.size(), has_real_uncond)
|
||||
return out
|
||||
|
||||
|
||||
@@ -526,7 +529,7 @@ def get_block_dims_and_alphas(
|
||||
len(block_dims) == num_total_blocks
|
||||
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
|
||||
else:
|
||||
print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
|
||||
logger.warning(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
|
||||
block_dims = [network_dim] * num_total_blocks
|
||||
|
||||
if block_alphas is not None:
|
||||
@@ -535,7 +538,7 @@ def get_block_dims_and_alphas(
|
||||
len(block_alphas) == num_total_blocks
|
||||
), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください"
|
||||
else:
|
||||
print(
|
||||
logger.warning(
|
||||
f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります"
|
||||
)
|
||||
block_alphas = [network_alpha] * num_total_blocks
|
||||
@@ -555,13 +558,13 @@ def get_block_dims_and_alphas(
|
||||
else:
|
||||
if conv_alpha is None:
|
||||
conv_alpha = 1.0
|
||||
print(
|
||||
logger.warning(
|
||||
f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります"
|
||||
)
|
||||
conv_block_alphas = [conv_alpha] * num_total_blocks
|
||||
else:
|
||||
if conv_dim is not None:
|
||||
print(
|
||||
logger.warning(
|
||||
f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります"
|
||||
)
|
||||
conv_block_dims = [conv_dim] * num_total_blocks
|
||||
@@ -601,7 +604,7 @@ def get_block_lr_weight(
|
||||
elif name == "zeros":
|
||||
return [0.0 + base_lr] * max_len
|
||||
else:
|
||||
print(
|
||||
logger.error(
|
||||
"Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
|
||||
% (name)
|
||||
)
|
||||
@@ -613,14 +616,14 @@ def get_block_lr_weight(
|
||||
up_lr_weight = get_list(up_lr_weight)
|
||||
|
||||
if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len):
|
||||
print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
|
||||
print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
|
||||
logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
|
||||
logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
|
||||
up_lr_weight = up_lr_weight[:max_len]
|
||||
down_lr_weight = down_lr_weight[:max_len]
|
||||
|
||||
if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len):
|
||||
print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
|
||||
print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
|
||||
logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
|
||||
logger.warning("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
|
||||
|
||||
if down_lr_weight != None and len(down_lr_weight) < max_len:
|
||||
down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
|
||||
@@ -628,24 +631,24 @@ def get_block_lr_weight(
|
||||
up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
|
||||
|
||||
if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
|
||||
print("apply block learning rate / 階層別学習率を適用します。")
|
||||
logger.info("apply block learning rate / 階層別学習率を適用します。")
|
||||
if down_lr_weight != None:
|
||||
down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
|
||||
print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", down_lr_weight)
|
||||
logger.info(f"down_lr_weight (shallower -> deeper, 浅い層->深い層): {down_lr_weight}")
|
||||
else:
|
||||
print("down_lr_weight: all 1.0, すべて1.0")
|
||||
logger.info("down_lr_weight: all 1.0, すべて1.0")
|
||||
|
||||
if mid_lr_weight != None:
|
||||
mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
|
||||
print("mid_lr_weight:", mid_lr_weight)
|
||||
logger.info(f"mid_lr_weight: {mid_lr_weight}")
|
||||
else:
|
||||
print("mid_lr_weight: 1.0")
|
||||
logger.info("mid_lr_weight: 1.0")
|
||||
|
||||
if up_lr_weight != None:
|
||||
up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
|
||||
print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", up_lr_weight)
|
||||
logger.info(f"up_lr_weight (deeper -> shallower, 深い層->浅い層): {up_lr_weight}")
|
||||
else:
|
||||
print("up_lr_weight: all 1.0, すべて1.0")
|
||||
logger.info("up_lr_weight: all 1.0, すべて1.0")
|
||||
|
||||
return down_lr_weight, mid_lr_weight, up_lr_weight
|
||||
|
||||
@@ -726,7 +729,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
|
||||
elif "lora_down" in key:
|
||||
dim = value.size()[0]
|
||||
modules_dim[lora_name] = dim
|
||||
# print(lora_name, value.size(), dim)
|
||||
# logger.info(lora_name, value.size(), dim)
|
||||
|
||||
# support old LoRA without alpha
|
||||
for key in modules_dim.keys():
|
||||
@@ -752,7 +755,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
||||
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
|
||||
LORA_PREFIX_UNET = "lora_unet"
|
||||
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
||||
|
||||
@@ -801,20 +804,20 @@ class LoRANetwork(torch.nn.Module):
|
||||
self.module_dropout = module_dropout
|
||||
|
||||
if modules_dim is not None:
|
||||
print(f"create LoRA network from weights")
|
||||
logger.info(f"create LoRA network from weights")
|
||||
elif block_dims is not None:
|
||||
print(f"create LoRA network from block_dims")
|
||||
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
||||
print(f"block_dims: {block_dims}")
|
||||
print(f"block_alphas: {block_alphas}")
|
||||
logger.info(f"create LoRA network from block_dims")
|
||||
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
||||
logger.info(f"block_dims: {block_dims}")
|
||||
logger.info(f"block_alphas: {block_alphas}")
|
||||
if conv_block_dims is not None:
|
||||
print(f"conv_block_dims: {conv_block_dims}")
|
||||
print(f"conv_block_alphas: {conv_block_alphas}")
|
||||
logger.info(f"conv_block_dims: {conv_block_dims}")
|
||||
logger.info(f"conv_block_alphas: {conv_block_alphas}")
|
||||
else:
|
||||
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
||||
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
||||
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
||||
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
||||
if self.conv_lora_dim is not None:
|
||||
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
||||
logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
||||
|
||||
# create module instances
|
||||
def create_modules(
|
||||
@@ -899,15 +902,15 @@ class LoRANetwork(torch.nn.Module):
|
||||
for i, text_encoder in enumerate(text_encoders):
|
||||
if len(text_encoders) > 1:
|
||||
index = i + 1
|
||||
print(f"create LoRA for Text Encoder {index}:")
|
||||
logger.info(f"create LoRA for Text Encoder {index}:")
|
||||
else:
|
||||
index = None
|
||||
print(f"create LoRA for Text Encoder:")
|
||||
logger.info(f"create LoRA for Text Encoder:")
|
||||
|
||||
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||
self.text_encoder_loras.extend(text_encoder_loras)
|
||||
skipped_te += skipped
|
||||
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||
logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||
|
||||
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
||||
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
|
||||
@@ -915,15 +918,15 @@ class LoRANetwork(torch.nn.Module):
|
||||
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
|
||||
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
||||
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||
|
||||
skipped = skipped_te + skipped_un
|
||||
if varbose and len(skipped) > 0:
|
||||
print(
|
||||
logger.warning(
|
||||
f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
|
||||
)
|
||||
for name in skipped:
|
||||
print(f"\t{name}")
|
||||
logger.info(f"\t{name}")
|
||||
|
||||
self.up_lr_weight: List[float] = None
|
||||
self.down_lr_weight: List[float] = None
|
||||
@@ -954,12 +957,12 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
|
||||
if apply_text_encoder:
|
||||
print("enable LoRA for text encoder")
|
||||
logger.info("enable LoRA for text encoder")
|
||||
else:
|
||||
self.text_encoder_loras = []
|
||||
|
||||
if apply_unet:
|
||||
print("enable LoRA for U-Net")
|
||||
logger.info("enable LoRA for U-Net")
|
||||
else:
|
||||
self.unet_loras = []
|
||||
|
||||
@@ -981,12 +984,12 @@ class LoRANetwork(torch.nn.Module):
|
||||
apply_unet = True
|
||||
|
||||
if apply_text_encoder:
|
||||
print("enable LoRA for text encoder")
|
||||
logger.info("enable LoRA for text encoder")
|
||||
else:
|
||||
self.text_encoder_loras = []
|
||||
|
||||
if apply_unet:
|
||||
print("enable LoRA for U-Net")
|
||||
logger.info("enable LoRA for U-Net")
|
||||
else:
|
||||
self.unet_loras = []
|
||||
|
||||
@@ -997,7 +1000,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
|
||||
lora.merge_to(sd_for_lora, dtype, device)
|
||||
|
||||
print(f"weights are merged")
|
||||
logger.info(f"weights are merged")
|
||||
|
||||
# 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない
|
||||
def set_block_lr_weight(
|
||||
@@ -1144,7 +1147,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
device = ref_weight.device
|
||||
|
||||
def resize_add(mh, mw):
|
||||
# print(mh, mw, mh * mw)
|
||||
# logger.info(mh, mw, mh * mw)
|
||||
m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16
|
||||
m = m.to(device, dtype=dtype)
|
||||
mask_dic[mh * mw] = m
|
||||
|
||||
1344
networks/lora_flux.py
Normal file
1344
networks/lora_flux.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -5,27 +5,34 @@ from library import model_util
|
||||
import library.train_util as train_util
|
||||
import argparse
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, get_preferred_device
|
||||
init_ipex()
|
||||
|
||||
import library.model_util as model_util
|
||||
import lora
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
||||
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
|
||||
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
DEVICE = get_preferred_device()
|
||||
|
||||
|
||||
def interrogate(args):
|
||||
weights_dtype = torch.float16
|
||||
|
||||
# いろいろ準備する
|
||||
print(f"loading SD model: {args.sd_model}")
|
||||
logger.info(f"loading SD model: {args.sd_model}")
|
||||
args.pretrained_model_name_or_path = args.sd_model
|
||||
args.vae = None
|
||||
text_encoder, vae, unet, _ = train_util._load_target_model(args,weights_dtype, DEVICE)
|
||||
|
||||
print(f"loading LoRA: {args.model}")
|
||||
logger.info(f"loading LoRA: {args.model}")
|
||||
network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
|
||||
|
||||
# text encoder向けの重みがあるかチェックする:本当はlora側でやるのがいい
|
||||
@@ -35,11 +42,11 @@ def interrogate(args):
|
||||
has_te_weight = True
|
||||
break
|
||||
if not has_te_weight:
|
||||
print("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません")
|
||||
logger.error("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません")
|
||||
return
|
||||
del vae
|
||||
|
||||
print("loading tokenizer")
|
||||
logger.info("loading tokenizer")
|
||||
if args.v2:
|
||||
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
|
||||
else:
|
||||
@@ -53,7 +60,7 @@ def interrogate(args):
|
||||
# トークンをひとつひとつ当たっていく
|
||||
token_id_start = 0
|
||||
token_id_end = max(tokenizer.all_special_ids)
|
||||
print(f"interrogate tokens are: {token_id_start} to {token_id_end}")
|
||||
logger.info(f"interrogate tokens are: {token_id_start} to {token_id_end}")
|
||||
|
||||
def get_all_embeddings(text_encoder):
|
||||
embs = []
|
||||
@@ -79,24 +86,24 @@ def interrogate(args):
|
||||
embs.extend(encoder_hidden_states)
|
||||
return torch.stack(embs)
|
||||
|
||||
print("get original text encoder embeddings.")
|
||||
logger.info("get original text encoder embeddings.")
|
||||
orig_embs = get_all_embeddings(text_encoder)
|
||||
|
||||
network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0)
|
||||
info = network.load_state_dict(weights_sd, strict=False)
|
||||
print(f"Loading LoRA weights: {info}")
|
||||
logger.info(f"Loading LoRA weights: {info}")
|
||||
|
||||
network.to(DEVICE, dtype=weights_dtype)
|
||||
network.eval()
|
||||
|
||||
del unet
|
||||
|
||||
print("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません(以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません)")
|
||||
print("get text encoder embeddings with lora.")
|
||||
logger.info("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません(以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません)")
|
||||
logger.info("get text encoder embeddings with lora.")
|
||||
lora_embs = get_all_embeddings(text_encoder)
|
||||
|
||||
# 比べる:とりあえず単純に差分の絶対値で
|
||||
print("comparing...")
|
||||
logger.info("comparing...")
|
||||
diffs = {}
|
||||
for i, (orig_emb, lora_emb) in enumerate(zip(orig_embs, tqdm(lora_embs))):
|
||||
diff = torch.mean(torch.abs(orig_emb - lora_emb))
|
||||
|
||||
839
networks/lora_sd3.py
Normal file
839
networks/lora_sd3.py
Normal file
@@ -0,0 +1,839 @@
|
||||
# temporary minimum implementation of LoRA
|
||||
# SD3 doesn't have Conv2d, so we ignore it
|
||||
# TODO commonize with the original/SD3/FLUX implementation
|
||||
|
||||
# LoRA network module
|
||||
# reference:
|
||||
# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
|
||||
# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
|
||||
|
||||
import math
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
from transformers import CLIPTextModelWithProjection, T5EncoderModel
|
||||
import numpy as np
|
||||
import torch
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from networks.lora_flux import LoRAModule, LoRAInfModule
|
||||
from library import sd3_models
|
||||
|
||||
|
||||
def create_network(
|
||||
multiplier: float,
|
||||
network_dim: Optional[int],
|
||||
network_alpha: Optional[float],
|
||||
vae: sd3_models.SDVAE,
|
||||
text_encoders: List[Union[CLIPTextModelWithProjection, T5EncoderModel]],
|
||||
mmdit,
|
||||
neuron_dropout: Optional[float] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if network_dim is None:
|
||||
network_dim = 4 # default
|
||||
if network_alpha is None:
|
||||
network_alpha = 1.0
|
||||
|
||||
# extract dim/alpha for conv2d, and block dim
|
||||
conv_dim = kwargs.get("conv_dim", None)
|
||||
conv_alpha = kwargs.get("conv_alpha", None)
|
||||
if conv_dim is not None:
|
||||
conv_dim = int(conv_dim)
|
||||
if conv_alpha is None:
|
||||
conv_alpha = 1.0
|
||||
else:
|
||||
conv_alpha = float(conv_alpha)
|
||||
|
||||
# attn dim, mlp dim: only for DoubleStreamBlock. SingleStreamBlock is not supported because of combined qkv
|
||||
context_attn_dim = kwargs.get("context_attn_dim", None)
|
||||
context_mlp_dim = kwargs.get("context_mlp_dim", None)
|
||||
context_mod_dim = kwargs.get("context_mod_dim", None)
|
||||
x_attn_dim = kwargs.get("x_attn_dim", None)
|
||||
x_mlp_dim = kwargs.get("x_mlp_dim", None)
|
||||
x_mod_dim = kwargs.get("x_mod_dim", None)
|
||||
if context_attn_dim is not None:
|
||||
context_attn_dim = int(context_attn_dim)
|
||||
if context_mlp_dim is not None:
|
||||
context_mlp_dim = int(context_mlp_dim)
|
||||
if context_mod_dim is not None:
|
||||
context_mod_dim = int(context_mod_dim)
|
||||
if x_attn_dim is not None:
|
||||
x_attn_dim = int(x_attn_dim)
|
||||
if x_mlp_dim is not None:
|
||||
x_mlp_dim = int(x_mlp_dim)
|
||||
if x_mod_dim is not None:
|
||||
x_mod_dim = int(x_mod_dim)
|
||||
type_dims = [context_attn_dim, context_mlp_dim, context_mod_dim, x_attn_dim, x_mlp_dim, x_mod_dim]
|
||||
if all([d is None for d in type_dims]):
|
||||
type_dims = None
|
||||
|
||||
# emb_dims [context_embedder, t_embedder, x_embedder, y_embedder, final_mod, final_linear]
|
||||
emb_dims = kwargs.get("emb_dims", None)
|
||||
if emb_dims is not None:
|
||||
emb_dims = emb_dims.strip()
|
||||
if emb_dims.startswith("[") and emb_dims.endswith("]"):
|
||||
emb_dims = emb_dims[1:-1]
|
||||
emb_dims = [int(d) for d in emb_dims.split(",")] # is it better to use ast.literal_eval?
|
||||
assert len(emb_dims) == 6, f"invalid emb_dims: {emb_dims}, must be 6 dimensions (context, t, x, y, final_mod, final_linear)"
|
||||
|
||||
# double/single train blocks
|
||||
def parse_block_selection(selection: str, total_blocks: int) -> List[bool]:
|
||||
"""
|
||||
Parse a block selection string and return a list of booleans.
|
||||
|
||||
Args:
|
||||
selection (str): A string specifying which blocks to select.
|
||||
total_blocks (int): The total number of blocks available.
|
||||
|
||||
Returns:
|
||||
List[bool]: A list of booleans indicating which blocks are selected.
|
||||
"""
|
||||
if selection == "all":
|
||||
return [True] * total_blocks
|
||||
if selection == "none" or selection == "":
|
||||
return [False] * total_blocks
|
||||
|
||||
selected = [False] * total_blocks
|
||||
ranges = selection.split(",")
|
||||
|
||||
for r in ranges:
|
||||
if "-" in r:
|
||||
start, end = map(str.strip, r.split("-"))
|
||||
start = int(start)
|
||||
end = int(end)
|
||||
assert 0 <= start < total_blocks, f"invalid start index: {start}"
|
||||
assert 0 <= end < total_blocks, f"invalid end index: {end}"
|
||||
assert start <= end, f"invalid range: {start}-{end}"
|
||||
for i in range(start, end + 1):
|
||||
selected[i] = True
|
||||
else:
|
||||
index = int(r)
|
||||
assert 0 <= index < total_blocks, f"invalid index: {index}"
|
||||
selected[index] = True
|
||||
|
||||
return selected
|
||||
|
||||
train_block_indices = kwargs.get("train_block_indices", None)
|
||||
if train_block_indices is not None:
|
||||
train_block_indices = parse_block_selection(train_block_indices, 999) # 999 is a dummy number
|
||||
|
||||
# rank/module dropout
|
||||
rank_dropout = kwargs.get("rank_dropout", None)
|
||||
if rank_dropout is not None:
|
||||
rank_dropout = float(rank_dropout)
|
||||
module_dropout = kwargs.get("module_dropout", None)
|
||||
if module_dropout is not None:
|
||||
module_dropout = float(module_dropout)
|
||||
|
||||
# split qkv
|
||||
split_qkv = kwargs.get("split_qkv", False)
|
||||
if split_qkv is not None:
|
||||
split_qkv = True if split_qkv == "True" else False
|
||||
|
||||
# train T5XXL
|
||||
train_t5xxl = kwargs.get("train_t5xxl", False)
|
||||
if train_t5xxl is not None:
|
||||
train_t5xxl = True if train_t5xxl == "True" else False
|
||||
|
||||
# verbose
|
||||
verbose = kwargs.get("verbose", False)
|
||||
if verbose is not None:
|
||||
verbose = True if verbose == "True" else False
|
||||
|
||||
# すごく引数が多いな ( ^ω^)・・・
|
||||
network = LoRANetwork(
|
||||
text_encoders,
|
||||
mmdit,
|
||||
multiplier=multiplier,
|
||||
lora_dim=network_dim,
|
||||
alpha=network_alpha,
|
||||
dropout=neuron_dropout,
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
conv_lora_dim=conv_dim,
|
||||
conv_alpha=conv_alpha,
|
||||
split_qkv=split_qkv,
|
||||
train_t5xxl=train_t5xxl,
|
||||
type_dims=type_dims,
|
||||
emb_dims=emb_dims,
|
||||
train_block_indices=train_block_indices,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
|
||||
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
|
||||
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
|
||||
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
|
||||
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
|
||||
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
|
||||
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
|
||||
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
|
||||
|
||||
return network
|
||||
|
||||
|
||||
# Create network from weights for inference, weights are not loaded here (because can be merged)
|
||||
def create_network_from_weights(multiplier, file, ae, text_encoders, mmdit, weights_sd=None, for_inference=False, **kwargs):
|
||||
# if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True
|
||||
if weights_sd is None:
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file, safe_open
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
|
||||
# get dim/alpha mapping, and train t5xxl
|
||||
modules_dim = {}
|
||||
modules_alpha = {}
|
||||
train_t5xxl = None
|
||||
for key, value in weights_sd.items():
|
||||
if "." not in key:
|
||||
continue
|
||||
|
||||
lora_name = key.split(".")[0]
|
||||
if "alpha" in key:
|
||||
modules_alpha[lora_name] = value
|
||||
elif "lora_down" in key:
|
||||
dim = value.size()[0]
|
||||
modules_dim[lora_name] = dim
|
||||
# logger.info(lora_name, value.size(), dim)
|
||||
|
||||
if train_t5xxl is None or train_t5xxl is False:
|
||||
train_t5xxl = "lora_te3" in lora_name
|
||||
|
||||
if train_t5xxl is None:
|
||||
train_t5xxl = False
|
||||
|
||||
split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined
|
||||
|
||||
module_class = LoRAInfModule if for_inference else LoRAModule
|
||||
|
||||
network = LoRANetwork(
|
||||
text_encoders,
|
||||
mmdit,
|
||||
multiplier=multiplier,
|
||||
modules_dim=modules_dim,
|
||||
modules_alpha=modules_alpha,
|
||||
module_class=module_class,
|
||||
split_qkv=split_qkv,
|
||||
train_t5xxl=train_t5xxl,
|
||||
)
|
||||
return network, weights_sd
|
||||
|
||||
|
||||
class LoRANetwork(torch.nn.Module):
|
||||
SD3_TARGET_REPLACE_MODULE = ["SingleDiTBlock"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP", "T5Attention", "T5DenseGatedActDense"]
|
||||
LORA_PREFIX_SD3 = "lora_unet" # make ComfyUI compatible
|
||||
LORA_PREFIX_TEXT_ENCODER_CLIP_L = "lora_te1"
|
||||
LORA_PREFIX_TEXT_ENCODER_CLIP_G = "lora_te2"
|
||||
LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te3" # make ComfyUI compatible
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_encoders: List[Union[CLIPTextModelWithProjection, T5EncoderModel]],
|
||||
unet: sd3_models.MMDiT,
|
||||
multiplier: float = 1.0,
|
||||
lora_dim: int = 4,
|
||||
alpha: float = 1,
|
||||
dropout: Optional[float] = None,
|
||||
rank_dropout: Optional[float] = None,
|
||||
module_dropout: Optional[float] = None,
|
||||
conv_lora_dim: Optional[int] = None,
|
||||
conv_alpha: Optional[float] = None,
|
||||
module_class: Type[object] = LoRAModule,
|
||||
modules_dim: Optional[Dict[str, int]] = None,
|
||||
modules_alpha: Optional[Dict[str, int]] = None,
|
||||
split_qkv: bool = False,
|
||||
train_t5xxl: bool = False,
|
||||
type_dims: Optional[List[int]] = None,
|
||||
emb_dims: Optional[List[int]] = None,
|
||||
train_block_indices: Optional[List[bool]] = None,
|
||||
verbose: Optional[bool] = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.multiplier = multiplier
|
||||
|
||||
self.lora_dim = lora_dim
|
||||
self.alpha = alpha
|
||||
self.conv_lora_dim = conv_lora_dim
|
||||
self.conv_alpha = conv_alpha
|
||||
self.dropout = dropout
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
self.split_qkv = split_qkv
|
||||
self.train_t5xxl = train_t5xxl
|
||||
|
||||
self.type_dims = type_dims
|
||||
self.emb_dims = emb_dims
|
||||
self.train_block_indices = train_block_indices
|
||||
|
||||
self.loraplus_lr_ratio = None
|
||||
self.loraplus_unet_lr_ratio = None
|
||||
self.loraplus_text_encoder_lr_ratio = None
|
||||
|
||||
if modules_dim is not None:
|
||||
logger.info(f"create LoRA network from weights")
|
||||
self.emb_dims = [0] * 6 # create emb_dims
|
||||
# verbose = True
|
||||
else:
|
||||
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
||||
logger.info(
|
||||
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
|
||||
)
|
||||
# if self.conv_lora_dim is not None:
|
||||
# logger.info(
|
||||
# f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
|
||||
# )
|
||||
|
||||
qkv_dim = 0
|
||||
if self.split_qkv:
|
||||
logger.info(f"split qkv for LoRA")
|
||||
qkv_dim = unet.joint_blocks[0].context_block.attn.qkv.weight.size(0)
|
||||
if train_t5xxl:
|
||||
logger.info(f"train T5XXL as well")
|
||||
|
||||
# create module instances
|
||||
def create_modules(
|
||||
is_mmdit: bool,
|
||||
text_encoder_idx: Optional[int],
|
||||
root_module: torch.nn.Module,
|
||||
target_replace_modules: List[str],
|
||||
filter: Optional[str] = None,
|
||||
default_dim: Optional[int] = None,
|
||||
include_conv2d_if_filter: bool = False,
|
||||
) -> List[LoRAModule]:
|
||||
prefix = (
|
||||
self.LORA_PREFIX_SD3
|
||||
if is_mmdit
|
||||
else [self.LORA_PREFIX_TEXT_ENCODER_CLIP_L, self.LORA_PREFIX_TEXT_ENCODER_CLIP_G, self.LORA_PREFIX_TEXT_ENCODER_T5][
|
||||
text_encoder_idx
|
||||
]
|
||||
)
|
||||
|
||||
loras = []
|
||||
skipped = []
|
||||
for name, module in root_module.named_modules():
|
||||
if target_replace_modules is None or module.__class__.__name__ in target_replace_modules:
|
||||
if target_replace_modules is None: # dirty hack for all modules
|
||||
module = root_module # search all modules
|
||||
|
||||
for child_name, child_module in module.named_modules():
|
||||
is_linear = child_module.__class__.__name__ == "Linear"
|
||||
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
||||
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
||||
|
||||
if is_linear or is_conv2d:
|
||||
lora_name = prefix + "." + (name + "." if name else "") + child_name
|
||||
lora_name = lora_name.replace(".", "_")
|
||||
|
||||
force_incl_conv2d = False
|
||||
if filter is not None:
|
||||
if not filter in lora_name:
|
||||
continue
|
||||
force_incl_conv2d = include_conv2d_if_filter
|
||||
|
||||
dim = None
|
||||
alpha = None
|
||||
|
||||
if modules_dim is not None:
|
||||
# モジュール指定あり
|
||||
if lora_name in modules_dim:
|
||||
dim = modules_dim[lora_name]
|
||||
alpha = modules_alpha[lora_name]
|
||||
else:
|
||||
# 通常、すべて対象とする
|
||||
if is_linear or is_conv2d_1x1:
|
||||
dim = default_dim if default_dim is not None else self.lora_dim
|
||||
alpha = self.alpha
|
||||
|
||||
if is_mmdit and type_dims is not None:
|
||||
# type_dims = [context_attn_dim, context_mlp_dim, context_mod_dim, x_attn_dim, x_mlp_dim, x_mod_dim]
|
||||
identifier = [
|
||||
("context_block", "attn"),
|
||||
("context_block", "mlp"),
|
||||
("context_block", "adaLN_modulation"),
|
||||
("x_block", "attn"),
|
||||
("x_block", "mlp"),
|
||||
("x_block", "adaLN_modulation"),
|
||||
]
|
||||
for i, d in enumerate(type_dims):
|
||||
if d is not None and all([id in lora_name for id in identifier[i]]):
|
||||
dim = d # may be 0 for skip
|
||||
break
|
||||
|
||||
if is_mmdit and dim and self.train_block_indices is not None and "joint_blocks" in lora_name:
|
||||
# "lora_unet_joint_blocks_0_x_block_attn_proj..."
|
||||
block_index = int(lora_name.split("_")[4]) # bit dirty
|
||||
if self.train_block_indices is not None and not self.train_block_indices[block_index]:
|
||||
dim = 0
|
||||
|
||||
elif self.conv_lora_dim is not None:
|
||||
dim = self.conv_lora_dim
|
||||
alpha = self.conv_alpha
|
||||
elif force_incl_conv2d:
|
||||
# x_embedder
|
||||
dim = default_dim if default_dim is not None else self.lora_dim
|
||||
alpha = self.alpha
|
||||
|
||||
if dim is None or dim == 0:
|
||||
# skipした情報を出力
|
||||
if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None):
|
||||
skipped.append(lora_name)
|
||||
continue
|
||||
|
||||
# qkv split
|
||||
split_dims = None
|
||||
if is_mmdit and split_qkv:
|
||||
if "joint_blocks" in lora_name and "qkv" in lora_name:
|
||||
split_dims = [qkv_dim // 3] * 3
|
||||
|
||||
lora = module_class(
|
||||
lora_name,
|
||||
child_module,
|
||||
self.multiplier,
|
||||
dim,
|
||||
alpha,
|
||||
dropout=dropout,
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
split_dims=split_dims,
|
||||
)
|
||||
loras.append(lora)
|
||||
|
||||
if target_replace_modules is None:
|
||||
break # all modules are searched
|
||||
return loras, skipped
|
||||
|
||||
# create LoRA for text encoder
|
||||
# 毎回すべてのモジュールを作るのは無駄なので要検討
|
||||
self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = []
|
||||
skipped_te = []
|
||||
for i, text_encoder in enumerate(text_encoders):
|
||||
index = i
|
||||
if not train_t5xxl and index >= 2: # 0: CLIP-L, 1: CLIP-G, 2: T5XXL, so we skip T5XXL if train_t5xxl is False
|
||||
break
|
||||
|
||||
logger.info(f"create LoRA for Text Encoder {index+1}:")
|
||||
|
||||
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||
logger.info(f"create LoRA for Text Encoder {index+1}: {len(text_encoder_loras)} modules.")
|
||||
self.text_encoder_loras.extend(text_encoder_loras)
|
||||
skipped_te += skipped
|
||||
|
||||
# create LoRA for U-Net
|
||||
self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
|
||||
self.unet_loras, skipped_un = create_modules(True, None, unet, LoRANetwork.SD3_TARGET_REPLACE_MODULE)
|
||||
|
||||
# emb_dims [context_embedder, t_embedder, x_embedder, y_embedder, final_mod, final_linear]
|
||||
if self.emb_dims:
|
||||
for filter, in_dim in zip(
|
||||
[
|
||||
"context_embedder",
|
||||
"_t_embedder", # don't use "t_embedder" because it's used in "context_embedder"
|
||||
"x_embedder",
|
||||
"y_embedder",
|
||||
"final_layer_adaLN_modulation",
|
||||
"final_layer_linear",
|
||||
],
|
||||
self.emb_dims,
|
||||
):
|
||||
# x_embedder is conv2d, so we need to include it
|
||||
loras, _ = create_modules(
|
||||
True, None, unet, None, filter=filter, default_dim=in_dim, include_conv2d_if_filter=filter == "x_embedder"
|
||||
)
|
||||
# if len(loras) > 0:
|
||||
# logger.info(f"create LoRA for {filter}: {len(loras)} modules.")
|
||||
self.unet_loras.extend(loras)
|
||||
|
||||
logger.info(f"create LoRA for SD3 MMDiT: {len(self.unet_loras)} modules.")
|
||||
if verbose:
|
||||
for lora in self.unet_loras:
|
||||
logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}")
|
||||
|
||||
skipped = skipped_te + skipped_un
|
||||
if verbose and len(skipped) > 0:
|
||||
logger.warning(
|
||||
f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
|
||||
)
|
||||
for name in skipped:
|
||||
logger.info(f"\t{name}")
|
||||
|
||||
# assertion
|
||||
names = set()
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
||||
names.add(lora.lora_name)
|
||||
|
||||
def set_multiplier(self, multiplier):
|
||||
self.multiplier = multiplier
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.multiplier = self.multiplier
|
||||
|
||||
def set_enabled(self, is_enabled):
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.enabled = is_enabled
|
||||
|
||||
def load_weights(self, file):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
|
||||
info = self.load_state_dict(weights_sd, False)
|
||||
return info
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
# override to convert original weight to split qkv
|
||||
if not self.split_qkv:
|
||||
return super().load_state_dict(state_dict, strict)
|
||||
|
||||
# split qkv
|
||||
for key in list(state_dict.keys()):
|
||||
if not ("joint_blocks" in key and "qkv" in key):
|
||||
continue
|
||||
|
||||
weight = state_dict[key]
|
||||
lora_name = key.split(".")[0]
|
||||
if "lora_down" in key and "weight" in key:
|
||||
# dense weight (rank*3, in_dim)
|
||||
split_weight = torch.chunk(weight, 3, dim=0)
|
||||
for i, split_w in enumerate(split_weight):
|
||||
state_dict[f"{lora_name}.lora_down.{i}.weight"] = split_w
|
||||
|
||||
del state_dict[key]
|
||||
# print(f"split {key}: {weight.shape} to {[w.shape for w in split_weight]}")
|
||||
elif "lora_up" in key and "weight" in key:
|
||||
# sparse weight (out_dim=sum(split_dims), rank*3)
|
||||
rank = weight.size(1) // 3
|
||||
i = 0
|
||||
split_dim = weight.shape[0] // 3
|
||||
for j in range(3):
|
||||
state_dict[f"{lora_name}.lora_up.{j}.weight"] = weight[i : i + split_dim, j * rank : (j + 1) * rank]
|
||||
i += split_dim
|
||||
del state_dict[key]
|
||||
|
||||
# alpha is unchanged
|
||||
|
||||
return super().load_state_dict(state_dict, strict)
|
||||
|
||||
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
||||
if not self.split_qkv:
|
||||
return super().state_dict(destination, prefix, keep_vars)
|
||||
|
||||
# merge qkv
|
||||
state_dict = super().state_dict(destination, prefix, keep_vars)
|
||||
new_state_dict = {}
|
||||
for key in list(state_dict.keys()):
|
||||
if not ("joint_blocks" in key and "qkv" in key):
|
||||
new_state_dict[key] = state_dict[key]
|
||||
continue
|
||||
|
||||
if key not in state_dict:
|
||||
continue # already merged
|
||||
|
||||
lora_name = key.split(".")[0]
|
||||
|
||||
# (rank, in_dim) * 3
|
||||
down_weights = [state_dict.pop(f"{lora_name}.lora_down.{i}.weight") for i in range(3)]
|
||||
# (split dim, rank) * 3
|
||||
up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(3)]
|
||||
|
||||
alpha = state_dict.pop(f"{lora_name}.alpha")
|
||||
|
||||
# merge down weight
|
||||
down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim)
|
||||
|
||||
# merge up weight (sum of split_dim, rank*3)
|
||||
split_dim, rank = up_weights[0].size()
|
||||
qkv_dim = split_dim * 3
|
||||
up_weight = torch.zeros((qkv_dim, down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype)
|
||||
i = 0
|
||||
for j in range(3):
|
||||
up_weight[i : i + split_dim, j * rank : (j + 1) * rank] = up_weights[j]
|
||||
i += split_dim
|
||||
|
||||
new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight
|
||||
new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight
|
||||
new_state_dict[f"{lora_name}.alpha"] = alpha
|
||||
|
||||
# print(
|
||||
# f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}"
|
||||
# )
|
||||
print(f"new key: {lora_name}.lora_down.weight, {lora_name}.lora_up.weight, {lora_name}.alpha")
|
||||
|
||||
return new_state_dict
|
||||
|
||||
def apply_to(self, text_encoders, mmdit, apply_text_encoder=True, apply_unet=True):
|
||||
if apply_text_encoder:
|
||||
logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules")
|
||||
else:
|
||||
self.text_encoder_loras = []
|
||||
|
||||
if apply_unet:
|
||||
logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules")
|
||||
else:
|
||||
self.unet_loras = []
|
||||
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.apply_to()
|
||||
self.add_module(lora.lora_name, lora)
|
||||
|
||||
# マージできるかどうかを返す
|
||||
def is_mergeable(self):
|
||||
return True
|
||||
|
||||
# TODO refactor to common function with apply_to
|
||||
def merge_to(self, text_encoders, mmdit, weights_sd, dtype=None, device=None):
|
||||
apply_text_encoder = apply_unet = False
|
||||
for key in weights_sd.keys():
|
||||
if (
|
||||
key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP_L)
|
||||
or key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP_G)
|
||||
or key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_T5)
|
||||
):
|
||||
apply_text_encoder = True
|
||||
elif key.startswith(LoRANetwork.LORA_PREFIX_SD3):
|
||||
apply_unet = True
|
||||
|
||||
if apply_text_encoder:
|
||||
logger.info("enable LoRA for text encoder")
|
||||
else:
|
||||
self.text_encoder_loras = []
|
||||
|
||||
if apply_unet:
|
||||
logger.info("enable LoRA for U-Net")
|
||||
else:
|
||||
self.unet_loras = []
|
||||
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
sd_for_lora = {}
|
||||
for key in weights_sd.keys():
|
||||
if key.startswith(lora.lora_name):
|
||||
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
|
||||
lora.merge_to(sd_for_lora, dtype, device)
|
||||
|
||||
logger.info(f"weights are merged")
|
||||
|
||||
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
|
||||
self.loraplus_lr_ratio = loraplus_lr_ratio
|
||||
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
|
||||
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
|
||||
|
||||
logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}")
|
||||
logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
|
||||
|
||||
def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr):
|
||||
# make sure text_encoder_lr as list of three elements
|
||||
# if float, use the same value for all three
|
||||
if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0):
|
||||
text_encoder_lr = [default_lr, default_lr, default_lr]
|
||||
elif isinstance(text_encoder_lr, float) or isinstance(text_encoder_lr, int):
|
||||
text_encoder_lr = [float(text_encoder_lr), float(text_encoder_lr), float(text_encoder_lr)]
|
||||
elif len(text_encoder_lr) == 1:
|
||||
text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0], text_encoder_lr[0]]
|
||||
elif len(text_encoder_lr) == 2:
|
||||
text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[1], text_encoder_lr[1]]
|
||||
|
||||
self.requires_grad_(True)
|
||||
|
||||
all_params = []
|
||||
lr_descriptions = []
|
||||
|
||||
def assemble_params(loras, lr, loraplus_ratio):
|
||||
param_groups = {"lora": {}, "plus": {}}
|
||||
for lora in loras:
|
||||
for name, param in lora.named_parameters():
|
||||
if loraplus_ratio is not None and "lora_up" in name:
|
||||
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
|
||||
else:
|
||||
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
|
||||
|
||||
params = []
|
||||
descriptions = []
|
||||
for key in param_groups.keys():
|
||||
param_data = {"params": param_groups[key].values()}
|
||||
|
||||
if len(param_data["params"]) == 0:
|
||||
continue
|
||||
|
||||
if lr is not None:
|
||||
if key == "plus":
|
||||
param_data["lr"] = lr * loraplus_ratio
|
||||
else:
|
||||
param_data["lr"] = lr
|
||||
|
||||
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
|
||||
logger.info("NO LR skipping!")
|
||||
continue
|
||||
|
||||
params.append(param_data)
|
||||
descriptions.append("plus" if key == "plus" else "")
|
||||
|
||||
return params, descriptions
|
||||
|
||||
if self.text_encoder_loras:
|
||||
loraplus_lr_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio
|
||||
|
||||
# split text encoder loras for te1 and te3
|
||||
te1_loras = [
|
||||
lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_CLIP_L)
|
||||
]
|
||||
te2_loras = [
|
||||
lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_CLIP_G)
|
||||
]
|
||||
te3_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_T5)]
|
||||
if len(te1_loras) > 0:
|
||||
logger.info(f"Text Encoder 1 (CLIP-L): {len(te1_loras)} modules, LR {text_encoder_lr[0]}")
|
||||
params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_lr_ratio)
|
||||
all_params.extend(params)
|
||||
lr_descriptions.extend(["textencoder 1 " + (" " + d if d else "") for d in descriptions])
|
||||
if len(te2_loras) > 0:
|
||||
logger.info(f"Text Encoder 2 (CLIP-G): {len(te2_loras)} modules, LR {text_encoder_lr[1]}")
|
||||
params, descriptions = assemble_params(te2_loras, text_encoder_lr[1], loraplus_lr_ratio)
|
||||
all_params.extend(params)
|
||||
lr_descriptions.extend(["textencoder 1 " + (" " + d if d else "") for d in descriptions])
|
||||
if len(te3_loras) > 0:
|
||||
logger.info(f"Text Encoder 3 (T5XXL): {len(te3_loras)} modules, LR {text_encoder_lr[2]}")
|
||||
params, descriptions = assemble_params(te3_loras, text_encoder_lr[2], loraplus_lr_ratio)
|
||||
all_params.extend(params)
|
||||
lr_descriptions.extend(["textencoder 3 " + (" " + d if d else "") for d in descriptions])
|
||||
|
||||
if self.unet_loras:
|
||||
params, descriptions = assemble_params(
|
||||
self.unet_loras,
|
||||
unet_lr if unet_lr is not None else default_lr,
|
||||
self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
|
||||
)
|
||||
all_params.extend(params)
|
||||
lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions])
|
||||
|
||||
return all_params, lr_descriptions
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
# not supported
|
||||
pass
|
||||
|
||||
def prepare_grad_etc(self, text_encoder, unet):
|
||||
self.requires_grad_(True)
|
||||
|
||||
def on_epoch_start(self, text_encoder, unet):
|
||||
self.train()
|
||||
|
||||
def get_trainable_params(self):
|
||||
return self.parameters()
|
||||
|
||||
def save_weights(self, file, dtype, metadata):
|
||||
if metadata is not None and len(metadata) == 0:
|
||||
metadata = None
|
||||
|
||||
state_dict = self.state_dict()
|
||||
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
v = state_dict[key]
|
||||
v = v.detach().clone().to("cpu").to(dtype)
|
||||
state_dict[key] = v
|
||||
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import save_file
|
||||
from library import train_util
|
||||
|
||||
# Precalculate model hashes to save time on indexing
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
metadata["sshs_legacy_hash"] = legacy_hash
|
||||
|
||||
save_file(state_dict, file, metadata)
|
||||
else:
|
||||
torch.save(state_dict, file)
|
||||
|
||||
def backup_weights(self):
|
||||
# 重みのバックアップを行う
|
||||
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
|
||||
for lora in loras:
|
||||
org_module = lora.org_module_ref[0]
|
||||
if not hasattr(org_module, "_lora_org_weight"):
|
||||
sd = org_module.state_dict()
|
||||
org_module._lora_org_weight = sd["weight"].detach().clone()
|
||||
org_module._lora_restored = True
|
||||
|
||||
def restore_weights(self):
|
||||
# 重みのリストアを行う
|
||||
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
|
||||
for lora in loras:
|
||||
org_module = lora.org_module_ref[0]
|
||||
if not org_module._lora_restored:
|
||||
sd = org_module.state_dict()
|
||||
sd["weight"] = org_module._lora_org_weight
|
||||
org_module.load_state_dict(sd)
|
||||
org_module._lora_restored = True
|
||||
|
||||
def pre_calculation(self):
|
||||
# 事前計算を行う
|
||||
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
|
||||
for lora in loras:
|
||||
org_module = lora.org_module_ref[0]
|
||||
sd = org_module.state_dict()
|
||||
|
||||
org_weight = sd["weight"]
|
||||
lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
|
||||
sd["weight"] = org_weight + lora_weight
|
||||
assert sd["weight"].shape == org_weight.shape
|
||||
org_module.load_state_dict(sd)
|
||||
|
||||
org_module._lora_restored = False
|
||||
lora.enabled = False
|
||||
|
||||
def apply_max_norm_regularization(self, max_norm_value, device):
|
||||
downkeys = []
|
||||
upkeys = []
|
||||
alphakeys = []
|
||||
norms = []
|
||||
keys_scaled = 0
|
||||
|
||||
state_dict = self.state_dict()
|
||||
for key in state_dict.keys():
|
||||
if "lora_down" in key and "weight" in key:
|
||||
downkeys.append(key)
|
||||
upkeys.append(key.replace("lora_down", "lora_up"))
|
||||
alphakeys.append(key.replace("lora_down.weight", "alpha"))
|
||||
|
||||
for i in range(len(downkeys)):
|
||||
down = state_dict[downkeys[i]].to(device)
|
||||
up = state_dict[upkeys[i]].to(device)
|
||||
alpha = state_dict[alphakeys[i]].to(device)
|
||||
dim = down.shape[0]
|
||||
scale = alpha / dim
|
||||
|
||||
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
|
||||
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
|
||||
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
|
||||
else:
|
||||
updown = up @ down
|
||||
|
||||
updown *= scale
|
||||
|
||||
norm = updown.norm().clamp(min=max_norm_value / 2)
|
||||
desired = torch.clamp(norm, max=max_norm_value)
|
||||
ratio = desired.cpu() / norm.cpu()
|
||||
sqrt_ratio = ratio**0.5
|
||||
if ratio != 1:
|
||||
keys_scaled += 1
|
||||
state_dict[upkeys[i]] *= sqrt_ratio
|
||||
state_dict[downkeys[i]] *= sqrt_ratio
|
||||
scalednorm = updown.norm() * ratio
|
||||
norms.append(scalednorm.item())
|
||||
|
||||
return keys_scaled, sum(norms) / len(norms), max(norms)
|
||||
@@ -7,7 +7,10 @@ from safetensors.torch import load_file, save_file
|
||||
from library import sai_model_spec, train_util
|
||||
import library.model_util as model_util
|
||||
import lora
|
||||
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def load_state_dict(file_name, dtype):
|
||||
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||
@@ -61,10 +64,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
||||
name_to_module[lora_name] = child_module
|
||||
|
||||
for model, ratio in zip(models, ratios):
|
||||
print(f"loading: {model}")
|
||||
logger.info(f"loading: {model}")
|
||||
lora_sd, _ = load_state_dict(model, merge_dtype)
|
||||
|
||||
print(f"merging...")
|
||||
logger.info(f"merging...")
|
||||
for key in lora_sd.keys():
|
||||
if "lora_down" in key:
|
||||
up_key = key.replace("lora_down", "lora_up")
|
||||
@@ -73,10 +76,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
||||
# find original module for this lora
|
||||
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
|
||||
if module_name not in name_to_module:
|
||||
print(f"no module found for LoRA weight: {key}")
|
||||
logger.info(f"no module found for LoRA weight: {key}")
|
||||
continue
|
||||
module = name_to_module[module_name]
|
||||
# print(f"apply {key} to {module}")
|
||||
# logger.info(f"apply {key} to {module}")
|
||||
|
||||
down_weight = lora_sd[key]
|
||||
up_weight = lora_sd[up_key]
|
||||
@@ -104,7 +107,7 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
||||
else:
|
||||
# conv2d 3x3
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
# print(conved.size(), weight.size(), module.stride, module.padding)
|
||||
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
|
||||
weight = weight + ratio * conved * scale
|
||||
|
||||
module.weight = torch.nn.Parameter(weight)
|
||||
@@ -118,7 +121,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
|
||||
v2 = None
|
||||
base_model = None
|
||||
for model, ratio in zip(models, ratios):
|
||||
print(f"loading: {model}")
|
||||
logger.info(f"loading: {model}")
|
||||
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
|
||||
|
||||
if lora_metadata is not None:
|
||||
@@ -151,10 +154,10 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
|
||||
if lora_module_name not in base_alphas:
|
||||
base_alphas[lora_module_name] = alpha
|
||||
|
||||
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
|
||||
logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
|
||||
|
||||
# merge
|
||||
print(f"merging...")
|
||||
logger.info(f"merging...")
|
||||
for key in lora_sd.keys():
|
||||
if "alpha" in key:
|
||||
continue
|
||||
@@ -196,8 +199,8 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
|
||||
merged_sd[key_down] = merged_sd[key_down][perm]
|
||||
merged_sd[key_up] = merged_sd[key_up][:,perm]
|
||||
|
||||
print("merged model")
|
||||
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
|
||||
logger.info("merged model")
|
||||
logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
|
||||
|
||||
# check all dims are same
|
||||
dims_list = list(set(base_dims.values()))
|
||||
@@ -239,7 +242,7 @@ def merge(args):
|
||||
save_dtype = merge_dtype
|
||||
|
||||
if args.sd_model is not None:
|
||||
print(f"loading SD model: {args.sd_model}")
|
||||
logger.info(f"loading SD model: {args.sd_model}")
|
||||
|
||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
|
||||
|
||||
@@ -264,18 +267,18 @@ def merge(args):
|
||||
)
|
||||
if args.v2:
|
||||
# TODO read sai modelspec
|
||||
print(
|
||||
logger.warning(
|
||||
"Cannot determine if model is for v-prediction, so save metadata as v-prediction / modelがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
|
||||
)
|
||||
|
||||
print(f"saving SD model to: {args.save_to}")
|
||||
logger.info(f"saving SD model to: {args.save_to}")
|
||||
model_util.save_stable_diffusion_checkpoint(
|
||||
args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, sai_metadata, save_dtype, vae
|
||||
)
|
||||
else:
|
||||
state_dict, metadata, v2 = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
|
||||
|
||||
print(f"calculating hashes and creating metadata...")
|
||||
logger.info(f"calculating hashes and creating metadata...")
|
||||
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
@@ -289,12 +292,12 @@ def merge(args):
|
||||
)
|
||||
if v2:
|
||||
# TODO read sai modelspec
|
||||
print(
|
||||
logger.warning(
|
||||
"Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
|
||||
)
|
||||
metadata.update(sai_metadata)
|
||||
|
||||
print(f"saving model to: {args.save_to}")
|
||||
logger.info(f"saving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,10 @@ import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
import library.model_util as model_util
|
||||
import lora
|
||||
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def load_state_dict(file_name, dtype):
|
||||
if os.path.splitext(file_name)[1] == '.safetensors':
|
||||
@@ -54,10 +57,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
||||
name_to_module[lora_name] = child_module
|
||||
|
||||
for model, ratio in zip(models, ratios):
|
||||
print(f"loading: {model}")
|
||||
logger.info(f"loading: {model}")
|
||||
lora_sd = load_state_dict(model, merge_dtype)
|
||||
|
||||
print(f"merging...")
|
||||
logger.info(f"merging...")
|
||||
for key in lora_sd.keys():
|
||||
if "lora_down" in key:
|
||||
up_key = key.replace("lora_down", "lora_up")
|
||||
@@ -66,10 +69,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
||||
# find original module for this lora
|
||||
module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight"
|
||||
if module_name not in name_to_module:
|
||||
print(f"no module found for LoRA weight: {key}")
|
||||
logger.info(f"no module found for LoRA weight: {key}")
|
||||
continue
|
||||
module = name_to_module[module_name]
|
||||
# print(f"apply {key} to {module}")
|
||||
# logger.info(f"apply {key} to {module}")
|
||||
|
||||
down_weight = lora_sd[key]
|
||||
up_weight = lora_sd[up_key]
|
||||
@@ -96,10 +99,10 @@ def merge_lora_models(models, ratios, merge_dtype):
|
||||
alpha = None
|
||||
dim = None
|
||||
for model, ratio in zip(models, ratios):
|
||||
print(f"loading: {model}")
|
||||
logger.info(f"loading: {model}")
|
||||
lora_sd = load_state_dict(model, merge_dtype)
|
||||
|
||||
print(f"merging...")
|
||||
logger.info(f"merging...")
|
||||
for key in lora_sd.keys():
|
||||
if 'alpha' in key:
|
||||
if key in merged_sd:
|
||||
@@ -117,7 +120,7 @@ def merge_lora_models(models, ratios, merge_dtype):
|
||||
dim = lora_sd[key].size()[0]
|
||||
merged_sd[key] = lora_sd[key] * ratio
|
||||
|
||||
print(f"dim (rank): {dim}, alpha: {alpha}")
|
||||
logger.info(f"dim (rank): {dim}, alpha: {alpha}")
|
||||
if alpha is None:
|
||||
alpha = dim
|
||||
|
||||
@@ -142,19 +145,21 @@ def merge(args):
|
||||
save_dtype = merge_dtype
|
||||
|
||||
if args.sd_model is not None:
|
||||
print(f"loading SD model: {args.sd_model}")
|
||||
logger.info(f"loading SD model: {args.sd_model}")
|
||||
|
||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
|
||||
|
||||
merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype)
|
||||
|
||||
print(f"\nsaving SD model to: {args.save_to}")
|
||||
logger.info("")
|
||||
logger.info(f"saving SD model to: {args.save_to}")
|
||||
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet,
|
||||
args.sd_model, 0, 0, save_dtype, vae)
|
||||
else:
|
||||
state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype)
|
||||
|
||||
print(f"\nsaving model to: {args.save_to}")
|
||||
logger.info(f"")
|
||||
logger.info(f"saving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
|
||||
|
||||
|
||||
|
||||
111
networks/oft.py
111
networks/oft.py
@@ -4,11 +4,18 @@ import math
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
from diffusers import AutoencoderKL
|
||||
import einops
|
||||
from transformers import CLIPTextModel
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import re
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
||||
|
||||
@@ -42,11 +49,16 @@ class OFTModule(torch.nn.Module):
|
||||
|
||||
if type(alpha) == torch.Tensor:
|
||||
alpha = alpha.detach().numpy()
|
||||
self.constraint = alpha * out_dim
|
||||
|
||||
# constraint in original paper is alpha * out_dim * out_dim, but we use alpha * out_dim for backward compatibility
|
||||
# original alpha is 1e-5, so we use 1e-2 or 1e-4 for alpha
|
||||
self.constraint = alpha * out_dim
|
||||
|
||||
self.register_buffer("alpha", torch.tensor(alpha))
|
||||
|
||||
self.block_size = out_dim // self.num_blocks
|
||||
self.oft_blocks = torch.nn.Parameter(torch.zeros(self.num_blocks, self.block_size, self.block_size))
|
||||
self.I = torch.eye(self.block_size).unsqueeze(0).repeat(self.num_blocks, 1, 1) # cpu
|
||||
|
||||
self.out_dim = out_dim
|
||||
self.shape = org_module.weight.shape
|
||||
@@ -66,27 +78,36 @@ class OFTModule(torch.nn.Module):
|
||||
norm_Q = torch.norm(block_Q.flatten())
|
||||
new_norm_Q = torch.clamp(norm_Q, max=self.constraint)
|
||||
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
|
||||
I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1)
|
||||
block_R = torch.matmul(I + block_Q, (I - block_Q).inverse())
|
||||
|
||||
block_R_weighted = self.multiplier * block_R + (1 - self.multiplier) * I
|
||||
R = torch.block_diag(*block_R_weighted)
|
||||
|
||||
return R
|
||||
if self.I.device != block_Q.device:
|
||||
self.I = self.I.to(block_Q.device)
|
||||
I = self.I
|
||||
block_R = torch.matmul(I + block_Q, (I - block_Q).float().inverse())
|
||||
block_R_weighted = self.multiplier * (block_R - I) + I
|
||||
return block_R_weighted
|
||||
|
||||
def forward(self, x, scale=None):
|
||||
x = self.org_forward(x)
|
||||
if self.multiplier == 0.0:
|
||||
return x
|
||||
return self.org_forward(x)
|
||||
org_module = self.org_module[0]
|
||||
org_dtype = x.dtype
|
||||
|
||||
R = self.get_weight().to(x.device, dtype=x.dtype)
|
||||
if x.dim() == 4:
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
x = torch.matmul(x, R)
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
else:
|
||||
x = torch.matmul(x, R)
|
||||
return x
|
||||
R = self.get_weight().to(torch.float32)
|
||||
W = org_module.weight.to(torch.float32)
|
||||
|
||||
if len(W.shape) == 4: # Conv2d
|
||||
W_reshaped = einops.rearrange(W, "(k n) ... -> k n ...", k=self.num_blocks, n=self.block_size)
|
||||
RW = torch.einsum("k n m, k n ... -> k m ...", R, W_reshaped)
|
||||
RW = einops.rearrange(RW, "k m ... -> (k m) ...")
|
||||
result = F.conv2d(
|
||||
x, RW.to(org_dtype), org_module.bias, org_module.stride, org_module.padding, org_module.dilation, org_module.groups
|
||||
)
|
||||
else: # Linear
|
||||
W_reshaped = einops.rearrange(W, "(k n) m -> k n m", k=self.num_blocks, n=self.block_size)
|
||||
RW = torch.einsum("k n m, k n p -> k m p", R, W_reshaped)
|
||||
RW = einops.rearrange(RW, "k m p -> (k m) p")
|
||||
result = F.linear(x, RW.to(org_dtype), org_module.bias)
|
||||
return result
|
||||
|
||||
|
||||
class OFTInfModule(OFTModule):
|
||||
@@ -112,18 +133,19 @@ class OFTInfModule(OFTModule):
|
||||
return self.org_forward(x)
|
||||
return super().forward(x, scale)
|
||||
|
||||
def merge_to(self, multiplier=None, sign=1):
|
||||
R = self.get_weight(multiplier) * sign
|
||||
|
||||
def merge_to(self, multiplier=None):
|
||||
# get org weight
|
||||
org_sd = self.org_module[0].state_dict()
|
||||
org_weight = org_sd["weight"]
|
||||
R = R.to(org_weight.device, dtype=org_weight.dtype)
|
||||
org_weight = org_sd["weight"].to(torch.float32)
|
||||
|
||||
if org_weight.dim() == 4:
|
||||
weight = torch.einsum("oihw, op -> pihw", org_weight, R)
|
||||
else:
|
||||
weight = torch.einsum("oi, op -> pi", org_weight, R)
|
||||
R = self.get_weight(multiplier).to(torch.float32)
|
||||
|
||||
weight = org_weight.reshape(self.num_blocks, self.block_size, -1)
|
||||
weight = torch.einsum("k n m, k n ... -> k m ...", R, weight)
|
||||
weight = weight.reshape(org_weight.shape)
|
||||
|
||||
# convert back to original dtype
|
||||
weight = weight.to(org_sd["weight"].dtype)
|
||||
|
||||
# set weight to org_module
|
||||
org_sd["weight"] = weight
|
||||
@@ -142,8 +164,16 @@ def create_network(
|
||||
):
|
||||
if network_dim is None:
|
||||
network_dim = 4 # default
|
||||
if network_alpha is None:
|
||||
network_alpha = 1.0
|
||||
if network_alpha is None: # should be set
|
||||
logger.info(
|
||||
"network_alpha is not set, use default value 1e-3 / network_alphaが設定されていないのでデフォルト値 1e-3 を使用します"
|
||||
)
|
||||
network_alpha = 1e-3
|
||||
elif network_alpha >= 1:
|
||||
logger.warning(
|
||||
"network_alpha is too large (>=1, maybe default value is too large), please consider to set smaller value like 1e-3"
|
||||
" / network_alphaが大きすぎるようです(>=1, デフォルト値が大きすぎる可能性があります)。1e-3のような小さな値を推奨"
|
||||
)
|
||||
|
||||
enable_all_linear = kwargs.get("enable_all_linear", None)
|
||||
enable_conv = kwargs.get("enable_conv", None)
|
||||
@@ -187,12 +217,11 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
|
||||
else:
|
||||
if dim is None:
|
||||
dim = param.size()[0]
|
||||
if has_conv2d is None and param.dim() == 4:
|
||||
if has_conv2d is None and "in_layers_2" in name:
|
||||
has_conv2d = True
|
||||
if all_linear is None:
|
||||
if param.dim() == 3 and "attn" not in name:
|
||||
all_linear = True
|
||||
if dim is not None and alpha is not None and has_conv2d is not None:
|
||||
if all_linear is None and "_ff_" in name:
|
||||
all_linear = True
|
||||
if dim is not None and alpha is not None and has_conv2d is not None and all_linear is not None:
|
||||
break
|
||||
if has_conv2d is None:
|
||||
has_conv2d = False
|
||||
@@ -237,8 +266,8 @@ class OFTNetwork(torch.nn.Module):
|
||||
self.dim = dim
|
||||
self.alpha = alpha
|
||||
|
||||
print(
|
||||
f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}"
|
||||
logger.info(
|
||||
f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}, enable_all_linear: {enable_all_linear}"
|
||||
)
|
||||
|
||||
# create module instances
|
||||
@@ -258,7 +287,7 @@ class OFTNetwork(torch.nn.Module):
|
||||
if is_linear or is_conv2d_1x1 or (is_conv2d and enable_conv):
|
||||
oft_name = prefix + "." + name + "." + child_name
|
||||
oft_name = oft_name.replace(".", "_")
|
||||
# print(oft_name)
|
||||
# logger.info(oft_name)
|
||||
|
||||
oft = module_class(
|
||||
oft_name,
|
||||
@@ -279,7 +308,7 @@ class OFTNetwork(torch.nn.Module):
|
||||
target_modules += OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
|
||||
self.unet_ofts: List[OFTModule] = create_modules(unet, target_modules)
|
||||
print(f"create OFT for U-Net: {len(self.unet_ofts)} modules.")
|
||||
logger.info(f"create OFT for U-Net: {len(self.unet_ofts)} modules.")
|
||||
|
||||
# assertion
|
||||
names = set()
|
||||
@@ -316,7 +345,7 @@ class OFTNetwork(torch.nn.Module):
|
||||
|
||||
# TODO refactor to common function with apply_to
|
||||
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
|
||||
print("enable OFT for U-Net")
|
||||
logger.info("enable OFT for U-Net")
|
||||
|
||||
for oft in self.unet_ofts:
|
||||
sd_for_lora = {}
|
||||
@@ -326,7 +355,7 @@ class OFTNetwork(torch.nn.Module):
|
||||
oft.load_state_dict(sd_for_lora, False)
|
||||
oft.merge_to()
|
||||
|
||||
print(f"weights are merged")
|
||||
logger.info(f"weights are merged")
|
||||
|
||||
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
||||
@@ -338,11 +367,11 @@ class OFTNetwork(torch.nn.Module):
|
||||
for oft in ofts:
|
||||
params.extend(oft.parameters())
|
||||
|
||||
# print num of params
|
||||
# logger.info num of params
|
||||
num_params = 0
|
||||
for p in params:
|
||||
num_params += p.numel()
|
||||
print(f"OFT params: {num_params}")
|
||||
logger.info(f"OFT params: {num_params}")
|
||||
return params
|
||||
|
||||
param_data = {"params": enumerate_params(self.unet_ofts)}
|
||||
|
||||
482
networks/oft_flux.py
Normal file
482
networks/oft_flux.py
Normal file
@@ -0,0 +1,482 @@
|
||||
# OFT network module
|
||||
|
||||
import math
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
from diffusers import AutoencoderKL
|
||||
import einops
|
||||
from transformers import CLIPTextModel
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import re
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OFTModule(torch.nn.Module):
|
||||
"""
|
||||
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
oft_name,
|
||||
org_module: torch.nn.Module,
|
||||
multiplier=1.0,
|
||||
dim=4,
|
||||
alpha=1,
|
||||
split_dims: Optional[List[int]] = None,
|
||||
):
|
||||
"""
|
||||
dim -> num blocks
|
||||
alpha -> constraint
|
||||
|
||||
split_dims is used to mimic the split qkv of FLUX as same as Diffusers
|
||||
"""
|
||||
super().__init__()
|
||||
self.oft_name = oft_name
|
||||
self.num_blocks = dim
|
||||
|
||||
if type(alpha) == torch.Tensor:
|
||||
alpha = alpha.detach().numpy()
|
||||
self.register_buffer("alpha", torch.tensor(alpha))
|
||||
|
||||
# No conv2d in FLUX
|
||||
# if "Linear" in org_module.__class__.__name__:
|
||||
self.out_dim = org_module.out_features
|
||||
# elif "Conv" in org_module.__class__.__name__:
|
||||
# out_dim = org_module.out_channels
|
||||
|
||||
if split_dims is None:
|
||||
split_dims = [self.out_dim]
|
||||
else:
|
||||
assert sum(split_dims) == self.out_dim, "sum of split_dims must be equal to out_dim"
|
||||
self.split_dims = split_dims
|
||||
|
||||
# assert all dim is divisible by num_blocks
|
||||
for split_dim in self.split_dims:
|
||||
assert split_dim % self.num_blocks == 0, "split_dim must be divisible by num_blocks"
|
||||
|
||||
self.constraint = [alpha * split_dim for split_dim in self.split_dims]
|
||||
self.block_size = [split_dim // self.num_blocks for split_dim in self.split_dims]
|
||||
self.oft_blocks = torch.nn.ParameterList(
|
||||
[torch.nn.Parameter(torch.zeros(self.num_blocks, block_size, block_size)) for block_size in self.block_size]
|
||||
)
|
||||
self.I = [torch.eye(block_size).unsqueeze(0).repeat(self.num_blocks, 1, 1) for block_size in self.block_size]
|
||||
|
||||
self.shape = org_module.weight.shape
|
||||
self.multiplier = multiplier
|
||||
self.org_module = [org_module] # moduleにならないようにlistに入れる
|
||||
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module[0].forward
|
||||
self.org_module[0].forward = self.forward
|
||||
|
||||
def get_weight(self, multiplier=None):
|
||||
if multiplier is None:
|
||||
multiplier = self.multiplier
|
||||
|
||||
if self.I[0].device != self.oft_blocks[0].device:
|
||||
self.I = [I.to(self.oft_blocks[0].device) for I in self.I]
|
||||
|
||||
block_R_weighted_list = []
|
||||
for i in range(len(self.oft_blocks)):
|
||||
block_Q = self.oft_blocks[i] - self.oft_blocks[i].transpose(1, 2)
|
||||
norm_Q = torch.norm(block_Q.flatten())
|
||||
new_norm_Q = torch.clamp(norm_Q, max=self.constraint[i])
|
||||
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
|
||||
|
||||
I = self.I[i]
|
||||
block_R = torch.matmul(I + block_Q, (I - block_Q).float().inverse())
|
||||
block_R_weighted = self.multiplier * (block_R - I) + I
|
||||
|
||||
block_R_weighted_list.append(block_R_weighted)
|
||||
|
||||
return block_R_weighted_list
|
||||
|
||||
def forward(self, x, scale=None):
|
||||
if self.multiplier == 0.0:
|
||||
return self.org_forward(x)
|
||||
|
||||
org_module = self.org_module[0]
|
||||
org_dtype = x.dtype
|
||||
|
||||
R = self.get_weight()
|
||||
W = org_module.weight.to(torch.float32)
|
||||
B = org_module.bias.to(torch.float32)
|
||||
|
||||
# split W to match R
|
||||
results = []
|
||||
d2 = 0
|
||||
for i in range(len(R)):
|
||||
d1 = d2
|
||||
d2 += self.split_dims[i]
|
||||
|
||||
W1 = W[d1:d2]
|
||||
W_reshaped = einops.rearrange(W1, "(k n) m -> k n m", k=self.num_blocks, n=self.block_size[i])
|
||||
RW_1 = torch.einsum("k n m, k n p -> k m p", R[i], W_reshaped)
|
||||
RW_1 = einops.rearrange(RW_1, "k m p -> (k m) p")
|
||||
|
||||
B1 = B[d1:d2]
|
||||
result = F.linear(x, RW_1.to(org_dtype), B1.to(org_dtype))
|
||||
results.append(result)
|
||||
|
||||
result = torch.cat(results, dim=-1)
|
||||
return result
|
||||
|
||||
|
||||
class OFTInfModule(OFTModule):
|
||||
def __init__(
|
||||
self,
|
||||
oft_name,
|
||||
org_module: torch.nn.Module,
|
||||
multiplier=1.0,
|
||||
dim=4,
|
||||
alpha=1,
|
||||
split_dims: Optional[List[int]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# no dropout for inference
|
||||
super().__init__(oft_name, org_module, multiplier, dim, alpha, split_dims)
|
||||
self.enabled = True
|
||||
self.network: OFTNetwork = None
|
||||
|
||||
def set_network(self, network):
|
||||
self.network = network
|
||||
|
||||
def forward(self, x, scale=None):
|
||||
if not self.enabled:
|
||||
return self.org_forward(x)
|
||||
return super().forward(x, scale)
|
||||
|
||||
def merge_to(self, multiplier=None):
|
||||
# get org weight
|
||||
org_sd = self.org_module[0].state_dict()
|
||||
W = org_sd["weight"].to(torch.float32)
|
||||
R = self.get_weight(multiplier).to(torch.float32)
|
||||
|
||||
d2 = 0
|
||||
W_list = []
|
||||
for i in range(len(self.oft_blocks)):
|
||||
d1 = d2
|
||||
d2 += self.split_dims[i]
|
||||
|
||||
W1 = W[d1:d2]
|
||||
W_reshaped = einops.rearrange(W1, "(k n) m -> k n m", k=self.num_blocks, n=self.block_size[i])
|
||||
W1 = torch.einsum("k n m, k n p -> k m p", R[i], W_reshaped)
|
||||
W1 = einops.rearrange(W1, "k m p -> (k m) p")
|
||||
|
||||
W_list.append(W1)
|
||||
|
||||
W = torch.cat(W_list, dim=-1)
|
||||
|
||||
# convert back to original dtype
|
||||
W = W.to(org_sd["weight"].dtype)
|
||||
|
||||
# set weight to org_module
|
||||
org_sd["weight"] = W
|
||||
self.org_module[0].load_state_dict(org_sd)
|
||||
|
||||
|
||||
def create_network(
|
||||
multiplier: float,
|
||||
network_dim: Optional[int],
|
||||
network_alpha: Optional[float],
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
|
||||
unet,
|
||||
neuron_dropout: Optional[float] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if network_dim is None:
|
||||
network_dim = 4 # default
|
||||
if network_alpha is None: # should be set
|
||||
logger.info(
|
||||
"network_alpha is not set, use default value 1e-3 / network_alphaが設定されていないのでデフォルト値 1e-3 を使用します"
|
||||
)
|
||||
network_alpha = 1e-3
|
||||
elif network_alpha >= 1:
|
||||
logger.warning(
|
||||
"network_alpha is too large (>=1, maybe default value is too large), please consider to set smaller value like 1e-3"
|
||||
" / network_alphaが大きすぎるようです(>=1, デフォルト値が大きすぎる可能性があります)。1e-3のような小さな値を推奨"
|
||||
)
|
||||
|
||||
# attn only or all linear (FFN) layers
|
||||
enable_all_linear = kwargs.get("enable_all_linear", None)
|
||||
# enable_conv = kwargs.get("enable_conv", None)
|
||||
if enable_all_linear is not None:
|
||||
enable_all_linear = bool(enable_all_linear)
|
||||
# if enable_conv is not None:
|
||||
# enable_conv = bool(enable_conv)
|
||||
|
||||
network = OFTNetwork(
|
||||
text_encoder,
|
||||
unet,
|
||||
multiplier=multiplier,
|
||||
dim=network_dim,
|
||||
alpha=network_alpha,
|
||||
enable_all_linear=enable_all_linear,
|
||||
varbose=True,
|
||||
)
|
||||
return network
|
||||
|
||||
|
||||
# Create network from weights for inference, weights are not loaded here (because can be merged)
|
||||
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
|
||||
if weights_sd is None:
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file, safe_open
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
|
||||
# check dim, alpha and if weights have for conv2d
|
||||
dim = None
|
||||
alpha = None
|
||||
all_linear = None
|
||||
for name, param in weights_sd.items():
|
||||
if name.endswith(".alpha"):
|
||||
if alpha is None:
|
||||
alpha = param.item()
|
||||
elif "qkv" in name:
|
||||
continue # ignore qkv
|
||||
else:
|
||||
if dim is None:
|
||||
dim = param.size()[0]
|
||||
if all_linear is None and "_mlp" in name:
|
||||
all_linear = True
|
||||
if dim is not None and alpha is not None and all_linear is not None:
|
||||
break
|
||||
if all_linear is None:
|
||||
all_linear = False
|
||||
|
||||
module_class = OFTInfModule if for_inference else OFTModule
|
||||
network = OFTNetwork(
|
||||
text_encoder,
|
||||
unet,
|
||||
multiplier=multiplier,
|
||||
dim=dim,
|
||||
alpha=alpha,
|
||||
enable_all_linear=all_linear,
|
||||
module_class=module_class,
|
||||
)
|
||||
return network, weights_sd
|
||||
|
||||
|
||||
class OFTNetwork(torch.nn.Module):
|
||||
FLUX_TARGET_REPLACE_MODULE_ALL_LINEAR = ["DoubleStreamBlock", "SingleStreamBlock"]
|
||||
FLUX_TARGET_REPLACE_MODULE_ATTN_ONLY = ["SelfAttention"]
|
||||
OFT_PREFIX_UNET = "oft_unet"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
|
||||
unet,
|
||||
multiplier: float = 1.0,
|
||||
dim: int = 4,
|
||||
alpha: float = 1,
|
||||
enable_all_linear: Optional[bool] = False,
|
||||
module_class: Union[Type[OFTModule], Type[OFTInfModule]] = OFTModule,
|
||||
varbose: Optional[bool] = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.train_t5xxl = False # make compatible with LoRA
|
||||
self.multiplier = multiplier
|
||||
|
||||
self.dim = dim
|
||||
self.alpha = alpha
|
||||
|
||||
logger.info(
|
||||
f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_all_linear: {enable_all_linear}"
|
||||
)
|
||||
|
||||
# create module instances
|
||||
def create_modules(
|
||||
root_module: torch.nn.Module,
|
||||
target_replace_modules: List[torch.nn.Module],
|
||||
) -> List[OFTModule]:
|
||||
prefix = self.OFT_PREFIX_UNET
|
||||
ofts = []
|
||||
for name, module in root_module.named_modules():
|
||||
if module.__class__.__name__ in target_replace_modules:
|
||||
for child_name, child_module in module.named_modules():
|
||||
is_linear = "Linear" in child_module.__class__.__name__
|
||||
|
||||
if is_linear:
|
||||
oft_name = prefix + "." + name + "." + child_name
|
||||
oft_name = oft_name.replace(".", "_")
|
||||
# logger.info(oft_name)
|
||||
|
||||
if "double" in oft_name and "qkv" in oft_name:
|
||||
split_dims = [3072] * 3
|
||||
elif "single" in oft_name and "linear1" in oft_name:
|
||||
split_dims = [3072] * 3 + [12288]
|
||||
else:
|
||||
split_dims = None
|
||||
|
||||
oft = module_class(oft_name, child_module, self.multiplier, dim, alpha, split_dims)
|
||||
ofts.append(oft)
|
||||
return ofts
|
||||
|
||||
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
||||
if enable_all_linear:
|
||||
target_modules = OFTNetwork.FLUX_TARGET_REPLACE_MODULE_ALL_LINEAR
|
||||
else:
|
||||
target_modules = OFTNetwork.FLUX_TARGET_REPLACE_MODULE_ATTN_ONLY
|
||||
|
||||
self.unet_ofts: List[OFTModule] = create_modules(unet, target_modules)
|
||||
logger.info(f"create OFT for Flux: {len(self.unet_ofts)} modules.")
|
||||
|
||||
# assertion
|
||||
names = set()
|
||||
for oft in self.unet_ofts:
|
||||
assert oft.oft_name not in names, f"duplicated oft name: {oft.oft_name}"
|
||||
names.add(oft.oft_name)
|
||||
|
||||
def set_multiplier(self, multiplier):
|
||||
self.multiplier = multiplier
|
||||
for oft in self.unet_ofts:
|
||||
oft.multiplier = self.multiplier
|
||||
|
||||
def load_weights(self, file):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
|
||||
info = self.load_state_dict(weights_sd, False)
|
||||
return info
|
||||
|
||||
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
|
||||
assert apply_unet, "apply_unet must be True"
|
||||
|
||||
for oft in self.unet_ofts:
|
||||
oft.apply_to()
|
||||
self.add_module(oft.oft_name, oft)
|
||||
|
||||
# マージできるかどうかを返す
|
||||
def is_mergeable(self):
|
||||
return True
|
||||
|
||||
# TODO refactor to common function with apply_to
|
||||
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
|
||||
logger.info("enable OFT for U-Net")
|
||||
|
||||
for oft in self.unet_ofts:
|
||||
sd_for_lora = {}
|
||||
for key in weights_sd.keys():
|
||||
if key.startswith(oft.oft_name):
|
||||
sd_for_lora[key[len(oft.oft_name) + 1 :]] = weights_sd[key]
|
||||
oft.load_state_dict(sd_for_lora, False)
|
||||
oft.merge_to()
|
||||
|
||||
logger.info(f"weights are merged")
|
||||
|
||||
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
||||
self.requires_grad_(True)
|
||||
all_params = []
|
||||
|
||||
def enumerate_params(ofts):
|
||||
params = []
|
||||
for oft in ofts:
|
||||
params.extend(oft.parameters())
|
||||
|
||||
# logger.info num of params
|
||||
num_params = 0
|
||||
for p in params:
|
||||
num_params += p.numel()
|
||||
logger.info(f"OFT params: {num_params}")
|
||||
return params
|
||||
|
||||
param_data = {"params": enumerate_params(self.unet_ofts)}
|
||||
if unet_lr is not None:
|
||||
param_data["lr"] = unet_lr
|
||||
all_params.append(param_data)
|
||||
|
||||
return all_params
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
# not supported
|
||||
pass
|
||||
|
||||
def prepare_grad_etc(self, text_encoder, unet):
|
||||
self.requires_grad_(True)
|
||||
|
||||
def on_epoch_start(self, text_encoder, unet):
|
||||
self.train()
|
||||
|
||||
def get_trainable_params(self):
|
||||
return self.parameters()
|
||||
|
||||
def save_weights(self, file, dtype, metadata):
|
||||
if metadata is not None and len(metadata) == 0:
|
||||
metadata = None
|
||||
|
||||
state_dict = self.state_dict()
|
||||
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
v = state_dict[key]
|
||||
v = v.detach().clone().to("cpu").to(dtype)
|
||||
state_dict[key] = v
|
||||
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import save_file
|
||||
from library import train_util
|
||||
|
||||
# Precalculate model hashes to save time on indexing
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
metadata["sshs_legacy_hash"] = legacy_hash
|
||||
|
||||
save_file(state_dict, file, metadata)
|
||||
else:
|
||||
torch.save(state_dict, file)
|
||||
|
||||
def backup_weights(self):
|
||||
# 重みのバックアップを行う
|
||||
ofts: List[OFTInfModule] = self.unet_ofts
|
||||
for oft in ofts:
|
||||
org_module = oft.org_module[0]
|
||||
if not hasattr(org_module, "_lora_org_weight"):
|
||||
sd = org_module.state_dict()
|
||||
org_module._lora_org_weight = sd["weight"].detach().clone()
|
||||
org_module._lora_restored = True
|
||||
|
||||
def restore_weights(self):
|
||||
# 重みのリストアを行う
|
||||
ofts: List[OFTInfModule] = self.unet_ofts
|
||||
for oft in ofts:
|
||||
org_module = oft.org_module[0]
|
||||
if not org_module._lora_restored:
|
||||
sd = org_module.state_dict()
|
||||
sd["weight"] = org_module._lora_org_weight
|
||||
org_module.load_state_dict(sd)
|
||||
org_module._lora_restored = True
|
||||
|
||||
def pre_calculation(self):
|
||||
# 事前計算を行う
|
||||
ofts: List[OFTInfModule] = self.unet_ofts
|
||||
for oft in ofts:
|
||||
org_module = oft.org_module[0]
|
||||
oft.merge_to()
|
||||
# sd = org_module.state_dict()
|
||||
# org_weight = sd["weight"]
|
||||
# lora_weight = oft.get_weight().to(org_weight.device, dtype=org_weight.dtype)
|
||||
# sd["weight"] = org_weight + lora_weight
|
||||
# assert sd["weight"].shape == org_weight.shape
|
||||
# org_module.load_state_dict(sd)
|
||||
|
||||
org_module._lora_restored = False
|
||||
oft.enabled = False
|
||||
@@ -2,80 +2,86 @@
|
||||
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
|
||||
# Thanks to cloneofsimo
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file, safe_open
|
||||
from tqdm import tqdm
|
||||
from library import train_util, model_util
|
||||
import numpy as np
|
||||
|
||||
from library import train_util
|
||||
from library import model_util
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MIN_SV = 1e-6
|
||||
|
||||
# Model save and load functions
|
||||
|
||||
|
||||
def load_state_dict(file_name, dtype):
|
||||
if model_util.is_safetensors(file_name):
|
||||
sd = load_file(file_name)
|
||||
with safe_open(file_name, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
else:
|
||||
sd = torch.load(file_name, map_location='cpu')
|
||||
metadata = None
|
||||
if model_util.is_safetensors(file_name):
|
||||
sd = load_file(file_name)
|
||||
with safe_open(file_name, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
else:
|
||||
sd = torch.load(file_name, map_location="cpu")
|
||||
metadata = None
|
||||
|
||||
for key in list(sd.keys()):
|
||||
if type(sd[key]) == torch.Tensor:
|
||||
sd[key] = sd[key].to(dtype)
|
||||
for key in list(sd.keys()):
|
||||
if type(sd[key]) == torch.Tensor:
|
||||
sd[key] = sd[key].to(dtype)
|
||||
|
||||
return sd, metadata
|
||||
return sd, metadata
|
||||
|
||||
|
||||
def save_to_file(file_name, model, state_dict, dtype, metadata):
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
if type(state_dict[key]) == torch.Tensor:
|
||||
state_dict[key] = state_dict[key].to(dtype)
|
||||
|
||||
if model_util.is_safetensors(file_name):
|
||||
save_file(model, file_name, metadata)
|
||||
else:
|
||||
torch.save(model, file_name)
|
||||
def save_to_file(file_name, state_dict, metadata):
|
||||
if model_util.is_safetensors(file_name):
|
||||
save_file(state_dict, file_name, metadata)
|
||||
else:
|
||||
torch.save(state_dict, file_name)
|
||||
|
||||
|
||||
# Indexing functions
|
||||
|
||||
def index_sv_cumulative(S, target):
|
||||
original_sum = float(torch.sum(S))
|
||||
cumulative_sums = torch.cumsum(S, dim=0)/original_sum
|
||||
index = int(torch.searchsorted(cumulative_sums, target)) + 1
|
||||
index = max(1, min(index, len(S)-1))
|
||||
|
||||
return index
|
||||
def index_sv_cumulative(S, target):
|
||||
original_sum = float(torch.sum(S))
|
||||
cumulative_sums = torch.cumsum(S, dim=0) / original_sum
|
||||
index = int(torch.searchsorted(cumulative_sums, target)) + 1
|
||||
index = max(1, min(index, len(S) - 1))
|
||||
|
||||
return index
|
||||
|
||||
|
||||
def index_sv_fro(S, target):
|
||||
S_squared = S.pow(2)
|
||||
s_fro_sq = float(torch.sum(S_squared))
|
||||
sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq
|
||||
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
|
||||
index = max(1, min(index, len(S)-1))
|
||||
S_squared = S.pow(2)
|
||||
S_fro_sq = float(torch.sum(S_squared))
|
||||
sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq
|
||||
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
|
||||
index = max(1, min(index, len(S) - 1))
|
||||
|
||||
return index
|
||||
return index
|
||||
|
||||
|
||||
def index_sv_ratio(S, target):
|
||||
max_sv = S[0]
|
||||
min_sv = max_sv/target
|
||||
index = int(torch.sum(S > min_sv).item())
|
||||
index = max(1, min(index, len(S)-1))
|
||||
max_sv = S[0]
|
||||
min_sv = max_sv / target
|
||||
index = int(torch.sum(S > min_sv).item())
|
||||
index = max(1, min(index, len(S) - 1))
|
||||
|
||||
return index
|
||||
return index
|
||||
|
||||
|
||||
# Modified from Kohaku-blueleaf's extract/merge functions
|
||||
def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
||||
out_size, in_size, kernel_size, _ = weight.size()
|
||||
U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device))
|
||||
|
||||
|
||||
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
|
||||
lora_rank = param_dict["new_rank"]
|
||||
|
||||
@@ -92,17 +98,17 @@ def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale
|
||||
|
||||
def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
||||
out_size, in_size = weight.size()
|
||||
|
||||
|
||||
U, S, Vh = torch.linalg.svd(weight.to(device))
|
||||
|
||||
|
||||
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
|
||||
lora_rank = param_dict["new_rank"]
|
||||
|
||||
|
||||
U = U[:, :lora_rank]
|
||||
S = S[:lora_rank]
|
||||
U = U @ torch.diag(S)
|
||||
Vh = Vh[:lora_rank, :]
|
||||
|
||||
|
||||
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu()
|
||||
param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu()
|
||||
del U, S, Vh, weight
|
||||
@@ -113,7 +119,7 @@ def merge_conv(lora_down, lora_up, device):
|
||||
in_rank, in_size, kernel_size, k_ = lora_down.shape
|
||||
out_size, out_rank, _, _ = lora_up.shape
|
||||
assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch"
|
||||
|
||||
|
||||
lora_down = lora_down.to(device)
|
||||
lora_up = lora_up.to(device)
|
||||
|
||||
@@ -127,236 +133,280 @@ def merge_linear(lora_down, lora_up, device):
|
||||
in_rank, in_size = lora_down.shape
|
||||
out_size, out_rank = lora_up.shape
|
||||
assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch"
|
||||
|
||||
|
||||
lora_down = lora_down.to(device)
|
||||
lora_up = lora_up.to(device)
|
||||
|
||||
|
||||
weight = lora_up @ lora_down
|
||||
del lora_up, lora_down
|
||||
return weight
|
||||
|
||||
|
||||
|
||||
# Calculate new rank
|
||||
|
||||
|
||||
def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
|
||||
param_dict = {}
|
||||
|
||||
if dynamic_method=="sv_ratio":
|
||||
if dynamic_method == "sv_ratio":
|
||||
# Calculate new dim and alpha based off ratio
|
||||
new_rank = index_sv_ratio(S, dynamic_param) + 1
|
||||
new_alpha = float(scale*new_rank)
|
||||
new_alpha = float(scale * new_rank)
|
||||
|
||||
elif dynamic_method=="sv_cumulative":
|
||||
elif dynamic_method == "sv_cumulative":
|
||||
# Calculate new dim and alpha based off cumulative sum
|
||||
new_rank = index_sv_cumulative(S, dynamic_param) + 1
|
||||
new_alpha = float(scale*new_rank)
|
||||
new_alpha = float(scale * new_rank)
|
||||
|
||||
elif dynamic_method=="sv_fro":
|
||||
elif dynamic_method == "sv_fro":
|
||||
# Calculate new dim and alpha based off sqrt sum of squares
|
||||
new_rank = index_sv_fro(S, dynamic_param) + 1
|
||||
new_alpha = float(scale*new_rank)
|
||||
new_alpha = float(scale * new_rank)
|
||||
else:
|
||||
new_rank = rank
|
||||
new_alpha = float(scale*new_rank)
|
||||
new_alpha = float(scale * new_rank)
|
||||
|
||||
|
||||
if S[0] <= MIN_SV: # Zero matrix, set dim to 1
|
||||
if S[0] <= MIN_SV: # Zero matrix, set dim to 1
|
||||
new_rank = 1
|
||||
new_alpha = float(scale*new_rank)
|
||||
elif new_rank > rank: # cap max rank at rank
|
||||
new_alpha = float(scale * new_rank)
|
||||
elif new_rank > rank: # cap max rank at rank
|
||||
new_rank = rank
|
||||
new_alpha = float(scale*new_rank)
|
||||
|
||||
new_alpha = float(scale * new_rank)
|
||||
|
||||
# Calculate resize info
|
||||
s_sum = torch.sum(torch.abs(S))
|
||||
s_rank = torch.sum(torch.abs(S[:new_rank]))
|
||||
|
||||
|
||||
S_squared = S.pow(2)
|
||||
s_fro = torch.sqrt(torch.sum(S_squared))
|
||||
s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank]))
|
||||
fro_percent = float(s_red_fro/s_fro)
|
||||
fro_percent = float(s_red_fro / s_fro)
|
||||
|
||||
param_dict["new_rank"] = new_rank
|
||||
param_dict["new_alpha"] = new_alpha
|
||||
param_dict["sum_retained"] = (s_rank)/s_sum
|
||||
param_dict["sum_retained"] = (s_rank) / s_sum
|
||||
param_dict["fro_retained"] = fro_percent
|
||||
param_dict["max_ratio"] = S[0]/S[new_rank - 1]
|
||||
param_dict["max_ratio"] = S[0] / S[new_rank - 1]
|
||||
|
||||
return param_dict
|
||||
|
||||
|
||||
def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
|
||||
network_alpha = None
|
||||
network_dim = None
|
||||
verbose_str = "\n"
|
||||
fro_list = []
|
||||
def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
|
||||
network_alpha = None
|
||||
network_dim = None
|
||||
verbose_str = "\n"
|
||||
fro_list = []
|
||||
|
||||
# Extract loaded lora dim and alpha
|
||||
for key, value in lora_sd.items():
|
||||
if network_alpha is None and 'alpha' in key:
|
||||
network_alpha = value
|
||||
if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
|
||||
network_dim = value.size()[0]
|
||||
if network_alpha is not None and network_dim is not None:
|
||||
break
|
||||
if network_alpha is None:
|
||||
network_alpha = network_dim
|
||||
# Extract loaded lora dim and alpha
|
||||
for key, value in lora_sd.items():
|
||||
if network_alpha is None and "alpha" in key:
|
||||
network_alpha = value
|
||||
if network_dim is None and "lora_down" in key and len(value.size()) == 2:
|
||||
network_dim = value.size()[0]
|
||||
if network_alpha is not None and network_dim is not None:
|
||||
break
|
||||
if network_alpha is None:
|
||||
network_alpha = network_dim
|
||||
|
||||
scale = network_alpha/network_dim
|
||||
scale = network_alpha / network_dim
|
||||
|
||||
if dynamic_method:
|
||||
print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}")
|
||||
if dynamic_method:
|
||||
logger.info(
|
||||
f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}"
|
||||
)
|
||||
|
||||
lora_down_weight = None
|
||||
lora_up_weight = None
|
||||
lora_down_weight = None
|
||||
lora_up_weight = None
|
||||
|
||||
o_lora_sd = lora_sd.copy()
|
||||
block_down_name = None
|
||||
block_up_name = None
|
||||
o_lora_sd = lora_sd.copy()
|
||||
block_down_name = None
|
||||
block_up_name = None
|
||||
|
||||
with torch.no_grad():
|
||||
for key, value in tqdm(lora_sd.items()):
|
||||
weight_name = None
|
||||
if 'lora_down' in key:
|
||||
block_down_name = key.rsplit('.lora_down', 1)[0]
|
||||
weight_name = key.rsplit(".", 1)[-1]
|
||||
lora_down_weight = value
|
||||
else:
|
||||
continue
|
||||
with torch.no_grad():
|
||||
for key, value in tqdm(lora_sd.items()):
|
||||
weight_name = None
|
||||
if "lora_down" in key:
|
||||
block_down_name = key.rsplit(".lora_down", 1)[0]
|
||||
weight_name = key.rsplit(".", 1)[-1]
|
||||
lora_down_weight = value
|
||||
else:
|
||||
continue
|
||||
|
||||
# find corresponding lora_up and alpha
|
||||
block_up_name = block_down_name
|
||||
lora_up_weight = lora_sd.get(block_up_name + '.lora_up.' + weight_name, None)
|
||||
lora_alpha = lora_sd.get(block_down_name + '.alpha', None)
|
||||
# find corresponding lora_up and alpha
|
||||
block_up_name = block_down_name
|
||||
lora_up_weight = lora_sd.get(block_up_name + ".lora_up." + weight_name, None)
|
||||
lora_alpha = lora_sd.get(block_down_name + ".alpha", None)
|
||||
|
||||
weights_loaded = (lora_down_weight is not None and lora_up_weight is not None)
|
||||
weights_loaded = lora_down_weight is not None and lora_up_weight is not None
|
||||
|
||||
if weights_loaded:
|
||||
if weights_loaded:
|
||||
|
||||
conv2d = (len(lora_down_weight.size()) == 4)
|
||||
if lora_alpha is None:
|
||||
scale = 1.0
|
||||
else:
|
||||
scale = lora_alpha/lora_down_weight.size()[0]
|
||||
conv2d = len(lora_down_weight.size()) == 4
|
||||
if lora_alpha is None:
|
||||
scale = 1.0
|
||||
else:
|
||||
scale = lora_alpha / lora_down_weight.size()[0]
|
||||
|
||||
if conv2d:
|
||||
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
|
||||
param_dict = extract_conv(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
|
||||
else:
|
||||
full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device)
|
||||
param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
|
||||
if conv2d:
|
||||
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
|
||||
param_dict = extract_conv(full_weight_matrix, new_conv_rank, dynamic_method, dynamic_param, device, scale)
|
||||
else:
|
||||
full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device)
|
||||
param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
|
||||
|
||||
if verbose:
|
||||
max_ratio = param_dict['max_ratio']
|
||||
sum_retained = param_dict['sum_retained']
|
||||
fro_retained = param_dict['fro_retained']
|
||||
if not np.isnan(fro_retained):
|
||||
fro_list.append(float(fro_retained))
|
||||
if verbose:
|
||||
max_ratio = param_dict["max_ratio"]
|
||||
sum_retained = param_dict["sum_retained"]
|
||||
fro_retained = param_dict["fro_retained"]
|
||||
if not np.isnan(fro_retained):
|
||||
fro_list.append(float(fro_retained))
|
||||
|
||||
verbose_str+=f"{block_down_name:75} | "
|
||||
verbose_str+=f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}"
|
||||
verbose_str += f"{block_down_name:75} | "
|
||||
verbose_str += (
|
||||
f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}"
|
||||
)
|
||||
|
||||
if verbose and dynamic_method:
|
||||
verbose_str+=f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n"
|
||||
else:
|
||||
verbose_str+=f"\n"
|
||||
if verbose and dynamic_method:
|
||||
verbose_str += f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n"
|
||||
else:
|
||||
verbose_str += "\n"
|
||||
|
||||
new_alpha = param_dict['new_alpha']
|
||||
o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous()
|
||||
o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous()
|
||||
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict['new_alpha']).to(save_dtype)
|
||||
new_alpha = param_dict["new_alpha"]
|
||||
o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous()
|
||||
o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous()
|
||||
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict["new_alpha"]).to(save_dtype)
|
||||
|
||||
block_down_name = None
|
||||
block_up_name = None
|
||||
lora_down_weight = None
|
||||
lora_up_weight = None
|
||||
weights_loaded = False
|
||||
del param_dict
|
||||
block_down_name = None
|
||||
block_up_name = None
|
||||
lora_down_weight = None
|
||||
lora_up_weight = None
|
||||
weights_loaded = False
|
||||
del param_dict
|
||||
|
||||
if verbose:
|
||||
print(verbose_str)
|
||||
|
||||
print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
|
||||
print("resizing complete")
|
||||
return o_lora_sd, network_dim, new_alpha
|
||||
if verbose:
|
||||
print(verbose_str)
|
||||
print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
|
||||
logger.info("resizing complete")
|
||||
return o_lora_sd, network_dim, new_alpha
|
||||
|
||||
|
||||
def resize(args):
|
||||
if args.save_to is None or not (args.save_to.endswith('.ckpt') or args.save_to.endswith('.pt') or args.save_to.endswith('.pth') or args.save_to.endswith('.safetensors')):
|
||||
raise Exception("The --save_to argument must be specified and must be a .ckpt , .pt, .pth or .safetensors file.")
|
||||
if args.save_to is None or not (
|
||||
args.save_to.endswith(".ckpt")
|
||||
or args.save_to.endswith(".pt")
|
||||
or args.save_to.endswith(".pth")
|
||||
or args.save_to.endswith(".safetensors")
|
||||
):
|
||||
raise Exception("The --save_to argument must be specified and must be a .ckpt , .pt, .pth or .safetensors file.")
|
||||
|
||||
|
||||
def str_to_dtype(p):
|
||||
if p == 'float':
|
||||
return torch.float
|
||||
if p == 'fp16':
|
||||
return torch.float16
|
||||
if p == 'bf16':
|
||||
return torch.bfloat16
|
||||
return None
|
||||
args.new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank
|
||||
|
||||
if args.dynamic_method and not args.dynamic_param:
|
||||
raise Exception("If using dynamic_method, then dynamic_param is required")
|
||||
def str_to_dtype(p):
|
||||
if p == "float":
|
||||
return torch.float
|
||||
if p == "fp16":
|
||||
return torch.float16
|
||||
if p == "bf16":
|
||||
return torch.bfloat16
|
||||
return None
|
||||
|
||||
merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32
|
||||
save_dtype = str_to_dtype(args.save_precision)
|
||||
if save_dtype is None:
|
||||
save_dtype = merge_dtype
|
||||
if args.dynamic_method and not args.dynamic_param:
|
||||
raise Exception("If using dynamic_method, then dynamic_param is required")
|
||||
|
||||
print("loading Model...")
|
||||
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
|
||||
merge_dtype = str_to_dtype("float") # matmul method above only seems to work in float32
|
||||
save_dtype = str_to_dtype(args.save_precision)
|
||||
if save_dtype is None:
|
||||
save_dtype = merge_dtype
|
||||
|
||||
print("Resizing Lora...")
|
||||
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose)
|
||||
logger.info("loading Model...")
|
||||
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
|
||||
|
||||
# update metadata
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
logger.info("Resizing Lora...")
|
||||
state_dict, old_dim, new_alpha = resize_lora_model(
|
||||
lora_sd, args.new_rank, args.new_conv_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose
|
||||
)
|
||||
|
||||
comment = metadata.get("ss_training_comment", "")
|
||||
# update metadata
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
|
||||
if not args.dynamic_method:
|
||||
metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}"
|
||||
metadata["ss_network_dim"] = str(args.new_rank)
|
||||
metadata["ss_network_alpha"] = str(new_alpha)
|
||||
else:
|
||||
metadata["ss_training_comment"] = f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}"
|
||||
metadata["ss_network_dim"] = 'Dynamic'
|
||||
metadata["ss_network_alpha"] = 'Dynamic'
|
||||
comment = metadata.get("ss_training_comment", "")
|
||||
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
metadata["sshs_legacy_hash"] = legacy_hash
|
||||
if not args.dynamic_method:
|
||||
conv_desc = "" if args.new_rank == args.new_conv_rank else f" (conv: {args.new_conv_rank})"
|
||||
metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}{conv_desc}; {comment}"
|
||||
metadata["ss_network_dim"] = str(args.new_rank)
|
||||
metadata["ss_network_alpha"] = str(new_alpha)
|
||||
else:
|
||||
metadata["ss_training_comment"] = (
|
||||
f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}"
|
||||
)
|
||||
metadata["ss_network_dim"] = "Dynamic"
|
||||
metadata["ss_network_alpha"] = "Dynamic"
|
||||
|
||||
print(f"saving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
|
||||
# cast to save_dtype before calculating hashes
|
||||
for key in list(state_dict.keys()):
|
||||
value = state_dict[key]
|
||||
if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype:
|
||||
state_dict[key] = value.to(save_dtype)
|
||||
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
metadata["sshs_legacy_hash"] = legacy_hash
|
||||
|
||||
logger.info(f"saving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, metadata)
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--save_precision", type=str, default=None,
|
||||
choices=[None, "float", "fp16", "bf16"], help="precision in saving, float if omitted / 保存時の精度、未指定時はfloat")
|
||||
parser.add_argument("--new_rank", type=int, default=4,
|
||||
help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
|
||||
parser.add_argument("--save_to", type=str, default=None,
|
||||
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
|
||||
parser.add_argument("--model", type=str, default=None,
|
||||
help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors")
|
||||
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
||||
parser.add_argument("--verbose", action="store_true",
|
||||
help="Display verbose resizing information / rank変更時の詳細情報を出力する")
|
||||
parser.add_argument("--dynamic_method", type=str, default=None, choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"],
|
||||
help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank")
|
||||
parser.add_argument("--dynamic_param", type=float, default=None,
|
||||
help="Specify target for dynamic reduction")
|
||||
|
||||
return parser
|
||||
parser.add_argument(
|
||||
"--save_precision",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=[None, "float", "fp16", "bf16"],
|
||||
help="precision in saving, float if omitted / 保存時の精度、未指定時はfloat",
|
||||
)
|
||||
parser.add_argument("--new_rank", type=int, default=4, help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
|
||||
parser.add_argument(
|
||||
"--new_conv_rank",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_to",
|
||||
type=str,
|
||||
default=None,
|
||||
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default=None,
|
||||
help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose", action="store_true", help="Display verbose resizing information / rank変更時の詳細情報を出力する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dynamic_method",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"],
|
||||
help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank",
|
||||
)
|
||||
parser.add_argument("--dynamic_param", type=float, default=None, help="Specify target for dynamic reduction")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = setup_parser()
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
resize(args)
|
||||
args = parser.parse_args()
|
||||
resize(args)
|
||||
|
||||
@@ -1,13 +1,23 @@
|
||||
import itertools
|
||||
import math
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
import concurrent.futures
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
from library import sai_model_spec, sdxl_model_util, train_util
|
||||
import library.model_util as model_util
|
||||
import lora
|
||||
import oft
|
||||
from svd_merge_lora import format_lbws, get_lbw_block_index, LAYER26
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_state_dict(file_name, dtype):
|
||||
@@ -25,36 +35,58 @@ def load_state_dict(file_name, dtype):
|
||||
return sd, metadata
|
||||
|
||||
|
||||
def save_to_file(file_name, model, state_dict, dtype, metadata):
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
if type(state_dict[key]) == torch.Tensor:
|
||||
state_dict[key] = state_dict[key].to(dtype)
|
||||
|
||||
def save_to_file(file_name, model, metadata):
|
||||
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||
save_file(model, file_name, metadata=metadata)
|
||||
else:
|
||||
torch.save(model, file_name)
|
||||
|
||||
|
||||
def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype):
|
||||
text_encoder1.to(merge_dtype)
|
||||
def detect_method_from_training_model(models, dtype):
|
||||
for model in models:
|
||||
# TODO It is better to use key names to detect the method
|
||||
lora_sd, _ = load_state_dict(model, dtype)
|
||||
for key in tqdm(lora_sd.keys()):
|
||||
if "lora_up" in key or "lora_down" in key:
|
||||
return "LoRA"
|
||||
elif "oft_blocks" in key:
|
||||
return "OFT"
|
||||
|
||||
|
||||
def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, lbws, merge_dtype):
|
||||
text_encoder1.to(merge_dtype)
|
||||
text_encoder2.to(merge_dtype)
|
||||
unet.to(merge_dtype)
|
||||
|
||||
# detect the method: OFT or LoRA_module
|
||||
method = detect_method_from_training_model(models, merge_dtype)
|
||||
logger.info(f"method:{method}")
|
||||
|
||||
if lbws:
|
||||
lbws, _, LBW_TARGET_IDX = format_lbws(lbws)
|
||||
else:
|
||||
LBW_TARGET_IDX = []
|
||||
|
||||
# create module map
|
||||
name_to_module = {}
|
||||
for i, root_module in enumerate([text_encoder1, text_encoder2, unet]):
|
||||
if i <= 1:
|
||||
if i == 0:
|
||||
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1
|
||||
if method == "LoRA":
|
||||
if i <= 1:
|
||||
if i == 0:
|
||||
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1
|
||||
else:
|
||||
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2
|
||||
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
|
||||
else:
|
||||
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2
|
||||
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
|
||||
else:
|
||||
prefix = lora.LoRANetwork.LORA_PREFIX_UNET
|
||||
prefix = lora.LoRANetwork.LORA_PREFIX_UNET
|
||||
target_replace_modules = (
|
||||
lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
)
|
||||
elif method == "OFT":
|
||||
prefix = oft.OFTNetwork.OFT_PREFIX_UNET
|
||||
# ALL_LINEAR includes ATTN_ONLY, so we don't need to specify ATTN_ONLY
|
||||
target_replace_modules = (
|
||||
lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_ALL_LINEAR + oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
)
|
||||
|
||||
for name, module in root_module.named_modules():
|
||||
@@ -65,65 +97,172 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
|
||||
lora_name = lora_name.replace(".", "_")
|
||||
name_to_module[lora_name] = child_module
|
||||
|
||||
for model, ratio in zip(models, ratios):
|
||||
print(f"loading: {model}")
|
||||
for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws):
|
||||
logger.info(f"loading: {model}")
|
||||
lora_sd, _ = load_state_dict(model, merge_dtype)
|
||||
|
||||
print(f"merging...")
|
||||
for key in tqdm(lora_sd.keys()):
|
||||
if "lora_down" in key:
|
||||
up_key = key.replace("lora_down", "lora_up")
|
||||
alpha_key = key[: key.index("lora_down")] + "alpha"
|
||||
logger.info(f"merging...")
|
||||
|
||||
# find original module for this lora
|
||||
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
|
||||
if lbw:
|
||||
lbw_weights = [1] * 26
|
||||
for index, value in zip(LBW_TARGET_IDX, lbw):
|
||||
lbw_weights[index] = value
|
||||
logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}")
|
||||
|
||||
if method == "LoRA":
|
||||
for key in tqdm(lora_sd.keys()):
|
||||
if "lora_down" in key:
|
||||
up_key = key.replace("lora_down", "lora_up")
|
||||
alpha_key = key[: key.index("lora_down")] + "alpha"
|
||||
|
||||
# find original module for this lora
|
||||
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
|
||||
if module_name not in name_to_module:
|
||||
logger.info(f"no module found for LoRA weight: {key}")
|
||||
continue
|
||||
module = name_to_module[module_name]
|
||||
# logger.info(f"apply {key} to {module}")
|
||||
|
||||
down_weight = lora_sd[key]
|
||||
up_weight = lora_sd[up_key]
|
||||
|
||||
dim = down_weight.size()[0]
|
||||
alpha = lora_sd.get(alpha_key, dim)
|
||||
scale = alpha / dim
|
||||
|
||||
if lbw:
|
||||
index = get_lbw_block_index(key, True)
|
||||
is_lbw_target = index in LBW_TARGET_IDX
|
||||
if is_lbw_target:
|
||||
scale *= lbw_weights[index] # keyがlbwの対象であれば、lbwの重みを掛ける
|
||||
|
||||
# W <- W + U * D
|
||||
weight = module.weight
|
||||
# logger.info(module_name, down_weight.size(), up_weight.size())
|
||||
if len(weight.size()) == 2:
|
||||
# linear
|
||||
weight = weight + ratio * (up_weight @ down_weight) * scale
|
||||
elif down_weight.size()[2:4] == (1, 1):
|
||||
# conv2d 1x1
|
||||
weight = (
|
||||
weight
|
||||
+ ratio
|
||||
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
* scale
|
||||
)
|
||||
else:
|
||||
# conv2d 3x3
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
|
||||
weight = weight + ratio * conved * scale
|
||||
|
||||
module.weight = torch.nn.Parameter(weight)
|
||||
|
||||
elif method == "OFT":
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
for key in tqdm(lora_sd.keys()):
|
||||
if "oft_blocks" in key:
|
||||
oft_blocks = lora_sd[key]
|
||||
dim = oft_blocks.shape[0]
|
||||
break
|
||||
for key in tqdm(lora_sd.keys()):
|
||||
if "alpha" in key:
|
||||
oft_blocks = lora_sd[key]
|
||||
alpha = oft_blocks.item()
|
||||
break
|
||||
|
||||
def merge_to(key):
|
||||
if "alpha" in key:
|
||||
return
|
||||
|
||||
# find original module for this OFT
|
||||
module_name = ".".join(key.split(".")[:-1])
|
||||
if module_name not in name_to_module:
|
||||
print(f"no module found for LoRA weight: {key}")
|
||||
continue
|
||||
logger.info(f"no module found for OFT weight: {key}")
|
||||
return
|
||||
module = name_to_module[module_name]
|
||||
# print(f"apply {key} to {module}")
|
||||
|
||||
down_weight = lora_sd[key]
|
||||
up_weight = lora_sd[up_key]
|
||||
# logger.info(f"apply {key} to {module}")
|
||||
|
||||
dim = down_weight.size()[0]
|
||||
alpha = lora_sd.get(alpha_key, dim)
|
||||
scale = alpha / dim
|
||||
oft_blocks = lora_sd[key]
|
||||
|
||||
# W <- W + U * D
|
||||
weight = module.weight
|
||||
# print(module_name, down_weight.size(), up_weight.size())
|
||||
if len(weight.size()) == 2:
|
||||
# linear
|
||||
weight = weight + ratio * (up_weight @ down_weight) * scale
|
||||
elif down_weight.size()[2:4] == (1, 1):
|
||||
# conv2d 1x1
|
||||
weight = (
|
||||
weight
|
||||
+ ratio
|
||||
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
* scale
|
||||
)
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
out_dim = module.out_features
|
||||
elif isinstance(module, torch.nn.Conv2d):
|
||||
out_dim = module.out_channels
|
||||
|
||||
num_blocks = dim
|
||||
block_size = out_dim // dim
|
||||
constraint = (0 if alpha is None else alpha) * out_dim
|
||||
|
||||
multiplier = 1
|
||||
if lbw:
|
||||
index = get_lbw_block_index(key, False)
|
||||
is_lbw_target = index in LBW_TARGET_IDX
|
||||
if is_lbw_target:
|
||||
multiplier *= lbw_weights[index]
|
||||
|
||||
block_Q = oft_blocks - oft_blocks.transpose(1, 2)
|
||||
norm_Q = torch.norm(block_Q.flatten())
|
||||
new_norm_Q = torch.clamp(norm_Q, max=constraint)
|
||||
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
|
||||
I = torch.eye(block_size, device=oft_blocks.device).unsqueeze(0).repeat(num_blocks, 1, 1)
|
||||
block_R = torch.matmul(I + block_Q, (I - block_Q).inverse())
|
||||
block_R_weighted = multiplier * block_R + (1 - multiplier) * I
|
||||
R = torch.block_diag(*block_R_weighted)
|
||||
|
||||
# get org weight
|
||||
org_sd = module.state_dict()
|
||||
org_weight = org_sd["weight"].to(device)
|
||||
|
||||
R = R.to(org_weight.device, dtype=org_weight.dtype)
|
||||
|
||||
if org_weight.dim() == 4:
|
||||
weight = torch.einsum("oihw, op -> pihw", org_weight, R)
|
||||
else:
|
||||
# conv2d 3x3
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
# print(conved.size(), weight.size(), module.stride, module.padding)
|
||||
weight = weight + ratio * conved * scale
|
||||
weight = torch.einsum("oi, op -> pi", org_weight, R)
|
||||
|
||||
weight = weight.contiguous() # Make Tensor contiguous; required due to ThreadPoolExecutor
|
||||
|
||||
module.weight = torch.nn.Parameter(weight)
|
||||
|
||||
# TODO multi-threading may cause OOM on CPU if cpu_count is too high and RAM is not enough
|
||||
max_workers = 1 if device.type != "cpu" else None # avoid OOM on GPU
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
list(tqdm(executor.map(merge_to, lora_sd.keys()), total=len(lora_sd.keys())))
|
||||
|
||||
def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
|
||||
|
||||
def merge_lora_models(models, ratios, lbws, merge_dtype, concat=False, shuffle=False):
|
||||
base_alphas = {} # alpha for merged model
|
||||
base_dims = {}
|
||||
|
||||
# detect the method: OFT or LoRA_module
|
||||
method = detect_method_from_training_model(models, merge_dtype)
|
||||
if method == "OFT":
|
||||
raise ValueError(
|
||||
"OFT model is not supported for merging OFT models. / OFTモデルはOFTモデル同士のマージには対応していません"
|
||||
)
|
||||
|
||||
if lbws:
|
||||
lbws, _, LBW_TARGET_IDX = format_lbws(lbws)
|
||||
else:
|
||||
LBW_TARGET_IDX = []
|
||||
|
||||
merged_sd = {}
|
||||
v2 = None
|
||||
base_model = None
|
||||
for model, ratio in zip(models, ratios):
|
||||
print(f"loading: {model}")
|
||||
for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws):
|
||||
logger.info(f"loading: {model}")
|
||||
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
|
||||
|
||||
if lbw:
|
||||
lbw_weights = [1] * 26
|
||||
for index, value in zip(LBW_TARGET_IDX, lbw):
|
||||
lbw_weights[index] = value
|
||||
logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}")
|
||||
|
||||
if lora_metadata is not None:
|
||||
if v2 is None:
|
||||
v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # returns string, SDXLはv2がないのでFalseのはず
|
||||
@@ -154,14 +293,14 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
|
||||
if lora_module_name not in base_alphas:
|
||||
base_alphas[lora_module_name] = alpha
|
||||
|
||||
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
|
||||
logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
|
||||
|
||||
# merge
|
||||
print(f"merging...")
|
||||
logger.info(f"merging...")
|
||||
for key in tqdm(lora_sd.keys()):
|
||||
if "alpha" in key:
|
||||
continue
|
||||
|
||||
|
||||
if "lora_up" in key and concat:
|
||||
concat_dim = 1
|
||||
elif "lora_down" in key and concat:
|
||||
@@ -175,8 +314,14 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
|
||||
alpha = alphas[lora_module_name]
|
||||
|
||||
scale = math.sqrt(alpha / base_alpha) * ratio
|
||||
scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。
|
||||
|
||||
scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。
|
||||
|
||||
if lbw:
|
||||
index = get_lbw_block_index(key, True)
|
||||
is_lbw_target = index in LBW_TARGET_IDX
|
||||
if is_lbw_target:
|
||||
scale *= lbw_weights[index] # keyがlbwの対象であれば、lbwの重みを掛ける
|
||||
|
||||
if key in merged_sd:
|
||||
assert (
|
||||
merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None
|
||||
@@ -198,10 +343,10 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
|
||||
dim = merged_sd[key_down].shape[0]
|
||||
perm = torch.randperm(dim)
|
||||
merged_sd[key_down] = merged_sd[key_down][perm]
|
||||
merged_sd[key_up] = merged_sd[key_up][:,perm]
|
||||
merged_sd[key_up] = merged_sd[key_up][:, perm]
|
||||
|
||||
print("merged model")
|
||||
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
|
||||
logger.info("merged model")
|
||||
logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
|
||||
|
||||
# check all dims are same
|
||||
dims_list = list(set(base_dims.values()))
|
||||
@@ -226,7 +371,15 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
|
||||
|
||||
|
||||
def merge(args):
|
||||
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
|
||||
assert len(args.models) == len(
|
||||
args.ratios
|
||||
), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
|
||||
if args.lbws:
|
||||
assert len(args.models) == len(
|
||||
args.lbws
|
||||
), f"number of models must be equal to number of ratios / モデルの数と層別適用率の数は合わせてください"
|
||||
else:
|
||||
args.lbws = [] # zip_longestで扱えるようにlbws未使用時には空のリストにしておく
|
||||
|
||||
def str_to_dtype(p):
|
||||
if p == "float":
|
||||
@@ -243,7 +396,7 @@ def merge(args):
|
||||
save_dtype = merge_dtype
|
||||
|
||||
if args.sd_model is not None:
|
||||
print(f"loading SD model: {args.sd_model}")
|
||||
logger.info(f"loading SD model: {args.sd_model}")
|
||||
|
||||
(
|
||||
text_model1,
|
||||
@@ -254,7 +407,7 @@ def merge(args):
|
||||
ckpt_info,
|
||||
) = sdxl_model_util.load_models_from_sdxl_checkpoint(sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.sd_model, "cpu")
|
||||
|
||||
merge_to_sd_model(text_model1, text_model2, unet, args.models, args.ratios, merge_dtype)
|
||||
merge_to_sd_model(text_model1, text_model2, unet, args.models, args.ratios, args.lbws, merge_dtype)
|
||||
|
||||
if args.no_metadata:
|
||||
sai_metadata = None
|
||||
@@ -265,14 +418,20 @@ def merge(args):
|
||||
None, False, False, True, False, False, time.time(), title=title, merged_from=merged_from
|
||||
)
|
||||
|
||||
print(f"saving SD model to: {args.save_to}")
|
||||
logger.info(f"saving SD model to: {args.save_to}")
|
||||
sdxl_model_util.save_stable_diffusion_checkpoint(
|
||||
args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype
|
||||
)
|
||||
else:
|
||||
state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
|
||||
state_dict, metadata = merge_lora_models(args.models, args.ratios, args.lbws, merge_dtype, args.concat, args.shuffle)
|
||||
|
||||
print(f"calculating hashes and creating metadata...")
|
||||
# cast to save_dtype before calculating hashes
|
||||
for key in list(state_dict.keys()):
|
||||
value = state_dict[key]
|
||||
if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype:
|
||||
state_dict[key] = value.to(save_dtype)
|
||||
|
||||
logger.info(f"calculating hashes and creating metadata...")
|
||||
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
@@ -286,8 +445,8 @@ def merge(args):
|
||||
)
|
||||
metadata.update(sai_metadata)
|
||||
|
||||
print(f"saving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
|
||||
logger.info(f"saving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, metadata)
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
@@ -313,12 +472,19 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
|
||||
"--save_to",
|
||||
type=str,
|
||||
default=None,
|
||||
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
|
||||
"--models",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors",
|
||||
)
|
||||
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
|
||||
parser.add_argument("--lbws", type=str, nargs="*", help="lbw for each model / それぞれのLoRAモデルの層別適用率")
|
||||
parser.add_argument(
|
||||
"--no_metadata",
|
||||
action="store_true",
|
||||
@@ -334,8 +500,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"--shuffle",
|
||||
action="store_true",
|
||||
help="shuffle lora weight./ "
|
||||
+ "LoRAの重みをシャッフルする",
|
||||
help="shuffle lora weight./ " + "LoRAの重みをシャッフルする",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import math
|
||||
import argparse
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
@@ -8,10 +10,196 @@ from tqdm import tqdm
|
||||
from library import sai_model_spec, train_util
|
||||
import library.model_util as model_util
|
||||
import lora
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CLAMP_QUANTILE = 0.99
|
||||
|
||||
ACCEPTABLE = [12, 17, 20, 26]
|
||||
SDXL_LAYER_NUM = [12, 20]
|
||||
|
||||
LAYER12 = {
|
||||
"BASE": True,
|
||||
"IN00": False,
|
||||
"IN01": False,
|
||||
"IN02": False,
|
||||
"IN03": False,
|
||||
"IN04": True,
|
||||
"IN05": True,
|
||||
"IN06": False,
|
||||
"IN07": True,
|
||||
"IN08": True,
|
||||
"IN09": False,
|
||||
"IN10": False,
|
||||
"IN11": False,
|
||||
"MID": True,
|
||||
"OUT00": True,
|
||||
"OUT01": True,
|
||||
"OUT02": True,
|
||||
"OUT03": True,
|
||||
"OUT04": True,
|
||||
"OUT05": True,
|
||||
"OUT06": False,
|
||||
"OUT07": False,
|
||||
"OUT08": False,
|
||||
"OUT09": False,
|
||||
"OUT10": False,
|
||||
"OUT11": False,
|
||||
}
|
||||
|
||||
LAYER17 = {
|
||||
"BASE": True,
|
||||
"IN00": False,
|
||||
"IN01": True,
|
||||
"IN02": True,
|
||||
"IN03": False,
|
||||
"IN04": True,
|
||||
"IN05": True,
|
||||
"IN06": False,
|
||||
"IN07": True,
|
||||
"IN08": True,
|
||||
"IN09": False,
|
||||
"IN10": False,
|
||||
"IN11": False,
|
||||
"MID": True,
|
||||
"OUT00": False,
|
||||
"OUT01": False,
|
||||
"OUT02": False,
|
||||
"OUT03": True,
|
||||
"OUT04": True,
|
||||
"OUT05": True,
|
||||
"OUT06": True,
|
||||
"OUT07": True,
|
||||
"OUT08": True,
|
||||
"OUT09": True,
|
||||
"OUT10": True,
|
||||
"OUT11": True,
|
||||
}
|
||||
|
||||
LAYER20 = {
|
||||
"BASE": True,
|
||||
"IN00": True,
|
||||
"IN01": True,
|
||||
"IN02": True,
|
||||
"IN03": True,
|
||||
"IN04": True,
|
||||
"IN05": True,
|
||||
"IN06": True,
|
||||
"IN07": True,
|
||||
"IN08": True,
|
||||
"IN09": False,
|
||||
"IN10": False,
|
||||
"IN11": False,
|
||||
"MID": True,
|
||||
"OUT00": True,
|
||||
"OUT01": True,
|
||||
"OUT02": True,
|
||||
"OUT03": True,
|
||||
"OUT04": True,
|
||||
"OUT05": True,
|
||||
"OUT06": True,
|
||||
"OUT07": True,
|
||||
"OUT08": True,
|
||||
"OUT09": False,
|
||||
"OUT10": False,
|
||||
"OUT11": False,
|
||||
}
|
||||
|
||||
LAYER26 = {
|
||||
"BASE": True,
|
||||
"IN00": True,
|
||||
"IN01": True,
|
||||
"IN02": True,
|
||||
"IN03": True,
|
||||
"IN04": True,
|
||||
"IN05": True,
|
||||
"IN06": True,
|
||||
"IN07": True,
|
||||
"IN08": True,
|
||||
"IN09": True,
|
||||
"IN10": True,
|
||||
"IN11": True,
|
||||
"MID": True,
|
||||
"OUT00": True,
|
||||
"OUT01": True,
|
||||
"OUT02": True,
|
||||
"OUT03": True,
|
||||
"OUT04": True,
|
||||
"OUT05": True,
|
||||
"OUT06": True,
|
||||
"OUT07": True,
|
||||
"OUT08": True,
|
||||
"OUT09": True,
|
||||
"OUT10": True,
|
||||
"OUT11": True,
|
||||
}
|
||||
|
||||
assert len([v for v in LAYER12.values() if v]) == 12
|
||||
assert len([v for v in LAYER17.values() if v]) == 17
|
||||
assert len([v for v in LAYER20.values() if v]) == 20
|
||||
assert len([v for v in LAYER26.values() if v]) == 26
|
||||
|
||||
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
||||
|
||||
|
||||
def get_lbw_block_index(lora_name: str, is_sdxl: bool = False) -> int:
|
||||
# lbw block index is 0-based, but 0 for text encoder, so we return 0 for text encoder
|
||||
if "text_model_encoder_" in lora_name: # LoRA for text encoder
|
||||
return 0
|
||||
|
||||
# lbw block index is 1-based for U-Net, and no "input_blocks.0" in CompVis SD, so "input_blocks.1" have index 2
|
||||
block_idx = -1 # invalid lora name
|
||||
if not is_sdxl:
|
||||
NUM_OF_BLOCKS = 12 # up/down blocks
|
||||
m = RE_UPDOWN.search(lora_name)
|
||||
if m:
|
||||
g = m.groups()
|
||||
up_down = g[0]
|
||||
i = int(g[1])
|
||||
j = int(g[3])
|
||||
if up_down == "down":
|
||||
if g[2] == "resnets" or g[2] == "attentions":
|
||||
idx = 3 * i + j + 1
|
||||
elif g[2] == "downsamplers":
|
||||
idx = 3 * (i + 1)
|
||||
else:
|
||||
return block_idx # invalid lora name
|
||||
elif up_down == "up":
|
||||
if g[2] == "resnets" or g[2] == "attentions":
|
||||
idx = 3 * i + j
|
||||
elif g[2] == "upsamplers":
|
||||
idx = 3 * i + 2
|
||||
else:
|
||||
return block_idx # invalid lora name
|
||||
|
||||
if g[0] == "down":
|
||||
block_idx = 1 + idx # 1-based index, down block index
|
||||
elif g[0] == "up":
|
||||
block_idx = 1 + NUM_OF_BLOCKS + 1 + idx # 1-based index, num blocks, mid block, up block index
|
||||
|
||||
elif "mid_block_" in lora_name:
|
||||
block_idx = 1 + NUM_OF_BLOCKS # 1-based index, num blocks, mid block
|
||||
else:
|
||||
# SDXL: some numbers are skipped
|
||||
if lora_name.startswith("lora_unet_"):
|
||||
name = lora_name[len("lora_unet_") :]
|
||||
if name.startswith("time_embed_") or name.startswith("label_emb_"): # 1, No LoRA in sd-scripts
|
||||
block_idx = 1
|
||||
elif name.startswith("input_blocks_"): # 1-8 to 2-9
|
||||
block_idx = 1 + int(name.split("_")[2])
|
||||
elif name.startswith("middle_block_"): # 13
|
||||
block_idx = 13
|
||||
elif name.startswith("output_blocks_"): # 0-8 to 14-22
|
||||
block_idx = 14 + int(name.split("_")[2])
|
||||
elif name.startswith("out_"): # 23, No LoRA in sd-scripts
|
||||
block_idx = 23
|
||||
|
||||
return block_idx
|
||||
|
||||
|
||||
def load_state_dict(file_name, dtype):
|
||||
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||
@@ -28,25 +216,54 @@ def load_state_dict(file_name, dtype):
|
||||
return sd, metadata
|
||||
|
||||
|
||||
def save_to_file(file_name, state_dict, dtype, metadata):
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
if type(state_dict[key]) == torch.Tensor:
|
||||
state_dict[key] = state_dict[key].to(dtype)
|
||||
|
||||
def save_to_file(file_name, state_dict, metadata):
|
||||
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||
save_file(state_dict, file_name, metadata=metadata)
|
||||
else:
|
||||
torch.save(state_dict, file_name)
|
||||
|
||||
|
||||
def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype):
|
||||
print(f"new rank: {new_rank}, new conv rank: {new_conv_rank}")
|
||||
def format_lbws(lbws):
|
||||
try:
|
||||
# lbwは"[1,1,1,1,1,1,1,1,1,1,1,1]"のような文字列で与えられることを期待している
|
||||
lbws = [json.loads(lbw) for lbw in lbws]
|
||||
except Exception:
|
||||
raise ValueError(f"format of lbws are must be json / 層別適用率はJSON形式で書いてください")
|
||||
assert all(isinstance(lbw, list) for lbw in lbws), f"lbws are must be list / 層別適用率はリストにしてください"
|
||||
assert len(set(len(lbw) for lbw in lbws)) == 1, "all lbws should have the same length / 層別適用率は同じ長さにしてください"
|
||||
assert all(
|
||||
len(lbw) in ACCEPTABLE for lbw in lbws
|
||||
), f"length of lbw are must be in {ACCEPTABLE} / 層別適用率の長さは{ACCEPTABLE}のいずれかにしてください"
|
||||
assert all(
|
||||
all(isinstance(weight, (int, float)) for weight in lbw) for lbw in lbws
|
||||
), f"values of lbs are must be numbers / 層別適用率の値はすべて数値にしてください"
|
||||
|
||||
layer_num = len(lbws[0])
|
||||
is_sdxl = True if layer_num in SDXL_LAYER_NUM else False
|
||||
FLAGS = {
|
||||
"12": LAYER12.values(),
|
||||
"17": LAYER17.values(),
|
||||
"20": LAYER20.values(),
|
||||
"26": LAYER26.values(),
|
||||
}[str(layer_num)]
|
||||
LBW_TARGET_IDX = [i for i, flag in enumerate(FLAGS) if flag]
|
||||
return lbws, is_sdxl, LBW_TARGET_IDX
|
||||
|
||||
|
||||
def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, merge_dtype):
|
||||
logger.info(f"new rank: {new_rank}, new conv rank: {new_conv_rank}")
|
||||
merged_sd = {}
|
||||
v2 = None
|
||||
v2 = None # This is meaning LoRA Metadata v2, Not meaning SD2
|
||||
base_model = None
|
||||
for model, ratio in zip(models, ratios):
|
||||
print(f"loading: {model}")
|
||||
|
||||
if lbws:
|
||||
lbws, is_sdxl, LBW_TARGET_IDX = format_lbws(lbws)
|
||||
else:
|
||||
is_sdxl = False
|
||||
LBW_TARGET_IDX = []
|
||||
|
||||
for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws):
|
||||
logger.info(f"loading: {model}")
|
||||
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
|
||||
|
||||
if lora_metadata is not None:
|
||||
@@ -55,8 +272,14 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
|
||||
if base_model is None:
|
||||
base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None)
|
||||
|
||||
if lbw:
|
||||
lbw_weights = [1] * 26
|
||||
for index, value in zip(LBW_TARGET_IDX, lbw):
|
||||
lbw_weights[index] = value
|
||||
logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}")
|
||||
|
||||
# merge
|
||||
print(f"merging...")
|
||||
logger.info(f"merging...")
|
||||
for key in tqdm(list(lora_sd.keys())):
|
||||
if "lora_down" not in key:
|
||||
continue
|
||||
@@ -73,15 +296,15 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
|
||||
out_dim = up_weight.size()[0]
|
||||
conv2d = len(down_weight.size()) == 4
|
||||
kernel_size = None if not conv2d else down_weight.size()[2:4]
|
||||
# print(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size)
|
||||
# logger.info(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size)
|
||||
|
||||
# make original weight if not exist
|
||||
if lora_module_name not in merged_sd:
|
||||
weight = torch.zeros((out_dim, in_dim, *kernel_size) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
|
||||
if device:
|
||||
weight = weight.to(device)
|
||||
else:
|
||||
weight = merged_sd[lora_module_name]
|
||||
if device:
|
||||
weight = weight.to(device)
|
||||
|
||||
# merge to weight
|
||||
if device:
|
||||
@@ -91,6 +314,12 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
|
||||
# W <- W + U * D
|
||||
scale = alpha / network_dim
|
||||
|
||||
if lbw:
|
||||
index = get_lbw_block_index(key, is_sdxl)
|
||||
is_lbw_target = index in LBW_TARGET_IDX
|
||||
if is_lbw_target:
|
||||
scale *= lbw_weights[index] # keyがlbwの対象であれば、lbwの重みを掛ける
|
||||
|
||||
if device: # and isinstance(scale, torch.Tensor):
|
||||
scale = scale.to(device)
|
||||
|
||||
@@ -107,13 +336,16 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
weight = weight + ratio * conved * scale
|
||||
|
||||
merged_sd[lora_module_name] = weight
|
||||
merged_sd[lora_module_name] = weight.to("cpu")
|
||||
|
||||
# extract from merged weights
|
||||
print("extract new lora...")
|
||||
logger.info("extract new lora...")
|
||||
merged_lora_sd = {}
|
||||
with torch.no_grad():
|
||||
for lora_module_name, mat in tqdm(list(merged_sd.items())):
|
||||
if device:
|
||||
mat = mat.to(device)
|
||||
|
||||
conv2d = len(mat.size()) == 4
|
||||
kernel_size = None if not conv2d else mat.size()[2:4]
|
||||
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
||||
@@ -152,7 +384,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
|
||||
|
||||
merged_lora_sd[lora_module_name + ".lora_up.weight"] = up_weight.to("cpu").contiguous()
|
||||
merged_lora_sd[lora_module_name + ".lora_down.weight"] = down_weight.to("cpu").contiguous()
|
||||
merged_lora_sd[lora_module_name + ".alpha"] = torch.tensor(module_new_rank)
|
||||
merged_lora_sd[lora_module_name + ".alpha"] = torch.tensor(module_new_rank, device="cpu")
|
||||
|
||||
# build minimum metadata
|
||||
dims = f"{new_rank}"
|
||||
@@ -167,7 +399,15 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
|
||||
|
||||
|
||||
def merge(args):
|
||||
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
|
||||
assert len(args.models) == len(
|
||||
args.ratios
|
||||
), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
|
||||
if args.lbws:
|
||||
assert len(args.models) == len(
|
||||
args.lbws
|
||||
), f"number of models must be equal to number of ratios / モデルの数と層別適用率の数は合わせてください"
|
||||
else:
|
||||
args.lbws = [] # zip_longestで扱えるようにlbws未使用時には空のリストにしておく
|
||||
|
||||
def str_to_dtype(p):
|
||||
if p == "float":
|
||||
@@ -185,10 +425,16 @@ def merge(args):
|
||||
|
||||
new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank
|
||||
state_dict, metadata, v2, base_model = merge_lora_models(
|
||||
args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype
|
||||
args.models, args.ratios, args.lbws, args.new_rank, new_conv_rank, args.device, merge_dtype
|
||||
)
|
||||
|
||||
print(f"calculating hashes and creating metadata...")
|
||||
# cast to save_dtype before calculating hashes
|
||||
for key in list(state_dict.keys()):
|
||||
value = state_dict[key]
|
||||
if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype:
|
||||
state_dict[key] = value.to(save_dtype)
|
||||
|
||||
logger.info(f"calculating hashes and creating metadata...")
|
||||
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
@@ -203,13 +449,13 @@ def merge(args):
|
||||
)
|
||||
if v2:
|
||||
# TODO read sai modelspec
|
||||
print(
|
||||
logger.warning(
|
||||
"Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
|
||||
)
|
||||
metadata.update(sai_metadata)
|
||||
|
||||
print(f"saving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, save_dtype, metadata)
|
||||
logger.info(f"saving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, metadata)
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
@@ -229,12 +475,19 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
|
||||
"--save_to",
|
||||
type=str,
|
||||
default=None,
|
||||
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
|
||||
"--models",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors",
|
||||
)
|
||||
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
|
||||
parser.add_argument("--lbws", type=str, nargs="*", help="lbw for each model / それぞれのLoRAモデルの層別適用率")
|
||||
parser.add_argument("--new_rank", type=int, default=4, help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
|
||||
parser.add_argument(
|
||||
"--new_conv_rank",
|
||||
@@ -242,7 +495,9 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
default=None,
|
||||
help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ",
|
||||
)
|
||||
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
||||
parser.add_argument(
|
||||
"--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_metadata",
|
||||
action="store_true",
|
||||
|
||||
8
pytest.ini
Normal file
8
pytest.ini
Normal file
@@ -0,0 +1,8 @@
|
||||
[pytest]
|
||||
minversion = 6.0
|
||||
testpaths =
|
||||
tests
|
||||
filterwarnings =
|
||||
ignore::DeprecationWarning
|
||||
ignore::UserWarning
|
||||
ignore::FutureWarning
|
||||
@@ -1,20 +1,28 @@
|
||||
accelerate==0.25.0
|
||||
transformers==4.36.2
|
||||
accelerate==0.33.0
|
||||
transformers==4.44.0
|
||||
diffusers[torch]==0.25.0
|
||||
ftfy==6.1.1
|
||||
# albumentations==1.3.0
|
||||
opencv-python==4.7.0.68
|
||||
einops==0.6.1
|
||||
opencv-python==4.8.1.78
|
||||
einops==0.7.0
|
||||
pytorch-lightning==1.9.0
|
||||
# bitsandbytes==0.39.1
|
||||
tensorboard==2.10.1
|
||||
safetensors==0.3.1
|
||||
bitsandbytes==0.44.0
|
||||
lion-pytorch==0.0.6
|
||||
schedulefree==1.4
|
||||
pytorch-optimizer==3.5.0
|
||||
prodigy-plus-schedule-free==1.9.0
|
||||
prodigyopt==1.1.2
|
||||
tensorboard
|
||||
safetensors==0.4.4
|
||||
# gradio==3.16.2
|
||||
altair==4.2.2
|
||||
easygui==0.98.3
|
||||
toml==0.10.2
|
||||
voluptuous==0.13.1
|
||||
huggingface-hub==0.20.1
|
||||
huggingface-hub==0.24.5
|
||||
# for Image utils
|
||||
imagesize==1.4.1
|
||||
numpy<=2.0
|
||||
# for BLIP captioning
|
||||
# requests==2.28.2
|
||||
# timm==0.6.12
|
||||
@@ -22,12 +30,19 @@ huggingface-hub==0.20.1
|
||||
# for WD14 captioning (tensorflow)
|
||||
# tensorflow==2.10.1
|
||||
# for WD14 captioning (onnx)
|
||||
# onnx==1.14.1
|
||||
# onnxruntime-gpu==1.16.0
|
||||
# onnxruntime==1.16.0
|
||||
# onnx==1.15.0
|
||||
# onnxruntime-gpu==1.17.1
|
||||
# onnxruntime==1.17.1
|
||||
# for cuda 12.1(default 11.8)
|
||||
# onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/
|
||||
|
||||
# this is for onnx:
|
||||
# protobuf==3.20.3
|
||||
# open clip for SDXL
|
||||
open-clip-torch==2.20.0
|
||||
# open-clip-torch==2.20.0
|
||||
# For logging
|
||||
rich==13.7.0
|
||||
# for T5XXL tokenizer (SD3/FLUX)
|
||||
sentencepiece==0.2.0
|
||||
# for kohya_ss library
|
||||
-e .
|
||||
|
||||
407
sd3_minimal_inference.py
Normal file
407
sd3_minimal_inference.py
Normal file
@@ -0,0 +1,407 @@
|
||||
# Minimum Inference Code for SD3
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from typing import Optional, Tuple
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from safetensors.torch import safe_open, load_file
|
||||
import torch.amp
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
from transformers import CLIPTextModelWithProjection, T5EncoderModel
|
||||
|
||||
from library.device_utils import init_ipex, get_preferred_device
|
||||
from networks import lora_sd3
|
||||
|
||||
init_ipex()
|
||||
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from library import sd3_models, sd3_utils, strategy_sd3
|
||||
from library.utils import load_safetensors
|
||||
|
||||
|
||||
def get_noise(seed, latent, device="cpu"):
|
||||
# generator = torch.manual_seed(seed)
|
||||
generator = torch.Generator(device)
|
||||
generator.manual_seed(seed)
|
||||
return torch.randn(latent.size(), dtype=latent.dtype, layout=latent.layout, generator=generator, device=device)
|
||||
|
||||
|
||||
def get_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps):
|
||||
start = sampling.timestep(sampling.sigma_max)
|
||||
end = sampling.timestep(sampling.sigma_min)
|
||||
timesteps = torch.linspace(start, end, steps)
|
||||
sigs = []
|
||||
for x in range(len(timesteps)):
|
||||
ts = timesteps[x]
|
||||
sigs.append(sampling.sigma(ts))
|
||||
sigs += [0.0]
|
||||
return torch.FloatTensor(sigs)
|
||||
|
||||
|
||||
def max_denoise(model_sampling, sigmas):
|
||||
max_sigma = float(model_sampling.sigma_max)
|
||||
sigma = float(sigmas[0])
|
||||
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
||||
|
||||
|
||||
def do_sample(
|
||||
height: int,
|
||||
width: int,
|
||||
initial_latent: Optional[torch.Tensor],
|
||||
seed: int,
|
||||
cond: Tuple[torch.Tensor, torch.Tensor],
|
||||
neg_cond: Tuple[torch.Tensor, torch.Tensor],
|
||||
mmdit: sd3_models.MMDiT,
|
||||
steps: int,
|
||||
cfg_scale: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
):
|
||||
if initial_latent is None:
|
||||
# latent = torch.ones(1, 16, height // 8, width // 8, device=device) * 0.0609 # this seems to be a bug in the original code. thanks to furusu for pointing it out
|
||||
latent = torch.zeros(1, 16, height // 8, width // 8, device=device)
|
||||
else:
|
||||
latent = initial_latent
|
||||
|
||||
latent = latent.to(dtype).to(device)
|
||||
|
||||
noise = get_noise(seed, latent, device)
|
||||
|
||||
model_sampling = sd3_utils.ModelSamplingDiscreteFlow(shift=3.0) # 3.0 is for SD3
|
||||
|
||||
sigmas = get_sigmas(model_sampling, steps).to(device)
|
||||
# sigmas = sigmas[int(steps * (1 - denoise)) :] # do not support i2i
|
||||
|
||||
# conditioning = fix_cond(conditioning)
|
||||
# neg_cond = fix_cond(neg_cond)
|
||||
# extra_args = {"cond": cond, "uncond": neg_cond, "cond_scale": guidance_scale}
|
||||
|
||||
noise_scaled = model_sampling.noise_scaling(sigmas[0], noise, latent, max_denoise(model_sampling, sigmas))
|
||||
|
||||
c_crossattn = torch.cat([cond[0], neg_cond[0]]).to(device).to(dtype)
|
||||
y = torch.cat([cond[1], neg_cond[1]]).to(device).to(dtype)
|
||||
|
||||
x = noise_scaled.to(device).to(dtype)
|
||||
# print(x.shape)
|
||||
|
||||
with torch.no_grad():
|
||||
for i in tqdm(range(len(sigmas) - 1)):
|
||||
sigma_hat = sigmas[i]
|
||||
|
||||
timestep = model_sampling.timestep(sigma_hat).float()
|
||||
timestep = torch.FloatTensor([timestep, timestep]).to(device)
|
||||
|
||||
x_c_nc = torch.cat([x, x], dim=0)
|
||||
# print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape)
|
||||
|
||||
with torch.autocast(device_type=device.type, dtype=dtype):
|
||||
model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y)
|
||||
model_output = model_output.float()
|
||||
batched = model_sampling.calculate_denoised(sigma_hat, model_output, x)
|
||||
|
||||
pos_out, neg_out = batched.chunk(2)
|
||||
denoised = neg_out + (pos_out - neg_out) * cfg_scale
|
||||
# print(denoised.shape)
|
||||
|
||||
# d = to_d(x, sigma_hat, denoised)
|
||||
dims_to_append = x.ndim - sigma_hat.ndim
|
||||
sigma_hat_dims = sigma_hat[(...,) + (None,) * dims_to_append]
|
||||
# print(dims_to_append, x.shape, sigma_hat.shape, denoised.shape, sigma_hat_dims.shape)
|
||||
"""Converts a denoiser output to a Karras ODE derivative."""
|
||||
d = (x - denoised) / sigma_hat_dims
|
||||
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
|
||||
# Euler method
|
||||
x = x + d * dt
|
||||
x = x.to(dtype)
|
||||
|
||||
latent = x
|
||||
latent = vae.process_out(latent)
|
||||
return latent
|
||||
|
||||
|
||||
def generate_image(
|
||||
mmdit: sd3_models.MMDiT,
|
||||
vae: sd3_models.SDVAE,
|
||||
clip_l: CLIPTextModelWithProjection,
|
||||
clip_g: CLIPTextModelWithProjection,
|
||||
t5xxl: T5EncoderModel,
|
||||
steps: int,
|
||||
prompt: str,
|
||||
seed: int,
|
||||
target_width: int,
|
||||
target_height: int,
|
||||
device: str,
|
||||
negative_prompt: str,
|
||||
cfg_scale: float,
|
||||
):
|
||||
# prepare embeddings
|
||||
logger.info("Encoding prompts...")
|
||||
|
||||
# TODO support one-by-one offloading
|
||||
clip_l.to(device)
|
||||
clip_g.to(device)
|
||||
t5xxl.to(device)
|
||||
|
||||
with torch.autocast(device_type=device.type, dtype=mmdit.dtype), torch.no_grad():
|
||||
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
||||
lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask
|
||||
)
|
||||
cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
|
||||
|
||||
tokens_and_masks = tokenize_strategy.tokenize(negative_prompt)
|
||||
lg_out, t5_out, pooled, neg_l_attn_mask, neg_g_attn_mask, neg_t5_attn_mask = encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask
|
||||
)
|
||||
neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
|
||||
|
||||
# attn masks are not used currently
|
||||
|
||||
if args.offload:
|
||||
clip_l.to("cpu")
|
||||
clip_g.to("cpu")
|
||||
t5xxl.to("cpu")
|
||||
|
||||
# generate image
|
||||
logger.info("Generating image...")
|
||||
mmdit.to(device)
|
||||
latent_sampled = do_sample(target_height, target_width, None, seed, cond, neg_cond, mmdit, steps, cfg_scale, sd3_dtype, device)
|
||||
if args.offload:
|
||||
mmdit.to("cpu")
|
||||
|
||||
# latent to image
|
||||
vae.to(device)
|
||||
with torch.no_grad():
|
||||
image = vae.decode(latent_sampled)
|
||||
|
||||
if args.offload:
|
||||
vae.to("cpu")
|
||||
|
||||
image = image.float()
|
||||
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0]
|
||||
decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2)
|
||||
decoded_np = decoded_np.astype(np.uint8)
|
||||
out_image = Image.fromarray(decoded_np)
|
||||
|
||||
# save image
|
||||
output_dir = args.output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
output_path = os.path.join(output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png")
|
||||
out_image.save(output_path)
|
||||
|
||||
logger.info(f"Saved image to {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
target_height = 1024
|
||||
target_width = 1024
|
||||
|
||||
# steps = 50 # 28 # 50
|
||||
# cfg_scale = 5
|
||||
# seed = 1 # None # 1
|
||||
|
||||
device = get_preferred_device()
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ckpt_path", type=str, required=True)
|
||||
parser.add_argument("--clip_g", type=str, required=False)
|
||||
parser.add_argument("--clip_l", type=str, required=False)
|
||||
parser.add_argument("--t5xxl", type=str, required=False)
|
||||
parser.add_argument("--t5xxl_token_length", type=int, default=256, help="t5xxl token length, default: 256")
|
||||
parser.add_argument("--apply_lg_attn_mask", action="store_true")
|
||||
parser.add_argument("--apply_t5_attn_mask", action="store_true")
|
||||
parser.add_argument("--prompt", type=str, default="A photo of a cat")
|
||||
# parser.add_argument("--prompt2", type=str, default=None) # do not support different prompts for text encoders
|
||||
parser.add_argument("--negative_prompt", type=str, default="")
|
||||
parser.add_argument("--cfg_scale", type=float, default=5.0)
|
||||
parser.add_argument("--offload", action="store_true", help="Offload to CPU")
|
||||
parser.add_argument("--output_dir", type=str, default=".")
|
||||
# parser.add_argument("--do_not_use_t5xxl", action="store_true")
|
||||
# parser.add_argument("--attn_mode", type=str, default="torch", help="torch (SDPA) or xformers. default: torch")
|
||||
parser.add_argument("--fp16", action="store_true")
|
||||
parser.add_argument("--bf16", action="store_true")
|
||||
parser.add_argument("--seed", type=int, default=1)
|
||||
parser.add_argument("--steps", type=int, default=50)
|
||||
parser.add_argument(
|
||||
"--lora_weights",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=[],
|
||||
help="LoRA weights, only supports networks.lora_sd3, each argument is a `path;multiplier` (semi-colon separated)",
|
||||
)
|
||||
parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model")
|
||||
parser.add_argument("--width", type=int, default=target_width)
|
||||
parser.add_argument("--height", type=int, default=target_height)
|
||||
parser.add_argument("--interactive", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
seed = args.seed
|
||||
steps = args.steps
|
||||
|
||||
sd3_dtype = torch.float32
|
||||
if args.fp16:
|
||||
sd3_dtype = torch.float16
|
||||
elif args.bf16:
|
||||
sd3_dtype = torch.bfloat16
|
||||
|
||||
loading_device = "cpu" if args.offload else device
|
||||
|
||||
# load state dict
|
||||
logger.info(f"Loading SD3 models from {args.ckpt_path}...")
|
||||
# state_dict = load_file(args.ckpt_path)
|
||||
state_dict = load_safetensors(args.ckpt_path, loading_device, disable_mmap=True, dtype=sd3_dtype)
|
||||
|
||||
# load text encoders
|
||||
clip_l = sd3_utils.load_clip_l(args.clip_l, sd3_dtype, loading_device, state_dict=state_dict)
|
||||
clip_g = sd3_utils.load_clip_g(args.clip_g, sd3_dtype, loading_device, state_dict=state_dict)
|
||||
t5xxl = sd3_utils.load_t5xxl(args.t5xxl, sd3_dtype, loading_device, state_dict=state_dict)
|
||||
|
||||
# MMDiT and VAE
|
||||
vae = sd3_utils.load_vae(None, sd3_dtype, loading_device, state_dict=state_dict)
|
||||
mmdit = sd3_utils.load_mmdit(state_dict, sd3_dtype, loading_device)
|
||||
|
||||
clip_l.to(sd3_dtype)
|
||||
clip_g.to(sd3_dtype)
|
||||
t5xxl.to(sd3_dtype)
|
||||
vae.to(sd3_dtype)
|
||||
mmdit.to(sd3_dtype)
|
||||
if not args.offload:
|
||||
# make sure to move to the device: some tensors are created in the constructor on the CPU
|
||||
clip_l.to(device)
|
||||
clip_g.to(device)
|
||||
t5xxl.to(device)
|
||||
vae.to(device)
|
||||
mmdit.to(device)
|
||||
|
||||
clip_l.eval()
|
||||
clip_g.eval()
|
||||
t5xxl.eval()
|
||||
mmdit.eval()
|
||||
vae.eval()
|
||||
|
||||
# load tokenizers
|
||||
logger.info("Loading tokenizers...")
|
||||
tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_token_length)
|
||||
encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy()
|
||||
|
||||
# LoRA
|
||||
lora_models: list[lora_sd3.LoRANetwork] = []
|
||||
for weights_file in args.lora_weights:
|
||||
if ";" in weights_file:
|
||||
weights_file, multiplier = weights_file.split(";")
|
||||
multiplier = float(multiplier)
|
||||
else:
|
||||
multiplier = 1.0
|
||||
|
||||
weights_sd = load_file(weights_file)
|
||||
module = lora_sd3
|
||||
lora_model, _ = module.create_network_from_weights(multiplier, None, vae, [clip_l, clip_g, t5xxl], mmdit, weights_sd, True)
|
||||
|
||||
if args.merge_lora_weights:
|
||||
lora_model.merge_to([clip_l, clip_g, t5xxl], mmdit, weights_sd)
|
||||
else:
|
||||
lora_model.apply_to([clip_l, clip_g, t5xxl], mmdit)
|
||||
info = lora_model.load_state_dict(weights_sd, strict=True)
|
||||
logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
|
||||
lora_model.eval()
|
||||
lora_model.to(device)
|
||||
|
||||
lora_models.append(lora_model)
|
||||
|
||||
if not args.interactive:
|
||||
generate_image(
|
||||
mmdit,
|
||||
vae,
|
||||
clip_l,
|
||||
clip_g,
|
||||
t5xxl,
|
||||
args.steps,
|
||||
args.prompt,
|
||||
args.seed,
|
||||
args.width,
|
||||
args.height,
|
||||
device,
|
||||
args.negative_prompt,
|
||||
args.cfg_scale,
|
||||
)
|
||||
else:
|
||||
# loop for interactive
|
||||
width = args.width
|
||||
height = args.height
|
||||
steps = None
|
||||
cfg_scale = args.cfg_scale
|
||||
|
||||
while True:
|
||||
print(
|
||||
"Enter prompt (empty to exit). Options: --w <width> --h <height> --s <steps> --d <seed>"
|
||||
" --n <negative prompt>, `--n -` for empty negative prompt"
|
||||
"Options are kept for the next prompt. Current options:"
|
||||
f" width={width}, height={height}, steps={steps}, seed={seed}, cfg_scale={cfg_scale}"
|
||||
)
|
||||
prompt = input()
|
||||
if prompt == "":
|
||||
break
|
||||
|
||||
# parse options
|
||||
options = prompt.split("--")
|
||||
prompt = options[0].strip()
|
||||
seed = None
|
||||
negative_prompt = None
|
||||
for opt in options[1:]:
|
||||
try:
|
||||
opt = opt.strip()
|
||||
if opt.startswith("w"):
|
||||
width = int(opt[1:].strip())
|
||||
elif opt.startswith("h"):
|
||||
height = int(opt[1:].strip())
|
||||
elif opt.startswith("s"):
|
||||
steps = int(opt[1:].strip())
|
||||
elif opt.startswith("d"):
|
||||
seed = int(opt[1:].strip())
|
||||
elif opt.startswith("m"):
|
||||
mutipliers = opt[1:].strip().split(",")
|
||||
if len(mutipliers) != len(lora_models):
|
||||
logger.error(f"Invalid number of multipliers, expected {len(lora_models)}")
|
||||
continue
|
||||
for i, lora_model in enumerate(lora_models):
|
||||
lora_model.set_multiplier(float(mutipliers[i]))
|
||||
elif opt.startswith("n"):
|
||||
negative_prompt = opt[1:].strip()
|
||||
if negative_prompt == "-":
|
||||
negative_prompt = ""
|
||||
elif opt.startswith("c"):
|
||||
cfg_scale = float(opt[1:].strip())
|
||||
except ValueError as e:
|
||||
logger.error(f"Invalid option: {opt}, {e}")
|
||||
|
||||
generate_image(
|
||||
mmdit,
|
||||
vae,
|
||||
clip_l,
|
||||
clip_g,
|
||||
t5xxl,
|
||||
steps if steps is not None else args.steps,
|
||||
prompt,
|
||||
seed if seed is not None else args.seed,
|
||||
width,
|
||||
height,
|
||||
device,
|
||||
negative_prompt if negative_prompt is not None else args.negative_prompt,
|
||||
cfg_scale,
|
||||
)
|
||||
|
||||
logger.info("Done!")
|
||||
1074
sd3_train.py
Normal file
1074
sd3_train.py
Normal file
File diff suppressed because it is too large
Load Diff
496
sd3_train_network.py
Normal file
496
sd3_train_network.py
Normal file
@@ -0,0 +1,496 @@
|
||||
import argparse
|
||||
import copy
|
||||
import math
|
||||
import random
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from library import sd3_models, strategy_sd3, utils
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, sd3_utils, strategy_base, strategy_sd3, train_util
|
||||
import train_network
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sample_prompts_te_outputs = None
|
||||
|
||||
def assert_extra_args(
|
||||
self,
|
||||
args,
|
||||
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
|
||||
val_dataset_group: Optional[train_util.DatasetGroup],
|
||||
):
|
||||
# super().assert_extra_args(args, train_dataset_group)
|
||||
# sdxl_train_util.verify_sdxl_training_args(args)
|
||||
|
||||
if args.fp8_base_unet:
|
||||
args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for SD3
|
||||
|
||||
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
||||
logger.warning(
|
||||
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります"
|
||||
)
|
||||
args.cache_text_encoder_outputs = True
|
||||
|
||||
if args.cache_text_encoder_outputs:
|
||||
assert (
|
||||
train_dataset_group.is_text_encoder_output_cacheable()
|
||||
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
||||
|
||||
# prepare CLIP-L/CLIP-G/T5XXL training flags
|
||||
self.train_clip = not args.network_train_unet_only
|
||||
self.train_t5xxl = False # default is False even if args.network_train_unet_only is False
|
||||
|
||||
if args.max_token_length is not None:
|
||||
logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")
|
||||
|
||||
assert (
|
||||
args.blocks_to_swap is None or args.blocks_to_swap == 0
|
||||
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.verify_bucket_reso_steps(32) # TODO check this
|
||||
|
||||
# enumerate resolutions from dataset for positional embeddings
|
||||
resolutions = train_dataset_group.get_resolutions()
|
||||
if val_dataset_group is not None:
|
||||
resolutions = resolutions + val_dataset_group.get_resolutions()
|
||||
self.resolutions = resolutions
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
# currently offload to cpu for some models
|
||||
|
||||
# if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
|
||||
loading_dtype = None if args.fp8_base else weight_dtype
|
||||
|
||||
# if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
|
||||
state_dict = utils.load_safetensors(
|
||||
args.pretrained_model_name_or_path, "cpu", disable_mmap=args.disable_mmap_load_safetensors, dtype=loading_dtype
|
||||
)
|
||||
mmdit = sd3_utils.load_mmdit(state_dict, loading_dtype, "cpu")
|
||||
self.model_type = mmdit.model_type
|
||||
mmdit.set_pos_emb_random_crop_rate(args.pos_emb_random_crop_rate)
|
||||
|
||||
# set resolutions for positional embeddings
|
||||
if args.enable_scaled_pos_embed:
|
||||
latent_sizes = [round(math.sqrt(res[0] * res[1])) // 8 for res in self.resolutions] # 8 is stride for latent
|
||||
latent_sizes = list(set(latent_sizes)) # remove duplicates
|
||||
logger.info(f"Prepare scaled positional embeddings for resolutions: {self.resolutions}, sizes: {latent_sizes}")
|
||||
mmdit.enable_scaled_pos_embed(True, latent_sizes)
|
||||
|
||||
if args.fp8_base:
|
||||
# check dtype of model
|
||||
if mmdit.dtype == torch.float8_e4m3fnuz or mmdit.dtype == torch.float8_e5m2 or mmdit.dtype == torch.float8_e5m2fnuz:
|
||||
raise ValueError(f"Unsupported fp8 model dtype: {mmdit.dtype}")
|
||||
elif mmdit.dtype == torch.float8_e4m3fn:
|
||||
logger.info("Loaded fp8 SD3 model")
|
||||
else:
|
||||
logger.info(
|
||||
"Cast SD3 model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint."
|
||||
" / SD3モデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。"
|
||||
)
|
||||
mmdit.to(torch.float8_e4m3fn)
|
||||
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
|
||||
if self.is_swapping_blocks:
|
||||
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
|
||||
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
|
||||
mmdit.enable_block_swap(args.blocks_to_swap, accelerator.device)
|
||||
|
||||
clip_l = sd3_utils.load_clip_l(
|
||||
args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict
|
||||
)
|
||||
clip_l.eval()
|
||||
clip_g = sd3_utils.load_clip_g(
|
||||
args.clip_g, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict
|
||||
)
|
||||
clip_g.eval()
|
||||
|
||||
# if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8)
|
||||
if args.fp8_base and not args.fp8_base_unet:
|
||||
loading_dtype = None # as is
|
||||
else:
|
||||
loading_dtype = weight_dtype
|
||||
|
||||
# loading t5xxl to cpu takes a long time, so we should load to gpu in future
|
||||
t5xxl = sd3_utils.load_t5xxl(
|
||||
args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict
|
||||
)
|
||||
t5xxl.eval()
|
||||
if args.fp8_base and not args.fp8_base_unet:
|
||||
# check dtype of model
|
||||
if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz:
|
||||
raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}")
|
||||
elif t5xxl.dtype == torch.float8_e4m3fn:
|
||||
logger.info("Loaded fp8 T5XXL model")
|
||||
|
||||
vae = sd3_utils.load_vae(
|
||||
args.vae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict
|
||||
)
|
||||
|
||||
return mmdit.model_type, [clip_l, clip_g, t5xxl], vae, mmdit
|
||||
|
||||
def get_tokenize_strategy(self, args):
|
||||
logger.info(f"t5xxl_max_token_length: {args.t5xxl_max_token_length}")
|
||||
return strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_max_token_length, args.tokenizer_cache_dir)
|
||||
|
||||
def get_tokenizers(self, tokenize_strategy: strategy_sd3.Sd3TokenizeStrategy):
|
||||
return [tokenize_strategy.clip_l, tokenize_strategy.clip_g, tokenize_strategy.t5xxl]
|
||||
|
||||
def get_latents_caching_strategy(self, args):
|
||||
latents_caching_strategy = strategy_sd3.Sd3LatentsCachingStrategy(
|
||||
args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
|
||||
)
|
||||
return latents_caching_strategy
|
||||
|
||||
def get_text_encoding_strategy(self, args):
|
||||
return strategy_sd3.Sd3TextEncodingStrategy(
|
||||
args.apply_lg_attn_mask,
|
||||
args.apply_t5_attn_mask,
|
||||
args.clip_l_dropout_rate,
|
||||
args.clip_g_dropout_rate,
|
||||
args.t5_dropout_rate,
|
||||
)
|
||||
|
||||
def post_process_network(self, args, accelerator, network, text_encoders, unet):
|
||||
# check t5xxl is trained or not
|
||||
self.train_t5xxl = network.train_t5xxl
|
||||
|
||||
if self.train_t5xxl and args.cache_text_encoder_outputs:
|
||||
raise ValueError(
|
||||
"T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません"
|
||||
)
|
||||
|
||||
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
||||
if args.cache_text_encoder_outputs:
|
||||
if self.train_clip and not self.train_t5xxl:
|
||||
return text_encoders[0:2] + [None] # only CLIP-L/CLIP-G is needed for encoding because T5XXL is cached
|
||||
else:
|
||||
return None # no text encoders are needed for encoding because both are cached
|
||||
else:
|
||||
return text_encoders # CLIP-L, CLIP-G and T5XXL are needed for encoding
|
||||
|
||||
def get_text_encoders_train_flags(self, args, text_encoders):
|
||||
return [self.train_clip, self.train_clip, self.train_t5xxl]
|
||||
|
||||
def get_text_encoder_outputs_caching_strategy(self, args):
|
||||
if args.cache_text_encoder_outputs:
|
||||
# if the text encoders is trained, we need tokenization, so is_partial is True
|
||||
return strategy_sd3.Sd3TextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
args.text_encoder_batch_size,
|
||||
args.skip_cache_check,
|
||||
is_partial=self.train_clip or self.train_t5xxl,
|
||||
apply_lg_attn_mask=args.apply_lg_attn_mask,
|
||||
apply_t5_attn_mask=args.apply_t5_attn_mask,
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
def cache_text_encoder_outputs_if_needed(
|
||||
self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
|
||||
):
|
||||
if args.cache_text_encoder_outputs:
|
||||
if not args.lowram:
|
||||
# メモリ消費を減らす
|
||||
logger.info("move vae and unet to cpu to save memory")
|
||||
org_vae_device = vae.device
|
||||
org_unet_device = unet.device
|
||||
vae.to("cpu")
|
||||
unet.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
||||
logger.info("move text encoders to gpu")
|
||||
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
|
||||
text_encoders[1].to(accelerator.device, dtype=weight_dtype) # always not fp8
|
||||
text_encoders[2].to(accelerator.device) # may be fp8
|
||||
|
||||
if text_encoders[2].dtype == torch.float8_e4m3fn:
|
||||
# if we load fp8 weights, the model is already fp8, so we use it as is
|
||||
self.prepare_text_encoder_fp8(2, text_encoders[2], text_encoders[2].dtype, weight_dtype)
|
||||
else:
|
||||
# otherwise, we need to convert it to target dtype
|
||||
text_encoders[2].to(weight_dtype)
|
||||
|
||||
with accelerator.autocast():
|
||||
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
|
||||
|
||||
# cache sample prompts
|
||||
if args.sample_prompts is not None:
|
||||
logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}")
|
||||
|
||||
tokenize_strategy: strategy_sd3.Sd3TokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
|
||||
text_encoding_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
|
||||
prompts = train_util.load_prompts(args.sample_prompts)
|
||||
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
for prompt_dict in prompts:
|
||||
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
|
||||
if p not in sample_prompts_te_outputs:
|
||||
logger.info(f"cache Text Encoder outputs for prompt: {p}")
|
||||
tokens_and_masks = tokenize_strategy.tokenize(p)
|
||||
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy,
|
||||
text_encoders,
|
||||
tokens_and_masks,
|
||||
args.apply_lg_attn_mask,
|
||||
args.apply_t5_attn_mask,
|
||||
)
|
||||
self.sample_prompts_te_outputs = sample_prompts_te_outputs
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# move back to cpu
|
||||
if not self.is_train_text_encoder(args):
|
||||
logger.info("move CLIP-L back to cpu")
|
||||
text_encoders[0].to("cpu")
|
||||
logger.info("move CLIP-G back to cpu")
|
||||
text_encoders[1].to("cpu")
|
||||
logger.info("move t5XXL back to cpu")
|
||||
text_encoders[2].to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
if not args.lowram:
|
||||
logger.info("move vae and unet back to original device")
|
||||
vae.to(org_vae_device)
|
||||
unet.to(org_unet_device)
|
||||
else:
|
||||
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
|
||||
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoders[2].to(accelerator.device)
|
||||
|
||||
# def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
||||
# noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
||||
|
||||
# # get size embeddings
|
||||
# orig_size = batch["original_sizes_hw"]
|
||||
# crop_size = batch["crop_top_lefts"]
|
||||
# target_size = batch["target_sizes_hw"]
|
||||
# embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
|
||||
|
||||
# # concat embeddings
|
||||
# encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
|
||||
# vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
|
||||
# text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
|
||||
|
||||
# noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
||||
# return noise_pred
|
||||
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, mmdit):
|
||||
text_encoders = text_encoder # for compatibility
|
||||
text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)
|
||||
|
||||
sd3_train_utils.sample_images(
|
||||
accelerator, args, epoch, global_step, mmdit, vae, text_encoders, self.sample_prompts_te_outputs
|
||||
)
|
||||
|
||||
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
|
||||
# this scheduler is not used in training, but used to get num_train_timesteps etc.
|
||||
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.training_shift)
|
||||
return noise_scheduler
|
||||
|
||||
def encode_images_to_latents(self, args, vae, images):
|
||||
return vae.encode(images)
|
||||
|
||||
def shift_scale_latents(self, args, latents):
|
||||
return sd3_models.SDVAE.process_in(latents)
|
||||
|
||||
def get_noise_pred_and_target(
|
||||
self,
|
||||
args,
|
||||
accelerator,
|
||||
noise_scheduler,
|
||||
latents,
|
||||
batch,
|
||||
text_encoder_conds,
|
||||
unet: flux_models.Flux,
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=True,
|
||||
):
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
|
||||
# get noisy model input and timesteps
|
||||
noisy_model_input, timesteps, sigmas = sd3_train_utils.get_noisy_model_input_and_timesteps(
|
||||
args, latents, noise, accelerator.device, weight_dtype
|
||||
)
|
||||
|
||||
# ensure the hidden state will require grad
|
||||
if args.gradient_checkpointing:
|
||||
noisy_model_input.requires_grad_(True)
|
||||
for t in text_encoder_conds:
|
||||
if t is not None and t.dtype.is_floating_point:
|
||||
t.requires_grad_(True)
|
||||
|
||||
# Predict the noise residual
|
||||
lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_conds
|
||||
text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
context, lg_pooled = text_encoding_strategy.concat_encodings(lg_out, t5_out, lg_pooled)
|
||||
if not args.apply_lg_attn_mask:
|
||||
l_attn_mask = None
|
||||
g_attn_mask = None
|
||||
if not args.apply_t5_attn_mask:
|
||||
t5_attn_mask = None
|
||||
|
||||
# call model
|
||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||
# TODO support attention mask
|
||||
model_pred = unet(noisy_model_input, timesteps, context=context, y=lg_pooled)
|
||||
|
||||
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
|
||||
# Preconditioning of the model outputs.
|
||||
model_pred = model_pred * (-sigmas) + noisy_model_input
|
||||
|
||||
# these weighting schemes use a uniform timestep sampling
|
||||
# and instead post-weight the loss
|
||||
weighting = sd3_train_utils.compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
|
||||
|
||||
# flow matching loss
|
||||
target = latents
|
||||
|
||||
# differential output preservation
|
||||
if "custom_attributes" in batch:
|
||||
diff_output_pr_indices = []
|
||||
for i, custom_attributes in enumerate(batch["custom_attributes"]):
|
||||
if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
|
||||
diff_output_pr_indices.append(i)
|
||||
|
||||
if len(diff_output_pr_indices) > 0:
|
||||
network.set_multiplier(0.0)
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
model_pred_prior = unet(
|
||||
noisy_model_input[diff_output_pr_indices],
|
||||
timesteps[diff_output_pr_indices],
|
||||
context=context[diff_output_pr_indices],
|
||||
y=lg_pooled[diff_output_pr_indices],
|
||||
)
|
||||
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
|
||||
|
||||
model_pred_prior = model_pred_prior * (-sigmas[diff_output_pr_indices]) + noisy_model_input[diff_output_pr_indices]
|
||||
|
||||
# weighting for differential output preservation is not needed because it is already applied
|
||||
|
||||
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
|
||||
|
||||
return model_pred, target, timesteps, weighting
|
||||
|
||||
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
||||
return loss
|
||||
|
||||
def get_sai_model_spec(self, args):
|
||||
return train_util.get_sai_model_spec(None, args, False, True, False, sd3=self.model_type)
|
||||
|
||||
def update_metadata(self, metadata, args):
|
||||
metadata["ss_apply_lg_attn_mask"] = args.apply_lg_attn_mask
|
||||
metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask
|
||||
metadata["ss_weighting_scheme"] = args.weighting_scheme
|
||||
metadata["ss_logit_mean"] = args.logit_mean
|
||||
metadata["ss_logit_std"] = args.logit_std
|
||||
metadata["ss_mode_scale"] = args.mode_scale
|
||||
|
||||
def is_text_encoder_not_needed_for_training(self, args):
|
||||
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
|
||||
|
||||
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
|
||||
if index == 0 or index == 1: # CLIP-L/CLIP-G
|
||||
return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder)
|
||||
else: # T5XXL
|
||||
text_encoder.encoder.embed_tokens.requires_grad_(True)
|
||||
|
||||
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
|
||||
if index == 0 or index == 1: # CLIP-L/CLIP-G
|
||||
clip_type = "CLIP-L" if index == 0 else "CLIP-G"
|
||||
logger.info(f"prepare CLIP-{clip_type} for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}")
|
||||
text_encoder.to(te_weight_dtype) # fp8
|
||||
text_encoder.text_model.embeddings.to(dtype=weight_dtype)
|
||||
else: # T5XXL
|
||||
|
||||
def prepare_fp8(text_encoder, target_dtype):
|
||||
def forward_hook(module):
|
||||
def forward(hidden_states):
|
||||
hidden_gelu = module.act(module.wi_0(hidden_states))
|
||||
hidden_linear = module.wi_1(hidden_states)
|
||||
hidden_states = hidden_gelu * hidden_linear
|
||||
hidden_states = module.dropout(hidden_states)
|
||||
|
||||
hidden_states = module.wo(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
return forward
|
||||
|
||||
for module in text_encoder.modules():
|
||||
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
|
||||
# print("set", module.__class__.__name__, "to", target_dtype)
|
||||
module.to(target_dtype)
|
||||
if module.__class__.__name__ in ["T5DenseGatedActDense"]:
|
||||
# print("set", module.__class__.__name__, "hooks")
|
||||
module.forward = forward_hook(module)
|
||||
|
||||
if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype:
|
||||
logger.info(f"T5XXL already prepared for fp8")
|
||||
else:
|
||||
logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks")
|
||||
text_encoder.to(te_weight_dtype) # fp8
|
||||
prepare_fp8(text_encoder, weight_dtype)
|
||||
|
||||
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True):
|
||||
# drop cached text encoder outputs: in validation, we drop cached outputs deterministically by fixed seed
|
||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||
if text_encoder_outputs_list is not None:
|
||||
text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list)
|
||||
batch["text_encoder_outputs_list"] = text_encoder_outputs_list
|
||||
|
||||
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
||||
if self.is_swapping_blocks:
|
||||
# prepare for next forward: because backward pass is not called, we need to prepare it here
|
||||
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
|
||||
|
||||
def prepare_unet_with_accelerator(
|
||||
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
|
||||
) -> torch.nn.Module:
|
||||
if not self.is_swapping_blocks:
|
||||
return super().prepare_unet_with_accelerator(args, accelerator, unet)
|
||||
|
||||
# if we doesn't swap blocks, we can move the model to device
|
||||
mmdit: sd3_models.MMDiT = unet
|
||||
mmdit = accelerator.prepare(mmdit, device_placement=[not self.is_swapping_blocks])
|
||||
accelerator.unwrap_model(mmdit).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
|
||||
accelerator.unwrap_model(mmdit).prepare_block_swap_before_forward()
|
||||
|
||||
return mmdit
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = train_network.setup_parser()
|
||||
train_util.add_dit_training_arguments(parser)
|
||||
sd3_train_utils.add_sd3_training_arguments(parser)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
train_util.verify_command_line_training_args(args)
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
trainer = Sd3NetworkTrainer()
|
||||
trainer.train(args)
|
||||
628
sdxl_gen_img.py
628
sdxl_gen_img.py
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user