mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
Compare commits
895 Commits
v0.5.1
...
nw_applica
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ae0872ba3b | ||
|
|
7f948db158 | ||
|
|
9d7729c00d | ||
|
|
988dee02b9 | ||
|
|
d4b9568269 | ||
|
|
ccc3a481e7 | ||
|
|
cd19df49cd | ||
|
|
736365bdd5 | ||
|
|
6ceedb9448 | ||
|
|
930a3912a7 | ||
|
|
cf790d87c4 | ||
|
|
4e67fb8444 | ||
|
|
50f631c768 | ||
|
|
85bc371ebc | ||
|
|
322ee52c77 | ||
|
|
c576f80639 | ||
|
|
d5ab97b69b | ||
|
|
7cb44e4502 | ||
|
|
7a20df5ad5 | ||
|
|
bea4362e21 | ||
|
|
6805cafa9b | ||
|
|
711b40ccda | ||
|
|
696dd7f668 | ||
|
|
e0a3c69223 | ||
|
|
c59249a664 | ||
|
|
fef172966f | ||
|
|
5a1ebc4c7c | ||
|
|
2a0f45aea9 | ||
|
|
1f77bb6e73 | ||
|
|
a7ef6422b6 | ||
|
|
9cfa68c92f | ||
|
|
6f3f701d3d | ||
|
|
d2a99a19d4 | ||
|
|
0395a35543 | ||
|
|
987d4a969d | ||
|
|
976d092c68 | ||
|
|
e6b15c7e4a | ||
|
|
ef50436464 | ||
|
|
26d35794e3 | ||
|
|
dcf0eeb5b6 | ||
|
|
32b759a328 | ||
|
|
09ef3ffa8b | ||
|
|
aab265e431 | ||
|
|
716bad188b | ||
|
|
4f93bf10f0 | ||
|
|
07bf2a21ac | ||
|
|
8ac2d2a92f | ||
|
|
76aee71257 | ||
|
|
1db5d790ed | ||
|
|
663b481029 | ||
|
|
1ab6493268 | ||
|
|
ab716302e4 | ||
|
|
b9d2181192 | ||
|
|
49148eb36e | ||
|
|
479bac447e | ||
|
|
15d5e78ac2 | ||
|
|
fd7f27f044 | ||
|
|
62e7516537 | ||
|
|
20296b4f0e | ||
|
|
5cae6db804 | ||
|
|
1a36f9dc65 | ||
|
|
c2497877ca | ||
|
|
3b5c1a1d4b | ||
|
|
9a2e385f12 | ||
|
|
7080e1a11c | ||
|
|
0a52b83c6a | ||
|
|
11ed8e2a6d | ||
|
|
bb20c09a9a | ||
|
|
04ef8d395f | ||
|
|
0676f1a86f | ||
|
|
6b7823df07 | ||
|
|
2186e417ba | ||
|
|
1519e3067c | ||
|
|
35e5424255 | ||
|
|
8c7d05afd2 | ||
|
|
f8360a4831 | ||
|
|
8556b9d7f5 | ||
|
|
3efd90b2ad | ||
|
|
7adcd9cd1a | ||
|
|
aff05e043f | ||
|
|
ff2c0c192e | ||
|
|
d309a27a51 | ||
|
|
471d274803 | ||
|
|
35f4c9b5c7 | ||
|
|
034a49c69d | ||
|
|
3b6825d7e2 | ||
|
|
bb5ae389f7 | ||
|
|
4a2cef887c | ||
|
|
42750f7846 | ||
|
|
d31aa143f4 | ||
|
|
710e777a92 | ||
|
|
912dca8f65 | ||
|
|
db84530074 | ||
|
|
72bbaac96d | ||
|
|
5713d63dc5 | ||
|
|
d653e594c2 | ||
|
|
dd7bb33ab6 | ||
|
|
a9c6182b3f | ||
|
|
3d70137d31 | ||
|
|
bce9a081db | ||
|
|
46cf41cc93 | ||
|
|
81a440c8e8 | ||
|
|
f24a3b5282 | ||
|
|
383b4a2c3e | ||
|
|
df59822a27 | ||
|
|
0908c5414d | ||
|
|
ee46134fa7 | ||
|
|
39bb319d4c | ||
|
|
1bdd83a85f | ||
|
|
1624c239c2 | ||
|
|
4a913ce61e | ||
|
|
764e333fa2 | ||
|
|
c61e3bf4c9 | ||
|
|
fc8649d80f | ||
|
|
0fb9ecf1f3 | ||
|
|
97958400fb | ||
|
|
6d6d86260b | ||
|
|
c856ea4249 | ||
|
|
d0923d6710 | ||
|
|
f312522cef | ||
|
|
da5a144589 | ||
|
|
2c1e669bd8 | ||
|
|
e20e9f61ac | ||
|
|
6b3148fd3f | ||
|
|
95ae56bd22 | ||
|
|
990192d077 | ||
|
|
f3e69531c3 | ||
|
|
0cb3272bda | ||
|
|
6231aa91e2 | ||
|
|
489b728dbc | ||
|
|
583e2b2d01 | ||
|
|
5dc2a0d3fd | ||
|
|
2c731418ad | ||
|
|
5c150675bf | ||
|
|
fea810b437 | ||
|
|
96d877be90 | ||
|
|
40d917b0fe | ||
|
|
e72020ae01 | ||
|
|
01d929ee2a | ||
|
|
cf876fcdb4 | ||
|
|
291c29caaf | ||
|
|
01e00ac1b0 | ||
|
|
a9ed4ed8a8 | ||
|
|
9d6a5a0c79 | ||
|
|
fb97a7aab1 | ||
|
|
1cefb2a753 | ||
|
|
63992b81c8 | ||
|
|
d8f68674fb | ||
|
|
9d00c8eea2 | ||
|
|
0d21925bdf | ||
|
|
efef5c8ead | ||
|
|
3d2bb1a8f1 | ||
|
|
837a4dddb8 | ||
|
|
b2626bc7a9 | ||
|
|
202f2c3292 | ||
|
|
2a23713f71 | ||
|
|
681034d001 | ||
|
|
17813ff5b4 | ||
|
|
3e81bd6b67 | ||
|
|
23ae358e0f | ||
|
|
f611726364 | ||
|
|
33ee0acd35 | ||
|
|
8b79e3b06c | ||
|
|
cf49e912fc | ||
|
|
66741c035c | ||
|
|
406511c333 | ||
|
|
8a2d68d63e | ||
|
|
07d297fdbe | ||
|
|
0d4e8b50d0 | ||
|
|
1d7c5c2a98 | ||
|
|
0faa350175 | ||
|
|
8a7509db75 | ||
|
|
025368f51c | ||
|
|
5fe52ed322 | ||
|
|
8b247a330b | ||
|
|
d6f458fcb3 | ||
|
|
b8b84021e5 | ||
|
|
70fe7e18be | ||
|
|
9378da3c82 | ||
|
|
a4857fa764 | ||
|
|
592014923f | ||
|
|
6d06b215bf | ||
|
|
2d87bb648f | ||
|
|
56ebef35b0 | ||
|
|
13d8b22d25 | ||
|
|
27f9b6ffeb | ||
|
|
c8fcfd4581 | ||
|
|
49c24285c7 | ||
|
|
c918489259 | ||
|
|
93155242fa | ||
|
|
4cc919607a | ||
|
|
81419f7f32 | ||
|
|
6bd6cd9c51 | ||
|
|
35a1d68eb6 | ||
|
|
365a06bdb6 | ||
|
|
8e117f9f92 | ||
|
|
209eafb631 | ||
|
|
14aa2923cf | ||
|
|
1e395ed285 | ||
|
|
98615166b0 | ||
|
|
28272de97a | ||
|
|
7e736da30c | ||
|
|
20e929e27e | ||
|
|
477b5260aa | ||
|
|
d39f1a3427 | ||
|
|
3757855231 | ||
|
|
d846431015 | ||
|
|
624edf428f | ||
|
|
54500b861d | ||
|
|
f2491ee0ac | ||
|
|
1f169ee7fb | ||
|
|
66817992c1 | ||
|
|
8052bcd5cd | ||
|
|
55886a0116 | ||
|
|
33e90cc6a0 | ||
|
|
d5be8125b0 | ||
|
|
b99cd2a920 | ||
|
|
b64389c8a9 | ||
|
|
db7a28ac25 | ||
|
|
d337bbf8a0 | ||
|
|
90c47140b8 | ||
|
|
0ecfd91a20 | ||
|
|
a0e05fa291 | ||
|
|
e33c007cd0 | ||
|
|
80aca1ccc7 | ||
|
|
6b3a580ee5 | ||
|
|
207fc8b256 | ||
|
|
74561dbdac | ||
|
|
867e7d3238 | ||
|
|
5f08a21d12 | ||
|
|
95bc6e8749 | ||
|
|
4530b96c67 | ||
|
|
360af27749 | ||
|
|
0ee75fd75d | ||
|
|
2eae9b66d0 | ||
|
|
f6d417e26d | ||
|
|
903825af6f | ||
|
|
948cf17499 | ||
|
|
cd59003003 | ||
|
|
f19a48a28c | ||
|
|
4c6f3125fc | ||
|
|
497051c14b | ||
|
|
6400116715 | ||
|
|
f77bdf96d8 | ||
|
|
c06a86706a | ||
|
|
e0beb6a999 | ||
|
|
633bb8d339 | ||
|
|
7e850f3b7e | ||
|
|
59c9a8e7ae | ||
|
|
c2419ddabf | ||
|
|
2e0942d5c8 | ||
|
|
6155f9c171 | ||
|
|
f64c78b777 | ||
|
|
3d12cdc643 | ||
|
|
526488feaa | ||
|
|
5d88351bb5 | ||
|
|
a46a4781e8 | ||
|
|
b44644bcec | ||
|
|
1f4a495e16 | ||
|
|
d97a1638d3 | ||
|
|
ef28a919d2 | ||
|
|
71369ac98b | ||
|
|
85f1114c4a | ||
|
|
927c687628 | ||
|
|
6d5cffaee9 | ||
|
|
fbc550d02e | ||
|
|
014c4b47c9 | ||
|
|
9be19ad777 | ||
|
|
1161a5c6da | ||
|
|
9947197a84 | ||
|
|
50c6aaae62 | ||
|
|
edd314cc8a | ||
|
|
8b2a11fd5e | ||
|
|
15b463d18d | ||
|
|
0c1975501c | ||
|
|
98f8785a4f | ||
|
|
b74dfba215 | ||
|
|
bee5c3f1b8 | ||
|
|
e191892824 | ||
|
|
2841927dba | ||
|
|
0646112010 | ||
|
|
782b11b844 | ||
|
|
5a86bbc0a0 | ||
|
|
fef7eb73ad | ||
|
|
62fa4734fe | ||
|
|
b5db90c8a8 | ||
|
|
3e1591661e | ||
|
|
1e52fe6e09 | ||
|
|
809fca0be9 | ||
|
|
5fa473d5f3 | ||
|
|
784a90c3a6 | ||
|
|
6111151f50 | ||
|
|
afc03af3ca | ||
|
|
306ee24c90 | ||
|
|
3f7235c36f | ||
|
|
9d678a6f41 | ||
|
|
983698dd1b | ||
|
|
9a60b8a0ba | ||
|
|
adf99a332e | ||
|
|
d713e4c757 | ||
|
|
a90c9c2776 | ||
|
|
d43fcd638e | ||
|
|
e32e24adf5 | ||
|
|
e2c2689f5c | ||
|
|
8415014de6 | ||
|
|
3307ccb2dc | ||
|
|
6889ee2b85 | ||
|
|
bf31f18c46 | ||
|
|
e73d103eca | ||
|
|
12e58ab37f | ||
|
|
daad50e384 | ||
|
|
4e339bb101 | ||
|
|
b83ce0c352 | ||
|
|
6f80fe17fc | ||
|
|
7ea38f90d7 | ||
|
|
f4a2bc6cf8 | ||
|
|
78226f8574 | ||
|
|
04b1defaf9 | ||
|
|
3cdbbb43be | ||
|
|
92f41f1051 | ||
|
|
c142dadb46 | ||
|
|
cd54af019a | ||
|
|
e5f9772a35 | ||
|
|
a02056c566 | ||
|
|
2dfa26cca0 | ||
|
|
25d8cd473e | ||
|
|
f4935dd6be | ||
|
|
9d855091bf | ||
|
|
f3be995c28 | ||
|
|
9d7619d1eb | ||
|
|
c6d52fdea4 | ||
|
|
cf6832896f | ||
|
|
6b1cf6c4fd | ||
|
|
db80c5a2e7 | ||
|
|
89aae3e04f | ||
|
|
0636399c8c | ||
|
|
7e474d21ca | ||
|
|
f61996b425 | ||
|
|
496c3f2732 | ||
|
|
8856c19c76 | ||
|
|
0eacadfa99 | ||
|
|
2a4ae88f18 | ||
|
|
a296654c1b | ||
|
|
b62185b821 | ||
|
|
e6034b7eb6 | ||
|
|
54a4aa22ed | ||
|
|
9ec70252d0 | ||
|
|
e20b6acfe9 | ||
|
|
d9180c03f6 | ||
|
|
4072f723c1 | ||
|
|
cf8021020f | ||
|
|
fb1054b5e3 | ||
|
|
1e4512b2c8 | ||
|
|
3a7326ae46 | ||
|
|
38b59a93de | ||
|
|
1199eacb72 | ||
|
|
fdb58b0b62 | ||
|
|
315fbc11e5 | ||
|
|
4a1b92d309 | ||
|
|
272dd993e6 | ||
|
|
96a52d9810 | ||
|
|
50544b7805 | ||
|
|
b78c0e2a69 | ||
|
|
2b969e9c42 | ||
|
|
e83ee217d3 | ||
|
|
b1e44e96bc | ||
|
|
7ae0cde754 | ||
|
|
c1d5c24bc7 | ||
|
|
eec6aaddda | ||
|
|
bb167f94ca | ||
|
|
2e4783bcdf | ||
|
|
7b31c0830f | ||
|
|
8f645d354e | ||
|
|
7ec9a7af79 | ||
|
|
50b53e183e | ||
|
|
d131bde183 | ||
|
|
d1864e2430 | ||
|
|
8ba02ac829 | ||
|
|
73a08c0be0 | ||
|
|
c45d2f214b | ||
|
|
9a67e0df39 | ||
|
|
acf16c063a | ||
|
|
86a8cbd002 | ||
|
|
fc276a51fb | ||
|
|
771f33d17d | ||
|
|
e6d1f509a0 | ||
|
|
225e871819 | ||
|
|
7875ca8fb5 | ||
|
|
6d2d8dfd2f | ||
|
|
0ec7166098 | ||
|
|
3d66a234b0 | ||
|
|
8a073ee49f | ||
|
|
7e20c6d1a1 | ||
|
|
1d4672d747 | ||
|
|
39e62b948e | ||
|
|
41d195715d | ||
|
|
3db97f8897 | ||
|
|
516f64f4d9 | ||
|
|
62dd99bee5 | ||
|
|
94c151aea3 | ||
|
|
81fa54837f | ||
|
|
9de357e373 | ||
|
|
b4a3824ce4 | ||
|
|
3bb80ebf20 | ||
|
|
cdffd19f61 | ||
|
|
a7ce2633f3 | ||
|
|
8fa5fb2816 | ||
|
|
8df948565a | ||
|
|
3c67e595b8 | ||
|
|
814996b14f | ||
|
|
2e67d74df4 | ||
|
|
b841dd78fe | ||
|
|
68ca0ea995 | ||
|
|
f54b784d88 | ||
|
|
b6e328ea8f | ||
|
|
5c80117fbd | ||
|
|
c2ceb6de5f | ||
|
|
77ec70d145 | ||
|
|
a380502c01 | ||
|
|
0416f26a76 | ||
|
|
3579b4570f | ||
|
|
256ff5b56c | ||
|
|
7502f662ab | ||
|
|
d974959738 | ||
|
|
5f348579d1 | ||
|
|
8371a7a3aa | ||
|
|
1d25703ac3 | ||
|
|
fe7ede5af3 | ||
|
|
d599394f60 | ||
|
|
66c03be45f | ||
|
|
c1d62383c6 | ||
|
|
73ab110260 | ||
|
|
cc3d40ca44 | ||
|
|
288efddf2f | ||
|
|
4a34e5804e | ||
|
|
3d0375daa6 | ||
|
|
3060eb5baf | ||
|
|
ce46aa0c3b | ||
|
|
3b35547da0 | ||
|
|
6aa62b9b66 | ||
|
|
2febbfe4b0 | ||
|
|
ea182461d3 | ||
|
|
5863676ccb | ||
|
|
97611e89ca | ||
|
|
64cf922841 | ||
|
|
227a62e4c4 | ||
|
|
38e21f5c1a | ||
|
|
d395bc0647 | ||
|
|
afce13d101 | ||
|
|
8521ab7990 | ||
|
|
71a6d49d06 | ||
|
|
07d5c71090 | ||
|
|
a751dc25d6 | ||
|
|
753c63e11b | ||
|
|
b0dfbe7086 | ||
|
|
31018d57b6 | ||
|
|
9ebebb22db | ||
|
|
2c461e4ad3 | ||
|
|
56ca5dfa15 | ||
|
|
747af145ed | ||
|
|
7981ee186f | ||
|
|
9e9df2b501 | ||
|
|
f7f762c676 | ||
|
|
0b730d904f | ||
|
|
11e8c7d8ff | ||
|
|
663f953a78 | ||
|
|
bfd909ab79 | ||
|
|
0cfcb5a49c | ||
|
|
6a86de1927 | ||
|
|
5114e8daf1 | ||
|
|
1c09867b3e | ||
|
|
2b4229fa51 | ||
|
|
92e50133f8 | ||
|
|
c4269b5efa | ||
|
|
19dfa24abb | ||
|
|
c7fd336c5d | ||
|
|
ed30af8343 | ||
|
|
1e0b059982 | ||
|
|
038c09f552 | ||
|
|
5d1b54de45 | ||
|
|
18156bf2a1 | ||
|
|
5845de7d7c | ||
|
|
e97d67a681 | ||
|
|
f0bb3ae825 | ||
|
|
9806b00f74 | ||
|
|
f2989b36c2 | ||
|
|
624fbadea2 | ||
|
|
d4ba37f543 | ||
|
|
449ad7502c | ||
|
|
44404fcd6d | ||
|
|
1da6d43109 | ||
|
|
9aee793078 | ||
|
|
89c3033401 | ||
|
|
67f09b7d7e | ||
|
|
0dfffcd88a | ||
|
|
9e1683cf2b | ||
|
|
4d0c06e397 | ||
|
|
0315611b11 | ||
|
|
33a6234b52 | ||
|
|
4b7b3bc04a | ||
|
|
035dd3a900 | ||
|
|
4e25c8f78e | ||
|
|
7f6b581ef8 | ||
|
|
cc274fb7fb | ||
|
|
334d07bf96 | ||
|
|
6417f5d7c1 | ||
|
|
8088c04a71 | ||
|
|
f7b1911f1b | ||
|
|
045cd38b6e | ||
|
|
dccdb8771c | ||
|
|
d4b5cab7f7 | ||
|
|
363f1dfab9 | ||
|
|
4e24733f1c | ||
|
|
bb91a10b5f | ||
|
|
98635ebde2 | ||
|
|
24823b061d | ||
|
|
0fe1afd4ef | ||
|
|
c0a7df9ee1 | ||
|
|
5907bbd9de | ||
|
|
5db792b10b | ||
|
|
7c38c33ed6 | ||
|
|
5bec05e045 | ||
|
|
6084611508 | ||
|
|
71a7a27319 | ||
|
|
ec2efe52e4 | ||
|
|
0f0158ddaa | ||
|
|
dde7807b00 | ||
|
|
1e3daa247b | ||
|
|
3bd00b88c2 | ||
|
|
62d00b4520 | ||
|
|
4f8ce00477 | ||
|
|
1214f35985 | ||
|
|
e743ee5d5c | ||
|
|
23c4e5cb01 | ||
|
|
1f1cae6c5a | ||
|
|
c8d209d36c | ||
|
|
f8e8df5a04 | ||
|
|
f4c9276336 | ||
|
|
a5c38e5d5b | ||
|
|
9c7237157d | ||
|
|
5931948adb | ||
|
|
8a5e3904a0 | ||
|
|
d679dc4de1 | ||
|
|
a002d10a4d | ||
|
|
3a06968332 | ||
|
|
6fbd526931 | ||
|
|
c437dce056 | ||
|
|
fc00691898 | ||
|
|
990ceddd14 | ||
|
|
226db64736 | ||
|
|
2429ac73b2 | ||
|
|
dd8e17cb37 | ||
|
|
db756e9a34 | ||
|
|
16e5981d31 | ||
|
|
575c51fd3b | ||
|
|
5b2447f71d | ||
|
|
0ccb4d4a3a | ||
|
|
b5bb8bec67 | ||
|
|
5cdf4e34a1 | ||
|
|
061e157191 | ||
|
|
d859a3a925 | ||
|
|
5a1a14f9fc | ||
|
|
b6ba4cac83 | ||
|
|
99b607c60c | ||
|
|
289298b17d | ||
|
|
f7a1868fc2 | ||
|
|
02bb8e0ac3 | ||
|
|
bc909e8359 | ||
|
|
c971d9319c | ||
|
|
0c942106bf | ||
|
|
c0c4d4ddc6 | ||
|
|
c924c47f37 | ||
|
|
5b54086663 | ||
|
|
9e797cc151 | ||
|
|
cc10a62e16 | ||
|
|
7e5b6154d0 | ||
|
|
6d6df18387 | ||
|
|
ca36f47dfc | ||
|
|
45f9cc9e0e | ||
|
|
3699a90645 | ||
|
|
714846e1e1 | ||
|
|
08d85d4013 | ||
|
|
0ec7743436 | ||
|
|
a72d80aa85 | ||
|
|
b556fc43bc | ||
|
|
dbb9c19669 | ||
|
|
bca6a44974 | ||
|
|
8ab5c8cb28 | ||
|
|
774c4059fb | ||
|
|
5f1d07d62f | ||
|
|
cd984992cf | ||
|
|
99f4940eb7 | ||
|
|
41dd835a89 | ||
|
|
ee42c5cd42 | ||
|
|
47b6101465 | ||
|
|
7889a52f95 | ||
|
|
8d562ecf48 | ||
|
|
2767a0f9f2 | ||
|
|
af08c56ce0 | ||
|
|
dfc56e9227 | ||
|
|
84d157995e | ||
|
|
ed5bfda372 | ||
|
|
a59822540f | ||
|
|
968bbd2f47 | ||
|
|
1b4bdff331 | ||
|
|
678fe003e3 | ||
|
|
3b1af3f1a6 | ||
|
|
437501cde3 | ||
|
|
8bd2072e19 | ||
|
|
85df289190 | ||
|
|
8856496aac | ||
|
|
a7df7db464 | ||
|
|
59507c7c02 | ||
|
|
09c719c926 | ||
|
|
e54b6311ef | ||
|
|
fdbdb4748a | ||
|
|
76a2b14cdb | ||
|
|
b08154dc36 | ||
|
|
165fc43655 | ||
|
|
42cbf75cfa | ||
|
|
e6ad3cbc66 | ||
|
|
2127907dd3 | ||
|
|
164a1978de | ||
|
|
cb1076ed23 | ||
|
|
ad5f318d06 | ||
|
|
60bbe64489 | ||
|
|
b9085fc80a | ||
|
|
2fad5b88bc | ||
|
|
b271a6bd89 | ||
|
|
758a1e7f66 | ||
|
|
1cba447102 | ||
|
|
e25164cfed | ||
|
|
f6556f7972 | ||
|
|
69579668bb | ||
|
|
2e688b7cd3 | ||
|
|
2fcbfec178 | ||
|
|
e1143caf38 | ||
|
|
a7485e4d9e | ||
|
|
335b2f960e | ||
|
|
b18d099291 | ||
|
|
bc803e01c7 | ||
|
|
eaa2460701 | ||
|
|
c7dbcc6483 | ||
|
|
ad8a5934e1 | ||
|
|
7078e6477e | ||
|
|
69475f5bf1 | ||
|
|
ddeeb9428c | ||
|
|
780c60630c | ||
|
|
40c37b1219 | ||
|
|
c14b09376a | ||
|
|
fbcf56b2ba | ||
|
|
2d369b32f9 | ||
|
|
d52c524fc2 | ||
|
|
c2b51fbe98 | ||
|
|
7f2ac589f9 | ||
|
|
dff3872897 | ||
|
|
4f4b92da7d | ||
|
|
18f171d885 | ||
|
|
c72f8acea1 | ||
|
|
abedbc726f | ||
|
|
3e8d389e3e | ||
|
|
8810f8a728 | ||
|
|
5de91b9d81 | ||
|
|
57bc2abf41 | ||
|
|
dd50514d17 | ||
|
|
ac4935bf79 | ||
|
|
c817862cf7 | ||
|
|
c3768aaa46 | ||
|
|
a85fcfe05f | ||
|
|
1890535d1b | ||
|
|
9bb52acc14 | ||
|
|
551fdf32c3 | ||
|
|
74008ce487 | ||
|
|
852481e14d | ||
|
|
25c8279f26 | ||
|
|
05c57b9c7b | ||
|
|
46cbae088e | ||
|
|
b824bbfce6 | ||
|
|
9ba4c3edca | ||
|
|
ed2eef1625 | ||
|
|
e9a641bde7 | ||
|
|
ae3965a2a7 | ||
|
|
700af1c96d | ||
|
|
66edc5af7b | ||
|
|
ed15f6808b | ||
|
|
dc37fd2ff6 | ||
|
|
f256660780 | ||
|
|
23b261de3f | ||
|
|
884e6bff5d | ||
|
|
220436244c | ||
|
|
c430cf481a | ||
|
|
9f8f27fbad | ||
|
|
e746829b5f | ||
|
|
a69b24a069 | ||
|
|
12567f55cd | ||
|
|
8090daca40 | ||
|
|
27ffd9fe3d | ||
|
|
ee5cec7530 | ||
|
|
589a90bfbc | ||
|
|
314a364f61 | ||
|
|
f770cd96c6 | ||
|
|
01df1c0cc4 | ||
|
|
334589af4e | ||
|
|
43ef635be3 | ||
|
|
47d61e2c02 | ||
|
|
8f6fc8daa1 | ||
|
|
01ebfc41f3 | ||
|
|
87163cff8b | ||
|
|
6d5f847edc | ||
|
|
afb8700a95 | ||
|
|
e60d18cfb3 | ||
|
|
92332eb96e | ||
|
|
d5263d442f | ||
|
|
7ad7cac0c2 | ||
|
|
06a9f51431 | ||
|
|
849bc24d20 | ||
|
|
423e6c229c | ||
|
|
9fc27403b2 | ||
|
|
2de9a51591 | ||
|
|
a8632b7329 | ||
|
|
9ff32fd4c0 | ||
|
|
a097c42579 | ||
|
|
68e0767404 | ||
|
|
e09966024c | ||
|
|
893c2fc08a | ||
|
|
2e9f7b5f91 | ||
|
|
7f8e05ccad | ||
|
|
c316c63dff | ||
|
|
683680e5c8 | ||
|
|
bf8088e225 | ||
|
|
5050971ac6 | ||
|
|
08c54dcf22 | ||
|
|
6a5f87d874 | ||
|
|
a876f2d3fb | ||
|
|
a75f5898e6 | ||
|
|
dbab72153f | ||
|
|
0d54609435 | ||
|
|
07aa000750 | ||
|
|
b5c60d7d62 | ||
|
|
defefd79c5 | ||
|
|
27834df444 | ||
|
|
5c020bed49 | ||
|
|
c775ec1255 | ||
|
|
7527436549 | ||
|
|
541539a144 | ||
|
|
74220bb52c | ||
|
|
8eb60baf3a | ||
|
|
4b47e8ecb0 | ||
|
|
76bac2c1c5 | ||
|
|
0fcdda7175 | ||
|
|
e4eb3e63e6 | ||
|
|
626d4b433a | ||
|
|
83c7e03d05 | ||
|
|
959561473c | ||
|
|
7209eb74cc | ||
|
|
53cc3583df | ||
|
|
82c2553f07 | ||
|
|
6f6f9b537f | ||
|
|
f407f5a686 | ||
|
|
6134619998 | ||
|
|
817a9268ff | ||
|
|
3beddf341e | ||
|
|
1892c82a60 | ||
|
|
3f339cda6f | ||
|
|
16ba1cec69 | ||
|
|
8bfa50e283 | ||
|
|
c4a11e5a5a | ||
|
|
3cc4939dd3 | ||
|
|
b5c7937f8d | ||
|
|
b5ff4e816f | ||
|
|
a7d302e196 | ||
|
|
45381b188c | ||
|
|
054fb3308c | ||
|
|
d42431d73a | ||
|
|
c639cb7d5d | ||
|
|
97e65bf93f | ||
|
|
36c8a4aee7 | ||
|
|
19340d82e6 | ||
|
|
058e442072 | ||
|
|
9577a9f38d | ||
|
|
786971d443 | ||
|
|
f037b09c2d | ||
|
|
18d69d8e5e | ||
|
|
770a56193e | ||
|
|
4627b389ff | ||
|
|
1cd07770a4 | ||
|
|
1e164b6ec3 | ||
|
|
41ecccb2a9 | ||
|
|
c93cbbc373 | ||
|
|
8cecc676cf | ||
|
|
94441fa746 | ||
|
|
ccb0ef518a | ||
|
|
3032a47af4 | ||
|
|
1b75dbd4f2 | ||
|
|
dade23a414 | ||
|
|
313f3e8286 | ||
|
|
4dacc52bde | ||
|
|
b1dffe8d9a | ||
|
|
ea1cf4acee | ||
|
|
cd5e3baace | ||
|
|
e76ea7cd7d | ||
|
|
d68ba2f9de | ||
|
|
5fc80b7a5b | ||
|
|
31069e1dc5 | ||
|
|
6c28dfb417 | ||
|
|
2d6faa9860 | ||
|
|
cb53a77334 | ||
|
|
4d91dc0d30 | ||
|
|
935d4774a9 | ||
|
|
24e3d4b464 | ||
|
|
b0c33a4294 | ||
|
|
bf3674c1db | ||
|
|
b996f5a6d6 | ||
|
|
472f516e7c | ||
|
|
c838efcfa8 | ||
|
|
4f70e5dca6 | ||
|
|
0138a917d8 | ||
|
|
49b29f2db2 | ||
|
|
99eaf1fd65 | ||
|
|
5fa20b5348 | ||
|
|
895b0b6ca7 | ||
|
|
238f01bc9c | ||
|
|
43a08b4061 | ||
|
|
066b1bb57e | ||
|
|
3cdae0cbd2 | ||
|
|
14891523ce | ||
|
|
559a1aeeda | ||
|
|
a18558ddfe | ||
|
|
6732df93e2 | ||
|
|
4f42f759ea | ||
|
|
c9b157b536 | ||
|
|
4c06bfad60 | ||
|
|
a35d7ef227 | ||
|
|
a4b34a9c3c | ||
|
|
5a3d564a30 | ||
|
|
4dc1124f93 | ||
|
|
9c80da6ac5 | ||
|
|
292cdb8379 | ||
|
|
5ec90990de | ||
|
|
e203270e31 | ||
|
|
b2c5b96f2a | ||
|
|
1b89b2a10e | ||
|
|
143c26e552 | ||
|
|
518a18aeff | ||
|
|
a3c7d711e4 | ||
|
|
dbadc40ec2 | ||
|
|
447c56bf50 | ||
|
|
a9b26b73e0 | ||
|
|
64c923230e | ||
|
|
795a6bd2d8 | ||
|
|
aee343a9ee | ||
|
|
2c5949c155 | ||
|
|
193674e16c | ||
|
|
4f92b6266c | ||
|
|
2d86f63e15 | ||
|
|
88751f58f6 | ||
|
|
7b324bcc3b | ||
|
|
1645698ec0 | ||
|
|
5aa5a07260 | ||
|
|
6d9f3bc0b2 | ||
|
|
1816ac3271 | ||
|
|
cca3804503 | ||
|
|
cb08fa0379 | ||
|
|
a265225972 | ||
|
|
eb66e5ebac | ||
|
|
9d4cf8b03b | ||
|
|
a167a592e2 | ||
|
|
432353185c | ||
|
|
d526f1d3d3 | ||
|
|
c219600ca0 | ||
|
|
de95431895 | ||
|
|
c86bf213d1 | ||
|
|
48c1be34f3 | ||
|
|
140b4fad43 | ||
|
|
1f7babd2c7 | ||
|
|
cfb19ad0da | ||
|
|
1214760cea | ||
|
|
64d85b2f51 | ||
|
|
8f08feb577 | ||
|
|
ec7f9bab6c | ||
|
|
83e102c691 | ||
|
|
c3f9eb10f1 | ||
|
|
563a4dc897 | ||
|
|
370ca9e8cd | ||
|
|
5dad64b684 | ||
|
|
e24a43ae0b | ||
|
|
44d4cfb453 | ||
|
|
7c1cf7f4ea | ||
|
|
0b38e663fd | ||
|
|
8b25929765 | ||
|
|
e3b2bb5b80 | ||
|
|
7544b38635 | ||
|
|
c4a596df9e | ||
|
|
ab05be11d2 | ||
|
|
eb68892ab1 |
7
.github/dependabot.yml
vendored
Normal file
7
.github/dependabot.yml
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
---
|
||||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "monthly"
|
||||
4
.github/workflows/typos.yml
vendored
4
.github/workflows/typos.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: typos-action
|
||||
uses: crate-ci/typos@v1.13.10
|
||||
uses: crate-ci/typos@v1.16.26
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -4,4 +4,5 @@ wd14_tagger_model
|
||||
venv
|
||||
*.egg-info
|
||||
build
|
||||
.vscode
|
||||
.vscode
|
||||
wandb
|
||||
|
||||
97
README-ja.md
97
README-ja.md
@@ -1,3 +1,7 @@
|
||||
SDXLがサポートされました。sdxlブランチはmainブランチにマージされました。リポジトリを更新したときにはUpgradeの手順を実行してください。また accelerate のバージョンが上がっていますので、accelerate config を再度実行してください。
|
||||
|
||||
SDXL学習については[こちら](./README.md#sdxl-training)をご覧ください(英語です)。
|
||||
|
||||
## リポジトリについて
|
||||
Stable Diffusionの学習、画像生成、その他のスクリプトを入れたリポジトリです。
|
||||
|
||||
@@ -9,20 +13,19 @@ GUIやPowerShellスクリプトなど、より使いやすくする機能が[bma
|
||||
|
||||
* DreamBooth、U-NetおよびText Encoderの学習をサポート
|
||||
* fine-tuning、同上
|
||||
* LoRAの学習をサポート
|
||||
* 画像生成
|
||||
* モデル変換(Stable Diffision ckpt/safetensorsとDiffusersの相互変換)
|
||||
|
||||
## 使用法について
|
||||
|
||||
当リポジトリ内およびnote.comに記事がありますのでそちらをご覧ください(将来的にはすべてこちらへ移すかもしれません)。
|
||||
|
||||
* [学習について、共通編](./train_README-ja.md) : データ整備やオプションなど
|
||||
* [データセット設定](./config_README-ja.md)
|
||||
* [DreamBoothの学習について](./train_db_README-ja.md)
|
||||
* [fine-tuningのガイド](./fine_tune_README_ja.md):
|
||||
* [LoRAの学習について](./train_network_README-ja.md)
|
||||
* [Textual Inversionの学習について](./train_ti_README-ja.md)
|
||||
* note.com [画像生成スクリプト](https://note.com/kohya_ss/n/n2693183a798e)
|
||||
* [学習について、共通編](./docs/train_README-ja.md) : データ整備やオプションなど
|
||||
* [データセット設定](./docs/config_README-ja.md)
|
||||
* [DreamBoothの学習について](./docs/train_db_README-ja.md)
|
||||
* [fine-tuningのガイド](./docs/fine_tune_README_ja.md):
|
||||
* [LoRAの学習について](./docs/train_network_README-ja.md)
|
||||
* [Textual Inversionの学習について](./docs/train_ti_README-ja.md)
|
||||
* [画像生成スクリプト](./docs/gen_img_README-ja.md)
|
||||
* note.com [モデル変換スクリプト](https://note.com/kohya_ss/n/n374f316fe4ad)
|
||||
|
||||
## Windowsでの動作に必要なプログラム
|
||||
@@ -41,11 +44,13 @@ PowerShellを使う場合、venvを使えるようにするためには以下の
|
||||
|
||||
## Windows環境でのインストール
|
||||
|
||||
以下の例ではPyTorchは1.12.1/CUDA 11.6版をインストールします。CUDA 11.3版やPyTorch 1.13を使う場合は適宜書き換えください。
|
||||
スクリプトはPyTorch 2.0.1でテストしています。PyTorch 1.12.1でも動作すると思われます。
|
||||
|
||||
以下の例ではPyTorchは2.0.1/CUDA 11.8版をインストールします。CUDA 11.6版やPyTorch 1.12.1を使う場合は適宜書き換えください。
|
||||
|
||||
(なお、python -m venv~の行で「python」とだけ表示された場合、py -m venv~のようにpythonをpyに変更してください。)
|
||||
|
||||
通常の(管理者ではない)PowerShellを開き以下を順に実行します。
|
||||
PowerShellを使う場合、通常の(管理者ではない)PowerShellを開き以下を順に実行します。
|
||||
|
||||
```powershell
|
||||
git clone https://github.com/kohya-ss/sd-scripts.git
|
||||
@@ -54,43 +59,14 @@ cd sd-scripts
|
||||
python -m venv venv
|
||||
.\venv\Scripts\activate
|
||||
|
||||
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
|
||||
pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118
|
||||
pip install --upgrade -r requirements.txt
|
||||
pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
|
||||
|
||||
cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
|
||||
cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
|
||||
cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
|
||||
pip install xformers==0.0.20
|
||||
|
||||
accelerate config
|
||||
```
|
||||
|
||||
<!--
|
||||
pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117
|
||||
pip install --use-pep517 --upgrade -r requirements.txt
|
||||
pip install -U -I --no-deps xformers==0.0.16
|
||||
-->
|
||||
|
||||
コマンドプロンプトでは以下になります。
|
||||
|
||||
|
||||
```bat
|
||||
git clone https://github.com/kohya-ss/sd-scripts.git
|
||||
cd sd-scripts
|
||||
|
||||
python -m venv venv
|
||||
.\venv\Scripts\activate
|
||||
|
||||
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
|
||||
pip install --upgrade -r requirements.txt
|
||||
pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
|
||||
|
||||
copy /y .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
|
||||
copy /y .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
|
||||
copy /y .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
|
||||
|
||||
accelerate config
|
||||
```
|
||||
コマンドプロンプトでも同一です。
|
||||
|
||||
(注:``python -m venv venv`` のほうが ``python -m venv --system-site-packages venv`` より安全そうなため書き換えました。globalなpythonにパッケージがインストールしてあると、後者だといろいろと問題が起きます。)
|
||||
|
||||
@@ -111,9 +87,40 @@ accelerate configの質問には以下のように答えてください。(bf1
|
||||
※場合によって ``ValueError: fp16 mixed precision requires a GPU`` というエラーが出ることがあるようです。この場合、6番目の質問(
|
||||
``What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:``)に「0」と答えてください。(id `0`のGPUが使われます。)
|
||||
|
||||
### PyTorchとxformersのバージョンについて
|
||||
### オプション:`bitsandbytes`(8bit optimizer)を使う
|
||||
|
||||
他のバージョンでは学習がうまくいかない場合があるようです。特に他の理由がなければ指定のバージョンをお使いください。
|
||||
`bitsandbytes`はオプションになりました。Linuxでは通常通りpipでインストールできます(0.41.1または以降のバージョンを推奨)。
|
||||
|
||||
Windowsでは0.35.0または0.41.1を推奨します。
|
||||
|
||||
- `bitsandbytes` 0.35.0: 安定しているとみられるバージョンです。AdamW8bitは使用できますが、他のいくつかの8bit optimizer、学習時の`full_bf16`オプションは使用できません。
|
||||
- `bitsandbytes` 0.41.1: Lion8bit、PagedAdamW8bit、PagedLion8bitをサポートします。`full_bf16`が使用できます。
|
||||
|
||||
注:`bitsandbytes` 0.35.0から0.41.0までのバージョンには問題があるようです。 https://github.com/TimDettmers/bitsandbytes/issues/659
|
||||
|
||||
以下の手順に従い、`bitsandbytes`をインストールしてください。
|
||||
|
||||
### 0.35.0を使う場合
|
||||
|
||||
PowerShellの例です。コマンドプロンプトではcpの代わりにcopyを使ってください。
|
||||
|
||||
```powershell
|
||||
cd sd-scripts
|
||||
.\venv\Scripts\activate
|
||||
pip install bitsandbytes==0.35.0
|
||||
|
||||
cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
|
||||
cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
|
||||
cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
|
||||
```
|
||||
|
||||
### 0.41.1を使う場合
|
||||
|
||||
jllllll氏の配布されている[こちら](https://github.com/jllllll/bitsandbytes-windows-webui) または他の場所から、Windows用のwhlファイルをインストールしてください。
|
||||
|
||||
```powershell
|
||||
python -m pip install bitsandbytes==0.41.1 --prefer-binary --extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui
|
||||
```
|
||||
|
||||
## アップグレード
|
||||
|
||||
|
||||
460
README.md
460
README.md
@@ -1,9 +1,47 @@
|
||||
## LoRAの層別適用率の探索について
|
||||
|
||||
層別適用率を探索する `train_network_appl_weights.py` を追加してあります。現在は SDXL のみ対応しています。
|
||||
|
||||
LoRA 等の学習済みネットワークに対して、層別適用率を変化させながら通常の学習プロセスを実行することで、適用率を探索します。つまり、どのような層別適用率を適用すると、学習データに近い画像が生成されるかを探索することができます。
|
||||
|
||||
層別適用率の合計をペナルティとすることが可能です。つまり、画像を再現しつつ、影響の少ない層の適用率が低くなるような適用率が探索できるはずです。
|
||||
|
||||
複数のネットワークを対象に探索できます。また探索には最低 1 枚の学習データが必要になります。
|
||||
|
||||
(何枚程度から正しく動くかは確認していません。50枚程度の画像でテスト済みです。また学習データは LoRA 学習時のデータでなくてもよいはずですが、未確認です。)
|
||||
|
||||
コマンドラインオプションは `sdxl_train_network.py` とほぼ同じですが、以下のオプションが追加、拡張されています。
|
||||
|
||||
- `--application_loss_weight` : 層別適用率を loss に加える際の重みです。デフォルトは 0.0001 です。大きくすると、なるべく適用率を低くするように学習します。0 を指定するとペナルティが適用されないため、再現度が最も高くなる適用率を自由に探索します。
|
||||
- `--network_module` : 探索対象の複数のモジュールを指定することができます。たとえば `--network_module networks.lora networks.lora` のように指定します。
|
||||
- `--network_weights` : 探索対象の複数のネットワークの重みを指定することができます。たとえば `--network_weights model1.safetensors model2.safetensors` のように指定します。
|
||||
|
||||
層別適用率のパラメータ数は 20個で、`BASE, IN00-08, MID, OUT00-08` となります。`BASE` は Text Encoder に適用されます。(Text Encoder を対象とした LoRA の動作は未確認です。)
|
||||
|
||||
パラメータは一応ファイルに保存されますが、画面に表示される値をコピーして保存することをお勧めします。
|
||||
|
||||
### 備考
|
||||
|
||||
オプティマイザ AdamW、学習率 1e-1 で動作確認しています。学習率はかなり高めに設定してよいようです。この設定では LoRA 学習時の 1/20 ~ 1/10 ほどの epoch 数でそれなりの結果が得られます。
|
||||
|
||||
`application_loss_weight` を 0.0001 より大きくすると合計の適用率がかなり低くなる(=LoRA があまり適用されない)ようです。条件にもよると思いますので、適宜調整してください。
|
||||
|
||||
適用率に負の値を使うと、影響の少ない層の適用率を極端に低くして合計を小さくする、という動きをしてしまうので、負の値は10倍の重み付けをしてあります(-0.01 は 0.1 とほぼ同じペナルティ)。重み付けを変更するときはソースを修正してください。
|
||||
|
||||
「必要ない層への適用率を下げて影響範囲を小さくする」という使い方だけでなく、「あるキャラクターがあるポーズをしている画像を教師データに、キャラクターを維持しつつポーズを取るための LoRA の適用率を探索する」、「ある画風のあるキャラクターの画像を教師データに、画風 LoRA とキャラクター LoRA の適用率を探索する」などの使い方が考えられます。
|
||||
|
||||
もしかすると、「あるキャラクターの、あえて別の画風の画像を教師データに、キャラクターの属性を再現するのに必要な層を探す」、「理想とする画像を教師データに、使えそうな LoRA を多数適用し、その中から最も再現度が高い適用率を探す(ただし LoRA の数が多いほど学習が遅くなります)」といった使い方もできるかもしれません。
|
||||
|
||||
---
|
||||
|
||||
__SDXL is now supported. The sdxl branch has been merged into the main branch. If you update the repository, please follow the upgrade instructions. Also, the version of accelerate has been updated, so please run accelerate config again.__ The documentation for SDXL training is [here](./README.md#sdxl-training).
|
||||
|
||||
This repository contains training, generation and utility scripts for Stable Diffusion.
|
||||
|
||||
[__Change History__](#change-history) is moved to the bottom of the page.
|
||||
[__Change History__](#change-history) is moved to the bottom of the page.
|
||||
更新履歴は[ページ末尾](#change-history)に移しました。
|
||||
|
||||
[日本語版README](./README-ja.md)
|
||||
[日本語版READMEはこちら](./README-ja.md)
|
||||
|
||||
For easier use (GUI and PowerShell scripts etc...), please visit [the repository maintained by bmaltais](https://github.com/bmaltais/kohya_ss). Thanks to @bmaltais!
|
||||
|
||||
@@ -12,29 +50,30 @@ This repository contains the scripts for:
|
||||
* DreamBooth training, including U-Net and Text Encoder
|
||||
* Fine-tuning (native training), including U-Net and Text Encoder
|
||||
* LoRA training
|
||||
* Texutl Inversion training
|
||||
* Textual Inversion training
|
||||
* Image generation
|
||||
* Model conversion (supports 1.x and 2.x, Stable Diffision ckpt/safetensors and Diffusers)
|
||||
|
||||
__Stable Diffusion web UI now seems to support LoRA trained by ``sd-scripts``.__ (SD 1.x based only) Thank you for great work!!!
|
||||
|
||||
## About requirements.txt
|
||||
|
||||
These files do not contain requirements for PyTorch. Because the versions of them depend on your environment. Please install PyTorch at first (see installation guide below.)
|
||||
|
||||
The scripts are tested with PyTorch 1.12.1 and 1.13.0, Diffusers 0.10.2.
|
||||
The scripts are tested with Pytorch 2.0.1. 1.12.1 is not tested but should work.
|
||||
|
||||
## Links to how-to-use documents
|
||||
## Links to usage documentation
|
||||
|
||||
All documents are in Japanese currently.
|
||||
Most of the documents are written in Japanese.
|
||||
|
||||
* [Training guide - common](./train_README-ja.md) : data preparation, options etc...
|
||||
* [Dataset config](./config_README-ja.md)
|
||||
* [DreamBooth training guide](./train_db_README-ja.md)
|
||||
* [Step by Step fine-tuning guide](./fine_tune_README_ja.md):
|
||||
* [training LoRA](./train_network_README-ja.md)
|
||||
* [training Textual Inversion](./train_ti_README-ja.md)
|
||||
* note.com [Image generation](https://note.com/kohya_ss/n/n2693183a798e)
|
||||
[English translation by darkstorm2150 is here](https://github.com/darkstorm2150/sd-scripts#links-to-usage-documentation). Thanks to darkstorm2150!
|
||||
|
||||
* [Training guide - common](./docs/train_README-ja.md) : data preparation, options etc...
|
||||
* [Chinese version](./docs/train_README-zh.md)
|
||||
* [Dataset config](./docs/config_README-ja.md)
|
||||
* [DreamBooth training guide](./docs/train_db_README-ja.md)
|
||||
* [Step by Step fine-tuning guide](./docs/fine_tune_README_ja.md):
|
||||
* [training LoRA](./docs/train_network_README-ja.md)
|
||||
* [training Textual Inversion](./docs/train_ti_README-ja.md)
|
||||
* [Image generation](./docs/gen_img_README-ja.md)
|
||||
* note.com [Model conversion](https://note.com/kohya_ss/n/n374f316fe4ad)
|
||||
|
||||
## Windows Required Dependencies
|
||||
@@ -61,19 +100,20 @@ cd sd-scripts
|
||||
python -m venv venv
|
||||
.\venv\Scripts\activate
|
||||
|
||||
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
|
||||
pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118
|
||||
pip install --upgrade -r requirements.txt
|
||||
pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
|
||||
|
||||
cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
|
||||
cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
|
||||
cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
|
||||
pip install xformers==0.0.20
|
||||
|
||||
accelerate config
|
||||
```
|
||||
|
||||
update: ``python -m venv venv`` is seemed to be safer than ``python -m venv --system-site-packages venv`` (some user have packages in global python).
|
||||
__Note:__ Now bitsandbytes is optional. Please install any version of bitsandbytes as needed. Installation instructions are in the following section.
|
||||
|
||||
<!--
|
||||
cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
|
||||
cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
|
||||
cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
|
||||
-->
|
||||
Answers to accelerate config:
|
||||
|
||||
```txt
|
||||
@@ -91,10 +131,42 @@ note: Some user reports ``ValueError: fp16 mixed precision requires a GPU`` is o
|
||||
|
||||
(Single GPU with id `0` will be used.)
|
||||
|
||||
### about PyTorch and xformers
|
||||
### Optional: Use `bitsandbytes` (8bit optimizer)
|
||||
|
||||
Other versions of PyTorch and xformers seem to have problems with training.
|
||||
If there is no other reason, please install the specified version.
|
||||
For 8bit optimizer, you need to install `bitsandbytes`. For Linux, please install `bitsandbytes` as usual (0.41.1 or later is recommended.)
|
||||
|
||||
For Windows, there are several versions of `bitsandbytes`:
|
||||
|
||||
- `bitsandbytes` 0.35.0: Stable version. AdamW8bit is available. `full_bf16` is not available.
|
||||
- `bitsandbytes` 0.41.1: Lion8bit, PagedAdamW8bit and PagedLion8bit are available. `full_bf16` is available.
|
||||
|
||||
Note: `bitsandbytes`above 0.35.0 till 0.41.0 seems to have an issue: https://github.com/TimDettmers/bitsandbytes/issues/659
|
||||
|
||||
Follow the instructions below to install `bitsandbytes` for Windows.
|
||||
|
||||
### bitsandbytes 0.35.0 for Windows
|
||||
|
||||
Open a regular Powershell terminal and type the following inside:
|
||||
|
||||
```powershell
|
||||
cd sd-scripts
|
||||
.\venv\Scripts\activate
|
||||
pip install bitsandbytes==0.35.0
|
||||
|
||||
cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
|
||||
cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
|
||||
cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
|
||||
```
|
||||
|
||||
This will install `bitsandbytes` 0.35.0 and copy the necessary files to the `bitsandbytes` directory.
|
||||
|
||||
### bitsandbytes 0.41.1 for Windows
|
||||
|
||||
Install the Windows version whl file from [here](https://github.com/jllllll/bitsandbytes-windows-webui) or other sources, like:
|
||||
|
||||
```powershell
|
||||
python -m pip install bitsandbytes==0.41.1 --prefer-binary --extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui
|
||||
```
|
||||
|
||||
## Upgrade
|
||||
|
||||
@@ -125,103 +197,279 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
|
||||
|
||||
[BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause
|
||||
|
||||
|
||||
## SDXL training
|
||||
|
||||
The documentation in this section will be moved to a separate document later.
|
||||
|
||||
### Training scripts for SDXL
|
||||
|
||||
- `sdxl_train.py` is a script for SDXL fine-tuning. The usage is almost the same as `fine_tune.py`, but it also supports DreamBooth dataset.
|
||||
- `--full_bf16` option is added. Thanks to KohakuBlueleaf!
|
||||
- This option enables the full bfloat16 training (includes gradients). This option is useful to reduce the GPU memory usage.
|
||||
- The full bfloat16 training might be unstable. Please use it at your own risk.
|
||||
- The different learning rates for each U-Net block are now supported in sdxl_train.py. Specify with `--block_lr` option. Specify 23 values separated by commas like `--block_lr 1e-3,1e-3 ... 1e-3`.
|
||||
- 23 values correspond to `0: time/label embed, 1-9: input blocks 0-8, 10-12: mid blocks 0-2, 13-21: output blocks 0-8, 22: out`.
|
||||
- `prepare_buckets_latents.py` now supports SDXL fine-tuning.
|
||||
|
||||
- `sdxl_train_network.py` is a script for LoRA training for SDXL. The usage is almost the same as `train_network.py`.
|
||||
|
||||
- Both scripts has following additional options:
|
||||
- `--cache_text_encoder_outputs` and `--cache_text_encoder_outputs_to_disk`: Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions.
|
||||
- `--no_half_vae`: Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs.
|
||||
|
||||
- `--weighted_captions` option is not supported yet for both scripts.
|
||||
|
||||
- `sdxl_train_textual_inversion.py` is a script for Textual Inversion training for SDXL. The usage is almost the same as `train_textual_inversion.py`.
|
||||
- `--cache_text_encoder_outputs` is not supported.
|
||||
- There are two options for captions:
|
||||
1. Training with captions. All captions must include the token string. The token string is replaced with multiple tokens.
|
||||
2. Use `--use_object_template` or `--use_style_template` option. The captions are generated from the template. The existing captions are ignored.
|
||||
- See below for the format of the embeddings.
|
||||
|
||||
- `--min_timestep` and `--max_timestep` options are added to each training script. These options can be used to train U-Net with different timesteps. The default values are 0 and 1000.
|
||||
|
||||
### Utility scripts for SDXL
|
||||
|
||||
- `tools/cache_latents.py` is added. This script can be used to cache the latents to disk in advance.
|
||||
- The options are almost the same as `sdxl_train.py'. See the help message for the usage.
|
||||
- Please launch the script as follows:
|
||||
`accelerate launch --num_cpu_threads_per_process 1 tools/cache_latents.py ...`
|
||||
- This script should work with multi-GPU, but it is not tested in my environment.
|
||||
|
||||
- `tools/cache_text_encoder_outputs.py` is added. This script can be used to cache the text encoder outputs to disk in advance.
|
||||
- The options are almost the same as `cache_latents.py` and `sdxl_train.py`. See the help message for the usage.
|
||||
|
||||
- `sdxl_gen_img.py` is added. This script can be used to generate images with SDXL, including LoRA, Textual Inversion and ControlNet-LLLite. See the help message for the usage.
|
||||
|
||||
### Tips for SDXL training
|
||||
|
||||
- The default resolution of SDXL is 1024x1024.
|
||||
- The fine-tuning can be done with 24GB GPU memory with the batch size of 1. For 24GB GPU, the following options are recommended __for the fine-tuning with 24GB GPU memory__:
|
||||
- Train U-Net only.
|
||||
- Use gradient checkpointing.
|
||||
- Use `--cache_text_encoder_outputs` option and caching latents.
|
||||
- Use Adafactor optimizer. RMSprop 8bit or Adagrad 8bit may work. AdamW 8bit doesn't seem to work.
|
||||
- The LoRA training can be done with 8GB GPU memory (10GB recommended). For reducing the GPU memory usage, the following options are recommended:
|
||||
- Train U-Net only.
|
||||
- Use gradient checkpointing.
|
||||
- Use `--cache_text_encoder_outputs` option and caching latents.
|
||||
- Use one of 8bit optimizers or Adafactor optimizer.
|
||||
- Use lower dim (4 to 8 for 8GB GPU).
|
||||
- `--network_train_unet_only` option is highly recommended for SDXL LoRA. Because SDXL has two text encoders, the result of the training will be unexpected.
|
||||
- PyTorch 2 seems to use slightly less GPU memory than PyTorch 1.
|
||||
- `--bucket_reso_steps` can be set to 32 instead of the default value 64. Smaller values than 32 will not work for SDXL training.
|
||||
|
||||
Example of the optimizer settings for Adafactor with the fixed learning rate:
|
||||
```toml
|
||||
optimizer_type = "adafactor"
|
||||
optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ]
|
||||
lr_scheduler = "constant_with_warmup"
|
||||
lr_warmup_steps = 100
|
||||
learning_rate = 4e-7 # SDXL original learning rate
|
||||
```
|
||||
|
||||
### Format of Textual Inversion embeddings for SDXL
|
||||
|
||||
```python
|
||||
from safetensors.torch import save_file
|
||||
|
||||
state_dict = {"clip_g": embs_for_text_encoder_1280, "clip_l": embs_for_text_encoder_768}
|
||||
save_file(state_dict, file)
|
||||
```
|
||||
|
||||
### ControlNet-LLLite
|
||||
|
||||
ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [documentation](./docs/train_lllite_README.md) for details.
|
||||
|
||||
|
||||
## Change History
|
||||
|
||||
- 10 Mar. 2023, 2023/3/10: release v0.5.1
|
||||
- Fix to LoRA modules in the model are same to the previous (before 0.5.0) if Conv2d-3x3 is disabled (no `conv_dim` arg, default).
|
||||
- Conv2D with kernel size 1x1 in ResNet modules were accidentally included in v0.5.0.
|
||||
- Trained models with v0.5.0 will work with Web UI's built-in LoRA and Additional Networks extension.
|
||||
- Fix an issue that dim (rank) of LoRA module is limited to the in/out dimensions of the target Linear/Conv2d (in case of the dim > 320).
|
||||
- `resize_lora.py` now have a feature to `dynamic resizing` which means each LoRA module can have different ranks (dims). Thanks to mgz-dev for this great work!
|
||||
- The appropriate rank is selected based on the complexity of each module with an algorithm specified in the command line arguments. For details: https://github.com/kohya-ss/sd-scripts/pull/243
|
||||
- Multiple GPUs training is finally supported in `train_network.py`. Thanks to ddPn08 to solve this long running issue!
|
||||
- Dataset with fine-tuning method (with metadata json) now works without images if `.npz` files exist. Thanks to rvhfxb!
|
||||
- `train_network.py` can work if the current directory is not the directory where the script is in. Thanks to mio2333!
|
||||
- Fix `extract_lora_from_models.py` and `svd_merge_lora.py` doesn't work with higher rank (>320).
|
||||
### Jan 27, 2024 / 2024/1/27: v0.8.3
|
||||
|
||||
- LoRAのConv2d-3x3拡張を行わない場合(`conv_dim` を指定しない場合)、以前(v0.5.0)と同じ構成になるよう修正しました。
|
||||
- ResNetのカーネルサイズ1x1のConv2dが誤って対象になっていました。
|
||||
- ただv0.5.0で学習したモデルは Additional Networks 拡張、およびWeb UIのLoRA機能で問題なく使えると思われます。
|
||||
- LoRAモジュールの dim (rank) が、対象モジュールの次元数以下に制限される不具合を修正しました(320より大きい dim を指定した場合)。
|
||||
- `resize_lora.py` に `dynamic resizing` (リサイズ後の各LoRAモジュールが異なるrank (dim) を持てる機能)を追加しました。mgz-dev 氏の貢献に感謝します。
|
||||
- 適切なランクがコマンドライン引数で指定したアルゴリズムにより自動的に選択されます。詳細はこちらをご覧ください: https://github.com/kohya-ss/sd-scripts/pull/243
|
||||
- `train_network.py` でマルチGPU学習をサポートしました。長年の懸案を解決された ddPn08 氏に感謝します。
|
||||
- fine-tuning方式のデータセット(メタデータ.jsonファイルを使うデータセット)で `.npz` が存在するときには画像がなくても動作するようになりました。rvhfxb 氏に感謝します。
|
||||
- 他のディレクトリから `train_network.py` を呼び出しても動作するよう変更しました。 mio2333 氏に感謝します。
|
||||
- `extract_lora_from_models.py` および `svd_merge_lora.py` が320より大きいrankを指定すると動かない不具合を修正しました。
|
||||
|
||||
- 9 Mar. 2023, 2023/3/9: release v0.5.0
|
||||
- There may be problems due to major changes. If you cannot revert back to the previous version when problems occur, please do not update for a while.
|
||||
- Minimum metadata (module name, dim, alpha and network_args) is recorded even with `--no_metadata`, issue https://github.com/kohya-ss/sd-scripts/issues/254
|
||||
- `train_network.py` supports LoRA for Conv2d-3x3 (extended to conv2d with a kernel size not 1x1).
|
||||
- Same as a current version of [LoCon](https://github.com/KohakuBlueleaf/LoCon). __Thank you very much KohakuBlueleaf for your help!__
|
||||
- LoCon will be enhanced in the future. Compatibility for future versions is not guaranteed.
|
||||
- Specify `--network_args` option like: `--network_args "conv_dim=4" "conv_alpha=1"`
|
||||
- [Additional Networks extension](https://github.com/kohya-ss/sd-webui-additional-networks) version 0.5.0 or later is required to use 'LoRA for Conv2d-3x3' in Stable Diffusion web UI.
|
||||
- __Stable Diffusion web UI built-in LoRA does not support 'LoRA for Conv2d-3x3' now. Consider carefully whether or not to use it.__
|
||||
- Merging/extracting scripts also support LoRA for Conv2d-3x3.
|
||||
- Free CUDA memory after sample generation to reduce VRAM usage, issue https://github.com/kohya-ss/sd-scripts/issues/260
|
||||
- Empty caption doesn't cause error now, issue https://github.com/kohya-ss/sd-scripts/issues/258
|
||||
- Fix sample generation is crashing in Textual Inversion training when using templates, or if height/width is not divisible by 8.
|
||||
- Update documents (Japanese only).
|
||||
- Fixed a bug that the training crashes when `--fp8_base` is specified with `--save_state`. PR [#1079](https://github.com/kohya-ss/sd-scripts/pull/1079) Thanks to feffy380!
|
||||
- `safetensors` is updated. Please see [Upgrade](#upgrade) and update the library.
|
||||
- Fixed a bug that the training crashes when `network_multiplier` is specified with multi-GPU training. PR [#1084](https://github.com/kohya-ss/sd-scripts/pull/1084) Thanks to fireicewolf!
|
||||
- Fixed a bug that the training crashes when training ControlNet-LLLite.
|
||||
|
||||
- 大きく変更したため不具合があるかもしれません。問題が起きた時にスクリプトを前のバージョンに戻せない場合は、しばらく更新を控えてください。
|
||||
- 最低限のメタデータ(module name, dim, alpha および network_args)が `--no_metadata` オプション指定時にも記録されます。issue https://github.com/kohya-ss/sd-scripts/issues/254
|
||||
- `train_network.py` で LoRAの Conv2d-3x3 拡張に対応しました(カーネルサイズ1x1以外のConv2dにも対象範囲を拡大します)。
|
||||
- 現在のバージョンの [LoCon](https://github.com/KohakuBlueleaf/LoCon) と同一の仕様です。__KohakuBlueleaf氏のご支援に深く感謝します。__
|
||||
- LoCon が将来的に拡張された場合、それらのバージョンでの互換性は保証できません。
|
||||
- `--network_args` オプションを `--network_args "conv_dim=4" "conv_alpha=1"` のように指定してください。
|
||||
- Stable Diffusion web UI での使用には [Additional Networks extension](https://github.com/kohya-ss/sd-webui-additional-networks) のversion 0.5.0 以降が必要です。
|
||||
- __Stable Diffusion web UI の LoRA 機能は LoRAの Conv2d-3x3 拡張に対応していないようです。使用するか否か慎重にご検討ください。__
|
||||
- マージ、抽出のスクリプトについても LoRA の Conv2d-3x3 拡張に対応しました.
|
||||
- サンプル画像生成後にCUDAメモリを解放しVRAM使用量を削減しました。 issue https://github.com/kohya-ss/sd-scripts/issues/260
|
||||
- 空のキャプションが使えるようになりました。 issue https://github.com/kohya-ss/sd-scripts/issues/258
|
||||
- Textual Inversion 学習でテンプレートを使ったとき、height/width が 8 で割り切れなかったときにサンプル画像生成がクラッシュするのを修正しました。
|
||||
- ドキュメント類を更新しました。
|
||||
- `--fp8_base` 指定時に `--save_state` での保存がエラーになる不具合が修正されました。 PR [#1079](https://github.com/kohya-ss/sd-scripts/pull/1079) feffy380 氏に感謝します。
|
||||
- `safetensors` がバージョンアップされていますので、[Upgrade](#upgrade) を参照し更新をお願いします。
|
||||
- 複数 GPU での学習時に `network_multiplier` を指定するとクラッシュする不具合が修正されました。 PR [#1084](https://github.com/kohya-ss/sd-scripts/pull/1084) fireicewolf 氏に感謝します。
|
||||
- ControlNet-LLLite の学習がエラーになる不具合を修正しました。
|
||||
|
||||
- Sample image generation:
|
||||
A prompt file might look like this, for example
|
||||
### Jan 23, 2024 / 2024/1/23: v0.8.2
|
||||
|
||||
```
|
||||
# prompt 1
|
||||
masterpiece, best quality, 1girl, in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
|
||||
- [Experimental] The `--fp8_base` option is added to the training scripts for LoRA etc. The base model (U-Net, and Text Encoder when training modules for Text Encoder) can be trained with fp8. PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) Thanks to KohakuBlueleaf!
|
||||
- Please specify `--fp8_base` in `train_network.py` or `sdxl_train_network.py`.
|
||||
- PyTorch 2.1 or later is required.
|
||||
- If you use xformers with PyTorch 2.1, please see [xformers repository](https://github.com/facebookresearch/xformers) and install the appropriate version according to your CUDA version.
|
||||
- The sample image generation during training consumes a lot of memory. It is recommended to turn it off.
|
||||
|
||||
# prompt 2
|
||||
masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
|
||||
```
|
||||
- [Experimental] The network multiplier can be specified for each dataset in the training scripts for LoRA etc.
|
||||
- This is an experimental option and may be removed or changed in the future.
|
||||
- For example, if you train with state A as `1.0` and state B as `-1.0`, you may be able to generate by switching between state A and B depending on the LoRA application rate.
|
||||
- Also, if you prepare five states and train them as `0.2`, `0.4`, `0.6`, `0.8`, and `1.0`, you may be able to generate by switching the states smoothly depending on the application rate.
|
||||
- Please specify `network_multiplier` in `[[datasets]]` in `.toml` file.
|
||||
- Some options are added to `networks/extract_lora_from_models.py` to reduce the memory usage.
|
||||
- `--load_precision` option can be used to specify the precision when loading the model. If the model is saved in fp16, you can reduce the memory usage by specifying `--load_precision fp16` without losing precision.
|
||||
- `--load_original_model_to` option can be used to specify the device to load the original model. `--load_tuned_model_to` option can be used to specify the device to load the derived model. The default is `cpu` for both options, but you can specify `cuda` etc. You can reduce the memory usage by loading one of them to GPU. This option is available only for SDXL.
|
||||
|
||||
Lines beginning with `#` are comments. You can specify options for the generated image with options like `--n` after the prompt. The following can be used.
|
||||
- The gradient synchronization in LoRA training with multi-GPU is improved. PR [#1064](https://github.com/kohya-ss/sd-scripts/pull/1064) Thanks to KohakuBlueleaf!
|
||||
- The code for Intel IPEX support is improved. PR [#1060](https://github.com/kohya-ss/sd-scripts/pull/1060) Thanks to akx!
|
||||
- Fixed a bug in multi-GPU Textual Inversion training.
|
||||
|
||||
* `--n` Negative prompt up to the next option.
|
||||
* `--w` Specifies the width of the generated image.
|
||||
* `--h` Specifies the height of the generated image.
|
||||
* `--d` Specifies the seed of the generated image.
|
||||
* `--l` Specifies the CFG scale of the generated image.
|
||||
* `--s` Specifies the number of steps in the generation.
|
||||
- (実験的) LoRA等の学習スクリプトで、ベースモデル(U-Net、および Text Encoder のモジュール学習時は Text Encoder も)の重みを fp8 にして学習するオプションが追加されました。 PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) KohakuBlueleaf 氏に感謝します。
|
||||
- `train_network.py` または `sdxl_train_network.py` で `--fp8_base` を指定してください。
|
||||
- PyTorch 2.1 以降が必要です。
|
||||
- PyTorch 2.1 で xformers を使用する場合は、[xformers のリポジトリ](https://github.com/facebookresearch/xformers) を参照し、CUDA バージョンに応じて適切なバージョンをインストールしてください。
|
||||
- 学習中のサンプル画像生成はメモリを大量に消費するため、オフにすることをお勧めします。
|
||||
- (実験的) LoRA 等の学習で、データセットごとに異なるネットワーク適用率を指定できるようになりました。
|
||||
- 実験的オプションのため、将来的に削除または仕様変更される可能性があります。
|
||||
- たとえば状態 A を `1.0`、状態 B を `-1.0` として学習すると、LoRA の適用率に応じて状態 A と B を切り替えつつ生成できるかもしれません。
|
||||
- また、五段階の状態を用意し、それぞれ `0.2`、`0.4`、`0.6`、`0.8`、`1.0` として学習すると、適用率でなめらかに状態を切り替えて生成できるかもしれません。
|
||||
- `.toml` ファイルで `[[datasets]]` に `network_multiplier` を指定してください。
|
||||
- `networks/extract_lora_from_models.py` に使用メモリ量を削減するいくつかのオプションを追加しました。
|
||||
- `--load_precision` で読み込み時の精度を指定できます。モデルが fp16 で保存されている場合は `--load_precision fp16` を指定して精度を変えずにメモリ量を削減できます。
|
||||
- `--load_original_model_to` で元モデルを読み込むデバイスを、`--load_tuned_model_to` で派生モデルを読み込むデバイスを指定できます。デフォルトは両方とも `cpu` ですがそれぞれ `cuda` 等を指定できます。片方を GPU に読み込むことでメモリ量を削減できます。SDXL の場合のみ有効です。
|
||||
- マルチ GPU での LoRA 等の学習時に勾配の同期が改善されました。 PR [#1064](https://github.com/kohya-ss/sd-scripts/pull/1064) KohakuBlueleaf 氏に感謝します。
|
||||
- Intel IPEX サポートのコードが改善されました。PR [#1060](https://github.com/kohya-ss/sd-scripts/pull/1060) akx 氏に感謝します。
|
||||
- マルチ GPU での Textual Inversion 学習の不具合を修正しました。
|
||||
|
||||
The prompt weighting such as `( )` and `[ ]` are not working.
|
||||
- `.toml` example for network multiplier / ネットワーク適用率の `.toml` の記述例
|
||||
|
||||
- サンプル画像生成:
|
||||
プロンプトファイルは例えば以下のようになります。
|
||||
```toml
|
||||
[general]
|
||||
[[datasets]]
|
||||
resolution = 512
|
||||
batch_size = 8
|
||||
network_multiplier = 1.0
|
||||
|
||||
```
|
||||
# prompt 1
|
||||
masterpiece, best quality, 1girl, in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
|
||||
... subset settings ...
|
||||
|
||||
# prompt 2
|
||||
masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
|
||||
```
|
||||
[[datasets]]
|
||||
resolution = 512
|
||||
batch_size = 8
|
||||
network_multiplier = -1.0
|
||||
|
||||
`#` で始まる行はコメントになります。`--n` のように「ハイフン二個+英小文字」の形でオプションを指定できます。以下が使用可能できます。
|
||||
... subset settings ...
|
||||
```
|
||||
|
||||
* `--n` Negative prompt up to the next option.
|
||||
* `--w` Specifies the width of the generated image.
|
||||
* `--h` Specifies the height of the generated image.
|
||||
* `--d` Specifies the seed of the generated image.
|
||||
* `--l` Specifies the CFG scale of the generated image.
|
||||
* `--s` Specifies the number of steps in the generation.
|
||||
|
||||
`( )` や `[ ]` などの重みづけは動作しません。
|
||||
### Jan 17, 2024 / 2024/1/17: v0.8.1
|
||||
|
||||
- Fixed a bug that the VRAM usage without Text Encoder training is larger than before in training scripts for LoRA etc (`train_network.py`, `sdxl_train_network.py`).
|
||||
- Text Encoders were not moved to CPU.
|
||||
- Fixed typos. Thanks to akx! [PR #1053](https://github.com/kohya-ss/sd-scripts/pull/1053)
|
||||
|
||||
- LoRA 等の学習スクリプト(`train_network.py`、`sdxl_train_network.py`)で、Text Encoder を学習しない場合の VRAM 使用量が以前に比べて大きくなっていた不具合を修正しました。
|
||||
- Text Encoder が GPU に保持されたままになっていました。
|
||||
- 誤字が修正されました。 [PR #1053](https://github.com/kohya-ss/sd-scripts/pull/1053) akx 氏に感謝します。
|
||||
|
||||
### Jan 15, 2024 / 2024/1/15: v0.8.0
|
||||
|
||||
- Diffusers, Accelerate, Transformers and other related libraries have been updated. Please update the libraries with [Upgrade](#upgrade).
|
||||
- Some model files (Text Encoder without position_id) based on the latest Transformers can be loaded.
|
||||
- `torch.compile` is supported (experimental). PR [#1024](https://github.com/kohya-ss/sd-scripts/pull/1024) Thanks to p1atdev!
|
||||
- This feature works only on Linux or WSL.
|
||||
- Please specify `--torch_compile` option in each training script.
|
||||
- You can select the backend with `--dynamo_backend` option. The default is `"inductor"`. `inductor` or `eager` seems to work.
|
||||
- Please use `--sdpa` option instead of `--xformers` option.
|
||||
- PyTorch 2.1 or later is recommended.
|
||||
- Please see [PR](https://github.com/kohya-ss/sd-scripts/pull/1024) for details.
|
||||
- The session name for wandb can be specified with `--wandb_run_name` option. PR [#1032](https://github.com/kohya-ss/sd-scripts/pull/1032) Thanks to hopl1t!
|
||||
- IPEX library is updated. PR [#1030](https://github.com/kohya-ss/sd-scripts/pull/1030) Thanks to Disty0!
|
||||
- Fixed a bug that Diffusers format model cannot be saved.
|
||||
|
||||
- Diffusers、Accelerate、Transformers 等の関連ライブラリを更新しました。[Upgrade](#upgrade) を参照し更新をお願いします。
|
||||
- 最新の Transformers を前提とした一部のモデルファイル(Text Encoder が position_id を持たないもの)が読み込めるようになりました。
|
||||
- `torch.compile` がサポートされしました(実験的)。 PR [#1024](https://github.com/kohya-ss/sd-scripts/pull/1024) p1atdev 氏に感謝します。
|
||||
- Linux または WSL でのみ動作します。
|
||||
- 各学習スクリプトで `--torch_compile` オプションを指定してください。
|
||||
- `--dynamo_backend` オプションで使用される backend を選択できます。デフォルトは `"inductor"` です。 `inductor` または `eager` が動作するようです。
|
||||
- `--xformers` オプションとは互換性がありません。 代わりに `--sdpa` オプションを使用してください。
|
||||
- PyTorch 2.1以降を推奨します。
|
||||
- 詳細は [PR](https://github.com/kohya-ss/sd-scripts/pull/1024) をご覧ください。
|
||||
- wandb 保存時のセッション名が各学習スクリプトの `--wandb_run_name` オプションで指定できるようになりました。 PR [#1032](https://github.com/kohya-ss/sd-scripts/pull/1032) hopl1t 氏に感謝します。
|
||||
- IPEX ライブラリが更新されました。[PR #1030](https://github.com/kohya-ss/sd-scripts/pull/1030) Disty0 氏に感謝します。
|
||||
- Diffusers 形式でのモデル保存ができなくなっていた不具合を修正しました。
|
||||
|
||||
|
||||
Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.
|
||||
最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。
|
||||
|
||||
### Naming of LoRA
|
||||
|
||||
The LoRA supported by `train_network.py` has been named to avoid confusion. The documentation has been updated. The following are the names of LoRA types in this repository.
|
||||
|
||||
1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers)
|
||||
|
||||
LoRA for Linear layers and Conv2d layers with 1x1 kernel
|
||||
|
||||
2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers)
|
||||
|
||||
In addition to 1., LoRA for Conv2d layers with 3x3 kernel
|
||||
|
||||
LoRA-LierLa is the default LoRA type for `train_network.py` (without `conv_dim` network arg). LoRA-LierLa can be used with [our extension](https://github.com/kohya-ss/sd-webui-additional-networks) for AUTOMATIC1111's Web UI, or with the built-in LoRA feature of the Web UI.
|
||||
|
||||
To use LoRA-C3Lier with Web UI, please use our extension.
|
||||
|
||||
### LoRAの名称について
|
||||
|
||||
`train_network.py` がサポートするLoRAについて、混乱を避けるため名前を付けました。ドキュメントは更新済みです。以下は当リポジトリ内の独自の名称です。
|
||||
|
||||
1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers、リエラと読みます)
|
||||
|
||||
Linear 層およびカーネルサイズ 1x1 の Conv2d 層に適用されるLoRA
|
||||
|
||||
2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers、セリアと読みます)
|
||||
|
||||
1.に加え、カーネルサイズ 3x3 の Conv2d 層に適用されるLoRA
|
||||
|
||||
LoRA-LierLa は[Web UI向け拡張](https://github.com/kohya-ss/sd-webui-additional-networks)、またはAUTOMATIC1111氏のWeb UIのLoRA機能で使用することができます。
|
||||
|
||||
LoRA-C3Lierを使いWeb UIで生成するには拡張を使用してください。
|
||||
|
||||
## Sample image generation during training
|
||||
A prompt file might look like this, for example
|
||||
|
||||
```
|
||||
# prompt 1
|
||||
masterpiece, best quality, (1girl), in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
|
||||
|
||||
# prompt 2
|
||||
masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n (low quality, worst quality), bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
|
||||
```
|
||||
|
||||
Lines beginning with `#` are comments. You can specify options for the generated image with options like `--n` after the prompt. The following can be used.
|
||||
|
||||
* `--n` Negative prompt up to the next option.
|
||||
* `--w` Specifies the width of the generated image.
|
||||
* `--h` Specifies the height of the generated image.
|
||||
* `--d` Specifies the seed of the generated image.
|
||||
* `--l` Specifies the CFG scale of the generated image.
|
||||
* `--s` Specifies the number of steps in the generation.
|
||||
|
||||
The prompt weighting such as `( )` and `[ ]` are working.
|
||||
|
||||
## サンプル画像生成
|
||||
プロンプトファイルは例えば以下のようになります。
|
||||
|
||||
```
|
||||
# prompt 1
|
||||
masterpiece, best quality, (1girl), in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
|
||||
|
||||
# prompt 2
|
||||
masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n (low quality, worst quality), bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
|
||||
```
|
||||
|
||||
`#` で始まる行はコメントになります。`--n` のように「ハイフン二個+英小文字」の形でオプションを指定できます。以下が使用可能できます。
|
||||
|
||||
* `--n` Negative prompt up to the next option.
|
||||
* `--w` Specifies the width of the generated image.
|
||||
* `--h` Specifies the height of the generated image.
|
||||
* `--d` Specifies the seed of the generated image.
|
||||
* `--l` Specifies the CFG scale of the generated image.
|
||||
* `--s` Specifies the number of steps in the generation.
|
||||
|
||||
`( )` や `[ ]` などの重みづけも動作します。
|
||||
|
||||
|
||||
204
XTI_hijack.py
Normal file
204
XTI_hijack.py
Normal file
@@ -0,0 +1,204 @@
|
||||
import torch
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
from typing import Union, List, Optional, Dict, Any, Tuple
|
||||
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
|
||||
|
||||
from library.original_unet import SampleOutput
|
||||
|
||||
|
||||
def unet_forward_XTI(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Dict, Tuple]:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
||||
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
||||
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a dict instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
`SampleOutput` or `tuple`:
|
||||
`SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
||||
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
||||
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
||||
# on the fly if necessary.
|
||||
# デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある
|
||||
# ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する
|
||||
# 多分画質が悪くなるので、64で割り切れるようにしておくのが良い
|
||||
default_overall_up_factor = 2**self.num_upsamplers
|
||||
|
||||
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
||||
# 64で割り切れないときはupsamplerにサイズを伝える
|
||||
forward_upsample_size = False
|
||||
upsample_size = None
|
||||
|
||||
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
||||
# logger.info("Forward upsample size to force interpolation output size.")
|
||||
forward_upsample_size = True
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
timesteps = self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
# timestepsは重みを含まないので常にfloat32のテンソルを返す
|
||||
# しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある
|
||||
# time_projでキャストしておけばいいんじゃね?
|
||||
t_emb = t_emb.to(dtype=self.dtype)
|
||||
emb = self.time_embedding(t_emb)
|
||||
|
||||
# 2. pre-process
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
down_i = 0
|
||||
for downsample_block in self.down_blocks:
|
||||
# downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
|
||||
# まあこちらのほうがわかりやすいかもしれない
|
||||
if downsample_block.has_cross_attention:
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states[down_i : down_i + 2],
|
||||
)
|
||||
down_i += 2
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 4. mid
|
||||
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states[6])
|
||||
|
||||
# 5. up
|
||||
up_i = 7
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
is_final_block = i == len(self.up_blocks) - 1
|
||||
|
||||
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
||||
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection
|
||||
|
||||
# if we have not reached the final block and need to forward the upsample size, we do it here
|
||||
# 前述のように最後のブロック以外ではupsample_sizeを伝える
|
||||
if not is_final_block and forward_upsample_size:
|
||||
upsample_size = down_block_res_samples[-1].shape[2:]
|
||||
|
||||
if upsample_block.has_cross_attention:
|
||||
sample = upsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
encoder_hidden_states=encoder_hidden_states[up_i : up_i + 3],
|
||||
upsample_size=upsample_size,
|
||||
)
|
||||
up_i += 3
|
||||
else:
|
||||
sample = upsample_block(
|
||||
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
||||
)
|
||||
|
||||
# 6. post-process
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
|
||||
return SampleOutput(sample=sample)
|
||||
|
||||
|
||||
def downblock_forward_XTI(
|
||||
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
|
||||
):
|
||||
output_states = ()
|
||||
i = 0
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i]
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample
|
||||
|
||||
output_states += (hidden_states,)
|
||||
i += 1
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
|
||||
def upblock_forward_XTI(
|
||||
self,
|
||||
hidden_states,
|
||||
res_hidden_states_tuple,
|
||||
temb=None,
|
||||
encoder_hidden_states=None,
|
||||
upsample_size=None,
|
||||
):
|
||||
i = 0
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i]
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample
|
||||
|
||||
i += 1
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, upsample_size)
|
||||
|
||||
return hidden_states
|
||||
20
_typos.toml
20
_typos.toml
@@ -9,7 +9,25 @@ parms="parms"
|
||||
nin="nin"
|
||||
extention="extention" # Intentionally left
|
||||
nd="nd"
|
||||
shs="shs"
|
||||
sts="sts"
|
||||
scs="scs"
|
||||
cpc="cpc"
|
||||
coc="coc"
|
||||
cic="cic"
|
||||
msm="msm"
|
||||
usu="usu"
|
||||
ici="ici"
|
||||
lvl="lvl"
|
||||
dii="dii"
|
||||
muk="muk"
|
||||
ori="ori"
|
||||
hru="hru"
|
||||
rik="rik"
|
||||
koo="koo"
|
||||
yos="yos"
|
||||
wn="wn"
|
||||
|
||||
|
||||
[files]
|
||||
extend-exclude = ["_typos.toml"]
|
||||
extend-exclude = ["_typos.toml", "venv"]
|
||||
|
||||
BIN
bitsandbytes_windows/libbitsandbytes_cuda118.dll
Normal file
BIN
bitsandbytes_windows/libbitsandbytes_cuda118.dll
Normal file
Binary file not shown.
@@ -1,166 +1,166 @@
|
||||
"""
|
||||
extract factors the build is dependent on:
|
||||
[X] compute capability
|
||||
[ ] TODO: Q - What if we have multiple GPUs of different makes?
|
||||
- CUDA version
|
||||
- Software:
|
||||
- CPU-only: only CPU quantization functions (no optimizer, no matrix multiple)
|
||||
- CuBLAS-LT: full-build 8-bit optimizer
|
||||
- no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`)
|
||||
|
||||
evaluation:
|
||||
- if paths faulty, return meaningful error
|
||||
- else:
|
||||
- determine CUDA version
|
||||
- determine capabilities
|
||||
- based on that set the default path
|
||||
"""
|
||||
|
||||
import ctypes
|
||||
|
||||
from .paths import determine_cuda_runtime_lib_path
|
||||
|
||||
|
||||
def check_cuda_result(cuda, result_val):
|
||||
# 3. Check for CUDA errors
|
||||
if result_val != 0:
|
||||
error_str = ctypes.c_char_p()
|
||||
cuda.cuGetErrorString(result_val, ctypes.byref(error_str))
|
||||
print(f"CUDA exception! Error code: {error_str.value.decode()}")
|
||||
|
||||
def get_cuda_version(cuda, cudart_path):
|
||||
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION
|
||||
try:
|
||||
cudart = ctypes.CDLL(cudart_path)
|
||||
except OSError:
|
||||
# TODO: shouldn't we error or at least warn here?
|
||||
print(f'ERROR: libcudart.so could not be read from path: {cudart_path}!')
|
||||
return None
|
||||
|
||||
version = ctypes.c_int()
|
||||
check_cuda_result(cuda, cudart.cudaRuntimeGetVersion(ctypes.byref(version)))
|
||||
version = int(version.value)
|
||||
major = version//1000
|
||||
minor = (version-(major*1000))//10
|
||||
|
||||
if major < 11:
|
||||
print('CUDA SETUP: CUDA version lower than 11 are currently not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!')
|
||||
|
||||
return f'{major}{minor}'
|
||||
|
||||
|
||||
def get_cuda_lib_handle():
|
||||
# 1. find libcuda.so library (GPU driver) (/usr/lib)
|
||||
try:
|
||||
cuda = ctypes.CDLL("libcuda.so")
|
||||
except OSError:
|
||||
# TODO: shouldn't we error or at least warn here?
|
||||
print('CUDA SETUP: WARNING! libcuda.so not found! Do you have a CUDA driver installed? If you are on a cluster, make sure you are on a CUDA machine!')
|
||||
return None
|
||||
check_cuda_result(cuda, cuda.cuInit(0))
|
||||
|
||||
return cuda
|
||||
|
||||
|
||||
def get_compute_capabilities(cuda):
|
||||
"""
|
||||
1. find libcuda.so library (GPU driver) (/usr/lib)
|
||||
init_device -> init variables -> call function by reference
|
||||
2. call extern C function to determine CC
|
||||
(https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html)
|
||||
3. Check for CUDA errors
|
||||
https://stackoverflow.com/questions/14038589/what-is-the-canonical-way-to-check-for-errors-using-the-cuda-runtime-api
|
||||
# bits taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
|
||||
"""
|
||||
|
||||
|
||||
nGpus = ctypes.c_int()
|
||||
cc_major = ctypes.c_int()
|
||||
cc_minor = ctypes.c_int()
|
||||
|
||||
device = ctypes.c_int()
|
||||
|
||||
check_cuda_result(cuda, cuda.cuDeviceGetCount(ctypes.byref(nGpus)))
|
||||
ccs = []
|
||||
for i in range(nGpus.value):
|
||||
check_cuda_result(cuda, cuda.cuDeviceGet(ctypes.byref(device), i))
|
||||
ref_major = ctypes.byref(cc_major)
|
||||
ref_minor = ctypes.byref(cc_minor)
|
||||
# 2. call extern C function to determine CC
|
||||
check_cuda_result(
|
||||
cuda, cuda.cuDeviceComputeCapability(ref_major, ref_minor, device)
|
||||
)
|
||||
ccs.append(f"{cc_major.value}.{cc_minor.value}")
|
||||
|
||||
return ccs
|
||||
|
||||
|
||||
# def get_compute_capability()-> Union[List[str, ...], None]: # FIXME: error
|
||||
def get_compute_capability(cuda):
|
||||
"""
|
||||
Extracts the highest compute capbility from all available GPUs, as compute
|
||||
capabilities are downwards compatible. If no GPUs are detected, it returns
|
||||
None.
|
||||
"""
|
||||
ccs = get_compute_capabilities(cuda)
|
||||
if ccs is not None:
|
||||
# TODO: handle different compute capabilities; for now, take the max
|
||||
return ccs[-1]
|
||||
return None
|
||||
|
||||
|
||||
def evaluate_cuda_setup():
|
||||
print('')
|
||||
print('='*35 + 'BUG REPORT' + '='*35)
|
||||
print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues')
|
||||
print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link')
|
||||
print('='*80)
|
||||
return "libbitsandbytes_cuda116.dll" # $$$
|
||||
|
||||
binary_name = "libbitsandbytes_cpu.so"
|
||||
#if not torch.cuda.is_available():
|
||||
#print('No GPU detected. Loading CPU library...')
|
||||
#return binary_name
|
||||
|
||||
cudart_path = determine_cuda_runtime_lib_path()
|
||||
if cudart_path is None:
|
||||
print(
|
||||
"WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!"
|
||||
)
|
||||
return binary_name
|
||||
|
||||
print(f"CUDA SETUP: CUDA runtime path found: {cudart_path}")
|
||||
cuda = get_cuda_lib_handle()
|
||||
cc = get_compute_capability(cuda)
|
||||
print(f"CUDA SETUP: Highest compute capability among GPUs detected: {cc}")
|
||||
cuda_version_string = get_cuda_version(cuda, cudart_path)
|
||||
|
||||
|
||||
if cc == '':
|
||||
print(
|
||||
"WARNING: No GPU detected! Check your CUDA paths. Processing to load CPU-only library..."
|
||||
)
|
||||
return binary_name
|
||||
|
||||
# 7.5 is the minimum CC vor cublaslt
|
||||
has_cublaslt = cc in ["7.5", "8.0", "8.6"]
|
||||
|
||||
# TODO:
|
||||
# (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible)
|
||||
# (2) Multiple CUDA versions installed
|
||||
|
||||
# we use ls -l instead of nvcc to determine the cuda version
|
||||
# since most installations will have the libcudart.so installed, but not the compiler
|
||||
print(f'CUDA SETUP: Detected CUDA version {cuda_version_string}')
|
||||
|
||||
def get_binary_name():
|
||||
"if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt.so"
|
||||
bin_base_name = "libbitsandbytes_cuda"
|
||||
if has_cublaslt:
|
||||
return f"{bin_base_name}{cuda_version_string}.so"
|
||||
else:
|
||||
return f"{bin_base_name}{cuda_version_string}_nocublaslt.so"
|
||||
|
||||
binary_name = get_binary_name()
|
||||
|
||||
return binary_name
|
||||
"""
|
||||
extract factors the build is dependent on:
|
||||
[X] compute capability
|
||||
[ ] TODO: Q - What if we have multiple GPUs of different makes?
|
||||
- CUDA version
|
||||
- Software:
|
||||
- CPU-only: only CPU quantization functions (no optimizer, no matrix multiple)
|
||||
- CuBLAS-LT: full-build 8-bit optimizer
|
||||
- no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`)
|
||||
|
||||
evaluation:
|
||||
- if paths faulty, return meaningful error
|
||||
- else:
|
||||
- determine CUDA version
|
||||
- determine capabilities
|
||||
- based on that set the default path
|
||||
"""
|
||||
|
||||
import ctypes
|
||||
|
||||
from .paths import determine_cuda_runtime_lib_path
|
||||
|
||||
|
||||
def check_cuda_result(cuda, result_val):
|
||||
# 3. Check for CUDA errors
|
||||
if result_val != 0:
|
||||
error_str = ctypes.c_char_p()
|
||||
cuda.cuGetErrorString(result_val, ctypes.byref(error_str))
|
||||
print(f"CUDA exception! Error code: {error_str.value.decode()}")
|
||||
|
||||
def get_cuda_version(cuda, cudart_path):
|
||||
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION
|
||||
try:
|
||||
cudart = ctypes.CDLL(cudart_path)
|
||||
except OSError:
|
||||
# TODO: shouldn't we error or at least warn here?
|
||||
print(f'ERROR: libcudart.so could not be read from path: {cudart_path}!')
|
||||
return None
|
||||
|
||||
version = ctypes.c_int()
|
||||
check_cuda_result(cuda, cudart.cudaRuntimeGetVersion(ctypes.byref(version)))
|
||||
version = int(version.value)
|
||||
major = version//1000
|
||||
minor = (version-(major*1000))//10
|
||||
|
||||
if major < 11:
|
||||
print('CUDA SETUP: CUDA version lower than 11 are currently not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!')
|
||||
|
||||
return f'{major}{minor}'
|
||||
|
||||
|
||||
def get_cuda_lib_handle():
|
||||
# 1. find libcuda.so library (GPU driver) (/usr/lib)
|
||||
try:
|
||||
cuda = ctypes.CDLL("libcuda.so")
|
||||
except OSError:
|
||||
# TODO: shouldn't we error or at least warn here?
|
||||
print('CUDA SETUP: WARNING! libcuda.so not found! Do you have a CUDA driver installed? If you are on a cluster, make sure you are on a CUDA machine!')
|
||||
return None
|
||||
check_cuda_result(cuda, cuda.cuInit(0))
|
||||
|
||||
return cuda
|
||||
|
||||
|
||||
def get_compute_capabilities(cuda):
|
||||
"""
|
||||
1. find libcuda.so library (GPU driver) (/usr/lib)
|
||||
init_device -> init variables -> call function by reference
|
||||
2. call extern C function to determine CC
|
||||
(https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html)
|
||||
3. Check for CUDA errors
|
||||
https://stackoverflow.com/questions/14038589/what-is-the-canonical-way-to-check-for-errors-using-the-cuda-runtime-api
|
||||
# bits taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
|
||||
"""
|
||||
|
||||
|
||||
nGpus = ctypes.c_int()
|
||||
cc_major = ctypes.c_int()
|
||||
cc_minor = ctypes.c_int()
|
||||
|
||||
device = ctypes.c_int()
|
||||
|
||||
check_cuda_result(cuda, cuda.cuDeviceGetCount(ctypes.byref(nGpus)))
|
||||
ccs = []
|
||||
for i in range(nGpus.value):
|
||||
check_cuda_result(cuda, cuda.cuDeviceGet(ctypes.byref(device), i))
|
||||
ref_major = ctypes.byref(cc_major)
|
||||
ref_minor = ctypes.byref(cc_minor)
|
||||
# 2. call extern C function to determine CC
|
||||
check_cuda_result(
|
||||
cuda, cuda.cuDeviceComputeCapability(ref_major, ref_minor, device)
|
||||
)
|
||||
ccs.append(f"{cc_major.value}.{cc_minor.value}")
|
||||
|
||||
return ccs
|
||||
|
||||
|
||||
# def get_compute_capability()-> Union[List[str, ...], None]: # FIXME: error
|
||||
def get_compute_capability(cuda):
|
||||
"""
|
||||
Extracts the highest compute capbility from all available GPUs, as compute
|
||||
capabilities are downwards compatible. If no GPUs are detected, it returns
|
||||
None.
|
||||
"""
|
||||
ccs = get_compute_capabilities(cuda)
|
||||
if ccs is not None:
|
||||
# TODO: handle different compute capabilities; for now, take the max
|
||||
return ccs[-1]
|
||||
return None
|
||||
|
||||
|
||||
def evaluate_cuda_setup():
|
||||
print('')
|
||||
print('='*35 + 'BUG REPORT' + '='*35)
|
||||
print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues')
|
||||
print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link')
|
||||
print('='*80)
|
||||
return "libbitsandbytes_cuda116.dll" # $$$
|
||||
|
||||
binary_name = "libbitsandbytes_cpu.so"
|
||||
#if not torch.cuda.is_available():
|
||||
#print('No GPU detected. Loading CPU library...')
|
||||
#return binary_name
|
||||
|
||||
cudart_path = determine_cuda_runtime_lib_path()
|
||||
if cudart_path is None:
|
||||
print(
|
||||
"WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!"
|
||||
)
|
||||
return binary_name
|
||||
|
||||
print(f"CUDA SETUP: CUDA runtime path found: {cudart_path}")
|
||||
cuda = get_cuda_lib_handle()
|
||||
cc = get_compute_capability(cuda)
|
||||
print(f"CUDA SETUP: Highest compute capability among GPUs detected: {cc}")
|
||||
cuda_version_string = get_cuda_version(cuda, cudart_path)
|
||||
|
||||
|
||||
if cc == '':
|
||||
print(
|
||||
"WARNING: No GPU detected! Check your CUDA paths. Processing to load CPU-only library..."
|
||||
)
|
||||
return binary_name
|
||||
|
||||
# 7.5 is the minimum CC vor cublaslt
|
||||
has_cublaslt = cc in ["7.5", "8.0", "8.6"]
|
||||
|
||||
# TODO:
|
||||
# (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible)
|
||||
# (2) Multiple CUDA versions installed
|
||||
|
||||
# we use ls -l instead of nvcc to determine the cuda version
|
||||
# since most installations will have the libcudart.so installed, but not the compiler
|
||||
print(f'CUDA SETUP: Detected CUDA version {cuda_version_string}')
|
||||
|
||||
def get_binary_name():
|
||||
"if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt.so"
|
||||
bin_base_name = "libbitsandbytes_cuda"
|
||||
if has_cublaslt:
|
||||
return f"{bin_base_name}{cuda_version_string}.so"
|
||||
else:
|
||||
return f"{bin_base_name}{cuda_version_string}_nocublaslt.so"
|
||||
|
||||
binary_name = get_binary_name()
|
||||
|
||||
return binary_name
|
||||
|
||||
@@ -138,9 +138,13 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学
|
||||
| `num_repeats` | `10` | o | o | o |
|
||||
| `random_crop` | `false` | o | o | o |
|
||||
| `shuffle_caption` | `true` | o | o | o |
|
||||
| `caption_prefix` | `“masterpiece, best quality, ”` | o | o | o |
|
||||
| `caption_suffix` | `“, from side”` | o | o | o |
|
||||
|
||||
* `num_repeats`
|
||||
* サブセットの画像の繰り返し回数を指定します。fine tuning における `--dataset_repeats` に相当しますが、`num_repeats` はどの学習方法でも指定可能です。
|
||||
* `caption_prefix`, `caption_suffix`
|
||||
* キャプションの前、後に付与する文字列を指定します。シャッフルはこれらの文字列を含めた状態で行われます。`keep_tokens` を指定する場合には注意してください。
|
||||
|
||||
### DreamBooth 方式専用のオプション
|
||||
|
||||
454
docs/gen_img_README-ja.md
Normal file
454
docs/gen_img_README-ja.md
Normal file
@@ -0,0 +1,454 @@
|
||||
SD 1.xおよび2.xのモデル、当リポジトリで学習したLoRA、ControlNet(v1.0のみ動作確認)などに対応した、Diffusersベースの推論(画像生成)スクリプトです。コマンドラインから用います。
|
||||
|
||||
# 概要
|
||||
|
||||
* Diffusers (v0.10.2) ベースの推論(画像生成)スクリプト。
|
||||
* SD 1.xおよび2.x (base/v-parameterization)モデルに対応。
|
||||
* txt2img、img2img、inpaintingに対応。
|
||||
* 対話モード、およびファイルからのプロンプト読み込み、連続生成に対応。
|
||||
* プロンプト1行あたりの生成枚数を指定可能。
|
||||
* 全体の繰り返し回数を指定可能。
|
||||
* `fp16`だけでなく`bf16`にも対応。
|
||||
* xformersに対応し高速生成が可能。
|
||||
* xformersにより省メモリ生成を行いますが、Automatic 1111氏のWeb UIほど最適化していないため、512*512の画像生成でおおむね6GB程度のVRAMを使用します。
|
||||
* プロンプトの225トークンへの拡張。ネガティブプロンプト、重みづけに対応。
|
||||
* Diffusersの各種samplerに対応(Web UIよりもsampler数は少ないです)。
|
||||
* Text Encoderのclip skip(最後からn番目の層の出力を用いる)に対応。
|
||||
* VAEの別途読み込み。
|
||||
* CLIP Guided Stable Diffusion、VGG16 Guided Stable Diffusion、Highres. fix、upscale対応。
|
||||
* Highres. fixはWeb UIの実装を全く確認していない独自実装のため、出力結果は異なるかもしれません。
|
||||
* LoRA対応。適用率指定、複数LoRA同時利用、重みのマージに対応。
|
||||
* Text EncoderとU-Netで別の適用率を指定することはできません。
|
||||
* Attention Coupleに対応。
|
||||
* ControlNet v1.0に対応。
|
||||
* 途中でモデルを切り替えることはできませんが、バッチファイルを組むことで対応できます。
|
||||
* 個人的に欲しくなった機能をいろいろ追加。
|
||||
|
||||
機能追加時にすべてのテストを行っているわけではないため、以前の機能に影響が出て一部機能が動かない可能性があります。何か問題があればお知らせください。
|
||||
|
||||
# 基本的な使い方
|
||||
|
||||
## 対話モードでの画像生成
|
||||
|
||||
以下のように入力してください。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt <モデル名> --outdir <画像出力先> --xformers --fp16 --interactive
|
||||
```
|
||||
|
||||
`--ckpt`オプションにモデル(Stable Diffusionのcheckpointファイル、またはDiffusersのモデルフォルダ)、`--outdir`オプションに画像の出力先フォルダを指定します。
|
||||
|
||||
`--xformers`オプションでxformersの使用を指定します(xformersを使わない場合は外してください)。`--fp16`オプションでfp16(単精度)での推論を行います。RTX 30系のGPUでは `--bf16`オプションでbf16(bfloat16)での推論を行うこともできます。
|
||||
|
||||
`--interactive`オプションで対話モードを指定しています。
|
||||
|
||||
Stable Diffusion 2.0(またはそこからの追加学習モデル)を使う場合は`--v2`オプションを追加してください。v-parameterizationを使うモデル(`768-v-ema.ckpt`およびそこからの追加学習モデル)を使う場合はさらに`--v_parameterization`を追加してください。
|
||||
|
||||
`--v2`の指定有無が間違っているとモデル読み込み時にエラーになります。`--v_parameterization`の指定有無が間違っていると茶色い画像が表示されます。
|
||||
|
||||
`Type prompt:`と表示されたらプロンプトを入力してください。
|
||||
|
||||

|
||||
|
||||
※画像が表示されずエラーになる場合、headless(画面表示機能なし)のOpenCVがインストールされているかもしれません。`pip install opencv-python`として通常のOpenCVを入れてください。または`--no_preview`オプションで画像表示を止めてください。
|
||||
|
||||
画像ウィンドウを選択してから何らかのキーを押すとウィンドウが閉じ、次のプロンプトが入力できます。プロンプトでCtrl+Z、エンターの順に打鍵するとスクリプトを閉じます。
|
||||
|
||||
## 単一のプロンプトで画像を一括生成
|
||||
|
||||
以下のように入力します(実際には1行で入力します)。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt <モデル名> --outdir <画像出力先>
|
||||
--xformers --fp16 --images_per_prompt <生成枚数> --prompt "<プロンプト>"
|
||||
```
|
||||
|
||||
`--images_per_prompt`オプションで、プロンプト1件当たりの生成枚数を指定します。`--prompt`オプションでプロンプトを指定します。スペースを含む場合はダブルクォーテーションで囲んでください。
|
||||
|
||||
`--batch_size`オプションでバッチサイズを指定できます(後述)。
|
||||
|
||||
## ファイルからプロンプトを読み込み一括生成
|
||||
|
||||
以下のように入力します。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt <モデル名> --outdir <画像出力先>
|
||||
--xformers --fp16 --from_file <プロンプトファイル名>
|
||||
```
|
||||
|
||||
`--from_file`オプションで、プロンプトが記述されたファイルを指定します。1行1プロンプトで記述してください。`--images_per_prompt`オプションを指定して1行あたり生成枚数を指定できます。
|
||||
|
||||
## ネガティブプロンプト、重みづけの使用
|
||||
|
||||
プロンプトオプション(プロンプト内で`--x`のように指定、後述)で`--n`を書くと、以降がネガティブプロンプトとなります。
|
||||
|
||||
またAUTOMATIC1111氏のWeb UIと同様の `()` や` []` 、`(xxx:1.3)` などによる重みづけが可能です(実装はDiffusersの[Long Prompt Weighting Stable Diffusion](https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#long-prompt-weighting-stable-diffusion)からコピーしたものです)。
|
||||
|
||||
コマンドラインからのプロンプト指定、ファイルからのプロンプト読み込みでも同様に指定できます。
|
||||
|
||||

|
||||
|
||||
# 主なオプション
|
||||
|
||||
コマンドラインから指定してください。
|
||||
|
||||
## モデルの指定
|
||||
|
||||
- `--ckpt <モデル名>`:モデル名を指定します。`--ckpt`オプションは必須です。Stable Diffusionのcheckpointファイル、またはDiffusersのモデルフォルダ、Hugging FaceのモデルIDを指定できます。
|
||||
|
||||
- `--v2`:Stable Diffusion 2.x系のモデルを使う場合に指定します。1.x系の場合には指定不要です。
|
||||
|
||||
- `--v_parameterization`:v-parameterizationを使うモデルを使う場合に指定します(`768-v-ema.ckpt`およびそこからの追加学習モデル、Waifu Diffusion v1.5など)。
|
||||
|
||||
`--v2`の指定有無が間違っているとモデル読み込み時にエラーになります。`--v_parameterization`の指定有無が間違っていると茶色い画像が表示されます。
|
||||
|
||||
- `--vae`:使用するVAEを指定します。未指定時はモデル内のVAEを使用します。
|
||||
|
||||
## 画像生成と出力
|
||||
|
||||
- `--interactive`:インタラクティブモードで動作します。プロンプトを入力すると画像が生成されます。
|
||||
|
||||
- `--prompt <プロンプト>`:プロンプトを指定します。スペースを含む場合はダブルクォーテーションで囲んでください。
|
||||
|
||||
- `--from_file <プロンプトファイル名>`:プロンプトが記述されたファイルを指定します。1行1プロンプトで記述してください。なお画像サイズやguidance scaleはプロンプトオプション(後述)で指定できます。
|
||||
|
||||
- `--W <画像幅>`:画像の幅を指定します。デフォルトは`512`です。
|
||||
|
||||
- `--H <画像高さ>`:画像の高さを指定します。デフォルトは`512`です。
|
||||
|
||||
- `--steps <ステップ数>`:サンプリングステップ数を指定します。デフォルトは`50`です。
|
||||
|
||||
- `--scale <ガイダンススケール>`:unconditionalガイダンススケールを指定します。デフォルトは`7.5`です。
|
||||
|
||||
- `--sampler <サンプラー名>`:サンプラーを指定します。デフォルトは`ddim`です。Diffusersで提供されているddim、pndm、dpmsolver、dpmsolver+++、lms、euler、euler_a、が指定可能です(後ろの三つはk_lms、k_euler、k_euler_aでも指定できます)。
|
||||
|
||||
- `--outdir <画像出力先フォルダ>`:画像の出力先を指定します。
|
||||
|
||||
- `--images_per_prompt <生成枚数>`:プロンプト1件当たりの生成枚数を指定します。デフォルトは`1`です。
|
||||
|
||||
- `--clip_skip <スキップ数>`:CLIPの後ろから何番目の層を使うかを指定します。省略時は最後の層を使います。
|
||||
|
||||
- `--max_embeddings_multiples <倍数>`:CLIPの入出力長をデフォルト(75)の何倍にするかを指定します。未指定時は75のままです。たとえば3を指定すると入出力長が225になります。
|
||||
|
||||
- `--negative_scale` : uncoditioningのguidance scaleを個別に指定します。[gcem156氏のこちらの記事](https://note.com/gcem156/n/ne9a53e4a6f43)を参考に実装したものです。
|
||||
|
||||
## メモリ使用量や生成速度の調整
|
||||
|
||||
- `--batch_size <バッチサイズ>`:バッチサイズを指定します。デフォルトは`1`です。バッチサイズが大きいとメモリを多く消費しますが、生成速度が速くなります。
|
||||
|
||||
- `--vae_batch_size <VAEのバッチサイズ>`:VAEのバッチサイズを指定します。デフォルトはバッチサイズと同じです。
|
||||
VAEのほうがメモリを多く消費するため、デノイジング後(stepが100%になった後)でメモリ不足になる場合があります。このような場合にはVAEのバッチサイズを小さくしてください。
|
||||
|
||||
- `--xformers`:xformersを使う場合に指定します。
|
||||
|
||||
- `--fp16`:fp16(単精度)での推論を行います。`fp16`と`bf16`をどちらも指定しない場合はfp32(単精度)での推論を行います。
|
||||
|
||||
- `--bf16`:bf16(bfloat16)での推論を行います。RTX 30系のGPUでのみ指定可能です。`--bf16`オプションはRTX 30系以外のGPUではエラーになります。`fp16`よりも`bf16`のほうが推論結果がNaNになる(真っ黒の画像になる)可能性が低いようです。
|
||||
|
||||
## 追加ネットワーク(LoRA等)の使用
|
||||
|
||||
- `--network_module`:使用する追加ネットワークを指定します。LoRAの場合は`--network_module networks.lora`と指定します。複数のLoRAを使用する場合は`--network_module networks.lora networks.lora networks.lora`のように指定します。
|
||||
|
||||
- `--network_weights`:使用する追加ネットワークの重みファイルを指定します。`--network_weights model.safetensors`のように指定します。複数のLoRAを使用する場合は`--network_weights model1.safetensors model2.safetensors model3.safetensors`のように指定します。引数の数は`--network_module`で指定した数と同じにしてください。
|
||||
|
||||
- `--network_mul`:使用する追加ネットワークの重みを何倍にするかを指定します。デフォルトは`1`です。`--network_mul 0.8`のように指定します。複数のLoRAを使用する場合は`--network_mul 0.4 0.5 0.7`のように指定します。引数の数は`--network_module`で指定した数と同じにしてください。
|
||||
|
||||
- `--network_merge`:使用する追加ネットワークの重みを`--network_mul`に指定した重みであらかじめマージします。`--network_pre_calc` と同時に使用できません。プロンプトオプションの`--am`、およびRegional LoRAは使用できなくなりますが、LoRA未使用時と同じ程度まで生成が高速化されます。
|
||||
|
||||
- `--network_pre_calc`:使用する追加ネットワークの重みを生成ごとにあらかじめ計算します。プロンプトオプションの`--am`が使用できます。LoRA未使用時と同じ程度まで生成は高速化されますが、生成前に重みを計算する時間が必要で、またメモリ使用量も若干増加します。Regional LoRA使用時は無効になります 。
|
||||
|
||||
# 主なオプションの指定例
|
||||
|
||||
次は同一プロンプトで64枚をバッチサイズ4で一括生成する例です。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt model.ckpt --outdir outputs
|
||||
--xformers --fp16 --W 512 --H 704 --scale 12.5 --sampler k_euler_a
|
||||
--steps 32 --batch_size 4 --images_per_prompt 64
|
||||
--prompt "beautiful flowers --n monochrome"
|
||||
```
|
||||
|
||||
次はファイルに書かれたプロンプトを、それぞれ10枚ずつ、バッチサイズ4で一括生成する例です。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt model.ckpt --outdir outputs
|
||||
--xformers --fp16 --W 512 --H 704 --scale 12.5 --sampler k_euler_a
|
||||
--steps 32 --batch_size 4 --images_per_prompt 10
|
||||
--from_file prompts.txt
|
||||
```
|
||||
|
||||
Textual Inversion(後述)およびLoRAの使用例です。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt model.safetensors
|
||||
--scale 8 --steps 48 --outdir txt2img --xformers
|
||||
--W 512 --H 768 --fp16 --sampler k_euler_a
|
||||
--textual_inversion_embeddings goodembed.safetensors negprompt.pt
|
||||
--network_module networks.lora networks.lora
|
||||
--network_weights model1.safetensors model2.safetensors
|
||||
--network_mul 0.4 0.8
|
||||
--clip_skip 2 --max_embeddings_multiples 1
|
||||
--batch_size 8 --images_per_prompt 1 --interactive
|
||||
```
|
||||
|
||||
# プロンプトオプション
|
||||
|
||||
プロンプト内で、`--n`のように「ハイフンふたつ+アルファベットn文字」でプロンプトから各種オプションの指定が可能です。対話モード、コマンドライン、ファイル、いずれからプロンプトを指定する場合でも有効です。
|
||||
|
||||
プロンプトのオプション指定`--n`の前後にはスペースを入れてください。
|
||||
|
||||
- `--n`:ネガティブプロンプトを指定します。
|
||||
|
||||
- `--w`:画像幅を指定します。コマンドラインからの指定を上書きします。
|
||||
|
||||
- `--h`:画像高さを指定します。コマンドラインからの指定を上書きします。
|
||||
|
||||
- `--s`:ステップ数を指定します。コマンドラインからの指定を上書きします。
|
||||
|
||||
- `--d`:この画像の乱数seedを指定します。`--images_per_prompt`を指定している場合は「--d 1,2,3,4」のようにカンマ区切りで複数指定してください。
|
||||
※様々な理由により、Web UIとは同じ乱数seedでも生成される画像が異なる場合があります。
|
||||
|
||||
- `--l`:guidance scaleを指定します。コマンドラインからの指定を上書きします。
|
||||
|
||||
- `--t`:img2img(後述)のstrengthを指定します。コマンドラインからの指定を上書きします。
|
||||
|
||||
- `--nl`:ネガティブプロンプトのguidance scaleを指定します(後述)。コマンドラインからの指定を上書きします。
|
||||
|
||||
- `--am`:追加ネットワークの重みを指定します。コマンドラインからの指定を上書きします。複数の追加ネットワークを使用する場合は`--am 0.8,0.5,0.3`のように __カンマ区切りで__ 指定します。
|
||||
|
||||
※これらのオプションを指定すると、バッチサイズよりも小さいサイズでバッチが実行される場合があります(これらの値が異なると一括生成できないため)。(あまり気にしなくて大丈夫ですが、ファイルからプロンプトを読み込み生成する場合は、これらの値が同一のプロンプトを並べておくと効率が良くなります。)
|
||||
|
||||
例:
|
||||
```
|
||||
(masterpiece, best quality), 1girl, in shirt and plated skirt, standing at street under cherry blossoms, upper body, [from below], kind smile, looking at another, [goodembed] --n realistic, real life, (negprompt), (lowres:1.1), (worst quality:1.2), (low quality:1.1), bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, normal quality, jpeg artifacts, signature, watermark, username, blurry --w 960 --h 640 --s 28 --d 1
|
||||
```
|
||||
|
||||

|
||||
|
||||
# img2img
|
||||
|
||||
## オプション
|
||||
|
||||
- `--image_path`:img2imgに利用する画像を指定します。`--image_path template.png`のように指定します。フォルダを指定すると、そのフォルダの画像を順次利用します。
|
||||
|
||||
- `--strength`:img2imgのstrengthを指定します。`--strength 0.8`のように指定します。デフォルトは`0.8`です。
|
||||
|
||||
- `--sequential_file_name`:ファイル名を連番にするかどうかを指定します。指定すると生成されるファイル名が`im_000001.png`からの連番になります。
|
||||
|
||||
- `--use_original_file_name`:指定すると生成ファイル名がオリジナルのファイル名と同じになります。
|
||||
|
||||
## コマンドラインからの実行例
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt trinart_characters_it4_v1_vae_merged.ckpt
|
||||
--outdir outputs --xformers --fp16 --scale 12.5 --sampler k_euler --steps 32
|
||||
--image_path template.png --strength 0.8
|
||||
--prompt "1girl, cowboy shot, brown hair, pony tail, brown eyes,
|
||||
sailor school uniform, outdoors
|
||||
--n lowres, bad anatomy, bad hands, error, missing fingers, cropped,
|
||||
worst quality, low quality, normal quality, jpeg artifacts, (blurry),
|
||||
hair ornament, glasses"
|
||||
--batch_size 8 --images_per_prompt 32
|
||||
```
|
||||
|
||||
`--image_path`オプションにフォルダを指定すると、そのフォルダの画像を順次読み込みます。生成される枚数は画像枚数ではなく、プロンプト数になりますので、`--images_per_promptPPオプションを指定してimg2imgする画像の枚数とプロンプト数を合わせてください。
|
||||
|
||||
ファイルはファイル名でソートして読み込みます。なおソート順は文字列順となりますので(`1.jpg→2.jpg→10.jpg`ではなく`1.jpg→10.jpg→2.jpg`の順)、頭を0埋めするなどしてご対応ください(`01.jpg→02.jpg→10.jpg`)。
|
||||
|
||||
## img2imgを利用したupscale
|
||||
|
||||
img2img時にコマンドラインオプションの`--W`と`--H`で生成画像サイズを指定すると、元画像をそのサイズにリサイズしてからimg2imgを行います。
|
||||
|
||||
またimg2imgの元画像がこのスクリプトで生成した画像の場合、プロンプトを省略すると、元画像のメタデータからプロンプトを取得しそのまま用います。これによりHighres. fixの2nd stageの動作だけを行うことができます。
|
||||
|
||||
## img2img時のinpainting
|
||||
|
||||
画像およびマスク画像を指定してinpaintingできます(inpaintingモデルには対応しておらず、単にマスク領域を対象にimg2imgするだけです)。
|
||||
|
||||
オプションは以下の通りです。
|
||||
|
||||
- `--mask_image`:マスク画像を指定します。`--img_path`と同様にフォルダを指定すると、そのフォルダの画像を順次利用します。
|
||||
|
||||
マスク画像はグレースケール画像で、白の部分がinpaintingされます。境界をグラデーションしておくとなんとなく滑らかになりますのでお勧めです。
|
||||
|
||||

|
||||
|
||||
# その他の機能
|
||||
|
||||
## Textual Inversion
|
||||
|
||||
`--textual_inversion_embeddings`オプションで使用するembeddingsを指定します(複数指定可)。拡張子を除いたファイル名をプロンプト内で使用することで、そのembeddingsを利用します(Web UIと同様の使用法です)。ネガティブプロンプト内でも使用できます。
|
||||
|
||||
モデルとして、当リポジトリで学習したTextual Inversionモデル、およびWeb UIで学習したTextual Inversionモデル(画像埋め込みは非対応)を利用できます
|
||||
|
||||
## Extended Textual Inversion
|
||||
|
||||
`--textual_inversion_embeddings`の代わりに`--XTI_embeddings`オプションを指定してください。使用法は`--textual_inversion_embeddings`と同じです。
|
||||
|
||||
## Highres. fix
|
||||
|
||||
AUTOMATIC1111氏のWeb UIにある機能の類似機能です(独自実装のためもしかしたらいろいろ異なるかもしれません)。最初に小さめの画像を生成し、その画像を元にimg2imgすることで、画像全体の破綻を防ぎつつ大きな解像度の画像を生成します。
|
||||
|
||||
2nd stageのstep数は`--steps` と`--strength`オプションの値から計算されます(`steps*strength`)。
|
||||
|
||||
img2imgと併用できません。
|
||||
|
||||
以下のオプションがあります。
|
||||
|
||||
- `--highres_fix_scale`:Highres. fixを有効にして、1st stageで生成する画像のサイズを、倍率で指定します。最終出力が1024x1024で、最初に512x512の画像を生成する場合は`--highres_fix_scale 0.5`のように指定します。Web UI出の指定の逆数になっていますのでご注意ください。
|
||||
|
||||
- `--highres_fix_steps`:1st stageの画像のステップ数を指定します。デフォルトは`28`です。
|
||||
|
||||
- `--highres_fix_save_1st`:1st stageの画像を保存するかどうかを指定します。
|
||||
|
||||
- `--highres_fix_latents_upscaling`:指定すると2nd stageの画像生成時に1st stageの画像をlatentベースでupscalingします(bilinearのみ対応)。未指定時は画像をLANCZOS4でupscalingします。
|
||||
|
||||
- `--highres_fix_upscaler`:2nd stageに任意のupscalerを利用します。現在は`--highres_fix_upscaler tools.latent_upscaler` のみ対応しています。
|
||||
|
||||
- `--highres_fix_upscaler_args`:`--highres_fix_upscaler`で指定したupscalerに渡す引数を指定します。
|
||||
`tools.latent_upscaler`の場合は、`--highres_fix_upscaler_args "weights=D:\Work\SD\Models\others\etc\upscaler-v1-e100-220.safetensors"`のように重みファイルを指定します。
|
||||
|
||||
コマンドラインの例です。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt trinart_characters_it4_v1_vae_merged.ckpt
|
||||
--n_iter 1 --scale 7.5 --W 1024 --H 1024 --batch_size 1 --outdir ../txt2img
|
||||
--steps 48 --sampler ddim --fp16
|
||||
--xformers
|
||||
--images_per_prompt 1 --interactive
|
||||
--highres_fix_scale 0.5 --highres_fix_steps 28 --strength 0.5
|
||||
```
|
||||
|
||||
## ControlNet
|
||||
|
||||
現在はControlNet 1.0のみ動作確認しています。プリプロセスはCannyのみサポートしています。
|
||||
|
||||
以下のオプションがあります。
|
||||
|
||||
- `--control_net_models`:ControlNetのモデルファイルを指定します。
|
||||
複数指定すると、それらをstepごとに切り替えて利用します(Web UIのControlNet拡張の実装と異なります)。diffと通常の両方をサポートします。
|
||||
|
||||
- `--guide_image_path`:ControlNetに使うヒント画像を指定します。`--img_path`と同様にフォルダを指定すると、そのフォルダの画像を順次利用します。Canny以外のモデルの場合には、あらかじめプリプロセスを行っておいてください。
|
||||
|
||||
- `--control_net_preps`:ControlNetのプリプロセスを指定します。`--control_net_models`と同様に複数指定可能です。現在はcannyのみ対応しています。対象モデルでプリプロセスを使用しない場合は `none` を指定します。
|
||||
cannyの場合 `--control_net_preps canny_63_191`のように、閾値1と2を'_'で区切って指定できます。
|
||||
|
||||
- `--control_net_weights`:ControlNetの適用時の重みを指定します(`1.0`で通常、`0.5`なら半分の影響力で適用)。`--control_net_models`と同様に複数指定可能です。
|
||||
|
||||
- `--control_net_ratios`:ControlNetを適用するstepの範囲を指定します。`0.5`の場合は、step数の半分までControlNetを適用します。`--control_net_models`と同様に複数指定可能です。
|
||||
|
||||
コマンドラインの例です。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt model_ckpt --scale 8 --steps 48 --outdir txt2img --xformers
|
||||
--W 512 --H 768 --bf16 --sampler k_euler_a
|
||||
--control_net_models diff_control_sd15_canny.safetensors --control_net_weights 1.0
|
||||
--guide_image_path guide.png --control_net_ratios 1.0 --interactive
|
||||
```
|
||||
|
||||
## Attention Couple + Reginal LoRA
|
||||
|
||||
プロンプトをいくつかの部分に分割し、それぞれのプロンプトを画像内のどの領域に適用するかを指定できる機能です。個別のオプションはありませんが、`mask_path`とプロンプトで指定します。
|
||||
|
||||
まず、プロンプトで` AND `を利用して、複数部分を定義します。最初の3つに対して領域指定ができ、以降の部分は画像全体へ適用されます。ネガティブプロンプトは画像全体に適用されます。
|
||||
|
||||
以下ではANDで3つの部分を定義しています。
|
||||
|
||||
```
|
||||
shs 2girls, looking at viewer, smile AND bsb 2girls, looking back AND 2girls --n bad quality, worst quality
|
||||
```
|
||||
|
||||
次にマスク画像を用意します。マスク画像はカラーの画像で、RGBの各チャネルがプロンプトのANDで区切られた部分に対応します。またあるチャネルの値がすべて0の場合、画像全体に適用されます。
|
||||
|
||||
上記の例では、Rチャネルが`shs 2girls, looking at viewer, smile`、Gチャネルが`bsb 2girls, looking back`に、Bチャネルが`2girls`に対応します。次のようなマスク画像を使用すると、Bチャネルに指定がありませんので、`2girls`は画像全体に適用されます。
|
||||
|
||||

|
||||
|
||||
マスク画像は`--mask_path`で指定します。現在は1枚のみ対応しています。指定した画像サイズに自動的にリサイズされ適用されます。
|
||||
|
||||
ControlNetと組み合わせることも可能です(細かい位置指定にはControlNetとの組み合わせを推奨します)。
|
||||
|
||||
LoRAを指定すると、`--network_weights`で指定した複数のLoRAがそれぞれANDの各部分に対応します。現在の制約として、LoRAの数はANDの部分の数と同じである必要があります。
|
||||
|
||||
## CLIP Guided Stable Diffusion
|
||||
|
||||
DiffusersのCommunity Examplesの[こちらのcustom pipeline](https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#clip-guided-stable-diffusion)からソースをコピー、変更したものです。
|
||||
|
||||
通常のプロンプトによる生成指定に加えて、追加でより大規模のCLIPでプロンプトのテキストの特徴量を取得し、生成中の画像の特徴量がそのテキストの特徴量に近づくよう、生成される画像をコントロールします(私のざっくりとした理解です)。大きめのCLIPを使いますのでVRAM使用量はかなり増加し(VRAM 8GBでは512*512でも厳しいかもしれません)、生成時間も掛かります。
|
||||
|
||||
なお選択できるサンプラーはDDIM、PNDM、LMSのみとなります。
|
||||
|
||||
`--clip_guidance_scale`オプションにどの程度、CLIPの特徴量を反映するかを数値で指定します。先のサンプルでは100になっていますので、そのあたりから始めて増減すると良いようです。
|
||||
|
||||
デフォルトではプロンプトの先頭75トークン(重みづけの特殊文字を除く)がCLIPに渡されます。プロンプトの`--c`オプションで、通常のプロンプトではなく、CLIPに渡すテキストを別に指定できます(たとえばCLIPはDreamBoothのidentifier(識別子)や「1girl」などのモデル特有の単語は認識できないと思われますので、それらを省いたテキストが良いと思われます)。
|
||||
|
||||
コマンドラインの例です。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt v1-5-pruned-emaonly.ckpt --n_iter 1
|
||||
--scale 2.5 --W 512 --H 512 --batch_size 1 --outdir ../txt2img --steps 36
|
||||
--sampler ddim --fp16 --opt_channels_last --xformers --images_per_prompt 1
|
||||
--interactive --clip_guidance_scale 100
|
||||
```
|
||||
|
||||
## CLIP Image Guided Stable Diffusion
|
||||
|
||||
テキストではなくCLIPに別の画像を渡し、その特徴量に近づくよう生成をコントロールする機能です。`--clip_image_guidance_scale`オプションで適用量の数値を、`--guide_image_path`オプションでguideに使用する画像(ファイルまたはフォルダ)を指定してください。
|
||||
|
||||
コマンドラインの例です。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt trinart_characters_it4_v1_vae_merged.ckpt
|
||||
--n_iter 1 --scale 7.5 --W 512 --H 512 --batch_size 1 --outdir ../txt2img
|
||||
--steps 80 --sampler ddim --fp16 --opt_channels_last --xformers
|
||||
--images_per_prompt 1 --interactive --clip_image_guidance_scale 100
|
||||
--guide_image_path YUKA160113420I9A4104_TP_V.jpg
|
||||
```
|
||||
|
||||
### VGG16 Guided Stable Diffusion
|
||||
|
||||
指定した画像に近づくように画像生成する機能です。通常のプロンプトによる生成指定に加えて、追加でVGG16の特徴量を取得し、生成中の画像が指定したガイド画像に近づくよう、生成される画像をコントロールします。img2imgでの使用をお勧めします(通常の生成では画像がぼやけた感じになります)。CLIP Guided Stable Diffusionの仕組みを流用した独自の機能です。またアイデアはVGGを利用したスタイル変換から拝借しています。
|
||||
|
||||
なお選択できるサンプラーはDDIM、PNDM、LMSのみとなります。
|
||||
|
||||
`--vgg16_guidance_scale`オプションにどの程度、VGG16特徴量を反映するかを数値で指定します。試した感じでは100くらいから始めて増減すると良いようです。`--guide_image_path`オプションでguideに使用する画像(ファイルまたはフォルダ)を指定してください。
|
||||
|
||||
複数枚の画像を一括でimg2img変換し、元画像をガイド画像とする場合、`--guide_image_path`と`--image_path`に同じ値を指定すればOKです。
|
||||
|
||||
コマンドラインの例です。
|
||||
|
||||
```batchfile
|
||||
python gen_img_diffusers.py --ckpt wd-v1-3-full-pruned-half.ckpt
|
||||
--n_iter 1 --scale 5.5 --steps 60 --outdir ../txt2img
|
||||
--xformers --sampler ddim --fp16 --W 512 --H 704
|
||||
--batch_size 1 --images_per_prompt 1
|
||||
--prompt "picturesque, 1girl, solo, anime face, skirt, beautiful face
|
||||
--n lowres, bad anatomy, bad hands, error, missing fingers,
|
||||
cropped, worst quality, low quality, normal quality,
|
||||
jpeg artifacts, blurry, 3d, bad face, monochrome --d 1"
|
||||
--strength 0.8 --image_path ..\src_image
|
||||
--vgg16_guidance_scale 100 --guide_image_path ..\src_image
|
||||
```
|
||||
|
||||
`--vgg16_guidance_layerPで特徴量取得に使用するVGG16のレイヤー番号を指定できます(デフォルトは20でconv4-2のReLUです)。上の層ほど画風を表現し、下の層ほどコンテンツを表現するといわれています。
|
||||
|
||||

|
||||
|
||||
# その他のオプション
|
||||
|
||||
- `--no_preview` : 対話モードでプレビュー画像を表示しません。OpenCVがインストールされていない場合や、出力されたファイルを直接確認する場合に指定してください。
|
||||
|
||||
- `--n_iter` : 生成を繰り返す回数を指定します。デフォルトは1です。プロンプトをファイルから読み込むとき、複数回の生成を行いたい場合に指定します。
|
||||
|
||||
- `--tokenizer_cache_dir` : トークナイザーのキャッシュディレクトリを指定します。(作業中)
|
||||
|
||||
- `--seed` : 乱数seedを指定します。1枚生成時はその画像のseed、複数枚生成時は各画像のseedを生成するための乱数のseedになります(`--from_file`で複数画像生成するとき、`--seed`オプションを指定すると複数回実行したときに各画像が同じseedになります)。
|
||||
|
||||
- `--iter_same_seed` : プロンプトに乱数seedの指定がないとき、`--n_iter`の繰り返し内ではすべて同じseedを使います。`--from_file`で指定した複数のプロンプト間でseedを統一して比較するときに使います。
|
||||
|
||||
- `--diffusers_xformers` : Diffuserのxformersを使用します。
|
||||
|
||||
- `--opt_channels_last` : 推論時にテンソルのチャンネルを最後に配置します。場合によっては高速化されることがあります。
|
||||
|
||||
- `--network_show_meta` : 追加ネットワークのメタデータを表示します。
|
||||
|
||||
@@ -2,7 +2,7 @@ __ドキュメント更新中のため記述に誤りがあるかもしれませ
|
||||
|
||||
# 学習について、共通編
|
||||
|
||||
当リポジトリではモデルのfine tuning、DreamBooth、およびLoRAとTextual Inversionの学習をサポートします。この文書ではそれらに共通する、学習データの準備方法やオプション等について説明します。
|
||||
当リポジトリではモデルのfine tuning、DreamBooth、およびLoRAとTextual Inversion([XTI:P+](https://github.com/kohya-ss/sd-scripts/pull/327)を含む)の学習をサポートします。この文書ではそれらに共通する、学習データの準備方法やオプション等について説明します。
|
||||
|
||||
# 概要
|
||||
|
||||
@@ -295,7 +295,7 @@ Stable Diffusion のv1は512\*512で学習されていますが、それに加
|
||||
|
||||
また任意の解像度で学習するため、事前に画像データの縦横比を統一しておく必要がなくなります。
|
||||
|
||||
設定で有効、向こうが切り替えられますが、ここまでの設定ファイルの記述例では有効になっています(`true` が設定されています)。
|
||||
設定で有効、無効が切り替えられますが、ここまでの設定ファイルの記述例では有効になっています(`true` が設定されています)。
|
||||
|
||||
学習解像度はパラメータとして与えられた解像度の面積(=メモリ使用量)を超えない範囲で、64ピクセル単位(デフォルト、変更可)で縦横に調整、作成されます。
|
||||
|
||||
@@ -374,6 +374,10 @@ classがひとつで対象が複数の場合、正則化画像フォルダはひ
|
||||
|
||||
サンプル出力するステップ数またはエポック数を指定します。この数ごとにサンプル出力します。両方指定するとエポック数が優先されます。
|
||||
|
||||
- `--sample_at_first`
|
||||
|
||||
学習開始前にサンプル出力します。学習前との比較ができます。
|
||||
|
||||
- `--sample_prompts`
|
||||
|
||||
サンプル出力用プロンプトのファイルを指定します。
|
||||
@@ -463,27 +467,6 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
|
||||
|
||||
xformersオプションを指定するとxformersのCrossAttentionを用います。xformersをインストールしていない場合やエラーとなる場合(環境にもよりますが `mixed_precision="no"` の場合など)、代わりに `mem_eff_attn` オプションを指定すると省メモリ版CrossAttentionを使用します(xformersよりも速度は遅くなります)。
|
||||
|
||||
- `--save_precision`
|
||||
|
||||
保存時のデータ精度を指定します。save_precisionオプションにfloat、fp16、bf16のいずれかを指定すると、その形式でモデルを保存します(DreamBooth、fine tuningでDiffusers形式でモデルを保存する場合は無効です)。モデルのサイズを削減したい場合などにお使いください。
|
||||
|
||||
- `--save_every_n_epochs` / `--save_state` / `--resume`
|
||||
save_every_n_epochsオプションに数値を指定すると、そのエポックごとに学習途中のモデルを保存します。
|
||||
|
||||
save_stateオプションを同時に指定すると、optimizer等の状態も含めた学習状態を合わせて保存します(保存したモデルからも学習再開できますが、それに比べると精度の向上、学習時間の短縮が期待できます)。保存先はフォルダになります。
|
||||
|
||||
学習状態は保存先フォルダに `<output_name>-??????-state`(??????はエポック数)という名前のフォルダで出力されます。長時間にわたる学習時にご利用ください。
|
||||
|
||||
保存された学習状態から学習を再開するにはresumeオプションを使います。学習状態のフォルダ(`output_dir` ではなくその中のstateのフォルダ)を指定してください。
|
||||
|
||||
なおAcceleratorの仕様により、エポック数、global stepは保存されておらず、resumeしたときにも1からになりますがご容赦ください。
|
||||
|
||||
- `--save_model_as` (DreamBooth, fine tuning のみ)
|
||||
|
||||
モデルの保存形式を`ckpt, safetensors, diffusers, diffusers_safetensors` から選べます。
|
||||
|
||||
`--save_model_as=safetensors` のように指定します。Stable Diffusion形式(ckptまたはsafetensors)を読み込み、Diffusers形式で保存する場合、不足する情報はHugging Faceからv1.5またはv2.1の情報を落としてきて補完します。
|
||||
|
||||
- `--clip_skip`
|
||||
|
||||
`2` を指定すると、Text Encoder (CLIP) の後ろから二番目の層の出力を用います。1またはオプション省略時は最後の層を用います。
|
||||
@@ -502,6 +485,12 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
|
||||
|
||||
clip_skipと同様に、モデルの学習状態と異なる長さで学習するには、ある程度の教師データ枚数、長めの学習時間が必要になると思われます。
|
||||
|
||||
- `--weighted_captions`
|
||||
|
||||
指定するとAutomatic1111氏のWeb UIと同様の重み付きキャプションが有効になります。「Textual Inversion と XTI」以外の学習に使用できます。キャプションだけでなく DreamBooth 手法の token string でも有効です。
|
||||
|
||||
重みづけキャプションの記法はWeb UIとほぼ同じで、(abc)や[abc]、(abc:1.23)などが使用できます。入れ子も可能です。括弧内にカンマを含めるとプロンプトのshuffle/dropoutで括弧の対応付けがおかしくなるため、括弧内にはカンマを含めないでください。
|
||||
|
||||
- `--persistent_data_loader_workers`
|
||||
|
||||
Windows環境で指定するとエポック間の待ち時間が大幅に短縮されます。
|
||||
@@ -527,15 +516,31 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
|
||||
|
||||
その後ブラウザを開き、http://localhost:6006/ へアクセスすると表示されます。
|
||||
|
||||
- `--log_with` / `--log_tracker_name`
|
||||
|
||||
学習ログの保存に関するオプションです。`tensorboard` だけでなく `wandb`への保存が可能です。詳細は [PR#428](https://github.com/kohya-ss/sd-scripts/pull/428)をご覧ください。
|
||||
|
||||
- `--noise_offset`
|
||||
|
||||
こちらの記事の実装になります: https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||
|
||||
全体的に暗い、明るい画像の生成結果が良くなる可能性があるようです。LoRA学習でも有効なようです。`0.1` 程度の値を指定するとよいようです。
|
||||
|
||||
- `--adaptive_noise_scale` (実験的オプション)
|
||||
|
||||
Noise offsetの値を、latentsの各チャネルの平均値の絶対値に応じて自動調整するオプションです。`--noise_offset` と同時に指定することで有効になります。Noise offsetの値は `noise_offset + abs(mean(latents, dim=(2,3))) * adaptive_noise_scale` で計算されます。latentは正規分布に近いためnoise_offsetの1/10~同程度の値を指定するとよいかもしれません。
|
||||
|
||||
負の値も指定でき、その場合はnoise offsetは0以上にclipされます。
|
||||
|
||||
- `--multires_noise_iterations` / `--multires_noise_discount`
|
||||
|
||||
Multi resolution noise (pyramid noise)の設定です。詳細は [PR#471](https://github.com/kohya-ss/sd-scripts/pull/471) およびこちらのページ [Multi-Resolution Noise for Diffusion Model Training](https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2) を参照してください。
|
||||
|
||||
`--multires_noise_iterations` に数値を指定すると有効になります。6~10程度の値が良いようです。`--multires_noise_discount` に0.1~0.3 程度の値(LoRA学習等比較的データセットが小さい場合のPR作者の推奨)、ないしは0.8程度の値(元記事の推奨)を指定してください(デフォルトは 0.3)。
|
||||
|
||||
- `--debug_dataset`
|
||||
|
||||
このオプションを付けることで学習を行う前に事前にどのような画像データ、キャプションで学習されるかを確認できます。Escキーを押すと終了してコマンドラインに戻ります。
|
||||
このオプションを付けることで学習を行う前に事前にどのような画像データ、キャプションで学習されるかを確認できます。Escキーを押すと終了してコマンドラインに戻ります。`S`キーで次のステップ(バッチ)、`E`キーで次のエポックに進みます。
|
||||
|
||||
※Linux環境(Colabを含む)では画像は表示されません。
|
||||
|
||||
@@ -545,6 +550,61 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
|
||||
|
||||
DreamBoothおよびfine tuningでは、保存されるモデルはこのVAEを組み込んだものになります。
|
||||
|
||||
- `--cache_latents` / `--cache_latents_to_disk`
|
||||
|
||||
使用VRAMを減らすためVAEの出力をメインメモリにキャッシュします。`flip_aug` 以外のaugmentationは使えなくなります。また全体の学習速度が若干速くなります。
|
||||
|
||||
cache_latents_to_diskを指定するとキャッシュをディスクに保存します。スクリプトを終了し、再度起動した場合もキャッシュが有効になります。
|
||||
|
||||
- `--min_snr_gamma`
|
||||
|
||||
Min-SNR Weighting strategyを指定します。詳細は[こちら](https://github.com/kohya-ss/sd-scripts/pull/308)を参照してください。論文では`5`が推奨されています。
|
||||
|
||||
## モデルの保存に関する設定
|
||||
|
||||
- `--save_precision`
|
||||
|
||||
保存時のデータ精度を指定します。save_precisionオプションにfloat、fp16、bf16のいずれかを指定すると、その形式でモデルを保存します(DreamBooth、fine tuningでDiffusers形式でモデルを保存する場合は無効です)。モデルのサイズを削減したい場合などにお使いください。
|
||||
|
||||
- `--save_every_n_epochs` / `--save_state` / `--resume`
|
||||
|
||||
save_every_n_epochsオプションに数値を指定すると、そのエポックごとに学習途中のモデルを保存します。
|
||||
|
||||
save_stateオプションを同時に指定すると、optimizer等の状態も含めた学習状態を合わせて保存します(保存したモデルからも学習再開できますが、それに比べると精度の向上、学習時間の短縮が期待できます)。保存先はフォルダになります。
|
||||
|
||||
学習状態は保存先フォルダに `<output_name>-??????-state`(??????はエポック数)という名前のフォルダで出力されます。長時間にわたる学習時にご利用ください。
|
||||
|
||||
保存された学習状態から学習を再開するにはresumeオプションを使います。学習状態のフォルダ(`output_dir` ではなくその中のstateのフォルダ)を指定してください。
|
||||
|
||||
なおAcceleratorの仕様により、エポック数、global stepは保存されておらず、resumeしたときにも1からになりますがご容赦ください。
|
||||
|
||||
- `--save_every_n_steps`
|
||||
|
||||
save_every_n_stepsオプションに数値を指定すると、そのステップごとに学習途中のモデルを保存します。save_every_n_epochsと同時に指定できます。
|
||||
|
||||
- `--save_model_as` (DreamBooth, fine tuning のみ)
|
||||
|
||||
モデルの保存形式を`ckpt, safetensors, diffusers, diffusers_safetensors` から選べます。
|
||||
|
||||
`--save_model_as=safetensors` のように指定します。Stable Diffusion形式(ckptまたはsafetensors)を読み込み、Diffusers形式で保存する場合、不足する情報はHugging Faceからv1.5またはv2.1の情報を落としてきて補完します。
|
||||
|
||||
- `--huggingface_repo_id` 等
|
||||
|
||||
huggingface_repo_idが指定されているとモデル保存時に同時にHuggingFaceにアップロードします。アクセストークンの取り扱いに注意してください(HuggingFaceのドキュメントを参照してください)。
|
||||
|
||||
他の引数をたとえば以下のように指定してください。
|
||||
|
||||
- `--huggingface_repo_id "your-hf-name/your-model" --huggingface_path_in_repo "path" --huggingface_repo_type model --huggingface_repo_visibility private --huggingface_token hf_YourAccessTokenHere`
|
||||
|
||||
huggingface_repo_visibilityに`public`を指定するとリポジトリが公開されます。省略時または`private`(などpublic以外)を指定すると非公開になります。
|
||||
|
||||
`--save_state`オプション指定時に`--save_state_to_huggingface`を指定するとstateもアップロードします。
|
||||
|
||||
`--resume`オプション指定時に`--resume_from_huggingface`を指定するとHuggingFaceからstateをダウンロードして再開します。その時の --resumeオプションは `--resume {repo_id}/{path_in_repo}:{revision}:{repo_type}`になります。
|
||||
|
||||
例: `--resume_from_huggingface --resume your-hf-name/your-model/path/test-000002-state:main:model`
|
||||
|
||||
`--async_upload`オプションを指定するとアップロードを非同期で行います。
|
||||
|
||||
## オプティマイザ関係
|
||||
|
||||
@@ -553,12 +613,22 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
|
||||
- AdamW : [torch.optim.AdamW](https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html)
|
||||
- 過去のバージョンのオプション未指定時と同じ
|
||||
- AdamW8bit : 引数は同上
|
||||
- PagedAdamW8bit : 引数は同上
|
||||
- 過去のバージョンの--use_8bit_adam指定時と同じ
|
||||
- Lion : https://github.com/lucidrains/lion-pytorch
|
||||
- 過去のバージョンの--use_lion_optimizer指定時と同じ
|
||||
- Lion8bit : 引数は同上
|
||||
- PagedLion8bit : 引数は同上
|
||||
- SGDNesterov : [torch.optim.SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html), nesterov=True
|
||||
- SGDNesterov8bit : 引数は同上
|
||||
- DAdaptation : https://github.com/facebookresearch/dadaptation
|
||||
- DAdaptation(DAdaptAdamPreprint) : https://github.com/facebookresearch/dadaptation
|
||||
- DAdaptAdam : 引数は同上
|
||||
- DAdaptAdaGrad : 引数は同上
|
||||
- DAdaptAdan : 引数は同上
|
||||
- DAdaptAdanIP : 引数は同上
|
||||
- DAdaptLion : 引数は同上
|
||||
- DAdaptSGD : 引数は同上
|
||||
- Prodigy : https://github.com/konstmish/prodigy
|
||||
- AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules)
|
||||
- 任意のオプティマイザ
|
||||
|
||||
@@ -570,7 +640,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
|
||||
|
||||
学習率のスケジューラ関連の指定です。
|
||||
|
||||
lr_schedulerオプションで学習率のスケジューラをlinear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmupから選べます。デフォルトはconstantです。
|
||||
lr_schedulerオプションで学習率のスケジューラをlinear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmup, 任意のスケジューラから選べます。デフォルトはconstantです。
|
||||
|
||||
lr_warmup_stepsでスケジューラのウォームアップ(だんだん学習率を変えていく)ステップ数を指定できます。
|
||||
|
||||
@@ -578,6 +648,8 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
|
||||
|
||||
詳細については各自お調べください。
|
||||
|
||||
任意のスケジューラを使う場合、任意のオプティマイザと同様に、`--scheduler_args`でオプション引数を指定してください。
|
||||
|
||||
### オプティマイザの指定について
|
||||
|
||||
オプティマイザのオプション引数は--optimizer_argsオプションで指定してください。key=valueの形式で、複数の値が指定できます。また、valueはカンマ区切りで複数の値が指定できます。たとえばAdamWオプティマイザに引数を指定する場合は、``--optimizer_args weight_decay=0.01 betas=.9,.999``のようになります。
|
||||
@@ -801,7 +873,7 @@ model_dirオプションでモデルの保存先フォルダを指定できま
|
||||
キャプションをメタデータに入れるには、作業フォルダ内で以下を実行してください(キャプションを学習に使わない場合は実行不要です)(実際は1行で記述します、以下同様)。`--full_path` オプションを指定してメタデータに画像ファイルの場所をフルパスで格納します。このオプションを省略すると相対パスで記録されますが、フォルダ指定が `.toml` ファイル内で別途必要になります。
|
||||
|
||||
```
|
||||
python merge_captions_to_metadata.py --full_apth <教師データフォルダ>
|
||||
python merge_captions_to_metadata.py --full_path <教師データフォルダ>
|
||||
--in_json <読み込むメタデータファイル名> <メタデータファイル名>
|
||||
```
|
||||
|
||||
912
docs/train_README-zh.md
Normal file
912
docs/train_README-zh.md
Normal file
@@ -0,0 +1,912 @@
|
||||
__由于文档正在更新中,描述可能有错误。__
|
||||
|
||||
# 关于训练,通用描述
|
||||
本库支持模型微调(fine tuning)、DreamBooth、训练LoRA和文本反转(Textual Inversion)(包括[XTI:P+](https://github.com/kohya-ss/sd-scripts/pull/327)
|
||||
)
|
||||
本文档将说明它们通用的训练数据准备方法和选项等。
|
||||
|
||||
# 概要
|
||||
|
||||
请提前参考本仓库的README,准备好环境。
|
||||
|
||||
|
||||
以下本节说明。
|
||||
|
||||
1. 准备训练数据(使用设置文件的新格式)
|
||||
1. 训练中使用的术语的简要解释
|
||||
1. 先前的指定格式(不使用设置文件,而是从命令行指定)
|
||||
1. 生成训练过程中的示例图像
|
||||
1. 各脚本中常用的共同选项
|
||||
1. 准备 fine tuning 方法的元数据:如说明文字(打标签)等
|
||||
|
||||
|
||||
1. 如果只执行一次,训练就可以进行(相关内容,请参阅各个脚本的文档)。如果需要,以后可以随时参考。
|
||||
|
||||
|
||||
|
||||
# 关于准备训练数据
|
||||
|
||||
在任意文件夹(也可以是多个文件夹)中准备好训练数据的图像文件。支持 `.png`, `.jpg`, `.jpeg`, `.webp`, `.bmp` 格式的文件。通常不需要进行任何预处理,如调整大小等。
|
||||
|
||||
但是请勿使用极小的图像,若其尺寸比训练分辨率(稍后将提到)还小,建议事先使用超分辨率AI等进行放大。另外,请注意不要使用过大的图像(约为3000 x 3000像素以上),因为这可能会导致错误,建议事先缩小。
|
||||
|
||||
在训练时,需要整理要用于训练模型的图像数据,并将其指定给脚本。根据训练数据的数量、训练目标和说明(图像描述)是否可用等因素,可以使用几种方法指定训练数据。以下是其中的一些方法(每个名称都不是通用的,而是该存储库自定义的定义)。有关正则化图像的信息将在稍后提供。
|
||||
|
||||
1. DreamBooth、class + identifier方式(可使用正则化图像)
|
||||
|
||||
将训练目标与特定单词(identifier)相关联进行训练。无需准备说明。例如,当要学习特定角色时,由于无需准备说明,因此比较方便,但由于训练数据的所有元素都与identifier相关联,例如发型、服装、背景等,因此在生成时可能会出现无法更换服装的情况。
|
||||
|
||||
2. DreamBooth、说明方式(可使用正则化图像)
|
||||
|
||||
事先给每个图片写说明(caption),存放到文本文件中,然后进行训练。例如,通过将图像详细信息(如穿着白色衣服的角色A、穿着红色衣服的角色A等)记录在caption中,可以将角色和其他元素分离,并期望模型更准确地学习角色。
|
||||
|
||||
3. 微调方式(不可使用正则化图像)
|
||||
|
||||
先将说明收集到元数据文件中。支持分离标签和说明以及预先缓存latents等功能,以加速训练(这些将在另一篇文档中介绍)。(虽然名为fine tuning方式,但不仅限于fine tuning。)
|
||||
|
||||
训练对象和你可以使用的规范方法的组合如下。
|
||||
|
||||
| 训练对象或方法 | 脚本 | DB/class+identifier | DB/caption | fine tuning |
|
||||
|----------------| ----- | ----- | ----- | ----- |
|
||||
| fine tuning微调模型 | `fine_tune.py`| x | x | o |
|
||||
| DreamBooth训练模型 | `train_db.py`| o | o | x |
|
||||
| LoRA | `train_network.py`| o | o | o |
|
||||
| Textual Invesion | `train_textual_inversion.py`| o | o | o |
|
||||
|
||||
## 选择哪一个
|
||||
|
||||
如果您想要训练LoRA、Textual Inversion而不需要准备说明(caption)文件,则建议使用DreamBooth class+identifier。如果您能够准备caption文件,则DreamBooth Captions方法更好。如果您有大量的训练数据并且不使用正则化图像,则请考虑使用fine-tuning方法。
|
||||
|
||||
对于DreamBooth也是一样的,但不能使用fine-tuning方法。若要进行微调,只能使用fine-tuning方式。
|
||||
|
||||
# 每种方法的指定方式
|
||||
|
||||
在这里,我们只介绍每种指定方法的典型模式。有关更详细的指定方法,请参见[数据集设置](./config_README-ja.md)。
|
||||
|
||||
# DreamBooth,class+identifier方法(可使用正则化图像)
|
||||
|
||||
在该方法中,每个图像将被视为使用与 `class identifier` 相同的标题进行训练(例如 `shs dog`)。
|
||||
|
||||
这样一来,每张图片都相当于使用标题“分类标识”(例如“shs dog”)进行训练。
|
||||
|
||||
## step 1.确定identifier和class
|
||||
|
||||
要将训练的目标与identifier和属于该目标的class相关联。
|
||||
|
||||
(虽然有很多称呼,但暂时按照原始论文的说法。)
|
||||
|
||||
以下是简要说明(请查阅详细信息)。
|
||||
|
||||
class是训练目标的一般类别。例如,如果要学习特定品种的狗,则class将是“dog”。对于动漫角色,根据模型不同,可能是“boy”或“girl”,也可能是“1boy”或“1girl”。
|
||||
|
||||
identifier是用于识别训练目标并进行学习的单词。可以使用任何单词,但是根据原始论文,“Tokenizer生成的3个或更少字符的罕见单词”是最好的选择。
|
||||
|
||||
使用identifier和class,例如,“shs dog”可以将模型训练为从class中识别并学习所需的目标。
|
||||
|
||||
在图像生成时,使用“shs dog”将生成所学习狗种的图像。
|
||||
|
||||
(作为identifier,我最近使用的一些参考是“shs sts scs cpc coc cic msm usu ici lvl cic dii muk ori hru rik koo yos wny”等。最好是不包含在Danbooru标签中的单词。)
|
||||
|
||||
## step 2. 决定是否使用正则化图像,并在使用时生成正则化图像
|
||||
|
||||
正则化图像是为防止前面提到的语言漂移,即整个类别被拉扯成为训练目标而生成的图像。如果不使用正则化图像,例如在 `shs 1girl` 中学习特定角色时,即使在简单的 `1girl` 提示下生成,也会越来越像该角色。这是因为 `1girl` 在训练时的标题中包含了该角色的信息。
|
||||
|
||||
通过同时学习目标图像和正则化图像,类别仍然保持不变,仅在将标识符附加到提示中时才生成目标图像。
|
||||
|
||||
如果您只想在LoRA或DreamBooth中使用特定的角色,则可以不使用正则化图像。
|
||||
|
||||
在Textual Inversion中也不需要使用(如果要学习的token string不包含在标题中,则不会学习任何内容)。
|
||||
|
||||
一般情况下,使用在训练目标模型时只使用类别名称生成的图像作为正则化图像是常见的做法(例如 `1girl`)。但是,如果生成的图像质量不佳,可以尝试修改提示或使用从网络上另外下载的图像。
|
||||
|
||||
(由于正则化图像也被训练,因此其质量会影响模型。)
|
||||
|
||||
通常,准备数百张图像是理想的(图像数量太少会导致类别图像无法被归纳,特征也不会被学习)。
|
||||
|
||||
如果要使用生成的图像,生成图像的大小通常应与训练分辨率(更准确地说,是bucket的分辨率,见下文)相匹配。
|
||||
|
||||
|
||||
|
||||
## step 2. 设置文件的描述
|
||||
|
||||
创建一个文本文件,并将其扩展名更改为`.toml`。例如,您可以按以下方式进行描述:
|
||||
|
||||
(以`#`开头的部分是注释,因此您可以直接复制粘贴,或者将其删除。)
|
||||
|
||||
```toml
|
||||
[general]
|
||||
enable_bucket = true # 是否使用Aspect Ratio Bucketing
|
||||
|
||||
[[datasets]]
|
||||
resolution = 512 # 训练分辨率
|
||||
batch_size = 4 # 批次大小
|
||||
|
||||
[[datasets.subsets]]
|
||||
image_dir = 'C:\hoge' # 指定包含训练图像的文件夹
|
||||
class_tokens = 'hoge girl' # 指定标识符类
|
||||
num_repeats = 10 # 训练图像的重复次数
|
||||
|
||||
# 以下仅在使用正则化图像时进行描述。不使用则删除
|
||||
[[datasets.subsets]]
|
||||
is_reg = true
|
||||
image_dir = 'C:\reg' # 指定包含正则化图像的文件夹
|
||||
class_tokens = 'girl' # 指定class
|
||||
num_repeats = 1 # 正则化图像的重复次数,基本上1就可以了
|
||||
```
|
||||
|
||||
基本上只需更改以下几个地方即可进行训练。
|
||||
|
||||
1. 训练分辨率
|
||||
|
||||
指定一个数字表示正方形(如果是 `512`,则为 512x512),如果使用方括号和逗号分隔的两个数字,则表示横向×纵向(如果是`[512,768]`,则为 512x768)。在SD1.x系列中,原始训练分辨率为512。指定较大的分辨率,如 `[512,768]` 可能会减少纵向和横向图像生成时的错误。在SD2.x 768系列中,分辨率为 `768`。
|
||||
|
||||
1. 批次大小
|
||||
|
||||
指定同时训练多少个数据。这取决于GPU的VRAM大小和训练分辨率。详细信息将在后面说明。此外,fine tuning/DreamBooth/LoRA等也会影响批次大小,请查看各个脚本的说明。
|
||||
|
||||
1. 文件夹指定
|
||||
|
||||
指定用于学习的图像和正则化图像(仅在使用时)的文件夹。指定包含图像数据的文件夹。
|
||||
|
||||
1. identifier 和 class 的指定
|
||||
|
||||
如前所述,与示例相同。
|
||||
|
||||
1. 重复次数
|
||||
|
||||
将在后面说明。
|
||||
|
||||
### 关于重复次数
|
||||
|
||||
重复次数用于调整正则化图像和训练用图像的数量。由于正则化图像的数量多于训练用图像,因此需要重复使用训练用图像来达到一对一的比例,从而实现训练。
|
||||
|
||||
请将重复次数指定为“ __训练用图像的重复次数×训练用图像的数量≥正则化图像的重复次数×正则化图像的数量__ ”。
|
||||
|
||||
(1个epoch(指训练数据过完一遍)的数据量为“训练用图像的重复次数×训练用图像的数量”。如果正则化图像的数量多于这个值,则剩余的正则化图像将不会被使用。)
|
||||
|
||||
## 步骤 3. 训练
|
||||
|
||||
详情请参考相关文档进行训练。
|
||||
|
||||
# DreamBooth,文本说明(caption)方式(可使用正则化图像)
|
||||
|
||||
在此方式中,每个图像都将通过caption进行训练。
|
||||
|
||||
## 步骤 1. 准备文本说明文件
|
||||
|
||||
请将与图像具有相同文件名且扩展名为 `.caption`(可以在设置中更改)的文件放置在用于训练图像的文件夹中。每个文件应该只有一行。编码为 `UTF-8`。
|
||||
|
||||
## 步骤 2. 决定是否使用正则化图像,并在使用时生成正则化图像
|
||||
|
||||
与class+identifier格式相同。可以在规范化图像上附加caption,但通常不需要。
|
||||
|
||||
## 步骤 2. 编写设置文件
|
||||
|
||||
创建一个文本文件并将扩展名更改为 `.toml`。例如,您可以按以下方式进行描述:
|
||||
|
||||
```toml
|
||||
[general]
|
||||
enable_bucket = true # 是否使用Aspect Ratio Bucketing
|
||||
|
||||
[[datasets]]
|
||||
resolution = 512 # 训练分辨率
|
||||
batch_size = 4 # 批次大小
|
||||
|
||||
[[datasets.subsets]]
|
||||
image_dir = 'C:\hoge' # 指定包含训练图像的文件夹
|
||||
caption_extension = '.caption' # 若使用txt文件,更改此项
|
||||
num_repeats = 10 # 训练图像的重复次数
|
||||
|
||||
# 以下仅在使用正则化图像时进行描述。不使用则删除
|
||||
[[datasets.subsets]]
|
||||
is_reg = true
|
||||
image_dir = 'C:\reg' # 指定包含正则化图像的文件夹
|
||||
class_tokens = 'girl' # 指定class
|
||||
num_repeats = 1 # 正则化图像的重复次数,基本上1就可以了
|
||||
```
|
||||
|
||||
基本上只需更改以下几个地方来训练。除非另有说明,否则与class+identifier方法相同。
|
||||
|
||||
1. 训练分辨率
|
||||
2. 批次大小
|
||||
3. 文件夹指定
|
||||
4. caption文件的扩展名
|
||||
|
||||
可以指定任意的扩展名。
|
||||
5. 重复次数
|
||||
|
||||
## 步骤 3. 训练
|
||||
|
||||
详情请参考相关文档进行训练。
|
||||
|
||||
# 微调方法(fine tuning)
|
||||
|
||||
## 步骤 1. 准备元数据
|
||||
|
||||
将caption和标签整合到管理文件中称为元数据。它的扩展名为 `.json`,格式为json。由于创建方法较长,因此在本文档的末尾进行描述。
|
||||
|
||||
## 步骤 2. 编写设置文件
|
||||
|
||||
创建一个文本文件,将扩展名设置为 `.toml`。例如,可以按以下方式编写:
|
||||
```toml
|
||||
[general]
|
||||
shuffle_caption = true
|
||||
keep_tokens = 1
|
||||
|
||||
[[datasets]]
|
||||
resolution = 512 # 图像分辨率
|
||||
batch_size = 4 # 批次大小
|
||||
|
||||
[[datasets.subsets]]
|
||||
image_dir = 'C:\piyo' # 指定包含训练图像的文件夹
|
||||
metadata_file = 'C:\piyo\piyo_md.json' # 元数据文件名
|
||||
```
|
||||
|
||||
基本上只需更改以下几个地方来训练。除非另有说明,否则与DreamBooth, class+identifier方法相同。
|
||||
|
||||
1. 训练分辨率
|
||||
2. 批次大小
|
||||
3. 指定文件夹
|
||||
4. 元数据文件名
|
||||
|
||||
指定使用后面所述方法创建的元数据文件。
|
||||
|
||||
|
||||
## 第三步:训练
|
||||
|
||||
详情请参考相关文档进行训练。
|
||||
|
||||
# 训练中使用的术语简单解释
|
||||
|
||||
由于省略了细节并且我自己也没有完全理解,因此请自行查阅详细信息。
|
||||
|
||||
## 微调(fine tuning)
|
||||
|
||||
指训练模型并微调其性能。具体含义因用法而异,但在 Stable Diffusion 中,狭义的微调是指使用图像和caption进行训练模型。DreamBooth 可视为狭义微调的一种特殊方法。广义的微调包括 LoRA、Textual Inversion、Hypernetworks 等,包括训练模型的所有内容。
|
||||
|
||||
## 步骤(step)
|
||||
|
||||
粗略地说,每次在训练数据上进行一次计算即为一步。具体来说,“将训练数据的caption传递给当前模型,将生成的图像与训练数据的图像进行比较,稍微更改模型,以使其更接近训练数据”即为一步。
|
||||
|
||||
## 批次大小(batch size)
|
||||
|
||||
批次大小指定每个步骤要计算多少数据。批次计算可以提高速度。一般来说,批次大小越大,精度也越高。
|
||||
|
||||
“批次大小×步数”是用于训练的数据数量。因此,建议减少步数以增加批次大小。
|
||||
|
||||
(但是,例如,“批次大小为 1,步数为 1600”和“批次大小为 4,步数为 400”将不会产生相同的结果。如果使用相同的学习速率,通常后者会导致模型欠拟合。请尝试增加学习率(例如 `2e-6`),将步数设置为 500 等。)
|
||||
|
||||
批次大小越大,GPU 内存消耗就越大。如果内存不足,将导致错误,或者在边缘时将导致训练速度降低。建议在任务管理器或 `nvidia-smi` 命令中检查使用的内存量进行调整。
|
||||
|
||||
注意,一个批次是指“一个数据单位”。
|
||||
|
||||
## 学习率
|
||||
|
||||
学习率指的是每个步骤中改变的程度。如果指定一个大的值,学习速度就会加快,但是可能会出现变化太大导致模型崩溃或无法达到最佳状态的情况。如果指定一个小的值,学习速度会变慢,同时可能无法达到最佳状态。
|
||||
|
||||
在fine tuning、DreamBooth、LoRA等过程中,学习率会有很大的差异,并且也会受到训练数据、所需训练的模型、批次大小和步骤数等因素的影响。建议从通常值开始,观察训练状态并逐渐调整。
|
||||
|
||||
默认情况下,整个训练过程中学习率是固定的。但是可以通过调度程序指定学习率如何变化,因此结果也会有所不同。
|
||||
|
||||
## Epoch
|
||||
|
||||
Epoch指的是训练数据被完整训练一遍(即数据已经迭代一轮)。如果指定了重复次数,则在重复后的数据迭代一轮后,为1个epoch。
|
||||
|
||||
1个epoch的步骤数通常为“数据量÷批次大小”,但如果使用Aspect Ratio Bucketing,则略微增加(由于不同bucket的数据不能在同一个批次中,因此步骤数会增加)。
|
||||
|
||||
## 长宽比分桶(Aspect Ratio Bucketing)
|
||||
|
||||
Stable Diffusion 的 v1 是以 512\*512 的分辨率进行训练的,但同时也可以在其他分辨率下进行训练,例如 256\*1024 和 384\*640。这样可以减少裁剪的部分,希望更准确地学习图像和标题之间的关系。
|
||||
|
||||
此外,由于可以在任意分辨率下进行训练,因此不再需要事先统一图像数据的长宽比。
|
||||
|
||||
此值可以被设定,其在此之前的配置文件示例中已被启用(设置为 `true`)。
|
||||
|
||||
只要不超过作为参数给出的分辨率区域(= 内存使用量),就可以按 64 像素的增量(默认值,可更改)在垂直和水平方向上调整和创建训练分辨率。
|
||||
|
||||
在机器学习中,通常需要将所有输入大小统一,但实际上只要在同一批次中统一即可。 NovelAI 所说的分桶(bucketing) 指的是,预先将训练数据按照长宽比分类到每个学习分辨率下,并通过使用每个 bucket 内的图像创建批次来统一批次图像大小。
|
||||
|
||||
# 以前的指定格式(不使用 .toml 文件,而是使用命令行选项指定)
|
||||
|
||||
这是一种通过命令行选项而不是指定 .toml 文件的方法。有 DreamBooth 类+标识符方法、DreamBooth caption方法、微调方法三种方式。
|
||||
|
||||
## DreamBooth、类+标识符方式
|
||||
|
||||
指定文件夹名称以指定迭代次数。还要使用 `train_data_dir` 和 `reg_data_dir` 选项。
|
||||
|
||||
### 第1步。准备用于训练的图像
|
||||
|
||||
创建一个用于存储训练图像的文件夹。__此外__,按以下名称创建目录。
|
||||
|
||||
```
|
||||
<迭代次数>_<标识符> <类别>
|
||||
```
|
||||
|
||||
不要忘记下划线``_``。
|
||||
|
||||
例如,如果在名为“sls frog”的提示下重复数据 20 次,则为“20_sls frog”。如下所示:
|
||||
|
||||

|
||||
|
||||
### 多个类别、多个标识符的训练
|
||||
|
||||
该方法很简单,在用于训练的图像文件夹中,需要准备多个文件夹,每个文件夹都是以“重复次数_<标识符> <类别>”命名的,同样,在正则化图像文件夹中,也需要准备多个文件夹,每个文件夹都是以“重复次数_<类别>”命名的。
|
||||
|
||||
例如,如果要同时训练“sls青蛙”和“cpc兔子”,则应按以下方式准备文件夹。
|
||||
|
||||

|
||||
|
||||
如果一个类别包含多个对象,可以只使用一个正则化图像文件夹。例如,如果在1girl类别中有角色A和角色B,则可以按照以下方式处理:
|
||||
|
||||
- train_girls
|
||||
- 10_sls 1girl
|
||||
- 10_cpc 1girl
|
||||
- reg_girls
|
||||
- 1_1girl
|
||||
|
||||
### step 2. 准备正规化图像
|
||||
|
||||
这是使用正则化图像时的过程。
|
||||
|
||||
创建一个文件夹来存储正则化的图像。 __此外,__ 创建一个名为``<repeat count>_<class>`` 的目录。
|
||||
|
||||
例如,使用提示“frog”并且不重复数据(仅一次):
|
||||

|
||||
|
||||
|
||||
步骤3. 执行训练
|
||||
|
||||
执行每个训练脚本。使用 `--train_data_dir` 选项指定包含训练数据文件夹的父文件夹(不是包含图像的文件夹),使用 `--reg_data_dir` 选项指定包含正则化图像的父文件夹(不是包含图像的文件夹)。
|
||||
|
||||
## DreamBooth,带文本说明(caption)的方式
|
||||
|
||||
在包含训练图像和正则化图像的文件夹中,将与图像具有相同文件名的文件.caption(可以使用选项进行更改)放置在该文件夹中,然后从该文件中加载caption所作为提示进行训练。
|
||||
|
||||
※文件夹名称(标识符类)不再用于这些图像的训练。
|
||||
|
||||
默认的caption文件扩展名为.caption。可以使用训练脚本的 `--caption_extension` 选项进行更改。 使用 `--shuffle_caption` 选项,同时对每个逗号分隔的部分进行训练时会对训练时的caption进行混洗。
|
||||
|
||||
## 微调方式
|
||||
|
||||
创建元数据的方式与使用配置文件相同。 使用 `in_json` 选项指定元数据文件。
|
||||
|
||||
# 训练过程中的样本输出
|
||||
|
||||
通过在训练中使用模型生成图像,可以检查训练进度。将以下选项指定为训练脚本。
|
||||
|
||||
- `--sample_every_n_steps` / `--sample_every_n_epochs`
|
||||
|
||||
指定要采样的步数或epoch数。为这些数字中的每一个输出样本。如果两者都指定,则 epoch 数优先。
|
||||
- `--sample_prompts`
|
||||
|
||||
指定示例输出的提示文件。
|
||||
|
||||
- `--sample_sampler`
|
||||
|
||||
指定用于采样输出的采样器。
|
||||
`'ddim', 'pndm', 'heun', 'dpmsolver', 'dpmsolver++', 'dpmsingle', 'k_lms', 'k_euler', 'k_euler_a', 'k_dpm_2', 'k_dpm_2_a'`が選べます。
|
||||
|
||||
要输出样本,您需要提前准备一个包含提示的文本文件。每行输入一个提示。
|
||||
|
||||
```txt
|
||||
# prompt 1
|
||||
masterpiece, best quality, 1girl, in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
|
||||
|
||||
# prompt 2
|
||||
masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
|
||||
```
|
||||
|
||||
以“#”开头的行是注释。您可以使用“`--` + 小写字母”为生成的图像指定选项,例如 `--n`。您可以使用:
|
||||
|
||||
- `--n` 否定提示到下一个选项。
|
||||
- `--w` 指定生成图像的宽度。
|
||||
- `--h` 指定生成图像的高度。
|
||||
- `--d` 指定生成图像的种子。
|
||||
- `--l` 指定生成图像的 CFG 比例。
|
||||
- `--s` 指定生成过程中的步骤数。
|
||||
|
||||
|
||||
# 每个脚本通用的常用选项
|
||||
|
||||
文档更新可能跟不上脚本更新。在这种情况下,请使用 `--help` 选项检查可用选项。
|
||||
## 学习模型规范
|
||||
|
||||
- `--v2` / `--v_parameterization`
|
||||
|
||||
如果使用 Hugging Face 的 stable-diffusion-2-base 或来自它的微调模型作为学习目标模型(对于在推理时指示使用 `v2-inference.yaml` 的模型),`- 当使用-v2` 选项与 stable-diffusion-2、768-v-ema.ckpt 及其微调模型(对于在推理过程中使用 `v2-inference-v.yaml` 的模型),`- 指定两个 -v2`和 `--v_parameterization` 选项。
|
||||
|
||||
以下几点在 Stable Diffusion 2.0 中发生了显着变化。
|
||||
|
||||
1. 使用分词器
|
||||
2. 使用哪个Text Encoder,使用哪个输出层(2.0使用倒数第二层)
|
||||
3. Text Encoder的输出维度(768->1024)
|
||||
4. U-Net的结构(CrossAttention的头数等)
|
||||
5. v-parameterization(采样方式好像变了)
|
||||
|
||||
其中base使用1-4,非base使用1-5(768-v)。使用 1-4 进行 v2 选择,使用 5 进行 v_parameterization 选择。
|
||||
- `--pretrained_model_name_or_path`
|
||||
|
||||
指定要从中执行额外训练的模型。您可以指定Stable Diffusion检查点文件(.ckpt 或 .safetensors)、diffusers本地磁盘上的模型目录或diffusers模型 ID(例如“stabilityai/stable-diffusion-2”)。
|
||||
## 训练设置
|
||||
|
||||
- `--output_dir`
|
||||
|
||||
指定训练后保存模型的文件夹。
|
||||
|
||||
- `--output_name`
|
||||
|
||||
指定不带扩展名的模型文件名。
|
||||
|
||||
- `--dataset_config`
|
||||
|
||||
指定描述数据集配置的 .toml 文件。
|
||||
|
||||
- `--max_train_steps` / `--max_train_epochs`
|
||||
|
||||
指定要训练的步数或epoch数。如果两者都指定,则 epoch 数优先。
|
||||
-
|
||||
- `--mixed_precision`
|
||||
|
||||
训练混合精度以节省内存。指定像`--mixed_precision = "fp16"`。与无混合精度(默认)相比,精度可能较低,但训练所需的 GPU 内存明显较少。
|
||||
|
||||
(在RTX30系列以后也可以指定`bf16`,请配合您在搭建环境时做的加速设置)。
|
||||
- `--gradient_checkpointing`
|
||||
|
||||
通过逐步计算权重而不是在训练期间一次计算所有权重来减少训练所需的 GPU 内存量。关闭它不会影响准确性,但打开它允许更大的批次大小,所以那里有影响。
|
||||
|
||||
另外,打开它通常会减慢速度,但可以增加批次大小,因此总的训练时间实际上可能会更快。
|
||||
|
||||
- `--xformers` / `--mem_eff_attn`
|
||||
|
||||
当指定 xformers 选项时,使用 xformers 的 CrossAttention。如果未安装 xformers 或发生错误(取决于环境,例如 `mixed_precision="no"`),请指定 `mem_eff_attn` 选项而不是使用 CrossAttention 的内存节省版本(xformers 比 慢)。
|
||||
- `--save_precision`
|
||||
|
||||
指定保存时的数据精度。为 save_precision 选项指定 float、fp16 或 bf16 将以该格式保存模型(在 DreamBooth 中保存 Diffusers 格式时无效,微调)。当您想缩小模型的尺寸时请使用它。
|
||||
- `--save_every_n_epochs` / `--save_state` / `--resume`
|
||||
为 save_every_n_epochs 选项指定一个数字可以在每个时期的训练期间保存模型。
|
||||
|
||||
如果同时指定save_state选项,训练状态包括优化器的状态等都会一起保存。。保存目的地将是一个文件夹。
|
||||
|
||||
训练状态输出到目标文件夹中名为“<output_name>-??????-state”(??????是epoch数)的文件夹中。长时间训练时请使用。
|
||||
|
||||
使用 resume 选项从保存的训练状态恢复训练。指定训练状态文件夹(其中的状态文件夹,而不是 `output_dir`)。
|
||||
|
||||
请注意,由于 Accelerator 规范,epoch 数和全局步数不会保存,即使恢复时它们也从 1 开始。
|
||||
- `--save_model_as` (DreamBooth, fine tuning 仅有的)
|
||||
|
||||
您可以从 `ckpt, safetensors, diffusers, diffusers_safetensors` 中选择模型保存格式。
|
||||
|
||||
- `--save_model_as=safetensors` 指定喜欢当读取Stable Diffusion格式(ckpt 或safetensors)并以diffusers格式保存时,缺少的信息通过从 Hugging Face 中删除 v1.5 或 v2.1 信息来补充。
|
||||
|
||||
- `--clip_skip`
|
||||
|
||||
`2` 如果指定,则使用文本编码器 (CLIP) 的倒数第二层的输出。如果省略 1 或选项,则使用最后一层。
|
||||
|
||||
*SD2.0默认使用倒数第二层,训练SD2.0时请不要指定。
|
||||
|
||||
如果被训练的模型最初被训练为使用第二层,则 2 是一个很好的值。
|
||||
|
||||
如果您使用的是最后一层,那么整个模型都会根据该假设进行训练。因此,如果再次使用第二层进行训练,可能需要一定数量的teacher数据和更长时间的训练才能得到想要的训练结果。
|
||||
- `--max_token_length`
|
||||
|
||||
默认值为 75。您可以通过指定“150”或“225”来扩展令牌长度来训练。使用长字幕训练时指定。
|
||||
|
||||
但由于训练时token展开的规范与Automatic1111的web UI(除法等规范)略有不同,如非必要建议用75训练。
|
||||
|
||||
与clip_skip一样,训练与模型训练状态不同的长度可能需要一定量的teacher数据和更长的学习时间。
|
||||
|
||||
- `--persistent_data_loader_workers`
|
||||
|
||||
在 Windows 环境中指定它可以显着减少时期之间的延迟。
|
||||
|
||||
- `--max_data_loader_n_workers`
|
||||
|
||||
指定数据加载的进程数。大量的进程会更快地加载数据并更有效地使用 GPU,但会消耗更多的主内存。默认是"`8`或者`CPU并发执行线程数 - 1`,取小者",所以如果主存没有空间或者GPU使用率大概在90%以上,就看那些数字和 `2` 或将其降低到大约 `1`。
|
||||
- `--logging_dir` / `--log_prefix`
|
||||
|
||||
保存训练日志的选项。在 logging_dir 选项中指定日志保存目标文件夹。以 TensorBoard 格式保存日志。
|
||||
|
||||
例如,如果您指定 --logging_dir=logs,将在您的工作文件夹中创建一个日志文件夹,并将日志保存在日期/时间文件夹中。
|
||||
此外,如果您指定 --log_prefix 选项,则指定的字符串将添加到日期和时间之前。使用“--logging_dir=logs --log_prefix=db_style1_”进行识别。
|
||||
|
||||
要检查 TensorBoard 中的日志,请打开另一个命令提示符并在您的工作文件夹中键入:
|
||||
```
|
||||
tensorboard --logdir=logs
|
||||
```
|
||||
|
||||
我觉得tensorboard会在环境搭建的时候安装,如果没有安装,请用`pip install tensorboard`安装。)
|
||||
|
||||
然后打开浏览器到http://localhost:6006/就可以看到了。
|
||||
- `--noise_offset`
|
||||
本文的实现:https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||
|
||||
看起来它可能会为整体更暗和更亮的图像产生更好的结果。它似乎对 LoRA 训练也有效。指定一个大约 0.1 的值似乎很好。
|
||||
|
||||
- `--debug_dataset`
|
||||
|
||||
通过添加此选项,您可以在训练之前检查将训练什么样的图像数据和标题。按 Esc 退出并返回命令行。按 `S` 进入下一步(批次),按 `E` 进入下一个epoch。
|
||||
|
||||
*图片在 Linux 环境(包括 Colab)下不显示。
|
||||
|
||||
- `--vae`
|
||||
|
||||
如果您在 vae 选项中指定Stable Diffusion检查点、VAE 检查点文件、扩散模型或 VAE(两者都可以指定本地或拥抱面模型 ID),则该 VAE 用于训练(缓存时的潜伏)或在训练过程中获得潜伏)。
|
||||
|
||||
对于 DreamBooth 和微调,保存的模型将包含此 VAE
|
||||
|
||||
- `--cache_latents`
|
||||
|
||||
在主内存中缓存 VAE 输出以减少 VRAM 使用。除 flip_aug 之外的任何增强都将不可用。此外,整体训练速度略快。
|
||||
- `--min_snr_gamma`
|
||||
|
||||
指定最小 SNR 加权策略。细节是[这里](https://github.com/kohya-ss/sd-scripts/pull/308)请参阅。论文中推荐`5`。
|
||||
|
||||
## 优化器相关
|
||||
|
||||
- `--optimizer_type`
|
||||
-- 指定优化器类型。您可以指定
|
||||
- AdamW : [torch.optim.AdamW](https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html)
|
||||
- 与过去版本中未指定选项时相同
|
||||
- AdamW8bit : 参数同上
|
||||
- PagedAdamW8bit : 参数同上
|
||||
- 与过去版本中指定的 --use_8bit_adam 相同
|
||||
- Lion : https://github.com/lucidrains/lion-pytorch
|
||||
- Lion8bit : 参数同上
|
||||
- PagedLion8bit : 参数同上
|
||||
- 与过去版本中指定的 --use_lion_optimizer 相同
|
||||
- SGDNesterov : [torch.optim.SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html), nesterov=True
|
||||
- SGDNesterov8bit : 参数同上
|
||||
- DAdaptation(DAdaptAdamPreprint) : https://github.com/facebookresearch/dadaptation
|
||||
- DAdaptAdam : 参数同上
|
||||
- DAdaptAdaGrad : 参数同上
|
||||
- DAdaptAdan : 参数同上
|
||||
- DAdaptAdanIP : 参数同上
|
||||
- DAdaptLion : 参数同上
|
||||
- DAdaptSGD : 参数同上
|
||||
- Prodigy : https://github.com/konstmish/prodigy
|
||||
- AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules)
|
||||
- 任何优化器
|
||||
|
||||
- `--learning_rate`
|
||||
|
||||
指定学习率。合适的学习率取决于训练脚本,所以请参考每个解释。
|
||||
- `--lr_scheduler` / `--lr_warmup_steps` / `--lr_scheduler_num_cycles` / `--lr_scheduler_power`
|
||||
|
||||
学习率的调度程序相关规范。
|
||||
|
||||
使用 lr_scheduler 选项,您可以从线性、余弦、cosine_with_restarts、多项式、常数、constant_with_warmup 或任何调度程序中选择学习率调度程序。默认值是常量。
|
||||
|
||||
使用 lr_warmup_steps,您可以指定预热调度程序的步数(逐渐改变学习率)。
|
||||
|
||||
lr_scheduler_num_cycles 是 cosine with restarts 调度器中的重启次数,lr_scheduler_power 是多项式调度器中的多项式幂。
|
||||
|
||||
有关详细信息,请自行研究。
|
||||
|
||||
要使用任何调度程序,请像使用任何优化器一样使用“--scheduler_args”指定可选参数。
|
||||
### 关于指定优化器
|
||||
|
||||
使用 --optimizer_args 选项指定优化器选项参数。可以以key=value的格式指定多个值。此外,您可以指定多个值,以逗号分隔。例如,要指定 AdamW 优化器的参数,``--optimizer_args weight_decay=0.01 betas=.9,.999``。
|
||||
|
||||
指定可选参数时,请检查每个优化器的规格。
|
||||
一些优化器有一个必需的参数,如果省略它会自动添加(例如 SGDNesterov 的动量)。检查控制台输出。
|
||||
|
||||
D-Adaptation 优化器自动调整学习率。学习率选项指定的值不是学习率本身,而是D-Adaptation决定的学习率的应用率,所以通常指定1.0。如果您希望 Text Encoder 的学习率是 U-Net 的一半,请指定 ``--text_encoder_lr=0.5 --unet_lr=1.0``。
|
||||
如果指定 relative_step=True,AdaFactor 优化器可以自动调整学习率(如果省略,将默认添加)。自动调整时,学习率调度器被迫使用 adafactor_scheduler。此外,指定 scale_parameter 和 warmup_init 似乎也不错。
|
||||
|
||||
自动调整的选项类似于``--optimizer_args "relative_step=True" "scale_parameter=True" "warmup_init=True"``。
|
||||
|
||||
如果您不想自动调整学习率,请添加可选参数 ``relative_step=False``。在那种情况下,似乎建议将 constant_with_warmup 用于学习率调度程序,而不要为梯度剪裁范数。所以参数就像``--optimizer_type=adafactor --optimizer_args "relative_step=False" --lr_scheduler="constant_with_warmup" --max_grad_norm=0.0``。
|
||||
|
||||
### 使用任何优化器
|
||||
|
||||
使用 ``torch.optim`` 优化器时,仅指定类名(例如 ``--optimizer_type=RMSprop``),使用其他模块的优化器时,指定“模块名.类名”。(例如``--optimizer_type=bitsandbytes.optim.lamb.LAMB``)。
|
||||
|
||||
(内部仅通过 importlib 未确认操作。如果需要,请安装包。)
|
||||
<!--
|
||||
## 使用任意大小的图像进行训练 --resolution
|
||||
你可以在广场外训练。请在分辨率中指定“宽度、高度”,如“448,640”。宽度和高度必须能被 64 整除。匹配训练图像和正则化图像的大小。
|
||||
|
||||
就我个人而言,我经常生成垂直长的图像,所以我有时会用“448、640”来训练。
|
||||
|
||||
## 纵横比分桶 --enable_bucket / --min_bucket_reso / --max_bucket_reso
|
||||
它通过指定 enable_bucket 选项来启用。 Stable Diffusion 在 512x512 分辨率下训练,但也在 256x768 和 384x640 等分辨率下训练。
|
||||
|
||||
如果指定此选项,则不需要将训练图像和正则化图像统一为特定分辨率。从多种分辨率(纵横比)中进行选择,并在该分辨率下训练。
|
||||
由于分辨率为 64 像素,纵横比可能与原始图像不完全相同。
|
||||
|
||||
您可以使用 min_bucket_reso 选项指定分辨率的最小大小,使用 max_bucket_reso 指定最大大小。默认值分别为 256 和 1024。
|
||||
例如,将最小尺寸指定为 384 将不会使用 256x1024 或 320x768 等分辨率。
|
||||
如果将分辨率增加到 768x768,您可能需要将 1280 指定为最大尺寸。
|
||||
|
||||
启用 Aspect Ratio Ratio Bucketing 时,最好准备具有与训练图像相似的各种分辨率的正则化图像。
|
||||
|
||||
(因为一批中的图像不偏向于训练图像和正则化图像。
|
||||
|
||||
## 扩充 --color_aug / --flip_aug
|
||||
增强是一种通过在训练过程中动态改变数据来提高模型性能的方法。在使用 color_aug 巧妙地改变色调并使用 flip_aug 左右翻转的同时训练。
|
||||
|
||||
由于数据是动态变化的,因此不能与 cache_latents 选项一起指定。
|
||||
|
||||
## 使用 fp16 梯度训练(实验特征)--full_fp16
|
||||
如果指定 full_fp16 选项,梯度从普通 float32 变为 float16 (fp16) 并训练(它似乎是 full fp16 训练而不是混合精度)。
|
||||
结果,似乎 SD1.x 512x512 大小可以在 VRAM 使用量小于 8GB 的情况下训练,而 SD2.x 512x512 大小可以在 VRAM 使用量小于 12GB 的情况下训练。
|
||||
|
||||
预先在加速配置中指定 fp16,并可选择设置 ``mixed_precision="fp16"``(bf16 不起作用)。
|
||||
|
||||
为了最大限度地减少内存使用,请使用 xformers、use_8bit_adam、cache_latents、gradient_checkpointing 选项并将 train_batch_size 设置为 1。
|
||||
|
||||
(如果你负担得起,逐步增加 train_batch_size 应该会提高一点精度。)
|
||||
|
||||
它是通过修补 PyTorch 源代码实现的(已通过 PyTorch 1.12.1 和 1.13.0 确认)。准确率会大幅下降,途中学习失败的概率也会增加。
|
||||
学习率和步数的设置似乎很严格。请注意它们并自行承担使用它们的风险。
|
||||
-->
|
||||
|
||||
# 创建元数据文件
|
||||
|
||||
## 准备训练数据
|
||||
|
||||
如上所述准备好你要训练的图像数据,放在任意文件夹中。
|
||||
|
||||
例如,存储这样的图像:
|
||||
|
||||

|
||||
|
||||
## 自动captioning
|
||||
|
||||
如果您只想训练没有标题的标签,请跳过。
|
||||
|
||||
另外,手动准备caption时,请准备在与教师数据图像相同的目录下,文件名相同,扩展名.caption等。每个文件应该是只有一行的文本文件。
|
||||
### 使用 BLIP 添加caption
|
||||
|
||||
最新版本不再需要 BLIP 下载、权重下载和额外的虚拟环境。按原样工作。
|
||||
|
||||
运行 finetune 文件夹中的 make_captions.py。
|
||||
|
||||
```
|
||||
python finetune\make_captions.py --batch_size <バッチサイズ> <教師データフォルダ>
|
||||
```
|
||||
|
||||
如果batch size为8,训练数据放在父文件夹train_data中,则会如下所示
|
||||
```
|
||||
python finetune\make_captions.py --batch_size 8 ..\train_data
|
||||
```
|
||||
|
||||
caption文件创建在与教师数据图像相同的目录中,具有相同的文件名和扩展名.caption。
|
||||
|
||||
根据 GPU 的 VRAM 容量增加或减少 batch_size。越大越快(我认为 12GB 的 VRAM 可以多一点)。
|
||||
您可以使用 max_length 选项指定caption的最大长度。默认值为 75。如果使用 225 的令牌长度训练模型,它可能会更长。
|
||||
您可以使用 caption_extension 选项更改caption扩展名。默认为 .caption(.txt 与稍后描述的 DeepDanbooru 冲突)。
|
||||
如果有多个教师数据文件夹,则对每个文件夹执行。
|
||||
|
||||
请注意,推理是随机的,因此每次运行时结果都会发生变化。如果要修复它,请使用 --seed 选项指定一个随机数种子,例如 `--seed 42`。
|
||||
|
||||
其他的选项,请参考help with `--help`(好像没有文档说明参数的含义,得看源码)。
|
||||
|
||||
默认情况下,会生成扩展名为 .caption 的caption文件。
|
||||
|
||||

|
||||
|
||||
例如,标题如下:
|
||||
|
||||

|
||||
|
||||
## 由 DeepDanbooru 标记
|
||||
|
||||
如果不想给danbooru标签本身打标签,请继续“标题和标签信息的预处理”。
|
||||
|
||||
标记是使用 DeepDanbooru 或 WD14Tagger 完成的。 WD14Tagger 似乎更准确。如果您想使用 WD14Tagger 进行标记,请跳至下一章。
|
||||
### 环境布置
|
||||
|
||||
将 DeepDanbooru https://github.com/KichangKim/DeepDanbooru 克隆到您的工作文件夹中,或下载并展开 zip。我解压缩了它。
|
||||
另外,从 DeepDanbooru 发布页面 https://github.com/KichangKim/DeepDanbooru/releases 上的“DeepDanbooru 预训练模型 v3-20211112-sgd-e28”的资产下载 deepdanbooru-v3-20211112-sgd-e28.zip 并解压到 DeepDanbooru 文件夹。
|
||||
|
||||
从下面下载。单击以打开资产并从那里下载。
|
||||
|
||||

|
||||
|
||||
做一个这样的目录结构
|
||||
|
||||

|
||||
为diffusers环境安装必要的库。进入 DeepDanbooru 文件夹并安装它(我认为它实际上只是添加了 tensorflow-io)。
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
接下来,安装 DeepDanbooru 本身。
|
||||
|
||||
```
|
||||
pip install .
|
||||
```
|
||||
|
||||
这样就完成了标注环境的准备工作。
|
||||
|
||||
### 实施标记
|
||||
转到 DeepDanbooru 的文件夹并运行 deepdanbooru 进行标记。
|
||||
```
|
||||
deepdanbooru evaluate <教师资料夹> --project-path deepdanbooru-v3-20211112-sgd-e28 --allow-folder --save-txt
|
||||
```
|
||||
|
||||
如果将训练数据放在父文件夹train_data中,则如下所示。
|
||||
```
|
||||
deepdanbooru evaluate ../train_data --project-path deepdanbooru-v3-20211112-sgd-e28 --allow-folder --save-txt
|
||||
```
|
||||
|
||||
在与教师数据图像相同的目录中创建具有相同文件名和扩展名.txt 的标记文件。它很慢,因为它是一个接一个地处理的。
|
||||
|
||||
如果有多个教师数据文件夹,则对每个文件夹执行。
|
||||
|
||||
它生成如下。
|
||||
|
||||

|
||||
|
||||
它会被这样标记(信息量很大...)。
|
||||
|
||||

|
||||
|
||||
## WD14Tagger标记为
|
||||
|
||||
此过程使用 WD14Tagger 而不是 DeepDanbooru。
|
||||
|
||||
使用 Mr. Automatic1111 的 WebUI 中使用的标记器。我参考了这个 github 页面上的信息 (https://github.com/toriato/stable-diffusion-webui-wd14-tagger#mrsmilingwolfs-model-aka-waifu-diffusion-14-tagger)。
|
||||
|
||||
初始环境维护所需的模块已经安装。权重自动从 Hugging Face 下载。
|
||||
### 实施标记
|
||||
|
||||
运行脚本以进行标记。
|
||||
```
|
||||
python tag_images_by_wd14_tagger.py --batch_size <バッチサイズ> <教師データフォルダ>
|
||||
```
|
||||
|
||||
如果将训练数据放在父文件夹train_data中,则如下所示
|
||||
```
|
||||
python tag_images_by_wd14_tagger.py --batch_size 4 ..\train_data
|
||||
```
|
||||
|
||||
模型文件将在首次启动时自动下载到 wd14_tagger_model 文件夹(文件夹可以在选项中更改)。它将如下所示。
|
||||

|
||||
|
||||
在与教师数据图像相同的目录中创建具有相同文件名和扩展名.txt 的标记文件。
|
||||

|
||||
|
||||

|
||||
|
||||
使用 thresh 选项,您可以指定确定的标签的置信度数以附加标签。默认值为 0.35,与 WD14Tagger 示例相同。较低的值给出更多的标签,但准确性较低。
|
||||
|
||||
根据 GPU 的 VRAM 容量增加或减少 batch_size。越大越快(我认为 12GB 的 VRAM 可以多一点)。您可以使用 caption_extension 选项更改标记文件扩展名。默认为 .txt。
|
||||
|
||||
您可以使用 model_dir 选项指定保存模型的文件夹。
|
||||
|
||||
此外,如果指定 force_download 选项,即使有保存目标文件夹,也会重新下载模型。
|
||||
|
||||
如果有多个教师数据文件夹,则对每个文件夹执行。
|
||||
|
||||
## 预处理caption和标签信息
|
||||
|
||||
将caption和标签作为元数据合并到一个文件中,以便从脚本中轻松处理。
|
||||
### caption预处理
|
||||
|
||||
要将caption放入元数据,请在您的工作文件夹中运行以下命令(如果您不使用caption进行训练,则不需要运行它)(它实际上是一行,依此类推)。指定 `--full_path` 选项以将图像文件的完整路径存储在元数据中。如果省略此选项,则会记录相对路径,但 .toml 文件中需要单独的文件夹规范。
|
||||
```
|
||||
python merge_captions_to_metadata.py --full_path <教师资料夹>
|
||||
--in_json <要读取的元数据文件名> <元数据文件名>
|
||||
```
|
||||
|
||||
元数据文件名是任意名称。
|
||||
如果训练数据为train_data,没有读取元数据文件,元数据文件为meta_cap.json,则会如下。
|
||||
```
|
||||
python merge_captions_to_metadata.py --full_path train_data meta_cap.json
|
||||
```
|
||||
|
||||
您可以使用 caption_extension 选项指定标题扩展。
|
||||
|
||||
如果有多个教师数据文件夹,请指定 full_path 参数并为每个文件夹执行。
|
||||
```
|
||||
python merge_captions_to_metadata.py --full_path
|
||||
train_data1 meta_cap1.json
|
||||
python merge_captions_to_metadata.py --full_path --in_json meta_cap1.json
|
||||
train_data2 meta_cap2.json
|
||||
```
|
||||
如果省略in_json,如果有写入目标元数据文件,将从那里读取并覆盖。
|
||||
|
||||
__* 每次重写 in_json 选项和写入目标并写入单独的元数据文件是安全的。 __
|
||||
### 标签预处理
|
||||
|
||||
同样,标签也收集在元数据中(如果标签不用于训练,则无需这样做)。
|
||||
```
|
||||
python merge_dd_tags_to_metadata.py --full_path <教师资料夹>
|
||||
--in_json <要读取的元数据文件名> <要写入的元数据文件名>
|
||||
```
|
||||
|
||||
同样的目录结构,读取meta_cap.json和写入meta_cap_dd.json时,会是这样的。
|
||||
```
|
||||
python merge_dd_tags_to_metadata.py --full_path train_data --in_json meta_cap.json meta_cap_dd.json
|
||||
```
|
||||
|
||||
如果有多个教师数据文件夹,请指定 full_path 参数并为每个文件夹执行。
|
||||
|
||||
```
|
||||
python merge_dd_tags_to_metadata.py --full_path --in_json meta_cap2.json
|
||||
train_data1 meta_cap_dd1.json
|
||||
python merge_dd_tags_to_metadata.py --full_path --in_json meta_cap_dd1.json
|
||||
train_data2 meta_cap_dd2.json
|
||||
```
|
||||
|
||||
如果省略in_json,如果有写入目标元数据文件,将从那里读取并覆盖。
|
||||
__※ 通过每次重写 in_json 选项和写入目标,写入单独的元数据文件是安全的。 __
|
||||
### 标题和标签清理
|
||||
|
||||
到目前为止,标题和DeepDanbooru标签已经被整理到元数据文件中。然而,自动标题生成的标题存在表达差异等微妙问题(※),而标签中可能包含下划线和评级(DeepDanbooru的情况下)。因此,最好使用编辑器的替换功能清理标题和标签。
|
||||
|
||||
※例如,如果要学习动漫中的女孩,标题可能会包含girl/girls/woman/women等不同的表达方式。另外,将"anime girl"简单地替换为"girl"可能更合适。
|
||||
|
||||
我们提供了用于清理的脚本,请根据情况编辑脚本并使用它。
|
||||
|
||||
(不需要指定教师数据文件夹。将清理元数据中的所有数据。)
|
||||
|
||||
```
|
||||
python clean_captions_and_tags.py <要读取的元数据文件名> <要写入的元数据文件名>
|
||||
```
|
||||
|
||||
--in_json 请注意,不包括在内。例如:
|
||||
|
||||
```
|
||||
python clean_captions_and_tags.py meta_cap_dd.json meta_clean.json
|
||||
```
|
||||
|
||||
标题和标签的预处理现已完成。
|
||||
|
||||
## 预先获取 latents
|
||||
|
||||
※ 这一步骤并非必须。即使省略此步骤,也可以在训练过程中获取 latents。但是,如果在训练时执行 `random_crop` 或 `color_aug` 等操作,则无法预先获取 latents(因为每次图像都会改变)。如果不进行预先获取,则可以使用到目前为止的元数据进行训练。
|
||||
|
||||
提前获取图像的潜在表达并保存到磁盘上。这样可以加速训练过程。同时进行 bucketing(根据宽高比对训练数据进行分类)。
|
||||
|
||||
请在工作文件夹中输入以下内容。
|
||||
|
||||
```
|
||||
python prepare_buckets_latents.py --full_path <教师资料夹>
|
||||
<要读取的元数据文件名> <要写入的元数据文件名>
|
||||
<要微调的模型名称或检查点>
|
||||
--batch_size <批次大小>
|
||||
--max_resolution <分辨率宽、高>
|
||||
--mixed_precision <准确性>
|
||||
```
|
||||
|
||||
如果要从meta_clean.json中读取元数据,并将其写入meta_lat.json,使用模型model.ckpt,批处理大小为4,训练分辨率为512*512,精度为no(float32),则应如下所示。
|
||||
```
|
||||
python prepare_buckets_latents.py --full_path
|
||||
train_data meta_clean.json meta_lat.json model.ckpt
|
||||
--batch_size 4 --max_resolution 512,512 --mixed_precision no
|
||||
```
|
||||
|
||||
教师数据文件夹中,latents以numpy的npz格式保存。
|
||||
|
||||
您可以使用--min_bucket_reso选项指定最小分辨率大小,--max_bucket_reso指定最大大小。默认值分别为256和1024。例如,如果指定最小大小为384,则将不再使用分辨率为256 * 1024或320 * 768等。如果将分辨率增加到768 * 768等较大的值,则最好将最大大小指定为1280等。
|
||||
|
||||
如果指定--flip_aug选项,则进行左右翻转的数据增强。虽然这可以使数据量伪造一倍,但如果数据不是左右对称的(例如角色外观、发型等),则可能会导致训练不成功。
|
||||
|
||||
对于翻转的图像,也会获取latents,并保存名为\ *_flip.npz的文件,这是一个简单的实现。在fline_tune.py中不需要特定的选项。如果有带有\_flip的文件,则会随机加载带有和不带有flip的文件。
|
||||
|
||||
即使VRAM为12GB,批次大小也可以稍微增加。分辨率以“宽度,高度”的形式指定,必须是64的倍数。分辨率直接影响fine tuning时的内存大小。在12GB VRAM中,512,512似乎是极限(*)。如果有16GB,则可以将其提高到512,704或512,768。即使分辨率为256,256等,VRAM 8GB也很难承受(因为参数、优化器等与分辨率无关,需要一定的内存)。
|
||||
|
||||
*有报道称,在batch size为1的训练中,使用12GB VRAM和640,640的分辨率。
|
||||
|
||||
以下是bucketing结果的显示方式。
|
||||
|
||||

|
||||
|
||||
如果有多个教师数据文件夹,请指定 full_path 参数并为每个文件夹执行
|
||||
|
||||
```
|
||||
python prepare_buckets_latents.py --full_path
|
||||
train_data1 meta_clean.json meta_lat1.json model.ckpt
|
||||
--batch_size 4 --max_resolution 512,512 --mixed_precision no
|
||||
|
||||
python prepare_buckets_latents.py --full_path
|
||||
train_data2 meta_lat1.json meta_lat2.json model.ckpt
|
||||
--batch_size 4 --max_resolution 512,512 --mixed_precision no
|
||||
|
||||
```
|
||||
可以将读取源和写入目标设为相同,但分开设定更为安全。
|
||||
|
||||
__※建议每次更改参数并将其写入另一个元数据文件,以确保安全性。__
|
||||
162
docs/train_db_README-zh.md
Normal file
162
docs/train_db_README-zh.md
Normal file
@@ -0,0 +1,162 @@
|
||||
这是DreamBooth的指南。
|
||||
|
||||
请同时查看[关于学习的通用文档](./train_README-zh.md)。
|
||||
|
||||
# 概要
|
||||
|
||||
DreamBooth是一种将特定主题添加到图像生成模型中进行学习,并使用特定识别子生成它的技术。论文链接。
|
||||
|
||||
具体来说,它可以将角色和绘画风格等添加到Stable Diffusion模型中进行学习,并使用特定的单词(例如`shs`)来调用(呈现在生成的图像中)。
|
||||
|
||||
脚本基于Diffusers的DreamBooth,但添加了以下功能(一些功能已在原始脚本中得到支持)。
|
||||
|
||||
脚本的主要功能如下:
|
||||
|
||||
- 使用8位Adam优化器和潜在变量的缓存来节省内存(与Shivam Shrirao版相似)。
|
||||
- 使用xformers来节省内存。
|
||||
- 不仅支持512x512,还支持任意尺寸的训练。
|
||||
- 通过数据增强来提高质量。
|
||||
- 支持DreamBooth和Text Encoder + U-Net的微调。
|
||||
- 支持以Stable Diffusion格式读写模型。
|
||||
- 支持Aspect Ratio Bucketing。
|
||||
- 支持Stable Diffusion v2.0。
|
||||
|
||||
# 训练步骤
|
||||
|
||||
请先参阅此存储库的README以进行环境设置。
|
||||
|
||||
## 准备数据
|
||||
|
||||
请参阅[有关准备训练数据的说明](./train_README-zh.md)。
|
||||
|
||||
## 运行训练
|
||||
|
||||
运行脚本。以下是最大程度地节省内存的命令(实际上,这将在一行中输入)。请根据需要修改每行。它似乎需要约12GB的VRAM才能运行。
|
||||
```
|
||||
accelerate launch --num_cpu_threads_per_process 1 train_db.py
|
||||
--pretrained_model_name_or_path=<.ckpt或.safetensord或Diffusers版模型的目录>
|
||||
--dataset_config=<数据准备时创建的.toml文件>
|
||||
--output_dir=<训练模型的输出目录>
|
||||
--output_name=<训练模型输出时的文件名>
|
||||
--save_model_as=safetensors
|
||||
--prior_loss_weight=1.0
|
||||
--max_train_steps=1600
|
||||
--learning_rate=1e-6
|
||||
--optimizer_type="AdamW8bit"
|
||||
--xformers
|
||||
--mixed_precision="fp16"
|
||||
--cache_latents
|
||||
--gradient_checkpointing
|
||||
```
|
||||
`num_cpu_threads_per_process` 通常应该设置为1。
|
||||
|
||||
`pretrained_model_name_or_path` 指定要进行追加训练的基础模型。可以指定 Stable Diffusion 的 checkpoint 文件(.ckpt 或 .safetensors)、Diffusers 的本地模型目录或模型 ID(如 "stabilityai/stable-diffusion-2")。
|
||||
|
||||
`output_dir` 指定保存训练后模型的文件夹。在 `output_name` 中指定模型文件名,不包括扩展名。使用 `save_model_as` 指定以 safetensors 格式保存。
|
||||
|
||||
在 `dataset_config` 中指定 `.toml` 文件。初始批处理大小应为 `1`,以减少内存消耗。
|
||||
|
||||
`prior_loss_weight` 是正则化图像损失的权重。通常设为1.0。
|
||||
|
||||
将要训练的步数 `max_train_steps` 设置为1600。在这里,学习率 `learning_rate` 被设置为1e-6。
|
||||
|
||||
为了节省内存,设置 `mixed_precision="fp16"`(在 RTX30 系列及更高版本中也可以设置为 `bf16`)。同时指定 `gradient_checkpointing`。
|
||||
|
||||
为了使用内存消耗较少的 8bit AdamW 优化器(将模型优化为适合于训练数据的状态),指定 `optimizer_type="AdamW8bit"`。
|
||||
|
||||
指定 `xformers` 选项,并使用 xformers 的 CrossAttention。如果未安装 xformers 或出现错误(具体情况取决于环境,例如使用 `mixed_precision="no"`),则可以指定 `mem_eff_attn` 选项以使用省内存版的 CrossAttention(速度会变慢)。
|
||||
|
||||
为了节省内存,指定 `cache_latents` 选项以缓存 VAE 的输出。
|
||||
|
||||
如果有足够的内存,请编辑 `.toml` 文件将批处理大小增加到大约 `4`(可能会提高速度和精度)。此外,取消 `cache_latents` 选项可以进行数据增强。
|
||||
|
||||
### 常用选项
|
||||
|
||||
对于以下情况,请参阅“常用选项”部分。
|
||||
|
||||
- 学习 Stable Diffusion 2.x 或其衍生模型。
|
||||
- 学习基于 clip skip 大于等于2的模型。
|
||||
- 学习超过75个令牌的标题。
|
||||
|
||||
### 关于DreamBooth中的步数
|
||||
|
||||
为了实现省内存化,该脚本中每个步骤的学习次数减半(因为学习和正则化的图像在训练时被分为不同的批次)。
|
||||
|
||||
要进行与原始Diffusers版或XavierXiao的Stable Diffusion版几乎相同的学习,请将步骤数加倍。
|
||||
|
||||
(虽然在将学习图像和正则化图像整合后再打乱顺序,但我认为对学习没有太大影响。)
|
||||
|
||||
关于DreamBooth的批量大小
|
||||
|
||||
与像LoRA这样的学习相比,为了训练整个模型,内存消耗量会更大(与微调相同)。
|
||||
|
||||
关于学习率
|
||||
|
||||
在Diffusers版中,学习率为5e-6,而在Stable Diffusion版中为1e-6,因此在上面的示例中指定了1e-6。
|
||||
|
||||
当使用旧格式的数据集指定命令行时
|
||||
|
||||
使用选项指定分辨率和批量大小。命令行示例如下。
|
||||
```
|
||||
accelerate launch --num_cpu_threads_per_process 1 train_db.py
|
||||
--pretrained_model_name_or_path=<.ckpt或.safetensord或Diffusers版模型的目录>
|
||||
--train_data_dir=<训练数据的目录>
|
||||
--reg_data_dir=<正则化图像的目录>
|
||||
--output_dir=<训练后模型的输出目录>
|
||||
--output_name=<训练后模型输出文件的名称>
|
||||
--prior_loss_weight=1.0
|
||||
--resolution=512
|
||||
--train_batch_size=1
|
||||
--learning_rate=1e-6
|
||||
--max_train_steps=1600
|
||||
--use_8bit_adam
|
||||
--xformers
|
||||
--mixed_precision="bf16"
|
||||
--cache_latents
|
||||
--gradient_checkpointing
|
||||
```
|
||||
|
||||
## 使用训练好的模型生成图像
|
||||
|
||||
训练完成后,将在指定的文件夹中以指定的名称输出safetensors文件。
|
||||
|
||||
对于v1.4/1.5和其他派生模型,可以在此模型中使用Automatic1111先生的WebUI进行推断。请将其放置在models\Stable-diffusion文件夹中。
|
||||
|
||||
对于使用v2.x模型在WebUI中生成图像的情况,需要单独的.yaml文件来描述模型的规格。对于v2.x base,需要v2-inference.yaml,对于768/v,则需要v2-inference-v.yaml。请将它们放置在相同的文件夹中,并将文件扩展名之前的部分命名为与模型相同的名称。
|
||||

|
||||
|
||||
每个yaml文件都在[Stability AI的SD2.0存储库](https://github.com/Stability-AI/stablediffusion/tree/main/configs/stable-diffusion)……之中。
|
||||
|
||||
# DreamBooth的其他主要选项
|
||||
|
||||
有关所有选项的详细信息,请参阅另一份文档。
|
||||
|
||||
## 不在中途开始对文本编码器进行训练 --stop_text_encoder_training
|
||||
|
||||
如果在stop_text_encoder_training选项中指定一个数字,则在该步骤之后,将不再对文本编码器进行训练,只会对U-Net进行训练。在某些情况下,可能会期望提高精度。
|
||||
|
||||
(我们推测可能会有时候仅仅文本编码器会过度学习,而这样做可以避免这种情况,但详细影响尚不清楚。)
|
||||
|
||||
## 不进行分词器的填充 --no_token_padding
|
||||
|
||||
如果指定no_token_padding选项,则不会对分词器的输出进行填充(与Diffusers版本的旧DreamBooth相同)。
|
||||
|
||||
<!--
|
||||
如果使用分桶(bucketing)和数据增强(augmentation),则使用示例如下:
|
||||
```
|
||||
accelerate launch --num_cpu_threads_per_process 8 train_db.py
|
||||
--pretrained_model_name_or_path=<.ckpt或.safetensord或Diffusers版模型的目录>
|
||||
--train_data_dir=<训练数据的目录>
|
||||
--reg_data_dir=<正则化图像的目录>
|
||||
--output_dir=<训练后模型的输出目录>
|
||||
--resolution=768,512
|
||||
--train_batch_size=20 --learning_rate=5e-6 --max_train_steps=800
|
||||
--use_8bit_adam --xformers --mixed_precision="bf16"
|
||||
--save_every_n_epochs=1 --save_state --save_precision="bf16"
|
||||
--logging_dir=logs
|
||||
--enable_bucket --min_bucket_reso=384 --max_bucket_reso=1280
|
||||
--color_aug --flip_aug --gradient_checkpointing --seed 42
|
||||
```
|
||||
|
||||
|
||||
-->
|
||||
214
docs/train_lllite_README-ja.md
Normal file
214
docs/train_lllite_README-ja.md
Normal file
@@ -0,0 +1,214 @@
|
||||
# ControlNet-LLLite について
|
||||
|
||||
__きわめて実験的な実装のため、将来的に大きく変更される可能性があります。__
|
||||
|
||||
## 概要
|
||||
ControlNet-LLLite は、[ControlNet](https://github.com/lllyasviel/ControlNet) の軽量版です。LoRA Like Lite という意味で、LoRAからインスピレーションを得た構造を持つ、軽量なControlNetです。現在はSDXLにのみ対応しています。
|
||||
|
||||
## サンプルの重みファイルと推論
|
||||
|
||||
こちらにあります: https://huggingface.co/kohya-ss/controlnet-lllite
|
||||
|
||||
ComfyUIのカスタムノードを用意しています。: https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI
|
||||
|
||||
生成サンプルはこのページの末尾にあります。
|
||||
|
||||
## モデル構造
|
||||
ひとつのLLLiteモジュールは、制御用画像(以下conditioning image)を潜在空間に写像するconditioning image embeddingと、LoRAにちょっと似た構造を持つ小型のネットワークからなります。LLLiteモジュールを、LoRAと同様にU-NetのLinearやConvに追加します。詳しくはソースコードを参照してください。
|
||||
|
||||
推論環境の制限で、現在はCrossAttentionのみ(attn1のq/k/v、attn2のq)に追加されます。
|
||||
|
||||
## モデルの学習
|
||||
|
||||
### データセットの準備
|
||||
通常のdatasetに加え、`conditioning_data_dir` で指定したディレクトリにconditioning imageを格納してください。conditioning imageは学習用画像と同じbasenameを持つ必要があります。また、conditioning imageは学習用画像と同じサイズに自動的にリサイズされます。conditioning imageにはキャプションファイルは不要です。
|
||||
|
||||
たとえば DreamBooth 方式でキャプションファイルを用いる場合の設定ファイルは以下のようになります。
|
||||
|
||||
```toml
|
||||
[[datasets.subsets]]
|
||||
image_dir = "path/to/image/dir"
|
||||
caption_extension = ".txt"
|
||||
conditioning_data_dir = "path/to/conditioning/image/dir"
|
||||
```
|
||||
|
||||
現時点の制約として、random_cropは使用できません。
|
||||
|
||||
学習データとしては、元のモデルで生成した画像を学習用画像として、そこから加工した画像をconditioning imageとした、合成によるデータセットを用いるのがもっとも簡単です(データセットの品質的には問題があるかもしれません)。具体的なデータセットの合成方法については後述します。
|
||||
|
||||
なお、元モデルと異なる画風の画像を学習用画像とすると、制御に加えて、その画風についても学ぶ必要が生じます。ControlNet-LLLiteは容量が少ないため、画風学習には不向きです。このような場合には、後述の次元数を多めにしてください。
|
||||
|
||||
### 学習
|
||||
スクリプトで生成する場合は、`sdxl_train_control_net_lllite.py` を実行してください。`--cond_emb_dim` でconditioning image embeddingの次元数を指定できます。`--network_dim` でLoRA的モジュールのrankを指定できます。その他のオプションは`sdxl_train_network.py`に準じますが、`--network_module`の指定は不要です。
|
||||
|
||||
学習時にはメモリを大量に使用しますので、キャッシュやgradient checkpointingなどの省メモリ化のオプションを有効にしてください。また`--full_bf16` オプションで、BFloat16を使用するのも有効です(RTX 30シリーズ以降のGPUが必要です)。24GB VRAMで動作確認しています。
|
||||
|
||||
conditioning image embeddingの次元数は、サンプルのCannyでは32を指定しています。LoRA的モジュールのrankは同じく64です。対象とするconditioning imageの特徴に合わせて調整してください。
|
||||
|
||||
(サンプルのCannyは恐らくかなり難しいと思われます。depthなどでは半分程度にしてもいいかもしれません。)
|
||||
|
||||
以下は .toml の設定例です。
|
||||
|
||||
```toml
|
||||
pretrained_model_name_or_path = "/path/to/model_trained_on.safetensors"
|
||||
max_train_epochs = 12
|
||||
max_data_loader_n_workers = 4
|
||||
persistent_data_loader_workers = true
|
||||
seed = 42
|
||||
gradient_checkpointing = true
|
||||
mixed_precision = "bf16"
|
||||
save_precision = "bf16"
|
||||
full_bf16 = true
|
||||
optimizer_type = "adamw8bit"
|
||||
learning_rate = 2e-4
|
||||
xformers = true
|
||||
output_dir = "/path/to/output/dir"
|
||||
output_name = "output_name"
|
||||
save_every_n_epochs = 1
|
||||
save_model_as = "safetensors"
|
||||
vae_batch_size = 4
|
||||
cache_latents = true
|
||||
cache_latents_to_disk = true
|
||||
cache_text_encoder_outputs = true
|
||||
cache_text_encoder_outputs_to_disk = true
|
||||
network_dim = 64
|
||||
cond_emb_dim = 32
|
||||
dataset_config = "/path/to/dataset.toml"
|
||||
```
|
||||
|
||||
### 推論
|
||||
|
||||
スクリプトで生成する場合は、`sdxl_gen_img.py` を実行してください。`--control_net_lllite_models` でLLLiteのモデルファイルを指定できます。次元数はモデルファイルから自動取得します。
|
||||
|
||||
`--guide_image_path`で推論に用いるconditioning imageを指定してください。なおpreprocessは行われないため、たとえばCannyならCanny処理を行った画像を指定してください(背景黒に白線)。`--control_net_preps`, `--control_net_weights`, `--control_net_ratios` には未対応です。
|
||||
|
||||
## データセットの合成方法
|
||||
|
||||
### 学習用画像の生成
|
||||
|
||||
学習のベースとなるモデルで画像生成を行います。Web UIやComfyUIなどで生成してください。画像サイズはモデルのデフォルトサイズで良いと思われます(1024x1024など)。bucketingを用いることもできます。その場合は適宜適切な解像度で生成してください。
|
||||
|
||||
生成時のキャプション等は、ControlNet-LLLiteの利用時に生成したい画像にあわせるのが良いと思われます。
|
||||
|
||||
生成した画像を任意のディレクトリに保存してください。このディレクトリをデータセットの設定ファイルで指定します。
|
||||
|
||||
当リポジトリ内の `sdxl_gen_img.py` でも生成できます。例えば以下のように実行します。
|
||||
|
||||
```dos
|
||||
python sdxl_gen_img.py --ckpt path/to/model.safetensors --n_iter 1 --scale 10 --steps 36 --outdir path/to/output/dir --xformers --W 1024 --H 1024 --original_width 2048 --original_height 2048 --bf16 --sampler ddim --batch_size 4 --vae_batch_size 2 --images_per_prompt 512 --max_embeddings_multiples 1 --prompt "{portrait|digital art|anime screen cap|detailed illustration} of 1girl, {standing|sitting|walking|running|dancing} on {classroom|street|town|beach|indoors|outdoors}, {looking at viewer|looking away|looking at another}, {in|wearing} {shirt and skirt|school uniform|casual wear} { |, dynamic pose}, (solo), teen age, {0-1$$smile,|blush,|kind smile,|expression less,|happy,|sadness,} {0-1$$upper body,|full body,|cowboy shot,|face focus,} trending on pixiv, {0-2$$depth of fields,|8k wallpaper,|highly detailed,|pov,} {0-1$$summer, |winter, |spring, |autumn, } beautiful face { |, from below|, from above|, from side|, from behind|, from back} --n nsfw, bad face, lowres, low quality, worst quality, low effort, watermark, signature, ugly, poorly drawn"
|
||||
```
|
||||
|
||||
VRAM 24GBの設定です。VRAMサイズにより`--batch_size` `--vae_batch_size`を調整してください。
|
||||
|
||||
`--prompt`でワイルドカードを利用してランダムに生成しています。適宜調整してください。
|
||||
|
||||
### 画像の加工
|
||||
|
||||
外部のプログラムを用いて、生成した画像を加工します。加工した画像を任意のディレクトリに保存してください。これらがconditioning imageになります。
|
||||
|
||||
加工にはたとえばCannyなら以下のようなスクリプトが使えます。
|
||||
|
||||
```python
|
||||
import glob
|
||||
import os
|
||||
import random
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
IMAGES_DIR = "path/to/generated/images"
|
||||
CANNY_DIR = "path/to/canny/images"
|
||||
|
||||
os.makedirs(CANNY_DIR, exist_ok=True)
|
||||
img_files = glob.glob(IMAGES_DIR + "/*.png")
|
||||
for img_file in img_files:
|
||||
can_file = CANNY_DIR + "/" + os.path.basename(img_file)
|
||||
if os.path.exists(can_file):
|
||||
print("Skip: " + img_file)
|
||||
continue
|
||||
|
||||
print(img_file)
|
||||
|
||||
img = cv2.imread(img_file)
|
||||
|
||||
# random threshold
|
||||
# while True:
|
||||
# threshold1 = random.randint(0, 127)
|
||||
# threshold2 = random.randint(128, 255)
|
||||
# if threshold2 - threshold1 > 80:
|
||||
# break
|
||||
|
||||
# fixed threshold
|
||||
threshold1 = 100
|
||||
threshold2 = 200
|
||||
|
||||
img = cv2.Canny(img, threshold1, threshold2)
|
||||
|
||||
cv2.imwrite(can_file, img)
|
||||
```
|
||||
|
||||
### キャプションファイルの作成
|
||||
|
||||
学習用画像のbasenameと同じ名前で、それぞれの画像に対応したキャプションファイルを作成してください。生成時のプロンプトをそのまま利用すれば良いと思われます。
|
||||
|
||||
`sdxl_gen_img.py` で生成した場合は、画像内のメタデータに生成時のプロンプトが記録されていますので、以下のようなスクリプトで学習用画像と同じディレクトリにキャプションファイルを作成できます(拡張子 `.txt`)。
|
||||
|
||||
```python
|
||||
import glob
|
||||
import os
|
||||
from PIL import Image
|
||||
|
||||
IMAGES_DIR = "path/to/generated/images"
|
||||
|
||||
img_files = glob.glob(IMAGES_DIR + "/*.png")
|
||||
for img_file in img_files:
|
||||
cap_file = img_file.replace(".png", ".txt")
|
||||
if os.path.exists(cap_file):
|
||||
print(f"Skip: {img_file}")
|
||||
continue
|
||||
print(img_file)
|
||||
|
||||
img = Image.open(img_file)
|
||||
prompt = img.text["prompt"] if "prompt" in img.text else ""
|
||||
if prompt == "":
|
||||
print(f"Prompt not found in {img_file}")
|
||||
|
||||
with open(cap_file, "w") as f:
|
||||
f.write(prompt + "\n")
|
||||
```
|
||||
|
||||
### データセットの設定ファイルの作成
|
||||
|
||||
コマンドラインオプションからの指定も可能ですが、`.toml`ファイルを作成する場合は `conditioning_data_dir` に加工した画像を保存したディレクトリを指定します。
|
||||
|
||||
以下は設定ファイルの例です。
|
||||
|
||||
```toml
|
||||
[general]
|
||||
flip_aug = false
|
||||
color_aug = false
|
||||
resolution = [1024,1024]
|
||||
|
||||
[[datasets]]
|
||||
batch_size = 8
|
||||
enable_bucket = false
|
||||
|
||||
[[datasets.subsets]]
|
||||
image_dir = "path/to/generated/image/dir"
|
||||
caption_extension = ".txt"
|
||||
conditioning_data_dir = "path/to/canny/image/dir"
|
||||
```
|
||||
|
||||
## 謝辞
|
||||
|
||||
ControlNetの作者である lllyasviel 氏、実装上のアドバイスとトラブル解決へのご尽力をいただいた furusu 氏、ControlNetデータセットを実装していただいた ddPn08 氏に感謝いたします。
|
||||
|
||||
## サンプル
|
||||
Canny
|
||||

|
||||
|
||||

|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
217
docs/train_lllite_README.md
Normal file
217
docs/train_lllite_README.md
Normal file
@@ -0,0 +1,217 @@
|
||||
# About ControlNet-LLLite
|
||||
|
||||
__This is an extremely experimental implementation and may change significantly in the future.__
|
||||
|
||||
日本語版は[こちら](./train_lllite_README-ja.md)
|
||||
|
||||
## Overview
|
||||
|
||||
ControlNet-LLLite is a lightweight version of [ControlNet](https://github.com/lllyasviel/ControlNet). It is a "LoRA Like Lite" that is inspired by LoRA and has a lightweight structure. Currently, only SDXL is supported.
|
||||
|
||||
## Sample weight file and inference
|
||||
|
||||
Sample weight file is available here: https://huggingface.co/kohya-ss/controlnet-lllite
|
||||
|
||||
A custom node for ComfyUI is available: https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI
|
||||
|
||||
Sample images are at the end of this page.
|
||||
|
||||
## Model structure
|
||||
|
||||
A single LLLite module consists of a conditioning image embedding that maps a conditioning image to a latent space and a small network with a structure similar to LoRA. The LLLite module is added to U-Net's Linear and Conv in the same way as LoRA. Please refer to the source code for details.
|
||||
|
||||
Due to the limitations of the inference environment, only CrossAttention (attn1 q/k/v, attn2 q) is currently added.
|
||||
|
||||
## Model training
|
||||
|
||||
### Preparing the dataset
|
||||
|
||||
In addition to the normal dataset, please store the conditioning image in the directory specified by `conditioning_data_dir`. The conditioning image must have the same basename as the training image. The conditioning image will be automatically resized to the same size as the training image. The conditioning image does not require a caption file.
|
||||
|
||||
```toml
|
||||
[[datasets.subsets]]
|
||||
image_dir = "path/to/image/dir"
|
||||
caption_extension = ".txt"
|
||||
conditioning_data_dir = "path/to/conditioning/image/dir"
|
||||
```
|
||||
|
||||
At the moment, random_crop cannot be used.
|
||||
|
||||
For training data, it is easiest to use a synthetic dataset with the original model-generated images as training images and processed images as conditioning images (the quality of the dataset may be problematic). See below for specific methods of synthesizing datasets.
|
||||
|
||||
Note that if you use an image with a different art style than the original model as a training image, the model will have to learn not only the control but also the art style. ControlNet-LLLite has a small capacity, so it is not suitable for learning art styles. In such cases, increase the number of dimensions as described below.
|
||||
|
||||
### Training
|
||||
|
||||
Run `sdxl_train_control_net_lllite.py`. You can specify the dimension of the conditioning image embedding with `--cond_emb_dim`. You can specify the rank of the LoRA-like module with `--network_dim`. Other options are the same as `sdxl_train_network.py`, but `--network_module` is not required.
|
||||
|
||||
Since a large amount of memory is used during training, please enable memory-saving options such as cache and gradient checkpointing. It is also effective to use BFloat16 with the `--full_bf16` option (requires RTX 30 series or later GPU). It has been confirmed to work with 24GB VRAM.
|
||||
|
||||
For the sample Canny, the dimension of the conditioning image embedding is 32. The rank of the LoRA-like module is also 64. Adjust according to the features of the conditioning image you are targeting.
|
||||
|
||||
(The sample Canny is probably quite difficult. It may be better to reduce it to about half for depth, etc.)
|
||||
|
||||
The following is an example of a .toml configuration.
|
||||
|
||||
```toml
|
||||
pretrained_model_name_or_path = "/path/to/model_trained_on.safetensors"
|
||||
max_train_epochs = 12
|
||||
max_data_loader_n_workers = 4
|
||||
persistent_data_loader_workers = true
|
||||
seed = 42
|
||||
gradient_checkpointing = true
|
||||
mixed_precision = "bf16"
|
||||
save_precision = "bf16"
|
||||
full_bf16 = true
|
||||
optimizer_type = "adamw8bit"
|
||||
learning_rate = 2e-4
|
||||
xformers = true
|
||||
output_dir = "/path/to/output/dir"
|
||||
output_name = "output_name"
|
||||
save_every_n_epochs = 1
|
||||
save_model_as = "safetensors"
|
||||
vae_batch_size = 4
|
||||
cache_latents = true
|
||||
cache_latents_to_disk = true
|
||||
cache_text_encoder_outputs = true
|
||||
cache_text_encoder_outputs_to_disk = true
|
||||
network_dim = 64
|
||||
cond_emb_dim = 32
|
||||
dataset_config = "/path/to/dataset.toml"
|
||||
```
|
||||
|
||||
### Inference
|
||||
|
||||
If you want to generate images with a script, run `sdxl_gen_img.py`. You can specify the LLLite model file with `--control_net_lllite_models`. The dimension is automatically obtained from the model file.
|
||||
|
||||
Specify the conditioning image to be used for inference with `--guide_image_path`. Since preprocess is not performed, if it is Canny, specify an image processed with Canny (white line on black background). `--control_net_preps`, `--control_net_weights`, and `--control_net_ratios` are not supported.
|
||||
|
||||
## How to synthesize a dataset
|
||||
|
||||
### Generating training images
|
||||
|
||||
Generate images with the base model for training. Please generate them with Web UI or ComfyUI etc. The image size should be the default size of the model (1024x1024, etc.). You can also use bucketing. In that case, please generate it at an arbitrary resolution.
|
||||
|
||||
The captions and other settings when generating the images should be the same as when generating the images with the trained ControlNet-LLLite model.
|
||||
|
||||
Save the generated images in an arbitrary directory. Specify this directory in the dataset configuration file.
|
||||
|
||||
|
||||
You can also generate them with `sdxl_gen_img.py` in this repository. For example, run as follows:
|
||||
|
||||
```dos
|
||||
python sdxl_gen_img.py --ckpt path/to/model.safetensors --n_iter 1 --scale 10 --steps 36 --outdir path/to/output/dir --xformers --W 1024 --H 1024 --original_width 2048 --original_height 2048 --bf16 --sampler ddim --batch_size 4 --vae_batch_size 2 --images_per_prompt 512 --max_embeddings_multiples 1 --prompt "{portrait|digital art|anime screen cap|detailed illustration} of 1girl, {standing|sitting|walking|running|dancing} on {classroom|street|town|beach|indoors|outdoors}, {looking at viewer|looking away|looking at another}, {in|wearing} {shirt and skirt|school uniform|casual wear} { |, dynamic pose}, (solo), teen age, {0-1$$smile,|blush,|kind smile,|expression less,|happy,|sadness,} {0-1$$upper body,|full body,|cowboy shot,|face focus,} trending on pixiv, {0-2$$depth of fields,|8k wallpaper,|highly detailed,|pov,} {0-1$$summer, |winter, |spring, |autumn, } beautiful face { |, from below|, from above|, from side|, from behind|, from back} --n nsfw, bad face, lowres, low quality, worst quality, low effort, watermark, signature, ugly, poorly drawn"
|
||||
```
|
||||
|
||||
This is a setting for VRAM 24GB. Adjust `--batch_size` and `--vae_batch_size` according to the VRAM size.
|
||||
|
||||
The images are generated randomly using wildcards in `--prompt`. Adjust as necessary.
|
||||
|
||||
### Processing images
|
||||
|
||||
Use an external program to process the generated images. Save the processed images in an arbitrary directory. These will be the conditioning images.
|
||||
|
||||
For example, you can use the following script to process the images with Canny.
|
||||
|
||||
```python
|
||||
import glob
|
||||
import os
|
||||
import random
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
IMAGES_DIR = "path/to/generated/images"
|
||||
CANNY_DIR = "path/to/canny/images"
|
||||
|
||||
os.makedirs(CANNY_DIR, exist_ok=True)
|
||||
img_files = glob.glob(IMAGES_DIR + "/*.png")
|
||||
for img_file in img_files:
|
||||
can_file = CANNY_DIR + "/" + os.path.basename(img_file)
|
||||
if os.path.exists(can_file):
|
||||
print("Skip: " + img_file)
|
||||
continue
|
||||
|
||||
print(img_file)
|
||||
|
||||
img = cv2.imread(img_file)
|
||||
|
||||
# random threshold
|
||||
# while True:
|
||||
# threshold1 = random.randint(0, 127)
|
||||
# threshold2 = random.randint(128, 255)
|
||||
# if threshold2 - threshold1 > 80:
|
||||
# break
|
||||
|
||||
# fixed threshold
|
||||
threshold1 = 100
|
||||
threshold2 = 200
|
||||
|
||||
img = cv2.Canny(img, threshold1, threshold2)
|
||||
|
||||
cv2.imwrite(can_file, img)
|
||||
```
|
||||
|
||||
### Creating caption files
|
||||
|
||||
Create a caption file for each image with the same basename as the training image. It is fine to use the same caption as the one used when generating the image.
|
||||
|
||||
If you generated the images with `sdxl_gen_img.py`, you can use the following script to create the caption files (`*.txt`) from the metadata in the generated images.
|
||||
|
||||
```python
|
||||
import glob
|
||||
import os
|
||||
from PIL import Image
|
||||
|
||||
IMAGES_DIR = "path/to/generated/images"
|
||||
|
||||
img_files = glob.glob(IMAGES_DIR + "/*.png")
|
||||
for img_file in img_files:
|
||||
cap_file = img_file.replace(".png", ".txt")
|
||||
if os.path.exists(cap_file):
|
||||
print(f"Skip: {img_file}")
|
||||
continue
|
||||
print(img_file)
|
||||
|
||||
img = Image.open(img_file)
|
||||
prompt = img.text["prompt"] if "prompt" in img.text else ""
|
||||
if prompt == "":
|
||||
print(f"Prompt not found in {img_file}")
|
||||
|
||||
with open(cap_file, "w") as f:
|
||||
f.write(prompt + "\n")
|
||||
```
|
||||
|
||||
### Creating a dataset configuration file
|
||||
|
||||
You can use the command line arguments of `sdxl_train_control_net_lllite.py` to specify the conditioning image directory. However, if you want to use a `.toml` file, specify the conditioning image directory in `conditioning_data_dir`.
|
||||
|
||||
```toml
|
||||
[general]
|
||||
flip_aug = false
|
||||
color_aug = false
|
||||
resolution = [1024,1024]
|
||||
|
||||
[[datasets]]
|
||||
batch_size = 8
|
||||
enable_bucket = false
|
||||
|
||||
[[datasets.subsets]]
|
||||
image_dir = "path/to/generated/image/dir"
|
||||
caption_extension = ".txt"
|
||||
conditioning_data_dir = "path/to/canny/image/dir"
|
||||
```
|
||||
|
||||
## Credit
|
||||
|
||||
I would like to thank lllyasviel, the author of ControlNet, furusu, who provided me with advice on implementation and helped me solve problems, and ddPn08, who implemented the ControlNet dataset.
|
||||
|
||||
## Sample
|
||||
|
||||
Canny
|
||||

|
||||
|
||||

|
||||
|
||||

|
||||
|
||||

|
||||
@@ -12,11 +12,31 @@ Conv2d 3x3への拡大は [cloneofsimo氏](https://github.com/cloneofsimo/lora)
|
||||
|
||||
[学習についての共通ドキュメント](./train_README-ja.md) もあわせてご覧ください。
|
||||
|
||||
# 学習できるLoRAの種類
|
||||
|
||||
以下の二種類をサポートします。以下は当リポジトリ内の独自の名称です。
|
||||
|
||||
1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers、リエラと読みます)
|
||||
|
||||
Linear およびカーネルサイズ 1x1 の Conv2d に適用されるLoRA
|
||||
|
||||
2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers、セリアと読みます)
|
||||
|
||||
1.に加え、カーネルサイズ 3x3 の Conv2d に適用されるLoRA
|
||||
|
||||
LoRA-LierLaに比べ、LoRA-C3Liarは適用される層が増える分、高い精度が期待できるかもしれません。
|
||||
|
||||
また学習時は __DyLoRA__ を使用することもできます(後述します)。
|
||||
|
||||
## 学習したモデルに関する注意
|
||||
|
||||
cloneofsimo氏のリポジトリ、およびd8ahazard氏の[Dreambooth Extension for Stable-Diffusion-WebUI](https://github.com/d8ahazard/sd_dreambooth_extension)とは、現時点では互換性がありません。いくつかの機能拡張を行っているためです(後述)。
|
||||
LoRA-LierLa は、AUTOMATIC1111氏のWeb UIのLoRA機能で使用することができます。
|
||||
|
||||
WebUI等で画像生成する場合には、学習したLoRAのモデルを学習元のStable Diffusionのモデルにこのリポジトリ内のスクリプトであらかじめマージしておくか、こちらの[WebUI用extension](https://github.com/kohya-ss/sd-webui-additional-networks)を使ってください。
|
||||
LoRA-C3Liarを使いWeb UIで生成するには、こちらの[WebUI用extension](https://github.com/kohya-ss/sd-webui-additional-networks)を使ってください。
|
||||
|
||||
いずれも学習したLoRAのモデルを、Stable Diffusionのモデルにこのリポジトリ内のスクリプトであらかじめマージすることもできます。
|
||||
|
||||
cloneofsimo氏のリポジトリ、およびd8ahazard氏の[Dreambooth Extension for Stable-Diffusion-WebUI](https://github.com/d8ahazard/sd_dreambooth_extension)とは、現時点では互換性がありません。いくつかの機能拡張を行っているためです(後述)。
|
||||
|
||||
# 学習の手順
|
||||
|
||||
@@ -31,9 +51,9 @@ WebUI等で画像生成する場合には、学習したLoRAのモデルを学
|
||||
|
||||
`train_network.py`を用います。
|
||||
|
||||
`train_network.py`では `--network_module` オプションに、学習対象のモジュール名を指定します。LoRAに対応するのはnetwork.loraとなりますので、それを指定してください。
|
||||
`train_network.py`では `--network_module` オプションに、学習対象のモジュール名を指定します。LoRAに対応するのは`network.lora`となりますので、それを指定してください。
|
||||
|
||||
なお学習率は通常のDreamBoothやfine tuningよりも高めの、1e-4程度を指定するとよいようです。
|
||||
なお学習率は通常のDreamBoothやfine tuningよりも高めの、`1e-4`~`1e-3`程度を指定するとよいようです。
|
||||
|
||||
以下はコマンドラインの例です。
|
||||
|
||||
@@ -56,6 +76,8 @@ accelerate launch --num_cpu_threads_per_process 1 train_network.py
|
||||
--network_module=networks.lora
|
||||
```
|
||||
|
||||
このコマンドラインでは LoRA-LierLa が学習されます。
|
||||
|
||||
`--output_dir` オプションで指定したフォルダに、LoRAのモデルが保存されます。他のオプション、オプティマイザ等については [学習の共通ドキュメント](./train_README-ja.md) の「よく使われるオプション」も参照してください。
|
||||
|
||||
その他、以下のオプションが指定できます。
|
||||
@@ -83,26 +105,151 @@ accelerate launch --num_cpu_threads_per_process 1 train_network.py
|
||||
|
||||
`--network_train_unet_only` と `--network_train_text_encoder_only` の両方とも未指定時(デフォルト)はText EncoderとU-Netの両方のLoRAモジュールを有効にします。
|
||||
|
||||
## LoRA を Conv2d に拡大して適用する
|
||||
# その他の学習方法
|
||||
|
||||
通常のLoRAは Linear およぴカーネルサイズ 1x1 の Conv2d にのみ適用されますが、カーネルサイズ 3x3 のConv2dに適用を拡大することもできます。
|
||||
## LoRA-C3Lier を学習する
|
||||
|
||||
`--network_args` に以下のように指定してください。`conv_dim` で Conv2d (3x3) の rank を、`conv_alpha` で alpha を指定してください。
|
||||
|
||||
```
|
||||
--network_args "conv_dim=1" "conv_alpha=1"
|
||||
--network_args "conv_dim=4" "conv_alpha=1"
|
||||
```
|
||||
|
||||
以下のように alpha 省略時は1になります。
|
||||
|
||||
```
|
||||
--network_args "conv_dim=1"
|
||||
--network_args "conv_dim=4"
|
||||
```
|
||||
|
||||
## DyLoRA
|
||||
|
||||
DyLoRAはこちらの論文で提案されたものです。[DyLoRA: Parameter Efficient Tuning of Pre-trained Models using Dynamic Search-Free Low-Rank Adaptation](https://arxiv.org/abs/2210.07558) 公式実装は[こちら](https://github.com/huawei-noah/KD-NLP/tree/main/DyLoRA)です。
|
||||
|
||||
論文によると、LoRAのrankは必ずしも高いほうが良いわけではなく、対象のモデル、データセット、タスクなどにより適切なrankを探す必要があるようです。DyLoRAを使うと、指定したdim(rank)以下のさまざまなrankで同時にLoRAを学習します。これにより最適なrankをそれぞれ学習して探す手間を省くことができます。
|
||||
|
||||
当リポジトリの実装は公式実装をベースに独自の拡張を加えています(そのため不具合などあるかもしれません)。
|
||||
|
||||
### 当リポジトリのDyLoRAの特徴
|
||||
|
||||
学習後のDyLoRAのモデルファイルはLoRAと互換性があります。また、モデルファイルから指定したdim(rank)以下の複数のdimのLoRAを抽出できます。
|
||||
|
||||
DyLoRA-LierLa、DyLoRA-C3Lierのどちらも学習できます。
|
||||
|
||||
### DyLoRAで学習する
|
||||
|
||||
`--network_module=networks.dylora` のように、DyLoRAに対応する`network.dylora`を指定してください。
|
||||
|
||||
また `--network_args` に、たとえば`--network_args "unit=4"`のように`unit`を指定します。`unit`はrankを分割する単位です。たとえば`--network_dim=16 --network_args "unit=4"` のように指定します。`unit`は`network_dim`を割り切れる値(`network_dim`は`unit`の倍数)としてください。
|
||||
|
||||
`unit`を指定しない場合は、`unit=1`として扱われます。
|
||||
|
||||
記述例は以下です。
|
||||
|
||||
```
|
||||
--network_module=networks.dylora --network_dim=16 --network_args "unit=4"
|
||||
|
||||
--network_module=networks.dylora --network_dim=32 --network_alpha=16 --network_args "unit=4"
|
||||
```
|
||||
|
||||
DyLoRA-C3Lierの場合は、`--network_args` に`"conv_dim=4"`のように`conv_dim`を指定します。通常のLoRAと異なり、`conv_dim`は`network_dim`と同じ値である必要があります。記述例は以下です。
|
||||
|
||||
```
|
||||
--network_module=networks.dylora --network_dim=16 --network_args "conv_dim=16" "unit=4"
|
||||
|
||||
--network_module=networks.dylora --network_dim=32 --network_alpha=16 --network_args "conv_dim=32" "conv_alpha=16" "unit=8"
|
||||
```
|
||||
|
||||
たとえばdim=16、unit=4(後述)で学習すると、4、8、12、16の4つのrankのLoRAを学習、抽出できます。抽出した各モデルで画像を生成し、比較することで、最適なrankのLoRAを選択できます。
|
||||
|
||||
その他のオプションは通常のLoRAと同じです。
|
||||
|
||||
※ `unit`は当リポジトリの独自拡張で、DyLoRAでは同dim(rank)の通常LoRAに比べると学習時間が長くなることが予想されるため、分割単位を大きくしたものです。
|
||||
|
||||
### DyLoRAのモデルからLoRAモデルを抽出する
|
||||
|
||||
`networks`フォルダ内の `extract_lora_from_dylora.py`を使用します。指定した`unit`単位で、DyLoRAのモデルからLoRAのモデルを抽出します。
|
||||
|
||||
コマンドラインはたとえば以下のようになります。
|
||||
|
||||
```powershell
|
||||
python networks\extract_lora_from_dylora.py --model "foldername/dylora-model.safetensors" --save_to "foldername/dylora-model-split.safetensors" --unit 4
|
||||
```
|
||||
|
||||
`--model` にはDyLoRAのモデルファイルを指定します。`--save_to` には抽出したモデルを保存するファイル名を指定します(rankの数値がファイル名に付加されます)。`--unit` にはDyLoRAの学習時の`unit`を指定します。
|
||||
|
||||
## 階層別学習率
|
||||
|
||||
詳細は[PR #355](https://github.com/kohya-ss/sd-scripts/pull/355) をご覧ください。
|
||||
|
||||
SDXLは現在サポートしていません。
|
||||
|
||||
フルモデルの25個のブロックの重みを指定できます。最初のブロックに該当するLoRAは存在しませんが、階層別LoRA適用等との互換性のために25個としています。またconv2d3x3に拡張しない場合も一部のブロックにはLoRAが存在しませんが、記述を統一するため常に25個の値を指定してください。
|
||||
|
||||
`--network_args` で以下の引数を指定してください。
|
||||
|
||||
- `down_lr_weight` : U-Netのdown blocksの学習率の重みを指定します。以下が指定可能です。
|
||||
- ブロックごとの重み : `"down_lr_weight=0,0,0,0,0,0,1,1,1,1,1,1"` のように12個の数値を指定します。
|
||||
- プリセットからの指定 : `"down_lr_weight=sine"` のように指定します(サインカーブで重みを指定します)。sine, cosine, linear, reverse_linear, zeros が指定可能です。また `"down_lr_weight=cosine+.25"` のように `+数値` を追加すると、指定した数値を加算します(0.25~1.25になります)。
|
||||
- `mid_lr_weight` : U-Netのmid blockの学習率の重みを指定します。`"down_lr_weight=0.5"` のように数値を一つだけ指定します。
|
||||
- `up_lr_weight` : U-Netのup blocksの学習率の重みを指定します。down_lr_weightと同様です。
|
||||
- 指定を省略した部分は1.0として扱われます。また重みを0にするとそのブロックのLoRAモジュールは作成されません。
|
||||
- `block_lr_zero_threshold` : 重みがこの値以下の場合、LoRAモジュールを作成しません。デフォルトは0です。
|
||||
|
||||
### 階層別学習率コマンドライン指定例:
|
||||
|
||||
```powershell
|
||||
--network_args "down_lr_weight=0.5,0.5,0.5,0.5,1.0,1.0,1.0,1.0,1.5,1.5,1.5,1.5" "mid_lr_weight=2.0" "up_lr_weight=1.5,1.5,1.5,1.5,1.0,1.0,1.0,1.0,0.5,0.5,0.5,0.5"
|
||||
|
||||
--network_args "block_lr_zero_threshold=0.1" "down_lr_weight=sine+.5" "mid_lr_weight=1.5" "up_lr_weight=cosine+.5"
|
||||
```
|
||||
|
||||
### 階層別学習率tomlファイル指定例:
|
||||
|
||||
```toml
|
||||
network_args = [ "down_lr_weight=0.5,0.5,0.5,0.5,1.0,1.0,1.0,1.0,1.5,1.5,1.5,1.5", "mid_lr_weight=2.0", "up_lr_weight=1.5,1.5,1.5,1.5,1.0,1.0,1.0,1.0,0.5,0.5,0.5,0.5",]
|
||||
|
||||
network_args = [ "block_lr_zero_threshold=0.1", "down_lr_weight=sine+.5", "mid_lr_weight=1.5", "up_lr_weight=cosine+.5", ]
|
||||
```
|
||||
|
||||
## 階層別dim (rank)
|
||||
|
||||
フルモデルの25個のブロックのdim (rank)を指定できます。階層別学習率と同様に一部のブロックにはLoRAが存在しない場合がありますが、常に25個の値を指定してください。
|
||||
|
||||
`--network_args` で以下の引数を指定してください。
|
||||
|
||||
- `block_dims` : 各ブロックのdim (rank)を指定します。`"block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"` のように25個の数値を指定します。
|
||||
- `block_alphas` : 各ブロックのalphaを指定します。block_dimsと同様に25個の数値を指定します。省略時はnetwork_alphaの値が使用されます。
|
||||
- `conv_block_dims` : LoRAをConv2d 3x3に拡張し、各ブロックのdim (rank)を指定します。
|
||||
- `conv_block_alphas` : LoRAをConv2d 3x3に拡張したときの各ブロックのalphaを指定します。省略時はconv_alphaの値が使用されます。
|
||||
|
||||
### 階層別dim (rank)コマンドライン指定例:
|
||||
|
||||
```powershell
|
||||
--network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2"
|
||||
|
||||
--network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2" "conv_block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"
|
||||
|
||||
--network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2" "block_alphas=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"
|
||||
```
|
||||
|
||||
### 階層別dim (rank)tomlファイル指定例:
|
||||
|
||||
```toml
|
||||
network_args = [ "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2",]
|
||||
|
||||
network_args = [ "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2", "block_alphas=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2",]
|
||||
```
|
||||
|
||||
# その他のスクリプト
|
||||
|
||||
マージ等LoRAに関連するスクリプト群です。
|
||||
|
||||
## マージスクリプトについて
|
||||
|
||||
merge_lora.pyでStable DiffusionのモデルにLoRAの学習結果をマージしたり、複数のLoRAモデルをマージしたりできます。
|
||||
|
||||
SDXL向けにはsdxl_merge_lora.pyを用意しています。オプション等は同一ですので、以下のmerge_lora.pyを読み替えてください。
|
||||
|
||||
### Stable DiffusionのモデルにLoRAのモデルをマージする
|
||||
|
||||
マージ後のモデルは通常のStable Diffusionのckptと同様に扱えます。たとえば以下のようなコマンドラインになります。
|
||||
@@ -133,26 +280,28 @@ python networks\merge_lora.py --sd_model ..\model\model.ckpt
|
||||
|
||||
### 複数のLoRAのモデルをマージする
|
||||
|
||||
複数のLoRAモデルをひとつずつSDモデルに適用する場合と、複数のLoRAモデルをマージしてからSDモデルにマージする場合とは、計算順序の関連で微妙に異なる結果になります。
|
||||
--concatオプションを指定すると、複数のLoRAを単純に結合して新しいLoRAモデルを作成できます。ファイルサイズ(およびdim/rank)は指定したLoRAの合計サイズになります(マージ時にdim (rank)を変更する場合は `svd_merge_lora.py` を使用してください)。
|
||||
|
||||
たとえば以下のようなコマンドラインになります。
|
||||
|
||||
```
|
||||
python networks\merge_lora.py
|
||||
python networks\merge_lora.py --save_precision bf16
|
||||
--save_to ..\lora_train1\model-char1-style1-merged.safetensors
|
||||
--models ..\lora_train1\last.safetensors ..\lora_train2\last.safetensors --ratios 0.6 0.4
|
||||
--models ..\lora_train1\last.safetensors ..\lora_train2\last.safetensors
|
||||
--ratios 1.0 -1.0 --concat --shuffle
|
||||
```
|
||||
|
||||
--sd_modelオプションは指定不要です。
|
||||
--concatオプションを指定します。
|
||||
|
||||
また--shuffleオプションを追加し、重みをシャッフルします。シャッフルしないとマージ後のLoRAから元のLoRAを取り出せるため、コピー機学習などの場合には学習元データが明らかになります。ご注意ください。
|
||||
|
||||
--save_toオプションにマージ後のLoRAモデルの保存先を指定します(.ckptまたは.safetensors、拡張子で自動判定)。
|
||||
|
||||
--modelsに学習したLoRAのモデルファイルを指定します。三つ以上も指定可能です。
|
||||
|
||||
--ratiosにそれぞれのモデルの比率(どのくらい重みを元モデルに反映するか)を0~1.0の数値で指定します。二つのモデルを一対一でマージす場合は、「0.5 0.5」になります。「1.0 1.0」では合計の重みが大きくなりすぎて、恐らく結果はあまり望ましくないものになると思われます。
|
||||
|
||||
v1で学習したLoRAとv2で学習したLoRA、rank(次元数)や``alpha``の異なるLoRAはマージできません。U-NetだけのLoRAとU-Net+Text EncoderのLoRAはマージできるはずですが、結果は未知数です。
|
||||
--ratiosにそれぞれのモデルの比率(どのくらい重みを元モデルに反映するか)を0~1.0の数値で指定します。二つのモデルを一対一でマージする場合は、「0.5 0.5」になります。「1.0 1.0」では合計の重みが大きくなりすぎて、恐らく結果はあまり望ましくないものになると思われます。
|
||||
|
||||
v1で学習したLoRAとv2で学習したLoRA、rank(次元数)の異なるLoRAはマージできません。U-NetだけのLoRAとU-Net+Text EncoderのLoRAはマージできるはずですが、結果は未知数です。
|
||||
|
||||
### その他のオプション
|
||||
|
||||
@@ -161,6 +310,7 @@ v1で学習したLoRAとv2で学習したLoRA、rank(次元数)や``alpha``
|
||||
* save_precision
|
||||
* モデル保存時の精度をfloat、fp16、bf16から指定できます。省略時はprecisionと同じ精度になります。
|
||||
|
||||
他にもいくつかのオプションがありますので、--helpで確認してください。
|
||||
|
||||
## 複数のrankが異なるLoRAのモデルをマージする
|
||||
|
||||
@@ -188,6 +338,73 @@ gen_img_diffusers.pyに、--network_module、--network_weightsの各オプショ
|
||||
|
||||
--network_mulオプションで0~1.0の数値を指定すると、LoRAの適用率を変えられます。
|
||||
|
||||
## Diffusersのpipelineで生成する
|
||||
|
||||
以下の例を参考にしてください。必要なファイルはnetworks/lora.pyのみです。Diffusersのバージョンは0.10.2以外では動作しない可能性があります。
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from networks.lora import LoRAModule, create_network_from_weights
|
||||
from safetensors.torch import load_file
|
||||
|
||||
# if the ckpt is CompVis based, convert it to Diffusers beforehand with tools/convert_diffusers20_original_sd.py. See --help for more details.
|
||||
|
||||
model_id_or_dir = r"model_id_on_hugging_face_or_dir"
|
||||
device = "cuda"
|
||||
|
||||
# create pipe
|
||||
print(f"creating pipe from {model_id_or_dir}...")
|
||||
pipe = StableDiffusionPipeline.from_pretrained(model_id_or_dir, revision="fp16", torch_dtype=torch.float16)
|
||||
pipe = pipe.to(device)
|
||||
vae = pipe.vae
|
||||
text_encoder = pipe.text_encoder
|
||||
unet = pipe.unet
|
||||
|
||||
# load lora networks
|
||||
print(f"loading lora networks...")
|
||||
|
||||
lora_path1 = r"lora1.safetensors"
|
||||
sd = load_file(lora_path1) # If the file is .ckpt, use torch.load instead.
|
||||
network1, sd = create_network_from_weights(0.5, None, vae, text_encoder,unet, sd)
|
||||
network1.apply_to(text_encoder, unet)
|
||||
network1.load_state_dict(sd)
|
||||
network1.to(device, dtype=torch.float16)
|
||||
|
||||
# # You can merge weights instead of apply_to+load_state_dict. network.set_multiplier does not work
|
||||
# network.merge_to(text_encoder, unet, sd)
|
||||
|
||||
lora_path2 = r"lora2.safetensors"
|
||||
sd = load_file(lora_path2)
|
||||
network2, sd = create_network_from_weights(0.7, None, vae, text_encoder,unet, sd)
|
||||
network2.apply_to(text_encoder, unet)
|
||||
network2.load_state_dict(sd)
|
||||
network2.to(device, dtype=torch.float16)
|
||||
|
||||
lora_path3 = r"lora3.safetensors"
|
||||
sd = load_file(lora_path3)
|
||||
network3, sd = create_network_from_weights(0.5, None, vae, text_encoder,unet, sd)
|
||||
network3.apply_to(text_encoder, unet)
|
||||
network3.load_state_dict(sd)
|
||||
network3.to(device, dtype=torch.float16)
|
||||
|
||||
# prompts
|
||||
prompt = "masterpiece, best quality, 1girl, in white shirt, looking at viewer"
|
||||
negative_prompt = "bad quality, worst quality, bad anatomy, bad hands"
|
||||
|
||||
# exec pipe
|
||||
print("generating image...")
|
||||
with torch.autocast("cuda"):
|
||||
image = pipe(prompt, guidance_scale=7.5, negative_prompt=negative_prompt).images[0]
|
||||
|
||||
# if not merged, you can use set_multiplier
|
||||
# network1.set_multiplier(0.8)
|
||||
# and generate image again...
|
||||
|
||||
# save image
|
||||
image.save(r"by_diffusers..png")
|
||||
```
|
||||
|
||||
## 二つのモデルの差分からLoRAモデルを作成する
|
||||
|
||||
[こちらのディスカッション](https://github.com/cloneofsimo/lora/discussions/56)を参考に実装したものです。数式はそのまま使わせていただきました(よく理解していませんが近似には特異値分解を用いるようです)。
|
||||
@@ -256,14 +473,14 @@ python tools\resize_images_to_resolution.py --max_resolution 512x512,384x384,256
|
||||
- 縮小時の補完方法を指定します。``area, cubic, lanczos4``から選択可能で、デフォルトは``area``です。
|
||||
|
||||
|
||||
## 追加情報
|
||||
# 追加情報
|
||||
|
||||
### cloneofsimo氏のリポジトリとの違い
|
||||
## cloneofsimo氏のリポジトリとの違い
|
||||
|
||||
2022/12/25時点では、当リポジトリはLoRAの適用個所をText EncoderのMLP、U-NetのFFN、Transformerのin/out projectionに拡大し、表現力が増しています。ただその代わりメモリ使用量は増え、8GBぎりぎりになりました。
|
||||
|
||||
またモジュール入れ替え機構は全く異なります。
|
||||
|
||||
### 将来拡張について
|
||||
## 将来拡張について
|
||||
|
||||
LoRAだけでなく他の拡張にも対応可能ですので、それらも追加予定です。
|
||||
466
docs/train_network_README-zh.md
Normal file
466
docs/train_network_README-zh.md
Normal file
@@ -0,0 +1,466 @@
|
||||
# 关于LoRA的学习。
|
||||
|
||||
[LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685)(arxiv)、[LoRA](https://github.com/microsoft/LoRA)(github)这是应用于Stable Diffusion“稳定扩散”的内容。
|
||||
|
||||
[cloneofsimo先生的代码仓库](https://github.com/cloneofsimo/lora) 我们非常感謝您提供的参考。非常感謝。
|
||||
|
||||
通常情況下,LoRA只适用于Linear和Kernel大小为1x1的Conv2d,但也可以將其擴展到Kernel大小为3x3的Conv2d。
|
||||
|
||||
Conv2d 3x3的扩展最初是由 [cloneofsimo先生的代码仓库](https://github.com/cloneofsimo/lora)
|
||||
而KohakuBlueleaf先生在[LoCon](https://github.com/KohakuBlueleaf/LoCon)中揭示了其有效性。我们深深地感谢KohakuBlueleaf先生。
|
||||
|
||||
看起来即使在8GB VRAM上也可以勉强运行。
|
||||
|
||||
请同时查看关于[学习的通用文档](./train_README-zh.md)。
|
||||
# 可学习的LoRA 类型
|
||||
|
||||
支持以下两种类型。以下是本仓库中自定义的名称。
|
||||
|
||||
1. __LoRA-LierLa__:(用于 __Li__ n __e__ a __r__ __La__ yers 的 LoRA,读作 "Liela")
|
||||
|
||||
适用于 Linear 和卷积层 Conv2d 的 1x1 Kernel 的 LoRA
|
||||
|
||||
2. __LoRA-C3Lier__:(用于具有 3x3 Kernel 的卷积层和 __Li__ n __e__ a __r__ 层的 LoRA,读作 "Seria")
|
||||
|
||||
除了第一种类型外,还适用于 3x3 Kernel 的 Conv2d 的 LoRA
|
||||
|
||||
与 LoRA-LierLa 相比,LoRA-C3Lier 可能会获得更高的准确性,因为它适用于更多的层。
|
||||
|
||||
在训练时,也可以使用 __DyLoRA__(将在后面介绍)。
|
||||
|
||||
## 请注意与所学模型相关的事项。
|
||||
|
||||
LoRA-LierLa可以用于AUTOMATIC1111先生的Web UI LoRA功能。
|
||||
|
||||
要使用LoRA-C3Liar并在Web UI中生成,请使用此处的[WebUI用extension](https://github.com/kohya-ss/sd-webui-additional-networks)。
|
||||
|
||||
在此存储库的脚本中,您还可以预先将经过训练的LoRA模型合并到Stable Diffusion模型中。
|
||||
|
||||
请注意,与cloneofsimo先生的存储库以及d8ahazard先生的[Stable-Diffusion-WebUI的Dreambooth扩展](https://github.com/d8ahazard/sd_dreambooth_extension)不兼容,因为它们进行了一些功能扩展(如下文所述)。
|
||||
|
||||
# 学习步骤
|
||||
|
||||
请先参考此存储库的README文件并进行环境设置。
|
||||
|
||||
## 准备数据
|
||||
|
||||
请参考 [关于准备学习数据](./train_README-zh.md)。
|
||||
|
||||
## 网络训练
|
||||
|
||||
使用`train_network.py`。
|
||||
|
||||
在`train_network.py`中,使用`--network_module`选项指定要训练的模块名称。对于LoRA模块,它应该是`network.lora`,请指定它。
|
||||
|
||||
请注意,学习率应该比通常的DreamBooth或fine tuning要高,建议指定为`1e-4`至`1e-3`左右。
|
||||
|
||||
以下是命令行示例。
|
||||
|
||||
```
|
||||
accelerate launch --num_cpu_threads_per_process 1 train_network.py
|
||||
--pretrained_model_name_or_path=<.ckpt或.safetensord或Diffusers版模型目录>
|
||||
--dataset_config=<数据集配置的.toml文件>
|
||||
--output_dir=<训练过程中的模型输出文件夹>
|
||||
--output_name=<训练模型输出时的文件名>
|
||||
--save_model_as=safetensors
|
||||
--prior_loss_weight=1.0
|
||||
--max_train_steps=400
|
||||
--learning_rate=1e-4
|
||||
--optimizer_type="AdamW8bit"
|
||||
--xformers
|
||||
--mixed_precision="fp16"
|
||||
--cache_latents
|
||||
--gradient_checkpointing
|
||||
--save_every_n_epochs=1
|
||||
--network_module=networks.lora
|
||||
```
|
||||
|
||||
在这个命令行中,LoRA-LierLa将会被训练。
|
||||
|
||||
LoRA的模型将会被保存在通过`--output_dir`选项指定的文件夹中。关于其他选项和优化器等,请参阅[学习的通用文档](./train_README-zh.md)中的“常用选项”。
|
||||
|
||||
此外,还可以指定以下选项:
|
||||
|
||||
* `--network_dim`
|
||||
* 指定LoRA的RANK(例如:`--network_dim=4`)。默认值为4。数值越大表示表现力越强,但需要更多的内存和时间来训练。而且不要盲目增加此数值。
|
||||
* `--network_alpha`
|
||||
* 指定用于防止下溢并稳定训练的alpha值。默认值为1。如果与`network_dim`指定相同的值,则将获得与以前版本相同的行为。
|
||||
* `--persistent_data_loader_workers`
|
||||
* 在Windows环境中指定可大幅缩短epoch之间的等待时间。
|
||||
* `--max_data_loader_n_workers`
|
||||
* 指定数据读取进程的数量。进程数越多,数据读取速度越快,可以更有效地利用GPU,但会占用主存。默认值为“`8`或`CPU同步执行线程数-1`的最小值”,因此如果主存不足或GPU使用率超过90%,则应将这些数字降低到约`2`或`1`。
|
||||
* `--network_weights`
|
||||
* 在训练之前读取预训练的LoRA权重,并在此基础上进行进一步的训练。
|
||||
* `--network_train_unet_only`
|
||||
* 仅启用与U-Net相关的LoRA模块。在类似fine tuning的学习中指定此选项可能会很有用。
|
||||
* `--network_train_text_encoder_only`
|
||||
* 仅启用与Text Encoder相关的LoRA模块。可能会期望Textual Inversion效果。
|
||||
* `--unet_lr`
|
||||
* 当在U-Net相关的LoRA模块中使用与常规学习率(由`--learning_rate`选项指定)不同的学习率时,应指定此选项。
|
||||
* `--text_encoder_lr`
|
||||
* 当在Text Encoder相关的LoRA模块中使用与常规学习率(由`--learning_rate`选项指定)不同的学习率时,应指定此选项。可能最好将Text Encoder的学习率稍微降低(例如5e-5)。
|
||||
* `--network_args`
|
||||
* 可以指定多个参数。将在下面详细说明。
|
||||
|
||||
当未指定`--network_train_unet_only`和`--network_train_text_encoder_only`时(默认情况),将启用Text Encoder和U-Net的两个LoRA模块。
|
||||
|
||||
# 其他的学习方法
|
||||
|
||||
## 学习 LoRA-C3Lier
|
||||
|
||||
请使用以下方式
|
||||
|
||||
```
|
||||
--network_args "conv_dim=4"
|
||||
```
|
||||
|
||||
DyLoRA是在这篇论文中提出的[DyLoRA: Parameter Efficient Tuning of Pre-trained Models using Dynamic Search-Free Low-Rank Adaptation](https://arxiv.org/abs/2210.07558),
|
||||
[其官方实现可在这里找到](https://github.com/huawei-noah/KD-NLP/tree/main/DyLoRA)。
|
||||
|
||||
根据论文,LoRA的rank并不是越高越好,而是需要根据模型、数据集、任务等因素来寻找合适的rank。使用DyLoRA,可以同时在指定的维度(rank)下学习多种rank的LoRA,从而省去了寻找最佳rank的麻烦。
|
||||
|
||||
本存储库的实现基于官方实现进行了自定义扩展(因此可能存在缺陷)。
|
||||
|
||||
### 本存储库DyLoRA的特点
|
||||
|
||||
DyLoRA训练后的模型文件与LoRA兼容。此外,可以从模型文件中提取多个低于指定维度(rank)的LoRA。
|
||||
|
||||
DyLoRA-LierLa和DyLoRA-C3Lier均可训练。
|
||||
|
||||
### 使用DyLoRA进行训练
|
||||
|
||||
请指定与DyLoRA相对应的`network.dylora`,例如 `--network_module=networks.dylora`。
|
||||
|
||||
此外,通过 `--network_args` 指定例如`--network_args "unit=4"`的参数。`unit`是划分rank的单位。例如,可以指定为`--network_dim=16 --network_args "unit=4"`。请将`unit`视为可以被`network_dim`整除的值(`network_dim`是`unit`的倍数)。
|
||||
|
||||
如果未指定`unit`,则默认为`unit=1`。
|
||||
|
||||
以下是示例说明。
|
||||
|
||||
```
|
||||
--network_module=networks.dylora --network_dim=16 --network_args "unit=4"
|
||||
|
||||
--network_module=networks.dylora --network_dim=32 --network_alpha=16 --network_args "unit=4"
|
||||
```
|
||||
|
||||
对于DyLoRA-C3Lier,需要在 `--network_args` 中指定 `conv_dim`,例如 `conv_dim=4`。与普通的LoRA不同,`conv_dim`必须与`network_dim`具有相同的值。以下是一个示例描述:
|
||||
|
||||
```
|
||||
--network_module=networks.dylora --network_dim=16 --network_args "conv_dim=16" "unit=4"
|
||||
|
||||
--network_module=networks.dylora --network_dim=32 --network_alpha=16 --network_args "conv_dim=32" "conv_alpha=16" "unit=8"
|
||||
```
|
||||
|
||||
例如,当使用dim=16、unit=4(如下所述)进行学习时,可以学习和提取4个rank的LoRA,即4、8、12和16。通过在每个提取的模型中生成图像并进行比较,可以选择最佳rank的LoRA。
|
||||
|
||||
其他选项与普通的LoRA相同。
|
||||
|
||||
*`unit`是本存储库的独有扩展,在DyLoRA中,由于预计相比同维度(rank)的普通LoRA,学习时间更长,因此将分割单位增加。
|
||||
|
||||
### 从DyLoRA模型中提取LoRA模型
|
||||
|
||||
请使用`networks`文件夹中的`extract_lora_from_dylora.py`。指定`unit`单位后,从DyLoRA模型中提取LoRA模型。
|
||||
|
||||
例如,命令行如下:
|
||||
|
||||
```powershell
|
||||
python networks\extract_lora_from_dylora.py --model "foldername/dylora-model.safetensors" --save_to "foldername/dylora-model-split.safetensors" --unit 4
|
||||
```
|
||||
|
||||
`--model` 参数用于指定DyLoRA模型文件。`--save_to` 参数用于指定要保存提取的模型的文件名(rank值将附加到文件名中)。`--unit` 参数用于指定DyLoRA训练时的`unit`。
|
||||
|
||||
## 分层学习率
|
||||
|
||||
请参阅PR#355了解详细信息。
|
||||
|
||||
您可以指定完整模型的25个块的权重。虽然第一个块没有对应的LoRA,但为了与分层LoRA应用等的兼容性,将其设为25个。此外,如果不扩展到conv2d3x3,则某些块中可能不存在LoRA,但为了统一描述,请始终指定25个值。
|
||||
|
||||
请在 `--network_args` 中指定以下参数。
|
||||
|
||||
- `down_lr_weight`:指定U-Net down blocks的学习率权重。可以指定以下内容:
|
||||
- 每个块的权重:指定12个数字,例如`"down_lr_weight=0,0,0,0,0,0,1,1,1,1,1,1"`
|
||||
- 从预设中指定:例如`"down_lr_weight=sine"`(使用正弦曲线指定权重)。可以指定sine、cosine、linear、reverse_linear、zeros。另外,添加 `+数字` 时,可以将指定的数字加上(变为0.25〜1.25)。
|
||||
- `mid_lr_weight`:指定U-Net mid block的学习率权重。只需指定一个数字,例如 `"mid_lr_weight=0.5"`。
|
||||
- `up_lr_weight`:指定U-Net up blocks的学习率权重。与down_lr_weight相同。
|
||||
- 省略指定的部分将被视为1.0。另外,如果将权重设为0,则不会创建该块的LoRA模块。
|
||||
- `block_lr_zero_threshold`:如果权重小于此值,则不会创建LoRA模块。默认值为0。
|
||||
|
||||
### 分层学习率命令行指定示例:
|
||||
|
||||
|
||||
```powershell
|
||||
--network_args "down_lr_weight=0.5,0.5,0.5,0.5,1.0,1.0,1.0,1.0,1.5,1.5,1.5,1.5" "mid_lr_weight=2.0" "up_lr_weight=1.5,1.5,1.5,1.5,1.0,1.0,1.0,1.0,0.5,0.5,0.5,0.5"
|
||||
|
||||
--network_args "block_lr_zero_threshold=0.1" "down_lr_weight=sine+.5" "mid_lr_weight=1.5" "up_lr_weight=cosine+.5"
|
||||
```
|
||||
|
||||
### Hierarchical Learning Rate指定的toml文件示例:
|
||||
|
||||
```toml
|
||||
network_args = [ "down_lr_weight=0.5,0.5,0.5,0.5,1.0,1.0,1.0,1.0,1.5,1.5,1.5,1.5", "mid_lr_weight=2.0", "up_lr_weight=1.5,1.5,1.5,1.5,1.0,1.0,1.0,1.0,0.5,0.5,0.5,0.5",]
|
||||
|
||||
network_args = [ "block_lr_zero_threshold=0.1", "down_lr_weight=sine+.5", "mid_lr_weight=1.5", "up_lr_weight=cosine+.5", ]
|
||||
```
|
||||
|
||||
## 层次结构维度(rank)
|
||||
|
||||
您可以指定完整模型的25个块的维度(rank)。与分层学习率一样,某些块可能不存在LoRA,但请始终指定25个值。
|
||||
|
||||
请在 `--network_args` 中指定以下参数:
|
||||
|
||||
- `block_dims`:指定每个块的维度(rank)。指定25个数字,例如 `"block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"`。
|
||||
- `block_alphas`:指定每个块的alpha。与block_dims一样,指定25个数字。如果省略,将使用network_alpha的值。
|
||||
- `conv_block_dims`:将LoRA扩展到Conv2d 3x3,并指定每个块的维度(rank)。
|
||||
- `conv_block_alphas`:在将LoRA扩展到Conv2d 3x3时指定每个块的alpha。如果省略,将使用conv_alpha的值。
|
||||
|
||||
### 层次结构维度(rank)命令行指定示例:
|
||||
|
||||
|
||||
```powershell
|
||||
--network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2"
|
||||
|
||||
--network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2" "conv_block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"
|
||||
|
||||
--network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2" "block_alphas=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"
|
||||
```
|
||||
|
||||
### 层级别dim(rank) toml文件指定示例:
|
||||
|
||||
```toml
|
||||
network_args = [ "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2",]
|
||||
|
||||
network_args = [ "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2", "block_alphas=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2",]
|
||||
```
|
||||
|
||||
# Other scripts
|
||||
这些是与LoRA相关的脚本,如合并脚本等。
|
||||
|
||||
关于合并脚本
|
||||
您可以使用merge_lora.py脚本将LoRA的训练结果合并到稳定扩散模型中,也可以将多个LoRA模型合并。
|
||||
|
||||
合并到稳定扩散模型中的LoRA模型
|
||||
合并后的模型可以像常规的稳定扩散ckpt一样使用。例如,以下是一个命令行示例:
|
||||
|
||||
```
|
||||
python networks\merge_lora.py --sd_model ..\model\model.ckpt
|
||||
--save_to ..\lora_train1\model-char1-merged.safetensors
|
||||
--models ..\lora_train1\last.safetensors --ratios 0.8
|
||||
```
|
||||
|
||||
请使用 Stable Diffusion v2.x 模型进行训练并进行合并时,需要指定--v2选项。
|
||||
|
||||
使用--sd_model选项指定要合并的 Stable Diffusion 模型文件(仅支持 .ckpt 或 .safetensors 格式,目前不支持 Diffusers)。
|
||||
|
||||
使用--save_to选项指定合并后模型的保存路径(根据扩展名自动判断为 .ckpt 或 .safetensors)。
|
||||
|
||||
使用--models选项指定已训练的 LoRA 模型文件,也可以指定多个,然后按顺序进行合并。
|
||||
|
||||
使用--ratios选项以0~1.0的数字指定每个模型的应用率(将多大比例的权重反映到原始模型中)。例如,在接近过度拟合的情况下,降低应用率可能会使结果更好。请指定与模型数量相同的比率。
|
||||
|
||||
当指定多个模型时,格式如下:
|
||||
|
||||
|
||||
```
|
||||
python networks\merge_lora.py --sd_model ..\model\model.ckpt
|
||||
--save_to ..\lora_train1\model-char1-merged.safetensors
|
||||
--models ..\lora_train1\last.safetensors ..\lora_train2\last.safetensors --ratios 0.8 0.5
|
||||
```
|
||||
|
||||
### 将多个LoRA模型合并
|
||||
|
||||
将多个LoRA模型逐个应用于SD模型与将多个LoRA模型合并后再应用于SD模型之间,由于计算顺序的不同,会得到微妙不同的结果。
|
||||
|
||||
例如,下面是一个命令行示例:
|
||||
|
||||
```
|
||||
python networks\merge_lora.py
|
||||
--save_to ..\lora_train1\model-char1-style1-merged.safetensors
|
||||
--models ..\lora_train1\last.safetensors ..\lora_train2\last.safetensors --ratios 0.6 0.4
|
||||
```
|
||||
|
||||
--sd_model选项不需要指定。
|
||||
|
||||
通过--save_to选项指定合并后的LoRA模型的保存位置(.ckpt或.safetensors,根据扩展名自动识别)。
|
||||
|
||||
通过--models选项指定学习的LoRA模型文件。可以指定三个或更多。
|
||||
|
||||
通过--ratios选项以0~1.0的数字指定每个模型的比率(反映多少权重来自原始模型)。如果将两个模型一对一合并,则比率将是“0.5 0.5”。如果比率为“1.0 1.0”,则总重量将过大,可能会产生不理想的结果。
|
||||
|
||||
在v1和v2中学习的LoRA,以及rank(维数)或“alpha”不同的LoRA不能合并。仅包含U-Net的LoRA和包含U-Net+文本编码器的LoRA可以合并,但结果未知。
|
||||
|
||||
### 其他选项
|
||||
|
||||
* 精度
|
||||
* 可以从float、fp16或bf16中选择合并计算时的精度。默认为float以保证精度。如果想减少内存使用量,请指定fp16/bf16。
|
||||
* save_precision
|
||||
* 可以从float、fp16或bf16中选择在保存模型时的精度。默认与精度相同。
|
||||
|
||||
## 合并多个维度不同的LoRA模型
|
||||
|
||||
将多个LoRA近似为一个LoRA(无法完全复制)。使用'svd_merge_lora.py'。例如,以下是命令行的示例。
|
||||
```
|
||||
python networks\svd_merge_lora.py
|
||||
--save_to ..\lora_train1\model-char1-style1-merged.safetensors
|
||||
--models ..\lora_train1\last.safetensors ..\lora_train2\last.safetensors
|
||||
--ratios 0.6 0.4 --new_rank 32 --device cuda
|
||||
```
|
||||
`merge_lora.py`和主要选项相同。以下选项已添加:
|
||||
|
||||
- `--new_rank`
|
||||
- 指定要创建的LoRA rank。
|
||||
- `--new_conv_rank`
|
||||
- 指定要创建的Conv2d 3x3 LoRA的rank。如果省略,则与`new_rank`相同。
|
||||
- `--device`
|
||||
- 如果指定为`--device cuda`,则在GPU上执行计算。处理速度将更快。
|
||||
|
||||
## 在此存储库中生成图像的脚本中
|
||||
|
||||
请在`gen_img_diffusers.py`中添加`--network_module`和`--network_weights`选项。其含义与训练时相同。
|
||||
|
||||
通过`--network_mul`选项,可以指定0~1.0的数字来改变LoRA的应用率。
|
||||
|
||||
## 请参考以下示例,在Diffusers的pipeline中生成。
|
||||
|
||||
所需文件仅为networks/lora.py。请注意,该示例只能在Diffusers版本0.10.2中正常运行。
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from networks.lora import LoRAModule, create_network_from_weights
|
||||
from safetensors.torch import load_file
|
||||
|
||||
# if the ckpt is CompVis based, convert it to Diffusers beforehand with tools/convert_diffusers20_original_sd.py. See --help for more details.
|
||||
|
||||
model_id_or_dir = r"model_id_on_hugging_face_or_dir"
|
||||
device = "cuda"
|
||||
|
||||
# create pipe
|
||||
print(f"creating pipe from {model_id_or_dir}...")
|
||||
pipe = StableDiffusionPipeline.from_pretrained(model_id_or_dir, revision="fp16", torch_dtype=torch.float16)
|
||||
pipe = pipe.to(device)
|
||||
vae = pipe.vae
|
||||
text_encoder = pipe.text_encoder
|
||||
unet = pipe.unet
|
||||
|
||||
# load lora networks
|
||||
print(f"loading lora networks...")
|
||||
|
||||
lora_path1 = r"lora1.safetensors"
|
||||
sd = load_file(lora_path1) # If the file is .ckpt, use torch.load instead.
|
||||
network1, sd = create_network_from_weights(0.5, None, vae, text_encoder,unet, sd)
|
||||
network1.apply_to(text_encoder, unet)
|
||||
network1.load_state_dict(sd)
|
||||
network1.to(device, dtype=torch.float16)
|
||||
|
||||
# # You can merge weights instead of apply_to+load_state_dict. network.set_multiplier does not work
|
||||
# network.merge_to(text_encoder, unet, sd)
|
||||
|
||||
lora_path2 = r"lora2.safetensors"
|
||||
sd = load_file(lora_path2)
|
||||
network2, sd = create_network_from_weights(0.7, None, vae, text_encoder,unet, sd)
|
||||
network2.apply_to(text_encoder, unet)
|
||||
network2.load_state_dict(sd)
|
||||
network2.to(device, dtype=torch.float16)
|
||||
|
||||
lora_path3 = r"lora3.safetensors"
|
||||
sd = load_file(lora_path3)
|
||||
network3, sd = create_network_from_weights(0.5, None, vae, text_encoder,unet, sd)
|
||||
network3.apply_to(text_encoder, unet)
|
||||
network3.load_state_dict(sd)
|
||||
network3.to(device, dtype=torch.float16)
|
||||
|
||||
# prompts
|
||||
prompt = "masterpiece, best quality, 1girl, in white shirt, looking at viewer"
|
||||
negative_prompt = "bad quality, worst quality, bad anatomy, bad hands"
|
||||
|
||||
# exec pipe
|
||||
print("generating image...")
|
||||
with torch.autocast("cuda"):
|
||||
image = pipe(prompt, guidance_scale=7.5, negative_prompt=negative_prompt).images[0]
|
||||
|
||||
# if not merged, you can use set_multiplier
|
||||
# network1.set_multiplier(0.8)
|
||||
# and generate image again...
|
||||
|
||||
# save image
|
||||
image.save(r"by_diffusers..png")
|
||||
```
|
||||
|
||||
## 从两个模型的差异中创建LoRA模型。
|
||||
|
||||
[参考讨论链接](https://github.com/cloneofsimo/lora/discussions/56)這是參考實現的結果。數學公式沒有改變(我並不完全理解,但似乎使用奇異值分解進行了近似)。
|
||||
|
||||
将两个模型(例如微调原始模型和微调后的模型)的差异近似为LoRA。
|
||||
|
||||
### 脚本执行方法
|
||||
|
||||
请按以下方式指定。
|
||||
|
||||
```
|
||||
python networks\extract_lora_from_models.py --model_org base-model.ckpt
|
||||
--model_tuned fine-tuned-model.ckpt
|
||||
--save_to lora-weights.safetensors --dim 4
|
||||
```
|
||||
|
||||
--model_org 选项指定原始的Stable Diffusion模型。如果要应用创建的LoRA模型,则需要指定该模型并将其应用。可以指定.ckpt或.safetensors文件。
|
||||
|
||||
--model_tuned 选项指定要提取差分的目标Stable Diffusion模型。例如,可以指定经过Fine Tuning或DreamBooth后的模型。可以指定.ckpt或.safetensors文件。
|
||||
|
||||
--save_to 指定LoRA模型的保存路径。--dim指定LoRA的维数。
|
||||
|
||||
生成的LoRA模型可以像已训练的LoRA模型一样使用。
|
||||
|
||||
当两个模型的文本编码器相同时,LoRA将成为仅包含U-Net的LoRA。
|
||||
|
||||
### 其他选项
|
||||
|
||||
- `--v2`
|
||||
- 如果使用v2.x的稳定扩散模型,请指定此选项。
|
||||
- `--device`
|
||||
- 指定为 ``--device cuda`` 可在GPU上执行计算。这会使处理速度更快(即使在CPU上也不会太慢,大约快几倍)。
|
||||
- `--save_precision`
|
||||
- 指定LoRA的保存格式为“float”、“fp16”、“bf16”。如果省略,将使用float。
|
||||
- `--conv_dim`
|
||||
- 指定后,将扩展LoRA的应用范围到Conv2d 3x3。指定Conv2d 3x3的rank。
|
||||
-
|
||||
## 图像大小调整脚本
|
||||
|
||||
(稍后将整理文件,但现在先在这里写下说明。)
|
||||
|
||||
在 Aspect Ratio Bucketing 的功能扩展中,现在可以将小图像直接用作教师数据,而无需进行放大。我收到了一个用于前处理的脚本,其中包括将原始教师图像缩小的图像添加到教师数据中可以提高准确性的报告。我整理了这个脚本并加入了感谢 bmaltais 先生。
|
||||
|
||||
### 执行脚本的方法如下。
|
||||
原始图像以及调整大小后的图像将保存到转换目标文件夹中。调整大小后的图像将在文件名中添加“+512x512”之类的调整后的分辨率(与图像大小不同)。小于调整大小后分辨率的图像将不会被放大。
|
||||
|
||||
```
|
||||
python tools\resize_images_to_resolution.py --max_resolution 512x512,384x384,256x256 --save_as_png
|
||||
--copy_associated_files 源图像文件夹目标文件夹
|
||||
```
|
||||
|
||||
在元画像文件夹中的图像文件将被调整大小以达到指定的分辨率(可以指定多个),并保存到目标文件夹中。除图像外的文件将被保留为原样。
|
||||
|
||||
请使用“--max_resolution”选项指定调整大小后的大小,使其达到指定的面积大小。如果指定多个,则会在每个分辨率上进行调整大小。例如,“512x512,384x384,256x256”将使目标文件夹中的图像变为原始大小和调整大小后的大小×3共计4张图像。
|
||||
|
||||
如果使用“--save_as_png”选项,则会以PNG格式保存。如果省略,则默认以JPEG格式(quality=100)保存。
|
||||
|
||||
如果使用“--copy_associated_files”选项,则会将与图像相同的文件名(例如标题等)的文件复制到调整大小后的图像文件的文件名相同的位置,但不包括扩展名。
|
||||
|
||||
### 其他选项
|
||||
|
||||
- divisible_by
|
||||
- 将图像中心裁剪到能够被该值整除的大小(分别是垂直和水平的大小),以便调整大小后的图像大小可以被该值整除。
|
||||
- interpolation
|
||||
- 指定缩小时的插值方法。可从``area、cubic、lanczos4``中选择,默认为``area``。
|
||||
|
||||
|
||||
# 追加信息
|
||||
|
||||
## 与cloneofsimo的代码库的区别
|
||||
|
||||
截至2022年12月25日,本代码库将LoRA应用扩展到了Text Encoder的MLP、U-Net的FFN以及Transformer的输入/输出投影中,从而增强了表现力。但是,内存使用量增加了,接近了8GB的限制。
|
||||
|
||||
此外,模块交换机制也完全不同。
|
||||
|
||||
## 关于未来的扩展
|
||||
|
||||
除了LoRA之外,我们还计划添加其他扩展,以支持更多的功能。
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
実装に当たっては https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion を大いに参考にしました。
|
||||
|
||||
学習したモデルはWeb UIでもそのまま使えます。なお恐らくSD2.xにも対応していますが現時点では未テストです。
|
||||
学習したモデルはWeb UIでもそのまま使えます。
|
||||
|
||||
# 学習の手順
|
||||
|
||||
717
fine_tune.py
717
fine_tune.py
@@ -5,361 +5,492 @@ import argparse
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
from multiprocessing import Value
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
import diffusers
|
||||
from diffusers import DDPMScheduler
|
||||
|
||||
import library.train_util as train_util
|
||||
import library.config_util as config_util
|
||||
from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import (
|
||||
apply_snr_weight,
|
||||
get_weighted_text_embeddings,
|
||||
prepare_scheduler_for_custom_training,
|
||||
scale_v_prediction_loss_like_noise_prediction,
|
||||
apply_debiased_estimation,
|
||||
)
|
||||
|
||||
def collate_fn(examples):
|
||||
return examples[0]
|
||||
|
||||
|
||||
def train(args):
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
|
||||
cache_latents = args.cache_latents
|
||||
cache_latents = args.cache_latents
|
||||
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed) # 乱数系列を初期化する
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed) # 乱数系列を初期化する
|
||||
|
||||
tokenizer = train_util.load_tokenizer(args)
|
||||
tokenizer = train_util.load_tokenizer(args)
|
||||
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True))
|
||||
if args.dataset_config is not None:
|
||||
print(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
|
||||
else:
|
||||
user_config = {
|
||||
"datasets": [{
|
||||
"subsets": [{
|
||||
"image_dir": args.train_data_dir,
|
||||
"metadata_file": args.in_json,
|
||||
}]
|
||||
}]
|
||||
}
|
||||
# データセットを準備する
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, False, True))
|
||||
if args.dataset_config is not None:
|
||||
print(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
print(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
else:
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": [
|
||||
{
|
||||
"image_dir": args.train_data_dir,
|
||||
"metadata_file": args.in_json,
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
|
||||
if args.debug_dataset:
|
||||
train_util.debug_dataset(train_dataset_group)
|
||||
return
|
||||
if len(train_dataset_group) == 0:
|
||||
print("No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。")
|
||||
return
|
||||
|
||||
if cache_latents:
|
||||
assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
|
||||
# acceleratorを準備する
|
||||
print("prepare accelerator")
|
||||
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
|
||||
# モデルを読み込む
|
||||
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype)
|
||||
|
||||
# verify load/save model formats
|
||||
if load_stable_diffusion_format:
|
||||
src_stable_diffusion_ckpt = args.pretrained_model_name_or_path
|
||||
src_diffusers_model_path = None
|
||||
else:
|
||||
src_stable_diffusion_ckpt = None
|
||||
src_diffusers_model_path = args.pretrained_model_name_or_path
|
||||
|
||||
if args.save_model_as is None:
|
||||
save_stable_diffusion_format = load_stable_diffusion_format
|
||||
use_safetensors = args.use_safetensors
|
||||
else:
|
||||
save_stable_diffusion_format = args.save_model_as.lower() == 'ckpt' or args.save_model_as.lower() == 'safetensors'
|
||||
use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
|
||||
|
||||
# Diffusers版のxformers使用フラグを設定する関数
|
||||
def set_diffusers_xformers_flag(model, valid):
|
||||
# model.set_use_memory_efficient_attention_xformers(valid) # 次のリリースでなくなりそう
|
||||
# pipeが自動で再帰的にset_use_memory_efficient_attention_xformersを探すんだって(;´Д`)
|
||||
# U-Netだけ使う時にはどうすればいいのか……仕方ないからコピって使うか
|
||||
# 0.10.2でなんか巻き戻って個別に指定するようになった(;^ω^)
|
||||
|
||||
# Recursively walk through all the children.
|
||||
# Any children which exposes the set_use_memory_efficient_attention_xformers method
|
||||
# gets the message
|
||||
def fn_recursive_set_mem_eff(module: torch.nn.Module):
|
||||
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
|
||||
module.set_use_memory_efficient_attention_xformers(valid)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_set_mem_eff(child)
|
||||
|
||||
fn_recursive_set_mem_eff(model)
|
||||
|
||||
# モデルに xformers とか memory efficient attention を組み込む
|
||||
if args.diffusers_xformers:
|
||||
print("Use xformers by Diffusers")
|
||||
set_diffusers_xformers_flag(unet, True)
|
||||
else:
|
||||
# Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある
|
||||
print("Disable Diffusers' xformers")
|
||||
set_diffusers_xformers_flag(unet, False)
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
||||
|
||||
# 学習を準備する
|
||||
if cache_latents:
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_latents(vae)
|
||||
vae.to("cpu")
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
# 学習を準備する:モデルを適切な状態にする
|
||||
training_models = []
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
training_models.append(unet)
|
||||
|
||||
if args.train_text_encoder:
|
||||
print("enable text encoder training")
|
||||
if args.gradient_checkpointing:
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
training_models.append(text_encoder)
|
||||
else:
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder.requires_grad_(False) # text encoderは学習しない
|
||||
if args.gradient_checkpointing:
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
text_encoder.train() # required for gradient_checkpointing
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
text_encoder.eval()
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
|
||||
|
||||
if not cache_latents:
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
for m in training_models:
|
||||
m.requires_grad_(True)
|
||||
params = []
|
||||
for m in training_models:
|
||||
params.extend(m.parameters())
|
||||
params_to_optimize = params
|
||||
if args.debug_dataset:
|
||||
train_util.debug_dataset(train_dataset_group)
|
||||
return
|
||||
if len(train_dataset_group) == 0:
|
||||
print(
|
||||
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
|
||||
)
|
||||
return
|
||||
|
||||
# 学習に必要なクラスを準備する
|
||||
print("prepare optimizer, data loader etc.")
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
|
||||
if cache_latents:
|
||||
assert (
|
||||
train_dataset_group.is_latent_cacheable()
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
||||
# acceleratorを準備する
|
||||
print("prepare accelerator")
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
|
||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
||||
# モデルを読み込む
|
||||
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
||||
if args.full_fp16:
|
||||
assert args.mixed_precision == "fp16", "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
||||
print("enable full fp16 training.")
|
||||
unet.to(weight_dtype)
|
||||
text_encoder.to(weight_dtype)
|
||||
# verify load/save model formats
|
||||
if load_stable_diffusion_format:
|
||||
src_stable_diffusion_ckpt = args.pretrained_model_name_or_path
|
||||
src_diffusers_model_path = None
|
||||
else:
|
||||
src_stable_diffusion_ckpt = None
|
||||
src_diffusers_model_path = args.pretrained_model_name_or_path
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
if args.train_text_encoder:
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler)
|
||||
else:
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||
if args.save_model_as is None:
|
||||
save_stable_diffusion_format = load_stable_diffusion_format
|
||||
use_safetensors = args.use_safetensors
|
||||
else:
|
||||
save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors"
|
||||
use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||
if args.full_fp16:
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
# Diffusers版のxformers使用フラグを設定する関数
|
||||
def set_diffusers_xformers_flag(model, valid):
|
||||
# model.set_use_memory_efficient_attention_xformers(valid) # 次のリリースでなくなりそう
|
||||
# pipeが自動で再帰的にset_use_memory_efficient_attention_xformersを探すんだって(;´Д`)
|
||||
# U-Netだけ使う時にはどうすればいいのか……仕方ないからコピって使うか
|
||||
# 0.10.2でなんか巻き戻って個別に指定するようになった(;^ω^)
|
||||
|
||||
# resumeする
|
||||
if args.resume is not None:
|
||||
print(f"resume training from state: {args.resume}")
|
||||
accelerator.load_state(args.resume)
|
||||
# Recursively walk through all the children.
|
||||
# Any children which exposes the set_use_memory_efficient_attention_xformers method
|
||||
# gets the message
|
||||
def fn_recursive_set_mem_eff(module: torch.nn.Module):
|
||||
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
|
||||
module.set_use_memory_efficient_attention_xformers(valid)
|
||||
|
||||
# epoch数を計算する
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
||||
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
||||
for child in module.children():
|
||||
fn_recursive_set_mem_eff(child)
|
||||
|
||||
# 学習する
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
print("running training / 学習開始")
|
||||
print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
|
||||
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||
print(f" num epochs / epoch数: {num_train_epochs}")
|
||||
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
||||
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
||||
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||
fn_recursive_set_mem_eff(model)
|
||||
|
||||
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
||||
global_step = 0
|
||||
# モデルに xformers とか memory efficient attention を組み込む
|
||||
if args.diffusers_xformers:
|
||||
accelerator.print("Use xformers by Diffusers")
|
||||
set_diffusers_xformers_flag(unet, True)
|
||||
else:
|
||||
# Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある
|
||||
accelerator.print("Disable Diffusers' xformers")
|
||||
set_diffusers_xformers_flag(unet, False)
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
||||
|
||||
noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
|
||||
num_train_timesteps=1000, clip_sample=False)
|
||||
# 学習を準備する
|
||||
if cache_latents:
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||
vae.to("cpu")
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("finetuning")
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
for epoch in range(num_train_epochs):
|
||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||
train_dataset_group.set_current_epoch(epoch + 1)
|
||||
# 学習を準備する:モデルを適切な状態にする
|
||||
training_models = []
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
training_models.append(unet)
|
||||
|
||||
if args.train_text_encoder:
|
||||
accelerator.print("enable text encoder training")
|
||||
if args.gradient_checkpointing:
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
training_models.append(text_encoder)
|
||||
else:
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder.requires_grad_(False) # text encoderは学習しない
|
||||
if args.gradient_checkpointing:
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
text_encoder.train() # required for gradient_checkpointing
|
||||
else:
|
||||
text_encoder.eval()
|
||||
|
||||
if not cache_latents:
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
for m in training_models:
|
||||
m.train()
|
||||
m.requires_grad_(True)
|
||||
|
||||
loss_total = 0
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
|
||||
with torch.no_grad():
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device)
|
||||
else:
|
||||
# latentに変換
|
||||
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||
latents = latents * 0.18215
|
||||
b_size = latents.shape[0]
|
||||
trainable_params = []
|
||||
if args.learning_rate_te is None or not args.train_text_encoder:
|
||||
for m in training_models:
|
||||
trainable_params.extend(m.parameters())
|
||||
else:
|
||||
trainable_params = [
|
||||
{"params": list(unet.parameters()), "lr": args.learning_rate},
|
||||
{"params": list(text_encoder.parameters()), "lr": args.learning_rate_te},
|
||||
]
|
||||
|
||||
with torch.set_grad_enabled(args.train_text_encoder):
|
||||
# Get the text embedding for conditioning
|
||||
input_ids = batch["input_ids"].to(accelerator.device)
|
||||
encoder_hidden_states = train_util.get_hidden_states(
|
||||
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype)
|
||||
# 学習に必要なクラスを準備する
|
||||
accelerator.print("prepare optimizer, data loader etc.")
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents, device=latents.device)
|
||||
if args.noise_offset:
|
||||
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
|
||||
timesteps = timesteps.long()
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||
)
|
||||
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
# データセット側にも学習ステップを送信
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
|
||||
# Predict the noise residual
|
||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
if args.v_parameterization:
|
||||
# v-parameterization training
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
||||
if args.full_fp16:
|
||||
assert (
|
||||
args.mixed_precision == "fp16"
|
||||
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
||||
accelerator.print("enable full fp16 training.")
|
||||
unet.to(weight_dtype)
|
||||
text_encoder.to(weight_dtype)
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
if args.train_text_encoder:
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
else:
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
params_to_clip = []
|
||||
for m in training_models:
|
||||
params_to_clip.extend(m.parameters())
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||
if args.full_fp16:
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
# resumeする
|
||||
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
# epoch数を計算する
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
||||
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
||||
|
||||
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
# 学習する
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
accelerator.print("running training / 学習開始")
|
||||
accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
|
||||
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
||||
accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
||||
accelerator.print(
|
||||
f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
|
||||
)
|
||||
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||
|
||||
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
||||
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
|
||||
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
|
||||
accelerator.log(logs, step=global_step)
|
||||
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
||||
global_step = 0
|
||||
|
||||
# TODO moving averageにする
|
||||
loss_total += current_loss
|
||||
avr_loss = loss_total / (step+1)
|
||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
||||
)
|
||||
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
|
||||
if args.zero_terminal_snr:
|
||||
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
init_kwargs['wandb'] = {'name': args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss/epoch": loss_total / len(train_dataloader)}
|
||||
accelerator.log(logs, step=epoch+1)
|
||||
# For --sample_at_first
|
||||
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
for epoch in range(num_train_epochs):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
if args.save_every_n_epochs is not None:
|
||||
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
||||
train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
|
||||
save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
|
||||
for m in training_models:
|
||||
m.train()
|
||||
|
||||
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
is_main_process = accelerator.is_main_process
|
||||
if is_main_process:
|
||||
unet = unwrap_model(unet)
|
||||
text_encoder = unwrap_model(text_encoder)
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
|
||||
with torch.no_grad():
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device) # .to(dtype=weight_dtype)
|
||||
else:
|
||||
# latentに変換
|
||||
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||
latents = latents * 0.18215
|
||||
b_size = latents.shape[0]
|
||||
|
||||
accelerator.end_training()
|
||||
with torch.set_grad_enabled(args.train_text_encoder):
|
||||
# Get the text embedding for conditioning
|
||||
if args.weighted_captions:
|
||||
encoder_hidden_states = get_weighted_text_embeddings(
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
batch["captions"],
|
||||
accelerator.device,
|
||||
args.max_token_length // 75 if args.max_token_length else 1,
|
||||
clip_skip=args.clip_skip,
|
||||
)
|
||||
else:
|
||||
input_ids = batch["input_ids"].to(accelerator.device)
|
||||
encoder_hidden_states = train_util.get_hidden_states(
|
||||
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
|
||||
)
|
||||
|
||||
if args.save_state:
|
||||
train_util.save_state_on_train_end(args, accelerator)
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||
|
||||
del accelerator # この後メモリを使うのでこれは消す
|
||||
# Predict the noise residual
|
||||
with accelerator.autocast():
|
||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
if is_main_process:
|
||||
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
||||
train_util.save_sd_model_on_train_end(args, src_path, save_stable_diffusion_format, use_safetensors,
|
||||
save_dtype, epoch, global_step, text_encoder, unet, vae)
|
||||
print("model saved.")
|
||||
if args.v_parameterization:
|
||||
# v-parameterization training
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
|
||||
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
|
||||
# do not mean over batch dimension for snr weight or scale v-pred loss
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
if args.debiased_estimation_loss:
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||
|
||||
loss = loss.mean() # mean over batch dimension
|
||||
else:
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
params_to_clip = []
|
||||
for m in training_models:
|
||||
params_to_clip.extend(m.parameters())
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
train_util.sample_images(
|
||||
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
||||
train_util.save_sd_model_on_epoch_end_or_stepwise(
|
||||
args,
|
||||
False,
|
||||
accelerator,
|
||||
src_path,
|
||||
save_stable_diffusion_format,
|
||||
use_safetensors,
|
||||
save_dtype,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
accelerator.unwrap_model(text_encoder),
|
||||
accelerator.unwrap_model(unet),
|
||||
vae,
|
||||
)
|
||||
|
||||
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss}
|
||||
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||
avr_loss: float = loss_recorder.moving_average
|
||||
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss/epoch": loss_recorder.moving_average}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if args.save_every_n_epochs is not None:
|
||||
if accelerator.is_main_process:
|
||||
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
||||
train_util.save_sd_model_on_epoch_end_or_stepwise(
|
||||
args,
|
||||
True,
|
||||
accelerator,
|
||||
src_path,
|
||||
save_stable_diffusion_format,
|
||||
use_safetensors,
|
||||
save_dtype,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
accelerator.unwrap_model(text_encoder),
|
||||
accelerator.unwrap_model(unet),
|
||||
vae,
|
||||
)
|
||||
|
||||
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
is_main_process = accelerator.is_main_process
|
||||
if is_main_process:
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
if args.save_state and is_main_process:
|
||||
train_util.save_state_on_train_end(args, accelerator)
|
||||
|
||||
del accelerator # この後メモリを使うのでこれは消す
|
||||
|
||||
if is_main_process:
|
||||
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
||||
train_util.save_sd_model_on_train_end(
|
||||
args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
|
||||
)
|
||||
print("model saved.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, False, True, True)
|
||||
train_util.add_training_arguments(parser, False)
|
||||
train_util.add_sd_saving_arguments(parser)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
config_util.add_config_arguments(parser)
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, False, True, True)
|
||||
train_util.add_training_arguments(parser, False)
|
||||
train_util.add_sd_saving_arguments(parser)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
config_util.add_config_arguments(parser)
|
||||
custom_train_functions.add_custom_train_arguments(parser)
|
||||
|
||||
parser.add_argument("--diffusers_xformers", action='store_true',
|
||||
help='use xformers by diffusers / Diffusersでxformersを使用する')
|
||||
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
|
||||
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
|
||||
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
|
||||
parser.add_argument(
|
||||
"--learning_rate_te",
|
||||
type=float,
|
||||
default=None,
|
||||
help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
train(args)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
train(args)
|
||||
|
||||
@@ -163,13 +163,19 @@ def main(args):
|
||||
print("done!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
# parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
|
||||
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = setup_parser()
|
||||
|
||||
args, unknown = parser.parse_known_args()
|
||||
if len(unknown) == 1:
|
||||
print("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.")
|
||||
|
||||
@@ -3,160 +3,200 @@ import glob
|
||||
import os
|
||||
import json
|
||||
import random
|
||||
import sys
|
||||
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
from blip.blip import blip_decoder
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
from blip.blip import blip_decoder, is_url
|
||||
import library.train_util as train_util
|
||||
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
IMAGE_SIZE = 384
|
||||
|
||||
# 正方形でいいのか? という気がするがソースがそうなので
|
||||
IMAGE_TRANSFORM = transforms.Compose([
|
||||
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
||||
])
|
||||
IMAGE_TRANSFORM = transforms.Compose(
|
||||
[
|
||||
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# 共通化したいが微妙に処理が異なる……
|
||||
class ImageLoadingTransformDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, image_paths):
|
||||
self.images = image_paths
|
||||
def __init__(self, image_paths):
|
||||
self.images = image_paths
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img_path = self.images[idx]
|
||||
def __getitem__(self, idx):
|
||||
img_path = self.images[idx]
|
||||
|
||||
try:
|
||||
image = Image.open(img_path).convert("RGB")
|
||||
# convert to tensor temporarily so dataloader will accept it
|
||||
tensor = IMAGE_TRANSFORM(image)
|
||||
except Exception as e:
|
||||
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
||||
return None
|
||||
try:
|
||||
image = Image.open(img_path).convert("RGB")
|
||||
# convert to tensor temporarily so dataloader will accept it
|
||||
tensor = IMAGE_TRANSFORM(image)
|
||||
except Exception as e:
|
||||
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
||||
return None
|
||||
|
||||
return (tensor, img_path)
|
||||
return (tensor, img_path)
|
||||
|
||||
|
||||
def collate_fn_remove_corrupted(batch):
|
||||
"""Collate function that allows to remove corrupted examples in the
|
||||
dataloader. It expects that the dataloader returns 'None' when that occurs.
|
||||
The 'None's in the batch are removed.
|
||||
"""
|
||||
# Filter out all the Nones (corrupted examples)
|
||||
batch = list(filter(lambda x: x is not None, batch))
|
||||
return batch
|
||||
"""Collate function that allows to remove corrupted examples in the
|
||||
dataloader. It expects that the dataloader returns 'None' when that occurs.
|
||||
The 'None's in the batch are removed.
|
||||
"""
|
||||
# Filter out all the Nones (corrupted examples)
|
||||
batch = list(filter(lambda x: x is not None, batch))
|
||||
return batch
|
||||
|
||||
|
||||
def main(args):
|
||||
# fix the seed for reproducibility
|
||||
seed = args.seed # + utils.get_rank()
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
# fix the seed for reproducibility
|
||||
seed = args.seed # + utils.get_rank()
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
if not os.path.exists("blip"):
|
||||
args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path
|
||||
if not os.path.exists("blip"):
|
||||
args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path
|
||||
|
||||
cwd = os.getcwd()
|
||||
print('Current Working Directory is: ', cwd)
|
||||
os.chdir('finetune')
|
||||
cwd = os.getcwd()
|
||||
print("Current Working Directory is: ", cwd)
|
||||
os.chdir("finetune")
|
||||
if not is_url(args.caption_weights) and not os.path.isfile(args.caption_weights):
|
||||
args.caption_weights = os.path.join("..", args.caption_weights)
|
||||
|
||||
print(f"load images from {args.train_data_dir}")
|
||||
image_paths = train_util.glob_images(args.train_data_dir)
|
||||
print(f"found {len(image_paths)} images.")
|
||||
print(f"load images from {args.train_data_dir}")
|
||||
train_data_dir_path = Path(args.train_data_dir)
|
||||
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
||||
print(f"found {len(image_paths)} images.")
|
||||
|
||||
print(f"loading BLIP caption: {args.caption_weights}")
|
||||
model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit='large', med_config="./blip/med_config.json")
|
||||
model.eval()
|
||||
model = model.to(DEVICE)
|
||||
print("BLIP loaded")
|
||||
print(f"loading BLIP caption: {args.caption_weights}")
|
||||
model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit="large", med_config="./blip/med_config.json")
|
||||
model.eval()
|
||||
model = model.to(DEVICE)
|
||||
print("BLIP loaded")
|
||||
|
||||
# captioningする
|
||||
def run_batch(path_imgs):
|
||||
imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE)
|
||||
# captioningする
|
||||
def run_batch(path_imgs):
|
||||
imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE)
|
||||
|
||||
with torch.no_grad():
|
||||
if args.beam_search:
|
||||
captions = model.generate(imgs, sample=False, num_beams=args.num_beams,
|
||||
max_length=args.max_length, min_length=args.min_length)
|
||||
else:
|
||||
captions = model.generate(imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length)
|
||||
with torch.no_grad():
|
||||
if args.beam_search:
|
||||
captions = model.generate(
|
||||
imgs, sample=False, num_beams=args.num_beams, max_length=args.max_length, min_length=args.min_length
|
||||
)
|
||||
else:
|
||||
captions = model.generate(
|
||||
imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length
|
||||
)
|
||||
|
||||
for (image_path, _), caption in zip(path_imgs, captions):
|
||||
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
|
||||
f.write(caption + "\n")
|
||||
if args.debug:
|
||||
print(image_path, caption)
|
||||
for (image_path, _), caption in zip(path_imgs, captions):
|
||||
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
|
||||
f.write(caption + "\n")
|
||||
if args.debug:
|
||||
print(image_path, caption)
|
||||
|
||||
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||
if args.max_data_loader_n_workers is not None:
|
||||
dataset = ImageLoadingTransformDataset(image_paths)
|
||||
data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
|
||||
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
|
||||
else:
|
||||
data = [[(None, ip)] for ip in image_paths]
|
||||
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||
if args.max_data_loader_n_workers is not None:
|
||||
dataset = ImageLoadingTransformDataset(image_paths)
|
||||
data = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=args.max_data_loader_n_workers,
|
||||
collate_fn=collate_fn_remove_corrupted,
|
||||
drop_last=False,
|
||||
)
|
||||
else:
|
||||
data = [[(None, ip)] for ip in image_paths]
|
||||
|
||||
b_imgs = []
|
||||
for data_entry in tqdm(data, smoothing=0.0):
|
||||
for data in data_entry:
|
||||
if data is None:
|
||||
continue
|
||||
b_imgs = []
|
||||
for data_entry in tqdm(data, smoothing=0.0):
|
||||
for data in data_entry:
|
||||
if data is None:
|
||||
continue
|
||||
|
||||
img_tensor, image_path = data
|
||||
if img_tensor is None:
|
||||
try:
|
||||
raw_image = Image.open(image_path)
|
||||
if raw_image.mode != 'RGB':
|
||||
raw_image = raw_image.convert("RGB")
|
||||
img_tensor = IMAGE_TRANSFORM(raw_image)
|
||||
except Exception as e:
|
||||
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||
continue
|
||||
img_tensor, image_path = data
|
||||
if img_tensor is None:
|
||||
try:
|
||||
raw_image = Image.open(image_path)
|
||||
if raw_image.mode != "RGB":
|
||||
raw_image = raw_image.convert("RGB")
|
||||
img_tensor = IMAGE_TRANSFORM(raw_image)
|
||||
except Exception as e:
|
||||
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||
continue
|
||||
|
||||
b_imgs.append((image_path, img_tensor))
|
||||
if len(b_imgs) >= args.batch_size:
|
||||
b_imgs.append((image_path, img_tensor))
|
||||
if len(b_imgs) >= args.batch_size:
|
||||
run_batch(b_imgs)
|
||||
b_imgs.clear()
|
||||
if len(b_imgs) > 0:
|
||||
run_batch(b_imgs)
|
||||
b_imgs.clear()
|
||||
if len(b_imgs) > 0:
|
||||
run_batch(b_imgs)
|
||||
|
||||
print("done!")
|
||||
print("done!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("--caption_weights", type=str, default="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth",
|
||||
help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)")
|
||||
parser.add_argument("--caption_extention", type=str, default=None,
|
||||
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
|
||||
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
|
||||
parser.add_argument("--beam_search", action="store_true",
|
||||
help="use beam search (default Nucleus sampling) / beam searchを使う(このオプション未指定時はNucleus sampling)")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
|
||||
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
|
||||
parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)")
|
||||
parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p")
|
||||
parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長")
|
||||
parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長")
|
||||
parser.add_argument('--seed', default=42, type=int, help='seed for reproducibility / 再現性を確保するための乱数seed')
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument(
|
||||
"--caption_weights",
|
||||
type=str,
|
||||
default="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth",
|
||||
help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--caption_extention",
|
||||
type=str,
|
||||
default=None,
|
||||
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)",
|
||||
)
|
||||
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
|
||||
parser.add_argument(
|
||||
"--beam_search",
|
||||
action="store_true",
|
||||
help="use beam search (default Nucleus sampling) / beam searchを使う(このオプション未指定時はNucleus sampling)",
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||
parser.add_argument(
|
||||
"--max_data_loader_n_workers",
|
||||
type=int,
|
||||
default=None,
|
||||
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)",
|
||||
)
|
||||
parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)")
|
||||
parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p")
|
||||
parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長")
|
||||
parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長")
|
||||
parser.add_argument("--seed", default=42, type=int, help="seed for reproducibility / 再現性を確保するための乱数seed")
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||
parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する")
|
||||
|
||||
args = parser.parse_args()
|
||||
return parser
|
||||
|
||||
# スペルミスしていたオプションを復元する
|
||||
if args.caption_extention is not None:
|
||||
args.caption_extension = args.caption_extention
|
||||
|
||||
main(args)
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# スペルミスしていたオプションを復元する
|
||||
if args.caption_extention is not None:
|
||||
args.caption_extension = args.caption_extention
|
||||
|
||||
main(args)
|
||||
|
||||
@@ -2,6 +2,7 @@ import argparse
|
||||
import os
|
||||
import re
|
||||
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
@@ -11,135 +12,165 @@ from transformers.generation.utils import GenerationMixin
|
||||
import library.train_util as train_util
|
||||
|
||||
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
PATTERN_REPLACE = [
|
||||
re.compile(r'(has|with|and) the (words?|letters?|name) (" ?[^"]*"|\w+)( ?(is )?(on|in) (the |her |their |him )?\w+)?'),
|
||||
re.compile(r'(with a sign )?that says ?(" ?[^"]*"|\w+)( ?on it)?'),
|
||||
re.compile(r"(with a sign )?that says ?(' ?(i'm)?[^']*'|\w+)( ?on it)?"),
|
||||
re.compile(r'with the number \d+ on (it|\w+ \w+)'),
|
||||
re.compile(r"with the number \d+ on (it|\w+ \w+)"),
|
||||
re.compile(r'with the words "'),
|
||||
re.compile(r'word \w+ on it'),
|
||||
re.compile(r'that says the word \w+ on it'),
|
||||
re.compile('that says\'the word "( on it)?'),
|
||||
re.compile(r"word \w+ on it"),
|
||||
re.compile(r"that says the word \w+ on it"),
|
||||
re.compile("that says'the word \"( on it)?"),
|
||||
]
|
||||
|
||||
# 誤検知しまくりの with the word xxxx を消す
|
||||
|
||||
|
||||
def remove_words(captions, debug):
|
||||
removed_caps = []
|
||||
for caption in captions:
|
||||
cap = caption
|
||||
for pat in PATTERN_REPLACE:
|
||||
cap = pat.sub("", cap)
|
||||
if debug and cap != caption:
|
||||
print(caption)
|
||||
print(cap)
|
||||
removed_caps.append(cap)
|
||||
return removed_caps
|
||||
removed_caps = []
|
||||
for caption in captions:
|
||||
cap = caption
|
||||
for pat in PATTERN_REPLACE:
|
||||
cap = pat.sub("", cap)
|
||||
if debug and cap != caption:
|
||||
print(caption)
|
||||
print(cap)
|
||||
removed_caps.append(cap)
|
||||
return removed_caps
|
||||
|
||||
|
||||
def collate_fn_remove_corrupted(batch):
|
||||
"""Collate function that allows to remove corrupted examples in the
|
||||
dataloader. It expects that the dataloader returns 'None' when that occurs.
|
||||
The 'None's in the batch are removed.
|
||||
"""
|
||||
# Filter out all the Nones (corrupted examples)
|
||||
batch = list(filter(lambda x: x is not None, batch))
|
||||
return batch
|
||||
"""Collate function that allows to remove corrupted examples in the
|
||||
dataloader. It expects that the dataloader returns 'None' when that occurs.
|
||||
The 'None's in the batch are removed.
|
||||
"""
|
||||
# Filter out all the Nones (corrupted examples)
|
||||
batch = list(filter(lambda x: x is not None, batch))
|
||||
return batch
|
||||
|
||||
|
||||
def main(args):
|
||||
# GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用
|
||||
org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation
|
||||
curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように
|
||||
r"""
|
||||
transformers 4.30.2で、バッチサイズ>1でも動くようになったので、以下コメントアウト
|
||||
|
||||
# input_idsがバッチサイズと同じ件数である必要がある:バッチサイズはこの関数から参照できないので外から渡す
|
||||
# ここより上で置き換えようとするとすごく大変
|
||||
def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs):
|
||||
input_ids = org_prepare_input_ids_for_generation(self, bos_token_id, encoder_outputs)
|
||||
if input_ids.size()[0] != curr_batch_size[0]:
|
||||
input_ids = input_ids.repeat(curr_batch_size[0], 1)
|
||||
return input_ids
|
||||
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
|
||||
# GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用
|
||||
org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation
|
||||
curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように
|
||||
|
||||
print(f"load images from {args.train_data_dir}")
|
||||
image_paths = train_util.glob_images(args.train_data_dir)
|
||||
print(f"found {len(image_paths)} images.")
|
||||
# input_idsがバッチサイズと同じ件数である必要がある:バッチサイズはこの関数から参照できないので外から渡す
|
||||
# ここより上で置き換えようとするとすごく大変
|
||||
def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs):
|
||||
input_ids = org_prepare_input_ids_for_generation(self, bos_token_id, encoder_outputs)
|
||||
if input_ids.size()[0] != curr_batch_size[0]:
|
||||
input_ids = input_ids.repeat(curr_batch_size[0], 1)
|
||||
return input_ids
|
||||
|
||||
# できればcacheに依存せず明示的にダウンロードしたい
|
||||
print(f"loading GIT: {args.model_id}")
|
||||
git_processor = AutoProcessor.from_pretrained(args.model_id)
|
||||
git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE)
|
||||
print("GIT loaded")
|
||||
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
|
||||
"""
|
||||
|
||||
# captioningする
|
||||
def run_batch(path_imgs):
|
||||
imgs = [im for _, im in path_imgs]
|
||||
print(f"load images from {args.train_data_dir}")
|
||||
train_data_dir_path = Path(args.train_data_dir)
|
||||
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
||||
print(f"found {len(image_paths)} images.")
|
||||
|
||||
curr_batch_size[0] = len(path_imgs)
|
||||
inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式
|
||||
generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length)
|
||||
captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
# できればcacheに依存せず明示的にダウンロードしたい
|
||||
print(f"loading GIT: {args.model_id}")
|
||||
git_processor = AutoProcessor.from_pretrained(args.model_id)
|
||||
git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE)
|
||||
print("GIT loaded")
|
||||
|
||||
if args.remove_words:
|
||||
captions = remove_words(captions, args.debug)
|
||||
# captioningする
|
||||
def run_batch(path_imgs):
|
||||
imgs = [im for _, im in path_imgs]
|
||||
|
||||
for (image_path, _), caption in zip(path_imgs, captions):
|
||||
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
|
||||
f.write(caption + "\n")
|
||||
if args.debug:
|
||||
print(image_path, caption)
|
||||
# curr_batch_size[0] = len(path_imgs)
|
||||
inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式
|
||||
generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length)
|
||||
captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
|
||||
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||
if args.max_data_loader_n_workers is not None:
|
||||
dataset = train_util.ImageLoadingDataset(image_paths)
|
||||
data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
|
||||
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
|
||||
else:
|
||||
data = [[(None, ip)] for ip in image_paths]
|
||||
if args.remove_words:
|
||||
captions = remove_words(captions, args.debug)
|
||||
|
||||
b_imgs = []
|
||||
for data_entry in tqdm(data, smoothing=0.0):
|
||||
for data in data_entry:
|
||||
if data is None:
|
||||
continue
|
||||
for (image_path, _), caption in zip(path_imgs, captions):
|
||||
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
|
||||
f.write(caption + "\n")
|
||||
if args.debug:
|
||||
print(image_path, caption)
|
||||
|
||||
image, image_path = data
|
||||
if image is None:
|
||||
try:
|
||||
image = Image.open(image_path)
|
||||
if image.mode != 'RGB':
|
||||
image = image.convert("RGB")
|
||||
except Exception as e:
|
||||
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||
continue
|
||||
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||
if args.max_data_loader_n_workers is not None:
|
||||
dataset = train_util.ImageLoadingDataset(image_paths)
|
||||
data = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=args.max_data_loader_n_workers,
|
||||
collate_fn=collate_fn_remove_corrupted,
|
||||
drop_last=False,
|
||||
)
|
||||
else:
|
||||
data = [[(None, ip)] for ip in image_paths]
|
||||
|
||||
b_imgs.append((image_path, image))
|
||||
if len(b_imgs) >= args.batch_size:
|
||||
b_imgs = []
|
||||
for data_entry in tqdm(data, smoothing=0.0):
|
||||
for data in data_entry:
|
||||
if data is None:
|
||||
continue
|
||||
|
||||
image, image_path = data
|
||||
if image is None:
|
||||
try:
|
||||
image = Image.open(image_path)
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
except Exception as e:
|
||||
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||
continue
|
||||
|
||||
b_imgs.append((image_path, image))
|
||||
if len(b_imgs) >= args.batch_size:
|
||||
run_batch(b_imgs)
|
||||
b_imgs.clear()
|
||||
|
||||
if len(b_imgs) > 0:
|
||||
run_batch(b_imgs)
|
||||
b_imgs.clear()
|
||||
|
||||
if len(b_imgs) > 0:
|
||||
run_batch(b_imgs)
|
||||
|
||||
print("done!")
|
||||
print("done!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
|
||||
parser.add_argument("--model_id", type=str, default="microsoft/git-large-textcaps",
|
||||
help="model id for GIT in Hugging Face / 使用するGITのHugging FaceのモデルID")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
|
||||
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
|
||||
parser.add_argument("--max_length", type=int, default=50, help="max length of caption / captionの最大長")
|
||||
parser.add_argument("--remove_words", action="store_true",
|
||||
help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する")
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
|
||||
parser.add_argument(
|
||||
"--model_id",
|
||||
type=str,
|
||||
default="microsoft/git-large-textcaps",
|
||||
help="model id for GIT in Hugging Face / 使用するGITのHugging FaceのモデルID",
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||
parser.add_argument(
|
||||
"--max_data_loader_n_workers",
|
||||
type=int,
|
||||
default=None,
|
||||
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)",
|
||||
)
|
||||
parser.add_argument("--max_length", type=int, default=50, help="max length of caption / captionの最大長")
|
||||
parser.add_argument(
|
||||
"--remove_words",
|
||||
action="store_true",
|
||||
help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する",
|
||||
)
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||
parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する")
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
@@ -4,7 +4,7 @@ from pathlib import Path
|
||||
from typing import List
|
||||
from tqdm import tqdm
|
||||
import library.train_util as train_util
|
||||
|
||||
import os
|
||||
|
||||
def main(args):
|
||||
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
|
||||
@@ -29,6 +29,9 @@ def main(args):
|
||||
caption_path = image_path.with_suffix(args.caption_extension)
|
||||
caption = caption_path.read_text(encoding='utf-8').strip()
|
||||
|
||||
if not os.path.exists(caption_path):
|
||||
caption_path = os.path.join(image_path, args.caption_extension)
|
||||
|
||||
image_key = str(image_path) if args.full_path else image_path.stem
|
||||
if image_key not in metadata:
|
||||
metadata[image_key] = {}
|
||||
@@ -43,7 +46,7 @@ def main(args):
|
||||
print("done!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
||||
@@ -58,6 +61,12 @@ if __name__ == '__main__':
|
||||
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# スペルミスしていたオプションを復元する
|
||||
|
||||
@@ -4,7 +4,7 @@ from pathlib import Path
|
||||
from typing import List
|
||||
from tqdm import tqdm
|
||||
import library.train_util as train_util
|
||||
|
||||
import os
|
||||
|
||||
def main(args):
|
||||
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
|
||||
@@ -29,6 +29,9 @@ def main(args):
|
||||
tags_path = image_path.with_suffix(args.caption_extension)
|
||||
tags = tags_path.read_text(encoding='utf-8').strip()
|
||||
|
||||
if not os.path.exists(tags_path):
|
||||
tags_path = os.path.join(image_path, args.caption_extension)
|
||||
|
||||
image_key = str(image_path) if args.full_path else image_path.stem
|
||||
if image_key not in metadata:
|
||||
metadata[image_key] = {}
|
||||
@@ -44,7 +47,7 @@ def main(args):
|
||||
print("done!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
||||
@@ -58,5 +61,11 @@ if __name__ == '__main__':
|
||||
help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子")
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode, print tags")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
@@ -2,6 +2,8 @@ import argparse
|
||||
import os
|
||||
import json
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
@@ -12,7 +14,7 @@ from torchvision import transforms
|
||||
import library.model_util as model_util
|
||||
import library.train_util as train_util
|
||||
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
IMAGE_TRANSFORMS = transforms.Compose(
|
||||
[
|
||||
@@ -23,239 +25,233 @@ IMAGE_TRANSFORMS = transforms.Compose(
|
||||
|
||||
|
||||
def collate_fn_remove_corrupted(batch):
|
||||
"""Collate function that allows to remove corrupted examples in the
|
||||
dataloader. It expects that the dataloader returns 'None' when that occurs.
|
||||
The 'None's in the batch are removed.
|
||||
"""
|
||||
# Filter out all the Nones (corrupted examples)
|
||||
batch = list(filter(lambda x: x is not None, batch))
|
||||
return batch
|
||||
"""Collate function that allows to remove corrupted examples in the
|
||||
dataloader. It expects that the dataloader returns 'None' when that occurs.
|
||||
The 'None's in the batch are removed.
|
||||
"""
|
||||
# Filter out all the Nones (corrupted examples)
|
||||
batch = list(filter(lambda x: x is not None, batch))
|
||||
return batch
|
||||
|
||||
|
||||
def get_latents(vae, images, weight_dtype):
|
||||
img_tensors = [IMAGE_TRANSFORMS(image) for image in images]
|
||||
img_tensors = torch.stack(img_tensors)
|
||||
img_tensors = img_tensors.to(DEVICE, weight_dtype)
|
||||
with torch.no_grad():
|
||||
latents = vae.encode(img_tensors).latent_dist.sample().float().to("cpu").numpy()
|
||||
return latents
|
||||
def get_npz_filename(data_dir, image_key, is_full_path, recursive):
|
||||
if is_full_path:
|
||||
base_name = os.path.splitext(os.path.basename(image_key))[0]
|
||||
relative_path = os.path.relpath(os.path.dirname(image_key), data_dir)
|
||||
else:
|
||||
base_name = image_key
|
||||
relative_path = ""
|
||||
|
||||
|
||||
def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip):
|
||||
if is_full_path:
|
||||
base_name = os.path.splitext(os.path.basename(image_key))[0]
|
||||
else:
|
||||
base_name = image_key
|
||||
if flip:
|
||||
base_name += '_flip'
|
||||
return os.path.join(data_dir, base_name)
|
||||
if recursive and relative_path:
|
||||
return os.path.join(data_dir, relative_path, base_name) + ".npz"
|
||||
else:
|
||||
return os.path.join(data_dir, base_name) + ".npz"
|
||||
|
||||
|
||||
def main(args):
|
||||
# assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります"
|
||||
if args.bucket_reso_steps % 8 > 0:
|
||||
print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります")
|
||||
# assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります"
|
||||
if args.bucket_reso_steps % 8 > 0:
|
||||
print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります")
|
||||
if args.bucket_reso_steps % 32 > 0:
|
||||
print(
|
||||
f"WARNING: bucket_reso_steps is not divisible by 32. It is not working with SDXL / bucket_reso_stepsが32で割り切れません。SDXLでは動作しません"
|
||||
)
|
||||
|
||||
image_paths = train_util.glob_images(args.train_data_dir)
|
||||
print(f"found {len(image_paths)} images.")
|
||||
train_data_dir_path = Path(args.train_data_dir)
|
||||
image_paths: List[str] = [str(p) for p in train_util.glob_images_pathlib(train_data_dir_path, args.recursive)]
|
||||
print(f"found {len(image_paths)} images.")
|
||||
|
||||
if os.path.exists(args.in_json):
|
||||
print(f"loading existing metadata: {args.in_json}")
|
||||
with open(args.in_json, "rt", encoding='utf-8') as f:
|
||||
metadata = json.load(f)
|
||||
else:
|
||||
print(f"no metadata / メタデータファイルがありません: {args.in_json}")
|
||||
return
|
||||
|
||||
weight_dtype = torch.float32
|
||||
if args.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif args.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
vae = model_util.load_vae(args.model_name_or_path, weight_dtype)
|
||||
vae.eval()
|
||||
vae.to(DEVICE, dtype=weight_dtype)
|
||||
|
||||
# bucketのサイズを計算する
|
||||
max_reso = tuple([int(t) for t in args.max_resolution.split(',')])
|
||||
assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
|
||||
|
||||
bucket_manager = train_util.BucketManager(args.bucket_no_upscale, max_reso,
|
||||
args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps)
|
||||
if not args.bucket_no_upscale:
|
||||
bucket_manager.make_buckets()
|
||||
else:
|
||||
print("min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます")
|
||||
|
||||
# 画像をひとつずつ適切なbucketに割り当てながらlatentを計算する
|
||||
img_ar_errors = []
|
||||
|
||||
def process_batch(is_last):
|
||||
for bucket in bucket_manager.buckets:
|
||||
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
|
||||
latents = get_latents(vae, [img for _, img in bucket], weight_dtype)
|
||||
assert latents.shape[2] == bucket[0][1].shape[0] // 8 and latents.shape[3] == bucket[0][1].shape[1] // 8, \
|
||||
f"latent shape {latents.shape}, {bucket[0][1].shape}"
|
||||
|
||||
for (image_key, _), latent in zip(bucket, latents):
|
||||
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False)
|
||||
np.savez(npz_file_name, latent)
|
||||
|
||||
# flip
|
||||
if args.flip_aug:
|
||||
latents = get_latents(vae, [img[:, ::-1].copy() for _, img in bucket], weight_dtype) # copyがないとTensor変換できない
|
||||
|
||||
for (image_key, _), latent in zip(bucket, latents):
|
||||
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True)
|
||||
np.savez(npz_file_name, latent)
|
||||
else:
|
||||
# remove existing flipped npz
|
||||
for image_key, _ in bucket:
|
||||
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz"
|
||||
if os.path.isfile(npz_file_name):
|
||||
print(f"remove existing flipped npz / 既存のflipされたnpzファイルを削除します: {npz_file_name}")
|
||||
os.remove(npz_file_name)
|
||||
|
||||
bucket.clear()
|
||||
|
||||
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||
if args.max_data_loader_n_workers is not None:
|
||||
dataset = train_util.ImageLoadingDataset(image_paths)
|
||||
data = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False,
|
||||
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
|
||||
else:
|
||||
data = [[(None, ip)] for ip in image_paths]
|
||||
|
||||
bucket_counts = {}
|
||||
for data_entry in tqdm(data, smoothing=0.0):
|
||||
if data_entry[0] is None:
|
||||
continue
|
||||
|
||||
img_tensor, image_path = data_entry[0]
|
||||
if img_tensor is not None:
|
||||
image = transforms.functional.to_pil_image(img_tensor)
|
||||
if os.path.exists(args.in_json):
|
||||
print(f"loading existing metadata: {args.in_json}")
|
||||
with open(args.in_json, "rt", encoding="utf-8") as f:
|
||||
metadata = json.load(f)
|
||||
else:
|
||||
try:
|
||||
image = Image.open(image_path)
|
||||
if image.mode != 'RGB':
|
||||
image = image.convert("RGB")
|
||||
except Exception as e:
|
||||
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||
continue
|
||||
print(f"no metadata / メタデータファイルがありません: {args.in_json}")
|
||||
return
|
||||
|
||||
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
|
||||
if image_key not in metadata:
|
||||
metadata[image_key] = {}
|
||||
weight_dtype = torch.float32
|
||||
if args.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif args.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
# 本当はこのあとの部分もDataSetに持っていけば高速化できるがいろいろ大変
|
||||
vae = model_util.load_vae(args.model_name_or_path, weight_dtype)
|
||||
vae.eval()
|
||||
vae.to(DEVICE, dtype=weight_dtype)
|
||||
|
||||
reso, resized_size, ar_error = bucket_manager.select_bucket(image.width, image.height)
|
||||
img_ar_errors.append(abs(ar_error))
|
||||
bucket_counts[reso] = bucket_counts.get(reso, 0) + 1
|
||||
|
||||
# メタデータに記録する解像度はlatent単位とするので、8単位で切り捨て
|
||||
metadata[image_key]['train_resolution'] = (reso[0] - reso[0] % 8, reso[1] - reso[1] % 8)
|
||||
# bucketのサイズを計算する
|
||||
max_reso = tuple([int(t) for t in args.max_resolution.split(",")])
|
||||
assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
|
||||
|
||||
bucket_manager = train_util.BucketManager(
|
||||
args.bucket_no_upscale, max_reso, args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps
|
||||
)
|
||||
if not args.bucket_no_upscale:
|
||||
# upscaleを行わないときには、resize後のサイズは、bucketのサイズと、縦横どちらかが同じであることを確認する
|
||||
assert resized_size[0] == reso[0] or resized_size[1] == reso[
|
||||
1], f"internal error, resized size not match: {reso}, {resized_size}, {image.width}, {image.height}"
|
||||
assert resized_size[0] >= reso[0] and resized_size[1] >= reso[
|
||||
1], f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}"
|
||||
bucket_manager.make_buckets()
|
||||
else:
|
||||
print(
|
||||
"min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます"
|
||||
)
|
||||
|
||||
assert resized_size[0] >= reso[0] and resized_size[1] >= reso[
|
||||
1], f"internal error resized size is small: {resized_size}, {reso}"
|
||||
# 画像をひとつずつ適切なbucketに割り当てながらlatentを計算する
|
||||
img_ar_errors = []
|
||||
|
||||
# 既に存在するファイルがあればshapeを確認して同じならskipする
|
||||
if args.skip_existing:
|
||||
npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) + ".npz"]
|
||||
if args.flip_aug:
|
||||
npz_files.append(get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz")
|
||||
def process_batch(is_last):
|
||||
for bucket in bucket_manager.buckets:
|
||||
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
|
||||
train_util.cache_batch_latents(vae, True, bucket, args.flip_aug, False)
|
||||
bucket.clear()
|
||||
|
||||
found = True
|
||||
for npz_file in npz_files:
|
||||
if not os.path.exists(npz_file):
|
||||
found = False
|
||||
break
|
||||
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||
if args.max_data_loader_n_workers is not None:
|
||||
dataset = train_util.ImageLoadingDataset(image_paths)
|
||||
data = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=args.max_data_loader_n_workers,
|
||||
collate_fn=collate_fn_remove_corrupted,
|
||||
drop_last=False,
|
||||
)
|
||||
else:
|
||||
data = [[(None, ip)] for ip in image_paths]
|
||||
|
||||
dat = np.load(npz_file)['arr_0']
|
||||
if dat.shape[1] != reso[1] // 8 or dat.shape[2] != reso[0] // 8: # latentsのshapeを確認
|
||||
found = False
|
||||
break
|
||||
if found:
|
||||
continue
|
||||
bucket_counts = {}
|
||||
for data_entry in tqdm(data, smoothing=0.0):
|
||||
if data_entry[0] is None:
|
||||
continue
|
||||
|
||||
# 画像をリサイズしてトリミングする
|
||||
# PILにinter_areaがないのでcv2で……
|
||||
image = np.array(image)
|
||||
if resized_size[0] != image.shape[1] or resized_size[1] != image.shape[0]: # リサイズ処理が必要?
|
||||
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA)
|
||||
img_tensor, image_path = data_entry[0]
|
||||
if img_tensor is not None:
|
||||
image = transforms.functional.to_pil_image(img_tensor)
|
||||
else:
|
||||
try:
|
||||
image = Image.open(image_path)
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
except Exception as e:
|
||||
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||
continue
|
||||
|
||||
if resized_size[0] > reso[0]:
|
||||
trim_size = resized_size[0] - reso[0]
|
||||
image = image[:, trim_size//2:trim_size//2 + reso[0]]
|
||||
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
|
||||
if image_key not in metadata:
|
||||
metadata[image_key] = {}
|
||||
|
||||
if resized_size[1] > reso[1]:
|
||||
trim_size = resized_size[1] - reso[1]
|
||||
image = image[trim_size//2:trim_size//2 + reso[1]]
|
||||
# 本当はこのあとの部分もDataSetに持っていけば高速化できるがいろいろ大変
|
||||
|
||||
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
|
||||
reso, resized_size, ar_error = bucket_manager.select_bucket(image.width, image.height)
|
||||
img_ar_errors.append(abs(ar_error))
|
||||
bucket_counts[reso] = bucket_counts.get(reso, 0) + 1
|
||||
|
||||
# # debug
|
||||
# cv2.imwrite(f"r:\\test\\img_{len(img_ar_errors)}.jpg", image[:, :, ::-1])
|
||||
# メタデータに記録する解像度はlatent単位とするので、8単位で切り捨て
|
||||
metadata[image_key]["train_resolution"] = (reso[0] - reso[0] % 8, reso[1] - reso[1] % 8)
|
||||
|
||||
# バッチへ追加
|
||||
bucket_manager.add_image(reso, (image_key, image))
|
||||
if not args.bucket_no_upscale:
|
||||
# upscaleを行わないときには、resize後のサイズは、bucketのサイズと、縦横どちらかが同じであることを確認する
|
||||
assert (
|
||||
resized_size[0] == reso[0] or resized_size[1] == reso[1]
|
||||
), f"internal error, resized size not match: {reso}, {resized_size}, {image.width}, {image.height}"
|
||||
assert (
|
||||
resized_size[0] >= reso[0] and resized_size[1] >= reso[1]
|
||||
), f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}"
|
||||
|
||||
# バッチを推論するか判定して推論する
|
||||
process_batch(False)
|
||||
assert (
|
||||
resized_size[0] >= reso[0] and resized_size[1] >= reso[1]
|
||||
), f"internal error resized size is small: {resized_size}, {reso}"
|
||||
|
||||
# 残りを処理する
|
||||
process_batch(True)
|
||||
# 既に存在するファイルがあればshape等を確認して同じならskipする
|
||||
npz_file_name = get_npz_filename(args.train_data_dir, image_key, args.full_path, args.recursive)
|
||||
if args.skip_existing:
|
||||
if train_util.is_disk_cached_latents_is_expected(reso, npz_file_name, args.flip_aug):
|
||||
continue
|
||||
|
||||
bucket_manager.sort()
|
||||
for i, reso in enumerate(bucket_manager.resos):
|
||||
count = bucket_counts.get(reso, 0)
|
||||
if count > 0:
|
||||
print(f"bucket {i} {reso}: {count}")
|
||||
img_ar_errors = np.array(img_ar_errors)
|
||||
print(f"mean ar error: {np.mean(img_ar_errors)}")
|
||||
# バッチへ追加
|
||||
image_info = train_util.ImageInfo(image_key, 1, "", False, image_path)
|
||||
image_info.latents_npz = npz_file_name
|
||||
image_info.bucket_reso = reso
|
||||
image_info.resized_size = resized_size
|
||||
image_info.image = image
|
||||
bucket_manager.add_image(reso, image_info)
|
||||
|
||||
# metadataを書き出して終わり
|
||||
print(f"writing metadata: {args.out_json}")
|
||||
with open(args.out_json, "wt", encoding='utf-8') as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
print("done!")
|
||||
# バッチを推論するか判定して推論する
|
||||
process_batch(False)
|
||||
|
||||
# 残りを処理する
|
||||
process_batch(True)
|
||||
|
||||
bucket_manager.sort()
|
||||
for i, reso in enumerate(bucket_manager.resos):
|
||||
count = bucket_counts.get(reso, 0)
|
||||
if count > 0:
|
||||
print(f"bucket {i} {reso}: {count}")
|
||||
img_ar_errors = np.array(img_ar_errors)
|
||||
print(f"mean ar error: {np.mean(img_ar_errors)}")
|
||||
|
||||
# metadataを書き出して終わり
|
||||
print(f"writing metadata: {args.out_json}")
|
||||
with open(args.out_json, "wt", encoding="utf-8") as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
print("done!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
|
||||
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
||||
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
|
||||
parser.add_argument("--v2", action='store_true',
|
||||
help='not used (for backward compatibility) / 使用されません(互換性のため残してあります)')
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
|
||||
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
|
||||
parser.add_argument("--max_resolution", type=str, default="512,512",
|
||||
help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)")
|
||||
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
|
||||
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度")
|
||||
parser.add_argument("--bucket_reso_steps", type=int, default=64,
|
||||
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します")
|
||||
parser.add_argument("--bucket_no_upscale", action="store_true",
|
||||
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します")
|
||||
parser.add_argument("--mixed_precision", type=str, default="no",
|
||||
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
|
||||
parser.add_argument("--full_path", action="store_true",
|
||||
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
|
||||
parser.add_argument("--flip_aug", action="store_true",
|
||||
help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する")
|
||||
parser.add_argument("--skip_existing", action="store_true",
|
||||
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)")
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
|
||||
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
||||
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
|
||||
parser.add_argument("--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||
parser.add_argument(
|
||||
"--max_data_loader_n_workers",
|
||||
type=int,
|
||||
default=None,
|
||||
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_resolution",
|
||||
type=str,
|
||||
default="512,512",
|
||||
help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)",
|
||||
)
|
||||
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
|
||||
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度")
|
||||
parser.add_argument(
|
||||
"--bucket_reso_steps",
|
||||
type=int,
|
||||
default=64,
|
||||
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--full_path",
|
||||
action="store_true",
|
||||
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--flip_aug", action="store_true", help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_existing",
|
||||
action="store_true",
|
||||
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recursive",
|
||||
action="store_true",
|
||||
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
import argparse
|
||||
import csv
|
||||
import glob
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
import cv2
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from tensorflow.keras.models import load_model
|
||||
from huggingface_hub import hf_hub_download
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
import library.train_util as train_util
|
||||
|
||||
@@ -17,184 +16,364 @@ import library.train_util as train_util
|
||||
IMAGE_SIZE = 448
|
||||
|
||||
# wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2
|
||||
DEFAULT_WD14_TAGGER_REPO = 'SmilingWolf/wd-v1-4-convnext-tagger-v2'
|
||||
DEFAULT_WD14_TAGGER_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
|
||||
FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
|
||||
FILES_ONNX = ["model.onnx"]
|
||||
SUB_DIR = "variables"
|
||||
SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
|
||||
CSV_FILE = FILES[-1]
|
||||
|
||||
|
||||
def preprocess_image(image):
|
||||
image = np.array(image)
|
||||
image = image[:, :, ::-1] # RGB->BGR
|
||||
image = np.array(image)
|
||||
image = image[:, :, ::-1] # RGB->BGR
|
||||
|
||||
# pad to square
|
||||
size = max(image.shape[0:2])
|
||||
pad_x = size - image.shape[1]
|
||||
pad_y = size - image.shape[0]
|
||||
pad_l = pad_x // 2
|
||||
pad_t = pad_y // 2
|
||||
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode='constant', constant_values=255)
|
||||
# pad to square
|
||||
size = max(image.shape[0:2])
|
||||
pad_x = size - image.shape[1]
|
||||
pad_y = size - image.shape[0]
|
||||
pad_l = pad_x // 2
|
||||
pad_t = pad_y // 2
|
||||
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)
|
||||
|
||||
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
|
||||
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
|
||||
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
|
||||
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
|
||||
|
||||
image = image.astype(np.float32)
|
||||
return image
|
||||
image = image.astype(np.float32)
|
||||
return image
|
||||
|
||||
|
||||
class ImageLoadingPrepDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, image_paths):
|
||||
self.images = image_paths
|
||||
def __init__(self, image_paths):
|
||||
self.images = image_paths
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img_path = self.images[idx]
|
||||
def __getitem__(self, idx):
|
||||
img_path = str(self.images[idx])
|
||||
|
||||
try:
|
||||
image = Image.open(img_path).convert("RGB")
|
||||
image = preprocess_image(image)
|
||||
tensor = torch.tensor(image)
|
||||
except Exception as e:
|
||||
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
||||
return None
|
||||
try:
|
||||
image = Image.open(img_path).convert("RGB")
|
||||
image = preprocess_image(image)
|
||||
tensor = torch.tensor(image)
|
||||
except Exception as e:
|
||||
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
||||
return None
|
||||
|
||||
return (tensor, img_path)
|
||||
return (tensor, img_path)
|
||||
|
||||
|
||||
def collate_fn_remove_corrupted(batch):
|
||||
"""Collate function that allows to remove corrupted examples in the
|
||||
dataloader. It expects that the dataloader returns 'None' when that occurs.
|
||||
The 'None's in the batch are removed.
|
||||
"""
|
||||
# Filter out all the Nones (corrupted examples)
|
||||
batch = list(filter(lambda x: x is not None, batch))
|
||||
return batch
|
||||
"""Collate function that allows to remove corrupted examples in the
|
||||
dataloader. It expects that the dataloader returns 'None' when that occurs.
|
||||
The 'None's in the batch are removed.
|
||||
"""
|
||||
# Filter out all the Nones (corrupted examples)
|
||||
batch = list(filter(lambda x: x is not None, batch))
|
||||
return batch
|
||||
|
||||
|
||||
def main(args):
|
||||
# hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする
|
||||
# depreacatedの警告が出るけどなくなったらその時
|
||||
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
|
||||
if not os.path.exists(args.model_dir) or args.force_download:
|
||||
print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
|
||||
for file in FILES:
|
||||
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
|
||||
for file in SUB_DIR_FILES:
|
||||
hf_hub_download(args.repo_id, file, subfolder=SUB_DIR, cache_dir=os.path.join(
|
||||
args.model_dir, SUB_DIR), force_download=True, force_filename=file)
|
||||
else:
|
||||
print("using existing wd14 tagger model")
|
||||
# hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする
|
||||
# depreacatedの警告が出るけどなくなったらその時
|
||||
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
|
||||
if not os.path.exists(args.model_dir) or args.force_download:
|
||||
print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
|
||||
files = FILES
|
||||
if args.onnx:
|
||||
files += FILES_ONNX
|
||||
for file in files:
|
||||
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
|
||||
for file in SUB_DIR_FILES:
|
||||
hf_hub_download(
|
||||
args.repo_id,
|
||||
file,
|
||||
subfolder=SUB_DIR,
|
||||
cache_dir=os.path.join(args.model_dir, SUB_DIR),
|
||||
force_download=True,
|
||||
force_filename=file,
|
||||
)
|
||||
else:
|
||||
print("using existing wd14 tagger model")
|
||||
|
||||
# 画像を読み込む
|
||||
image_paths = train_util.glob_images(args.train_data_dir)
|
||||
print(f"found {len(image_paths)} images.")
|
||||
# 画像を読み込む
|
||||
if args.onnx:
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
|
||||
print("loading model and labels")
|
||||
model = load_model(args.model_dir)
|
||||
onnx_path = f"{args.model_dir}/model.onnx"
|
||||
print("Running wd14 tagger with onnx")
|
||||
print(f"loading onnx model: {onnx_path}")
|
||||
|
||||
# label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
|
||||
# 依存ライブラリを増やしたくないので自力で読むよ
|
||||
with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f:
|
||||
reader = csv.reader(f)
|
||||
l = [row for row in reader]
|
||||
header = l[0] # tag_id,name,category,count
|
||||
rows = l[1:]
|
||||
assert header[0] == 'tag_id' and header[1] == 'name' and header[2] == 'category', f"unexpected csv format: {header}"
|
||||
if not os.path.exists(onnx_path):
|
||||
raise Exception(
|
||||
f"onnx model not found: {onnx_path}, please redownload the model with --force_download"
|
||||
+ " / onnxモデルが見つかりませんでした。--force_downloadで再ダウンロードしてください"
|
||||
)
|
||||
|
||||
tags = [row[1] for row in rows[1:] if row[2] == '0'] # categoryが0、つまり通常のタグのみ
|
||||
|
||||
# 推論する
|
||||
def run_batch(path_imgs):
|
||||
imgs = np.array([im for _, im in path_imgs])
|
||||
|
||||
probs = model(imgs, training=False)
|
||||
probs = probs.numpy()
|
||||
|
||||
for (image_path, _), prob in zip(path_imgs, probs):
|
||||
# 最初の4つはratingなので無視する
|
||||
# # First 4 labels are actually ratings: pick one with argmax
|
||||
# ratings_names = label_names[:4]
|
||||
# rating_index = ratings_names["probs"].argmax()
|
||||
# found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]]
|
||||
|
||||
# それ以降はタグなのでconfidenceがthresholdより高いものを追加する
|
||||
# Everything else is tags: pick any where prediction confidence > threshold
|
||||
tag_text = ""
|
||||
for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで
|
||||
if p >= args.thresh and i < len(tags):
|
||||
tag_text += ", " + tags[i]
|
||||
|
||||
if len(tag_text) > 0:
|
||||
tag_text = tag_text[2:] # 最初の ", " を消す
|
||||
|
||||
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
|
||||
f.write(tag_text + '\n')
|
||||
if args.debug:
|
||||
print(image_path, tag_text)
|
||||
|
||||
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||
if args.max_data_loader_n_workers is not None:
|
||||
dataset = ImageLoadingPrepDataset(image_paths)
|
||||
data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
|
||||
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
|
||||
else:
|
||||
data = [[(None, ip)] for ip in image_paths]
|
||||
|
||||
b_imgs = []
|
||||
for data_entry in tqdm(data, smoothing=0.0):
|
||||
for data in data_entry:
|
||||
if data is None:
|
||||
continue
|
||||
|
||||
image, image_path = data
|
||||
if image is not None:
|
||||
image = image.detach().numpy()
|
||||
else:
|
||||
model = onnx.load(onnx_path)
|
||||
input_name = model.graph.input[0].name
|
||||
try:
|
||||
image = Image.open(image_path)
|
||||
if image.mode != 'RGB':
|
||||
image = image.convert("RGB")
|
||||
image = preprocess_image(image)
|
||||
except Exception as e:
|
||||
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||
continue
|
||||
b_imgs.append((image_path, image))
|
||||
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_value
|
||||
except:
|
||||
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_param
|
||||
|
||||
if len(b_imgs) >= args.batch_size:
|
||||
if args.batch_size != batch_size and type(batch_size) != str:
|
||||
# some rebatch model may use 'N' as dynamic axes
|
||||
print(
|
||||
f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}"
|
||||
)
|
||||
args.batch_size = batch_size
|
||||
|
||||
del model
|
||||
|
||||
ort_sess = ort.InferenceSession(
|
||||
onnx_path,
|
||||
providers=["CUDAExecutionProvider"]
|
||||
if "CUDAExecutionProvider" in ort.get_available_providers()
|
||||
else ["CPUExecutionProvider"],
|
||||
)
|
||||
else:
|
||||
from tensorflow.keras.models import load_model
|
||||
|
||||
model = load_model(f"{args.model_dir}")
|
||||
|
||||
# label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
|
||||
# 依存ライブラリを増やしたくないので自力で読むよ
|
||||
|
||||
with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f:
|
||||
reader = csv.reader(f)
|
||||
l = [row for row in reader]
|
||||
header = l[0] # tag_id,name,category,count
|
||||
rows = l[1:]
|
||||
assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"
|
||||
|
||||
general_tags = [row[1] for row in rows[1:] if row[2] == "0"]
|
||||
character_tags = [row[1] for row in rows[1:] if row[2] == "4"]
|
||||
|
||||
# 画像を読み込む
|
||||
|
||||
train_data_dir_path = Path(args.train_data_dir)
|
||||
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
||||
print(f"found {len(image_paths)} images.")
|
||||
|
||||
tag_freq = {}
|
||||
|
||||
caption_separator = args.caption_separator
|
||||
stripped_caption_separator = caption_separator.strip()
|
||||
undesired_tags = set(args.undesired_tags.split(stripped_caption_separator))
|
||||
|
||||
def run_batch(path_imgs):
|
||||
imgs = np.array([im for _, im in path_imgs])
|
||||
|
||||
if args.onnx:
|
||||
if len(imgs) < args.batch_size:
|
||||
imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0)
|
||||
probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy
|
||||
probs = probs[: len(path_imgs)]
|
||||
else:
|
||||
probs = model(imgs, training=False)
|
||||
probs = probs.numpy()
|
||||
|
||||
for (image_path, _), prob in zip(path_imgs, probs):
|
||||
# 最初の4つはratingなので無視する
|
||||
# # First 4 labels are actually ratings: pick one with argmax
|
||||
# ratings_names = label_names[:4]
|
||||
# rating_index = ratings_names["probs"].argmax()
|
||||
# found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]]
|
||||
|
||||
# それ以降はタグなのでconfidenceがthresholdより高いものを追加する
|
||||
# Everything else is tags: pick any where prediction confidence > threshold
|
||||
combined_tags = []
|
||||
general_tag_text = ""
|
||||
character_tag_text = ""
|
||||
for i, p in enumerate(prob[4:]):
|
||||
if i < len(general_tags) and p >= args.general_threshold:
|
||||
tag_name = general_tags[i]
|
||||
if args.remove_underscore and len(tag_name) > 3: # ignore emoji tags like >_< and ^_^
|
||||
tag_name = tag_name.replace("_", " ")
|
||||
|
||||
if tag_name not in undesired_tags:
|
||||
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
||||
general_tag_text += caption_separator + tag_name
|
||||
combined_tags.append(tag_name)
|
||||
elif i >= len(general_tags) and p >= args.character_threshold:
|
||||
tag_name = character_tags[i - len(general_tags)]
|
||||
if args.remove_underscore and len(tag_name) > 3:
|
||||
tag_name = tag_name.replace("_", " ")
|
||||
|
||||
if tag_name not in undesired_tags:
|
||||
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
||||
character_tag_text += caption_separator + tag_name
|
||||
combined_tags.append(tag_name)
|
||||
|
||||
# 先頭のカンマを取る
|
||||
if len(general_tag_text) > 0:
|
||||
general_tag_text = general_tag_text[len(caption_separator) :]
|
||||
if len(character_tag_text) > 0:
|
||||
character_tag_text = character_tag_text[len(caption_separator) :]
|
||||
|
||||
caption_file = os.path.splitext(image_path)[0] + args.caption_extension
|
||||
|
||||
tag_text = caption_separator.join(combined_tags)
|
||||
|
||||
if args.append_tags:
|
||||
# Check if file exists
|
||||
if os.path.exists(caption_file):
|
||||
with open(caption_file, "rt", encoding="utf-8") as f:
|
||||
# Read file and remove new lines
|
||||
existing_content = f.read().strip("\n") # Remove newlines
|
||||
|
||||
# Split the content into tags and store them in a list
|
||||
existing_tags = [tag.strip() for tag in existing_content.split(stripped_caption_separator) if tag.strip()]
|
||||
|
||||
# Check and remove repeating tags in tag_text
|
||||
new_tags = [tag for tag in combined_tags if tag not in existing_tags]
|
||||
|
||||
# Create new tag_text
|
||||
tag_text = caption_separator.join(existing_tags + new_tags)
|
||||
|
||||
with open(caption_file, "wt", encoding="utf-8") as f:
|
||||
f.write(tag_text + "\n")
|
||||
if args.debug:
|
||||
print(f"\n{image_path}:\n Character tags: {character_tag_text}\n General tags: {general_tag_text}")
|
||||
|
||||
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||
if args.max_data_loader_n_workers is not None:
|
||||
dataset = ImageLoadingPrepDataset(image_paths)
|
||||
data = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=args.max_data_loader_n_workers,
|
||||
collate_fn=collate_fn_remove_corrupted,
|
||||
drop_last=False,
|
||||
)
|
||||
else:
|
||||
data = [[(None, ip)] for ip in image_paths]
|
||||
|
||||
b_imgs = []
|
||||
for data_entry in tqdm(data, smoothing=0.0):
|
||||
for data in data_entry:
|
||||
if data is None:
|
||||
continue
|
||||
|
||||
image, image_path = data
|
||||
if image is not None:
|
||||
image = image.detach().numpy()
|
||||
else:
|
||||
try:
|
||||
image = Image.open(image_path)
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
image = preprocess_image(image)
|
||||
except Exception as e:
|
||||
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||
continue
|
||||
b_imgs.append((image_path, image))
|
||||
|
||||
if len(b_imgs) >= args.batch_size:
|
||||
b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string
|
||||
run_batch(b_imgs)
|
||||
b_imgs.clear()
|
||||
|
||||
if len(b_imgs) > 0:
|
||||
b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string
|
||||
run_batch(b_imgs)
|
||||
b_imgs.clear()
|
||||
|
||||
if len(b_imgs) > 0:
|
||||
run_batch(b_imgs)
|
||||
if args.frequency_tags:
|
||||
sorted_tags = sorted(tag_freq.items(), key=lambda x: x[1], reverse=True)
|
||||
print("\nTag frequencies:")
|
||||
for tag, freq in sorted_tags:
|
||||
print(f"{tag}: {freq}")
|
||||
|
||||
print("done!")
|
||||
print("done!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("--repo_id", type=str, default=DEFAULT_WD14_TAGGER_REPO,
|
||||
help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID")
|
||||
parser.add_argument("--model_dir", type=str, default="wd14_tagger_model",
|
||||
help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ")
|
||||
parser.add_argument("--force_download", action='store_true',
|
||||
help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします")
|
||||
parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
|
||||
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
|
||||
parser.add_argument("--caption_extention", type=str, default=None,
|
||||
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
|
||||
parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子")
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument(
|
||||
"--repo_id",
|
||||
type=str,
|
||||
default=DEFAULT_WD14_TAGGER_REPO,
|
||||
help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_dir",
|
||||
type=str,
|
||||
default="wd14_tagger_model",
|
||||
help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force_download", action="store_true", help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします"
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||
parser.add_argument(
|
||||
"--max_data_loader_n_workers",
|
||||
type=int,
|
||||
default=None,
|
||||
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--caption_extention",
|
||||
type=str,
|
||||
default=None,
|
||||
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)",
|
||||
)
|
||||
parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子")
|
||||
parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値")
|
||||
parser.add_argument(
|
||||
"--general_threshold",
|
||||
type=float,
|
||||
default=None,
|
||||
help="threshold of confidence to add a tag for general category, same as --thresh if omitted / generalカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--character_threshold",
|
||||
type=float,
|
||||
default=None,
|
||||
help="threshold of confidence to add a tag for character category, same as --thres if omitted / characterカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ",
|
||||
)
|
||||
parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する")
|
||||
parser.add_argument(
|
||||
"--remove_underscore",
|
||||
action="store_true",
|
||||
help="replace underscores with spaces in the output tags / 出力されるタグのアンダースコアをスペースに置き換える",
|
||||
)
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||
parser.add_argument(
|
||||
"--undesired_tags",
|
||||
type=str,
|
||||
default="",
|
||||
help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト",
|
||||
)
|
||||
parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する")
|
||||
parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する")
|
||||
parser.add_argument("--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する")
|
||||
parser.add_argument(
|
||||
"--caption_separator",
|
||||
type=str,
|
||||
default=", ",
|
||||
help="Separator for captions, include space if needed / キャプションの区切り文字、必要ならスペースを含めてください",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return parser
|
||||
|
||||
# スペルミスしていたオプションを復元する
|
||||
if args.caption_extention is not None:
|
||||
args.caption_extension = args.caption_extention
|
||||
|
||||
main(args)
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# スペルミスしていたオプションを復元する
|
||||
if args.caption_extention is not None:
|
||||
args.caption_extension = args.caption_extention
|
||||
|
||||
if args.general_threshold is None:
|
||||
args.general_threshold = args.thresh
|
||||
if args.character_threshold is None:
|
||||
args.character_threshold = args.thresh
|
||||
|
||||
main(args)
|
||||
|
||||
5564
gen_img_diffusers.py
5564
gen_img_diffusers.py
File diff suppressed because it is too large
Load Diff
227
library/attention_processors.py
Normal file
227
library/attention_processors.py
Normal file
@@ -0,0 +1,227 @@
|
||||
import math
|
||||
from typing import Any
|
||||
from einops import rearrange
|
||||
import torch
|
||||
from diffusers.models.attention_processor import Attention
|
||||
|
||||
|
||||
# flash attention forwards and backwards
|
||||
|
||||
# https://arxiv.org/abs/2205.14135
|
||||
|
||||
EPSILON = 1e-6
|
||||
|
||||
|
||||
class FlashAttentionFunction(torch.autograd.function.Function):
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
|
||||
"""Algorithm 2 in the paper"""
|
||||
|
||||
device = q.device
|
||||
dtype = q.dtype
|
||||
max_neg_value = -torch.finfo(q.dtype).max
|
||||
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
||||
|
||||
o = torch.zeros_like(q)
|
||||
all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
|
||||
all_row_maxes = torch.full(
|
||||
(*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
scale = q.shape[-1] ** -0.5
|
||||
|
||||
if mask is None:
|
||||
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
|
||||
else:
|
||||
mask = rearrange(mask, "b n -> b 1 1 n")
|
||||
mask = mask.split(q_bucket_size, dim=-1)
|
||||
|
||||
row_splits = zip(
|
||||
q.split(q_bucket_size, dim=-2),
|
||||
o.split(q_bucket_size, dim=-2),
|
||||
mask,
|
||||
all_row_sums.split(q_bucket_size, dim=-2),
|
||||
all_row_maxes.split(q_bucket_size, dim=-2),
|
||||
)
|
||||
|
||||
for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
|
||||
q_start_index = ind * q_bucket_size - qk_len_diff
|
||||
|
||||
col_splits = zip(
|
||||
k.split(k_bucket_size, dim=-2),
|
||||
v.split(k_bucket_size, dim=-2),
|
||||
)
|
||||
|
||||
for k_ind, (kc, vc) in enumerate(col_splits):
|
||||
k_start_index = k_ind * k_bucket_size
|
||||
|
||||
attn_weights = (
|
||||
torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
|
||||
)
|
||||
|
||||
if row_mask is not None:
|
||||
attn_weights.masked_fill_(~row_mask, max_neg_value)
|
||||
|
||||
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
||||
causal_mask = torch.ones(
|
||||
(qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
|
||||
).triu(q_start_index - k_start_index + 1)
|
||||
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
||||
|
||||
block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
|
||||
attn_weights -= block_row_maxes
|
||||
exp_weights = torch.exp(attn_weights)
|
||||
|
||||
if row_mask is not None:
|
||||
exp_weights.masked_fill_(~row_mask, 0.0)
|
||||
|
||||
block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(
|
||||
min=EPSILON
|
||||
)
|
||||
|
||||
new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
|
||||
|
||||
exp_values = torch.einsum(
|
||||
"... i j, ... j d -> ... i d", exp_weights, vc
|
||||
)
|
||||
|
||||
exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
|
||||
exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
|
||||
|
||||
new_row_sums = (
|
||||
exp_row_max_diff * row_sums
|
||||
+ exp_block_row_max_diff * block_row_sums
|
||||
)
|
||||
|
||||
oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_(
|
||||
(exp_block_row_max_diff / new_row_sums) * exp_values
|
||||
)
|
||||
|
||||
row_maxes.copy_(new_row_maxes)
|
||||
row_sums.copy_(new_row_sums)
|
||||
|
||||
ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
|
||||
ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
|
||||
|
||||
return o
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def backward(ctx, do):
|
||||
"""Algorithm 4 in the paper"""
|
||||
|
||||
causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
|
||||
q, k, v, o, l, m = ctx.saved_tensors
|
||||
|
||||
device = q.device
|
||||
|
||||
max_neg_value = -torch.finfo(q.dtype).max
|
||||
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
||||
|
||||
dq = torch.zeros_like(q)
|
||||
dk = torch.zeros_like(k)
|
||||
dv = torch.zeros_like(v)
|
||||
|
||||
row_splits = zip(
|
||||
q.split(q_bucket_size, dim=-2),
|
||||
o.split(q_bucket_size, dim=-2),
|
||||
do.split(q_bucket_size, dim=-2),
|
||||
mask,
|
||||
l.split(q_bucket_size, dim=-2),
|
||||
m.split(q_bucket_size, dim=-2),
|
||||
dq.split(q_bucket_size, dim=-2),
|
||||
)
|
||||
|
||||
for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
|
||||
q_start_index = ind * q_bucket_size - qk_len_diff
|
||||
|
||||
col_splits = zip(
|
||||
k.split(k_bucket_size, dim=-2),
|
||||
v.split(k_bucket_size, dim=-2),
|
||||
dk.split(k_bucket_size, dim=-2),
|
||||
dv.split(k_bucket_size, dim=-2),
|
||||
)
|
||||
|
||||
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
|
||||
k_start_index = k_ind * k_bucket_size
|
||||
|
||||
attn_weights = (
|
||||
torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
|
||||
)
|
||||
|
||||
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
||||
causal_mask = torch.ones(
|
||||
(qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
|
||||
).triu(q_start_index - k_start_index + 1)
|
||||
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
||||
|
||||
exp_attn_weights = torch.exp(attn_weights - mc)
|
||||
|
||||
if row_mask is not None:
|
||||
exp_attn_weights.masked_fill_(~row_mask, 0.0)
|
||||
|
||||
p = exp_attn_weights / lc
|
||||
|
||||
dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc)
|
||||
dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc)
|
||||
|
||||
D = (doc * oc).sum(dim=-1, keepdims=True)
|
||||
ds = p * scale * (dp - D)
|
||||
|
||||
dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc)
|
||||
dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc)
|
||||
|
||||
dqc.add_(dq_chunk)
|
||||
dkc.add_(dk_chunk)
|
||||
dvc.add_(dv_chunk)
|
||||
|
||||
return dq, dk, dv, None, None, None, None
|
||||
|
||||
|
||||
class FlashAttnProcessor:
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
) -> Any:
|
||||
q_bucket_size = 512
|
||||
k_bucket_size = 1024
|
||||
|
||||
h = attn.heads
|
||||
q = attn.to_q(hidden_states)
|
||||
|
||||
encoder_hidden_states = (
|
||||
encoder_hidden_states
|
||||
if encoder_hidden_states is not None
|
||||
else hidden_states
|
||||
)
|
||||
encoder_hidden_states = encoder_hidden_states.to(hidden_states.dtype)
|
||||
|
||||
if hasattr(attn, "hypernetwork") and attn.hypernetwork is not None:
|
||||
context_k, context_v = attn.hypernetwork.forward(
|
||||
hidden_states, encoder_hidden_states
|
||||
)
|
||||
context_k = context_k.to(hidden_states.dtype)
|
||||
context_v = context_v.to(hidden_states.dtype)
|
||||
else:
|
||||
context_k = encoder_hidden_states
|
||||
context_v = encoder_hidden_states
|
||||
|
||||
k = attn.to_k(context_k)
|
||||
v = attn.to_v(context_v)
|
||||
del encoder_hidden_states, hidden_states
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
||||
|
||||
out = FlashAttentionFunction.apply(
|
||||
q, k, v, attention_mask, False, q_bucket_size, k_bucket_size
|
||||
)
|
||||
|
||||
out = rearrange(out, "b h n d -> b n (h d)")
|
||||
|
||||
out = attn.to_out[0](out)
|
||||
out = attn.to_out[1](out)
|
||||
return out
|
||||
File diff suppressed because it is too large
Load Diff
529
library/custom_train_functions.py
Normal file
529
library/custom_train_functions.py
Normal file
@@ -0,0 +1,529 @@
|
||||
import torch
|
||||
import argparse
|
||||
import random
|
||||
import re
|
||||
from typing import List, Optional, Union
|
||||
|
||||
|
||||
def prepare_scheduler_for_custom_training(noise_scheduler, device):
|
||||
if hasattr(noise_scheduler, "all_snr"):
|
||||
return
|
||||
|
||||
alphas_cumprod = noise_scheduler.alphas_cumprod
|
||||
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
|
||||
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
|
||||
alpha = sqrt_alphas_cumprod
|
||||
sigma = sqrt_one_minus_alphas_cumprod
|
||||
all_snr = (alpha / sigma) ** 2
|
||||
|
||||
noise_scheduler.all_snr = all_snr.to(device)
|
||||
|
||||
|
||||
def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
|
||||
# fix beta: zero terminal SNR
|
||||
print(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891")
|
||||
|
||||
def enforce_zero_terminal_snr(betas):
|
||||
# Convert betas to alphas_bar_sqrt
|
||||
alphas = 1 - betas
|
||||
alphas_bar = alphas.cumprod(0)
|
||||
alphas_bar_sqrt = alphas_bar.sqrt()
|
||||
|
||||
# Store old values.
|
||||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||
# Shift so last timestep is zero.
|
||||
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
||||
# Scale so first timestep is back to old value.
|
||||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||
|
||||
# Convert alphas_bar_sqrt to betas
|
||||
alphas_bar = alphas_bar_sqrt**2
|
||||
alphas = alphas_bar[1:] / alphas_bar[:-1]
|
||||
alphas = torch.cat([alphas_bar[0:1], alphas])
|
||||
betas = 1 - alphas
|
||||
return betas
|
||||
|
||||
betas = noise_scheduler.betas
|
||||
betas = enforce_zero_terminal_snr(betas)
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
|
||||
# print("original:", noise_scheduler.betas)
|
||||
# print("fixed:", betas)
|
||||
|
||||
noise_scheduler.betas = betas
|
||||
noise_scheduler.alphas = alphas
|
||||
noise_scheduler.alphas_cumprod = alphas_cumprod
|
||||
|
||||
|
||||
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
|
||||
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
|
||||
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
|
||||
if v_prediction:
|
||||
snr_weight = torch.div(min_snr_gamma, snr+1).float().to(loss.device)
|
||||
else:
|
||||
snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
|
||||
loss = loss * snr_weight
|
||||
return loss
|
||||
|
||||
|
||||
def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
|
||||
scale = get_snr_scale(timesteps, noise_scheduler)
|
||||
loss = loss * scale
|
||||
return loss
|
||||
|
||||
|
||||
def get_snr_scale(timesteps, noise_scheduler):
|
||||
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
|
||||
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
||||
scale = snr_t / (snr_t + 1)
|
||||
# # show debug info
|
||||
# print(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}")
|
||||
return scale
|
||||
|
||||
|
||||
def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss):
|
||||
scale = get_snr_scale(timesteps, noise_scheduler)
|
||||
# print(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
|
||||
loss = loss + loss / scale * v_pred_like_loss
|
||||
return loss
|
||||
|
||||
def apply_debiased_estimation(loss, timesteps, noise_scheduler):
|
||||
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
|
||||
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
||||
weight = 1/torch.sqrt(snr_t)
|
||||
loss = weight * loss
|
||||
return loss
|
||||
|
||||
# TODO train_utilと分散しているのでどちらかに寄せる
|
||||
|
||||
|
||||
def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True):
|
||||
parser.add_argument(
|
||||
"--min_snr_gamma",
|
||||
type=float,
|
||||
default=None,
|
||||
help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale_v_pred_loss_like_noise_pred",
|
||||
action="store_true",
|
||||
help="scale v-prediction loss like noise prediction loss / v-prediction lossをnoise prediction lossと同じようにスケーリングする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--v_pred_like_loss",
|
||||
type=float,
|
||||
default=None,
|
||||
help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけたものをlossに加算する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debiased_estimation_loss",
|
||||
action="store_true",
|
||||
help="debiased estimation loss / debiased estimation loss",
|
||||
)
|
||||
if support_weighted_captions:
|
||||
parser.add_argument(
|
||||
"--weighted_captions",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意",
|
||||
)
|
||||
|
||||
|
||||
re_attention = re.compile(
|
||||
r"""
|
||||
\\\(|
|
||||
\\\)|
|
||||
\\\[|
|
||||
\\]|
|
||||
\\\\|
|
||||
\\|
|
||||
\(|
|
||||
\[|
|
||||
:([+-]?[.\d]+)\)|
|
||||
\)|
|
||||
]|
|
||||
[^\\()\[\]:]+|
|
||||
:
|
||||
""",
|
||||
re.X,
|
||||
)
|
||||
|
||||
|
||||
def parse_prompt_attention(text):
|
||||
"""
|
||||
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
||||
Accepted tokens are:
|
||||
(abc) - increases attention to abc by a multiplier of 1.1
|
||||
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
||||
[abc] - decreases attention to abc by a multiplier of 1.1
|
||||
\( - literal character '('
|
||||
\[ - literal character '['
|
||||
\) - literal character ')'
|
||||
\] - literal character ']'
|
||||
\\ - literal character '\'
|
||||
anything else - just text
|
||||
>>> parse_prompt_attention('normal text')
|
||||
[['normal text', 1.0]]
|
||||
>>> parse_prompt_attention('an (important) word')
|
||||
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
||||
>>> parse_prompt_attention('(unbalanced')
|
||||
[['unbalanced', 1.1]]
|
||||
>>> parse_prompt_attention('\(literal\]')
|
||||
[['(literal]', 1.0]]
|
||||
>>> parse_prompt_attention('(unnecessary)(parens)')
|
||||
[['unnecessaryparens', 1.1]]
|
||||
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
||||
[['a ', 1.0],
|
||||
['house', 1.5730000000000004],
|
||||
[' ', 1.1],
|
||||
['on', 1.0],
|
||||
[' a ', 1.1],
|
||||
['hill', 0.55],
|
||||
[', sun, ', 1.1],
|
||||
['sky', 1.4641000000000006],
|
||||
['.', 1.1]]
|
||||
"""
|
||||
|
||||
res = []
|
||||
round_brackets = []
|
||||
square_brackets = []
|
||||
|
||||
round_bracket_multiplier = 1.1
|
||||
square_bracket_multiplier = 1 / 1.1
|
||||
|
||||
def multiply_range(start_position, multiplier):
|
||||
for p in range(start_position, len(res)):
|
||||
res[p][1] *= multiplier
|
||||
|
||||
for m in re_attention.finditer(text):
|
||||
text = m.group(0)
|
||||
weight = m.group(1)
|
||||
|
||||
if text.startswith("\\"):
|
||||
res.append([text[1:], 1.0])
|
||||
elif text == "(":
|
||||
round_brackets.append(len(res))
|
||||
elif text == "[":
|
||||
square_brackets.append(len(res))
|
||||
elif weight is not None and len(round_brackets) > 0:
|
||||
multiply_range(round_brackets.pop(), float(weight))
|
||||
elif text == ")" and len(round_brackets) > 0:
|
||||
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
||||
elif text == "]" and len(square_brackets) > 0:
|
||||
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
||||
else:
|
||||
res.append([text, 1.0])
|
||||
|
||||
for pos in round_brackets:
|
||||
multiply_range(pos, round_bracket_multiplier)
|
||||
|
||||
for pos in square_brackets:
|
||||
multiply_range(pos, square_bracket_multiplier)
|
||||
|
||||
if len(res) == 0:
|
||||
res = [["", 1.0]]
|
||||
|
||||
# merge runs of identical weights
|
||||
i = 0
|
||||
while i + 1 < len(res):
|
||||
if res[i][1] == res[i + 1][1]:
|
||||
res[i][0] += res[i + 1][0]
|
||||
res.pop(i + 1)
|
||||
else:
|
||||
i += 1
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int):
|
||||
r"""
|
||||
Tokenize a list of prompts and return its tokens with weights of each token.
|
||||
|
||||
No padding, starting or ending token is included.
|
||||
"""
|
||||
tokens = []
|
||||
weights = []
|
||||
truncated = False
|
||||
for text in prompt:
|
||||
texts_and_weights = parse_prompt_attention(text)
|
||||
text_token = []
|
||||
text_weight = []
|
||||
for word, weight in texts_and_weights:
|
||||
# tokenize and discard the starting and the ending token
|
||||
token = tokenizer(word).input_ids[1:-1]
|
||||
text_token += token
|
||||
# copy the weight by length of token
|
||||
text_weight += [weight] * len(token)
|
||||
# stop if the text is too long (longer than truncation limit)
|
||||
if len(text_token) > max_length:
|
||||
truncated = True
|
||||
break
|
||||
# truncate
|
||||
if len(text_token) > max_length:
|
||||
truncated = True
|
||||
text_token = text_token[:max_length]
|
||||
text_weight = text_weight[:max_length]
|
||||
tokens.append(text_token)
|
||||
weights.append(text_weight)
|
||||
if truncated:
|
||||
print("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
|
||||
return tokens, weights
|
||||
|
||||
|
||||
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
|
||||
r"""
|
||||
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
||||
"""
|
||||
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
|
||||
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
|
||||
for i in range(len(tokens)):
|
||||
tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
|
||||
if no_boseos_middle:
|
||||
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
|
||||
else:
|
||||
w = []
|
||||
if len(weights[i]) == 0:
|
||||
w = [1.0] * weights_length
|
||||
else:
|
||||
for j in range(max_embeddings_multiples):
|
||||
w.append(1.0) # weight for starting token in this chunk
|
||||
w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
|
||||
w.append(1.0) # weight for ending token in this chunk
|
||||
w += [1.0] * (weights_length - len(w))
|
||||
weights[i] = w[:]
|
||||
|
||||
return tokens, weights
|
||||
|
||||
|
||||
def get_unweighted_text_embeddings(
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
text_input: torch.Tensor,
|
||||
chunk_length: int,
|
||||
clip_skip: int,
|
||||
eos: int,
|
||||
pad: int,
|
||||
no_boseos_middle: Optional[bool] = True,
|
||||
):
|
||||
"""
|
||||
When the length of tokens is a multiple of the capacity of the text encoder,
|
||||
it should be split into chunks and sent to the text encoder individually.
|
||||
"""
|
||||
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
|
||||
if max_embeddings_multiples > 1:
|
||||
text_embeddings = []
|
||||
for i in range(max_embeddings_multiples):
|
||||
# extract the i-th chunk
|
||||
text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
|
||||
|
||||
# cover the head and the tail by the starting and the ending tokens
|
||||
text_input_chunk[:, 0] = text_input[0, 0]
|
||||
if pad == eos: # v1
|
||||
text_input_chunk[:, -1] = text_input[0, -1]
|
||||
else: # v2
|
||||
for j in range(len(text_input_chunk)):
|
||||
if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
|
||||
text_input_chunk[j, -1] = eos
|
||||
if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
|
||||
text_input_chunk[j, 1] = eos
|
||||
|
||||
if clip_skip is None or clip_skip == 1:
|
||||
text_embedding = text_encoder(text_input_chunk)[0]
|
||||
else:
|
||||
enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
|
||||
text_embedding = enc_out["hidden_states"][-clip_skip]
|
||||
text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
|
||||
|
||||
if no_boseos_middle:
|
||||
if i == 0:
|
||||
# discard the ending token
|
||||
text_embedding = text_embedding[:, :-1]
|
||||
elif i == max_embeddings_multiples - 1:
|
||||
# discard the starting token
|
||||
text_embedding = text_embedding[:, 1:]
|
||||
else:
|
||||
# discard both starting and ending tokens
|
||||
text_embedding = text_embedding[:, 1:-1]
|
||||
|
||||
text_embeddings.append(text_embedding)
|
||||
text_embeddings = torch.concat(text_embeddings, axis=1)
|
||||
else:
|
||||
if clip_skip is None or clip_skip == 1:
|
||||
text_embeddings = text_encoder(text_input)[0]
|
||||
else:
|
||||
enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
|
||||
text_embeddings = enc_out["hidden_states"][-clip_skip]
|
||||
text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings)
|
||||
return text_embeddings
|
||||
|
||||
|
||||
def get_weighted_text_embeddings(
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
prompt: Union[str, List[str]],
|
||||
device,
|
||||
max_embeddings_multiples: Optional[int] = 3,
|
||||
no_boseos_middle: Optional[bool] = False,
|
||||
clip_skip=None,
|
||||
):
|
||||
r"""
|
||||
Prompts can be assigned with local weights using brackets. For example,
|
||||
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
|
||||
and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
|
||||
|
||||
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
||||
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
||||
no_boseos_middle (`bool`, *optional*, defaults to `False`):
|
||||
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
|
||||
ending token in each of the chunk in the middle.
|
||||
skip_parsing (`bool`, *optional*, defaults to `False`):
|
||||
Skip the parsing of brackets.
|
||||
skip_weighting (`bool`, *optional*, defaults to `False`):
|
||||
Skip the weighting. When the parsing is skipped, it is forced True.
|
||||
"""
|
||||
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
|
||||
prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, prompt, max_length - 2)
|
||||
|
||||
# round up the longest length of tokens to a multiple of (model_max_length - 2)
|
||||
max_length = max([len(token) for token in prompt_tokens])
|
||||
|
||||
max_embeddings_multiples = min(
|
||||
max_embeddings_multiples,
|
||||
(max_length - 1) // (tokenizer.model_max_length - 2) + 1,
|
||||
)
|
||||
max_embeddings_multiples = max(1, max_embeddings_multiples)
|
||||
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
||||
|
||||
# pad the length of tokens and weights
|
||||
bos = tokenizer.bos_token_id
|
||||
eos = tokenizer.eos_token_id
|
||||
pad = tokenizer.pad_token_id
|
||||
prompt_tokens, prompt_weights = pad_tokens_and_weights(
|
||||
prompt_tokens,
|
||||
prompt_weights,
|
||||
max_length,
|
||||
bos,
|
||||
eos,
|
||||
no_boseos_middle=no_boseos_middle,
|
||||
chunk_length=tokenizer.model_max_length,
|
||||
)
|
||||
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device)
|
||||
|
||||
# get the embeddings
|
||||
text_embeddings = get_unweighted_text_embeddings(
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
prompt_tokens,
|
||||
tokenizer.model_max_length,
|
||||
clip_skip,
|
||||
eos,
|
||||
pad,
|
||||
no_boseos_middle=no_boseos_middle,
|
||||
)
|
||||
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device)
|
||||
|
||||
# assign weights to the prompts and normalize in the sense of mean
|
||||
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
||||
text_embeddings = text_embeddings * prompt_weights.unsqueeze(-1)
|
||||
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
||||
text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
return text_embeddings
|
||||
|
||||
|
||||
# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
|
||||
def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
|
||||
b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant!
|
||||
u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
|
||||
for i in range(iterations):
|
||||
r = random.random() * 2 + 2 # Rather than always going 2x,
|
||||
wn, hn = max(1, int(w / (r**i))), max(1, int(h / (r**i)))
|
||||
noise += u(torch.randn(b, c, wn, hn).to(device)) * discount**i
|
||||
if wn == 1 or hn == 1:
|
||||
break # Lowest resolution is 1x1
|
||||
return noise / noise.std() # Scaled back to roughly unit variance
|
||||
|
||||
|
||||
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||
def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
|
||||
if noise_offset is None:
|
||||
return noise
|
||||
if adaptive_noise_scale is not None:
|
||||
# latent shape: (batch_size, channels, height, width)
|
||||
# abs mean value for each channel
|
||||
latent_mean = torch.abs(latents.mean(dim=(2, 3), keepdim=True))
|
||||
|
||||
# multiply adaptive noise scale to the mean value and add it to the noise offset
|
||||
noise_offset = noise_offset + adaptive_noise_scale * latent_mean
|
||||
noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative
|
||||
|
||||
noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
||||
return noise
|
||||
|
||||
|
||||
"""
|
||||
##########################################
|
||||
# Perlin Noise
|
||||
def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
|
||||
delta = (res[0] / shape[0], res[1] / shape[1])
|
||||
d = (shape[0] // res[0], shape[1] // res[1])
|
||||
|
||||
grid = (
|
||||
torch.stack(
|
||||
torch.meshgrid(torch.arange(0, res[0], delta[0], device=device), torch.arange(0, res[1], delta[1], device=device)),
|
||||
dim=-1,
|
||||
)
|
||||
% 1
|
||||
)
|
||||
angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1, device=device)
|
||||
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
|
||||
|
||||
tile_grads = (
|
||||
lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
|
||||
.repeat_interleave(d[0], 0)
|
||||
.repeat_interleave(d[1], 1)
|
||||
)
|
||||
dot = lambda grad, shift: (
|
||||
torch.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), dim=-1)
|
||||
* grad[: shape[0], : shape[1]]
|
||||
).sum(dim=-1)
|
||||
|
||||
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
|
||||
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
|
||||
n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
|
||||
n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
|
||||
t = fade(grid[: shape[0], : shape[1]])
|
||||
return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
|
||||
|
||||
|
||||
def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5):
|
||||
noise = torch.zeros(shape, device=device)
|
||||
frequency = 1
|
||||
amplitude = 1
|
||||
for _ in range(octaves):
|
||||
noise += amplitude * rand_perlin_2d(device, shape, (frequency * res[0], frequency * res[1]))
|
||||
frequency *= 2
|
||||
amplitude *= persistence
|
||||
return noise
|
||||
|
||||
|
||||
def perlin_noise(noise, device, octaves):
|
||||
_, c, w, h = noise.shape
|
||||
perlin = lambda: rand_perlin_2d_octaves(device, (w, h), (4, 4), octaves)
|
||||
noise_perlin = []
|
||||
for _ in range(c):
|
||||
noise_perlin.append(perlin())
|
||||
noise_perlin = torch.stack(noise_perlin).unsqueeze(0) # (1, c, w, h)
|
||||
noise += noise_perlin # broadcast for each batch
|
||||
return noise / noise.std() # Scaled back to roughly unit variance
|
||||
"""
|
||||
81
library/huggingface_util.py
Normal file
81
library/huggingface_util.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from typing import Union, BinaryIO
|
||||
from huggingface_hub import HfApi
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
import os
|
||||
from library.utils import fire_in_thread
|
||||
|
||||
|
||||
def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None):
|
||||
api = HfApi(
|
||||
token=token,
|
||||
)
|
||||
try:
|
||||
api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
def upload(
|
||||
args: argparse.Namespace,
|
||||
src: Union[str, Path, bytes, BinaryIO],
|
||||
dest_suffix: str = "",
|
||||
force_sync_upload: bool = False,
|
||||
):
|
||||
repo_id = args.huggingface_repo_id
|
||||
repo_type = args.huggingface_repo_type
|
||||
token = args.huggingface_token
|
||||
path_in_repo = args.huggingface_path_in_repo + dest_suffix if args.huggingface_path_in_repo is not None else None
|
||||
private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public"
|
||||
api = HfApi(token=token)
|
||||
if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token):
|
||||
try:
|
||||
api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
|
||||
except Exception as e: # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので
|
||||
print("===========================================")
|
||||
print(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}")
|
||||
print("===========================================")
|
||||
|
||||
is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir())
|
||||
|
||||
def uploader():
|
||||
try:
|
||||
if is_folder:
|
||||
api.upload_folder(
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
folder_path=src,
|
||||
path_in_repo=path_in_repo,
|
||||
)
|
||||
else:
|
||||
api.upload_file(
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
path_or_fileobj=src,
|
||||
path_in_repo=path_in_repo,
|
||||
)
|
||||
except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので
|
||||
print("===========================================")
|
||||
print(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}")
|
||||
print("===========================================")
|
||||
|
||||
if args.async_upload and not force_sync_upload:
|
||||
fire_in_thread(uploader)
|
||||
else:
|
||||
uploader()
|
||||
|
||||
|
||||
def list_dir(
|
||||
repo_id: str,
|
||||
subfolder: str,
|
||||
repo_type: str,
|
||||
revision: str = "main",
|
||||
token: str = None,
|
||||
):
|
||||
api = HfApi(
|
||||
token=token,
|
||||
)
|
||||
repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
|
||||
file_list = [file for file in repo_info.siblings if file.rfilename.startswith(subfolder)]
|
||||
return file_list
|
||||
223
library/hypernetwork.py
Normal file
223
library/hypernetwork.py
Normal file
@@ -0,0 +1,223 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from diffusers.models.attention_processor import (
|
||||
Attention,
|
||||
AttnProcessor2_0,
|
||||
SlicedAttnProcessor,
|
||||
XFormersAttnProcessor
|
||||
)
|
||||
|
||||
try:
|
||||
import xformers.ops
|
||||
except:
|
||||
xformers = None
|
||||
|
||||
|
||||
loaded_networks = []
|
||||
|
||||
|
||||
def apply_single_hypernetwork(
|
||||
hypernetwork, hidden_states, encoder_hidden_states
|
||||
):
|
||||
context_k, context_v = hypernetwork.forward(hidden_states, encoder_hidden_states)
|
||||
return context_k, context_v
|
||||
|
||||
|
||||
def apply_hypernetworks(context_k, context_v, layer=None):
|
||||
if len(loaded_networks) == 0:
|
||||
return context_v, context_v
|
||||
for hypernetwork in loaded_networks:
|
||||
context_k, context_v = hypernetwork.forward(context_k, context_v)
|
||||
|
||||
context_k = context_k.to(dtype=context_k.dtype)
|
||||
context_v = context_v.to(dtype=context_k.dtype)
|
||||
|
||||
return context_k, context_v
|
||||
|
||||
|
||||
|
||||
def xformers_forward(
|
||||
self: XFormersAttnProcessor,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
attention_mask: torch.Tensor = None,
|
||||
):
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape
|
||||
if encoder_hidden_states is None
|
||||
else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
attention_mask = attn.prepare_attention_mask(
|
||||
attention_mask, sequence_length, batch_size
|
||||
)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(context_k)
|
||||
value = attn.to_v(context_v)
|
||||
|
||||
query = attn.head_to_batch_dim(query).contiguous()
|
||||
key = attn.head_to_batch_dim(key).contiguous()
|
||||
value = attn.head_to_batch_dim(value).contiguous()
|
||||
|
||||
hidden_states = xformers.ops.memory_efficient_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_bias=attention_mask,
|
||||
op=self.attention_op,
|
||||
scale=attn.scale,
|
||||
)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def sliced_attn_forward(
|
||||
self: SlicedAttnProcessor,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
attention_mask: torch.Tensor = None,
|
||||
):
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape
|
||||
if encoder_hidden_states is None
|
||||
else encoder_hidden_states.shape
|
||||
)
|
||||
attention_mask = attn.prepare_attention_mask(
|
||||
attention_mask, sequence_length, batch_size
|
||||
)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
dim = query.shape[-1]
|
||||
query = attn.head_to_batch_dim(query)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(context_k)
|
||||
value = attn.to_v(context_v)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
batch_size_attention, query_tokens, _ = query.shape
|
||||
hidden_states = torch.zeros(
|
||||
(batch_size_attention, query_tokens, dim // attn.heads),
|
||||
device=query.device,
|
||||
dtype=query.dtype,
|
||||
)
|
||||
|
||||
for i in range(batch_size_attention // self.slice_size):
|
||||
start_idx = i * self.slice_size
|
||||
end_idx = (i + 1) * self.slice_size
|
||||
|
||||
query_slice = query[start_idx:end_idx]
|
||||
key_slice = key[start_idx:end_idx]
|
||||
attn_mask_slice = (
|
||||
attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
||||
)
|
||||
|
||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||
|
||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
||||
|
||||
hidden_states[start_idx:end_idx] = attn_slice
|
||||
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def v2_0_forward(
|
||||
self: AttnProcessor2_0,
|
||||
attn: Attention,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
):
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape
|
||||
if encoder_hidden_states is None
|
||||
else encoder_hidden_states.shape
|
||||
)
|
||||
inner_dim = hidden_states.shape[-1]
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(
|
||||
attention_mask, sequence_length, batch_size
|
||||
)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(
|
||||
batch_size, attn.heads, -1, attention_mask.shape[-1]
|
||||
)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(context_k)
|
||||
value = attn.to_v(context_v)
|
||||
|
||||
head_dim = inner_dim // attn.heads
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(
|
||||
batch_size, -1, attn.heads * head_dim
|
||||
)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def replace_attentions_for_hypernetwork():
|
||||
import diffusers.models.attention_processor
|
||||
|
||||
diffusers.models.attention_processor.XFormersAttnProcessor.__call__ = (
|
||||
xformers_forward
|
||||
)
|
||||
diffusers.models.attention_processor.SlicedAttnProcessor.__call__ = (
|
||||
sliced_attn_forward
|
||||
)
|
||||
diffusers.models.attention_processor.AttnProcessor2_0.__call__ = v2_0_forward
|
||||
175
library/ipex/__init__.py
Normal file
175
library/ipex/__init__.py
Normal file
@@ -0,0 +1,175 @@
|
||||
import os
|
||||
import sys
|
||||
import contextlib
|
||||
import torch
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||
from .hijacks import ipex_hijacks
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||
|
||||
def ipex_init(): # pylint: disable=too-many-statements
|
||||
try:
|
||||
# Replace cuda with xpu:
|
||||
torch.cuda.current_device = torch.xpu.current_device
|
||||
torch.cuda.current_stream = torch.xpu.current_stream
|
||||
torch.cuda.device = torch.xpu.device
|
||||
torch.cuda.device_count = torch.xpu.device_count
|
||||
torch.cuda.device_of = torch.xpu.device_of
|
||||
torch.cuda.get_device_name = torch.xpu.get_device_name
|
||||
torch.cuda.get_device_properties = torch.xpu.get_device_properties
|
||||
torch.cuda.init = torch.xpu.init
|
||||
torch.cuda.is_available = torch.xpu.is_available
|
||||
torch.cuda.is_initialized = torch.xpu.is_initialized
|
||||
torch.cuda.is_current_stream_capturing = lambda: False
|
||||
torch.cuda.set_device = torch.xpu.set_device
|
||||
torch.cuda.stream = torch.xpu.stream
|
||||
torch.cuda.synchronize = torch.xpu.synchronize
|
||||
torch.cuda.Event = torch.xpu.Event
|
||||
torch.cuda.Stream = torch.xpu.Stream
|
||||
torch.cuda.FloatTensor = torch.xpu.FloatTensor
|
||||
torch.Tensor.cuda = torch.Tensor.xpu
|
||||
torch.Tensor.is_cuda = torch.Tensor.is_xpu
|
||||
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
|
||||
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
|
||||
torch.cuda._initialized = torch.xpu.lazy_init._initialized
|
||||
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
|
||||
torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
|
||||
torch.cuda._tls = torch.xpu.lazy_init._tls
|
||||
torch.cuda.threading = torch.xpu.lazy_init.threading
|
||||
torch.cuda.traceback = torch.xpu.lazy_init.traceback
|
||||
torch.cuda.Optional = torch.xpu.Optional
|
||||
torch.cuda.__cached__ = torch.xpu.__cached__
|
||||
torch.cuda.__loader__ = torch.xpu.__loader__
|
||||
torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
|
||||
torch.cuda.Tuple = torch.xpu.Tuple
|
||||
torch.cuda.streams = torch.xpu.streams
|
||||
torch.cuda._lazy_new = torch.xpu._lazy_new
|
||||
torch.cuda.FloatStorage = torch.xpu.FloatStorage
|
||||
torch.cuda.Any = torch.xpu.Any
|
||||
torch.cuda.__doc__ = torch.xpu.__doc__
|
||||
torch.cuda.default_generators = torch.xpu.default_generators
|
||||
torch.cuda.HalfTensor = torch.xpu.HalfTensor
|
||||
torch.cuda._get_device_index = torch.xpu._get_device_index
|
||||
torch.cuda.__path__ = torch.xpu.__path__
|
||||
torch.cuda.Device = torch.xpu.Device
|
||||
torch.cuda.IntTensor = torch.xpu.IntTensor
|
||||
torch.cuda.ByteStorage = torch.xpu.ByteStorage
|
||||
torch.cuda.set_stream = torch.xpu.set_stream
|
||||
torch.cuda.BoolStorage = torch.xpu.BoolStorage
|
||||
torch.cuda.os = torch.xpu.os
|
||||
torch.cuda.torch = torch.xpu.torch
|
||||
torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
|
||||
torch.cuda.Union = torch.xpu.Union
|
||||
torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
|
||||
torch.cuda.ShortTensor = torch.xpu.ShortTensor
|
||||
torch.cuda.LongTensor = torch.xpu.LongTensor
|
||||
torch.cuda.IntStorage = torch.xpu.IntStorage
|
||||
torch.cuda.LongStorage = torch.xpu.LongStorage
|
||||
torch.cuda.__annotations__ = torch.xpu.__annotations__
|
||||
torch.cuda.__package__ = torch.xpu.__package__
|
||||
torch.cuda.__builtins__ = torch.xpu.__builtins__
|
||||
torch.cuda.CharTensor = torch.xpu.CharTensor
|
||||
torch.cuda.List = torch.xpu.List
|
||||
torch.cuda._lazy_init = torch.xpu._lazy_init
|
||||
torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
|
||||
torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
|
||||
torch.cuda.ByteTensor = torch.xpu.ByteTensor
|
||||
torch.cuda.StreamContext = torch.xpu.StreamContext
|
||||
torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
|
||||
torch.cuda.ShortStorage = torch.xpu.ShortStorage
|
||||
torch.cuda._lazy_call = torch.xpu._lazy_call
|
||||
torch.cuda.HalfStorage = torch.xpu.HalfStorage
|
||||
torch.cuda.random = torch.xpu.random
|
||||
torch.cuda._device = torch.xpu._device
|
||||
torch.cuda.classproperty = torch.xpu.classproperty
|
||||
torch.cuda.__name__ = torch.xpu.__name__
|
||||
torch.cuda._device_t = torch.xpu._device_t
|
||||
torch.cuda.warnings = torch.xpu.warnings
|
||||
torch.cuda.__spec__ = torch.xpu.__spec__
|
||||
torch.cuda.BoolTensor = torch.xpu.BoolTensor
|
||||
torch.cuda.CharStorage = torch.xpu.CharStorage
|
||||
torch.cuda.__file__ = torch.xpu.__file__
|
||||
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
|
||||
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
|
||||
|
||||
# Memory:
|
||||
torch.cuda.memory = torch.xpu.memory
|
||||
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
|
||||
torch.xpu.empty_cache = lambda: None
|
||||
torch.cuda.empty_cache = torch.xpu.empty_cache
|
||||
torch.cuda.memory_stats = torch.xpu.memory_stats
|
||||
torch.cuda.memory_summary = torch.xpu.memory_summary
|
||||
torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
|
||||
torch.cuda.memory_allocated = torch.xpu.memory_allocated
|
||||
torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated
|
||||
torch.cuda.memory_reserved = torch.xpu.memory_reserved
|
||||
torch.cuda.memory_cached = torch.xpu.memory_reserved
|
||||
torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved
|
||||
torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved
|
||||
torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats
|
||||
torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats
|
||||
torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats
|
||||
torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict
|
||||
torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats
|
||||
|
||||
# RNG:
|
||||
torch.cuda.get_rng_state = torch.xpu.get_rng_state
|
||||
torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all
|
||||
torch.cuda.set_rng_state = torch.xpu.set_rng_state
|
||||
torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all
|
||||
torch.cuda.manual_seed = torch.xpu.manual_seed
|
||||
torch.cuda.manual_seed_all = torch.xpu.manual_seed_all
|
||||
torch.cuda.seed = torch.xpu.seed
|
||||
torch.cuda.seed_all = torch.xpu.seed_all
|
||||
torch.cuda.initial_seed = torch.xpu.initial_seed
|
||||
|
||||
# AMP:
|
||||
torch.cuda.amp = torch.xpu.amp
|
||||
torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled
|
||||
torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype
|
||||
|
||||
if not hasattr(torch.cuda.amp, "common"):
|
||||
torch.cuda.amp.common = contextlib.nullcontext()
|
||||
torch.cuda.amp.common.amp_definitely_not_available = lambda: False
|
||||
|
||||
try:
|
||||
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
try:
|
||||
from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
|
||||
gradscaler_init()
|
||||
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
|
||||
|
||||
# C
|
||||
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
|
||||
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_eu_count
|
||||
ipex._C._DeviceProperties.major = 2023
|
||||
ipex._C._DeviceProperties.minor = 2
|
||||
|
||||
# Fix functions with ipex:
|
||||
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
|
||||
torch._utils._get_available_device_type = lambda: "xpu"
|
||||
torch.has_cuda = True
|
||||
torch.cuda.has_half = True
|
||||
torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
|
||||
torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
|
||||
torch.backends.cuda.is_built = lambda *args, **kwargs: True
|
||||
torch.version.cuda = "12.1"
|
||||
torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1]
|
||||
torch.cuda.get_device_properties.major = 12
|
||||
torch.cuda.get_device_properties.minor = 1
|
||||
torch.cuda.ipc_collect = lambda *args, **kwargs: None
|
||||
torch.cuda.utilization = lambda *args, **kwargs: 0
|
||||
|
||||
ipex_hijacks()
|
||||
if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None:
|
||||
try:
|
||||
from .diffusers import ipex_diffusers
|
||||
ipex_diffusers()
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
pass
|
||||
except Exception as e:
|
||||
return False, e
|
||||
return True, None
|
||||
177
library/ipex/attention.py
Normal file
177
library/ipex/attention.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import os
|
||||
import torch
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||
from functools import cache
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||
|
||||
# ARC GPUs can't allocate more than 4GB to a single block so we slice the attetion layers
|
||||
|
||||
sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4))
|
||||
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
|
||||
|
||||
# Find something divisible with the input_tokens
|
||||
@cache
|
||||
def find_slice_size(slice_size, slice_block_size):
|
||||
while (slice_size * slice_block_size) > attention_slice_rate:
|
||||
slice_size = slice_size // 2
|
||||
if slice_size <= 1:
|
||||
slice_size = 1
|
||||
break
|
||||
return slice_size
|
||||
|
||||
# Find slice sizes for SDPA
|
||||
@cache
|
||||
def find_sdpa_slice_sizes(query_shape, query_element_size):
|
||||
if len(query_shape) == 3:
|
||||
batch_size_attention, query_tokens, shape_three = query_shape
|
||||
shape_four = 1
|
||||
else:
|
||||
batch_size_attention, query_tokens, shape_three, shape_four = query_shape
|
||||
|
||||
slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size
|
||||
block_size = batch_size_attention * slice_block_size
|
||||
|
||||
split_slice_size = batch_size_attention
|
||||
split_2_slice_size = query_tokens
|
||||
split_3_slice_size = shape_three
|
||||
|
||||
do_split = False
|
||||
do_split_2 = False
|
||||
do_split_3 = False
|
||||
|
||||
if block_size > sdpa_slice_trigger_rate:
|
||||
do_split = True
|
||||
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
|
||||
if split_slice_size * slice_block_size > attention_slice_rate:
|
||||
slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
|
||||
do_split_2 = True
|
||||
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
|
||||
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
|
||||
slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
|
||||
do_split_3 = True
|
||||
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
|
||||
|
||||
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
||||
|
||||
# Find slice sizes for BMM
|
||||
@cache
|
||||
def find_bmm_slice_sizes(input_shape, input_element_size, mat2_shape):
|
||||
batch_size_attention, input_tokens, mat2_atten_shape = input_shape[0], input_shape[1], mat2_shape[2]
|
||||
slice_block_size = input_tokens * mat2_atten_shape / 1024 / 1024 * input_element_size
|
||||
block_size = batch_size_attention * slice_block_size
|
||||
|
||||
split_slice_size = batch_size_attention
|
||||
split_2_slice_size = input_tokens
|
||||
split_3_slice_size = mat2_atten_shape
|
||||
|
||||
do_split = False
|
||||
do_split_2 = False
|
||||
do_split_3 = False
|
||||
|
||||
if block_size > attention_slice_rate:
|
||||
do_split = True
|
||||
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
|
||||
if split_slice_size * slice_block_size > attention_slice_rate:
|
||||
slice_2_block_size = split_slice_size * mat2_atten_shape / 1024 / 1024 * input_element_size
|
||||
do_split_2 = True
|
||||
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
|
||||
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
|
||||
slice_3_block_size = split_slice_size * split_2_slice_size / 1024 / 1024 * input_element_size
|
||||
do_split_3 = True
|
||||
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
|
||||
|
||||
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
||||
|
||||
|
||||
original_torch_bmm = torch.bmm
|
||||
def torch_bmm_32_bit(input, mat2, *, out=None):
|
||||
if input.device.type != "xpu":
|
||||
return original_torch_bmm(input, mat2, out=out)
|
||||
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_bmm_slice_sizes(input.shape, input.element_size(), mat2.shape)
|
||||
|
||||
# Slice BMM
|
||||
if do_split:
|
||||
batch_size_attention, input_tokens, mat2_atten_shape = input.shape[0], input.shape[1], mat2.shape[2]
|
||||
hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype)
|
||||
for i in range(batch_size_attention // split_slice_size):
|
||||
start_idx = i * split_slice_size
|
||||
end_idx = (i + 1) * split_slice_size
|
||||
if do_split_2:
|
||||
for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_2 = i2 * split_2_slice_size
|
||||
end_idx_2 = (i2 + 1) * split_2_slice_size
|
||||
if do_split_3:
|
||||
for i3 in range(mat2_atten_shape // split_3_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_3 = i3 * split_3_slice_size
|
||||
end_idx_3 = (i3 + 1) * split_3_slice_size
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_torch_bmm(
|
||||
input[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
||||
mat2[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
||||
out=out
|
||||
)
|
||||
else:
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm(
|
||||
input[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
mat2[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
out=out
|
||||
)
|
||||
else:
|
||||
hidden_states[start_idx:end_idx] = original_torch_bmm(
|
||||
input[start_idx:end_idx],
|
||||
mat2[start_idx:end_idx],
|
||||
out=out
|
||||
)
|
||||
else:
|
||||
return original_torch_bmm(input, mat2, out=out)
|
||||
torch.xpu.synchronize(input.device)
|
||||
return hidden_states
|
||||
|
||||
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
||||
def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
|
||||
if query.device.type != "xpu":
|
||||
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
|
||||
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size())
|
||||
|
||||
# Slice SDPA
|
||||
if do_split:
|
||||
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
|
||||
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
|
||||
for i in range(batch_size_attention // split_slice_size):
|
||||
start_idx = i * split_slice_size
|
||||
end_idx = (i + 1) * split_slice_size
|
||||
if do_split_2:
|
||||
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_2 = i2 * split_2_slice_size
|
||||
end_idx_2 = (i2 + 1) * split_2_slice_size
|
||||
if do_split_3:
|
||||
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_3 = i3 * split_3_slice_size
|
||||
end_idx_3 = (i3 + 1) * split_3_slice_size
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_scaled_dot_product_attention(
|
||||
query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
||||
key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
||||
value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
||||
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask,
|
||||
dropout_p=dropout_p, is_causal=is_causal
|
||||
)
|
||||
else:
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
|
||||
query[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
key[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
value[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
|
||||
dropout_p=dropout_p, is_causal=is_causal
|
||||
)
|
||||
else:
|
||||
hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention(
|
||||
query[start_idx:end_idx],
|
||||
key[start_idx:end_idx],
|
||||
value[start_idx:end_idx],
|
||||
attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask,
|
||||
dropout_p=dropout_p, is_causal=is_causal
|
||||
)
|
||||
else:
|
||||
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
|
||||
torch.xpu.synchronize(query.device)
|
||||
return hidden_states
|
||||
312
library/ipex/diffusers.py
Normal file
312
library/ipex/diffusers.py
Normal file
@@ -0,0 +1,312 @@
|
||||
import os
|
||||
import torch
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||
import diffusers #0.24.0 # pylint: disable=import-error
|
||||
from diffusers.models.attention_processor import Attention
|
||||
from diffusers.utils import USE_PEFT_BACKEND
|
||||
from functools import cache
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||
|
||||
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
|
||||
|
||||
@cache
|
||||
def find_slice_size(slice_size, slice_block_size):
|
||||
while (slice_size * slice_block_size) > attention_slice_rate:
|
||||
slice_size = slice_size // 2
|
||||
if slice_size <= 1:
|
||||
slice_size = 1
|
||||
break
|
||||
return slice_size
|
||||
|
||||
@cache
|
||||
def find_attention_slice_sizes(query_shape, query_element_size, query_device_type, slice_size=None):
|
||||
if len(query_shape) == 3:
|
||||
batch_size_attention, query_tokens, shape_three = query_shape
|
||||
shape_four = 1
|
||||
else:
|
||||
batch_size_attention, query_tokens, shape_three, shape_four = query_shape
|
||||
if slice_size is not None:
|
||||
batch_size_attention = slice_size
|
||||
|
||||
slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size
|
||||
block_size = batch_size_attention * slice_block_size
|
||||
|
||||
split_slice_size = batch_size_attention
|
||||
split_2_slice_size = query_tokens
|
||||
split_3_slice_size = shape_three
|
||||
|
||||
do_split = False
|
||||
do_split_2 = False
|
||||
do_split_3 = False
|
||||
|
||||
if query_device_type != "xpu":
|
||||
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
||||
|
||||
if block_size > attention_slice_rate:
|
||||
do_split = True
|
||||
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
|
||||
if split_slice_size * slice_block_size > attention_slice_rate:
|
||||
slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
|
||||
do_split_2 = True
|
||||
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
|
||||
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
|
||||
slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
|
||||
do_split_3 = True
|
||||
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
|
||||
|
||||
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
||||
|
||||
class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
|
||||
r"""
|
||||
Processor for implementing sliced attention.
|
||||
|
||||
Args:
|
||||
slice_size (`int`, *optional*):
|
||||
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
|
||||
`attention_head_dim` must be a multiple of the `slice_size`.
|
||||
"""
|
||||
|
||||
def __init__(self, slice_size):
|
||||
self.slice_size = slice_size
|
||||
|
||||
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states=None, attention_mask=None) -> torch.FloatTensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
dim = query.shape[-1]
|
||||
query = attn.head_to_batch_dim(query)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
batch_size_attention, query_tokens, shape_three = query.shape
|
||||
hidden_states = torch.zeros(
|
||||
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
|
||||
)
|
||||
|
||||
####################################################################
|
||||
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
||||
_, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type, slice_size=self.slice_size)
|
||||
|
||||
for i in range(batch_size_attention // split_slice_size):
|
||||
start_idx = i * split_slice_size
|
||||
end_idx = (i + 1) * split_slice_size
|
||||
if do_split_2:
|
||||
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_2 = i2 * split_2_slice_size
|
||||
end_idx_2 = (i2 + 1) * split_2_slice_size
|
||||
if do_split_3:
|
||||
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_3 = i3 * split_3_slice_size
|
||||
end_idx_3 = (i3 + 1) * split_3_slice_size
|
||||
|
||||
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
||||
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
|
||||
|
||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||
del query_slice
|
||||
del key_slice
|
||||
del attn_mask_slice
|
||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
|
||||
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
|
||||
del attn_slice
|
||||
else:
|
||||
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
|
||||
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
|
||||
|
||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||
del query_slice
|
||||
del key_slice
|
||||
del attn_mask_slice
|
||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
|
||||
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
|
||||
del attn_slice
|
||||
torch.xpu.synchronize(query.device)
|
||||
else:
|
||||
query_slice = query[start_idx:end_idx]
|
||||
key_slice = key[start_idx:end_idx]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
||||
|
||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||
del query_slice
|
||||
del key_slice
|
||||
del attn_mask_slice
|
||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
||||
|
||||
hidden_states[start_idx:end_idx] = attn_slice
|
||||
del attn_slice
|
||||
####################################################################
|
||||
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AttnProcessor:
|
||||
r"""
|
||||
Default processor for performing attention-related computations.
|
||||
"""
|
||||
|
||||
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states=None, attention_mask=None,
|
||||
temb=None, scale: float = 1.0) -> torch.Tensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
args = () if USE_PEFT_BACKEND else (scale,)
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states, *args)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states, *args)
|
||||
value = attn.to_v(encoder_hidden_states, *args)
|
||||
|
||||
query = attn.head_to_batch_dim(query)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
####################################################################
|
||||
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
||||
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
|
||||
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
|
||||
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type)
|
||||
|
||||
if do_split:
|
||||
for i in range(batch_size_attention // split_slice_size):
|
||||
start_idx = i * split_slice_size
|
||||
end_idx = (i + 1) * split_slice_size
|
||||
if do_split_2:
|
||||
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_2 = i2 * split_2_slice_size
|
||||
end_idx_2 = (i2 + 1) * split_2_slice_size
|
||||
if do_split_3:
|
||||
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_3 = i3 * split_3_slice_size
|
||||
end_idx_3 = (i3 + 1) * split_3_slice_size
|
||||
|
||||
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
||||
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
|
||||
|
||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||
del query_slice
|
||||
del key_slice
|
||||
del attn_mask_slice
|
||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
|
||||
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
|
||||
del attn_slice
|
||||
else:
|
||||
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
|
||||
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
|
||||
|
||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||
del query_slice
|
||||
del key_slice
|
||||
del attn_mask_slice
|
||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
|
||||
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
|
||||
del attn_slice
|
||||
else:
|
||||
query_slice = query[start_idx:end_idx]
|
||||
key_slice = key[start_idx:end_idx]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
||||
|
||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||
del query_slice
|
||||
del key_slice
|
||||
del attn_mask_slice
|
||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
||||
|
||||
hidden_states[start_idx:end_idx] = attn_slice
|
||||
del attn_slice
|
||||
torch.xpu.synchronize(query.device)
|
||||
else:
|
||||
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
||||
hidden_states = torch.bmm(attention_probs, value)
|
||||
####################################################################
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states, *args)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
def ipex_diffusers():
|
||||
#ARC GPUs can't allocate more than 4GB to a single block:
|
||||
diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor
|
||||
diffusers.models.attention_processor.AttnProcessor = AttnProcessor
|
||||
183
library/ipex/gradscaler.py
Normal file
183
library/ipex/gradscaler.py
Normal file
@@ -0,0 +1,183 @@
|
||||
from collections import defaultdict
|
||||
import torch
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||
import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||
|
||||
device_supports_fp64 = torch.xpu.has_fp64_dtype()
|
||||
OptState = ipex.cpu.autocast._grad_scaler.OptState
|
||||
_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
|
||||
_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state
|
||||
|
||||
def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint: disable=unused-argument
|
||||
per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
|
||||
per_device_found_inf = _MultiDeviceReplicator(found_inf)
|
||||
|
||||
# To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
|
||||
# There could be hundreds of grads, so we'd like to iterate through them just once.
|
||||
# However, we don't know their devices or dtypes in advance.
|
||||
|
||||
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
|
||||
# Google says mypy struggles with defaultdicts type annotations.
|
||||
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
|
||||
# sync grad to master weight
|
||||
if hasattr(optimizer, "sync_grad"):
|
||||
optimizer.sync_grad()
|
||||
with torch.no_grad():
|
||||
for group in optimizer.param_groups:
|
||||
for param in group["params"]:
|
||||
if param.grad is None:
|
||||
continue
|
||||
if (not allow_fp16) and param.grad.dtype == torch.float16:
|
||||
raise ValueError("Attempting to unscale FP16 gradients.")
|
||||
if param.grad.is_sparse:
|
||||
# is_coalesced() == False means the sparse grad has values with duplicate indices.
|
||||
# coalesce() deduplicates indices and adds all values that have the same index.
|
||||
# For scaled fp16 values, there's a good chance coalescing will cause overflow,
|
||||
# so we should check the coalesced _values().
|
||||
if param.grad.dtype is torch.float16:
|
||||
param.grad = param.grad.coalesce()
|
||||
to_unscale = param.grad._values()
|
||||
else:
|
||||
to_unscale = param.grad
|
||||
|
||||
# -: is there a way to split by device and dtype without appending in the inner loop?
|
||||
to_unscale = to_unscale.to("cpu")
|
||||
per_device_and_dtype_grads[to_unscale.device][
|
||||
to_unscale.dtype
|
||||
].append(to_unscale)
|
||||
|
||||
for _, per_dtype_grads in per_device_and_dtype_grads.items():
|
||||
for grads in per_dtype_grads.values():
|
||||
core._amp_foreach_non_finite_check_and_unscale_(
|
||||
grads,
|
||||
per_device_found_inf.get("cpu"),
|
||||
per_device_inv_scale.get("cpu"),
|
||||
)
|
||||
|
||||
return per_device_found_inf._per_device_tensors
|
||||
|
||||
def unscale_(self, optimizer):
|
||||
"""
|
||||
Divides ("unscales") the optimizer's gradient tensors by the scale factor.
|
||||
:meth:`unscale_` is optional, serving cases where you need to
|
||||
:ref:`modify or inspect gradients<working-with-unscaled-gradients>`
|
||||
between the backward pass(es) and :meth:`step`.
|
||||
If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
|
||||
Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
|
||||
...
|
||||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
Args:
|
||||
optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
|
||||
.. warning::
|
||||
:meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
|
||||
and only after all gradients for that optimizer's assigned parameters have been accumulated.
|
||||
Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
|
||||
.. warning::
|
||||
:meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return
|
||||
|
||||
self._check_scale_growth_tracker("unscale_")
|
||||
|
||||
optimizer_state = self._per_optimizer_states[id(optimizer)]
|
||||
|
||||
if optimizer_state["stage"] is OptState.UNSCALED: # pylint: disable=no-else-raise
|
||||
raise RuntimeError(
|
||||
"unscale_() has already been called on this optimizer since the last update()."
|
||||
)
|
||||
elif optimizer_state["stage"] is OptState.STEPPED:
|
||||
raise RuntimeError("unscale_() is being called after step().")
|
||||
|
||||
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
|
||||
assert self._scale is not None
|
||||
if device_supports_fp64:
|
||||
inv_scale = self._scale.double().reciprocal().float()
|
||||
else:
|
||||
inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
|
||||
found_inf = torch.full(
|
||||
(1,), 0.0, dtype=torch.float32, device=self._scale.device
|
||||
)
|
||||
|
||||
optimizer_state["found_inf_per_device"] = self._unscale_grads_(
|
||||
optimizer, inv_scale, found_inf, False
|
||||
)
|
||||
optimizer_state["stage"] = OptState.UNSCALED
|
||||
|
||||
def update(self, new_scale=None):
|
||||
"""
|
||||
Updates the scale factor.
|
||||
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
|
||||
to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
|
||||
the scale is multiplied by ``growth_factor`` to increase it.
|
||||
Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
|
||||
used directly, it's used to fill GradScaler's internal scale tensor. So if
|
||||
``new_scale`` was a tensor, later in-place changes to that tensor will not further
|
||||
affect the scale GradScaler uses internally.)
|
||||
Args:
|
||||
new_scale (float or :class:`torch.FloatTensor`, optional, default=None): New scale factor.
|
||||
.. warning::
|
||||
:meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
|
||||
been invoked for all optimizers used this iteration.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return
|
||||
|
||||
_scale, _growth_tracker = self._check_scale_growth_tracker("update")
|
||||
|
||||
if new_scale is not None:
|
||||
# Accept a new user-defined scale.
|
||||
if isinstance(new_scale, float):
|
||||
self._scale.fill_(new_scale) # type: ignore[union-attr]
|
||||
else:
|
||||
reason = "new_scale should be a float or a 1-element torch.FloatTensor with requires_grad=False."
|
||||
assert isinstance(new_scale, torch.FloatTensor), reason # type: ignore[attr-defined]
|
||||
assert new_scale.numel() == 1, reason
|
||||
assert new_scale.requires_grad is False, reason
|
||||
self._scale.copy_(new_scale) # type: ignore[union-attr]
|
||||
else:
|
||||
# Consume shared inf/nan data collected from optimizers to update the scale.
|
||||
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
|
||||
found_infs = [
|
||||
found_inf.to(device="cpu", non_blocking=True)
|
||||
for state in self._per_optimizer_states.values()
|
||||
for found_inf in state["found_inf_per_device"].values()
|
||||
]
|
||||
|
||||
assert len(found_infs) > 0, "No inf checks were recorded prior to update."
|
||||
|
||||
found_inf_combined = found_infs[0]
|
||||
if len(found_infs) > 1:
|
||||
for i in range(1, len(found_infs)):
|
||||
found_inf_combined += found_infs[i]
|
||||
|
||||
to_device = _scale.device
|
||||
_scale = _scale.to("cpu")
|
||||
_growth_tracker = _growth_tracker.to("cpu")
|
||||
|
||||
core._amp_update_scale_(
|
||||
_scale,
|
||||
_growth_tracker,
|
||||
found_inf_combined,
|
||||
self._growth_factor,
|
||||
self._backoff_factor,
|
||||
self._growth_interval,
|
||||
)
|
||||
|
||||
_scale = _scale.to(to_device)
|
||||
_growth_tracker = _growth_tracker.to(to_device)
|
||||
# To prepare for next iteration, clear the data collected from optimizers this iteration.
|
||||
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
||||
|
||||
def gradscaler_init():
|
||||
torch.xpu.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
|
||||
torch.xpu.amp.GradScaler._unscale_grads_ = _unscale_grads_
|
||||
torch.xpu.amp.GradScaler.unscale_ = unscale_
|
||||
torch.xpu.amp.GradScaler.update = update
|
||||
return torch.xpu.amp.GradScaler
|
||||
298
library/ipex/hijacks.py
Normal file
298
library/ipex/hijacks.py
Normal file
@@ -0,0 +1,298 @@
|
||||
import os
|
||||
from functools import wraps
|
||||
from contextlib import nullcontext
|
||||
import torch
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||
import numpy as np
|
||||
|
||||
device_supports_fp64 = torch.xpu.has_fp64_dtype()
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
|
||||
|
||||
class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
|
||||
def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
|
||||
if isinstance(device_ids, list) and len(device_ids) > 1:
|
||||
print("IPEX backend doesn't support DataParallel on multiple XPU devices")
|
||||
return module.to("xpu")
|
||||
|
||||
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
|
||||
return nullcontext()
|
||||
|
||||
@property
|
||||
def is_cuda(self):
|
||||
return self.device.type == 'xpu' or self.device.type == 'cuda'
|
||||
|
||||
def check_device(device):
|
||||
return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
|
||||
|
||||
def return_xpu(device):
|
||||
return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu"
|
||||
|
||||
|
||||
# Autocast
|
||||
original_autocast_init = torch.amp.autocast_mode.autocast.__init__
|
||||
@wraps(torch.amp.autocast_mode.autocast.__init__)
|
||||
def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=None):
|
||||
if device_type == "cuda":
|
||||
return original_autocast_init(self, device_type="xpu", dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
|
||||
else:
|
||||
return original_autocast_init(self, device_type=device_type, dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
|
||||
|
||||
# Latent Antialias CPU Offload:
|
||||
original_interpolate = torch.nn.functional.interpolate
|
||||
@wraps(torch.nn.functional.interpolate)
|
||||
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
|
||||
if antialias or align_corners is not None:
|
||||
return_device = tensor.device
|
||||
return_dtype = tensor.dtype
|
||||
return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode,
|
||||
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias).to(return_device, dtype=return_dtype)
|
||||
else:
|
||||
return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode,
|
||||
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias)
|
||||
|
||||
|
||||
# Diffusers Float64 (Alchemist GPUs doesn't support 64 bit):
|
||||
original_from_numpy = torch.from_numpy
|
||||
@wraps(torch.from_numpy)
|
||||
def from_numpy(ndarray):
|
||||
if ndarray.dtype == float:
|
||||
return original_from_numpy(ndarray.astype('float32'))
|
||||
else:
|
||||
return original_from_numpy(ndarray)
|
||||
|
||||
original_as_tensor = torch.as_tensor
|
||||
@wraps(torch.as_tensor)
|
||||
def as_tensor(data, dtype=None, device=None):
|
||||
if check_device(device):
|
||||
device = return_xpu(device)
|
||||
if isinstance(data, np.ndarray) and data.dtype == float and not (
|
||||
(isinstance(device, torch.device) and device.type == "cpu") or (isinstance(device, str) and "cpu" in device)):
|
||||
return original_as_tensor(data, dtype=torch.float32, device=device)
|
||||
else:
|
||||
return original_as_tensor(data, dtype=dtype, device=device)
|
||||
|
||||
|
||||
if device_supports_fp64 and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None:
|
||||
original_torch_bmm = torch.bmm
|
||||
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
||||
else:
|
||||
# 32 bit attention workarounds for Alchemist:
|
||||
try:
|
||||
from .attention import torch_bmm_32_bit as original_torch_bmm
|
||||
from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
original_torch_bmm = torch.bmm
|
||||
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
||||
|
||||
|
||||
# Data Type Errors:
|
||||
@wraps(torch.bmm)
|
||||
def torch_bmm(input, mat2, *, out=None):
|
||||
if input.dtype != mat2.dtype:
|
||||
mat2 = mat2.to(input.dtype)
|
||||
return original_torch_bmm(input, mat2, out=out)
|
||||
|
||||
@wraps(torch.nn.functional.scaled_dot_product_attention)
|
||||
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
|
||||
if query.dtype != key.dtype:
|
||||
key = key.to(dtype=query.dtype)
|
||||
if query.dtype != value.dtype:
|
||||
value = value.to(dtype=query.dtype)
|
||||
if attn_mask is not None and query.dtype != attn_mask.dtype:
|
||||
attn_mask = attn_mask.to(dtype=query.dtype)
|
||||
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
|
||||
|
||||
# A1111 FP16
|
||||
original_functional_group_norm = torch.nn.functional.group_norm
|
||||
@wraps(torch.nn.functional.group_norm)
|
||||
def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05):
|
||||
if weight is not None and input.dtype != weight.data.dtype:
|
||||
input = input.to(dtype=weight.data.dtype)
|
||||
if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype:
|
||||
bias.data = bias.data.to(dtype=weight.data.dtype)
|
||||
return original_functional_group_norm(input, num_groups, weight=weight, bias=bias, eps=eps)
|
||||
|
||||
# A1111 BF16
|
||||
original_functional_layer_norm = torch.nn.functional.layer_norm
|
||||
@wraps(torch.nn.functional.layer_norm)
|
||||
def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
|
||||
if weight is not None and input.dtype != weight.data.dtype:
|
||||
input = input.to(dtype=weight.data.dtype)
|
||||
if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype:
|
||||
bias.data = bias.data.to(dtype=weight.data.dtype)
|
||||
return original_functional_layer_norm(input, normalized_shape, weight=weight, bias=bias, eps=eps)
|
||||
|
||||
# Training
|
||||
original_functional_linear = torch.nn.functional.linear
|
||||
@wraps(torch.nn.functional.linear)
|
||||
def functional_linear(input, weight, bias=None):
|
||||
if input.dtype != weight.data.dtype:
|
||||
input = input.to(dtype=weight.data.dtype)
|
||||
if bias is not None and bias.data.dtype != weight.data.dtype:
|
||||
bias.data = bias.data.to(dtype=weight.data.dtype)
|
||||
return original_functional_linear(input, weight, bias=bias)
|
||||
|
||||
original_functional_conv2d = torch.nn.functional.conv2d
|
||||
@wraps(torch.nn.functional.conv2d)
|
||||
def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
||||
if input.dtype != weight.data.dtype:
|
||||
input = input.to(dtype=weight.data.dtype)
|
||||
if bias is not None and bias.data.dtype != weight.data.dtype:
|
||||
bias.data = bias.data.to(dtype=weight.data.dtype)
|
||||
return original_functional_conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
|
||||
# A1111 Embedding BF16
|
||||
original_torch_cat = torch.cat
|
||||
@wraps(torch.cat)
|
||||
def torch_cat(tensor, *args, **kwargs):
|
||||
if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
|
||||
return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
|
||||
else:
|
||||
return original_torch_cat(tensor, *args, **kwargs)
|
||||
|
||||
# SwinIR BF16:
|
||||
original_functional_pad = torch.nn.functional.pad
|
||||
@wraps(torch.nn.functional.pad)
|
||||
def functional_pad(input, pad, mode='constant', value=None):
|
||||
if mode == 'reflect' and input.dtype == torch.bfloat16:
|
||||
return original_functional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16)
|
||||
else:
|
||||
return original_functional_pad(input, pad, mode=mode, value=value)
|
||||
|
||||
|
||||
original_torch_tensor = torch.tensor
|
||||
@wraps(torch.tensor)
|
||||
def torch_tensor(data, *args, dtype=None, device=None, **kwargs):
|
||||
if check_device(device):
|
||||
device = return_xpu(device)
|
||||
if not device_supports_fp64:
|
||||
if (isinstance(device, torch.device) and device.type == "xpu") or (isinstance(device, str) and "xpu" in device):
|
||||
if dtype == torch.float64:
|
||||
dtype = torch.float32
|
||||
elif dtype is None and (hasattr(data, "dtype") and (data.dtype == torch.float64 or data.dtype == float)):
|
||||
dtype = torch.float32
|
||||
return original_torch_tensor(data, *args, dtype=dtype, device=device, **kwargs)
|
||||
|
||||
original_Tensor_to = torch.Tensor.to
|
||||
@wraps(torch.Tensor.to)
|
||||
def Tensor_to(self, device=None, *args, **kwargs):
|
||||
if check_device(device):
|
||||
return original_Tensor_to(self, return_xpu(device), *args, **kwargs)
|
||||
else:
|
||||
return original_Tensor_to(self, device, *args, **kwargs)
|
||||
|
||||
original_Tensor_cuda = torch.Tensor.cuda
|
||||
@wraps(torch.Tensor.cuda)
|
||||
def Tensor_cuda(self, device=None, *args, **kwargs):
|
||||
if check_device(device):
|
||||
return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs)
|
||||
else:
|
||||
return original_Tensor_cuda(self, device, *args, **kwargs)
|
||||
|
||||
original_UntypedStorage_init = torch.UntypedStorage.__init__
|
||||
@wraps(torch.UntypedStorage.__init__)
|
||||
def UntypedStorage_init(*args, device=None, **kwargs):
|
||||
if check_device(device):
|
||||
return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
return original_UntypedStorage_init(*args, device=device, **kwargs)
|
||||
|
||||
original_UntypedStorage_cuda = torch.UntypedStorage.cuda
|
||||
@wraps(torch.UntypedStorage.cuda)
|
||||
def UntypedStorage_cuda(self, device=None, *args, **kwargs):
|
||||
if check_device(device):
|
||||
return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs)
|
||||
else:
|
||||
return original_UntypedStorage_cuda(self, device, *args, **kwargs)
|
||||
|
||||
original_torch_empty = torch.empty
|
||||
@wraps(torch.empty)
|
||||
def torch_empty(*args, device=None, **kwargs):
|
||||
if check_device(device):
|
||||
return original_torch_empty(*args, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
return original_torch_empty(*args, device=device, **kwargs)
|
||||
|
||||
original_torch_randn = torch.randn
|
||||
@wraps(torch.randn)
|
||||
def torch_randn(*args, device=None, **kwargs):
|
||||
if check_device(device):
|
||||
return original_torch_randn(*args, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
return original_torch_randn(*args, device=device, **kwargs)
|
||||
|
||||
original_torch_ones = torch.ones
|
||||
@wraps(torch.ones)
|
||||
def torch_ones(*args, device=None, **kwargs):
|
||||
if check_device(device):
|
||||
return original_torch_ones(*args, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
return original_torch_ones(*args, device=device, **kwargs)
|
||||
|
||||
original_torch_zeros = torch.zeros
|
||||
@wraps(torch.zeros)
|
||||
def torch_zeros(*args, device=None, **kwargs):
|
||||
if check_device(device):
|
||||
return original_torch_zeros(*args, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
return original_torch_zeros(*args, device=device, **kwargs)
|
||||
|
||||
original_torch_linspace = torch.linspace
|
||||
@wraps(torch.linspace)
|
||||
def torch_linspace(*args, device=None, **kwargs):
|
||||
if check_device(device):
|
||||
return original_torch_linspace(*args, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
return original_torch_linspace(*args, device=device, **kwargs)
|
||||
|
||||
original_torch_Generator = torch.Generator
|
||||
@wraps(torch.Generator)
|
||||
def torch_Generator(device=None):
|
||||
if check_device(device):
|
||||
return original_torch_Generator(return_xpu(device))
|
||||
else:
|
||||
return original_torch_Generator(device)
|
||||
|
||||
original_torch_load = torch.load
|
||||
@wraps(torch.load)
|
||||
def torch_load(f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs):
|
||||
if check_device(map_location):
|
||||
return original_torch_load(f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs)
|
||||
else:
|
||||
return original_torch_load(f, map_location=map_location, pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs)
|
||||
|
||||
|
||||
# Hijack Functions:
|
||||
def ipex_hijacks():
|
||||
torch.tensor = torch_tensor
|
||||
torch.Tensor.to = Tensor_to
|
||||
torch.Tensor.cuda = Tensor_cuda
|
||||
torch.UntypedStorage.__init__ = UntypedStorage_init
|
||||
torch.UntypedStorage.cuda = UntypedStorage_cuda
|
||||
torch.empty = torch_empty
|
||||
torch.randn = torch_randn
|
||||
torch.ones = torch_ones
|
||||
torch.zeros = torch_zeros
|
||||
torch.linspace = torch_linspace
|
||||
torch.Generator = torch_Generator
|
||||
torch.load = torch_load
|
||||
|
||||
torch.backends.cuda.sdp_kernel = return_null_context
|
||||
torch.nn.DataParallel = DummyDataParallel
|
||||
torch.UntypedStorage.is_cuda = is_cuda
|
||||
torch.amp.autocast_mode.autocast.__init__ = autocast_init
|
||||
|
||||
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
|
||||
torch.nn.functional.group_norm = functional_group_norm
|
||||
torch.nn.functional.layer_norm = functional_layer_norm
|
||||
torch.nn.functional.linear = functional_linear
|
||||
torch.nn.functional.conv2d = functional_conv2d
|
||||
torch.nn.functional.interpolate = interpolate
|
||||
torch.nn.functional.pad = functional_pad
|
||||
|
||||
torch.bmm = torch_bmm
|
||||
torch.cat = torch_cat
|
||||
if not device_supports_fp64:
|
||||
torch.from_numpy = from_numpy
|
||||
torch.as_tensor = as_tensor
|
||||
24
library/ipex_interop.py
Normal file
24
library/ipex_interop.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import torch
|
||||
|
||||
|
||||
def init_ipex():
|
||||
"""
|
||||
Try to import `intel_extension_for_pytorch`, and apply
|
||||
the hijacks using `library.ipex.ipex_init`.
|
||||
|
||||
If IPEX is not installed, this function does nothing.
|
||||
"""
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex # noqa
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
try:
|
||||
from library.ipex import ipex_init
|
||||
|
||||
if torch.xpu.is_available():
|
||||
is_initialized, error_message = ipex_init()
|
||||
if not is_initialized:
|
||||
print("failed to initialize ipex:", error_message)
|
||||
except Exception as e:
|
||||
print("failed to initialize ipex:", e)
|
||||
1234
library/lpw_stable_diffusion.py
Normal file
1234
library/lpw_stable_diffusion.py
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
1915
library/original_unet.py
Normal file
1915
library/original_unet.py
Normal file
File diff suppressed because it is too large
Load Diff
305
library/sai_model_spec.py
Normal file
305
library/sai_model_spec.py
Normal file
@@ -0,0 +1,305 @@
|
||||
# based on https://github.com/Stability-AI/ModelSpec
|
||||
import datetime
|
||||
import hashlib
|
||||
from io import BytesIO
|
||||
import os
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import safetensors
|
||||
|
||||
r"""
|
||||
# Metadata Example
|
||||
metadata = {
|
||||
# === Must ===
|
||||
"modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
|
||||
"modelspec.architecture": "stable-diffusion-xl-v1-base", # Architecture, reference the ID of the original model of the arch to match the ID
|
||||
"modelspec.implementation": "sgm",
|
||||
"modelspec.title": "Example Model Version 1.0", # Clean, human-readable title. May use your own phrasing/language/etc
|
||||
# === Should ===
|
||||
"modelspec.author": "Example Corp", # Your name or company name
|
||||
"modelspec.description": "This is my example model to show you how to do it!", # Describe the model in your own words/language/etc. Focus on what users need to know
|
||||
"modelspec.date": "2023-07-20", # ISO-8601 compliant date of when the model was created
|
||||
# === Can ===
|
||||
"modelspec.license": "ExampleLicense-1.0", # eg CreativeML Open RAIL, etc.
|
||||
"modelspec.usage_hint": "Use keyword 'example'" # In your own language, very short hints about how the user should use the model
|
||||
}
|
||||
"""
|
||||
|
||||
BASE_METADATA = {
|
||||
# === Must ===
|
||||
"modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
|
||||
"modelspec.architecture": None,
|
||||
"modelspec.implementation": None,
|
||||
"modelspec.title": None,
|
||||
"modelspec.resolution": None,
|
||||
# === Should ===
|
||||
"modelspec.description": None,
|
||||
"modelspec.author": None,
|
||||
"modelspec.date": None,
|
||||
# === Can ===
|
||||
"modelspec.license": None,
|
||||
"modelspec.tags": None,
|
||||
"modelspec.merged_from": None,
|
||||
"modelspec.prediction_type": None,
|
||||
"modelspec.timestep_range": None,
|
||||
"modelspec.encoder_layer": None,
|
||||
}
|
||||
|
||||
# 別に使うやつだけ定義
|
||||
MODELSPEC_TITLE = "modelspec.title"
|
||||
|
||||
ARCH_SD_V1 = "stable-diffusion-v1"
|
||||
ARCH_SD_V2_512 = "stable-diffusion-v2-512"
|
||||
ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v"
|
||||
ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
|
||||
|
||||
ADAPTER_LORA = "lora"
|
||||
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
|
||||
|
||||
IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
|
||||
IMPL_DIFFUSERS = "diffusers"
|
||||
|
||||
PRED_TYPE_EPSILON = "epsilon"
|
||||
PRED_TYPE_V = "v"
|
||||
|
||||
|
||||
def load_bytes_in_safetensors(tensors):
|
||||
bytes = safetensors.torch.save(tensors)
|
||||
b = BytesIO(bytes)
|
||||
|
||||
b.seek(0)
|
||||
header = b.read(8)
|
||||
n = int.from_bytes(header, "little")
|
||||
|
||||
offset = n + 8
|
||||
b.seek(offset)
|
||||
|
||||
return b.read()
|
||||
|
||||
|
||||
def precalculate_safetensors_hashes(state_dict):
|
||||
# calculate each tensor one by one to reduce memory usage
|
||||
hash_sha256 = hashlib.sha256()
|
||||
for tensor in state_dict.values():
|
||||
single_tensor_sd = {"tensor": tensor}
|
||||
bytes_for_tensor = load_bytes_in_safetensors(single_tensor_sd)
|
||||
hash_sha256.update(bytes_for_tensor)
|
||||
|
||||
return f"0x{hash_sha256.hexdigest()}"
|
||||
|
||||
|
||||
def update_hash_sha256(metadata: dict, state_dict: dict):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def build_metadata(
|
||||
state_dict: Optional[dict],
|
||||
v2: bool,
|
||||
v_parameterization: bool,
|
||||
sdxl: bool,
|
||||
lora: bool,
|
||||
textual_inversion: bool,
|
||||
timestamp: float,
|
||||
title: Optional[str] = None,
|
||||
reso: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
is_stable_diffusion_ckpt: Optional[bool] = None,
|
||||
author: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
license: Optional[str] = None,
|
||||
tags: Optional[str] = None,
|
||||
merged_from: Optional[str] = None,
|
||||
timesteps: Optional[Tuple[int, int]] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
):
|
||||
# if state_dict is None, hash is not calculated
|
||||
|
||||
metadata = {}
|
||||
metadata.update(BASE_METADATA)
|
||||
|
||||
# TODO メモリを消費せずかつ正しいハッシュ計算の方法がわかったら実装する
|
||||
# if state_dict is not None:
|
||||
# hash = precalculate_safetensors_hashes(state_dict)
|
||||
# metadata["modelspec.hash_sha256"] = hash
|
||||
|
||||
if sdxl:
|
||||
arch = ARCH_SD_XL_V1_BASE
|
||||
elif v2:
|
||||
if v_parameterization:
|
||||
arch = ARCH_SD_V2_768_V
|
||||
else:
|
||||
arch = ARCH_SD_V2_512
|
||||
else:
|
||||
arch = ARCH_SD_V1
|
||||
|
||||
if lora:
|
||||
arch += f"/{ADAPTER_LORA}"
|
||||
elif textual_inversion:
|
||||
arch += f"/{ADAPTER_TEXTUAL_INVERSION}"
|
||||
|
||||
metadata["modelspec.architecture"] = arch
|
||||
|
||||
if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
|
||||
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
|
||||
|
||||
if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
|
||||
# Stable Diffusion ckpt, TI, SDXL LoRA
|
||||
impl = IMPL_STABILITY_AI
|
||||
else:
|
||||
# v1/v2 LoRA or Diffusers
|
||||
impl = IMPL_DIFFUSERS
|
||||
metadata["modelspec.implementation"] = impl
|
||||
|
||||
if title is None:
|
||||
if lora:
|
||||
title = "LoRA"
|
||||
elif textual_inversion:
|
||||
title = "TextualInversion"
|
||||
else:
|
||||
title = "Checkpoint"
|
||||
title += f"@{timestamp}"
|
||||
metadata[MODELSPEC_TITLE] = title
|
||||
|
||||
if author is not None:
|
||||
metadata["modelspec.author"] = author
|
||||
else:
|
||||
del metadata["modelspec.author"]
|
||||
|
||||
if description is not None:
|
||||
metadata["modelspec.description"] = description
|
||||
else:
|
||||
del metadata["modelspec.description"]
|
||||
|
||||
if merged_from is not None:
|
||||
metadata["modelspec.merged_from"] = merged_from
|
||||
else:
|
||||
del metadata["modelspec.merged_from"]
|
||||
|
||||
if license is not None:
|
||||
metadata["modelspec.license"] = license
|
||||
else:
|
||||
del metadata["modelspec.license"]
|
||||
|
||||
if tags is not None:
|
||||
metadata["modelspec.tags"] = tags
|
||||
else:
|
||||
del metadata["modelspec.tags"]
|
||||
|
||||
# remove microsecond from time
|
||||
int_ts = int(timestamp)
|
||||
|
||||
# time to iso-8601 compliant date
|
||||
date = datetime.datetime.fromtimestamp(int_ts).isoformat()
|
||||
metadata["modelspec.date"] = date
|
||||
|
||||
if reso is not None:
|
||||
# comma separated to tuple
|
||||
if isinstance(reso, str):
|
||||
reso = tuple(map(int, reso.split(",")))
|
||||
if len(reso) == 1:
|
||||
reso = (reso[0], reso[0])
|
||||
else:
|
||||
# resolution is defined in dataset, so use default
|
||||
if sdxl:
|
||||
reso = 1024
|
||||
elif v2 and v_parameterization:
|
||||
reso = 768
|
||||
else:
|
||||
reso = 512
|
||||
if isinstance(reso, int):
|
||||
reso = (reso, reso)
|
||||
|
||||
metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}"
|
||||
|
||||
if v_parameterization:
|
||||
metadata["modelspec.prediction_type"] = PRED_TYPE_V
|
||||
else:
|
||||
metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON
|
||||
|
||||
if timesteps is not None:
|
||||
if isinstance(timesteps, str) or isinstance(timesteps, int):
|
||||
timesteps = (timesteps, timesteps)
|
||||
if len(timesteps) == 1:
|
||||
timesteps = (timesteps[0], timesteps[0])
|
||||
metadata["modelspec.timestep_range"] = f"{timesteps[0]},{timesteps[1]}"
|
||||
else:
|
||||
del metadata["modelspec.timestep_range"]
|
||||
|
||||
if clip_skip is not None:
|
||||
metadata["modelspec.encoder_layer"] = f"{clip_skip}"
|
||||
else:
|
||||
del metadata["modelspec.encoder_layer"]
|
||||
|
||||
# # assert all values are filled
|
||||
# assert all([v is not None for v in metadata.values()]), metadata
|
||||
if not all([v is not None for v in metadata.values()]):
|
||||
print(f"Internal error: some metadata values are None: {metadata}")
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
# region utils
|
||||
|
||||
|
||||
def get_title(metadata: dict) -> Optional[str]:
|
||||
return metadata.get(MODELSPEC_TITLE, None)
|
||||
|
||||
|
||||
def load_metadata_from_safetensors(model: str) -> dict:
|
||||
if not model.endswith(".safetensors"):
|
||||
return {}
|
||||
|
||||
with safetensors.safe_open(model, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
return metadata
|
||||
|
||||
|
||||
def build_merged_from(models: List[str]) -> str:
|
||||
def get_title(model: str):
|
||||
metadata = load_metadata_from_safetensors(model)
|
||||
title = metadata.get(MODELSPEC_TITLE, None)
|
||||
if title is None:
|
||||
title = os.path.splitext(os.path.basename(model))[0] # use filename
|
||||
return title
|
||||
|
||||
titles = [get_title(model) for model in models]
|
||||
return ", ".join(titles)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
r"""
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from library import train_util
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ckpt", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"Loading {args.ckpt}")
|
||||
state_dict = load_file(args.ckpt)
|
||||
|
||||
print(f"Calculating metadata")
|
||||
metadata = get(state_dict, False, False, False, False, "sgm", False, False, "title", "date", 256, 1000, 0)
|
||||
print(metadata)
|
||||
del state_dict
|
||||
|
||||
# by reference implementation
|
||||
with open(args.ckpt, mode="rb") as file_data:
|
||||
file_hash = hashlib.sha256()
|
||||
head_len = struct.unpack("Q", file_data.read(8)) # int64 header length prefix
|
||||
header = json.loads(file_data.read(head_len[0])) # header itself, json string
|
||||
content = (
|
||||
file_data.read()
|
||||
) # All other content is tightly packed tensors. Copy to RAM for simplicity, but you can avoid this read with a more careful FS-dependent impl.
|
||||
file_hash.update(content)
|
||||
# ===== Update the hash for modelspec =====
|
||||
by_ref = f"0x{file_hash.hexdigest()}"
|
||||
print(by_ref)
|
||||
print("is same?", by_ref == metadata["modelspec.hash_sha256"])
|
||||
|
||||
"""
|
||||
1347
library/sdxl_lpw_stable_diffusion.py
Normal file
1347
library/sdxl_lpw_stable_diffusion.py
Normal file
File diff suppressed because it is too large
Load Diff
574
library/sdxl_model_util.py
Normal file
574
library/sdxl_model_util.py
Normal file
@@ -0,0 +1,574 @@
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.utils.modeling import set_module_tensor_to_device
|
||||
from safetensors.torch import load_file, save_file
|
||||
from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
from typing import List
|
||||
from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
|
||||
from library import model_util
|
||||
from library import sdxl_original_unet
|
||||
|
||||
|
||||
VAE_SCALE_FACTOR = 0.13025
|
||||
MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_base_v1-0"
|
||||
|
||||
# Diffusersの設定を読み込むための参照モデル
|
||||
DIFFUSERS_REF_MODEL_ID_SDXL = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
|
||||
DIFFUSERS_SDXL_UNET_CONFIG = {
|
||||
"act_fn": "silu",
|
||||
"addition_embed_type": "text_time",
|
||||
"addition_embed_type_num_heads": 64,
|
||||
"addition_time_embed_dim": 256,
|
||||
"attention_head_dim": [5, 10, 20],
|
||||
"block_out_channels": [320, 640, 1280],
|
||||
"center_input_sample": False,
|
||||
"class_embed_type": None,
|
||||
"class_embeddings_concat": False,
|
||||
"conv_in_kernel": 3,
|
||||
"conv_out_kernel": 3,
|
||||
"cross_attention_dim": 2048,
|
||||
"cross_attention_norm": None,
|
||||
"down_block_types": ["DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"],
|
||||
"downsample_padding": 1,
|
||||
"dual_cross_attention": False,
|
||||
"encoder_hid_dim": None,
|
||||
"encoder_hid_dim_type": None,
|
||||
"flip_sin_to_cos": True,
|
||||
"freq_shift": 0,
|
||||
"in_channels": 4,
|
||||
"layers_per_block": 2,
|
||||
"mid_block_only_cross_attention": None,
|
||||
"mid_block_scale_factor": 1,
|
||||
"mid_block_type": "UNetMidBlock2DCrossAttn",
|
||||
"norm_eps": 1e-05,
|
||||
"norm_num_groups": 32,
|
||||
"num_attention_heads": None,
|
||||
"num_class_embeds": None,
|
||||
"only_cross_attention": False,
|
||||
"out_channels": 4,
|
||||
"projection_class_embeddings_input_dim": 2816,
|
||||
"resnet_out_scale_factor": 1.0,
|
||||
"resnet_skip_time_act": False,
|
||||
"resnet_time_scale_shift": "default",
|
||||
"sample_size": 128,
|
||||
"time_cond_proj_dim": None,
|
||||
"time_embedding_act_fn": None,
|
||||
"time_embedding_dim": None,
|
||||
"time_embedding_type": "positional",
|
||||
"timestep_post_act": None,
|
||||
"transformer_layers_per_block": [1, 2, 10],
|
||||
"up_block_types": ["CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"],
|
||||
"upcast_attention": False,
|
||||
"use_linear_projection": True,
|
||||
}
|
||||
|
||||
|
||||
def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
|
||||
SDXL_KEY_PREFIX = "conditioner.embedders.1.model."
|
||||
|
||||
# SD2のと、基本的には同じ。logit_scaleを後で使うので、それを追加で返す
|
||||
# logit_scaleはcheckpointの保存時に使用する
|
||||
def convert_key(key):
|
||||
# common conversion
|
||||
key = key.replace(SDXL_KEY_PREFIX + "transformer.", "text_model.encoder.")
|
||||
key = key.replace(SDXL_KEY_PREFIX, "text_model.")
|
||||
|
||||
if "resblocks" in key:
|
||||
# resblocks conversion
|
||||
key = key.replace(".resblocks.", ".layers.")
|
||||
if ".ln_" in key:
|
||||
key = key.replace(".ln_", ".layer_norm")
|
||||
elif ".mlp." in key:
|
||||
key = key.replace(".c_fc.", ".fc1.")
|
||||
key = key.replace(".c_proj.", ".fc2.")
|
||||
elif ".attn.out_proj" in key:
|
||||
key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
|
||||
elif ".attn.in_proj" in key:
|
||||
key = None # 特殊なので後で処理する
|
||||
else:
|
||||
raise ValueError(f"unexpected key in SD: {key}")
|
||||
elif ".positional_embedding" in key:
|
||||
key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
|
||||
elif ".text_projection" in key:
|
||||
key = key.replace("text_model.text_projection", "text_projection.weight")
|
||||
elif ".logit_scale" in key:
|
||||
key = None # 後で処理する
|
||||
elif ".token_embedding" in key:
|
||||
key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
|
||||
elif ".ln_final" in key:
|
||||
key = key.replace(".ln_final", ".final_layer_norm")
|
||||
# ckpt from comfy has this key: text_model.encoder.text_model.embeddings.position_ids
|
||||
elif ".embeddings.position_ids" in key:
|
||||
key = None # remove this key: position_ids is not used in newer transformers
|
||||
return key
|
||||
|
||||
keys = list(checkpoint.keys())
|
||||
new_sd = {}
|
||||
for key in keys:
|
||||
new_key = convert_key(key)
|
||||
if new_key is None:
|
||||
continue
|
||||
new_sd[new_key] = checkpoint[key]
|
||||
|
||||
# attnの変換
|
||||
for key in keys:
|
||||
if ".resblocks" in key and ".attn.in_proj_" in key:
|
||||
# 三つに分割
|
||||
values = torch.chunk(checkpoint[key], 3)
|
||||
|
||||
key_suffix = ".weight" if "weight" in key else ".bias"
|
||||
key_pfx = key.replace(SDXL_KEY_PREFIX + "transformer.resblocks.", "text_model.encoder.layers.")
|
||||
key_pfx = key_pfx.replace("_weight", "")
|
||||
key_pfx = key_pfx.replace("_bias", "")
|
||||
key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
|
||||
new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
|
||||
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
|
||||
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
|
||||
|
||||
# logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す
|
||||
logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None)
|
||||
|
||||
# temporary workaround for text_projection.weight.weight for Playground-v2
|
||||
if "text_projection.weight.weight" in new_sd:
|
||||
print(f"convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight")
|
||||
new_sd["text_projection.weight"] = new_sd["text_projection.weight.weight"]
|
||||
del new_sd["text_projection.weight.weight"]
|
||||
|
||||
return new_sd, logit_scale
|
||||
|
||||
|
||||
# load state_dict without allocating new tensors
|
||||
def _load_state_dict_on_device(model, state_dict, device, dtype=None):
|
||||
# dtype will use fp32 as default
|
||||
missing_keys = list(model.state_dict().keys() - state_dict.keys())
|
||||
unexpected_keys = list(state_dict.keys() - model.state_dict().keys())
|
||||
|
||||
# similar to model.load_state_dict()
|
||||
if not missing_keys and not unexpected_keys:
|
||||
for k in list(state_dict.keys()):
|
||||
set_module_tensor_to_device(model, k, device, value=state_dict.pop(k), dtype=dtype)
|
||||
return "<All keys matched successfully>"
|
||||
|
||||
# error_msgs
|
||||
error_msgs: List[str] = []
|
||||
if missing_keys:
|
||||
error_msgs.insert(0, "Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in missing_keys)))
|
||||
if unexpected_keys:
|
||||
error_msgs.insert(0, "Unexpected key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in unexpected_keys)))
|
||||
|
||||
raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
|
||||
|
||||
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None):
|
||||
# model_version is reserved for future use
|
||||
# dtype is used for full_fp16/bf16 integration. Text Encoder will remain fp32, because it runs on CPU when caching
|
||||
|
||||
# Load the state dict
|
||||
if model_util.is_safetensors(ckpt_path):
|
||||
checkpoint = None
|
||||
try:
|
||||
state_dict = load_file(ckpt_path, device=map_location)
|
||||
except:
|
||||
state_dict = load_file(ckpt_path) # prevent device invalid Error
|
||||
epoch = None
|
||||
global_step = None
|
||||
else:
|
||||
checkpoint = torch.load(ckpt_path, map_location=map_location)
|
||||
if "state_dict" in checkpoint:
|
||||
state_dict = checkpoint["state_dict"]
|
||||
epoch = checkpoint.get("epoch", 0)
|
||||
global_step = checkpoint.get("global_step", 0)
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
epoch = 0
|
||||
global_step = 0
|
||||
checkpoint = None
|
||||
|
||||
# U-Net
|
||||
print("building U-Net")
|
||||
with init_empty_weights():
|
||||
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
|
||||
|
||||
print("loading U-Net from checkpoint")
|
||||
unet_sd = {}
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith("model.diffusion_model."):
|
||||
unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k)
|
||||
info = _load_state_dict_on_device(unet, unet_sd, device=map_location, dtype=dtype)
|
||||
print("U-Net: ", info)
|
||||
|
||||
# Text Encoders
|
||||
print("building text encoders")
|
||||
|
||||
# Text Encoder 1 is same to Stability AI's SDXL
|
||||
text_model1_cfg = CLIPTextConfig(
|
||||
vocab_size=49408,
|
||||
hidden_size=768,
|
||||
intermediate_size=3072,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
max_position_embeddings=77,
|
||||
hidden_act="quick_gelu",
|
||||
layer_norm_eps=1e-05,
|
||||
dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
initializer_range=0.02,
|
||||
initializer_factor=1.0,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
model_type="clip_text_model",
|
||||
projection_dim=768,
|
||||
# torch_dtype="float32",
|
||||
# transformers_version="4.25.0.dev0",
|
||||
)
|
||||
with init_empty_weights():
|
||||
text_model1 = CLIPTextModel._from_config(text_model1_cfg)
|
||||
|
||||
# Text Encoder 2 is different from Stability AI's SDXL. SDXL uses open clip, but we use the model from HuggingFace.
|
||||
# Note: Tokenizer from HuggingFace is different from SDXL. We must use open clip's tokenizer.
|
||||
text_model2_cfg = CLIPTextConfig(
|
||||
vocab_size=49408,
|
||||
hidden_size=1280,
|
||||
intermediate_size=5120,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=20,
|
||||
max_position_embeddings=77,
|
||||
hidden_act="gelu",
|
||||
layer_norm_eps=1e-05,
|
||||
dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
initializer_range=0.02,
|
||||
initializer_factor=1.0,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
model_type="clip_text_model",
|
||||
projection_dim=1280,
|
||||
# torch_dtype="float32",
|
||||
# transformers_version="4.25.0.dev0",
|
||||
)
|
||||
with init_empty_weights():
|
||||
text_model2 = CLIPTextModelWithProjection(text_model2_cfg)
|
||||
|
||||
print("loading text encoders from checkpoint")
|
||||
te1_sd = {}
|
||||
te2_sd = {}
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith("conditioner.embedders.0.transformer."):
|
||||
te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k)
|
||||
elif k.startswith("conditioner.embedders.1.model."):
|
||||
te2_sd[k] = state_dict.pop(k)
|
||||
|
||||
# 最新の transformers では position_ids を含むとエラーになるので削除 / remove position_ids for latest transformers
|
||||
if "text_model.embeddings.position_ids" in te1_sd:
|
||||
te1_sd.pop("text_model.embeddings.position_ids")
|
||||
|
||||
info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32
|
||||
print("text encoder 1:", info1)
|
||||
|
||||
converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77)
|
||||
info2 = _load_state_dict_on_device(text_model2, converted_sd, device=map_location) # remain fp32
|
||||
print("text encoder 2:", info2)
|
||||
|
||||
# prepare vae
|
||||
print("building VAE")
|
||||
vae_config = model_util.create_vae_diffusers_config()
|
||||
with init_empty_weights():
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
|
||||
print("loading VAE from checkpoint")
|
||||
converted_vae_checkpoint = model_util.convert_ldm_vae_checkpoint(state_dict, vae_config)
|
||||
info = _load_state_dict_on_device(vae, converted_vae_checkpoint, device=map_location, dtype=dtype)
|
||||
print("VAE:", info)
|
||||
|
||||
ckpt_info = (epoch, global_step) if epoch is not None else None
|
||||
return text_model1, text_model2, vae, unet, logit_scale, ckpt_info
|
||||
|
||||
|
||||
def make_unet_conversion_map():
|
||||
unet_conversion_map_layer = []
|
||||
|
||||
for i in range(3): # num_blocks is 3 in sdxl
|
||||
# loop over downblocks/upblocks
|
||||
for j in range(2):
|
||||
# loop over resnets/attentions for downblocks
|
||||
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
||||
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no attention layers in down_blocks.3
|
||||
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
||||
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||
|
||||
for j in range(3):
|
||||
# loop over resnets/attentions for upblocks
|
||||
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
||||
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
||||
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
||||
|
||||
# if i > 0: commentout for sdxl
|
||||
# no attention layers in up_blocks.0
|
||||
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
||||
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
||||
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no downsample in down_blocks.3
|
||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
||||
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
# no upsample in up_blocks.3
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
|
||||
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
|
||||
hf_mid_atn_prefix = "mid_block.attentions.0."
|
||||
sd_mid_atn_prefix = "middle_block.1."
|
||||
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2*j}."
|
||||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
unet_conversion_map_resnet = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("in_layers.0.", "norm1."),
|
||||
("in_layers.2.", "conv1."),
|
||||
("out_layers.0.", "norm2."),
|
||||
("out_layers.3.", "conv2."),
|
||||
("emb_layers.1.", "time_emb_proj."),
|
||||
("skip_connection.", "conv_shortcut."),
|
||||
]
|
||||
|
||||
unet_conversion_map = []
|
||||
for sd, hf in unet_conversion_map_layer:
|
||||
if "resnets" in hf:
|
||||
for sd_res, hf_res in unet_conversion_map_resnet:
|
||||
unet_conversion_map.append((sd + sd_res, hf + hf_res))
|
||||
else:
|
||||
unet_conversion_map.append((sd, hf))
|
||||
|
||||
for j in range(2):
|
||||
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
|
||||
sd_time_embed_prefix = f"time_embed.{j*2}."
|
||||
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
|
||||
sd_label_embed_prefix = f"label_emb.0.{j*2}."
|
||||
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
|
||||
|
||||
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
|
||||
unet_conversion_map.append(("out.0.", "conv_norm_out."))
|
||||
unet_conversion_map.append(("out.2.", "conv_out."))
|
||||
|
||||
return unet_conversion_map
|
||||
|
||||
|
||||
def convert_diffusers_unet_state_dict_to_sdxl(du_sd):
|
||||
unet_conversion_map = make_unet_conversion_map()
|
||||
|
||||
conversion_map = {hf: sd for sd, hf in unet_conversion_map}
|
||||
return convert_unet_state_dict(du_sd, conversion_map)
|
||||
|
||||
|
||||
def convert_unet_state_dict(src_sd, conversion_map):
|
||||
converted_sd = {}
|
||||
for src_key, value in src_sd.items():
|
||||
# さすがに全部回すのは時間がかかるので右から要素を削りつつprefixを探す
|
||||
src_key_fragments = src_key.split(".")[:-1] # remove weight/bias
|
||||
while len(src_key_fragments) > 0:
|
||||
src_key_prefix = ".".join(src_key_fragments) + "."
|
||||
if src_key_prefix in conversion_map:
|
||||
converted_prefix = conversion_map[src_key_prefix]
|
||||
converted_key = converted_prefix + src_key[len(src_key_prefix) :]
|
||||
converted_sd[converted_key] = value
|
||||
break
|
||||
src_key_fragments.pop(-1)
|
||||
assert len(src_key_fragments) > 0, f"key {src_key} not found in conversion map"
|
||||
|
||||
return converted_sd
|
||||
|
||||
|
||||
def convert_sdxl_unet_state_dict_to_diffusers(sd):
|
||||
unet_conversion_map = make_unet_conversion_map()
|
||||
|
||||
conversion_dict = {sd: hf for sd, hf in unet_conversion_map}
|
||||
return convert_unet_state_dict(sd, conversion_dict)
|
||||
|
||||
|
||||
def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, logit_scale):
|
||||
def convert_key(key):
|
||||
# position_idsの除去
|
||||
if ".position_ids" in key:
|
||||
return None
|
||||
|
||||
# common
|
||||
key = key.replace("text_model.encoder.", "transformer.")
|
||||
key = key.replace("text_model.", "")
|
||||
if "layers" in key:
|
||||
# resblocks conversion
|
||||
key = key.replace(".layers.", ".resblocks.")
|
||||
if ".layer_norm" in key:
|
||||
key = key.replace(".layer_norm", ".ln_")
|
||||
elif ".mlp." in key:
|
||||
key = key.replace(".fc1.", ".c_fc.")
|
||||
key = key.replace(".fc2.", ".c_proj.")
|
||||
elif ".self_attn.out_proj" in key:
|
||||
key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
|
||||
elif ".self_attn." in key:
|
||||
key = None # 特殊なので後で処理する
|
||||
else:
|
||||
raise ValueError(f"unexpected key in DiffUsers model: {key}")
|
||||
elif ".position_embedding" in key:
|
||||
key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
|
||||
elif ".token_embedding" in key:
|
||||
key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
|
||||
elif "text_projection" in key: # no dot in key
|
||||
key = key.replace("text_projection.weight", "text_projection")
|
||||
elif "final_layer_norm" in key:
|
||||
key = key.replace("final_layer_norm", "ln_final")
|
||||
return key
|
||||
|
||||
keys = list(checkpoint.keys())
|
||||
new_sd = {}
|
||||
for key in keys:
|
||||
new_key = convert_key(key)
|
||||
if new_key is None:
|
||||
continue
|
||||
new_sd[new_key] = checkpoint[key]
|
||||
|
||||
# attnの変換
|
||||
for key in keys:
|
||||
if "layers" in key and "q_proj" in key:
|
||||
# 三つを結合
|
||||
key_q = key
|
||||
key_k = key.replace("q_proj", "k_proj")
|
||||
key_v = key.replace("q_proj", "v_proj")
|
||||
|
||||
value_q = checkpoint[key_q]
|
||||
value_k = checkpoint[key_k]
|
||||
value_v = checkpoint[key_v]
|
||||
value = torch.cat([value_q, value_k, value_v])
|
||||
|
||||
new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
|
||||
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
|
||||
new_sd[new_key] = value
|
||||
|
||||
if logit_scale is not None:
|
||||
new_sd["logit_scale"] = logit_scale
|
||||
|
||||
return new_sd
|
||||
|
||||
|
||||
def save_stable_diffusion_checkpoint(
|
||||
output_file,
|
||||
text_encoder1,
|
||||
text_encoder2,
|
||||
unet,
|
||||
epochs,
|
||||
steps,
|
||||
ckpt_info,
|
||||
vae,
|
||||
logit_scale,
|
||||
metadata,
|
||||
save_dtype=None,
|
||||
):
|
||||
state_dict = {}
|
||||
|
||||
def update_sd(prefix, sd):
|
||||
for k, v in sd.items():
|
||||
key = prefix + k
|
||||
if save_dtype is not None:
|
||||
v = v.detach().clone().to("cpu").to(save_dtype)
|
||||
state_dict[key] = v
|
||||
|
||||
# Convert the UNet model
|
||||
update_sd("model.diffusion_model.", unet.state_dict())
|
||||
|
||||
# Convert the text encoders
|
||||
update_sd("conditioner.embedders.0.transformer.", text_encoder1.state_dict())
|
||||
|
||||
text_enc2_dict = convert_text_encoder_2_state_dict_to_sdxl(text_encoder2.state_dict(), logit_scale)
|
||||
update_sd("conditioner.embedders.1.model.", text_enc2_dict)
|
||||
|
||||
# Convert the VAE
|
||||
vae_dict = model_util.convert_vae_state_dict(vae.state_dict())
|
||||
update_sd("first_stage_model.", vae_dict)
|
||||
|
||||
# Put together new checkpoint
|
||||
key_count = len(state_dict.keys())
|
||||
new_ckpt = {"state_dict": state_dict}
|
||||
|
||||
# epoch and global_step are sometimes not int
|
||||
if ckpt_info is not None:
|
||||
epochs += ckpt_info[0]
|
||||
steps += ckpt_info[1]
|
||||
|
||||
new_ckpt["epoch"] = epochs
|
||||
new_ckpt["global_step"] = steps
|
||||
|
||||
if model_util.is_safetensors(output_file):
|
||||
save_file(state_dict, output_file, metadata)
|
||||
else:
|
||||
torch.save(new_ckpt, output_file)
|
||||
|
||||
return key_count
|
||||
|
||||
|
||||
def save_diffusers_checkpoint(
|
||||
output_dir, text_encoder1, text_encoder2, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False, save_dtype=None
|
||||
):
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
|
||||
# convert U-Net
|
||||
unet_sd = unet.state_dict()
|
||||
du_unet_sd = convert_sdxl_unet_state_dict_to_diffusers(unet_sd)
|
||||
|
||||
diffusers_unet = UNet2DConditionModel(**DIFFUSERS_SDXL_UNET_CONFIG)
|
||||
if save_dtype is not None:
|
||||
diffusers_unet.to(save_dtype)
|
||||
diffusers_unet.load_state_dict(du_unet_sd)
|
||||
|
||||
# create pipeline to save
|
||||
if pretrained_model_name_or_path is None:
|
||||
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_SDXL
|
||||
|
||||
scheduler = EulerDiscreteScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
|
||||
tokenizer1 = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
|
||||
tokenizer2 = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer_2")
|
||||
if vae is None:
|
||||
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
|
||||
|
||||
# prevent local path from being saved
|
||||
def remove_name_or_path(model):
|
||||
if hasattr(model, "config"):
|
||||
model.config._name_or_path = None
|
||||
model.config._name_or_path = None
|
||||
|
||||
remove_name_or_path(diffusers_unet)
|
||||
remove_name_or_path(text_encoder1)
|
||||
remove_name_or_path(text_encoder2)
|
||||
remove_name_or_path(scheduler)
|
||||
remove_name_or_path(tokenizer1)
|
||||
remove_name_or_path(tokenizer2)
|
||||
remove_name_or_path(vae)
|
||||
|
||||
pipeline = StableDiffusionXLPipeline(
|
||||
unet=diffusers_unet,
|
||||
text_encoder=text_encoder1,
|
||||
text_encoder_2=text_encoder2,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
tokenizer=tokenizer1,
|
||||
tokenizer_2=tokenizer2,
|
||||
)
|
||||
if save_dtype is not None:
|
||||
pipeline.to(None, save_dtype)
|
||||
pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
|
||||
1281
library/sdxl_original_unet.py
Normal file
1281
library/sdxl_original_unet.py
Normal file
File diff suppressed because it is too large
Load Diff
367
library/sdxl_train_util.py
Normal file
367
library/sdxl_train_util.py
Normal file
@@ -0,0 +1,367 @@
|
||||
import argparse
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
from typing import Optional
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPTokenizer
|
||||
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
|
||||
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
|
||||
|
||||
TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
|
||||
TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
|
||||
# DEFAULT_NOISE_OFFSET = 0.0357
|
||||
|
||||
|
||||
def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
||||
# load models for each process
|
||||
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16
|
||||
for pi in range(accelerator.state.num_processes):
|
||||
if pi == accelerator.state.local_process_index:
|
||||
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
|
||||
|
||||
(
|
||||
load_stable_diffusion_format,
|
||||
text_encoder1,
|
||||
text_encoder2,
|
||||
vae,
|
||||
unet,
|
||||
logit_scale,
|
||||
ckpt_info,
|
||||
) = _load_target_model(
|
||||
args.pretrained_model_name_or_path,
|
||||
args.vae,
|
||||
model_version,
|
||||
weight_dtype,
|
||||
accelerator.device if args.lowram else "cpu",
|
||||
model_dtype,
|
||||
)
|
||||
|
||||
# work on low-ram device
|
||||
if args.lowram:
|
||||
text_encoder1.to(accelerator.device)
|
||||
text_encoder2.to(accelerator.device)
|
||||
unet.to(accelerator.device)
|
||||
vae.to(accelerator.device)
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
|
||||
|
||||
|
||||
def _load_target_model(
|
||||
name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None
|
||||
):
|
||||
# model_dtype only work with full fp16/bf16
|
||||
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
||||
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
||||
|
||||
if load_stable_diffusion_format:
|
||||
print(f"load StableDiffusion checkpoint: {name_or_path}")
|
||||
(
|
||||
text_encoder1,
|
||||
text_encoder2,
|
||||
vae,
|
||||
unet,
|
||||
logit_scale,
|
||||
ckpt_info,
|
||||
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype)
|
||||
else:
|
||||
# Diffusers model is loaded to CPU
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
|
||||
variant = "fp16" if weight_dtype == torch.float16 else None
|
||||
print(f"load Diffusers pretrained models: {name_or_path}, variant={variant}")
|
||||
try:
|
||||
try:
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
name_or_path, torch_dtype=model_dtype, variant=variant, tokenizer=None
|
||||
)
|
||||
except EnvironmentError as ex:
|
||||
if variant is not None:
|
||||
print("try to load fp32 model")
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, variant=None, tokenizer=None)
|
||||
else:
|
||||
raise ex
|
||||
except EnvironmentError as ex:
|
||||
print(
|
||||
f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}"
|
||||
)
|
||||
raise ex
|
||||
|
||||
text_encoder1 = pipe.text_encoder
|
||||
text_encoder2 = pipe.text_encoder_2
|
||||
|
||||
# convert to fp32 for cache text_encoders outputs
|
||||
if text_encoder1.dtype != torch.float32:
|
||||
text_encoder1 = text_encoder1.to(dtype=torch.float32)
|
||||
if text_encoder2.dtype != torch.float32:
|
||||
text_encoder2 = text_encoder2.to(dtype=torch.float32)
|
||||
|
||||
vae = pipe.vae
|
||||
unet = pipe.unet
|
||||
del pipe
|
||||
|
||||
# Diffusers U-Net to original U-Net
|
||||
state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict())
|
||||
with init_empty_weights():
|
||||
unet = sdxl_original_unet.SdxlUNet2DConditionModel() # overwrite unet
|
||||
sdxl_model_util._load_state_dict_on_device(unet, state_dict, device=device, dtype=model_dtype)
|
||||
print("U-Net converted to original U-Net")
|
||||
|
||||
logit_scale = None
|
||||
ckpt_info = None
|
||||
|
||||
# VAEを読み込む
|
||||
if vae_path is not None:
|
||||
vae = model_util.load_vae(vae_path, weight_dtype)
|
||||
print("additional VAE loaded")
|
||||
|
||||
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
|
||||
|
||||
|
||||
def load_tokenizers(args: argparse.Namespace):
|
||||
print("prepare tokenizers")
|
||||
|
||||
original_paths = [TOKENIZER1_PATH, TOKENIZER2_PATH]
|
||||
tokeniers = []
|
||||
for i, original_path in enumerate(original_paths):
|
||||
tokenizer: CLIPTokenizer = None
|
||||
if args.tokenizer_cache_dir:
|
||||
local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
|
||||
if os.path.exists(local_tokenizer_path):
|
||||
print(f"load tokenizer from cache: {local_tokenizer_path}")
|
||||
tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path)
|
||||
|
||||
if tokenizer is None:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(original_path)
|
||||
|
||||
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
|
||||
print(f"save Tokenizer to cache: {local_tokenizer_path}")
|
||||
tokenizer.save_pretrained(local_tokenizer_path)
|
||||
|
||||
if i == 1:
|
||||
tokenizer.pad_token_id = 0 # fix pad token id to make same as open clip tokenizer
|
||||
|
||||
tokeniers.append(tokenizer)
|
||||
|
||||
if hasattr(args, "max_token_length") and args.max_token_length is not None:
|
||||
print(f"update token length: {args.max_token_length}")
|
||||
|
||||
return tokeniers
|
||||
|
||||
|
||||
def match_mixed_precision(args, weight_dtype):
|
||||
if args.full_fp16:
|
||||
assert (
|
||||
weight_dtype == torch.float16
|
||||
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
||||
return weight_dtype
|
||||
elif args.full_bf16:
|
||||
assert (
|
||||
weight_dtype == torch.bfloat16
|
||||
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
|
||||
return weight_dtype
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def timestep_embedding(timesteps, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
||||
device=timesteps.device
|
||||
)
|
||||
args = timesteps[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
|
||||
def get_timestep_embedding(x, outdim):
|
||||
assert len(x.shape) == 2
|
||||
b, dims = x.shape[0], x.shape[1]
|
||||
x = torch.flatten(x)
|
||||
emb = timestep_embedding(x, outdim)
|
||||
emb = torch.reshape(emb, (b, dims * outdim))
|
||||
return emb
|
||||
|
||||
|
||||
def get_size_embeddings(orig_size, crop_size, target_size, device):
|
||||
emb1 = get_timestep_embedding(orig_size, 256)
|
||||
emb2 = get_timestep_embedding(crop_size, 256)
|
||||
emb3 = get_timestep_embedding(target_size, 256)
|
||||
vector = torch.cat([emb1, emb2, emb3], dim=1).to(device)
|
||||
return vector
|
||||
|
||||
|
||||
def save_sd_model_on_train_end(
|
||||
args: argparse.Namespace,
|
||||
src_path: str,
|
||||
save_stable_diffusion_format: bool,
|
||||
use_safetensors: bool,
|
||||
save_dtype: torch.dtype,
|
||||
epoch: int,
|
||||
global_step: int,
|
||||
text_encoder1,
|
||||
text_encoder2,
|
||||
unet,
|
||||
vae,
|
||||
logit_scale,
|
||||
ckpt_info,
|
||||
):
|
||||
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||
sai_metadata = train_util.get_sai_model_spec(None, args, True, False, False, is_stable_diffusion_ckpt=True)
|
||||
sdxl_model_util.save_stable_diffusion_checkpoint(
|
||||
ckpt_file,
|
||||
text_encoder1,
|
||||
text_encoder2,
|
||||
unet,
|
||||
epoch_no,
|
||||
global_step,
|
||||
ckpt_info,
|
||||
vae,
|
||||
logit_scale,
|
||||
sai_metadata,
|
||||
save_dtype,
|
||||
)
|
||||
|
||||
def diffusers_saver(out_dir):
|
||||
sdxl_model_util.save_diffusers_checkpoint(
|
||||
out_dir,
|
||||
text_encoder1,
|
||||
text_encoder2,
|
||||
unet,
|
||||
src_path,
|
||||
vae,
|
||||
use_safetensors=use_safetensors,
|
||||
save_dtype=save_dtype,
|
||||
)
|
||||
|
||||
train_util.save_sd_model_on_train_end_common(
|
||||
args, save_stable_diffusion_format, use_safetensors, epoch, global_step, sd_saver, diffusers_saver
|
||||
)
|
||||
|
||||
|
||||
# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している
|
||||
# on_epoch_end: Trueならepoch終了時、Falseならstep経過時
|
||||
def save_sd_model_on_epoch_end_or_stepwise(
|
||||
args: argparse.Namespace,
|
||||
on_epoch_end: bool,
|
||||
accelerator,
|
||||
src_path,
|
||||
save_stable_diffusion_format: bool,
|
||||
use_safetensors: bool,
|
||||
save_dtype: torch.dtype,
|
||||
epoch: int,
|
||||
num_train_epochs: int,
|
||||
global_step: int,
|
||||
text_encoder1,
|
||||
text_encoder2,
|
||||
unet,
|
||||
vae,
|
||||
logit_scale,
|
||||
ckpt_info,
|
||||
):
|
||||
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||
sai_metadata = train_util.get_sai_model_spec(None, args, True, False, False, is_stable_diffusion_ckpt=True)
|
||||
sdxl_model_util.save_stable_diffusion_checkpoint(
|
||||
ckpt_file,
|
||||
text_encoder1,
|
||||
text_encoder2,
|
||||
unet,
|
||||
epoch_no,
|
||||
global_step,
|
||||
ckpt_info,
|
||||
vae,
|
||||
logit_scale,
|
||||
sai_metadata,
|
||||
save_dtype,
|
||||
)
|
||||
|
||||
def diffusers_saver(out_dir):
|
||||
sdxl_model_util.save_diffusers_checkpoint(
|
||||
out_dir,
|
||||
text_encoder1,
|
||||
text_encoder2,
|
||||
unet,
|
||||
src_path,
|
||||
vae,
|
||||
use_safetensors=use_safetensors,
|
||||
save_dtype=save_dtype,
|
||||
)
|
||||
|
||||
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
|
||||
args,
|
||||
on_epoch_end,
|
||||
accelerator,
|
||||
save_stable_diffusion_format,
|
||||
use_safetensors,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
sd_saver,
|
||||
diffusers_saver,
|
||||
)
|
||||
|
||||
|
||||
def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_text_encoder_outputs_to_disk",
|
||||
action="store_true",
|
||||
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
|
||||
)
|
||||
|
||||
|
||||
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
|
||||
assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
|
||||
if args.v_parameterization:
|
||||
print("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります")
|
||||
|
||||
if args.clip_skip is not None:
|
||||
print("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
|
||||
|
||||
# if args.multires_noise_iterations:
|
||||
# print(
|
||||
# f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります"
|
||||
# )
|
||||
# else:
|
||||
# if args.noise_offset is None:
|
||||
# args.noise_offset = DEFAULT_NOISE_OFFSET
|
||||
# elif args.noise_offset != DEFAULT_NOISE_OFFSET:
|
||||
# print(
|
||||
# f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています"
|
||||
# )
|
||||
# print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
|
||||
|
||||
assert (
|
||||
not hasattr(args, "weighted_captions") or not args.weighted_captions
|
||||
), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
|
||||
|
||||
if supportTextEncoderCaching:
|
||||
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
||||
args.cache_text_encoder_outputs = True
|
||||
print(
|
||||
"cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / "
|
||||
+ "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました"
|
||||
)
|
||||
|
||||
|
||||
def sample_images(*args, **kwargs):
|
||||
return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
|
||||
679
library/slicing_vae.py
Normal file
679
library/slicing_vae.py
Normal file
@@ -0,0 +1,679 @@
|
||||
# Modified from Diffusers to reduce VRAM usage
|
||||
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
||||
from diffusers.models.vae import DecoderOutput, DiagonalGaussianDistribution
|
||||
from diffusers.models.autoencoder_kl import AutoencoderKLOutput
|
||||
|
||||
|
||||
def slice_h(x, num_slices):
|
||||
# slice with pad 1 both sides: to eliminate side effect of padding of conv2d
|
||||
# Conv2dのpaddingの副作用を排除するために、両側にpad 1しながらHをスライスする
|
||||
# NCHWでもNHWCでもどちらでも動く
|
||||
size = (x.shape[2] + num_slices - 1) // num_slices
|
||||
sliced = []
|
||||
for i in range(num_slices):
|
||||
if i == 0:
|
||||
sliced.append(x[:, :, : size + 1, :])
|
||||
else:
|
||||
end = size * (i + 1) + 1
|
||||
if x.shape[2] - end < 3: # if the last slice is too small, use the rest of the tensor 最後が細すぎるとconv2dできないので全部使う
|
||||
end = x.shape[2]
|
||||
sliced.append(x[:, :, size * i - 1 : end, :])
|
||||
if end >= x.shape[2]:
|
||||
break
|
||||
return sliced
|
||||
|
||||
|
||||
def cat_h(sliced):
|
||||
# padding分を除いて結合する
|
||||
cat = []
|
||||
for i, x in enumerate(sliced):
|
||||
if i == 0:
|
||||
cat.append(x[:, :, :-1, :])
|
||||
elif i == len(sliced) - 1:
|
||||
cat.append(x[:, :, 1:, :])
|
||||
else:
|
||||
cat.append(x[:, :, 1:-1, :])
|
||||
del x
|
||||
x = torch.cat(cat, dim=2)
|
||||
return x
|
||||
|
||||
|
||||
def resblock_forward(_self, num_slices, input_tensor, temb, **kwargs):
|
||||
assert _self.upsample is None and _self.downsample is None
|
||||
assert _self.norm1.num_groups == _self.norm2.num_groups
|
||||
assert temb is None
|
||||
|
||||
# make sure norms are on cpu
|
||||
org_device = input_tensor.device
|
||||
cpu_device = torch.device("cpu")
|
||||
_self.norm1.to(cpu_device)
|
||||
_self.norm2.to(cpu_device)
|
||||
|
||||
# GroupNormがCPUでfp16で動かない対策
|
||||
org_dtype = input_tensor.dtype
|
||||
if org_dtype == torch.float16:
|
||||
_self.norm1.to(torch.float32)
|
||||
_self.norm2.to(torch.float32)
|
||||
|
||||
# すべてのテンソルをCPUに移動する
|
||||
input_tensor = input_tensor.to(cpu_device)
|
||||
hidden_states = input_tensor
|
||||
|
||||
# どうもこれは結果が異なるようだ……
|
||||
# def sliced_norm1(norm, x):
|
||||
# num_div = 4 if up_block_idx <= 2 else x.shape[1] // norm.num_groups
|
||||
# sliced_tensor = torch.chunk(x, num_div, dim=1)
|
||||
# sliced_weight = torch.chunk(norm.weight, num_div, dim=0)
|
||||
# sliced_bias = torch.chunk(norm.bias, num_div, dim=0)
|
||||
# print(sliced_tensor[0].shape, num_div, sliced_weight[0].shape, sliced_bias[0].shape)
|
||||
# normed_tensor = []
|
||||
# for i in range(num_div):
|
||||
# n = torch.group_norm(sliced_tensor[i], norm.num_groups, sliced_weight[i], sliced_bias[i], norm.eps)
|
||||
# normed_tensor.append(n)
|
||||
# del n
|
||||
# x = torch.cat(normed_tensor, dim=1)
|
||||
# return num_div, x
|
||||
|
||||
# normを分割すると結果が変わるので、ここだけは分割しない。GPUで計算するとVRAMが足りなくなるので、CPUで計算する。幸いCPUでもそこまで遅くない
|
||||
if org_dtype == torch.float16:
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
hidden_states = _self.norm1(hidden_states) # run on cpu
|
||||
if org_dtype == torch.float16:
|
||||
hidden_states = hidden_states.to(torch.float16)
|
||||
|
||||
sliced = slice_h(hidden_states, num_slices)
|
||||
del hidden_states
|
||||
|
||||
for i in range(len(sliced)):
|
||||
x = sliced[i]
|
||||
sliced[i] = None
|
||||
|
||||
# 計算する部分だけGPUに移動する、以下同様
|
||||
x = x.to(org_device)
|
||||
x = _self.nonlinearity(x)
|
||||
x = _self.conv1(x)
|
||||
x = x.to(cpu_device)
|
||||
sliced[i] = x
|
||||
del x
|
||||
|
||||
hidden_states = cat_h(sliced)
|
||||
del sliced
|
||||
|
||||
if org_dtype == torch.float16:
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
hidden_states = _self.norm2(hidden_states) # run on cpu
|
||||
if org_dtype == torch.float16:
|
||||
hidden_states = hidden_states.to(torch.float16)
|
||||
|
||||
sliced = slice_h(hidden_states, num_slices)
|
||||
del hidden_states
|
||||
|
||||
for i in range(len(sliced)):
|
||||
x = sliced[i]
|
||||
sliced[i] = None
|
||||
|
||||
x = x.to(org_device)
|
||||
x = _self.nonlinearity(x)
|
||||
x = _self.dropout(x)
|
||||
x = _self.conv2(x)
|
||||
x = x.to(cpu_device)
|
||||
sliced[i] = x
|
||||
del x
|
||||
|
||||
hidden_states = cat_h(sliced)
|
||||
del sliced
|
||||
|
||||
# make shortcut
|
||||
if _self.conv_shortcut is not None:
|
||||
sliced = list(torch.chunk(input_tensor, num_slices, dim=2)) # no padding in conv_shortcut パディングがないので普通にスライスする
|
||||
del input_tensor
|
||||
|
||||
for i in range(len(sliced)):
|
||||
x = sliced[i]
|
||||
sliced[i] = None
|
||||
|
||||
x = x.to(org_device)
|
||||
x = _self.conv_shortcut(x)
|
||||
x = x.to(cpu_device)
|
||||
sliced[i] = x
|
||||
del x
|
||||
|
||||
input_tensor = torch.cat(sliced, dim=2)
|
||||
del sliced
|
||||
|
||||
output_tensor = (input_tensor + hidden_states) / _self.output_scale_factor
|
||||
|
||||
output_tensor = output_tensor.to(org_device) # 次のレイヤーがGPUで計算する
|
||||
return output_tensor
|
||||
|
||||
|
||||
class SlicingEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=("DownEncoderBlock2D",),
|
||||
block_out_channels=(64,),
|
||||
layers_per_block=2,
|
||||
norm_num_groups=32,
|
||||
act_fn="silu",
|
||||
double_z=True,
|
||||
num_slices=2,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers_per_block = layers_per_block
|
||||
|
||||
self.conv_in = torch.nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.mid_block = None
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=self.layers_per_block,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
add_downsample=not is_final_block,
|
||||
resnet_eps=1e-6,
|
||||
downsample_padding=0,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
attention_head_dim=output_channel,
|
||||
temb_channels=None,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
self.mid_block = UNetMidBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
resnet_eps=1e-6,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=1,
|
||||
resnet_time_scale_shift="default",
|
||||
attention_head_dim=block_out_channels[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
temb_channels=None,
|
||||
)
|
||||
self.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) # とりあえずDiffusersのxformersを使う
|
||||
|
||||
# out
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
|
||||
self.conv_act = nn.SiLU()
|
||||
|
||||
conv_out_channels = 2 * out_channels if double_z else out_channels
|
||||
self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
|
||||
|
||||
# replace forward of ResBlocks
|
||||
def wrapper(func, module, num_slices):
|
||||
def forward(*args, **kwargs):
|
||||
return func(module, num_slices, *args, **kwargs)
|
||||
|
||||
return forward
|
||||
|
||||
self.num_slices = num_slices
|
||||
div = num_slices / (2 ** (len(self.down_blocks) - 1)) # 深い層はそこまで分割しなくていいので適宜減らす
|
||||
# print(f"initial divisor: {div}")
|
||||
if div >= 2:
|
||||
div = int(div)
|
||||
for resnet in self.mid_block.resnets:
|
||||
resnet.forward = wrapper(resblock_forward, resnet, div)
|
||||
# midblock doesn't have downsample
|
||||
|
||||
for i, down_block in enumerate(self.down_blocks[::-1]):
|
||||
if div >= 2:
|
||||
div = int(div)
|
||||
# print(f"down block: {i} divisor: {div}")
|
||||
for resnet in down_block.resnets:
|
||||
resnet.forward = wrapper(resblock_forward, resnet, div)
|
||||
if down_block.downsamplers is not None:
|
||||
# print("has downsample")
|
||||
for downsample in down_block.downsamplers:
|
||||
downsample.forward = wrapper(self.downsample_forward, downsample, div * 2)
|
||||
div *= 2
|
||||
|
||||
def forward(self, x):
|
||||
sample = x
|
||||
del x
|
||||
|
||||
org_device = sample.device
|
||||
cpu_device = torch.device("cpu")
|
||||
|
||||
# sample = self.conv_in(sample)
|
||||
sample = sample.to(cpu_device)
|
||||
sliced = slice_h(sample, self.num_slices)
|
||||
del sample
|
||||
|
||||
for i in range(len(sliced)):
|
||||
x = sliced[i]
|
||||
sliced[i] = None
|
||||
|
||||
x = x.to(org_device)
|
||||
x = self.conv_in(x)
|
||||
x = x.to(cpu_device)
|
||||
sliced[i] = x
|
||||
del x
|
||||
|
||||
sample = cat_h(sliced)
|
||||
del sliced
|
||||
|
||||
sample = sample.to(org_device)
|
||||
|
||||
# down
|
||||
for down_block in self.down_blocks:
|
||||
sample = down_block(sample)
|
||||
|
||||
# middle
|
||||
sample = self.mid_block(sample)
|
||||
|
||||
# post-process
|
||||
# ここも省メモリ化したいが、恐らくそこまでメモリを食わないので省略
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
return sample
|
||||
|
||||
def downsample_forward(self, _self, num_slices, hidden_states):
|
||||
assert hidden_states.shape[1] == _self.channels
|
||||
assert _self.use_conv and _self.padding == 0
|
||||
print("downsample forward", num_slices, hidden_states.shape)
|
||||
|
||||
org_device = hidden_states.device
|
||||
cpu_device = torch.device("cpu")
|
||||
|
||||
hidden_states = hidden_states.to(cpu_device)
|
||||
pad = (0, 1, 0, 1)
|
||||
hidden_states = torch.nn.functional.pad(hidden_states, pad, mode="constant", value=0)
|
||||
|
||||
# slice with even number because of stride 2
|
||||
# strideが2なので偶数でスライスする
|
||||
# slice with pad 1 both sides: to eliminate side effect of padding of conv2d
|
||||
size = (hidden_states.shape[2] + num_slices - 1) // num_slices
|
||||
size = size + 1 if size % 2 == 1 else size
|
||||
|
||||
sliced = []
|
||||
for i in range(num_slices):
|
||||
if i == 0:
|
||||
sliced.append(hidden_states[:, :, : size + 1, :])
|
||||
else:
|
||||
end = size * (i + 1) + 1
|
||||
if hidden_states.shape[2] - end < 4: # if the last slice is too small, use the rest of the tensor
|
||||
end = hidden_states.shape[2]
|
||||
sliced.append(hidden_states[:, :, size * i - 1 : end, :])
|
||||
if end >= hidden_states.shape[2]:
|
||||
break
|
||||
del hidden_states
|
||||
|
||||
for i in range(len(sliced)):
|
||||
x = sliced[i]
|
||||
sliced[i] = None
|
||||
|
||||
x = x.to(org_device)
|
||||
x = _self.conv(x)
|
||||
x = x.to(cpu_device)
|
||||
|
||||
# ここだけ雰囲気が違うのはCopilotのせい
|
||||
if i == 0:
|
||||
hidden_states = x
|
||||
else:
|
||||
hidden_states = torch.cat([hidden_states, x], dim=2)
|
||||
|
||||
hidden_states = hidden_states.to(org_device)
|
||||
# print("downsample forward done", hidden_states.shape)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SlicingDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
up_block_types=("UpDecoderBlock2D",),
|
||||
block_out_channels=(64,),
|
||||
layers_per_block=2,
|
||||
norm_num_groups=32,
|
||||
act_fn="silu",
|
||||
num_slices=2,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers_per_block = layers_per_block
|
||||
|
||||
self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.mid_block = None
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
# mid
|
||||
self.mid_block = UNetMidBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
resnet_eps=1e-6,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=1,
|
||||
resnet_time_scale_shift="default",
|
||||
attention_head_dim=block_out_channels[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
temb_channels=None,
|
||||
)
|
||||
self.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) # とりあえずDiffusersのxformersを使う
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_out_channels[i]
|
||||
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
up_block = get_up_block(
|
||||
up_block_type,
|
||||
num_layers=self.layers_per_block + 1,
|
||||
in_channels=prev_output_channel,
|
||||
out_channels=output_channel,
|
||||
prev_output_channel=None,
|
||||
add_upsample=not is_final_block,
|
||||
resnet_eps=1e-6,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
attention_head_dim=output_channel,
|
||||
temb_channels=None,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# out
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
|
||||
|
||||
# replace forward of ResBlocks
|
||||
def wrapper(func, module, num_slices):
|
||||
def forward(*args, **kwargs):
|
||||
return func(module, num_slices, *args, **kwargs)
|
||||
|
||||
return forward
|
||||
|
||||
self.num_slices = num_slices
|
||||
div = num_slices / (2 ** (len(self.up_blocks) - 1))
|
||||
print(f"initial divisor: {div}")
|
||||
if div >= 2:
|
||||
div = int(div)
|
||||
for resnet in self.mid_block.resnets:
|
||||
resnet.forward = wrapper(resblock_forward, resnet, div)
|
||||
# midblock doesn't have upsample
|
||||
|
||||
for i, up_block in enumerate(self.up_blocks):
|
||||
if div >= 2:
|
||||
div = int(div)
|
||||
# print(f"up block: {i} divisor: {div}")
|
||||
for resnet in up_block.resnets:
|
||||
resnet.forward = wrapper(resblock_forward, resnet, div)
|
||||
if up_block.upsamplers is not None:
|
||||
# print("has upsample")
|
||||
for upsample in up_block.upsamplers:
|
||||
upsample.forward = wrapper(self.upsample_forward, upsample, div * 2)
|
||||
div *= 2
|
||||
|
||||
def forward(self, z):
|
||||
sample = z
|
||||
del z
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
# middle
|
||||
sample = self.mid_block(sample)
|
||||
|
||||
# up
|
||||
for i, up_block in enumerate(self.up_blocks):
|
||||
sample = up_block(sample)
|
||||
|
||||
# post-process
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
|
||||
# conv_out with slicing because of VRAM usage
|
||||
# conv_outはとてもVRAM使うのでスライスして対応
|
||||
org_device = sample.device
|
||||
cpu_device = torch.device("cpu")
|
||||
sample = sample.to(cpu_device)
|
||||
|
||||
sliced = slice_h(sample, self.num_slices)
|
||||
del sample
|
||||
for i in range(len(sliced)):
|
||||
x = sliced[i]
|
||||
sliced[i] = None
|
||||
|
||||
x = x.to(org_device)
|
||||
x = self.conv_out(x)
|
||||
x = x.to(cpu_device)
|
||||
sliced[i] = x
|
||||
sample = cat_h(sliced)
|
||||
del sliced
|
||||
|
||||
sample = sample.to(org_device)
|
||||
return sample
|
||||
|
||||
def upsample_forward(self, _self, num_slices, hidden_states, output_size=None):
|
||||
assert hidden_states.shape[1] == _self.channels
|
||||
assert _self.use_conv_transpose == False and _self.use_conv
|
||||
|
||||
org_dtype = hidden_states.dtype
|
||||
org_device = hidden_states.device
|
||||
cpu_device = torch.device("cpu")
|
||||
|
||||
hidden_states = hidden_states.to(cpu_device)
|
||||
sliced = slice_h(hidden_states, num_slices)
|
||||
del hidden_states
|
||||
|
||||
for i in range(len(sliced)):
|
||||
x = sliced[i]
|
||||
sliced[i] = None
|
||||
|
||||
x = x.to(org_device)
|
||||
|
||||
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
||||
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
|
||||
# https://github.com/pytorch/pytorch/issues/86679
|
||||
# PyTorch 2で直らないかね……
|
||||
if org_dtype == torch.bfloat16:
|
||||
x = x.to(torch.float32)
|
||||
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
|
||||
if org_dtype == torch.bfloat16:
|
||||
x = x.to(org_dtype)
|
||||
|
||||
x = _self.conv(x)
|
||||
|
||||
# upsampleされてるのでpadは2になる
|
||||
if i == 0:
|
||||
x = x[:, :, :-2, :]
|
||||
elif i == num_slices - 1:
|
||||
x = x[:, :, 2:, :]
|
||||
else:
|
||||
x = x[:, :, 2:-2, :]
|
||||
|
||||
x = x.to(cpu_device)
|
||||
sliced[i] = x
|
||||
del x
|
||||
|
||||
hidden_states = torch.cat(sliced, dim=2)
|
||||
# print("us hidden_states", hidden_states.shape)
|
||||
del sliced
|
||||
|
||||
hidden_states = hidden_states.to(org_device)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SlicingAutoencoderKL(ModelMixin, ConfigMixin):
|
||||
r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma
|
||||
and Max Welling.
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
|
||||
implements for all the model (such as downloading or saving, etc.)
|
||||
|
||||
Parameters:
|
||||
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
||||
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to :
|
||||
obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to :
|
||||
obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to :
|
||||
obj:`(64,)`): Tuple of block output channels.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
||||
latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space.
|
||||
sample_size (`int`, *optional*, defaults to `32`): TODO
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
|
||||
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
|
||||
block_out_channels: Tuple[int] = (64,),
|
||||
layers_per_block: int = 1,
|
||||
act_fn: str = "silu",
|
||||
latent_channels: int = 4,
|
||||
norm_num_groups: int = 32,
|
||||
sample_size: int = 32,
|
||||
num_slices: int = 16,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# pass init params to Encoder
|
||||
self.encoder = SlicingEncoder(
|
||||
in_channels=in_channels,
|
||||
out_channels=latent_channels,
|
||||
down_block_types=down_block_types,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
act_fn=act_fn,
|
||||
norm_num_groups=norm_num_groups,
|
||||
double_z=True,
|
||||
num_slices=num_slices,
|
||||
)
|
||||
|
||||
# pass init params to Decoder
|
||||
self.decoder = SlicingDecoder(
|
||||
in_channels=latent_channels,
|
||||
out_channels=out_channels,
|
||||
up_block_types=up_block_types,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
norm_num_groups=norm_num_groups,
|
||||
act_fn=act_fn,
|
||||
num_slices=num_slices,
|
||||
)
|
||||
|
||||
self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
|
||||
self.use_slicing = False
|
||||
|
||||
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
||||
h = self.encoder(x)
|
||||
moments = self.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
z = self.post_quant_conv(z)
|
||||
dec = self.decoder(z)
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
# これはバッチ方向のスライシング 紛らわしい
|
||||
def enable_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding.
|
||||
|
||||
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
|
||||
steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.use_slicing = True
|
||||
|
||||
def disable_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_slicing = False
|
||||
|
||||
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
if self.use_slicing and z.shape[0] > 1:
|
||||
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
||||
decoded = torch.cat(decoded_slices)
|
||||
else:
|
||||
decoded = self._decode(z).sample
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): Input sample.
|
||||
sample_posterior (`bool`, *optional*, defaults to `False`):
|
||||
Whether to sample from the posterior.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
||||
"""
|
||||
x = sample
|
||||
posterior = self.encode(x).latent_dist
|
||||
if sample_posterior:
|
||||
z = posterior.sample(generator=generator)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z).sample
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
File diff suppressed because it is too large
Load Diff
6
library/utils.py
Normal file
6
library/utils.py
Normal file
@@ -0,0 +1,6 @@
|
||||
import threading
|
||||
from typing import *
|
||||
|
||||
|
||||
def fire_in_thread(f, *args, **kwargs):
|
||||
threading.Thread(target=f, args=args, kwargs=kwargs).start()
|
||||
@@ -5,28 +5,41 @@ from safetensors.torch import load_file
|
||||
|
||||
|
||||
def main(file):
|
||||
print(f"loading: {file}")
|
||||
if os.path.splitext(file)[1] == '.safetensors':
|
||||
sd = load_file(file)
|
||||
else:
|
||||
sd = torch.load(file, map_location='cpu')
|
||||
print(f"loading: {file}")
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
sd = load_file(file)
|
||||
else:
|
||||
sd = torch.load(file, map_location="cpu")
|
||||
|
||||
values = []
|
||||
values = []
|
||||
|
||||
keys = list(sd.keys())
|
||||
for key in keys:
|
||||
if 'lora_up' in key or 'lora_down' in key:
|
||||
values.append((key, sd[key]))
|
||||
print(f"number of LoRA modules: {len(values)}")
|
||||
keys = list(sd.keys())
|
||||
for key in keys:
|
||||
if "lora_up" in key or "lora_down" in key:
|
||||
values.append((key, sd[key]))
|
||||
print(f"number of LoRA modules: {len(values)}")
|
||||
|
||||
for key, value in values:
|
||||
value = value.to(torch.float32)
|
||||
print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
|
||||
if args.show_all_keys:
|
||||
for key in [k for k in keys if k not in values]:
|
||||
values.append((key, sd[key]))
|
||||
print(f"number of all modules: {len(values)}")
|
||||
|
||||
for key, value in values:
|
||||
value = value.to(torch.float32)
|
||||
print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("file", type=str, help="model file to check / 重みを確認するモデルファイル")
|
||||
args = parser.parse_args()
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("file", type=str, help="model file to check / 重みを確認するモデルファイル")
|
||||
parser.add_argument("-s", "--show_all_keys", action="store_true", help="show all keys / 全てのキーを表示する")
|
||||
|
||||
main(args.file)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args.file)
|
||||
|
||||
446
networks/control_net_lllite.py
Normal file
446
networks/control_net_lllite.py
Normal file
@@ -0,0 +1,446 @@
|
||||
import os
|
||||
from typing import Optional, List, Type
|
||||
import torch
|
||||
from library import sdxl_original_unet
|
||||
|
||||
|
||||
# input_blocksに適用するかどうか / if True, input_blocks are not applied
|
||||
SKIP_INPUT_BLOCKS = False
|
||||
|
||||
# output_blocksに適用するかどうか / if True, output_blocks are not applied
|
||||
SKIP_OUTPUT_BLOCKS = True
|
||||
|
||||
# conv2dに適用するかどうか / if True, conv2d are not applied
|
||||
SKIP_CONV2D = False
|
||||
|
||||
# transformer_blocksのみに適用するかどうか。Trueの場合、ResBlockには適用されない
|
||||
# if True, only transformer_blocks are applied, and ResBlocks are not applied
|
||||
TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored because conv2d is not used in transformer_blocks
|
||||
|
||||
# Trueならattn1とattn2にのみ適用し、ffなどには適用しない / if True, apply only to attn1 and attn2, not to ff etc.
|
||||
ATTN1_2_ONLY = True
|
||||
|
||||
# Trueならattn1のQKV、attn2のQにのみ適用する、ATTN1_2_ONLY指定時のみ有効 / if True, apply only to attn1 QKV and attn2 Q, only valid when ATTN1_2_ONLY is specified
|
||||
ATTN_QKV_ONLY = True
|
||||
|
||||
# Trueならattn1やffなどにのみ適用し、attn2などには適用しない / if True, apply only to attn1 and ff, not to attn2
|
||||
# ATTN1_2_ONLYと同時にTrueにできない / cannot be True at the same time as ATTN1_2_ONLY
|
||||
ATTN1_ETC_ONLY = False # True
|
||||
|
||||
# transformer_blocksの最大インデックス。Noneなら全てのtransformer_blocksに適用
|
||||
# max index of transformer_blocks. if None, apply to all transformer_blocks
|
||||
TRANSFORMER_MAX_BLOCK_INDEX = None
|
||||
|
||||
|
||||
class LLLiteModule(torch.nn.Module):
|
||||
def __init__(self, depth, cond_emb_dim, name, org_module, mlp_dim, dropout=None, multiplier=1.0):
|
||||
super().__init__()
|
||||
|
||||
self.is_conv2d = org_module.__class__.__name__ == "Conv2d"
|
||||
self.lllite_name = name
|
||||
self.cond_emb_dim = cond_emb_dim
|
||||
self.org_module = [org_module]
|
||||
self.dropout = dropout
|
||||
self.multiplier = multiplier
|
||||
|
||||
if self.is_conv2d:
|
||||
in_dim = org_module.in_channels
|
||||
else:
|
||||
in_dim = org_module.in_features
|
||||
|
||||
# conditioning1はconditioning imageを embedding する。timestepごとに呼ばれない
|
||||
# conditioning1 embeds conditioning image. it is not called for each timestep
|
||||
modules = []
|
||||
modules.append(torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size
|
||||
if depth == 1:
|
||||
modules.append(torch.nn.ReLU(inplace=True))
|
||||
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
|
||||
elif depth == 2:
|
||||
modules.append(torch.nn.ReLU(inplace=True))
|
||||
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0))
|
||||
elif depth == 3:
|
||||
# kernel size 8は大きすぎるので、4にする / kernel size 8 is too large, so set it to 4
|
||||
modules.append(torch.nn.ReLU(inplace=True))
|
||||
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0))
|
||||
modules.append(torch.nn.ReLU(inplace=True))
|
||||
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
|
||||
|
||||
self.conditioning1 = torch.nn.Sequential(*modules)
|
||||
|
||||
# downで入力の次元数を削減する。LoRAにヒントを得ていることにする
|
||||
# midでconditioning image embeddingと入力を結合する
|
||||
# upで元の次元数に戻す
|
||||
# これらはtimestepごとに呼ばれる
|
||||
# reduce the number of input dimensions with down. inspired by LoRA
|
||||
# combine conditioning image embedding and input with mid
|
||||
# restore to the original dimension with up
|
||||
# these are called for each timestep
|
||||
|
||||
if self.is_conv2d:
|
||||
self.down = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(in_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
)
|
||||
self.mid = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(mlp_dim + cond_emb_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
)
|
||||
self.up = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(mlp_dim, in_dim, kernel_size=1, stride=1, padding=0),
|
||||
)
|
||||
else:
|
||||
# midの前にconditioningをreshapeすること / reshape conditioning before mid
|
||||
self.down = torch.nn.Sequential(
|
||||
torch.nn.Linear(in_dim, mlp_dim),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
)
|
||||
self.mid = torch.nn.Sequential(
|
||||
torch.nn.Linear(mlp_dim + cond_emb_dim, mlp_dim),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
)
|
||||
self.up = torch.nn.Sequential(
|
||||
torch.nn.Linear(mlp_dim, in_dim),
|
||||
)
|
||||
|
||||
# Zero-Convにする / set to Zero-Conv
|
||||
torch.nn.init.zeros_(self.up[0].weight) # zero conv
|
||||
|
||||
self.depth = depth # 1~3
|
||||
self.cond_emb = None
|
||||
self.batch_cond_only = False # Trueなら推論時のcondにのみ適用する / if True, apply only to cond at inference
|
||||
self.use_zeros_for_batch_uncond = False # Trueならuncondのconditioningを0にする / if True, set uncond conditioning to 0
|
||||
|
||||
# batch_cond_onlyとuse_zeros_for_batch_uncondはどちらも適用すると生成画像の色味がおかしくなるので実際には使えそうにない
|
||||
# Controlの種類によっては使えるかも
|
||||
# both batch_cond_only and use_zeros_for_batch_uncond make the color of the generated image strange, so it doesn't seem to be usable in practice
|
||||
# it may be available depending on the type of Control
|
||||
|
||||
def set_cond_image(self, cond_image):
|
||||
r"""
|
||||
中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む
|
||||
/ call the model inside, so if necessary, surround it with torch.no_grad()
|
||||
"""
|
||||
if cond_image is None:
|
||||
self.cond_emb = None
|
||||
return
|
||||
|
||||
# timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance
|
||||
# print(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}")
|
||||
cx = self.conditioning1(cond_image)
|
||||
if not self.is_conv2d:
|
||||
# reshape / b,c,h,w -> b,h*w,c
|
||||
n, c, h, w = cx.shape
|
||||
cx = cx.view(n, c, h * w).permute(0, 2, 1)
|
||||
self.cond_emb = cx
|
||||
|
||||
def set_batch_cond_only(self, cond_only, zeros):
|
||||
self.batch_cond_only = cond_only
|
||||
self.use_zeros_for_batch_uncond = zeros
|
||||
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module[0].forward
|
||||
self.org_module[0].forward = self.forward
|
||||
|
||||
def forward(self, x):
|
||||
r"""
|
||||
学習用の便利forward。元のモジュールのforwardを呼び出す
|
||||
/ convenient forward for training. call the forward of the original module
|
||||
"""
|
||||
if self.multiplier == 0.0 or self.cond_emb is None:
|
||||
return self.org_forward(x)
|
||||
|
||||
cx = self.cond_emb
|
||||
|
||||
if not self.batch_cond_only and x.shape[0] // 2 == cx.shape[0]: # inference only
|
||||
cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1)
|
||||
if self.use_zeros_for_batch_uncond:
|
||||
cx[0::2] = 0.0 # uncond is zero
|
||||
# print(f"C {self.lllite_name}, x.shape={x.shape}, cx.shape={cx.shape}")
|
||||
|
||||
# downで入力の次元数を削減し、conditioning image embeddingと結合する
|
||||
# 加算ではなくchannel方向に結合することで、うまいこと混ぜてくれることを期待している
|
||||
# down reduces the number of input dimensions and combines it with conditioning image embedding
|
||||
# we expect that it will mix well by combining in the channel direction instead of adding
|
||||
|
||||
cx = torch.cat([cx, self.down(x if not self.batch_cond_only else x[1::2])], dim=1 if self.is_conv2d else 2)
|
||||
cx = self.mid(cx)
|
||||
|
||||
if self.dropout is not None and self.training:
|
||||
cx = torch.nn.functional.dropout(cx, p=self.dropout)
|
||||
|
||||
cx = self.up(cx) * self.multiplier
|
||||
|
||||
# residual (x) を加算して元のforwardを呼び出す / add residual (x) and call the original forward
|
||||
if self.batch_cond_only:
|
||||
zx = torch.zeros_like(x)
|
||||
zx[1::2] += cx
|
||||
cx = zx
|
||||
|
||||
x = self.org_forward(x + cx) # ここで元のモジュールを呼び出す / call the original module here
|
||||
return x
|
||||
|
||||
|
||||
class ControlNetLLLite(torch.nn.Module):
|
||||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
||||
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
unet: sdxl_original_unet.SdxlUNet2DConditionModel,
|
||||
cond_emb_dim: int = 16,
|
||||
mlp_dim: int = 16,
|
||||
dropout: Optional[float] = None,
|
||||
varbose: Optional[bool] = False,
|
||||
multiplier: Optional[float] = 1.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# self.unets = [unet]
|
||||
|
||||
def create_modules(
|
||||
root_module: torch.nn.Module,
|
||||
target_replace_modules: List[torch.nn.Module],
|
||||
module_class: Type[object],
|
||||
) -> List[torch.nn.Module]:
|
||||
prefix = "lllite_unet"
|
||||
|
||||
modules = []
|
||||
for name, module in root_module.named_modules():
|
||||
if module.__class__.__name__ in target_replace_modules:
|
||||
for child_name, child_module in module.named_modules():
|
||||
is_linear = child_module.__class__.__name__ == "Linear"
|
||||
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
||||
|
||||
if is_linear or (is_conv2d and not SKIP_CONV2D):
|
||||
# block indexからdepthを計算: depthはconditioningのサイズやチャネルを計算するのに使う
|
||||
# block index to depth: depth is using to calculate conditioning size and channels
|
||||
block_name, index1, index2 = (name + "." + child_name).split(".")[:3]
|
||||
index1 = int(index1)
|
||||
if block_name == "input_blocks":
|
||||
if SKIP_INPUT_BLOCKS:
|
||||
continue
|
||||
depth = 1 if index1 <= 2 else (2 if index1 <= 5 else 3)
|
||||
elif block_name == "middle_block":
|
||||
depth = 3
|
||||
elif block_name == "output_blocks":
|
||||
if SKIP_OUTPUT_BLOCKS:
|
||||
continue
|
||||
depth = 3 if index1 <= 2 else (2 if index1 <= 5 else 1)
|
||||
if int(index2) >= 2:
|
||||
depth -= 1
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
lllite_name = prefix + "." + name + "." + child_name
|
||||
lllite_name = lllite_name.replace(".", "_")
|
||||
|
||||
if TRANSFORMER_MAX_BLOCK_INDEX is not None:
|
||||
p = lllite_name.find("transformer_blocks")
|
||||
if p >= 0:
|
||||
tf_index = int(lllite_name[p:].split("_")[2])
|
||||
if tf_index > TRANSFORMER_MAX_BLOCK_INDEX:
|
||||
continue
|
||||
|
||||
# time embは適用外とする
|
||||
# attn2のconditioning (CLIPからの入力) はshapeが違うので適用できない
|
||||
# time emb is not applied
|
||||
# attn2 conditioning (input from CLIP) cannot be applied because the shape is different
|
||||
if "emb_layers" in lllite_name or (
|
||||
"attn2" in lllite_name and ("to_k" in lllite_name or "to_v" in lllite_name)
|
||||
):
|
||||
continue
|
||||
|
||||
if ATTN1_2_ONLY:
|
||||
if not ("attn1" in lllite_name or "attn2" in lllite_name):
|
||||
continue
|
||||
if ATTN_QKV_ONLY:
|
||||
if "to_out" in lllite_name:
|
||||
continue
|
||||
|
||||
if ATTN1_ETC_ONLY:
|
||||
if "proj_out" in lllite_name:
|
||||
pass
|
||||
elif "attn1" in lllite_name and (
|
||||
"to_k" in lllite_name or "to_v" in lllite_name or "to_out" in lllite_name
|
||||
):
|
||||
pass
|
||||
elif "ff_net_2" in lllite_name:
|
||||
pass
|
||||
else:
|
||||
continue
|
||||
|
||||
module = module_class(
|
||||
depth,
|
||||
cond_emb_dim,
|
||||
lllite_name,
|
||||
child_module,
|
||||
mlp_dim,
|
||||
dropout=dropout,
|
||||
multiplier=multiplier,
|
||||
)
|
||||
modules.append(module)
|
||||
return modules
|
||||
|
||||
target_modules = ControlNetLLLite.UNET_TARGET_REPLACE_MODULE
|
||||
if not TRANSFORMER_ONLY:
|
||||
target_modules = target_modules + ControlNetLLLite.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
|
||||
# create module instances
|
||||
self.unet_modules: List[LLLiteModule] = create_modules(unet, target_modules, LLLiteModule)
|
||||
print(f"create ControlNet LLLite for U-Net: {len(self.unet_modules)} modules.")
|
||||
|
||||
def forward(self, x):
|
||||
return x # dummy
|
||||
|
||||
def set_cond_image(self, cond_image):
|
||||
r"""
|
||||
中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む
|
||||
/ call the model inside, so if necessary, surround it with torch.no_grad()
|
||||
"""
|
||||
for module in self.unet_modules:
|
||||
module.set_cond_image(cond_image)
|
||||
|
||||
def set_batch_cond_only(self, cond_only, zeros):
|
||||
for module in self.unet_modules:
|
||||
module.set_batch_cond_only(cond_only, zeros)
|
||||
|
||||
def set_multiplier(self, multiplier):
|
||||
for module in self.unet_modules:
|
||||
module.multiplier = multiplier
|
||||
|
||||
def load_weights(self, file):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
|
||||
info = self.load_state_dict(weights_sd, False)
|
||||
return info
|
||||
|
||||
def apply_to(self):
|
||||
print("applying LLLite for U-Net...")
|
||||
for module in self.unet_modules:
|
||||
module.apply_to()
|
||||
self.add_module(module.lllite_name, module)
|
||||
|
||||
# マージできるかどうかを返す
|
||||
def is_mergeable(self):
|
||||
return False
|
||||
|
||||
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
|
||||
raise NotImplementedError()
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
# not supported
|
||||
pass
|
||||
|
||||
def prepare_optimizer_params(self):
|
||||
self.requires_grad_(True)
|
||||
return self.parameters()
|
||||
|
||||
def prepare_grad_etc(self):
|
||||
self.requires_grad_(True)
|
||||
|
||||
def on_epoch_start(self):
|
||||
self.train()
|
||||
|
||||
def get_trainable_params(self):
|
||||
return self.parameters()
|
||||
|
||||
def save_weights(self, file, dtype, metadata):
|
||||
if metadata is not None and len(metadata) == 0:
|
||||
metadata = None
|
||||
|
||||
state_dict = self.state_dict()
|
||||
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
v = state_dict[key]
|
||||
v = v.detach().clone().to("cpu").to(dtype)
|
||||
state_dict[key] = v
|
||||
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import save_file
|
||||
|
||||
save_file(state_dict, file, metadata)
|
||||
else:
|
||||
torch.save(state_dict, file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# デバッグ用 / for debug
|
||||
|
||||
# sdxl_original_unet.USE_REENTRANT = False
|
||||
|
||||
# test shape etc
|
||||
print("create unet")
|
||||
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
|
||||
unet.to("cuda").to(torch.float16)
|
||||
|
||||
print("create ControlNet-LLLite")
|
||||
control_net = ControlNetLLLite(unet, 32, 64)
|
||||
control_net.apply_to()
|
||||
control_net.to("cuda")
|
||||
|
||||
print(control_net)
|
||||
|
||||
# print number of parameters
|
||||
print("number of parameters", sum(p.numel() for p in control_net.parameters() if p.requires_grad))
|
||||
|
||||
input()
|
||||
|
||||
unet.set_use_memory_efficient_attention(True, False)
|
||||
unet.set_gradient_checkpointing(True)
|
||||
unet.train() # for gradient checkpointing
|
||||
|
||||
control_net.train()
|
||||
|
||||
# # visualize
|
||||
# import torchviz
|
||||
# print("run visualize")
|
||||
# controlnet.set_control(conditioning_image)
|
||||
# output = unet(x, t, ctx, y)
|
||||
# print("make_dot")
|
||||
# image = torchviz.make_dot(output, params=dict(controlnet.named_parameters()))
|
||||
# print("render")
|
||||
# image.format = "svg" # "png"
|
||||
# image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time
|
||||
# input()
|
||||
|
||||
import bitsandbytes
|
||||
|
||||
optimizer = bitsandbytes.adam.Adam8bit(control_net.prepare_optimizer_params(), 1e-3)
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=True)
|
||||
|
||||
print("start training")
|
||||
steps = 10
|
||||
|
||||
sample_param = [p for p in control_net.named_parameters() if "up" in p[0]][0]
|
||||
for step in range(steps):
|
||||
print(f"step {step}")
|
||||
|
||||
batch_size = 1
|
||||
conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0
|
||||
x = torch.randn(batch_size, 4, 128, 128).cuda()
|
||||
t = torch.randint(low=0, high=10, size=(batch_size,)).cuda()
|
||||
ctx = torch.randn(batch_size, 77, 2048).cuda()
|
||||
y = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda()
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=True):
|
||||
control_net.set_cond_image(conditioning_image)
|
||||
|
||||
output = unet(x, t, ctx, y)
|
||||
target = torch.randn_like(output)
|
||||
loss = torch.nn.functional.mse_loss(output, target)
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
print(sample_param)
|
||||
|
||||
# from safetensors.torch import save_file
|
||||
|
||||
# save_file(control_net.state_dict(), "logs/control_net.safetensors")
|
||||
502
networks/control_net_lllite_for_train.py
Normal file
502
networks/control_net_lllite_for_train.py
Normal file
@@ -0,0 +1,502 @@
|
||||
# cond_imageをU-Netのforwardで渡すバージョンのControlNet-LLLite検証用実装
|
||||
# ControlNet-LLLite implementation for verification with cond_image passed in U-Net's forward
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import Optional, List, Type
|
||||
import torch
|
||||
from library import sdxl_original_unet
|
||||
|
||||
|
||||
# input_blocksに適用するかどうか / if True, input_blocks are not applied
|
||||
SKIP_INPUT_BLOCKS = False
|
||||
|
||||
# output_blocksに適用するかどうか / if True, output_blocks are not applied
|
||||
SKIP_OUTPUT_BLOCKS = True
|
||||
|
||||
# conv2dに適用するかどうか / if True, conv2d are not applied
|
||||
SKIP_CONV2D = False
|
||||
|
||||
# transformer_blocksのみに適用するかどうか。Trueの場合、ResBlockには適用されない
|
||||
# if True, only transformer_blocks are applied, and ResBlocks are not applied
|
||||
TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored because conv2d is not used in transformer_blocks
|
||||
|
||||
# Trueならattn1とattn2にのみ適用し、ffなどには適用しない / if True, apply only to attn1 and attn2, not to ff etc.
|
||||
ATTN1_2_ONLY = True
|
||||
|
||||
# Trueならattn1のQKV、attn2のQにのみ適用する、ATTN1_2_ONLY指定時のみ有効 / if True, apply only to attn1 QKV and attn2 Q, only valid when ATTN1_2_ONLY is specified
|
||||
ATTN_QKV_ONLY = True
|
||||
|
||||
# Trueならattn1やffなどにのみ適用し、attn2などには適用しない / if True, apply only to attn1 and ff, not to attn2
|
||||
# ATTN1_2_ONLYと同時にTrueにできない / cannot be True at the same time as ATTN1_2_ONLY
|
||||
ATTN1_ETC_ONLY = False # True
|
||||
|
||||
# transformer_blocksの最大インデックス。Noneなら全てのtransformer_blocksに適用
|
||||
# max index of transformer_blocks. if None, apply to all transformer_blocks
|
||||
TRANSFORMER_MAX_BLOCK_INDEX = None
|
||||
|
||||
ORIGINAL_LINEAR = torch.nn.Linear
|
||||
ORIGINAL_CONV2D = torch.nn.Conv2d
|
||||
|
||||
|
||||
def add_lllite_modules(module: torch.nn.Module, in_dim: int, depth, cond_emb_dim, mlp_dim) -> None:
|
||||
# conditioning1はconditioning imageを embedding する。timestepごとに呼ばれない
|
||||
# conditioning1 embeds conditioning image. it is not called for each timestep
|
||||
modules = []
|
||||
modules.append(ORIGINAL_CONV2D(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size
|
||||
if depth == 1:
|
||||
modules.append(torch.nn.ReLU(inplace=True))
|
||||
modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
|
||||
elif depth == 2:
|
||||
modules.append(torch.nn.ReLU(inplace=True))
|
||||
modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0))
|
||||
elif depth == 3:
|
||||
# kernel size 8は大きすぎるので、4にする / kernel size 8 is too large, so set it to 4
|
||||
modules.append(torch.nn.ReLU(inplace=True))
|
||||
modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0))
|
||||
modules.append(torch.nn.ReLU(inplace=True))
|
||||
modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
|
||||
|
||||
module.lllite_conditioning1 = torch.nn.Sequential(*modules)
|
||||
|
||||
# downで入力の次元数を削減する。LoRAにヒントを得ていることにする
|
||||
# midでconditioning image embeddingと入力を結合する
|
||||
# upで元の次元数に戻す
|
||||
# これらはtimestepごとに呼ばれる
|
||||
# reduce the number of input dimensions with down. inspired by LoRA
|
||||
# combine conditioning image embedding and input with mid
|
||||
# restore to the original dimension with up
|
||||
# these are called for each timestep
|
||||
|
||||
module.lllite_down = torch.nn.Sequential(
|
||||
ORIGINAL_LINEAR(in_dim, mlp_dim),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
)
|
||||
module.lllite_mid = torch.nn.Sequential(
|
||||
ORIGINAL_LINEAR(mlp_dim + cond_emb_dim, mlp_dim),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
)
|
||||
module.lllite_up = torch.nn.Sequential(
|
||||
ORIGINAL_LINEAR(mlp_dim, in_dim),
|
||||
)
|
||||
|
||||
# Zero-Convにする / set to Zero-Conv
|
||||
torch.nn.init.zeros_(module.lllite_up[0].weight) # zero conv
|
||||
|
||||
|
||||
class LLLiteLinear(ORIGINAL_LINEAR):
|
||||
def __init__(self, in_features: int, out_features: int, **kwargs):
|
||||
super().__init__(in_features, out_features, **kwargs)
|
||||
self.enabled = False
|
||||
|
||||
def set_lllite(self, depth, cond_emb_dim, name, mlp_dim, dropout=None, multiplier=1.0):
|
||||
self.enabled = True
|
||||
self.lllite_name = name
|
||||
self.cond_emb_dim = cond_emb_dim
|
||||
self.dropout = dropout
|
||||
self.multiplier = multiplier # ignored
|
||||
|
||||
in_dim = self.in_features
|
||||
add_lllite_modules(self, in_dim, depth, cond_emb_dim, mlp_dim)
|
||||
|
||||
self.cond_image = None
|
||||
self.cond_emb = None
|
||||
|
||||
def set_cond_image(self, cond_image):
|
||||
self.cond_image = cond_image
|
||||
self.cond_emb = None
|
||||
|
||||
def forward(self, x):
|
||||
if not self.enabled:
|
||||
return super().forward(x)
|
||||
|
||||
if self.cond_emb is None:
|
||||
self.cond_emb = self.lllite_conditioning1(self.cond_image)
|
||||
cx = self.cond_emb
|
||||
|
||||
# reshape / b,c,h,w -> b,h*w,c
|
||||
n, c, h, w = cx.shape
|
||||
cx = cx.view(n, c, h * w).permute(0, 2, 1)
|
||||
|
||||
cx = torch.cat([cx, self.lllite_down(x)], dim=2)
|
||||
cx = self.lllite_mid(cx)
|
||||
|
||||
if self.dropout is not None and self.training:
|
||||
cx = torch.nn.functional.dropout(cx, p=self.dropout)
|
||||
|
||||
cx = self.lllite_up(cx) * self.multiplier
|
||||
|
||||
x = super().forward(x + cx) # ここで元のモジュールを呼び出す / call the original module here
|
||||
return x
|
||||
|
||||
|
||||
class LLLiteConv2d(ORIGINAL_CONV2D):
|
||||
def __init__(self, in_channels: int, out_channels: int, kernel_size, **kwargs):
|
||||
super().__init__(in_channels, out_channels, kernel_size, **kwargs)
|
||||
self.enabled = False
|
||||
|
||||
def set_lllite(self, depth, cond_emb_dim, name, mlp_dim, dropout=None, multiplier=1.0):
|
||||
self.enabled = True
|
||||
self.lllite_name = name
|
||||
self.cond_emb_dim = cond_emb_dim
|
||||
self.dropout = dropout
|
||||
self.multiplier = multiplier # ignored
|
||||
|
||||
in_dim = self.in_channels
|
||||
add_lllite_modules(self, in_dim, depth, cond_emb_dim, mlp_dim)
|
||||
|
||||
self.cond_image = None
|
||||
self.cond_emb = None
|
||||
|
||||
def set_cond_image(self, cond_image):
|
||||
self.cond_image = cond_image
|
||||
self.cond_emb = None
|
||||
|
||||
def forward(self, x): # , cond_image=None):
|
||||
if not self.enabled:
|
||||
return super().forward(x)
|
||||
|
||||
if self.cond_emb is None:
|
||||
self.cond_emb = self.lllite_conditioning1(self.cond_image)
|
||||
cx = self.cond_emb
|
||||
|
||||
cx = torch.cat([cx, self.down(x)], dim=1)
|
||||
cx = self.mid(cx)
|
||||
|
||||
if self.dropout is not None and self.training:
|
||||
cx = torch.nn.functional.dropout(cx, p=self.dropout)
|
||||
|
||||
cx = self.up(cx) * self.multiplier
|
||||
|
||||
x = super().forward(x + cx) # ここで元のモジュールを呼び出す / call the original module here
|
||||
return x
|
||||
|
||||
|
||||
class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DConditionModel):
|
||||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
||||
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||
LLLITE_PREFIX = "lllite_unet"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def apply_lllite(
|
||||
self,
|
||||
cond_emb_dim: int = 16,
|
||||
mlp_dim: int = 16,
|
||||
dropout: Optional[float] = None,
|
||||
varbose: Optional[bool] = False,
|
||||
multiplier: Optional[float] = 1.0,
|
||||
) -> None:
|
||||
def apply_to_modules(
|
||||
root_module: torch.nn.Module,
|
||||
target_replace_modules: List[torch.nn.Module],
|
||||
) -> List[torch.nn.Module]:
|
||||
prefix = "lllite_unet"
|
||||
|
||||
modules = []
|
||||
for name, module in root_module.named_modules():
|
||||
if module.__class__.__name__ in target_replace_modules:
|
||||
for child_name, child_module in module.named_modules():
|
||||
is_linear = child_module.__class__.__name__ == "LLLiteLinear"
|
||||
is_conv2d = child_module.__class__.__name__ == "LLLiteConv2d"
|
||||
|
||||
if is_linear or (is_conv2d and not SKIP_CONV2D):
|
||||
# block indexからdepthを計算: depthはconditioningのサイズやチャネルを計算するのに使う
|
||||
# block index to depth: depth is using to calculate conditioning size and channels
|
||||
block_name, index1, index2 = (name + "." + child_name).split(".")[:3]
|
||||
index1 = int(index1)
|
||||
if block_name == "input_blocks":
|
||||
if SKIP_INPUT_BLOCKS:
|
||||
continue
|
||||
depth = 1 if index1 <= 2 else (2 if index1 <= 5 else 3)
|
||||
elif block_name == "middle_block":
|
||||
depth = 3
|
||||
elif block_name == "output_blocks":
|
||||
if SKIP_OUTPUT_BLOCKS:
|
||||
continue
|
||||
depth = 3 if index1 <= 2 else (2 if index1 <= 5 else 1)
|
||||
if int(index2) >= 2:
|
||||
depth -= 1
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
lllite_name = prefix + "." + name + "." + child_name
|
||||
lllite_name = lllite_name.replace(".", "_")
|
||||
|
||||
if TRANSFORMER_MAX_BLOCK_INDEX is not None:
|
||||
p = lllite_name.find("transformer_blocks")
|
||||
if p >= 0:
|
||||
tf_index = int(lllite_name[p:].split("_")[2])
|
||||
if tf_index > TRANSFORMER_MAX_BLOCK_INDEX:
|
||||
continue
|
||||
|
||||
# time embは適用外とする
|
||||
# attn2のconditioning (CLIPからの入力) はshapeが違うので適用できない
|
||||
# time emb is not applied
|
||||
# attn2 conditioning (input from CLIP) cannot be applied because the shape is different
|
||||
if "emb_layers" in lllite_name or (
|
||||
"attn2" in lllite_name and ("to_k" in lllite_name or "to_v" in lllite_name)
|
||||
):
|
||||
continue
|
||||
|
||||
if ATTN1_2_ONLY:
|
||||
if not ("attn1" in lllite_name or "attn2" in lllite_name):
|
||||
continue
|
||||
if ATTN_QKV_ONLY:
|
||||
if "to_out" in lllite_name:
|
||||
continue
|
||||
|
||||
if ATTN1_ETC_ONLY:
|
||||
if "proj_out" in lllite_name:
|
||||
pass
|
||||
elif "attn1" in lllite_name and (
|
||||
"to_k" in lllite_name or "to_v" in lllite_name or "to_out" in lllite_name
|
||||
):
|
||||
pass
|
||||
elif "ff_net_2" in lllite_name:
|
||||
pass
|
||||
else:
|
||||
continue
|
||||
|
||||
child_module.set_lllite(depth, cond_emb_dim, lllite_name, mlp_dim, dropout, multiplier)
|
||||
modules.append(child_module)
|
||||
|
||||
return modules
|
||||
|
||||
target_modules = SdxlUNet2DConditionModelControlNetLLLite.UNET_TARGET_REPLACE_MODULE
|
||||
if not TRANSFORMER_ONLY:
|
||||
target_modules = target_modules + SdxlUNet2DConditionModelControlNetLLLite.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
|
||||
# create module instances
|
||||
self.lllite_modules = apply_to_modules(self, target_modules)
|
||||
print(f"enable ControlNet LLLite for U-Net: {len(self.lllite_modules)} modules.")
|
||||
|
||||
# def prepare_optimizer_params(self):
|
||||
def prepare_params(self):
|
||||
train_params = []
|
||||
non_train_params = []
|
||||
for name, p in self.named_parameters():
|
||||
if "lllite" in name:
|
||||
train_params.append(p)
|
||||
else:
|
||||
non_train_params.append(p)
|
||||
print(f"count of trainable parameters: {len(train_params)}")
|
||||
print(f"count of non-trainable parameters: {len(non_train_params)}")
|
||||
|
||||
for p in non_train_params:
|
||||
p.requires_grad_(False)
|
||||
|
||||
# without this, an error occurs in the optimizer
|
||||
# RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
|
||||
non_train_params[0].requires_grad_(True)
|
||||
|
||||
for p in train_params:
|
||||
p.requires_grad_(True)
|
||||
|
||||
return train_params
|
||||
|
||||
# def prepare_grad_etc(self):
|
||||
# self.requires_grad_(True)
|
||||
|
||||
# def on_epoch_start(self):
|
||||
# self.train()
|
||||
|
||||
def get_trainable_params(self):
|
||||
return [p[1] for p in self.named_parameters() if "lllite" in p[0]]
|
||||
|
||||
def save_lllite_weights(self, file, dtype, metadata):
|
||||
if metadata is not None and len(metadata) == 0:
|
||||
metadata = None
|
||||
|
||||
org_state_dict = self.state_dict()
|
||||
|
||||
# copy LLLite keys from org_state_dict to state_dict with key conversion
|
||||
state_dict = {}
|
||||
for key in org_state_dict.keys():
|
||||
# split with ".lllite"
|
||||
pos = key.find(".lllite")
|
||||
if pos < 0:
|
||||
continue
|
||||
lllite_key = SdxlUNet2DConditionModelControlNetLLLite.LLLITE_PREFIX + "." + key[:pos]
|
||||
lllite_key = lllite_key.replace(".", "_") + key[pos:]
|
||||
lllite_key = lllite_key.replace(".lllite_", ".")
|
||||
state_dict[lllite_key] = org_state_dict[key]
|
||||
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
v = state_dict[key]
|
||||
v = v.detach().clone().to("cpu").to(dtype)
|
||||
state_dict[key] = v
|
||||
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import save_file
|
||||
|
||||
save_file(state_dict, file, metadata)
|
||||
else:
|
||||
torch.save(state_dict, file)
|
||||
|
||||
def load_lllite_weights(self, file, non_lllite_unet_sd=None):
|
||||
r"""
|
||||
LLLiteの重みを読み込まない(initされた値を使う)場合はfileにNoneを指定する。
|
||||
この場合、non_lllite_unet_sdにはU-Netのstate_dictを指定する。
|
||||
|
||||
If you do not want to load LLLite weights (use initialized values), specify None for file.
|
||||
In this case, specify the state_dict of U-Net for non_lllite_unet_sd.
|
||||
"""
|
||||
if not file:
|
||||
state_dict = self.state_dict()
|
||||
for key in non_lllite_unet_sd:
|
||||
if key in state_dict:
|
||||
state_dict[key] = non_lllite_unet_sd[key]
|
||||
info = self.load_state_dict(state_dict, False)
|
||||
return info
|
||||
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
|
||||
# module_name = module_name.replace("_block", "@blocks")
|
||||
# module_name = module_name.replace("_layer", "@layer")
|
||||
# module_name = module_name.replace("to_", "to@")
|
||||
# module_name = module_name.replace("time_embed", "time@embed")
|
||||
# module_name = module_name.replace("label_emb", "label@emb")
|
||||
# module_name = module_name.replace("skip_connection", "skip@connection")
|
||||
# module_name = module_name.replace("proj_in", "proj@in")
|
||||
# module_name = module_name.replace("proj_out", "proj@out")
|
||||
pattern = re.compile(r"(_block|_layer|to_|time_embed|label_emb|skip_connection|proj_in|proj_out)")
|
||||
|
||||
# convert to lllite with U-Net state dict
|
||||
state_dict = non_lllite_unet_sd.copy() if non_lllite_unet_sd is not None else {}
|
||||
for key in weights_sd.keys():
|
||||
# split with "."
|
||||
pos = key.find(".")
|
||||
if pos < 0:
|
||||
continue
|
||||
|
||||
module_name = key[:pos]
|
||||
weight_name = key[pos + 1 :] # exclude "."
|
||||
module_name = module_name.replace(SdxlUNet2DConditionModelControlNetLLLite.LLLITE_PREFIX + "_", "")
|
||||
|
||||
# これはうまくいかない。逆変換を考えなかった設計が悪い / this does not work well. bad design because I didn't think about inverse conversion
|
||||
# module_name = module_name.replace("_", ".")
|
||||
|
||||
# ださいけどSDXLのU-Netの "_" を "@" に変換する / ugly but convert "_" of SDXL U-Net to "@"
|
||||
matches = pattern.findall(module_name)
|
||||
if matches is not None:
|
||||
for m in matches:
|
||||
print(module_name, m)
|
||||
module_name = module_name.replace(m, m.replace("_", "@"))
|
||||
module_name = module_name.replace("_", ".")
|
||||
module_name = module_name.replace("@", "_")
|
||||
|
||||
lllite_key = module_name + ".lllite_" + weight_name
|
||||
|
||||
state_dict[lllite_key] = weights_sd[key]
|
||||
|
||||
info = self.load_state_dict(state_dict, False)
|
||||
return info
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, y=None, cond_image=None, **kwargs):
|
||||
for m in self.lllite_modules:
|
||||
m.set_cond_image(cond_image)
|
||||
return super().forward(x, timesteps, context, y, **kwargs)
|
||||
|
||||
|
||||
def replace_unet_linear_and_conv2d():
|
||||
print("replace torch.nn.Linear and torch.nn.Conv2d to LLLiteLinear and LLLiteConv2d in U-Net")
|
||||
sdxl_original_unet.torch.nn.Linear = LLLiteLinear
|
||||
sdxl_original_unet.torch.nn.Conv2d = LLLiteConv2d
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# デバッグ用 / for debug
|
||||
|
||||
# sdxl_original_unet.USE_REENTRANT = False
|
||||
replace_unet_linear_and_conv2d()
|
||||
|
||||
# test shape etc
|
||||
print("create unet")
|
||||
unet = SdxlUNet2DConditionModelControlNetLLLite()
|
||||
|
||||
print("enable ControlNet-LLLite")
|
||||
unet.apply_lllite(32, 64, None, False, 1.0)
|
||||
unet.to("cuda") # .to(torch.float16)
|
||||
|
||||
# from safetensors.torch import load_file
|
||||
|
||||
# model_sd = load_file(r"E:\Work\SD\Models\sdxl\sd_xl_base_1.0_0.9vae.safetensors")
|
||||
# unet_sd = {}
|
||||
|
||||
# # copy U-Net keys from unet_state_dict to state_dict
|
||||
# prefix = "model.diffusion_model."
|
||||
# for key in model_sd.keys():
|
||||
# if key.startswith(prefix):
|
||||
# converted_key = key[len(prefix) :]
|
||||
# unet_sd[converted_key] = model_sd[key]
|
||||
|
||||
# info = unet.load_lllite_weights("r:/lllite_from_unet.safetensors", unet_sd)
|
||||
# print(info)
|
||||
|
||||
# print(unet)
|
||||
|
||||
# print number of parameters
|
||||
params = unet.prepare_params()
|
||||
print("number of parameters", sum(p.numel() for p in params))
|
||||
# print("type any key to continue")
|
||||
# input()
|
||||
|
||||
unet.set_use_memory_efficient_attention(True, False)
|
||||
unet.set_gradient_checkpointing(True)
|
||||
unet.train() # for gradient checkpointing
|
||||
|
||||
# # visualize
|
||||
# import torchviz
|
||||
# print("run visualize")
|
||||
# controlnet.set_control(conditioning_image)
|
||||
# output = unet(x, t, ctx, y)
|
||||
# print("make_dot")
|
||||
# image = torchviz.make_dot(output, params=dict(controlnet.named_parameters()))
|
||||
# print("render")
|
||||
# image.format = "svg" # "png"
|
||||
# image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time
|
||||
# input()
|
||||
|
||||
import bitsandbytes
|
||||
|
||||
optimizer = bitsandbytes.adam.Adam8bit(params, 1e-3)
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=True)
|
||||
|
||||
print("start training")
|
||||
steps = 10
|
||||
batch_size = 1
|
||||
|
||||
sample_param = [p for p in unet.named_parameters() if ".lllite_up." in p[0]][0]
|
||||
for step in range(steps):
|
||||
print(f"step {step}")
|
||||
|
||||
conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0
|
||||
x = torch.randn(batch_size, 4, 128, 128).cuda()
|
||||
t = torch.randint(low=0, high=10, size=(batch_size,)).cuda()
|
||||
ctx = torch.randn(batch_size, 77, 2048).cuda()
|
||||
y = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda()
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
|
||||
output = unet(x, t, ctx, y, conditioning_image)
|
||||
target = torch.randn_like(output)
|
||||
loss = torch.nn.functional.mse_loss(output, target)
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
print(sample_param)
|
||||
|
||||
# from safetensors.torch import save_file
|
||||
|
||||
# print("save weights")
|
||||
# unet.save_lllite_weights("r:/lllite_from_unet.safetensors", torch.float16, None)
|
||||
450
networks/dylora.py
Normal file
450
networks/dylora.py
Normal file
@@ -0,0 +1,450 @@
|
||||
# some codes are copied from:
|
||||
# https://github.com/huawei-noah/KD-NLP/blob/main/DyLoRA/
|
||||
|
||||
# Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved.
|
||||
# Changes made to the original code:
|
||||
# 2022.08.20 - Integrate the DyLoRA layer for the LoRA Linear layer
|
||||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from typing import List, Tuple, Union
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class DyLoRAModule(torch.nn.Module):
|
||||
"""
|
||||
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
||||
"""
|
||||
|
||||
# NOTE: support dropout in future
|
||||
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, unit=1):
|
||||
super().__init__()
|
||||
self.lora_name = lora_name
|
||||
self.lora_dim = lora_dim
|
||||
self.unit = unit
|
||||
assert self.lora_dim % self.unit == 0, "rank must be a multiple of unit"
|
||||
|
||||
if org_module.__class__.__name__ == "Conv2d":
|
||||
in_dim = org_module.in_channels
|
||||
out_dim = org_module.out_channels
|
||||
else:
|
||||
in_dim = org_module.in_features
|
||||
out_dim = org_module.out_features
|
||||
|
||||
if type(alpha) == torch.Tensor:
|
||||
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
||||
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
|
||||
self.scale = alpha / self.lora_dim
|
||||
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
|
||||
|
||||
self.is_conv2d = org_module.__class__.__name__ == "Conv2d"
|
||||
self.is_conv2d_3x3 = self.is_conv2d and org_module.kernel_size == (3, 3)
|
||||
|
||||
if self.is_conv2d and self.is_conv2d_3x3:
|
||||
kernel_size = org_module.kernel_size
|
||||
self.stride = org_module.stride
|
||||
self.padding = org_module.padding
|
||||
self.lora_A = nn.ParameterList([org_module.weight.new_zeros((1, in_dim, *kernel_size)) for _ in range(self.lora_dim)])
|
||||
self.lora_B = nn.ParameterList([org_module.weight.new_zeros((out_dim, 1, 1, 1)) for _ in range(self.lora_dim)])
|
||||
else:
|
||||
self.lora_A = nn.ParameterList([org_module.weight.new_zeros((1, in_dim)) for _ in range(self.lora_dim)])
|
||||
self.lora_B = nn.ParameterList([org_module.weight.new_zeros((out_dim, 1)) for _ in range(self.lora_dim)])
|
||||
|
||||
# same as microsoft's
|
||||
for lora in self.lora_A:
|
||||
torch.nn.init.kaiming_uniform_(lora, a=math.sqrt(5))
|
||||
for lora in self.lora_B:
|
||||
torch.nn.init.zeros_(lora)
|
||||
|
||||
self.multiplier = multiplier
|
||||
self.org_module = org_module # remove in applying
|
||||
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module.forward
|
||||
self.org_module.forward = self.forward
|
||||
del self.org_module
|
||||
|
||||
def forward(self, x):
|
||||
result = self.org_forward(x)
|
||||
|
||||
# specify the dynamic rank
|
||||
trainable_rank = random.randint(0, self.lora_dim - 1)
|
||||
trainable_rank = trainable_rank - trainable_rank % self.unit # make sure the rank is a multiple of unit
|
||||
|
||||
# 一部のパラメータを固定して、残りのパラメータを学習する
|
||||
for i in range(0, trainable_rank):
|
||||
self.lora_A[i].requires_grad = False
|
||||
self.lora_B[i].requires_grad = False
|
||||
for i in range(trainable_rank, trainable_rank + self.unit):
|
||||
self.lora_A[i].requires_grad = True
|
||||
self.lora_B[i].requires_grad = True
|
||||
for i in range(trainable_rank + self.unit, self.lora_dim):
|
||||
self.lora_A[i].requires_grad = False
|
||||
self.lora_B[i].requires_grad = False
|
||||
|
||||
lora_A = torch.cat(tuple(self.lora_A), dim=0)
|
||||
lora_B = torch.cat(tuple(self.lora_B), dim=1)
|
||||
|
||||
# calculate with lora_A and lora_B
|
||||
if self.is_conv2d_3x3:
|
||||
ab = torch.nn.functional.conv2d(x, lora_A, stride=self.stride, padding=self.padding)
|
||||
ab = torch.nn.functional.conv2d(ab, lora_B)
|
||||
else:
|
||||
ab = x
|
||||
if self.is_conv2d:
|
||||
ab = ab.reshape(ab.size(0), ab.size(1), -1).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C)
|
||||
|
||||
ab = torch.nn.functional.linear(ab, lora_A)
|
||||
ab = torch.nn.functional.linear(ab, lora_B)
|
||||
|
||||
if self.is_conv2d:
|
||||
ab = ab.transpose(1, 2).reshape(ab.size(0), -1, *x.size()[2:]) # (N, H*W, C) -> (N, C, H, W)
|
||||
|
||||
# 最後の項は、低rankをより大きくするためのスケーリング(じゃないかな)
|
||||
result = result + ab * self.scale * math.sqrt(self.lora_dim / (trainable_rank + self.unit))
|
||||
|
||||
# NOTE weightに加算してからlinear/conv2dを呼んだほうが速いかも
|
||||
return result
|
||||
|
||||
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
||||
# state dictを通常のLoRAと同じにする:
|
||||
# nn.ParameterListは `.lora_A.0` みたいな名前になるので、forwardと同様にcatして入れ替える
|
||||
sd = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
|
||||
|
||||
lora_A_weight = torch.cat(tuple(self.lora_A), dim=0)
|
||||
if self.is_conv2d and not self.is_conv2d_3x3:
|
||||
lora_A_weight = lora_A_weight.unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
lora_B_weight = torch.cat(tuple(self.lora_B), dim=1)
|
||||
if self.is_conv2d and not self.is_conv2d_3x3:
|
||||
lora_B_weight = lora_B_weight.unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
sd[self.lora_name + ".lora_down.weight"] = lora_A_weight if keep_vars else lora_A_weight.detach()
|
||||
sd[self.lora_name + ".lora_up.weight"] = lora_B_weight if keep_vars else lora_B_weight.detach()
|
||||
|
||||
i = 0
|
||||
while True:
|
||||
key_a = f"{self.lora_name}.lora_A.{i}"
|
||||
key_b = f"{self.lora_name}.lora_B.{i}"
|
||||
if key_a in sd:
|
||||
sd.pop(key_a)
|
||||
sd.pop(key_b)
|
||||
else:
|
||||
break
|
||||
i += 1
|
||||
return sd
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
||||
# 通常のLoRAと同じstate dictを読み込めるようにする:この方法はchatGPTに聞いた
|
||||
lora_A_weight = state_dict.pop(self.lora_name + ".lora_down.weight", None)
|
||||
lora_B_weight = state_dict.pop(self.lora_name + ".lora_up.weight", None)
|
||||
|
||||
if lora_A_weight is None or lora_B_weight is None:
|
||||
if strict:
|
||||
raise KeyError(f"{self.lora_name}.lora_down/up.weight is not found")
|
||||
else:
|
||||
return
|
||||
|
||||
if self.is_conv2d and not self.is_conv2d_3x3:
|
||||
lora_A_weight = lora_A_weight.squeeze(-1).squeeze(-1)
|
||||
lora_B_weight = lora_B_weight.squeeze(-1).squeeze(-1)
|
||||
|
||||
state_dict.update(
|
||||
{f"{self.lora_name}.lora_A.{i}": nn.Parameter(lora_A_weight[i].unsqueeze(0)) for i in range(lora_A_weight.size(0))}
|
||||
)
|
||||
state_dict.update(
|
||||
{f"{self.lora_name}.lora_B.{i}": nn.Parameter(lora_B_weight[:, i].unsqueeze(1)) for i in range(lora_B_weight.size(1))}
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
|
||||
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
||||
if network_dim is None:
|
||||
network_dim = 4 # default
|
||||
if network_alpha is None:
|
||||
network_alpha = 1.0
|
||||
|
||||
# extract dim/alpha for conv2d, and block dim
|
||||
conv_dim = kwargs.get("conv_dim", None)
|
||||
conv_alpha = kwargs.get("conv_alpha", None)
|
||||
unit = kwargs.get("unit", None)
|
||||
if conv_dim is not None:
|
||||
conv_dim = int(conv_dim)
|
||||
assert conv_dim == network_dim, "conv_dim must be same as network_dim"
|
||||
if conv_alpha is None:
|
||||
conv_alpha = 1.0
|
||||
else:
|
||||
conv_alpha = float(conv_alpha)
|
||||
if unit is not None:
|
||||
unit = int(unit)
|
||||
else:
|
||||
unit = 1
|
||||
|
||||
network = DyLoRANetwork(
|
||||
text_encoder,
|
||||
unet,
|
||||
multiplier=multiplier,
|
||||
lora_dim=network_dim,
|
||||
alpha=network_alpha,
|
||||
apply_to_conv=conv_dim is not None,
|
||||
unit=unit,
|
||||
varbose=True,
|
||||
)
|
||||
return network
|
||||
|
||||
|
||||
# Create network from weights for inference, weights are not loaded here (because can be merged)
|
||||
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
|
||||
if weights_sd is None:
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file, safe_open
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
|
||||
# get dim/alpha mapping
|
||||
modules_dim = {}
|
||||
modules_alpha = {}
|
||||
for key, value in weights_sd.items():
|
||||
if "." not in key:
|
||||
continue
|
||||
|
||||
lora_name = key.split(".")[0]
|
||||
if "alpha" in key:
|
||||
modules_alpha[lora_name] = value
|
||||
elif "lora_down" in key:
|
||||
dim = value.size()[0]
|
||||
modules_dim[lora_name] = dim
|
||||
# print(lora_name, value.size(), dim)
|
||||
|
||||
# support old LoRA without alpha
|
||||
for key in modules_dim.keys():
|
||||
if key not in modules_alpha:
|
||||
modules_alpha = modules_dim[key]
|
||||
|
||||
module_class = DyLoRAModule
|
||||
|
||||
network = DyLoRANetwork(
|
||||
text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
|
||||
)
|
||||
return network, weights_sd
|
||||
|
||||
|
||||
class DyLoRANetwork(torch.nn.Module):
|
||||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
||||
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||
LORA_PREFIX_UNET = "lora_unet"
|
||||
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_encoder,
|
||||
unet,
|
||||
multiplier=1.0,
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
apply_to_conv=False,
|
||||
modules_dim=None,
|
||||
modules_alpha=None,
|
||||
unit=1,
|
||||
module_class=DyLoRAModule,
|
||||
varbose=False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.multiplier = multiplier
|
||||
|
||||
self.lora_dim = lora_dim
|
||||
self.alpha = alpha
|
||||
self.apply_to_conv = apply_to_conv
|
||||
|
||||
if modules_dim is not None:
|
||||
print(f"create LoRA network from weights")
|
||||
else:
|
||||
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}")
|
||||
if self.apply_to_conv:
|
||||
print(f"apply LoRA to Conv2d with kernel size (3,3).")
|
||||
|
||||
# create module instances
|
||||
def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[DyLoRAModule]:
|
||||
prefix = DyLoRANetwork.LORA_PREFIX_UNET if is_unet else DyLoRANetwork.LORA_PREFIX_TEXT_ENCODER
|
||||
loras = []
|
||||
for name, module in root_module.named_modules():
|
||||
if module.__class__.__name__ in target_replace_modules:
|
||||
for child_name, child_module in module.named_modules():
|
||||
is_linear = child_module.__class__.__name__ == "Linear"
|
||||
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
||||
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
||||
|
||||
if is_linear or is_conv2d:
|
||||
lora_name = prefix + "." + name + "." + child_name
|
||||
lora_name = lora_name.replace(".", "_")
|
||||
|
||||
dim = None
|
||||
alpha = None
|
||||
if modules_dim is not None:
|
||||
if lora_name in modules_dim:
|
||||
dim = modules_dim[lora_name]
|
||||
alpha = modules_alpha[lora_name]
|
||||
else:
|
||||
if is_linear or is_conv2d_1x1 or apply_to_conv:
|
||||
dim = self.lora_dim
|
||||
alpha = self.alpha
|
||||
|
||||
if dim is None or dim == 0:
|
||||
continue
|
||||
|
||||
# dropout and fan_in_fan_out is default
|
||||
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, unit)
|
||||
loras.append(lora)
|
||||
return loras
|
||||
|
||||
self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||
|
||||
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
||||
target_modules = DyLoRANetwork.UNET_TARGET_REPLACE_MODULE
|
||||
if modules_dim is not None or self.apply_to_conv:
|
||||
target_modules += DyLoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
|
||||
self.unet_loras = create_modules(True, unet, target_modules)
|
||||
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||
|
||||
def set_multiplier(self, multiplier):
|
||||
self.multiplier = multiplier
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.multiplier = self.multiplier
|
||||
|
||||
def load_weights(self, file):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
|
||||
info = self.load_state_dict(weights_sd, False)
|
||||
return info
|
||||
|
||||
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
|
||||
if apply_text_encoder:
|
||||
print("enable LoRA for text encoder")
|
||||
else:
|
||||
self.text_encoder_loras = []
|
||||
|
||||
if apply_unet:
|
||||
print("enable LoRA for U-Net")
|
||||
else:
|
||||
self.unet_loras = []
|
||||
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.apply_to()
|
||||
self.add_module(lora.lora_name, lora)
|
||||
|
||||
"""
|
||||
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
|
||||
apply_text_encoder = apply_unet = False
|
||||
for key in weights_sd.keys():
|
||||
if key.startswith(DyLoRANetwork.LORA_PREFIX_TEXT_ENCODER):
|
||||
apply_text_encoder = True
|
||||
elif key.startswith(DyLoRANetwork.LORA_PREFIX_UNET):
|
||||
apply_unet = True
|
||||
|
||||
if apply_text_encoder:
|
||||
print("enable LoRA for text encoder")
|
||||
else:
|
||||
self.text_encoder_loras = []
|
||||
|
||||
if apply_unet:
|
||||
print("enable LoRA for U-Net")
|
||||
else:
|
||||
self.unet_loras = []
|
||||
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
sd_for_lora = {}
|
||||
for key in weights_sd.keys():
|
||||
if key.startswith(lora.lora_name):
|
||||
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
|
||||
lora.merge_to(sd_for_lora, dtype, device)
|
||||
|
||||
print(f"weights are merged")
|
||||
"""
|
||||
|
||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
||||
self.requires_grad_(True)
|
||||
all_params = []
|
||||
|
||||
def enumerate_params(loras):
|
||||
params = []
|
||||
for lora in loras:
|
||||
params.extend(lora.parameters())
|
||||
return params
|
||||
|
||||
if self.text_encoder_loras:
|
||||
param_data = {"params": enumerate_params(self.text_encoder_loras)}
|
||||
if text_encoder_lr is not None:
|
||||
param_data["lr"] = text_encoder_lr
|
||||
all_params.append(param_data)
|
||||
|
||||
if self.unet_loras:
|
||||
param_data = {"params": enumerate_params(self.unet_loras)}
|
||||
if unet_lr is not None:
|
||||
param_data["lr"] = unet_lr
|
||||
all_params.append(param_data)
|
||||
|
||||
return all_params
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
# not supported
|
||||
pass
|
||||
|
||||
def prepare_grad_etc(self, text_encoder, unet):
|
||||
self.requires_grad_(True)
|
||||
|
||||
def on_epoch_start(self, text_encoder, unet):
|
||||
self.train()
|
||||
|
||||
def get_trainable_params(self):
|
||||
return self.parameters()
|
||||
|
||||
def save_weights(self, file, dtype, metadata):
|
||||
if metadata is not None and len(metadata) == 0:
|
||||
metadata = None
|
||||
|
||||
state_dict = self.state_dict()
|
||||
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
v = state_dict[key]
|
||||
v = v.detach().clone().to("cpu").to(dtype)
|
||||
state_dict[key] = v
|
||||
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import save_file
|
||||
from library import train_util
|
||||
|
||||
# Precalculate model hashes to save time on indexing
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
metadata["sshs_legacy_hash"] = legacy_hash
|
||||
|
||||
save_file(state_dict, file, metadata)
|
||||
else:
|
||||
torch.save(state_dict, file)
|
||||
|
||||
# mask is a tensor with values from 0 to 1
|
||||
def set_region(self, sub_prompt_index, is_last_network, mask):
|
||||
pass
|
||||
|
||||
def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared):
|
||||
pass
|
||||
125
networks/extract_lora_from_dylora.py
Normal file
125
networks/extract_lora_from_dylora.py
Normal file
@@ -0,0 +1,125 @@
|
||||
# Convert LoRA to different rank approximation (should only be used to go to lower rank)
|
||||
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
|
||||
# Thanks to cloneofsimo
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file, safe_open
|
||||
from tqdm import tqdm
|
||||
from library import train_util, model_util
|
||||
import numpy as np
|
||||
|
||||
|
||||
def load_state_dict(file_name):
|
||||
if model_util.is_safetensors(file_name):
|
||||
sd = load_file(file_name)
|
||||
with safe_open(file_name, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
else:
|
||||
sd = torch.load(file_name, map_location="cpu")
|
||||
metadata = None
|
||||
|
||||
return sd, metadata
|
||||
|
||||
|
||||
def save_to_file(file_name, model, metadata):
|
||||
if model_util.is_safetensors(file_name):
|
||||
save_file(model, file_name, metadata)
|
||||
else:
|
||||
torch.save(model, file_name)
|
||||
|
||||
|
||||
def split_lora_model(lora_sd, unit):
|
||||
max_rank = 0
|
||||
|
||||
# Extract loaded lora dim and alpha
|
||||
for key, value in lora_sd.items():
|
||||
if "lora_down" in key:
|
||||
rank = value.size()[0]
|
||||
if rank > max_rank:
|
||||
max_rank = rank
|
||||
print(f"Max rank: {max_rank}")
|
||||
|
||||
rank = unit
|
||||
split_models = []
|
||||
new_alpha = None
|
||||
while rank < max_rank:
|
||||
print(f"Splitting rank {rank}")
|
||||
new_sd = {}
|
||||
for key, value in lora_sd.items():
|
||||
if "lora_down" in key:
|
||||
new_sd[key] = value[:rank].contiguous()
|
||||
elif "lora_up" in key:
|
||||
new_sd[key] = value[:, :rank].contiguous()
|
||||
else:
|
||||
# なぜかscaleするとおかしくなる……
|
||||
# this_rank = lora_sd[key.replace("alpha", "lora_down.weight")].size()[0]
|
||||
# scale = math.sqrt(this_rank / rank) # rank is > unit
|
||||
# print(key, value.size(), this_rank, rank, value, scale)
|
||||
# new_alpha = value * scale # always same
|
||||
# new_sd[key] = new_alpha
|
||||
new_sd[key] = value
|
||||
|
||||
split_models.append((new_sd, rank, new_alpha))
|
||||
rank += unit
|
||||
|
||||
return max_rank, split_models
|
||||
|
||||
|
||||
def split(args):
|
||||
print("loading Model...")
|
||||
lora_sd, metadata = load_state_dict(args.model)
|
||||
|
||||
print("Splitting Model...")
|
||||
original_rank, split_models = split_lora_model(lora_sd, args.unit)
|
||||
|
||||
comment = metadata.get("ss_training_comment", "")
|
||||
for state_dict, new_rank, new_alpha in split_models:
|
||||
# update metadata
|
||||
if metadata is None:
|
||||
new_metadata = {}
|
||||
else:
|
||||
new_metadata = metadata.copy()
|
||||
|
||||
new_metadata["ss_training_comment"] = f"split from DyLoRA, rank {original_rank} to {new_rank}; {comment}"
|
||||
new_metadata["ss_network_dim"] = str(new_rank)
|
||||
# new_metadata["ss_network_alpha"] = str(new_alpha.float().numpy())
|
||||
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
metadata["sshs_legacy_hash"] = legacy_hash
|
||||
|
||||
filename, ext = os.path.splitext(args.save_to)
|
||||
model_file_name = filename + f"-{new_rank:04d}{ext}"
|
||||
|
||||
print(f"saving model to: {model_file_name}")
|
||||
save_to_file(model_file_name, state_dict, new_metadata)
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--unit", type=int, default=None, help="size of rank to split into / rankを分割するサイズ")
|
||||
parser.add_argument(
|
||||
"--save_to",
|
||||
type=str,
|
||||
default=None,
|
||||
help="destination base file name: ckpt or safetensors file / 保存先のファイル名のbase、ckptまたはsafetensors",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default=None,
|
||||
help="DyLoRA model to resize at to new rank: ckpt or safetensors file / 読み込むDyLoRAモデル、ckptまたはsafetensors",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
split(args)
|
||||
@@ -3,181 +3,355 @@
|
||||
# Thanks to cloneofsimo!
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
import library.model_util as model_util
|
||||
from library import sai_model_spec, model_util, sdxl_model_util
|
||||
import lora
|
||||
|
||||
|
||||
CLAMP_QUANTILE = 0.99
|
||||
MIN_DIFF = 1e-6
|
||||
# CLAMP_QUANTILE = 0.99
|
||||
# MIN_DIFF = 1e-1
|
||||
|
||||
|
||||
def save_to_file(file_name, model, state_dict, dtype):
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
if type(state_dict[key]) == torch.Tensor:
|
||||
state_dict[key] = state_dict[key].to(dtype)
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
if type(state_dict[key]) == torch.Tensor:
|
||||
state_dict[key] = state_dict[key].to(dtype)
|
||||
|
||||
if os.path.splitext(file_name)[1] == '.safetensors':
|
||||
save_file(model, file_name)
|
||||
else:
|
||||
torch.save(model, file_name)
|
||||
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||
save_file(model, file_name)
|
||||
else:
|
||||
torch.save(model, file_name)
|
||||
|
||||
|
||||
def svd(args):
|
||||
def str_to_dtype(p):
|
||||
if p == 'float':
|
||||
return torch.float
|
||||
if p == 'fp16':
|
||||
return torch.float16
|
||||
if p == 'bf16':
|
||||
return torch.bfloat16
|
||||
return None
|
||||
def svd(
|
||||
model_org=None,
|
||||
model_tuned=None,
|
||||
save_to=None,
|
||||
dim=4,
|
||||
v2=None,
|
||||
sdxl=None,
|
||||
conv_dim=None,
|
||||
v_parameterization=None,
|
||||
device=None,
|
||||
save_precision=None,
|
||||
clamp_quantile=0.99,
|
||||
min_diff=0.01,
|
||||
no_metadata=False,
|
||||
load_precision=None,
|
||||
load_original_model_to=None,
|
||||
load_tuned_model_to=None,
|
||||
):
|
||||
def str_to_dtype(p):
|
||||
if p == "float":
|
||||
return torch.float
|
||||
if p == "fp16":
|
||||
return torch.float16
|
||||
if p == "bf16":
|
||||
return torch.bfloat16
|
||||
return None
|
||||
|
||||
save_dtype = str_to_dtype(args.save_precision)
|
||||
assert v2 != sdxl or (not v2 and not sdxl), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません"
|
||||
if v_parameterization is None:
|
||||
v_parameterization = v2
|
||||
|
||||
print(f"loading SD model : {args.model_org}")
|
||||
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org)
|
||||
print(f"loading SD model : {args.model_tuned}")
|
||||
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned)
|
||||
load_dtype = str_to_dtype(load_precision) if load_precision else None
|
||||
save_dtype = str_to_dtype(save_precision)
|
||||
work_device = "cpu"
|
||||
|
||||
# create LoRA network to extract weights: Use dim (rank) as alpha
|
||||
if args.conv_dim is None:
|
||||
kwargs = {}
|
||||
else:
|
||||
kwargs = {"conv_dim": args.conv_dim, "conv_alpha": args.conv_dim}
|
||||
# load models
|
||||
if not sdxl:
|
||||
print(f"loading original SD model : {model_org}")
|
||||
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org)
|
||||
text_encoders_o = [text_encoder_o]
|
||||
if load_dtype is not None:
|
||||
text_encoder_o = text_encoder_o.to(load_dtype)
|
||||
unet_o = unet_o.to(load_dtype)
|
||||
|
||||
lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_o, unet_o, **kwargs)
|
||||
lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_t, unet_t, **kwargs)
|
||||
assert len(lora_network_o.text_encoder_loras) == len(
|
||||
lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
|
||||
print(f"loading tuned SD model : {model_tuned}")
|
||||
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned)
|
||||
text_encoders_t = [text_encoder_t]
|
||||
if load_dtype is not None:
|
||||
text_encoder_t = text_encoder_t.to(load_dtype)
|
||||
unet_t = unet_t.to(load_dtype)
|
||||
|
||||
# get diffs
|
||||
diffs = {}
|
||||
text_encoder_different = False
|
||||
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras)):
|
||||
lora_name = lora_o.lora_name
|
||||
module_o = lora_o.org_module
|
||||
module_t = lora_t.org_module
|
||||
diff = module_t.weight - module_o.weight
|
||||
model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization)
|
||||
else:
|
||||
device_org = load_original_model_to if load_original_model_to else "cpu"
|
||||
device_tuned = load_tuned_model_to if load_tuned_model_to else "cpu"
|
||||
|
||||
# Text Encoder might be same
|
||||
if torch.max(torch.abs(diff)) > MIN_DIFF:
|
||||
text_encoder_different = True
|
||||
print(f"loading original SDXL model : {model_org}")
|
||||
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, device_org
|
||||
)
|
||||
text_encoders_o = [text_encoder_o1, text_encoder_o2]
|
||||
if load_dtype is not None:
|
||||
text_encoder_o1 = text_encoder_o1.to(load_dtype)
|
||||
text_encoder_o2 = text_encoder_o2.to(load_dtype)
|
||||
unet_o = unet_o.to(load_dtype)
|
||||
|
||||
diff = diff.float()
|
||||
diffs[lora_name] = diff
|
||||
print(f"loading original SDXL model : {model_tuned}")
|
||||
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, device_tuned
|
||||
)
|
||||
text_encoders_t = [text_encoder_t1, text_encoder_t2]
|
||||
if load_dtype is not None:
|
||||
text_encoder_t1 = text_encoder_t1.to(load_dtype)
|
||||
text_encoder_t2 = text_encoder_t2.to(load_dtype)
|
||||
unet_t = unet_t.to(load_dtype)
|
||||
|
||||
if not text_encoder_different:
|
||||
print("Text encoder is same. Extract U-Net only.")
|
||||
lora_network_o.text_encoder_loras = []
|
||||
model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
|
||||
|
||||
# create LoRA network to extract weights: Use dim (rank) as alpha
|
||||
if conv_dim is None:
|
||||
kwargs = {}
|
||||
else:
|
||||
kwargs = {"conv_dim": conv_dim, "conv_alpha": conv_dim}
|
||||
|
||||
lora_network_o = lora.create_network(1.0, dim, dim, None, text_encoders_o, unet_o, **kwargs)
|
||||
lora_network_t = lora.create_network(1.0, dim, dim, None, text_encoders_t, unet_t, **kwargs)
|
||||
assert len(lora_network_o.text_encoder_loras) == len(
|
||||
lora_network_t.text_encoder_loras
|
||||
), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
|
||||
|
||||
# get diffs
|
||||
diffs = {}
|
||||
text_encoder_different = False
|
||||
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras)):
|
||||
lora_name = lora_o.lora_name
|
||||
module_o = lora_o.org_module
|
||||
module_t = lora_t.org_module
|
||||
diff = module_t.weight.to(work_device) - module_o.weight.to(work_device)
|
||||
|
||||
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)):
|
||||
lora_name = lora_o.lora_name
|
||||
module_o = lora_o.org_module
|
||||
module_t = lora_t.org_module
|
||||
diff = module_t.weight - module_o.weight
|
||||
diff = diff.float()
|
||||
# clear weight to save memory
|
||||
module_o.weight = None
|
||||
module_t.weight = None
|
||||
|
||||
if args.device:
|
||||
diff = diff.to(args.device)
|
||||
# Text Encoder might be same
|
||||
if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff:
|
||||
text_encoder_different = True
|
||||
print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}")
|
||||
|
||||
diffs[lora_name] = diff
|
||||
diffs[lora_name] = diff
|
||||
|
||||
# make LoRA with svd
|
||||
print("calculating by svd")
|
||||
lora_weights = {}
|
||||
with torch.no_grad():
|
||||
for lora_name, mat in tqdm(list(diffs.items())):
|
||||
# if args.conv_dim is None, diffs do not include LoRAs for conv2d-3x3
|
||||
conv2d = (len(mat.size()) == 4)
|
||||
kernel_size = None if not conv2d else mat.size()[2:4]
|
||||
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
||||
# clear target Text Encoder to save memory
|
||||
for text_encoder in text_encoders_t:
|
||||
del text_encoder
|
||||
|
||||
rank = args.dim if not conv2d_3x3 or args.conv_dim is None else args.conv_dim
|
||||
out_dim, in_dim = mat.size()[0:2]
|
||||
if not text_encoder_different:
|
||||
print("Text encoder is same. Extract U-Net only.")
|
||||
lora_network_o.text_encoder_loras = []
|
||||
diffs = {} # clear diffs
|
||||
|
||||
if args.device:
|
||||
mat = mat.to(args.device)
|
||||
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)):
|
||||
lora_name = lora_o.lora_name
|
||||
module_o = lora_o.org_module
|
||||
module_t = lora_t.org_module
|
||||
diff = module_t.weight.to(work_device) - module_o.weight.to(work_device)
|
||||
|
||||
# print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
|
||||
rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
|
||||
# clear weight to save memory
|
||||
module_o.weight = None
|
||||
module_t.weight = None
|
||||
|
||||
if conv2d:
|
||||
if conv2d_3x3:
|
||||
mat = mat.flatten(start_dim=1)
|
||||
else:
|
||||
mat = mat.squeeze()
|
||||
diffs[lora_name] = diff
|
||||
|
||||
U, S, Vh = torch.linalg.svd(mat)
|
||||
# clear LoRA network, target U-Net to save memory
|
||||
del lora_network_o
|
||||
del lora_network_t
|
||||
del unet_t
|
||||
|
||||
U = U[:, :rank]
|
||||
S = S[:rank]
|
||||
U = U @ torch.diag(S)
|
||||
# make LoRA with svd
|
||||
print("calculating by svd")
|
||||
lora_weights = {}
|
||||
with torch.no_grad():
|
||||
for lora_name, mat in tqdm(list(diffs.items())):
|
||||
if args.device:
|
||||
mat = mat.to(args.device)
|
||||
mat = mat.to(torch.float) # calc by float
|
||||
|
||||
Vh = Vh[:rank, :]
|
||||
# if conv_dim is None, diffs do not include LoRAs for conv2d-3x3
|
||||
conv2d = len(mat.size()) == 4
|
||||
kernel_size = None if not conv2d else mat.size()[2:4]
|
||||
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
||||
|
||||
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||
low_val = -hi_val
|
||||
rank = dim if not conv2d_3x3 or conv_dim is None else conv_dim
|
||||
out_dim, in_dim = mat.size()[0:2]
|
||||
|
||||
U = U.clamp(low_val, hi_val)
|
||||
Vh = Vh.clamp(low_val, hi_val)
|
||||
if device:
|
||||
mat = mat.to(device)
|
||||
|
||||
if conv2d:
|
||||
U = U.reshape(out_dim, rank, 1, 1)
|
||||
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
|
||||
# print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
|
||||
rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
|
||||
|
||||
U = U.to("cpu").contiguous()
|
||||
Vh = Vh.to("cpu").contiguous()
|
||||
if conv2d:
|
||||
if conv2d_3x3:
|
||||
mat = mat.flatten(start_dim=1)
|
||||
else:
|
||||
mat = mat.squeeze()
|
||||
|
||||
lora_weights[lora_name] = (U, Vh)
|
||||
U, S, Vh = torch.linalg.svd(mat)
|
||||
|
||||
# make state dict for LoRA
|
||||
lora_sd = {}
|
||||
for lora_name, (up_weight, down_weight) in lora_weights.items():
|
||||
lora_sd[lora_name + '.lora_up.weight'] = up_weight
|
||||
lora_sd[lora_name + '.lora_down.weight'] = down_weight
|
||||
lora_sd[lora_name + '.alpha'] = torch.tensor(down_weight.size()[0])
|
||||
U = U[:, :rank]
|
||||
S = S[:rank]
|
||||
U = U @ torch.diag(S)
|
||||
|
||||
# load state dict to LoRA and save it
|
||||
lora_network_save = lora.create_network_from_weights(1.0, None, None, text_encoder_o, unet_o, weights_sd=lora_sd)
|
||||
lora_network_save.apply_to(text_encoder_o, unet_o) # create internal module references for state_dict
|
||||
Vh = Vh[:rank, :]
|
||||
|
||||
info = lora_network_save.load_state_dict(lora_sd)
|
||||
print(f"Loading extracted LoRA weights: {info}")
|
||||
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||
hi_val = torch.quantile(dist, clamp_quantile)
|
||||
low_val = -hi_val
|
||||
|
||||
dir_name = os.path.dirname(args.save_to)
|
||||
if dir_name and not os.path.exists(dir_name):
|
||||
os.makedirs(dir_name, exist_ok=True)
|
||||
U = U.clamp(low_val, hi_val)
|
||||
Vh = Vh.clamp(low_val, hi_val)
|
||||
|
||||
# minimum metadata
|
||||
metadata = {"ss_network_module": "networks.lora", "ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)}
|
||||
if conv2d:
|
||||
U = U.reshape(out_dim, rank, 1, 1)
|
||||
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
|
||||
|
||||
lora_network_save.save_weights(args.save_to, save_dtype, metadata)
|
||||
print(f"LoRA weights are saved to: {args.save_to}")
|
||||
U = U.to(work_device, dtype=save_dtype).contiguous()
|
||||
Vh = Vh.to(work_device, dtype=save_dtype).contiguous()
|
||||
|
||||
lora_weights[lora_name] = (U, Vh)
|
||||
|
||||
# make state dict for LoRA
|
||||
lora_sd = {}
|
||||
for lora_name, (up_weight, down_weight) in lora_weights.items():
|
||||
lora_sd[lora_name + ".lora_up.weight"] = up_weight
|
||||
lora_sd[lora_name + ".lora_down.weight"] = down_weight
|
||||
lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0])
|
||||
|
||||
# load state dict to LoRA and save it
|
||||
lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoders_o, unet_o, weights_sd=lora_sd)
|
||||
lora_network_save.apply_to(text_encoders_o, unet_o) # create internal module references for state_dict
|
||||
|
||||
info = lora_network_save.load_state_dict(lora_sd)
|
||||
print(f"Loading extracted LoRA weights: {info}")
|
||||
|
||||
dir_name = os.path.dirname(save_to)
|
||||
if dir_name and not os.path.exists(dir_name):
|
||||
os.makedirs(dir_name, exist_ok=True)
|
||||
|
||||
# minimum metadata
|
||||
net_kwargs = {}
|
||||
if conv_dim is not None:
|
||||
net_kwargs["conv_dim"] = str(conv_dim)
|
||||
net_kwargs["conv_alpha"] = str(float(conv_dim))
|
||||
|
||||
metadata = {
|
||||
"ss_v2": str(v2),
|
||||
"ss_base_model_version": model_version,
|
||||
"ss_network_module": "networks.lora",
|
||||
"ss_network_dim": str(dim),
|
||||
"ss_network_alpha": str(float(dim)),
|
||||
"ss_network_args": json.dumps(net_kwargs),
|
||||
}
|
||||
|
||||
if not no_metadata:
|
||||
title = os.path.splitext(os.path.basename(save_to))[0]
|
||||
sai_metadata = sai_model_spec.build_metadata(None, v2, v_parameterization, sdxl, True, False, time.time(), title=title)
|
||||
metadata.update(sai_metadata)
|
||||
|
||||
lora_network_save.save_weights(save_to, save_dtype, metadata)
|
||||
print(f"LoRA weights are saved to: {save_to}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--v2", action='store_true',
|
||||
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
|
||||
parser.add_argument("--save_precision", type=str, default=None,
|
||||
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat")
|
||||
parser.add_argument("--model_org", type=str, default=None,
|
||||
help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors")
|
||||
parser.add_argument("--model_tuned", type=str, default=None,
|
||||
help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors")
|
||||
parser.add_argument("--save_to", type=str, default=None,
|
||||
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
|
||||
parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)")
|
||||
parser.add_argument("--conv_dim", type=int, default=None,
|
||||
help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)")
|
||||
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む")
|
||||
parser.add_argument(
|
||||
"--v_parameterization",
|
||||
action="store_true",
|
||||
default=None,
|
||||
help="make LoRA metadata for v-parameterization (default is same to v2) / 作成するLoRAのメタデータにv-parameterization用と設定する(省略時はv2と同じ)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sdxl", action="store_true", help="load Stable Diffusion SDXL base model / Stable Diffusion SDXL baseのモデルを読み込む"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_precision",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=[None, "float", "fp16", "bf16"],
|
||||
help="precision in loading, model default if omitted / 読み込み時に精度を変更して読み込む、省略時はモデルファイルによる"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_precision",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=[None, "float", "fp16", "bf16"],
|
||||
help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_org",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_tuned",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_to",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors",
|
||||
)
|
||||
parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)")
|
||||
parser.add_argument(
|
||||
"--conv_dim",
|
||||
type=int,
|
||||
default=None,
|
||||
help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)",
|
||||
)
|
||||
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
||||
parser.add_argument(
|
||||
"--clamp_quantile",
|
||||
type=float,
|
||||
default=0.99,
|
||||
help="Quantile clamping value, float, (0-1). Default = 0.99 / 値をクランプするための分位点、float、(0-1)。デフォルトは0.99",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min_diff",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01 /"
|
||||
+ "LoRAを抽出するために元モデルと派生モデルの差分の最小値、float、(0-1)。デフォルトは0.01",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_metadata",
|
||||
action="store_true",
|
||||
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
|
||||
+ "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_original_model_to",
|
||||
type=str,
|
||||
default=None,
|
||||
help="location to load original model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 元モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_tuned_model_to",
|
||||
type=str,
|
||||
default=None,
|
||||
help="location to load tuned model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 派生モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
svd(args)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
svd(**vars(args))
|
||||
|
||||
1518
networks/lora.py
1518
networks/lora.py
File diff suppressed because it is too large
Load Diff
609
networks/lora_diffusers.py
Normal file
609
networks/lora_diffusers.py
Normal file
@@ -0,0 +1,609 @@
|
||||
# Diffusersで動くLoRA。このファイル単独で完結する。
|
||||
# LoRA module for Diffusers. This file works independently.
|
||||
|
||||
import bisect
|
||||
import math
|
||||
import random
|
||||
from typing import Any, Dict, List, Mapping, Optional, Union
|
||||
from diffusers import UNet2DConditionModel
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPTextModel
|
||||
import torch
|
||||
|
||||
|
||||
def make_unet_conversion_map() -> Dict[str, str]:
|
||||
unet_conversion_map_layer = []
|
||||
|
||||
for i in range(3): # num_blocks is 3 in sdxl
|
||||
# loop over downblocks/upblocks
|
||||
for j in range(2):
|
||||
# loop over resnets/attentions for downblocks
|
||||
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
||||
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no attention layers in down_blocks.3
|
||||
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
||||
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||
|
||||
for j in range(3):
|
||||
# loop over resnets/attentions for upblocks
|
||||
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
||||
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
||||
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
||||
|
||||
# if i > 0: commentout for sdxl
|
||||
# no attention layers in up_blocks.0
|
||||
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
||||
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
||||
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no downsample in down_blocks.3
|
||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
||||
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
# no upsample in up_blocks.3
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
|
||||
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
|
||||
hf_mid_atn_prefix = "mid_block.attentions.0."
|
||||
sd_mid_atn_prefix = "middle_block.1."
|
||||
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2*j}."
|
||||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
unet_conversion_map_resnet = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("in_layers.0.", "norm1."),
|
||||
("in_layers.2.", "conv1."),
|
||||
("out_layers.0.", "norm2."),
|
||||
("out_layers.3.", "conv2."),
|
||||
("emb_layers.1.", "time_emb_proj."),
|
||||
("skip_connection.", "conv_shortcut."),
|
||||
]
|
||||
|
||||
unet_conversion_map = []
|
||||
for sd, hf in unet_conversion_map_layer:
|
||||
if "resnets" in hf:
|
||||
for sd_res, hf_res in unet_conversion_map_resnet:
|
||||
unet_conversion_map.append((sd + sd_res, hf + hf_res))
|
||||
else:
|
||||
unet_conversion_map.append((sd, hf))
|
||||
|
||||
for j in range(2):
|
||||
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
|
||||
sd_time_embed_prefix = f"time_embed.{j*2}."
|
||||
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
|
||||
sd_label_embed_prefix = f"label_emb.0.{j*2}."
|
||||
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
|
||||
|
||||
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
|
||||
unet_conversion_map.append(("out.0.", "conv_norm_out."))
|
||||
unet_conversion_map.append(("out.2.", "conv_out."))
|
||||
|
||||
sd_hf_conversion_map = {sd.replace(".", "_")[:-1]: hf.replace(".", "_")[:-1] for sd, hf in unet_conversion_map}
|
||||
return sd_hf_conversion_map
|
||||
|
||||
|
||||
UNET_CONVERSION_MAP = make_unet_conversion_map()
|
||||
|
||||
|
||||
class LoRAModule(torch.nn.Module):
|
||||
"""
|
||||
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lora_name,
|
||||
org_module: torch.nn.Module,
|
||||
multiplier=1.0,
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
):
|
||||
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
||||
super().__init__()
|
||||
self.lora_name = lora_name
|
||||
|
||||
if org_module.__class__.__name__ == "Conv2d" or org_module.__class__.__name__ == "LoRACompatibleConv":
|
||||
in_dim = org_module.in_channels
|
||||
out_dim = org_module.out_channels
|
||||
else:
|
||||
in_dim = org_module.in_features
|
||||
out_dim = org_module.out_features
|
||||
|
||||
self.lora_dim = lora_dim
|
||||
|
||||
if org_module.__class__.__name__ == "Conv2d" or org_module.__class__.__name__ == "LoRACompatibleConv":
|
||||
kernel_size = org_module.kernel_size
|
||||
stride = org_module.stride
|
||||
padding = org_module.padding
|
||||
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
||||
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
||||
else:
|
||||
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
||||
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
|
||||
|
||||
if type(alpha) == torch.Tensor:
|
||||
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
||||
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
|
||||
self.scale = alpha / self.lora_dim
|
||||
self.register_buffer("alpha", torch.tensor(alpha)) # 勾配計算に含めない / not included in gradient calculation
|
||||
|
||||
# same as microsoft's
|
||||
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
||||
torch.nn.init.zeros_(self.lora_up.weight)
|
||||
|
||||
self.multiplier = multiplier
|
||||
self.org_module = [org_module]
|
||||
self.enabled = True
|
||||
self.network: LoRANetwork = None
|
||||
self.org_forward = None
|
||||
|
||||
# override org_module's forward method
|
||||
def apply_to(self, multiplier=None):
|
||||
if multiplier is not None:
|
||||
self.multiplier = multiplier
|
||||
if self.org_forward is None:
|
||||
self.org_forward = self.org_module[0].forward
|
||||
self.org_module[0].forward = self.forward
|
||||
|
||||
# restore org_module's forward method
|
||||
def unapply_to(self):
|
||||
if self.org_forward is not None:
|
||||
self.org_module[0].forward = self.org_forward
|
||||
|
||||
# forward with lora
|
||||
# scale is used LoRACompatibleConv, but we ignore it because we have multiplier
|
||||
def forward(self, x, scale=1.0):
|
||||
if not self.enabled:
|
||||
return self.org_forward(x)
|
||||
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||
|
||||
def set_network(self, network):
|
||||
self.network = network
|
||||
|
||||
# merge lora weight to org weight
|
||||
def merge_to(self, multiplier=1.0):
|
||||
# get lora weight
|
||||
lora_weight = self.get_weight(multiplier)
|
||||
|
||||
# get org weight
|
||||
org_sd = self.org_module[0].state_dict()
|
||||
org_weight = org_sd["weight"]
|
||||
weight = org_weight + lora_weight.to(org_weight.device, dtype=org_weight.dtype)
|
||||
|
||||
# set weight to org_module
|
||||
org_sd["weight"] = weight
|
||||
self.org_module[0].load_state_dict(org_sd)
|
||||
|
||||
# restore org weight from lora weight
|
||||
def restore_from(self, multiplier=1.0):
|
||||
# get lora weight
|
||||
lora_weight = self.get_weight(multiplier)
|
||||
|
||||
# get org weight
|
||||
org_sd = self.org_module[0].state_dict()
|
||||
org_weight = org_sd["weight"]
|
||||
weight = org_weight - lora_weight.to(org_weight.device, dtype=org_weight.dtype)
|
||||
|
||||
# set weight to org_module
|
||||
org_sd["weight"] = weight
|
||||
self.org_module[0].load_state_dict(org_sd)
|
||||
|
||||
# return lora weight
|
||||
def get_weight(self, multiplier=None):
|
||||
if multiplier is None:
|
||||
multiplier = self.multiplier
|
||||
|
||||
# get up/down weight from module
|
||||
up_weight = self.lora_up.weight.to(torch.float)
|
||||
down_weight = self.lora_down.weight.to(torch.float)
|
||||
|
||||
# pre-calculated weight
|
||||
if len(down_weight.size()) == 2:
|
||||
# linear
|
||||
weight = self.multiplier * (up_weight @ down_weight) * self.scale
|
||||
elif down_weight.size()[2:4] == (1, 1):
|
||||
# conv2d 1x1
|
||||
weight = (
|
||||
self.multiplier
|
||||
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
* self.scale
|
||||
)
|
||||
else:
|
||||
# conv2d 3x3
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
weight = self.multiplier * conved * self.scale
|
||||
|
||||
return weight
|
||||
|
||||
|
||||
# Create network from weights for inference, weights are not loaded here
|
||||
def create_network_from_weights(
|
||||
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]], unet: UNet2DConditionModel, weights_sd: Dict, multiplier: float = 1.0
|
||||
):
|
||||
# get dim/alpha mapping
|
||||
modules_dim = {}
|
||||
modules_alpha = {}
|
||||
for key, value in weights_sd.items():
|
||||
if "." not in key:
|
||||
continue
|
||||
|
||||
lora_name = key.split(".")[0]
|
||||
if "alpha" in key:
|
||||
modules_alpha[lora_name] = value
|
||||
elif "lora_down" in key:
|
||||
dim = value.size()[0]
|
||||
modules_dim[lora_name] = dim
|
||||
# print(lora_name, value.size(), dim)
|
||||
|
||||
# support old LoRA without alpha
|
||||
for key in modules_dim.keys():
|
||||
if key not in modules_alpha:
|
||||
modules_alpha[key] = modules_dim[key]
|
||||
|
||||
return LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha)
|
||||
|
||||
|
||||
def merge_lora_weights(pipe, weights_sd: Dict, multiplier: float = 1.0):
|
||||
text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if hasattr(pipe, "text_encoder_2") else [pipe.text_encoder]
|
||||
unet = pipe.unet
|
||||
|
||||
lora_network = create_network_from_weights(text_encoders, unet, weights_sd, multiplier=multiplier)
|
||||
lora_network.load_state_dict(weights_sd)
|
||||
lora_network.merge_to(multiplier=multiplier)
|
||||
|
||||
|
||||
# block weightや学習に対応しない簡易版 / simple version without block weight and training
|
||||
class LoRANetwork(torch.nn.Module):
|
||||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
||||
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||
LORA_PREFIX_UNET = "lora_unet"
|
||||
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
||||
|
||||
# SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
|
||||
LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
|
||||
LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
|
||||
unet: UNet2DConditionModel,
|
||||
multiplier: float = 1.0,
|
||||
modules_dim: Optional[Dict[str, int]] = None,
|
||||
modules_alpha: Optional[Dict[str, int]] = None,
|
||||
varbose: Optional[bool] = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.multiplier = multiplier
|
||||
|
||||
print(f"create LoRA network from weights")
|
||||
|
||||
# convert SDXL Stability AI's U-Net modules to Diffusers
|
||||
converted = self.convert_unet_modules(modules_dim, modules_alpha)
|
||||
if converted:
|
||||
print(f"converted {converted} Stability AI's U-Net LoRA modules to Diffusers (SDXL)")
|
||||
|
||||
# create module instances
|
||||
def create_modules(
|
||||
is_unet: bool,
|
||||
text_encoder_idx: Optional[int], # None, 1, 2
|
||||
root_module: torch.nn.Module,
|
||||
target_replace_modules: List[torch.nn.Module],
|
||||
) -> List[LoRAModule]:
|
||||
prefix = (
|
||||
self.LORA_PREFIX_UNET
|
||||
if is_unet
|
||||
else (
|
||||
self.LORA_PREFIX_TEXT_ENCODER
|
||||
if text_encoder_idx is None
|
||||
else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2)
|
||||
)
|
||||
)
|
||||
loras = []
|
||||
skipped = []
|
||||
for name, module in root_module.named_modules():
|
||||
if module.__class__.__name__ in target_replace_modules:
|
||||
for child_name, child_module in module.named_modules():
|
||||
is_linear = (
|
||||
child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "LoRACompatibleLinear"
|
||||
)
|
||||
is_conv2d = (
|
||||
child_module.__class__.__name__ == "Conv2d" or child_module.__class__.__name__ == "LoRACompatibleConv"
|
||||
)
|
||||
|
||||
if is_linear or is_conv2d:
|
||||
lora_name = prefix + "." + name + "." + child_name
|
||||
lora_name = lora_name.replace(".", "_")
|
||||
|
||||
if lora_name not in modules_dim:
|
||||
# print(f"skipped {lora_name} (not found in modules_dim)")
|
||||
skipped.append(lora_name)
|
||||
continue
|
||||
|
||||
dim = modules_dim[lora_name]
|
||||
alpha = modules_alpha[lora_name]
|
||||
lora = LoRAModule(
|
||||
lora_name,
|
||||
child_module,
|
||||
self.multiplier,
|
||||
dim,
|
||||
alpha,
|
||||
)
|
||||
loras.append(lora)
|
||||
return loras, skipped
|
||||
|
||||
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
|
||||
|
||||
# create LoRA for text encoder
|
||||
# 毎回すべてのモジュールを作るのは無駄なので要検討 / it is wasteful to create all modules every time, need to consider
|
||||
self.text_encoder_loras: List[LoRAModule] = []
|
||||
skipped_te = []
|
||||
for i, text_encoder in enumerate(text_encoders):
|
||||
if len(text_encoders) > 1:
|
||||
index = i + 1
|
||||
else:
|
||||
index = None
|
||||
|
||||
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||
self.text_encoder_loras.extend(text_encoder_loras)
|
||||
skipped_te += skipped
|
||||
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||
if len(skipped_te) > 0:
|
||||
print(f"skipped {len(skipped_te)} modules because of missing weight for text encoder.")
|
||||
|
||||
# extend U-Net target modules to include Conv2d 3x3
|
||||
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
|
||||
self.unet_loras: List[LoRAModule]
|
||||
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
||||
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||
if len(skipped_un) > 0:
|
||||
print(f"skipped {len(skipped_un)} modules because of missing weight for U-Net.")
|
||||
|
||||
# assertion
|
||||
names = set()
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
names.add(lora.lora_name)
|
||||
for lora_name in modules_dim.keys():
|
||||
assert lora_name in names, f"{lora_name} is not found in created LoRA modules."
|
||||
|
||||
# make to work load_state_dict
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
self.add_module(lora.lora_name, lora)
|
||||
|
||||
# SDXL: convert SDXL Stability AI's U-Net modules to Diffusers
|
||||
def convert_unet_modules(self, modules_dim, modules_alpha):
|
||||
converted_count = 0
|
||||
not_converted_count = 0
|
||||
|
||||
map_keys = list(UNET_CONVERSION_MAP.keys())
|
||||
map_keys.sort()
|
||||
|
||||
for key in list(modules_dim.keys()):
|
||||
if key.startswith(LoRANetwork.LORA_PREFIX_UNET + "_"):
|
||||
search_key = key.replace(LoRANetwork.LORA_PREFIX_UNET + "_", "")
|
||||
position = bisect.bisect_right(map_keys, search_key)
|
||||
map_key = map_keys[position - 1]
|
||||
if search_key.startswith(map_key):
|
||||
new_key = key.replace(map_key, UNET_CONVERSION_MAP[map_key])
|
||||
modules_dim[new_key] = modules_dim[key]
|
||||
modules_alpha[new_key] = modules_alpha[key]
|
||||
del modules_dim[key]
|
||||
del modules_alpha[key]
|
||||
converted_count += 1
|
||||
else:
|
||||
not_converted_count += 1
|
||||
assert (
|
||||
converted_count == 0 or not_converted_count == 0
|
||||
), f"some modules are not converted: {converted_count} converted, {not_converted_count} not converted"
|
||||
return converted_count
|
||||
|
||||
def set_multiplier(self, multiplier):
|
||||
self.multiplier = multiplier
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.multiplier = self.multiplier
|
||||
|
||||
def apply_to(self, multiplier=1.0, apply_text_encoder=True, apply_unet=True):
|
||||
if apply_text_encoder:
|
||||
print("enable LoRA for text encoder")
|
||||
for lora in self.text_encoder_loras:
|
||||
lora.apply_to(multiplier)
|
||||
if apply_unet:
|
||||
print("enable LoRA for U-Net")
|
||||
for lora in self.unet_loras:
|
||||
lora.apply_to(multiplier)
|
||||
|
||||
def unapply_to(self):
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.unapply_to()
|
||||
|
||||
def merge_to(self, multiplier=1.0):
|
||||
print("merge LoRA weights to original weights")
|
||||
for lora in tqdm(self.text_encoder_loras + self.unet_loras):
|
||||
lora.merge_to(multiplier)
|
||||
print(f"weights are merged")
|
||||
|
||||
def restore_from(self, multiplier=1.0):
|
||||
print("restore LoRA weights from original weights")
|
||||
for lora in tqdm(self.text_encoder_loras + self.unet_loras):
|
||||
lora.restore_from(multiplier)
|
||||
print(f"weights are restored")
|
||||
|
||||
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
||||
# convert SDXL Stability AI's state dict to Diffusers' based state dict
|
||||
map_keys = list(UNET_CONVERSION_MAP.keys()) # prefix of U-Net modules
|
||||
map_keys.sort()
|
||||
for key in list(state_dict.keys()):
|
||||
if key.startswith(LoRANetwork.LORA_PREFIX_UNET + "_"):
|
||||
search_key = key.replace(LoRANetwork.LORA_PREFIX_UNET + "_", "")
|
||||
position = bisect.bisect_right(map_keys, search_key)
|
||||
map_key = map_keys[position - 1]
|
||||
if search_key.startswith(map_key):
|
||||
new_key = key.replace(map_key, UNET_CONVERSION_MAP[map_key])
|
||||
state_dict[new_key] = state_dict[key]
|
||||
del state_dict[key]
|
||||
|
||||
# in case of V2, some weights have different shape, so we need to convert them
|
||||
# because V2 LoRA is based on U-Net created by use_linear_projection=False
|
||||
my_state_dict = self.state_dict()
|
||||
for key in state_dict.keys():
|
||||
if state_dict[key].size() != my_state_dict[key].size():
|
||||
# print(f"convert {key} from {state_dict[key].size()} to {my_state_dict[key].size()}")
|
||||
state_dict[key] = state_dict[key].view(my_state_dict[key].size())
|
||||
|
||||
return super().load_state_dict(state_dict, strict)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# sample code to use LoRANetwork
|
||||
import os
|
||||
import argparse
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
|
||||
import torch
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_id", type=str, default=None, help="model id for huggingface")
|
||||
parser.add_argument("--lora_weights", type=str, default=None, help="path to LoRA weights")
|
||||
parser.add_argument("--sdxl", action="store_true", help="use SDXL model")
|
||||
parser.add_argument("--prompt", type=str, default="A photo of cat", help="prompt text")
|
||||
parser.add_argument("--negative_prompt", type=str, default="", help="negative prompt text")
|
||||
parser.add_argument("--seed", type=int, default=0, help="random seed")
|
||||
args = parser.parse_args()
|
||||
|
||||
image_prefix = args.model_id.replace("/", "_") + "_"
|
||||
|
||||
# load Diffusers model
|
||||
print(f"load model from {args.model_id}")
|
||||
pipe: Union[StableDiffusionPipeline, StableDiffusionXLPipeline]
|
||||
if args.sdxl:
|
||||
# use_safetensors=True does not work with 0.18.2
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(args.model_id, variant="fp16", torch_dtype=torch.float16)
|
||||
else:
|
||||
pipe = StableDiffusionPipeline.from_pretrained(args.model_id, variant="fp16", torch_dtype=torch.float16)
|
||||
pipe.to(device)
|
||||
pipe.set_use_memory_efficient_attention_xformers(True)
|
||||
|
||||
text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if args.sdxl else [pipe.text_encoder]
|
||||
|
||||
# load LoRA weights
|
||||
print(f"load LoRA weights from {args.lora_weights}")
|
||||
if os.path.splitext(args.lora_weights)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
lora_sd = load_file(args.lora_weights)
|
||||
else:
|
||||
lora_sd = torch.load(args.lora_weights)
|
||||
|
||||
# create by LoRA weights and load weights
|
||||
print(f"create LoRA network")
|
||||
lora_network: LoRANetwork = create_network_from_weights(text_encoders, pipe.unet, lora_sd, multiplier=1.0)
|
||||
|
||||
print(f"load LoRA network weights")
|
||||
lora_network.load_state_dict(lora_sd)
|
||||
|
||||
lora_network.to(device, dtype=pipe.unet.dtype) # required to apply_to. merge_to works without this
|
||||
|
||||
# 必要があれば、元のモデルの重みをバックアップしておく
|
||||
# back-up unet/text encoder weights if necessary
|
||||
def detach_and_move_to_cpu(state_dict):
|
||||
for k, v in state_dict.items():
|
||||
state_dict[k] = v.detach().cpu()
|
||||
return state_dict
|
||||
|
||||
org_unet_sd = pipe.unet.state_dict()
|
||||
detach_and_move_to_cpu(org_unet_sd)
|
||||
|
||||
org_text_encoder_sd = pipe.text_encoder.state_dict()
|
||||
detach_and_move_to_cpu(org_text_encoder_sd)
|
||||
|
||||
if args.sdxl:
|
||||
org_text_encoder_2_sd = pipe.text_encoder_2.state_dict()
|
||||
detach_and_move_to_cpu(org_text_encoder_2_sd)
|
||||
|
||||
def seed_everything(seed):
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
# create image with original weights
|
||||
print(f"create image with original weights")
|
||||
seed_everything(args.seed)
|
||||
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
||||
image.save(image_prefix + "original.png")
|
||||
|
||||
# apply LoRA network to the model: slower than merge_to, but can be reverted easily
|
||||
print(f"apply LoRA network to the model")
|
||||
lora_network.apply_to(multiplier=1.0)
|
||||
|
||||
print(f"create image with applied LoRA")
|
||||
seed_everything(args.seed)
|
||||
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
||||
image.save(image_prefix + "applied_lora.png")
|
||||
|
||||
# unapply LoRA network to the model
|
||||
print(f"unapply LoRA network to the model")
|
||||
lora_network.unapply_to()
|
||||
|
||||
print(f"create image with unapplied LoRA")
|
||||
seed_everything(args.seed)
|
||||
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
||||
image.save(image_prefix + "unapplied_lora.png")
|
||||
|
||||
# merge LoRA network to the model: faster than apply_to, but requires back-up of original weights (or unmerge_to)
|
||||
print(f"merge LoRA network to the model")
|
||||
lora_network.merge_to(multiplier=1.0)
|
||||
|
||||
print(f"create image with LoRA")
|
||||
seed_everything(args.seed)
|
||||
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
||||
image.save(image_prefix + "merged_lora.png")
|
||||
|
||||
# restore (unmerge) LoRA weights: numerically unstable
|
||||
# マージされた重みを元に戻す。計算誤差のため、元の重みと完全に一致しないことがあるかもしれない
|
||||
# 保存したstate_dictから元の重みを復元するのが確実
|
||||
print(f"restore (unmerge) LoRA weights")
|
||||
lora_network.restore_from(multiplier=1.0)
|
||||
|
||||
print(f"create image without LoRA")
|
||||
seed_everything(args.seed)
|
||||
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
||||
image.save(image_prefix + "unmerged_lora.png")
|
||||
|
||||
# restore original weights
|
||||
print(f"restore original weights")
|
||||
pipe.unet.load_state_dict(org_unet_sd)
|
||||
pipe.text_encoder.load_state_dict(org_text_encoder_sd)
|
||||
if args.sdxl:
|
||||
pipe.text_encoder_2.load_state_dict(org_text_encoder_2_sd)
|
||||
|
||||
print(f"create image with restored original weights")
|
||||
seed_everything(args.seed)
|
||||
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
||||
image.save(image_prefix + "restore_original.png")
|
||||
|
||||
# use convenience function to merge LoRA weights
|
||||
print(f"merge LoRA weights with convenience function")
|
||||
merge_lora_weights(pipe, lora_sd, multiplier=1.0)
|
||||
|
||||
print(f"create image with merged LoRA weights")
|
||||
seed_everything(args.seed)
|
||||
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
||||
image.save(image_prefix + "convenience_merged_lora.png")
|
||||
1241
networks/lora_fa.py
Normal file
1241
networks/lora_fa.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -2,6 +2,7 @@
|
||||
|
||||
from tqdm import tqdm
|
||||
from library import model_util
|
||||
import library.train_util as train_util
|
||||
import argparse
|
||||
from transformers import CLIPTokenizer
|
||||
import torch
|
||||
@@ -16,16 +17,20 @@ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
|
||||
def interrogate(args):
|
||||
weights_dtype = torch.float16
|
||||
|
||||
# いろいろ準備する
|
||||
print(f"loading SD model: {args.sd_model}")
|
||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
|
||||
args.pretrained_model_name_or_path = args.sd_model
|
||||
args.vae = None
|
||||
text_encoder, vae, unet, _ = train_util._load_target_model(args,weights_dtype, DEVICE)
|
||||
|
||||
print(f"loading LoRA: {args.model}")
|
||||
network = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
|
||||
network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
|
||||
|
||||
# text encoder向けの重みがあるかチェックする:本当はlora側でやるのがいい
|
||||
has_te_weight = False
|
||||
for key in network.weights_sd.keys():
|
||||
for key in weights_sd.keys():
|
||||
if 'lora_te' in key:
|
||||
has_te_weight = True
|
||||
break
|
||||
@@ -40,9 +45,9 @@ def interrogate(args):
|
||||
else:
|
||||
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2)
|
||||
|
||||
text_encoder.to(DEVICE)
|
||||
text_encoder.to(DEVICE, dtype=weights_dtype)
|
||||
text_encoder.eval()
|
||||
unet.to(DEVICE)
|
||||
unet.to(DEVICE, dtype=weights_dtype)
|
||||
unet.eval() # U-Netは呼び出さないので不要だけど
|
||||
|
||||
# トークンをひとつひとつ当たっていく
|
||||
@@ -78,9 +83,14 @@ def interrogate(args):
|
||||
orig_embs = get_all_embeddings(text_encoder)
|
||||
|
||||
network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0)
|
||||
network.to(DEVICE)
|
||||
info = network.load_state_dict(weights_sd, strict=False)
|
||||
print(f"Loading LoRA weights: {info}")
|
||||
|
||||
network.to(DEVICE, dtype=weights_dtype)
|
||||
network.eval()
|
||||
|
||||
del unet
|
||||
|
||||
print("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません(以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません)")
|
||||
print("get text encoder embeddings with lora.")
|
||||
lora_embs = get_all_embeddings(text_encoder)
|
||||
@@ -105,8 +115,9 @@ def interrogate(args):
|
||||
print(f"[{i:3d}]: {token:5d} {string:<20s}: {diff:.5f}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--v2", action='store_true',
|
||||
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
|
||||
parser.add_argument("--sd_model", type=str, default=None,
|
||||
@@ -118,5 +129,11 @@ if __name__ == '__main__':
|
||||
parser.add_argument("--clip_skip", type=int, default=None,
|
||||
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
interrogate(args)
|
||||
|
||||
@@ -1,218 +1,357 @@
|
||||
|
||||
import math
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from library import sai_model_spec, train_util
|
||||
import library.model_util as model_util
|
||||
import lora
|
||||
|
||||
|
||||
def load_state_dict(file_name, dtype):
|
||||
if os.path.splitext(file_name)[1] == '.safetensors':
|
||||
sd = load_file(file_name)
|
||||
else:
|
||||
sd = torch.load(file_name, map_location='cpu')
|
||||
for key in list(sd.keys()):
|
||||
if type(sd[key]) == torch.Tensor:
|
||||
sd[key] = sd[key].to(dtype)
|
||||
return sd
|
||||
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||
sd = load_file(file_name)
|
||||
metadata = train_util.load_metadata_from_safetensors(file_name)
|
||||
else:
|
||||
sd = torch.load(file_name, map_location="cpu")
|
||||
metadata = {}
|
||||
|
||||
for key in list(sd.keys()):
|
||||
if type(sd[key]) == torch.Tensor:
|
||||
sd[key] = sd[key].to(dtype)
|
||||
|
||||
return sd, metadata
|
||||
|
||||
|
||||
def save_to_file(file_name, model, state_dict, dtype):
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
if type(state_dict[key]) == torch.Tensor:
|
||||
state_dict[key] = state_dict[key].to(dtype)
|
||||
def save_to_file(file_name, model, state_dict, dtype, metadata):
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
if type(state_dict[key]) == torch.Tensor:
|
||||
state_dict[key] = state_dict[key].to(dtype)
|
||||
|
||||
if os.path.splitext(file_name)[1] == '.safetensors':
|
||||
save_file(model, file_name)
|
||||
else:
|
||||
torch.save(model, file_name)
|
||||
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||
save_file(model, file_name, metadata=metadata)
|
||||
else:
|
||||
torch.save(model, file_name)
|
||||
|
||||
|
||||
def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
||||
text_encoder.to(merge_dtype)
|
||||
unet.to(merge_dtype)
|
||||
text_encoder.to(merge_dtype)
|
||||
unet.to(merge_dtype)
|
||||
|
||||
# create module map
|
||||
name_to_module = {}
|
||||
for i, root_module in enumerate([text_encoder, unet]):
|
||||
if i == 0:
|
||||
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER
|
||||
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
|
||||
else:
|
||||
prefix = lora.LoRANetwork.LORA_PREFIX_UNET
|
||||
target_replace_modules = lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE
|
||||
|
||||
for name, module in root_module.named_modules():
|
||||
if module.__class__.__name__ in target_replace_modules:
|
||||
for child_name, child_module in module.named_modules():
|
||||
if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
|
||||
lora_name = prefix + '.' + name + '.' + child_name
|
||||
lora_name = lora_name.replace('.', '_')
|
||||
name_to_module[lora_name] = child_module
|
||||
|
||||
for model, ratio in zip(models, ratios):
|
||||
print(f"loading: {model}")
|
||||
lora_sd = load_state_dict(model, merge_dtype)
|
||||
|
||||
print(f"merging...")
|
||||
for key in lora_sd.keys():
|
||||
if "lora_down" in key:
|
||||
up_key = key.replace("lora_down", "lora_up")
|
||||
alpha_key = key[:key.index("lora_down")] + 'alpha'
|
||||
|
||||
# find original module for this lora
|
||||
module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight"
|
||||
if module_name not in name_to_module:
|
||||
print(f"no module found for LoRA weight: {key}")
|
||||
continue
|
||||
module = name_to_module[module_name]
|
||||
# print(f"apply {key} to {module}")
|
||||
|
||||
down_weight = lora_sd[key]
|
||||
up_weight = lora_sd[up_key]
|
||||
|
||||
dim = down_weight.size()[0]
|
||||
alpha = lora_sd.get(alpha_key, dim)
|
||||
scale = alpha / dim
|
||||
|
||||
# W <- W + U * D
|
||||
weight = module.weight
|
||||
# print(module_name, down_weight.size(), up_weight.size())
|
||||
if len(weight.size()) == 2:
|
||||
# linear
|
||||
weight = weight + ratio * (up_weight @ down_weight) * scale
|
||||
elif down_weight.size()[2:4] == (1, 1):
|
||||
# conv2d 1x1
|
||||
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
|
||||
).unsqueeze(2).unsqueeze(3) * scale
|
||||
# create module map
|
||||
name_to_module = {}
|
||||
for i, root_module in enumerate([text_encoder, unet]):
|
||||
if i == 0:
|
||||
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER
|
||||
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
|
||||
else:
|
||||
# conv2d 3x3
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
# print(conved.size(), weight.size(), module.stride, module.padding)
|
||||
weight = weight + ratio * conved * scale
|
||||
prefix = lora.LoRANetwork.LORA_PREFIX_UNET
|
||||
target_replace_modules = (
|
||||
lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
)
|
||||
|
||||
module.weight = torch.nn.Parameter(weight)
|
||||
for name, module in root_module.named_modules():
|
||||
if module.__class__.__name__ in target_replace_modules:
|
||||
for child_name, child_module in module.named_modules():
|
||||
if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
|
||||
lora_name = prefix + "." + name + "." + child_name
|
||||
lora_name = lora_name.replace(".", "_")
|
||||
name_to_module[lora_name] = child_module
|
||||
|
||||
for model, ratio in zip(models, ratios):
|
||||
print(f"loading: {model}")
|
||||
lora_sd, _ = load_state_dict(model, merge_dtype)
|
||||
|
||||
print(f"merging...")
|
||||
for key in lora_sd.keys():
|
||||
if "lora_down" in key:
|
||||
up_key = key.replace("lora_down", "lora_up")
|
||||
alpha_key = key[: key.index("lora_down")] + "alpha"
|
||||
|
||||
# find original module for this lora
|
||||
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
|
||||
if module_name not in name_to_module:
|
||||
print(f"no module found for LoRA weight: {key}")
|
||||
continue
|
||||
module = name_to_module[module_name]
|
||||
# print(f"apply {key} to {module}")
|
||||
|
||||
down_weight = lora_sd[key]
|
||||
up_weight = lora_sd[up_key]
|
||||
|
||||
dim = down_weight.size()[0]
|
||||
alpha = lora_sd.get(alpha_key, dim)
|
||||
scale = alpha / dim
|
||||
|
||||
# W <- W + U * D
|
||||
weight = module.weight
|
||||
if len(weight.size()) == 2:
|
||||
# linear
|
||||
if len(up_weight.size()) == 4: # use linear projection mismatch
|
||||
up_weight = up_weight.squeeze(3).squeeze(2)
|
||||
down_weight = down_weight.squeeze(3).squeeze(2)
|
||||
weight = weight + ratio * (up_weight @ down_weight) * scale
|
||||
elif down_weight.size()[2:4] == (1, 1):
|
||||
# conv2d 1x1
|
||||
weight = (
|
||||
weight
|
||||
+ ratio
|
||||
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
* scale
|
||||
)
|
||||
else:
|
||||
# conv2d 3x3
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
# print(conved.size(), weight.size(), module.stride, module.padding)
|
||||
weight = weight + ratio * conved * scale
|
||||
|
||||
module.weight = torch.nn.Parameter(weight)
|
||||
|
||||
|
||||
def merge_lora_models(models, ratios, merge_dtype):
|
||||
base_alphas = {} # alpha for merged model
|
||||
base_dims = {}
|
||||
def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
|
||||
base_alphas = {} # alpha for merged model
|
||||
base_dims = {}
|
||||
|
||||
merged_sd = {}
|
||||
for model, ratio in zip(models, ratios):
|
||||
print(f"loading: {model}")
|
||||
lora_sd = load_state_dict(model, merge_dtype)
|
||||
merged_sd = {}
|
||||
v2 = None
|
||||
base_model = None
|
||||
for model, ratio in zip(models, ratios):
|
||||
print(f"loading: {model}")
|
||||
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
|
||||
|
||||
# get alpha and dim
|
||||
alphas = {} # alpha for current model
|
||||
dims = {} # dims for current model
|
||||
for key in lora_sd.keys():
|
||||
if 'alpha' in key:
|
||||
lora_module_name = key[:key.rfind(".alpha")]
|
||||
alpha = float(lora_sd[key].detach().numpy())
|
||||
alphas[lora_module_name] = alpha
|
||||
if lora_module_name not in base_alphas:
|
||||
base_alphas[lora_module_name] = alpha
|
||||
elif "lora_down" in key:
|
||||
lora_module_name = key[:key.rfind(".lora_down")]
|
||||
dim = lora_sd[key].size()[0]
|
||||
dims[lora_module_name] = dim
|
||||
if lora_module_name not in base_dims:
|
||||
base_dims[lora_module_name] = dim
|
||||
if lora_metadata is not None:
|
||||
if v2 is None:
|
||||
v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # return string
|
||||
if base_model is None:
|
||||
base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None)
|
||||
|
||||
for lora_module_name in dims.keys():
|
||||
if lora_module_name not in alphas:
|
||||
alpha = dims[lora_module_name]
|
||||
alphas[lora_module_name] = alpha
|
||||
if lora_module_name not in base_alphas:
|
||||
base_alphas[lora_module_name] = alpha
|
||||
# get alpha and dim
|
||||
alphas = {} # alpha for current model
|
||||
dims = {} # dims for current model
|
||||
for key in lora_sd.keys():
|
||||
if "alpha" in key:
|
||||
lora_module_name = key[: key.rfind(".alpha")]
|
||||
alpha = float(lora_sd[key].detach().numpy())
|
||||
alphas[lora_module_name] = alpha
|
||||
if lora_module_name not in base_alphas:
|
||||
base_alphas[lora_module_name] = alpha
|
||||
elif "lora_down" in key:
|
||||
lora_module_name = key[: key.rfind(".lora_down")]
|
||||
dim = lora_sd[key].size()[0]
|
||||
dims[lora_module_name] = dim
|
||||
if lora_module_name not in base_dims:
|
||||
base_dims[lora_module_name] = dim
|
||||
|
||||
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
|
||||
for lora_module_name in dims.keys():
|
||||
if lora_module_name not in alphas:
|
||||
alpha = dims[lora_module_name]
|
||||
alphas[lora_module_name] = alpha
|
||||
if lora_module_name not in base_alphas:
|
||||
base_alphas[lora_module_name] = alpha
|
||||
|
||||
# merge
|
||||
print(f"merging...")
|
||||
for key in lora_sd.keys():
|
||||
if 'alpha' in key:
|
||||
continue
|
||||
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
|
||||
|
||||
lora_module_name = key[:key.rfind(".lora_")]
|
||||
# merge
|
||||
print(f"merging...")
|
||||
for key in lora_sd.keys():
|
||||
if "alpha" in key:
|
||||
continue
|
||||
if "lora_up" in key and concat:
|
||||
concat_dim = 1
|
||||
elif "lora_down" in key and concat:
|
||||
concat_dim = 0
|
||||
else:
|
||||
concat_dim = None
|
||||
|
||||
base_alpha = base_alphas[lora_module_name]
|
||||
alpha = alphas[lora_module_name]
|
||||
lora_module_name = key[: key.rfind(".lora_")]
|
||||
|
||||
scale = math.sqrt(alpha / base_alpha) * ratio
|
||||
base_alpha = base_alphas[lora_module_name]
|
||||
alpha = alphas[lora_module_name]
|
||||
|
||||
if key in merged_sd:
|
||||
assert merged_sd[key].size() == lora_sd[key].size(
|
||||
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
|
||||
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
|
||||
else:
|
||||
merged_sd[key] = lora_sd[key] * scale
|
||||
scale = math.sqrt(alpha / base_alpha) * ratio
|
||||
scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。
|
||||
|
||||
# set alpha to sd
|
||||
for lora_module_name, alpha in base_alphas.items():
|
||||
key = lora_module_name + ".alpha"
|
||||
merged_sd[key] = torch.tensor(alpha)
|
||||
if key in merged_sd:
|
||||
assert (
|
||||
merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None
|
||||
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
|
||||
if concat_dim is not None:
|
||||
merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim)
|
||||
else:
|
||||
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
|
||||
else:
|
||||
merged_sd[key] = lora_sd[key] * scale
|
||||
|
||||
print("merged model")
|
||||
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
|
||||
# set alpha to sd
|
||||
for lora_module_name, alpha in base_alphas.items():
|
||||
key = lora_module_name + ".alpha"
|
||||
merged_sd[key] = torch.tensor(alpha)
|
||||
if shuffle:
|
||||
key_down = lora_module_name + ".lora_down.weight"
|
||||
key_up = lora_module_name + ".lora_up.weight"
|
||||
dim = merged_sd[key_down].shape[0]
|
||||
perm = torch.randperm(dim)
|
||||
merged_sd[key_down] = merged_sd[key_down][perm]
|
||||
merged_sd[key_up] = merged_sd[key_up][:,perm]
|
||||
|
||||
return merged_sd
|
||||
print("merged model")
|
||||
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
|
||||
|
||||
# check all dims are same
|
||||
dims_list = list(set(base_dims.values()))
|
||||
alphas_list = list(set(base_alphas.values()))
|
||||
all_same_dims = True
|
||||
all_same_alphas = True
|
||||
for dims in dims_list:
|
||||
if dims != dims_list[0]:
|
||||
all_same_dims = False
|
||||
break
|
||||
for alphas in alphas_list:
|
||||
if alphas != alphas_list[0]:
|
||||
all_same_alphas = False
|
||||
break
|
||||
|
||||
# build minimum metadata
|
||||
dims = f"{dims_list[0]}" if all_same_dims else "Dynamic"
|
||||
alphas = f"{alphas_list[0]}" if all_same_alphas else "Dynamic"
|
||||
metadata = train_util.build_minimum_network_metadata(v2, base_model, "networks.lora", dims, alphas, None)
|
||||
|
||||
return merged_sd, metadata, v2 == "True"
|
||||
|
||||
|
||||
def merge(args):
|
||||
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
|
||||
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
|
||||
|
||||
def str_to_dtype(p):
|
||||
if p == 'float':
|
||||
return torch.float
|
||||
if p == 'fp16':
|
||||
return torch.float16
|
||||
if p == 'bf16':
|
||||
return torch.bfloat16
|
||||
return None
|
||||
def str_to_dtype(p):
|
||||
if p == "float":
|
||||
return torch.float
|
||||
if p == "fp16":
|
||||
return torch.float16
|
||||
if p == "bf16":
|
||||
return torch.bfloat16
|
||||
return None
|
||||
|
||||
merge_dtype = str_to_dtype(args.precision)
|
||||
save_dtype = str_to_dtype(args.save_precision)
|
||||
if save_dtype is None:
|
||||
save_dtype = merge_dtype
|
||||
merge_dtype = str_to_dtype(args.precision)
|
||||
save_dtype = str_to_dtype(args.save_precision)
|
||||
if save_dtype is None:
|
||||
save_dtype = merge_dtype
|
||||
|
||||
if args.sd_model is not None:
|
||||
print(f"loading SD model: {args.sd_model}")
|
||||
if args.sd_model is not None:
|
||||
print(f"loading SD model: {args.sd_model}")
|
||||
|
||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
|
||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
|
||||
|
||||
merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype)
|
||||
merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype)
|
||||
|
||||
print(f"saving SD model to: {args.save_to}")
|
||||
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet,
|
||||
args.sd_model, 0, 0, save_dtype, vae)
|
||||
else:
|
||||
state_dict = merge_lora_models(args.models, args.ratios, merge_dtype)
|
||||
if args.no_metadata:
|
||||
sai_metadata = None
|
||||
else:
|
||||
merged_from = sai_model_spec.build_merged_from([args.sd_model] + args.models)
|
||||
title = os.path.splitext(os.path.basename(args.save_to))[0]
|
||||
sai_metadata = sai_model_spec.build_metadata(
|
||||
None,
|
||||
args.v2,
|
||||
args.v2,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
time.time(),
|
||||
title=title,
|
||||
merged_from=merged_from,
|
||||
is_stable_diffusion_ckpt=True,
|
||||
)
|
||||
if args.v2:
|
||||
# TODO read sai modelspec
|
||||
print(
|
||||
"Cannot determine if model is for v-prediction, so save metadata as v-prediction / modelがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
|
||||
)
|
||||
|
||||
print(f"saving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
|
||||
print(f"saving SD model to: {args.save_to}")
|
||||
model_util.save_stable_diffusion_checkpoint(
|
||||
args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, sai_metadata, save_dtype, vae
|
||||
)
|
||||
else:
|
||||
state_dict, metadata, v2 = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
|
||||
|
||||
print(f"calculating hashes and creating metadata...")
|
||||
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
metadata["sshs_legacy_hash"] = legacy_hash
|
||||
|
||||
if not args.no_metadata:
|
||||
merged_from = sai_model_spec.build_merged_from(args.models)
|
||||
title = os.path.splitext(os.path.basename(args.save_to))[0]
|
||||
sai_metadata = sai_model_spec.build_metadata(
|
||||
state_dict, v2, v2, False, True, False, time.time(), title=title, merged_from=merged_from
|
||||
)
|
||||
if v2:
|
||||
# TODO read sai modelspec
|
||||
print(
|
||||
"Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
|
||||
)
|
||||
metadata.update(sai_metadata)
|
||||
|
||||
print(f"saving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--v2", action='store_true',
|
||||
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
|
||||
parser.add_argument("--save_precision", type=str, default=None,
|
||||
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ")
|
||||
parser.add_argument("--precision", type=str, default="float",
|
||||
choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)")
|
||||
parser.add_argument("--sd_model", type=str, default=None,
|
||||
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする")
|
||||
parser.add_argument("--save_to", type=str, default=None,
|
||||
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
|
||||
parser.add_argument("--models", type=str, nargs='*',
|
||||
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors")
|
||||
parser.add_argument("--ratios", type=float, nargs='*',
|
||||
help="ratios for each model / それぞれのLoRAモデルの比率")
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む")
|
||||
parser.add_argument(
|
||||
"--save_precision",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=[None, "float", "fp16", "bf16"],
|
||||
help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
type=str,
|
||||
default="float",
|
||||
choices=["float", "fp16", "bf16"],
|
||||
help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sd_model",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
|
||||
)
|
||||
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
|
||||
parser.add_argument(
|
||||
"--no_metadata",
|
||||
action="store_true",
|
||||
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
|
||||
+ "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--concat",
|
||||
action="store_true",
|
||||
help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / "
|
||||
+ "マージの代わりに結合する(LoRAのdim(rank)は入力dimの合計になる)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shuffle",
|
||||
action="store_true",
|
||||
help="shuffle lora weight./ "
|
||||
+ "LoRAの重みをシャッフルする",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
args = parser.parse_args()
|
||||
merge(args)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
merge(args)
|
||||
|
||||
@@ -148,17 +148,17 @@ def merge(args):
|
||||
|
||||
merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype)
|
||||
|
||||
print(f"saving SD model to: {args.save_to}")
|
||||
print(f"\nsaving SD model to: {args.save_to}")
|
||||
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet,
|
||||
args.sd_model, 0, 0, save_dtype, vae)
|
||||
else:
|
||||
state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype)
|
||||
|
||||
print(f"saving model to: {args.save_to}")
|
||||
print(f"\nsaving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--v2", action='store_true',
|
||||
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
|
||||
@@ -175,5 +175,11 @@ if __name__ == '__main__':
|
||||
parser.add_argument("--ratios", type=float, nargs='*',
|
||||
help="ratios for each model / それぞれのLoRAモデルの比率")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
merge(args)
|
||||
|
||||
430
networks/oft.py
Normal file
430
networks/oft.py
Normal file
@@ -0,0 +1,430 @@
|
||||
# OFT network module
|
||||
|
||||
import math
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
from diffusers import AutoencoderKL
|
||||
from transformers import CLIPTextModel
|
||||
import numpy as np
|
||||
import torch
|
||||
import re
|
||||
|
||||
|
||||
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
||||
|
||||
|
||||
class OFTModule(torch.nn.Module):
|
||||
"""
|
||||
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
oft_name,
|
||||
org_module: torch.nn.Module,
|
||||
multiplier=1.0,
|
||||
dim=4,
|
||||
alpha=1,
|
||||
):
|
||||
"""
|
||||
dim -> num blocks
|
||||
alpha -> constraint
|
||||
"""
|
||||
super().__init__()
|
||||
self.oft_name = oft_name
|
||||
|
||||
self.num_blocks = dim
|
||||
|
||||
if "Linear" in org_module.__class__.__name__:
|
||||
out_dim = org_module.out_features
|
||||
elif "Conv" in org_module.__class__.__name__:
|
||||
out_dim = org_module.out_channels
|
||||
|
||||
if type(alpha) == torch.Tensor:
|
||||
alpha = alpha.detach().numpy()
|
||||
self.constraint = alpha * out_dim
|
||||
self.register_buffer("alpha", torch.tensor(alpha))
|
||||
|
||||
self.block_size = out_dim // self.num_blocks
|
||||
self.oft_blocks = torch.nn.Parameter(torch.zeros(self.num_blocks, self.block_size, self.block_size))
|
||||
|
||||
self.out_dim = out_dim
|
||||
self.shape = org_module.weight.shape
|
||||
|
||||
self.multiplier = multiplier
|
||||
self.org_module = [org_module] # moduleにならないようにlistに入れる
|
||||
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module[0].forward
|
||||
self.org_module[0].forward = self.forward
|
||||
|
||||
def get_weight(self, multiplier=None):
|
||||
if multiplier is None:
|
||||
multiplier = self.multiplier
|
||||
|
||||
block_Q = self.oft_blocks - self.oft_blocks.transpose(1, 2)
|
||||
norm_Q = torch.norm(block_Q.flatten())
|
||||
new_norm_Q = torch.clamp(norm_Q, max=self.constraint)
|
||||
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
|
||||
I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1)
|
||||
block_R = torch.matmul(I + block_Q, (I - block_Q).inverse())
|
||||
|
||||
block_R_weighted = self.multiplier * block_R + (1 - self.multiplier) * I
|
||||
R = torch.block_diag(*block_R_weighted)
|
||||
|
||||
return R
|
||||
|
||||
def forward(self, x, scale=None):
|
||||
x = self.org_forward(x)
|
||||
if self.multiplier == 0.0:
|
||||
return x
|
||||
|
||||
R = self.get_weight().to(x.device, dtype=x.dtype)
|
||||
if x.dim() == 4:
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
x = torch.matmul(x, R)
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
else:
|
||||
x = torch.matmul(x, R)
|
||||
return x
|
||||
|
||||
|
||||
class OFTInfModule(OFTModule):
|
||||
def __init__(
|
||||
self,
|
||||
oft_name,
|
||||
org_module: torch.nn.Module,
|
||||
multiplier=1.0,
|
||||
dim=4,
|
||||
alpha=1,
|
||||
**kwargs,
|
||||
):
|
||||
# no dropout for inference
|
||||
super().__init__(oft_name, org_module, multiplier, dim, alpha)
|
||||
self.enabled = True
|
||||
self.network: OFTNetwork = None
|
||||
|
||||
def set_network(self, network):
|
||||
self.network = network
|
||||
|
||||
def forward(self, x, scale=None):
|
||||
if not self.enabled:
|
||||
return self.org_forward(x)
|
||||
return super().forward(x, scale)
|
||||
|
||||
def merge_to(self, multiplier=None, sign=1):
|
||||
R = self.get_weight(multiplier) * sign
|
||||
|
||||
# get org weight
|
||||
org_sd = self.org_module[0].state_dict()
|
||||
org_weight = org_sd["weight"]
|
||||
R = R.to(org_weight.device, dtype=org_weight.dtype)
|
||||
|
||||
if org_weight.dim() == 4:
|
||||
weight = torch.einsum("oihw, op -> pihw", org_weight, R)
|
||||
else:
|
||||
weight = torch.einsum("oi, op -> pi", org_weight, R)
|
||||
|
||||
# set weight to org_module
|
||||
org_sd["weight"] = weight
|
||||
self.org_module[0].load_state_dict(org_sd)
|
||||
|
||||
|
||||
def create_network(
|
||||
multiplier: float,
|
||||
network_dim: Optional[int],
|
||||
network_alpha: Optional[float],
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
|
||||
unet,
|
||||
neuron_dropout: Optional[float] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if network_dim is None:
|
||||
network_dim = 4 # default
|
||||
if network_alpha is None:
|
||||
network_alpha = 1.0
|
||||
|
||||
enable_all_linear = kwargs.get("enable_all_linear", None)
|
||||
enable_conv = kwargs.get("enable_conv", None)
|
||||
if enable_all_linear is not None:
|
||||
enable_all_linear = bool(enable_all_linear)
|
||||
if enable_conv is not None:
|
||||
enable_conv = bool(enable_conv)
|
||||
|
||||
network = OFTNetwork(
|
||||
text_encoder,
|
||||
unet,
|
||||
multiplier=multiplier,
|
||||
dim=network_dim,
|
||||
alpha=network_alpha,
|
||||
enable_all_linear=enable_all_linear,
|
||||
enable_conv=enable_conv,
|
||||
varbose=True,
|
||||
)
|
||||
return network
|
||||
|
||||
|
||||
# Create network from weights for inference, weights are not loaded here (because can be merged)
|
||||
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
|
||||
if weights_sd is None:
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file, safe_open
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
|
||||
# check dim, alpha and if weights have for conv2d
|
||||
dim = None
|
||||
alpha = None
|
||||
has_conv2d = None
|
||||
all_linear = None
|
||||
for name, param in weights_sd.items():
|
||||
if name.endswith(".alpha"):
|
||||
if alpha is None:
|
||||
alpha = param.item()
|
||||
else:
|
||||
if dim is None:
|
||||
dim = param.size()[0]
|
||||
if has_conv2d is None and param.dim() == 4:
|
||||
has_conv2d = True
|
||||
if all_linear is None:
|
||||
if param.dim() == 3 and "attn" not in name:
|
||||
all_linear = True
|
||||
if dim is not None and alpha is not None and has_conv2d is not None:
|
||||
break
|
||||
if has_conv2d is None:
|
||||
has_conv2d = False
|
||||
if all_linear is None:
|
||||
all_linear = False
|
||||
|
||||
module_class = OFTInfModule if for_inference else OFTModule
|
||||
network = OFTNetwork(
|
||||
text_encoder,
|
||||
unet,
|
||||
multiplier=multiplier,
|
||||
dim=dim,
|
||||
alpha=alpha,
|
||||
enable_all_linear=all_linear,
|
||||
enable_conv=has_conv2d,
|
||||
module_class=module_class,
|
||||
)
|
||||
return network, weights_sd
|
||||
|
||||
|
||||
class OFTNetwork(torch.nn.Module):
|
||||
UNET_TARGET_REPLACE_MODULE_ATTN_ONLY = ["CrossAttention"]
|
||||
UNET_TARGET_REPLACE_MODULE_ALL_LINEAR = ["Transformer2DModel"]
|
||||
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||
OFT_PREFIX_UNET = "oft_unet" # これ変えないほうがいいかな
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
|
||||
unet,
|
||||
multiplier: float = 1.0,
|
||||
dim: int = 4,
|
||||
alpha: float = 1,
|
||||
enable_all_linear: Optional[bool] = False,
|
||||
enable_conv: Optional[bool] = False,
|
||||
module_class: Type[object] = OFTModule,
|
||||
varbose: Optional[bool] = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.multiplier = multiplier
|
||||
|
||||
self.dim = dim
|
||||
self.alpha = alpha
|
||||
|
||||
print(
|
||||
f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}"
|
||||
)
|
||||
|
||||
# create module instances
|
||||
def create_modules(
|
||||
root_module: torch.nn.Module,
|
||||
target_replace_modules: List[torch.nn.Module],
|
||||
) -> List[OFTModule]:
|
||||
prefix = self.OFT_PREFIX_UNET
|
||||
ofts = []
|
||||
for name, module in root_module.named_modules():
|
||||
if module.__class__.__name__ in target_replace_modules:
|
||||
for child_name, child_module in module.named_modules():
|
||||
is_linear = "Linear" in child_module.__class__.__name__
|
||||
is_conv2d = "Conv2d" in child_module.__class__.__name__
|
||||
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
||||
|
||||
if is_linear or is_conv2d_1x1 or (is_conv2d and enable_conv):
|
||||
oft_name = prefix + "." + name + "." + child_name
|
||||
oft_name = oft_name.replace(".", "_")
|
||||
# print(oft_name)
|
||||
|
||||
oft = module_class(
|
||||
oft_name,
|
||||
child_module,
|
||||
self.multiplier,
|
||||
dim,
|
||||
alpha,
|
||||
)
|
||||
ofts.append(oft)
|
||||
return ofts
|
||||
|
||||
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
||||
if enable_all_linear:
|
||||
target_modules = OFTNetwork.UNET_TARGET_REPLACE_MODULE_ALL_LINEAR
|
||||
else:
|
||||
target_modules = OFTNetwork.UNET_TARGET_REPLACE_MODULE_ATTN_ONLY
|
||||
if enable_conv:
|
||||
target_modules += OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
|
||||
self.unet_ofts: List[OFTModule] = create_modules(unet, target_modules)
|
||||
print(f"create OFT for U-Net: {len(self.unet_ofts)} modules.")
|
||||
|
||||
# assertion
|
||||
names = set()
|
||||
for oft in self.unet_ofts:
|
||||
assert oft.oft_name not in names, f"duplicated oft name: {oft.oft_name}"
|
||||
names.add(oft.oft_name)
|
||||
|
||||
def set_multiplier(self, multiplier):
|
||||
self.multiplier = multiplier
|
||||
for oft in self.unet_ofts:
|
||||
oft.multiplier = self.multiplier
|
||||
|
||||
def load_weights(self, file):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
|
||||
info = self.load_state_dict(weights_sd, False)
|
||||
return info
|
||||
|
||||
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
|
||||
assert apply_unet, "apply_unet must be True"
|
||||
|
||||
for oft in self.unet_ofts:
|
||||
oft.apply_to()
|
||||
self.add_module(oft.oft_name, oft)
|
||||
|
||||
# マージできるかどうかを返す
|
||||
def is_mergeable(self):
|
||||
return True
|
||||
|
||||
# TODO refactor to common function with apply_to
|
||||
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
|
||||
print("enable OFT for U-Net")
|
||||
|
||||
for oft in self.unet_ofts:
|
||||
sd_for_lora = {}
|
||||
for key in weights_sd.keys():
|
||||
if key.startswith(oft.oft_name):
|
||||
sd_for_lora[key[len(oft.oft_name) + 1 :]] = weights_sd[key]
|
||||
oft.load_state_dict(sd_for_lora, False)
|
||||
oft.merge_to()
|
||||
|
||||
print(f"weights are merged")
|
||||
|
||||
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
||||
self.requires_grad_(True)
|
||||
all_params = []
|
||||
|
||||
def enumerate_params(ofts):
|
||||
params = []
|
||||
for oft in ofts:
|
||||
params.extend(oft.parameters())
|
||||
|
||||
# print num of params
|
||||
num_params = 0
|
||||
for p in params:
|
||||
num_params += p.numel()
|
||||
print(f"OFT params: {num_params}")
|
||||
return params
|
||||
|
||||
param_data = {"params": enumerate_params(self.unet_ofts)}
|
||||
if unet_lr is not None:
|
||||
param_data["lr"] = unet_lr
|
||||
all_params.append(param_data)
|
||||
|
||||
return all_params
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
# not supported
|
||||
pass
|
||||
|
||||
def prepare_grad_etc(self, text_encoder, unet):
|
||||
self.requires_grad_(True)
|
||||
|
||||
def on_epoch_start(self, text_encoder, unet):
|
||||
self.train()
|
||||
|
||||
def get_trainable_params(self):
|
||||
return self.parameters()
|
||||
|
||||
def save_weights(self, file, dtype, metadata):
|
||||
if metadata is not None and len(metadata) == 0:
|
||||
metadata = None
|
||||
|
||||
state_dict = self.state_dict()
|
||||
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
v = state_dict[key]
|
||||
v = v.detach().clone().to("cpu").to(dtype)
|
||||
state_dict[key] = v
|
||||
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import save_file
|
||||
from library import train_util
|
||||
|
||||
# Precalculate model hashes to save time on indexing
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
metadata["sshs_legacy_hash"] = legacy_hash
|
||||
|
||||
save_file(state_dict, file, metadata)
|
||||
else:
|
||||
torch.save(state_dict, file)
|
||||
|
||||
def backup_weights(self):
|
||||
# 重みのバックアップを行う
|
||||
ofts: List[OFTInfModule] = self.unet_ofts
|
||||
for oft in ofts:
|
||||
org_module = oft.org_module[0]
|
||||
if not hasattr(org_module, "_lora_org_weight"):
|
||||
sd = org_module.state_dict()
|
||||
org_module._lora_org_weight = sd["weight"].detach().clone()
|
||||
org_module._lora_restored = True
|
||||
|
||||
def restore_weights(self):
|
||||
# 重みのリストアを行う
|
||||
ofts: List[OFTInfModule] = self.unet_ofts
|
||||
for oft in ofts:
|
||||
org_module = oft.org_module[0]
|
||||
if not org_module._lora_restored:
|
||||
sd = org_module.state_dict()
|
||||
sd["weight"] = org_module._lora_org_weight
|
||||
org_module.load_state_dict(sd)
|
||||
org_module._lora_restored = True
|
||||
|
||||
def pre_calculation(self):
|
||||
# 事前計算を行う
|
||||
ofts: List[OFTInfModule] = self.unet_ofts
|
||||
for oft in ofts:
|
||||
org_module = oft.org_module[0]
|
||||
oft.merge_to()
|
||||
# sd = org_module.state_dict()
|
||||
# org_weight = sd["weight"]
|
||||
# lora_weight = oft.get_weight().to(org_weight.device, dtype=org_weight.dtype)
|
||||
# sd["weight"] = org_weight + lora_weight
|
||||
# assert sd["weight"].shape == org_weight.shape
|
||||
# org_module.load_state_dict(sd)
|
||||
|
||||
org_module._lora_restored = False
|
||||
oft.enabled = False
|
||||
@@ -11,6 +11,8 @@ import numpy as np
|
||||
|
||||
MIN_SV = 1e-6
|
||||
|
||||
# Model save and load functions
|
||||
|
||||
def load_state_dict(file_name, dtype):
|
||||
if model_util.is_safetensors(file_name):
|
||||
sd = load_file(file_name)
|
||||
@@ -39,12 +41,13 @@ def save_to_file(file_name, model, state_dict, dtype, metadata):
|
||||
torch.save(model, file_name)
|
||||
|
||||
|
||||
# Indexing functions
|
||||
|
||||
def index_sv_cumulative(S, target):
|
||||
original_sum = float(torch.sum(S))
|
||||
cumulative_sums = torch.cumsum(S, dim=0)/original_sum
|
||||
index = int(torch.searchsorted(cumulative_sums, target)) + 1
|
||||
if index >= len(S):
|
||||
index = len(S) - 1
|
||||
index = max(1, min(index, len(S)-1))
|
||||
|
||||
return index
|
||||
|
||||
@@ -54,8 +57,16 @@ def index_sv_fro(S, target):
|
||||
s_fro_sq = float(torch.sum(S_squared))
|
||||
sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq
|
||||
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
|
||||
if index >= len(S):
|
||||
index = len(S) - 1
|
||||
index = max(1, min(index, len(S)-1))
|
||||
|
||||
return index
|
||||
|
||||
|
||||
def index_sv_ratio(S, target):
|
||||
max_sv = S[0]
|
||||
min_sv = max_sv/target
|
||||
index = int(torch.sum(S > min_sv).item())
|
||||
index = max(1, min(index, len(S)-1))
|
||||
|
||||
return index
|
||||
|
||||
@@ -125,26 +136,24 @@ def merge_linear(lora_down, lora_up, device):
|
||||
return weight
|
||||
|
||||
|
||||
# Calculate new rank
|
||||
|
||||
def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
|
||||
param_dict = {}
|
||||
|
||||
if dynamic_method=="sv_ratio":
|
||||
# Calculate new dim and alpha based off ratio
|
||||
max_sv = S[0]
|
||||
min_sv = max_sv/dynamic_param
|
||||
new_rank = max(torch.sum(S > min_sv).item(),1)
|
||||
new_rank = index_sv_ratio(S, dynamic_param) + 1
|
||||
new_alpha = float(scale*new_rank)
|
||||
|
||||
elif dynamic_method=="sv_cumulative":
|
||||
# Calculate new dim and alpha based off cumulative sum
|
||||
new_rank = index_sv_cumulative(S, dynamic_param)
|
||||
new_rank = max(new_rank, 1)
|
||||
new_rank = index_sv_cumulative(S, dynamic_param) + 1
|
||||
new_alpha = float(scale*new_rank)
|
||||
|
||||
elif dynamic_method=="sv_fro":
|
||||
# Calculate new dim and alpha based off sqrt sum of squares
|
||||
new_rank = index_sv_fro(S, dynamic_param)
|
||||
new_rank = min(max(new_rank, 1), len(S)-1)
|
||||
new_rank = index_sv_fro(S, dynamic_param) + 1
|
||||
new_alpha = float(scale*new_rank)
|
||||
else:
|
||||
new_rank = rank
|
||||
@@ -172,7 +181,7 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
|
||||
param_dict["new_alpha"] = new_alpha
|
||||
param_dict["sum_retained"] = (s_rank)/s_sum
|
||||
param_dict["fro_retained"] = fro_percent
|
||||
param_dict["max_ratio"] = S[0]/S[new_rank]
|
||||
param_dict["max_ratio"] = S[0]/S[new_rank - 1]
|
||||
|
||||
return param_dict
|
||||
|
||||
@@ -208,18 +217,28 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
|
||||
|
||||
with torch.no_grad():
|
||||
for key, value in tqdm(lora_sd.items()):
|
||||
weight_name = None
|
||||
if 'lora_down' in key:
|
||||
block_down_name = key.split(".")[0]
|
||||
block_down_name = key.rsplit('.lora_down', 1)[0]
|
||||
weight_name = key.rsplit(".", 1)[-1]
|
||||
lora_down_weight = value
|
||||
if 'lora_up' in key:
|
||||
block_up_name = key.split(".")[0]
|
||||
lora_up_weight = value
|
||||
else:
|
||||
continue
|
||||
|
||||
# find corresponding lora_up and alpha
|
||||
block_up_name = block_down_name
|
||||
lora_up_weight = lora_sd.get(block_up_name + '.lora_up.' + weight_name, None)
|
||||
lora_alpha = lora_sd.get(block_down_name + '.alpha', None)
|
||||
|
||||
weights_loaded = (lora_down_weight is not None and lora_up_weight is not None)
|
||||
|
||||
if (block_down_name == block_up_name) and weights_loaded:
|
||||
if weights_loaded:
|
||||
|
||||
conv2d = (len(lora_down_weight.size()) == 4)
|
||||
if lora_alpha is None:
|
||||
scale = 1.0
|
||||
else:
|
||||
scale = lora_alpha/lora_down_weight.size()[0]
|
||||
|
||||
if conv2d:
|
||||
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
|
||||
@@ -264,7 +283,10 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
|
||||
|
||||
|
||||
def resize(args):
|
||||
if args.save_to is None or not (args.save_to.endswith('.ckpt') or args.save_to.endswith('.pt') or args.save_to.endswith('.pth') or args.save_to.endswith('.safetensors')):
|
||||
raise Exception("The --save_to argument must be specified and must be a .ckpt , .pt, .pth or .safetensors file.")
|
||||
|
||||
|
||||
def str_to_dtype(p):
|
||||
if p == 'float':
|
||||
return torch.float
|
||||
@@ -311,7 +333,7 @@ def resize(args):
|
||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--save_precision", type=str, default=None,
|
||||
@@ -329,7 +351,12 @@ if __name__ == '__main__':
|
||||
help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank")
|
||||
parser.add_argument("--dynamic_param", type=float, default=None,
|
||||
help="Specify target for dynamic reduction")
|
||||
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
resize(args)
|
||||
|
||||
348
networks/sdxl_merge_lora.py
Normal file
348
networks/sdxl_merge_lora.py
Normal file
@@ -0,0 +1,348 @@
|
||||
import math
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
from library import sai_model_spec, sdxl_model_util, train_util
|
||||
import library.model_util as model_util
|
||||
import lora
|
||||
|
||||
|
||||
def load_state_dict(file_name, dtype):
|
||||
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||
sd = load_file(file_name)
|
||||
metadata = train_util.load_metadata_from_safetensors(file_name)
|
||||
else:
|
||||
sd = torch.load(file_name, map_location="cpu")
|
||||
metadata = {}
|
||||
|
||||
for key in list(sd.keys()):
|
||||
if type(sd[key]) == torch.Tensor:
|
||||
sd[key] = sd[key].to(dtype)
|
||||
|
||||
return sd, metadata
|
||||
|
||||
|
||||
def save_to_file(file_name, model, state_dict, dtype, metadata):
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
if type(state_dict[key]) == torch.Tensor:
|
||||
state_dict[key] = state_dict[key].to(dtype)
|
||||
|
||||
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||
save_file(model, file_name, metadata=metadata)
|
||||
else:
|
||||
torch.save(model, file_name)
|
||||
|
||||
|
||||
def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype):
|
||||
text_encoder1.to(merge_dtype)
|
||||
text_encoder1.to(merge_dtype)
|
||||
unet.to(merge_dtype)
|
||||
|
||||
# create module map
|
||||
name_to_module = {}
|
||||
for i, root_module in enumerate([text_encoder1, text_encoder2, unet]):
|
||||
if i <= 1:
|
||||
if i == 0:
|
||||
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1
|
||||
else:
|
||||
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2
|
||||
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
|
||||
else:
|
||||
prefix = lora.LoRANetwork.LORA_PREFIX_UNET
|
||||
target_replace_modules = (
|
||||
lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
)
|
||||
|
||||
for name, module in root_module.named_modules():
|
||||
if module.__class__.__name__ in target_replace_modules:
|
||||
for child_name, child_module in module.named_modules():
|
||||
if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
|
||||
lora_name = prefix + "." + name + "." + child_name
|
||||
lora_name = lora_name.replace(".", "_")
|
||||
name_to_module[lora_name] = child_module
|
||||
|
||||
for model, ratio in zip(models, ratios):
|
||||
print(f"loading: {model}")
|
||||
lora_sd, _ = load_state_dict(model, merge_dtype)
|
||||
|
||||
print(f"merging...")
|
||||
for key in tqdm(lora_sd.keys()):
|
||||
if "lora_down" in key:
|
||||
up_key = key.replace("lora_down", "lora_up")
|
||||
alpha_key = key[: key.index("lora_down")] + "alpha"
|
||||
|
||||
# find original module for this lora
|
||||
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
|
||||
if module_name not in name_to_module:
|
||||
print(f"no module found for LoRA weight: {key}")
|
||||
continue
|
||||
module = name_to_module[module_name]
|
||||
# print(f"apply {key} to {module}")
|
||||
|
||||
down_weight = lora_sd[key]
|
||||
up_weight = lora_sd[up_key]
|
||||
|
||||
dim = down_weight.size()[0]
|
||||
alpha = lora_sd.get(alpha_key, dim)
|
||||
scale = alpha / dim
|
||||
|
||||
# W <- W + U * D
|
||||
weight = module.weight
|
||||
# print(module_name, down_weight.size(), up_weight.size())
|
||||
if len(weight.size()) == 2:
|
||||
# linear
|
||||
weight = weight + ratio * (up_weight @ down_weight) * scale
|
||||
elif down_weight.size()[2:4] == (1, 1):
|
||||
# conv2d 1x1
|
||||
weight = (
|
||||
weight
|
||||
+ ratio
|
||||
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
* scale
|
||||
)
|
||||
else:
|
||||
# conv2d 3x3
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
# print(conved.size(), weight.size(), module.stride, module.padding)
|
||||
weight = weight + ratio * conved * scale
|
||||
|
||||
module.weight = torch.nn.Parameter(weight)
|
||||
|
||||
|
||||
def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
|
||||
base_alphas = {} # alpha for merged model
|
||||
base_dims = {}
|
||||
|
||||
merged_sd = {}
|
||||
v2 = None
|
||||
base_model = None
|
||||
for model, ratio in zip(models, ratios):
|
||||
print(f"loading: {model}")
|
||||
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
|
||||
|
||||
if lora_metadata is not None:
|
||||
if v2 is None:
|
||||
v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # returns string, SDXLはv2がないのでFalseのはず
|
||||
if base_model is None:
|
||||
base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None)
|
||||
|
||||
# get alpha and dim
|
||||
alphas = {} # alpha for current model
|
||||
dims = {} # dims for current model
|
||||
for key in lora_sd.keys():
|
||||
if "alpha" in key:
|
||||
lora_module_name = key[: key.rfind(".alpha")]
|
||||
alpha = float(lora_sd[key].detach().numpy())
|
||||
alphas[lora_module_name] = alpha
|
||||
if lora_module_name not in base_alphas:
|
||||
base_alphas[lora_module_name] = alpha
|
||||
elif "lora_down" in key:
|
||||
lora_module_name = key[: key.rfind(".lora_down")]
|
||||
dim = lora_sd[key].size()[0]
|
||||
dims[lora_module_name] = dim
|
||||
if lora_module_name not in base_dims:
|
||||
base_dims[lora_module_name] = dim
|
||||
|
||||
for lora_module_name in dims.keys():
|
||||
if lora_module_name not in alphas:
|
||||
alpha = dims[lora_module_name]
|
||||
alphas[lora_module_name] = alpha
|
||||
if lora_module_name not in base_alphas:
|
||||
base_alphas[lora_module_name] = alpha
|
||||
|
||||
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
|
||||
|
||||
# merge
|
||||
print(f"merging...")
|
||||
for key in tqdm(lora_sd.keys()):
|
||||
if "alpha" in key:
|
||||
continue
|
||||
|
||||
if "lora_up" in key and concat:
|
||||
concat_dim = 1
|
||||
elif "lora_down" in key and concat:
|
||||
concat_dim = 0
|
||||
else:
|
||||
concat_dim = None
|
||||
|
||||
lora_module_name = key[: key.rfind(".lora_")]
|
||||
|
||||
base_alpha = base_alphas[lora_module_name]
|
||||
alpha = alphas[lora_module_name]
|
||||
|
||||
scale = math.sqrt(alpha / base_alpha) * ratio
|
||||
scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。
|
||||
|
||||
if key in merged_sd:
|
||||
assert (
|
||||
merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None
|
||||
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
|
||||
if concat_dim is not None:
|
||||
merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim)
|
||||
else:
|
||||
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
|
||||
else:
|
||||
merged_sd[key] = lora_sd[key] * scale
|
||||
|
||||
# set alpha to sd
|
||||
for lora_module_name, alpha in base_alphas.items():
|
||||
key = lora_module_name + ".alpha"
|
||||
merged_sd[key] = torch.tensor(alpha)
|
||||
if shuffle:
|
||||
key_down = lora_module_name + ".lora_down.weight"
|
||||
key_up = lora_module_name + ".lora_up.weight"
|
||||
dim = merged_sd[key_down].shape[0]
|
||||
perm = torch.randperm(dim)
|
||||
merged_sd[key_down] = merged_sd[key_down][perm]
|
||||
merged_sd[key_up] = merged_sd[key_up][:,perm]
|
||||
|
||||
print("merged model")
|
||||
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
|
||||
|
||||
# check all dims are same
|
||||
dims_list = list(set(base_dims.values()))
|
||||
alphas_list = list(set(base_alphas.values()))
|
||||
all_same_dims = True
|
||||
all_same_alphas = True
|
||||
for dims in dims_list:
|
||||
if dims != dims_list[0]:
|
||||
all_same_dims = False
|
||||
break
|
||||
for alphas in alphas_list:
|
||||
if alphas != alphas_list[0]:
|
||||
all_same_alphas = False
|
||||
break
|
||||
|
||||
# build minimum metadata
|
||||
dims = f"{dims_list[0]}" if all_same_dims else "Dynamic"
|
||||
alphas = f"{alphas_list[0]}" if all_same_alphas else "Dynamic"
|
||||
metadata = train_util.build_minimum_network_metadata(v2, base_model, "networks.lora", dims, alphas, None)
|
||||
|
||||
return merged_sd, metadata
|
||||
|
||||
|
||||
def merge(args):
|
||||
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
|
||||
|
||||
def str_to_dtype(p):
|
||||
if p == "float":
|
||||
return torch.float
|
||||
if p == "fp16":
|
||||
return torch.float16
|
||||
if p == "bf16":
|
||||
return torch.bfloat16
|
||||
return None
|
||||
|
||||
merge_dtype = str_to_dtype(args.precision)
|
||||
save_dtype = str_to_dtype(args.save_precision)
|
||||
if save_dtype is None:
|
||||
save_dtype = merge_dtype
|
||||
|
||||
if args.sd_model is not None:
|
||||
print(f"loading SD model: {args.sd_model}")
|
||||
|
||||
(
|
||||
text_model1,
|
||||
text_model2,
|
||||
vae,
|
||||
unet,
|
||||
logit_scale,
|
||||
ckpt_info,
|
||||
) = sdxl_model_util.load_models_from_sdxl_checkpoint(sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.sd_model, "cpu")
|
||||
|
||||
merge_to_sd_model(text_model1, text_model2, unet, args.models, args.ratios, merge_dtype)
|
||||
|
||||
if args.no_metadata:
|
||||
sai_metadata = None
|
||||
else:
|
||||
merged_from = sai_model_spec.build_merged_from([args.sd_model] + args.models)
|
||||
title = os.path.splitext(os.path.basename(args.save_to))[0]
|
||||
sai_metadata = sai_model_spec.build_metadata(
|
||||
None, False, False, True, False, False, time.time(), title=title, merged_from=merged_from
|
||||
)
|
||||
|
||||
print(f"saving SD model to: {args.save_to}")
|
||||
sdxl_model_util.save_stable_diffusion_checkpoint(
|
||||
args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype
|
||||
)
|
||||
else:
|
||||
state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
|
||||
|
||||
print(f"calculating hashes and creating metadata...")
|
||||
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
metadata["sshs_legacy_hash"] = legacy_hash
|
||||
|
||||
if not args.no_metadata:
|
||||
merged_from = sai_model_spec.build_merged_from(args.models)
|
||||
title = os.path.splitext(os.path.basename(args.save_to))[0]
|
||||
sai_metadata = sai_model_spec.build_metadata(
|
||||
state_dict, False, False, True, True, False, time.time(), title=title, merged_from=merged_from
|
||||
)
|
||||
metadata.update(sai_metadata)
|
||||
|
||||
print(f"saving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--save_precision",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=[None, "float", "fp16", "bf16"],
|
||||
help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
type=str,
|
||||
default="float",
|
||||
choices=["float", "fp16", "bf16"],
|
||||
help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sd_model",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
|
||||
)
|
||||
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
|
||||
parser.add_argument(
|
||||
"--no_metadata",
|
||||
action="store_true",
|
||||
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
|
||||
+ "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--concat",
|
||||
action="store_true",
|
||||
help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / "
|
||||
+ "マージの代わりに結合する(LoRAのdim(rank)は入力dimの合計になる)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shuffle",
|
||||
action="store_true",
|
||||
help="shuffle lora weight./ "
|
||||
+ "LoRAの重みをシャッフルする",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
merge(args)
|
||||
@@ -1,182 +1,257 @@
|
||||
|
||||
import math
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
import library.model_util as model_util
|
||||
import lora
|
||||
from library import sai_model_spec, train_util
|
||||
|
||||
|
||||
CLAMP_QUANTILE = 0.99
|
||||
|
||||
|
||||
def load_state_dict(file_name, dtype):
|
||||
if os.path.splitext(file_name)[1] == '.safetensors':
|
||||
sd = load_file(file_name)
|
||||
else:
|
||||
sd = torch.load(file_name, map_location='cpu')
|
||||
for key in list(sd.keys()):
|
||||
if type(sd[key]) == torch.Tensor:
|
||||
sd[key] = sd[key].to(dtype)
|
||||
return sd
|
||||
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||
sd = load_file(file_name)
|
||||
metadata = train_util.load_metadata_from_safetensors(file_name)
|
||||
else:
|
||||
sd = torch.load(file_name, map_location="cpu")
|
||||
metadata = {}
|
||||
|
||||
for key in list(sd.keys()):
|
||||
if type(sd[key]) == torch.Tensor:
|
||||
sd[key] = sd[key].to(dtype)
|
||||
|
||||
return sd, metadata
|
||||
|
||||
|
||||
def save_to_file(file_name, state_dict, dtype):
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
if type(state_dict[key]) == torch.Tensor:
|
||||
state_dict[key] = state_dict[key].to(dtype)
|
||||
def save_to_file(file_name, state_dict, dtype, metadata):
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
if type(state_dict[key]) == torch.Tensor:
|
||||
state_dict[key] = state_dict[key].to(dtype)
|
||||
|
||||
if os.path.splitext(file_name)[1] == '.safetensors':
|
||||
save_file(state_dict, file_name)
|
||||
else:
|
||||
torch.save(state_dict, file_name)
|
||||
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||
save_file(state_dict, file_name, metadata=metadata)
|
||||
else:
|
||||
torch.save(state_dict, file_name)
|
||||
|
||||
|
||||
def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype):
|
||||
print(f"new rank: {new_rank}, new conv rank: {new_conv_rank}")
|
||||
merged_sd = {}
|
||||
for model, ratio in zip(models, ratios):
|
||||
print(f"loading: {model}")
|
||||
lora_sd = load_state_dict(model, merge_dtype)
|
||||
print(f"new rank: {new_rank}, new conv rank: {new_conv_rank}")
|
||||
merged_sd = {}
|
||||
v2 = None
|
||||
base_model = None
|
||||
for model, ratio in zip(models, ratios):
|
||||
print(f"loading: {model}")
|
||||
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
|
||||
|
||||
# merge
|
||||
print(f"merging...")
|
||||
for key in tqdm(list(lora_sd.keys())):
|
||||
if 'lora_down' not in key:
|
||||
continue
|
||||
if lora_metadata is not None:
|
||||
if v2 is None:
|
||||
v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # return string
|
||||
if base_model is None:
|
||||
base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None)
|
||||
|
||||
lora_module_name = key[:key.rfind(".lora_down")]
|
||||
# merge
|
||||
print(f"merging...")
|
||||
for key in tqdm(list(lora_sd.keys())):
|
||||
if "lora_down" not in key:
|
||||
continue
|
||||
|
||||
down_weight = lora_sd[key]
|
||||
network_dim = down_weight.size()[0]
|
||||
lora_module_name = key[: key.rfind(".lora_down")]
|
||||
|
||||
up_weight = lora_sd[lora_module_name + '.lora_up.weight']
|
||||
alpha = lora_sd.get(lora_module_name + '.alpha', network_dim)
|
||||
down_weight = lora_sd[key]
|
||||
network_dim = down_weight.size()[0]
|
||||
|
||||
in_dim = down_weight.size()[1]
|
||||
out_dim = up_weight.size()[0]
|
||||
conv2d = len(down_weight.size()) == 4
|
||||
kernel_size = None if not conv2d else down_weight.size()[2:4]
|
||||
# print(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size)
|
||||
up_weight = lora_sd[lora_module_name + ".lora_up.weight"]
|
||||
alpha = lora_sd.get(lora_module_name + ".alpha", network_dim)
|
||||
|
||||
# make original weight if not exist
|
||||
if lora_module_name not in merged_sd:
|
||||
weight = torch.zeros((out_dim, in_dim, *kernel_size) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
|
||||
if device:
|
||||
weight = weight.to(device)
|
||||
else:
|
||||
weight = merged_sd[lora_module_name]
|
||||
in_dim = down_weight.size()[1]
|
||||
out_dim = up_weight.size()[0]
|
||||
conv2d = len(down_weight.size()) == 4
|
||||
kernel_size = None if not conv2d else down_weight.size()[2:4]
|
||||
# print(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size)
|
||||
|
||||
# merge to weight
|
||||
if device:
|
||||
up_weight = up_weight.to(device)
|
||||
down_weight = down_weight.to(device)
|
||||
# make original weight if not exist
|
||||
if lora_module_name not in merged_sd:
|
||||
weight = torch.zeros((out_dim, in_dim, *kernel_size) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
|
||||
if device:
|
||||
weight = weight.to(device)
|
||||
else:
|
||||
weight = merged_sd[lora_module_name]
|
||||
|
||||
# W <- W + U * D
|
||||
scale = (alpha / network_dim)
|
||||
if not conv2d: # linear
|
||||
weight = weight + ratio * (up_weight @ down_weight) * scale
|
||||
elif kernel_size == (1, 1):
|
||||
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
|
||||
).unsqueeze(2).unsqueeze(3) * scale
|
||||
else:
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
weight = weight + ratio * conved * scale
|
||||
# merge to weight
|
||||
if device:
|
||||
up_weight = up_weight.to(device)
|
||||
down_weight = down_weight.to(device)
|
||||
|
||||
merged_sd[lora_module_name] = weight
|
||||
# W <- W + U * D
|
||||
scale = alpha / network_dim
|
||||
|
||||
# extract from merged weights
|
||||
print("extract new lora...")
|
||||
merged_lora_sd = {}
|
||||
with torch.no_grad():
|
||||
for lora_module_name, mat in tqdm(list(merged_sd.items())):
|
||||
conv2d = (len(mat.size()) == 4)
|
||||
kernel_size = None if not conv2d else mat.size()[2:4]
|
||||
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
||||
out_dim, in_dim = mat.size()[0:2]
|
||||
if device: # and isinstance(scale, torch.Tensor):
|
||||
scale = scale.to(device)
|
||||
|
||||
if conv2d:
|
||||
if conv2d_3x3:
|
||||
mat = mat.flatten(start_dim=1)
|
||||
else:
|
||||
mat = mat.squeeze()
|
||||
if not conv2d: # linear
|
||||
weight = weight + ratio * (up_weight @ down_weight) * scale
|
||||
elif kernel_size == (1, 1):
|
||||
weight = (
|
||||
weight
|
||||
+ ratio
|
||||
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
* scale
|
||||
)
|
||||
else:
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
weight = weight + ratio * conved * scale
|
||||
|
||||
module_new_rank = new_conv_rank if conv2d_3x3 else new_rank
|
||||
module_new_rank = min(module_new_rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
|
||||
merged_sd[lora_module_name] = weight
|
||||
|
||||
U, S, Vh = torch.linalg.svd(mat)
|
||||
# extract from merged weights
|
||||
print("extract new lora...")
|
||||
merged_lora_sd = {}
|
||||
with torch.no_grad():
|
||||
for lora_module_name, mat in tqdm(list(merged_sd.items())):
|
||||
conv2d = len(mat.size()) == 4
|
||||
kernel_size = None if not conv2d else mat.size()[2:4]
|
||||
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
||||
out_dim, in_dim = mat.size()[0:2]
|
||||
|
||||
U = U[:, :module_new_rank]
|
||||
S = S[:module_new_rank]
|
||||
U = U @ torch.diag(S)
|
||||
if conv2d:
|
||||
if conv2d_3x3:
|
||||
mat = mat.flatten(start_dim=1)
|
||||
else:
|
||||
mat = mat.squeeze()
|
||||
|
||||
Vh = Vh[:module_new_rank, :]
|
||||
module_new_rank = new_conv_rank if conv2d_3x3 else new_rank
|
||||
module_new_rank = min(module_new_rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
|
||||
|
||||
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||
low_val = -hi_val
|
||||
U, S, Vh = torch.linalg.svd(mat)
|
||||
|
||||
U = U.clamp(low_val, hi_val)
|
||||
Vh = Vh.clamp(low_val, hi_val)
|
||||
U = U[:, :module_new_rank]
|
||||
S = S[:module_new_rank]
|
||||
U = U @ torch.diag(S)
|
||||
|
||||
if conv2d:
|
||||
U = U.reshape(out_dim, module_new_rank, 1, 1)
|
||||
Vh = Vh.reshape(module_new_rank, in_dim, kernel_size[0], kernel_size[1])
|
||||
Vh = Vh[:module_new_rank, :]
|
||||
|
||||
up_weight = U
|
||||
down_weight = Vh
|
||||
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||
low_val = -hi_val
|
||||
|
||||
merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous()
|
||||
merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous()
|
||||
merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(module_new_rank)
|
||||
U = U.clamp(low_val, hi_val)
|
||||
Vh = Vh.clamp(low_val, hi_val)
|
||||
|
||||
return merged_lora_sd
|
||||
if conv2d:
|
||||
U = U.reshape(out_dim, module_new_rank, 1, 1)
|
||||
Vh = Vh.reshape(module_new_rank, in_dim, kernel_size[0], kernel_size[1])
|
||||
|
||||
up_weight = U
|
||||
down_weight = Vh
|
||||
|
||||
merged_lora_sd[lora_module_name + ".lora_up.weight"] = up_weight.to("cpu").contiguous()
|
||||
merged_lora_sd[lora_module_name + ".lora_down.weight"] = down_weight.to("cpu").contiguous()
|
||||
merged_lora_sd[lora_module_name + ".alpha"] = torch.tensor(module_new_rank)
|
||||
|
||||
# build minimum metadata
|
||||
dims = f"{new_rank}"
|
||||
alphas = f"{new_rank}"
|
||||
if new_conv_rank is not None:
|
||||
network_args = {"conv_dim": new_conv_rank, "conv_alpha": new_conv_rank}
|
||||
else:
|
||||
network_args = None
|
||||
metadata = train_util.build_minimum_network_metadata(v2, base_model, "networks.lora", dims, alphas, network_args)
|
||||
|
||||
return merged_lora_sd, metadata, v2 == "True", base_model
|
||||
|
||||
|
||||
def merge(args):
|
||||
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
|
||||
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
|
||||
|
||||
def str_to_dtype(p):
|
||||
if p == 'float':
|
||||
return torch.float
|
||||
if p == 'fp16':
|
||||
return torch.float16
|
||||
if p == 'bf16':
|
||||
return torch.bfloat16
|
||||
return None
|
||||
def str_to_dtype(p):
|
||||
if p == "float":
|
||||
return torch.float
|
||||
if p == "fp16":
|
||||
return torch.float16
|
||||
if p == "bf16":
|
||||
return torch.bfloat16
|
||||
return None
|
||||
|
||||
merge_dtype = str_to_dtype(args.precision)
|
||||
save_dtype = str_to_dtype(args.save_precision)
|
||||
if save_dtype is None:
|
||||
save_dtype = merge_dtype
|
||||
merge_dtype = str_to_dtype(args.precision)
|
||||
save_dtype = str_to_dtype(args.save_precision)
|
||||
if save_dtype is None:
|
||||
save_dtype = merge_dtype
|
||||
|
||||
new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank
|
||||
state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype)
|
||||
new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank
|
||||
state_dict, metadata, v2, base_model = merge_lora_models(
|
||||
args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype
|
||||
)
|
||||
|
||||
print(f"saving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, save_dtype)
|
||||
print(f"calculating hashes and creating metadata...")
|
||||
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
metadata["sshs_legacy_hash"] = legacy_hash
|
||||
|
||||
if not args.no_metadata:
|
||||
is_sdxl = base_model is not None and base_model.lower().startswith("sdxl")
|
||||
merged_from = sai_model_spec.build_merged_from(args.models)
|
||||
title = os.path.splitext(os.path.basename(args.save_to))[0]
|
||||
sai_metadata = sai_model_spec.build_metadata(
|
||||
state_dict, v2, v2, is_sdxl, True, False, time.time(), title=title, merged_from=merged_from
|
||||
)
|
||||
if v2:
|
||||
# TODO read sai modelspec
|
||||
print(
|
||||
"Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
|
||||
)
|
||||
metadata.update(sai_metadata)
|
||||
|
||||
print(f"saving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, save_dtype, metadata)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--save_precision", type=str, default=None,
|
||||
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ")
|
||||
parser.add_argument("--precision", type=str, default="float",
|
||||
choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)")
|
||||
parser.add_argument("--save_to", type=str, default=None,
|
||||
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
|
||||
parser.add_argument("--models", type=str, nargs='*',
|
||||
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors")
|
||||
parser.add_argument("--ratios", type=float, nargs='*',
|
||||
help="ratios for each model / それぞれのLoRAモデルの比率")
|
||||
parser.add_argument("--new_rank", type=int, default=4,
|
||||
help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
|
||||
parser.add_argument("--new_conv_rank", type=int, default=None,
|
||||
help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ")
|
||||
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--save_precision",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=[None, "float", "fp16", "bf16"],
|
||||
help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
type=str,
|
||||
default="float",
|
||||
choices=["float", "fp16", "bf16"],
|
||||
help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
|
||||
)
|
||||
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
|
||||
parser.add_argument("--new_rank", type=int, default=4, help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
|
||||
parser.add_argument(
|
||||
"--new_conv_rank",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ",
|
||||
)
|
||||
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
||||
parser.add_argument(
|
||||
"--no_metadata",
|
||||
action="store_true",
|
||||
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
|
||||
+ "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
merge(args)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
merge(args)
|
||||
|
||||
@@ -1,26 +1,33 @@
|
||||
accelerate==0.15.0
|
||||
transformers==4.26.0
|
||||
accelerate==0.25.0
|
||||
transformers==4.36.2
|
||||
diffusers[torch]==0.25.0
|
||||
ftfy==6.1.1
|
||||
albumentations==1.3.0
|
||||
# albumentations==1.3.0
|
||||
opencv-python==4.7.0.68
|
||||
einops==0.6.0
|
||||
diffusers[torch]==0.10.2
|
||||
einops==0.6.1
|
||||
pytorch-lightning==1.9.0
|
||||
bitsandbytes==0.35.0
|
||||
# bitsandbytes==0.39.1
|
||||
tensorboard==2.10.1
|
||||
safetensors==0.2.6
|
||||
gradio==3.16.2
|
||||
safetensors==0.4.2
|
||||
# gradio==3.16.2
|
||||
altair==4.2.2
|
||||
easygui==0.98.3
|
||||
toml==0.10.2
|
||||
voluptuous==0.13.1
|
||||
huggingface-hub==0.20.1
|
||||
# for BLIP captioning
|
||||
requests==2.28.2
|
||||
timm==0.6.12
|
||||
fairscale==0.4.13
|
||||
# for WD14 captioning
|
||||
# tensorflow<2.11
|
||||
tensorflow==2.10.1
|
||||
huggingface-hub==0.12.0
|
||||
# requests==2.28.2
|
||||
# timm==0.6.12
|
||||
# fairscale==0.4.13
|
||||
# for WD14 captioning (tensorflow)
|
||||
# tensorflow==2.10.1
|
||||
# for WD14 captioning (onnx)
|
||||
# onnx==1.14.1
|
||||
# onnxruntime-gpu==1.16.0
|
||||
# onnxruntime==1.16.0
|
||||
# this is for onnx:
|
||||
# protobuf==3.20.3
|
||||
# open clip for SDXL
|
||||
open-clip-torch==2.20.0
|
||||
# for kohya_ss library
|
||||
.
|
||||
-e .
|
||||
|
||||
2828
sdxl_gen_img.py
Executable file
2828
sdxl_gen_img.py
Executable file
File diff suppressed because it is too large
Load Diff
326
sdxl_minimal_inference.py
Normal file
326
sdxl_minimal_inference.py
Normal file
@@ -0,0 +1,326 @@
|
||||
# 手元で推論を行うための最低限のコード。HuggingFace/DiffusersのCLIP、schedulerとVAEを使う
|
||||
# Minimal code for performing inference at local. Use HuggingFace/Diffusers CLIP, scheduler and VAE
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from einops import repeat
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPTokenizer
|
||||
from diffusers import EulerDiscreteScheduler
|
||||
from PIL import Image
|
||||
import open_clip
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from library import model_util, sdxl_model_util
|
||||
import networks.lora as lora
|
||||
|
||||
# scheduler: このあたりの設定はSD1/2と同じでいいらしい
|
||||
# scheduler: The settings around here seem to be the same as SD1/2
|
||||
SCHEDULER_LINEAR_START = 0.00085
|
||||
SCHEDULER_LINEAR_END = 0.0120
|
||||
SCHEDULER_TIMESTEPS = 1000
|
||||
SCHEDLER_SCHEDULE = "scaled_linear"
|
||||
|
||||
|
||||
# Time EmbeddingはDiffusersからのコピー
|
||||
# Time Embedding is copied from Diffusers
|
||||
|
||||
|
||||
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
if not repeat_only:
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
||||
device=timesteps.device
|
||||
)
|
||||
args = timesteps[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
else:
|
||||
embedding = repeat(timesteps, "b -> b d", d=dim)
|
||||
return embedding
|
||||
|
||||
|
||||
def get_timestep_embedding(x, outdim):
|
||||
assert len(x.shape) == 2
|
||||
b, dims = x.shape[0], x.shape[1]
|
||||
# x = rearrange(x, "b d -> (b d)")
|
||||
x = torch.flatten(x)
|
||||
emb = timestep_embedding(x, outdim)
|
||||
# emb = rearrange(emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=outdim)
|
||||
emb = torch.reshape(emb, (b, dims * outdim))
|
||||
return emb
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 画像生成条件を変更する場合はここを変更 / change here to change image generation conditions
|
||||
|
||||
# SDXLの追加のvector embeddingへ渡す値 / Values to pass to additional vector embedding of SDXL
|
||||
target_height = 1024
|
||||
target_width = 1024
|
||||
original_height = target_height
|
||||
original_width = target_width
|
||||
crop_top = 0
|
||||
crop_left = 0
|
||||
|
||||
steps = 50
|
||||
guidance_scale = 7
|
||||
seed = None # 1
|
||||
|
||||
DEVICE = "cuda"
|
||||
DTYPE = torch.float16 # bfloat16 may work
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ckpt_path", type=str, required=True)
|
||||
parser.add_argument("--prompt", type=str, default="A photo of a cat")
|
||||
parser.add_argument("--prompt2", type=str, default=None)
|
||||
parser.add_argument("--negative_prompt", type=str, default="")
|
||||
parser.add_argument("--output_dir", type=str, default=".")
|
||||
parser.add_argument(
|
||||
"--lora_weights",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=[],
|
||||
help="LoRA weights, only supports networks.lora, each argument is a `path;multiplier` (semi-colon separated)",
|
||||
)
|
||||
parser.add_argument("--interactive", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.prompt2 is None:
|
||||
args.prompt2 = args.prompt
|
||||
|
||||
# HuggingFaceのmodel id
|
||||
text_encoder_1_name = "openai/clip-vit-large-patch14"
|
||||
text_encoder_2_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
|
||||
# checkpointを読み込む。モデル変換についてはそちらの関数を参照
|
||||
# Load checkpoint. For model conversion, see this function
|
||||
|
||||
# 本体RAMが少ない場合はGPUにロードするといいかも
|
||||
# If the main RAM is small, it may be better to load it on the GPU
|
||||
text_model1, text_model2, vae, unet, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.ckpt_path, "cpu"
|
||||
)
|
||||
|
||||
# Text Encoder 1はSDXL本体でもHuggingFaceのものを使っている
|
||||
# In SDXL, Text Encoder 1 is also using HuggingFace's
|
||||
|
||||
# Text Encoder 2はSDXL本体ではopen_clipを使っている
|
||||
# それを使ってもいいが、SD2のDiffusers版に合わせる形で、HuggingFaceのものを使う
|
||||
# 重みの変換コードはSD2とほぼ同じ
|
||||
# In SDXL, Text Encoder 2 is using open_clip
|
||||
# It's okay to use it, but to match the Diffusers version of SD2, use HuggingFace's
|
||||
# The weight conversion code is almost the same as SD2
|
||||
|
||||
# VAEの構造はSDXLもSD1/2と同じだが、重みは異なるようだ。何より謎のscale値が違う
|
||||
# fp16でNaNが出やすいようだ
|
||||
# The structure of VAE is the same as SD1/2, but the weights seem to be different. Above all, the mysterious scale value is different.
|
||||
# NaN seems to be more likely to occur in fp16
|
||||
|
||||
unet.to(DEVICE, dtype=DTYPE)
|
||||
unet.eval()
|
||||
|
||||
vae_dtype = DTYPE
|
||||
if DTYPE == torch.float16:
|
||||
print("use float32 for vae")
|
||||
vae_dtype = torch.float32
|
||||
vae.to(DEVICE, dtype=vae_dtype)
|
||||
vae.eval()
|
||||
|
||||
text_model1.to(DEVICE, dtype=DTYPE)
|
||||
text_model1.eval()
|
||||
text_model2.to(DEVICE, dtype=DTYPE)
|
||||
text_model2.eval()
|
||||
|
||||
unet.set_use_memory_efficient_attention(True, False)
|
||||
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
|
||||
vae.set_use_memory_efficient_attention_xformers(True)
|
||||
|
||||
# Tokenizers
|
||||
tokenizer1 = CLIPTokenizer.from_pretrained(text_encoder_1_name)
|
||||
tokenizer2 = lambda x: open_clip.tokenize(x, context_length=77)
|
||||
|
||||
# LoRA
|
||||
for weights_file in args.lora_weights:
|
||||
if ";" in weights_file:
|
||||
weights_file, multiplier = weights_file.split(";")
|
||||
multiplier = float(multiplier)
|
||||
else:
|
||||
multiplier = 1.0
|
||||
|
||||
lora_model, weights_sd = lora.create_network_from_weights(
|
||||
multiplier, weights_file, vae, [text_model1, text_model2], unet, None, True
|
||||
)
|
||||
lora_model.merge_to([text_model1, text_model2], unet, weights_sd, DTYPE, DEVICE)
|
||||
|
||||
# scheduler
|
||||
scheduler = EulerDiscreteScheduler(
|
||||
num_train_timesteps=SCHEDULER_TIMESTEPS,
|
||||
beta_start=SCHEDULER_LINEAR_START,
|
||||
beta_end=SCHEDULER_LINEAR_END,
|
||||
beta_schedule=SCHEDLER_SCHEDULE,
|
||||
)
|
||||
|
||||
def generate_image(prompt, prompt2, negative_prompt, seed=None):
|
||||
# 将来的にサイズ情報も変えられるようにする / Make it possible to change the size information in the future
|
||||
# prepare embedding
|
||||
with torch.no_grad():
|
||||
# vector
|
||||
emb1 = get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256)
|
||||
emb2 = get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256)
|
||||
emb3 = get_timestep_embedding(torch.FloatTensor([target_height, target_width]).unsqueeze(0), 256)
|
||||
# print("emb1", emb1.shape)
|
||||
c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(DEVICE, dtype=DTYPE)
|
||||
uc_vector = c_vector.clone().to(DEVICE, dtype=DTYPE) # ちょっとここ正しいかどうかわからない I'm not sure if this is right
|
||||
|
||||
# crossattn
|
||||
|
||||
# Text Encoderを二つ呼ぶ関数 Function to call two Text Encoders
|
||||
def call_text_encoder(text, text2):
|
||||
# text encoder 1
|
||||
batch_encoding = tokenizer1(
|
||||
text,
|
||||
truncation=True,
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
tokens = batch_encoding["input_ids"].to(DEVICE)
|
||||
|
||||
with torch.no_grad():
|
||||
enc_out = text_model1(tokens, output_hidden_states=True, return_dict=True)
|
||||
text_embedding1 = enc_out["hidden_states"][11]
|
||||
# text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding) # layer normは通さないらしい
|
||||
|
||||
# text encoder 2
|
||||
with torch.no_grad():
|
||||
tokens = tokenizer2(text2).to(DEVICE)
|
||||
|
||||
enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True)
|
||||
text_embedding2_penu = enc_out["hidden_states"][-2]
|
||||
# print("hidden_states2", text_embedding2_penu.shape)
|
||||
text_embedding2_pool = enc_out["text_embeds"] # do not support Textual Inversion
|
||||
|
||||
# 連結して終了 concat and finish
|
||||
text_embedding = torch.cat([text_embedding1, text_embedding2_penu], dim=2)
|
||||
return text_embedding, text_embedding2_pool
|
||||
|
||||
# cond
|
||||
c_ctx, c_ctx_pool = call_text_encoder(prompt, prompt2)
|
||||
# print(c_ctx.shape, c_ctx_p.shape, c_vector.shape)
|
||||
c_vector = torch.cat([c_ctx_pool, c_vector], dim=1)
|
||||
|
||||
# uncond
|
||||
uc_ctx, uc_ctx_pool = call_text_encoder(negative_prompt, negative_prompt)
|
||||
uc_vector = torch.cat([uc_ctx_pool, uc_vector], dim=1)
|
||||
|
||||
text_embeddings = torch.cat([uc_ctx, c_ctx])
|
||||
vector_embeddings = torch.cat([uc_vector, c_vector])
|
||||
|
||||
# メモリ使用量を減らすにはここでText Encoderを削除するかCPUへ移動する
|
||||
|
||||
if seed is not None:
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
# # random generator for initial noise
|
||||
# generator = torch.Generator(device="cuda").manual_seed(seed)
|
||||
generator = None
|
||||
else:
|
||||
generator = None
|
||||
|
||||
# get the initial random noise unless the user supplied it
|
||||
# SDXLはCPUでlatentsを作成しているので一応合わせておく、Diffusersはtarget deviceでlatentsを作成している
|
||||
# SDXL creates latents in CPU, Diffusers creates latents in target device
|
||||
latents_shape = (1, 4, target_height // 8, target_width // 8)
|
||||
latents = torch.randn(
|
||||
latents_shape,
|
||||
generator=generator,
|
||||
device="cpu",
|
||||
dtype=torch.float32,
|
||||
).to(DEVICE, dtype=DTYPE)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * scheduler.init_noise_sigma
|
||||
|
||||
# set timesteps
|
||||
scheduler.set_timesteps(steps, DEVICE)
|
||||
|
||||
# このへんはDiffusersからのコピペ
|
||||
# Copy from Diffusers
|
||||
timesteps = scheduler.timesteps.to(DEVICE) # .to(DTYPE)
|
||||
num_latent_input = 2
|
||||
with torch.no_grad():
|
||||
for i, t in enumerate(tqdm(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
|
||||
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
noise_pred = unet(latent_model_input, t, text_embeddings, vector_embeddings)
|
||||
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
# latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
||||
|
||||
# latents = 1 / 0.18215 * latents
|
||||
latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents
|
||||
latents = latents.to(vae_dtype)
|
||||
image = vae.decode(latents).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
# image = self.numpy_to_pil(image)
|
||||
image = (image * 255).round().astype("uint8")
|
||||
image = [Image.fromarray(im) for im in image]
|
||||
|
||||
# 保存して終了 save and finish
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
for i, img in enumerate(image):
|
||||
img.save(os.path.join(args.output_dir, f"image_{timestamp}_{i:03d}.png"))
|
||||
|
||||
if not args.interactive:
|
||||
generate_image(args.prompt, args.prompt2, args.negative_prompt, seed)
|
||||
else:
|
||||
# loop for interactive
|
||||
while True:
|
||||
prompt = input("prompt: ")
|
||||
if prompt == "":
|
||||
break
|
||||
prompt2 = input("prompt2: ")
|
||||
if prompt2 == "":
|
||||
prompt2 = prompt
|
||||
negative_prompt = input("negative prompt: ")
|
||||
seed = input("seed: ")
|
||||
if seed == "":
|
||||
seed = None
|
||||
else:
|
||||
seed = int(seed)
|
||||
generate_image(prompt, prompt2, negative_prompt, seed)
|
||||
|
||||
print("Done!")
|
||||
779
sdxl_train.py
Normal file
779
sdxl_train.py
Normal file
@@ -0,0 +1,779 @@
|
||||
# training with captions
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
from multiprocessing import Value
|
||||
from typing import List
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
from library import sdxl_model_util
|
||||
|
||||
import library.train_util as train_util
|
||||
import library.config_util as config_util
|
||||
import library.sdxl_train_util as sdxl_train_util
|
||||
from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import (
|
||||
apply_snr_weight,
|
||||
prepare_scheduler_for_custom_training,
|
||||
scale_v_prediction_loss_like_noise_prediction,
|
||||
add_v_prediction_like_loss,
|
||||
apply_debiased_estimation,
|
||||
)
|
||||
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
||||
|
||||
|
||||
UNET_NUM_BLOCKS_FOR_BLOCK_LR = 23
|
||||
|
||||
|
||||
def get_block_params_to_optimize(unet: SdxlUNet2DConditionModel, block_lrs: List[float]) -> List[dict]:
|
||||
block_params = [[] for _ in range(len(block_lrs))]
|
||||
|
||||
for i, (name, param) in enumerate(unet.named_parameters()):
|
||||
if name.startswith("time_embed.") or name.startswith("label_emb."):
|
||||
block_index = 0 # 0
|
||||
elif name.startswith("input_blocks."): # 1-9
|
||||
block_index = 1 + int(name.split(".")[1])
|
||||
elif name.startswith("middle_block."): # 10-12
|
||||
block_index = 10 + int(name.split(".")[1])
|
||||
elif name.startswith("output_blocks."): # 13-21
|
||||
block_index = 13 + int(name.split(".")[1])
|
||||
elif name.startswith("out."): # 22
|
||||
block_index = 22
|
||||
else:
|
||||
raise ValueError(f"unexpected parameter name: {name}")
|
||||
|
||||
block_params[block_index].append(param)
|
||||
|
||||
params_to_optimize = []
|
||||
for i, params in enumerate(block_params):
|
||||
if block_lrs[i] == 0: # 0のときは学習しない do not optimize when lr is 0
|
||||
continue
|
||||
params_to_optimize.append({"params": params, "lr": block_lrs[i]})
|
||||
|
||||
return params_to_optimize
|
||||
|
||||
|
||||
def append_block_lr_to_logs(block_lrs, logs, lr_scheduler, optimizer_type):
|
||||
names = []
|
||||
block_index = 0
|
||||
while block_index < UNET_NUM_BLOCKS_FOR_BLOCK_LR + 2:
|
||||
if block_index < UNET_NUM_BLOCKS_FOR_BLOCK_LR:
|
||||
if block_lrs[block_index] == 0:
|
||||
block_index += 1
|
||||
continue
|
||||
names.append(f"block{block_index}")
|
||||
elif block_index == UNET_NUM_BLOCKS_FOR_BLOCK_LR:
|
||||
names.append("text_encoder1")
|
||||
elif block_index == UNET_NUM_BLOCKS_FOR_BLOCK_LR + 1:
|
||||
names.append("text_encoder2")
|
||||
|
||||
block_index += 1
|
||||
|
||||
train_util.append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names)
|
||||
|
||||
|
||||
def train(args):
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
sdxl_train_util.verify_sdxl_training_args(args)
|
||||
|
||||
assert not args.weighted_captions, "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
|
||||
assert (
|
||||
not args.train_text_encoder or not args.cache_text_encoder_outputs
|
||||
), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません"
|
||||
|
||||
if args.block_lr:
|
||||
block_lrs = [float(lr) for lr in args.block_lr.split(",")]
|
||||
assert (
|
||||
len(block_lrs) == UNET_NUM_BLOCKS_FOR_BLOCK_LR
|
||||
), f"block_lr must have {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / block_lrは{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値を指定してください"
|
||||
else:
|
||||
block_lrs = None
|
||||
|
||||
cache_latents = args.cache_latents
|
||||
use_dreambooth_method = args.in_json is None
|
||||
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed) # 乱数系列を初期化する
|
||||
|
||||
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
|
||||
|
||||
# データセットを準備する
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
|
||||
if args.dataset_config is not None:
|
||||
print(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
print(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
else:
|
||||
if use_dreambooth_method:
|
||||
print("Using DreamBooth method.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
|
||||
args.train_data_dir, args.reg_data_dir
|
||||
)
|
||||
}
|
||||
]
|
||||
}
|
||||
else:
|
||||
print("Training with captions.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": [
|
||||
{
|
||||
"image_dir": args.train_data_dir,
|
||||
"metadata_file": args.in_json,
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2])
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args, [tokenizer1, tokenizer2])
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(32)
|
||||
|
||||
if args.debug_dataset:
|
||||
train_util.debug_dataset(train_dataset_group, True)
|
||||
return
|
||||
if len(train_dataset_group) == 0:
|
||||
print(
|
||||
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
|
||||
)
|
||||
return
|
||||
|
||||
if cache_latents:
|
||||
assert (
|
||||
train_dataset_group.is_latent_cacheable()
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
|
||||
if args.cache_text_encoder_outputs:
|
||||
assert (
|
||||
train_dataset_group.is_text_encoder_output_cacheable()
|
||||
), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
||||
|
||||
# acceleratorを準備する
|
||||
print("prepare accelerator")
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
|
||||
|
||||
# モデルを読み込む
|
||||
(
|
||||
load_stable_diffusion_format,
|
||||
text_encoder1,
|
||||
text_encoder2,
|
||||
vae,
|
||||
unet,
|
||||
logit_scale,
|
||||
ckpt_info,
|
||||
) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
|
||||
# logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# verify load/save model formats
|
||||
if load_stable_diffusion_format:
|
||||
src_stable_diffusion_ckpt = args.pretrained_model_name_or_path
|
||||
src_diffusers_model_path = None
|
||||
else:
|
||||
src_stable_diffusion_ckpt = None
|
||||
src_diffusers_model_path = args.pretrained_model_name_or_path
|
||||
|
||||
if args.save_model_as is None:
|
||||
save_stable_diffusion_format = load_stable_diffusion_format
|
||||
use_safetensors = args.use_safetensors
|
||||
else:
|
||||
save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors"
|
||||
use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
|
||||
# assert save_stable_diffusion_format, "save_model_as must be ckpt or safetensors / save_model_asはckptかsafetensorsである必要があります"
|
||||
|
||||
# Diffusers版のxformers使用フラグを設定する関数
|
||||
def set_diffusers_xformers_flag(model, valid):
|
||||
def fn_recursive_set_mem_eff(module: torch.nn.Module):
|
||||
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
|
||||
module.set_use_memory_efficient_attention_xformers(valid)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_set_mem_eff(child)
|
||||
|
||||
fn_recursive_set_mem_eff(model)
|
||||
|
||||
# モデルに xformers とか memory efficient attention を組み込む
|
||||
if args.diffusers_xformers:
|
||||
# もうU-Netを独自にしたので動かないけどVAEのxformersは動くはず
|
||||
accelerator.print("Use xformers by Diffusers")
|
||||
# set_diffusers_xformers_flag(unet, True)
|
||||
set_diffusers_xformers_flag(vae, True)
|
||||
else:
|
||||
# Windows版のxformersはfloatで学習できなかったりするのでxformersを使わない設定も可能にしておく必要がある
|
||||
accelerator.print("Disable Diffusers' xformers")
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
||||
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
|
||||
vae.set_use_memory_efficient_attention_xformers(args.xformers)
|
||||
|
||||
# 学習を準備する
|
||||
if cache_latents:
|
||||
vae.to(accelerator.device, dtype=vae_dtype)
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||
vae.to("cpu")
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# 学習を準備する:モデルを適切な状態にする
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
train_unet = args.learning_rate > 0
|
||||
train_text_encoder1 = False
|
||||
train_text_encoder2 = False
|
||||
|
||||
if args.train_text_encoder:
|
||||
# TODO each option for two text encoders?
|
||||
accelerator.print("enable text encoder training")
|
||||
if args.gradient_checkpointing:
|
||||
text_encoder1.gradient_checkpointing_enable()
|
||||
text_encoder2.gradient_checkpointing_enable()
|
||||
lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train
|
||||
lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train
|
||||
train_text_encoder1 = lr_te1 > 0
|
||||
train_text_encoder2 = lr_te2 > 0
|
||||
|
||||
# caching one text encoder output is not supported
|
||||
if not train_text_encoder1:
|
||||
text_encoder1.to(weight_dtype)
|
||||
if not train_text_encoder2:
|
||||
text_encoder2.to(weight_dtype)
|
||||
text_encoder1.requires_grad_(train_text_encoder1)
|
||||
text_encoder2.requires_grad_(train_text_encoder2)
|
||||
text_encoder1.train(train_text_encoder1)
|
||||
text_encoder2.train(train_text_encoder2)
|
||||
else:
|
||||
text_encoder1.to(weight_dtype)
|
||||
text_encoder2.to(weight_dtype)
|
||||
text_encoder1.requires_grad_(False)
|
||||
text_encoder2.requires_grad_(False)
|
||||
text_encoder1.eval()
|
||||
text_encoder2.eval()
|
||||
|
||||
# TextEncoderの出力をキャッシュする
|
||||
if args.cache_text_encoder_outputs:
|
||||
# Text Encodes are eval and no grad
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
train_dataset_group.cache_text_encoder_outputs(
|
||||
(tokenizer1, tokenizer2),
|
||||
(text_encoder1, text_encoder2),
|
||||
accelerator.device,
|
||||
None,
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
accelerator.is_main_process,
|
||||
)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if not cache_latents:
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
vae.to(accelerator.device, dtype=vae_dtype)
|
||||
|
||||
unet.requires_grad_(train_unet)
|
||||
if not train_unet:
|
||||
unet.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared
|
||||
|
||||
training_models = []
|
||||
params_to_optimize = []
|
||||
if train_unet:
|
||||
training_models.append(unet)
|
||||
if block_lrs is None:
|
||||
params_to_optimize.append({"params": list(unet.parameters()), "lr": args.learning_rate})
|
||||
else:
|
||||
params_to_optimize.extend(get_block_params_to_optimize(unet, block_lrs))
|
||||
|
||||
if train_text_encoder1:
|
||||
training_models.append(text_encoder1)
|
||||
params_to_optimize.append({"params": list(text_encoder1.parameters()), "lr": args.learning_rate_te1 or args.learning_rate})
|
||||
if train_text_encoder2:
|
||||
training_models.append(text_encoder2)
|
||||
params_to_optimize.append({"params": list(text_encoder2.parameters()), "lr": args.learning_rate_te2 or args.learning_rate})
|
||||
|
||||
# calculate number of trainable parameters
|
||||
n_params = 0
|
||||
for params in params_to_optimize:
|
||||
for p in params["params"]:
|
||||
n_params += p.numel()
|
||||
|
||||
accelerator.print(f"train unet: {train_unet}, text_encoder1: {train_text_encoder1}, text_encoder2: {train_text_encoder2}")
|
||||
accelerator.print(f"number of models: {len(training_models)}")
|
||||
accelerator.print(f"number of trainable parameters: {n_params}")
|
||||
|
||||
# 学習に必要なクラスを準備する
|
||||
accelerator.print("prepare optimizer, data loader etc.")
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||
)
|
||||
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
# データセット側にも学習ステップを送信
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
|
||||
if args.full_fp16:
|
||||
assert (
|
||||
args.mixed_precision == "fp16"
|
||||
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
||||
accelerator.print("enable full fp16 training.")
|
||||
unet.to(weight_dtype)
|
||||
text_encoder1.to(weight_dtype)
|
||||
text_encoder2.to(weight_dtype)
|
||||
elif args.full_bf16:
|
||||
assert (
|
||||
args.mixed_precision == "bf16"
|
||||
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
|
||||
accelerator.print("enable full bf16 training.")
|
||||
unet.to(weight_dtype)
|
||||
text_encoder1.to(weight_dtype)
|
||||
text_encoder2.to(weight_dtype)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
if train_unet:
|
||||
unet = accelerator.prepare(unet)
|
||||
if train_text_encoder1:
|
||||
# freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer
|
||||
text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
|
||||
text_encoder1.text_model.final_layer_norm.requires_grad_(False)
|
||||
text_encoder1 = accelerator.prepare(text_encoder1)
|
||||
if train_text_encoder2:
|
||||
text_encoder2 = accelerator.prepare(text_encoder2)
|
||||
|
||||
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
|
||||
if args.cache_text_encoder_outputs:
|
||||
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
|
||||
text_encoder1.to("cpu", dtype=torch.float32)
|
||||
text_encoder2.to("cpu", dtype=torch.float32)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
# make sure Text Encoders are on GPU
|
||||
text_encoder1.to(accelerator.device)
|
||||
text_encoder2.to(accelerator.device)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||
if args.full_fp16:
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
|
||||
# resumeする
|
||||
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||
|
||||
# epoch数を計算する
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
||||
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
||||
|
||||
# 学習する
|
||||
# total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
accelerator.print("running training / 学習開始")
|
||||
accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
|
||||
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
||||
accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
|
||||
# accelerator.print(
|
||||
# f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
|
||||
# )
|
||||
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||
|
||||
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
||||
global_step = 0
|
||||
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
||||
)
|
||||
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
|
||||
if args.zero_terminal_snr:
|
||||
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
init_kwargs['wandb'] = {'name': args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
||||
|
||||
# For --sample_at_first
|
||||
sdxl_train_util.sample_images(
|
||||
accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], unet
|
||||
)
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
for epoch in range(num_train_epochs):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
for m in training_models:
|
||||
m.train()
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(*training_models):
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
# latentに変換
|
||||
latents = vae.encode(batch["images"].to(vae_dtype)).latent_dist.sample().to(weight_dtype)
|
||||
|
||||
# NaNが含まれていれば警告を表示し0に置き換える
|
||||
if torch.any(torch.isnan(latents)):
|
||||
accelerator.print("NaN found in latents, replacing with zeros")
|
||||
latents = torch.nan_to_num(latents, 0, out=latents)
|
||||
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
|
||||
|
||||
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
|
||||
input_ids1 = batch["input_ids"]
|
||||
input_ids2 = batch["input_ids2"]
|
||||
with torch.set_grad_enabled(args.train_text_encoder):
|
||||
# Get the text embedding for conditioning
|
||||
# TODO support weighted captions
|
||||
# if args.weighted_captions:
|
||||
# encoder_hidden_states = get_weighted_text_embeddings(
|
||||
# tokenizer,
|
||||
# text_encoder,
|
||||
# batch["captions"],
|
||||
# accelerator.device,
|
||||
# args.max_token_length // 75 if args.max_token_length else 1,
|
||||
# clip_skip=args.clip_skip,
|
||||
# )
|
||||
# else:
|
||||
input_ids1 = input_ids1.to(accelerator.device)
|
||||
input_ids2 = input_ids2.to(accelerator.device)
|
||||
# unwrap_model is fine for models not wrapped by accelerator
|
||||
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
|
||||
args.max_token_length,
|
||||
input_ids1,
|
||||
input_ids2,
|
||||
tokenizer1,
|
||||
tokenizer2,
|
||||
text_encoder1,
|
||||
text_encoder2,
|
||||
None if not args.full_fp16 else weight_dtype,
|
||||
accelerator=accelerator,
|
||||
)
|
||||
else:
|
||||
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
|
||||
encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype)
|
||||
pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype)
|
||||
|
||||
# # verify that the text encoder outputs are correct
|
||||
# ehs1, ehs2, p2 = train_util.get_hidden_states_sdxl(
|
||||
# args.max_token_length,
|
||||
# batch["input_ids"].to(text_encoder1.device),
|
||||
# batch["input_ids2"].to(text_encoder1.device),
|
||||
# tokenizer1,
|
||||
# tokenizer2,
|
||||
# text_encoder1,
|
||||
# text_encoder2,
|
||||
# None if not args.full_fp16 else weight_dtype,
|
||||
# )
|
||||
# b_size = encoder_hidden_states1.shape[0]
|
||||
# assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
||||
# assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
||||
# assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
||||
# print("text encoder outputs verified")
|
||||
|
||||
# get size embeddings
|
||||
orig_size = batch["original_sizes_hw"]
|
||||
crop_size = batch["crop_top_lefts"]
|
||||
target_size = batch["target_sizes_hw"]
|
||||
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
|
||||
|
||||
# concat embeddings
|
||||
vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
|
||||
text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
|
||||
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||
|
||||
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
||||
|
||||
# Predict the noise residual
|
||||
with accelerator.autocast():
|
||||
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
||||
|
||||
target = noise
|
||||
|
||||
if (
|
||||
args.min_snr_gamma
|
||||
or args.scale_v_pred_loss_like_noise_pred
|
||||
or args.v_pred_like_loss
|
||||
or args.debiased_estimation_loss
|
||||
):
|
||||
# do not mean over batch dimension for snr weight or scale v-pred loss
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
if args.v_pred_like_loss:
|
||||
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
||||
if args.debiased_estimation_loss:
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||
|
||||
loss = loss.mean() # mean over batch dimension
|
||||
else:
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
params_to_clip = []
|
||||
for m in training_models:
|
||||
params_to_clip.extend(m.parameters())
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
sdxl_train_util.sample_images(
|
||||
accelerator,
|
||||
args,
|
||||
None,
|
||||
global_step,
|
||||
accelerator.device,
|
||||
vae,
|
||||
[tokenizer1, tokenizer2],
|
||||
[text_encoder1, text_encoder2],
|
||||
unet,
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
||||
sdxl_train_util.save_sd_model_on_epoch_end_or_stepwise(
|
||||
args,
|
||||
False,
|
||||
accelerator,
|
||||
src_path,
|
||||
save_stable_diffusion_format,
|
||||
use_safetensors,
|
||||
save_dtype,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
accelerator.unwrap_model(text_encoder1),
|
||||
accelerator.unwrap_model(text_encoder2),
|
||||
accelerator.unwrap_model(unet),
|
||||
vae,
|
||||
logit_scale,
|
||||
ckpt_info,
|
||||
)
|
||||
|
||||
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss}
|
||||
if block_lrs is None:
|
||||
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=train_unet)
|
||||
else:
|
||||
append_block_lr_to_logs(block_lrs, logs, lr_scheduler, args.optimizer_type) # U-Net is included in block_lrs
|
||||
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||
avr_loss: float = loss_recorder.moving_average
|
||||
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss/epoch": loss_recorder.moving_average}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if args.save_every_n_epochs is not None:
|
||||
if accelerator.is_main_process:
|
||||
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
||||
sdxl_train_util.save_sd_model_on_epoch_end_or_stepwise(
|
||||
args,
|
||||
True,
|
||||
accelerator,
|
||||
src_path,
|
||||
save_stable_diffusion_format,
|
||||
use_safetensors,
|
||||
save_dtype,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
accelerator.unwrap_model(text_encoder1),
|
||||
accelerator.unwrap_model(text_encoder2),
|
||||
accelerator.unwrap_model(unet),
|
||||
vae,
|
||||
logit_scale,
|
||||
ckpt_info,
|
||||
)
|
||||
|
||||
sdxl_train_util.sample_images(
|
||||
accelerator,
|
||||
args,
|
||||
epoch + 1,
|
||||
global_step,
|
||||
accelerator.device,
|
||||
vae,
|
||||
[tokenizer1, tokenizer2],
|
||||
[text_encoder1, text_encoder2],
|
||||
unet,
|
||||
)
|
||||
|
||||
is_main_process = accelerator.is_main_process
|
||||
# if is_main_process:
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
text_encoder1 = accelerator.unwrap_model(text_encoder1)
|
||||
text_encoder2 = accelerator.unwrap_model(text_encoder2)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
if args.save_state: # and is_main_process:
|
||||
train_util.save_state_on_train_end(args, accelerator)
|
||||
|
||||
del accelerator # この後メモリを使うのでこれは消す
|
||||
|
||||
if is_main_process:
|
||||
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
||||
sdxl_train_util.save_sd_model_on_train_end(
|
||||
args,
|
||||
src_path,
|
||||
save_stable_diffusion_format,
|
||||
use_safetensors,
|
||||
save_dtype,
|
||||
epoch,
|
||||
global_step,
|
||||
text_encoder1,
|
||||
text_encoder2,
|
||||
unet,
|
||||
vae,
|
||||
logit_scale,
|
||||
ckpt_info,
|
||||
)
|
||||
print("model saved.")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, True, True, True)
|
||||
train_util.add_training_arguments(parser, False)
|
||||
train_util.add_sd_saving_arguments(parser)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
config_util.add_config_arguments(parser)
|
||||
custom_train_functions.add_custom_train_arguments(parser)
|
||||
sdxl_train_util.add_sdxl_training_arguments(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--learning_rate_te1",
|
||||
type=float,
|
||||
default=None,
|
||||
help="learning rate for text encoder 1 (ViT-L) / text encoder 1 (ViT-L)の学習率",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate_te2",
|
||||
type=float,
|
||||
default=None,
|
||||
help="learning rate for text encoder 2 (BiG-G) / text encoder 2 (BiG-G)の学習率",
|
||||
)
|
||||
|
||||
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
|
||||
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
|
||||
parser.add_argument(
|
||||
"--no_half_vae",
|
||||
action="store_true",
|
||||
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--block_lr",
|
||||
type=str,
|
||||
default=None,
|
||||
help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / "
|
||||
+ f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
train(args)
|
||||
603
sdxl_train_control_net_lllite.py
Normal file
603
sdxl_train_control_net_lllite.py
Normal file
@@ -0,0 +1,603 @@
|
||||
# cond_imageをU-Netのforwardで渡すバージョンのControlNet-LLLite検証用学習コード
|
||||
# training code for ControlNet-LLLite with passing cond_image to U-Net's forward
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from multiprocessing import Value
|
||||
from types import SimpleNamespace
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from accelerate.utils import set_seed
|
||||
import accelerate
|
||||
from diffusers import DDPMScheduler, ControlNetModel
|
||||
from safetensors.torch import load_file
|
||||
from library import sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util
|
||||
|
||||
import library.model_util as model_util
|
||||
import library.train_util as train_util
|
||||
import library.config_util as config_util
|
||||
from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
import library.huggingface_util as huggingface_util
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import (
|
||||
add_v_prediction_like_loss,
|
||||
apply_snr_weight,
|
||||
prepare_scheduler_for_custom_training,
|
||||
pyramid_noise_like,
|
||||
apply_noise_offset,
|
||||
scale_v_prediction_loss_like_noise_prediction,
|
||||
apply_debiased_estimation,
|
||||
)
|
||||
import networks.control_net_lllite_for_train as control_net_lllite_for_train
|
||||
|
||||
|
||||
# TODO 他のスクリプトと共通化する
|
||||
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
||||
logs = {
|
||||
"loss/current": current_loss,
|
||||
"loss/average": avr_loss,
|
||||
"lr": lr_scheduler.get_last_lr()[0],
|
||||
}
|
||||
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()):
|
||||
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
|
||||
|
||||
return logs
|
||||
|
||||
|
||||
def train(args):
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
sdxl_train_util.verify_sdxl_training_args(args)
|
||||
|
||||
cache_latents = args.cache_latents
|
||||
use_user_config = args.dataset_config is not None
|
||||
|
||||
if args.seed is None:
|
||||
args.seed = random.randint(0, 2**32)
|
||||
set_seed(args.seed)
|
||||
|
||||
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
|
||||
|
||||
# データセットを準備する
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
|
||||
if use_user_config:
|
||||
print(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "conditioning_data_dir"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
print(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
else:
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": config_util.generate_controlnet_subsets_config_by_subdirs(
|
||||
args.train_data_dir,
|
||||
args.conditioning_data_dir,
|
||||
args.caption_extension,
|
||||
)
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2])
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(32)
|
||||
|
||||
if args.debug_dataset:
|
||||
train_util.debug_dataset(train_dataset_group)
|
||||
return
|
||||
if len(train_dataset_group) == 0:
|
||||
print(
|
||||
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)"
|
||||
)
|
||||
return
|
||||
|
||||
if cache_latents:
|
||||
assert (
|
||||
train_dataset_group.is_latent_cacheable()
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
else:
|
||||
print("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません")
|
||||
|
||||
if args.cache_text_encoder_outputs:
|
||||
assert (
|
||||
train_dataset_group.is_text_encoder_output_cacheable()
|
||||
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
||||
|
||||
# acceleratorを準備する
|
||||
print("prepare accelerator")
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
is_main_process = accelerator.is_main_process
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
|
||||
|
||||
# モデルを読み込む
|
||||
(
|
||||
load_stable_diffusion_format,
|
||||
text_encoder1,
|
||||
text_encoder2,
|
||||
vae,
|
||||
unet,
|
||||
logit_scale,
|
||||
ckpt_info,
|
||||
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype)
|
||||
|
||||
# 学習を準備する
|
||||
if cache_latents:
|
||||
vae.to(accelerator.device, dtype=vae_dtype)
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_latents(
|
||||
vae,
|
||||
args.vae_batch_size,
|
||||
args.cache_latents_to_disk,
|
||||
accelerator.is_main_process,
|
||||
)
|
||||
vae.to("cpu")
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# TextEncoderの出力をキャッシュする
|
||||
if args.cache_text_encoder_outputs:
|
||||
# Text Encodes are eval and no grad
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_text_encoder_outputs(
|
||||
(tokenizer1, tokenizer2),
|
||||
(text_encoder1, text_encoder2),
|
||||
accelerator.device,
|
||||
None,
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
accelerator.is_main_process,
|
||||
)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# prepare ControlNet-LLLite
|
||||
control_net_lllite_for_train.replace_unet_linear_and_conv2d()
|
||||
|
||||
if args.network_weights is not None:
|
||||
accelerator.print(f"initialize U-Net with ControlNet-LLLite")
|
||||
with accelerate.init_empty_weights():
|
||||
unet_lllite = control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite()
|
||||
unet_lllite.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
unet_sd = unet.state_dict()
|
||||
info = unet_lllite.load_lllite_weights(args.network_weights, unet_sd)
|
||||
accelerator.print(f"load ControlNet-LLLite weights from {args.network_weights}: {info}")
|
||||
else:
|
||||
# cosumes large memory, so send to GPU before creating the LLLite model
|
||||
accelerator.print("sending U-Net to GPU")
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
unet_sd = unet.state_dict()
|
||||
|
||||
# init LLLite weights
|
||||
accelerator.print(f"initialize U-Net with ControlNet-LLLite")
|
||||
|
||||
if args.lowram:
|
||||
with accelerate.init_on_device(accelerator.device):
|
||||
unet_lllite = control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite()
|
||||
else:
|
||||
unet_lllite = control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite()
|
||||
unet_lllite.to(weight_dtype)
|
||||
|
||||
info = unet_lllite.load_lllite_weights(None, unet_sd)
|
||||
accelerator.print(f"init U-Net with ControlNet-LLLite weights: {info}")
|
||||
del unet_sd, unet
|
||||
|
||||
unet: control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite = unet_lllite
|
||||
del unet_lllite
|
||||
|
||||
unet.apply_lllite(args.cond_emb_dim, args.network_dim, args.network_dropout)
|
||||
|
||||
# モデルに xformers とか memory efficient attention を組み込む
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
|
||||
# 学習に必要なクラスを準備する
|
||||
accelerator.print("prepare optimizer, data loader etc.")
|
||||
|
||||
trainable_params = list(unet.prepare_params())
|
||||
print(f"trainable params count: {len(trainable_params)}")
|
||||
print(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}")
|
||||
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||
)
|
||||
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
# データセット側にも学習ステップを送信
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
|
||||
# if args.full_fp16:
|
||||
# assert (
|
||||
# args.mixed_precision == "fp16"
|
||||
# ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
||||
# accelerator.print("enable full fp16 training.")
|
||||
# unet.to(weight_dtype)
|
||||
# elif args.full_bf16:
|
||||
# assert (
|
||||
# args.mixed_precision == "bf16"
|
||||
# ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
|
||||
# accelerator.print("enable full bf16 training.")
|
||||
# unet.to(weight_dtype)
|
||||
|
||||
unet.to(weight_dtype)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
|
||||
else:
|
||||
unet.eval()
|
||||
|
||||
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
|
||||
if args.cache_text_encoder_outputs:
|
||||
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
|
||||
text_encoder1.to("cpu", dtype=torch.float32)
|
||||
text_encoder2.to("cpu", dtype=torch.float32)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
# make sure Text Encoders are on GPU
|
||||
text_encoder1.to(accelerator.device)
|
||||
text_encoder2.to(accelerator.device)
|
||||
|
||||
if not cache_latents:
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
vae.to(accelerator.device, dtype=vae_dtype)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||
if args.full_fp16:
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
|
||||
# resumeする
|
||||
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||
|
||||
# epoch数を計算する
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
||||
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
||||
|
||||
# 学習する
|
||||
# TODO: find a way to handle total batch size when there are multiple datasets
|
||||
accelerator.print("running training / 学習開始")
|
||||
accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
||||
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
||||
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
||||
accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
|
||||
# print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
||||
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||
|
||||
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
||||
global_step = 0
|
||||
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
||||
)
|
||||
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
|
||||
if args.zero_terminal_snr:
|
||||
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
init_kwargs['wandb'] = {'name': args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers(
|
||||
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
|
||||
)
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
del train_dataset_group
|
||||
|
||||
# function for saving/removing
|
||||
def save_model(
|
||||
ckpt_name,
|
||||
unwrapped_nw: control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite,
|
||||
steps,
|
||||
epoch_no,
|
||||
force_sync_upload=False,
|
||||
):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
|
||||
accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
|
||||
sai_metadata = train_util.get_sai_model_spec(None, args, True, True, False)
|
||||
sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/control-net-lllite"
|
||||
|
||||
unwrapped_nw.save_lllite_weights(ckpt_file, save_dtype, sai_metadata)
|
||||
if args.huggingface_repo_id is not None:
|
||||
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
|
||||
|
||||
def remove_model(old_ckpt_name):
|
||||
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
||||
if os.path.exists(old_ckpt_file):
|
||||
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
|
||||
os.remove(old_ckpt_file)
|
||||
|
||||
# training loop
|
||||
for epoch in range(num_train_epochs):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(unet):
|
||||
with torch.no_grad():
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device)
|
||||
else:
|
||||
# latentに変換
|
||||
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample()
|
||||
|
||||
# NaNが含まれていれば警告を表示し0に置き換える
|
||||
if torch.any(torch.isnan(latents)):
|
||||
accelerator.print("NaN found in latents, replacing with zeros")
|
||||
latents = torch.nan_to_num(latents, 0, out=latents)
|
||||
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
|
||||
|
||||
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
|
||||
input_ids1 = batch["input_ids"]
|
||||
input_ids2 = batch["input_ids2"]
|
||||
with torch.no_grad():
|
||||
# Get the text embedding for conditioning
|
||||
input_ids1 = input_ids1.to(accelerator.device)
|
||||
input_ids2 = input_ids2.to(accelerator.device)
|
||||
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
|
||||
args.max_token_length,
|
||||
input_ids1,
|
||||
input_ids2,
|
||||
tokenizer1,
|
||||
tokenizer2,
|
||||
text_encoder1,
|
||||
text_encoder2,
|
||||
None if not args.full_fp16 else weight_dtype,
|
||||
)
|
||||
else:
|
||||
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
|
||||
encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype)
|
||||
pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype)
|
||||
|
||||
# get size embeddings
|
||||
orig_size = batch["original_sizes_hw"]
|
||||
crop_size = batch["crop_top_lefts"]
|
||||
target_size = batch["target_sizes_hw"]
|
||||
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
|
||||
|
||||
# concat embeddings
|
||||
vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
|
||||
text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
|
||||
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||
|
||||
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
||||
|
||||
controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype)
|
||||
|
||||
with accelerator.autocast():
|
||||
# conditioning imageをControlNetに渡す / pass conditioning image to ControlNet
|
||||
# 内部でcond_embに変換される / it will be converted to cond_emb inside
|
||||
|
||||
# それらの値を使いつつ、U-Netでノイズを予測する / predict noise with U-Net using those values
|
||||
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding, controlnet_image)
|
||||
|
||||
if args.v_parameterization:
|
||||
# v-parameterization training
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
if args.v_pred_like_loss:
|
||||
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
||||
if args.debiased_estimation_loss:
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
params_to_clip = unet.get_trainable_params()
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
# sdxl_train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
|
||||
save_model(ckpt_name, accelerator.unwrap_model(unet), global_step, epoch)
|
||||
|
||||
if args.save_state:
|
||||
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
|
||||
|
||||
remove_step_no = train_util.get_remove_step_no(args, global_step)
|
||||
if remove_step_no is not None:
|
||||
remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
|
||||
remove_model(remove_ckpt_name)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||
avr_loss: float = loss_recorder.moving_average
|
||||
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss/epoch": loss_recorder.moving_average}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# 指定エポックごとにモデルを保存
|
||||
if args.save_every_n_epochs is not None:
|
||||
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
|
||||
if is_main_process and saving:
|
||||
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
|
||||
save_model(ckpt_name, accelerator.unwrap_model(unet), global_step, epoch + 1)
|
||||
|
||||
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
|
||||
if remove_epoch_no is not None:
|
||||
remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no)
|
||||
remove_model(remove_ckpt_name)
|
||||
|
||||
if args.save_state:
|
||||
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
|
||||
|
||||
# self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
# end of epoch
|
||||
|
||||
if is_main_process:
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
if is_main_process and args.save_state:
|
||||
train_util.save_state_on_train_end(args, accelerator)
|
||||
|
||||
if is_main_process:
|
||||
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
||||
save_model(ckpt_name, unet, global_step, num_train_epochs, force_sync_upload=True)
|
||||
|
||||
print("model saved.")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, False, True, True)
|
||||
train_util.add_training_arguments(parser, False)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
config_util.add_config_arguments(parser)
|
||||
custom_train_functions.add_custom_train_arguments(parser)
|
||||
sdxl_train_util.add_sdxl_training_arguments(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--save_model_as",
|
||||
type=str,
|
||||
default="safetensors",
|
||||
choices=[None, "ckpt", "pt", "safetensors"],
|
||||
help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)",
|
||||
)
|
||||
parser.add_argument("--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数")
|
||||
parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み")
|
||||
parser.add_argument("--network_dim", type=int, default=None, help="network dimensions (rank) / モジュールの次元数")
|
||||
parser.add_argument(
|
||||
"--network_dropout",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする(0またはNoneはdropoutなし、1は全ニューロンをdropout)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--conditioning_data_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="conditioning data directory / 条件付けデータのディレクトリ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_half_vae",
|
||||
action="store_true",
|
||||
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# sdxl_original_unet.USE_REENTRANT = False
|
||||
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
train(args)
|
||||
571
sdxl_train_control_net_lllite_old.py
Normal file
571
sdxl_train_control_net_lllite_old.py
Normal file
@@ -0,0 +1,571 @@
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from multiprocessing import Value
|
||||
from types import SimpleNamespace
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler, ControlNetModel
|
||||
from safetensors.torch import load_file
|
||||
from library import sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util
|
||||
|
||||
import library.model_util as model_util
|
||||
import library.train_util as train_util
|
||||
import library.config_util as config_util
|
||||
from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
import library.huggingface_util as huggingface_util
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import (
|
||||
add_v_prediction_like_loss,
|
||||
apply_snr_weight,
|
||||
prepare_scheduler_for_custom_training,
|
||||
pyramid_noise_like,
|
||||
apply_noise_offset,
|
||||
scale_v_prediction_loss_like_noise_prediction,
|
||||
apply_debiased_estimation,
|
||||
)
|
||||
import networks.control_net_lllite as control_net_lllite
|
||||
|
||||
|
||||
# TODO 他のスクリプトと共通化する
|
||||
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
||||
logs = {
|
||||
"loss/current": current_loss,
|
||||
"loss/average": avr_loss,
|
||||
"lr": lr_scheduler.get_last_lr()[0],
|
||||
}
|
||||
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()):
|
||||
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
|
||||
|
||||
return logs
|
||||
|
||||
|
||||
def train(args):
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
sdxl_train_util.verify_sdxl_training_args(args)
|
||||
|
||||
cache_latents = args.cache_latents
|
||||
use_user_config = args.dataset_config is not None
|
||||
|
||||
if args.seed is None:
|
||||
args.seed = random.randint(0, 2**32)
|
||||
set_seed(args.seed)
|
||||
|
||||
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
|
||||
|
||||
# データセットを準備する
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
|
||||
if use_user_config:
|
||||
print(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "conditioning_data_dir"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
print(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
else:
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": config_util.generate_controlnet_subsets_config_by_subdirs(
|
||||
args.train_data_dir,
|
||||
args.conditioning_data_dir,
|
||||
args.caption_extension,
|
||||
)
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2])
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(32)
|
||||
|
||||
if args.debug_dataset:
|
||||
train_util.debug_dataset(train_dataset_group)
|
||||
return
|
||||
if len(train_dataset_group) == 0:
|
||||
print(
|
||||
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)"
|
||||
)
|
||||
return
|
||||
|
||||
if cache_latents:
|
||||
assert (
|
||||
train_dataset_group.is_latent_cacheable()
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
else:
|
||||
print("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません")
|
||||
|
||||
if args.cache_text_encoder_outputs:
|
||||
assert (
|
||||
train_dataset_group.is_text_encoder_output_cacheable()
|
||||
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
||||
|
||||
# acceleratorを準備する
|
||||
print("prepare accelerator")
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
is_main_process = accelerator.is_main_process
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
|
||||
|
||||
# モデルを読み込む
|
||||
(
|
||||
load_stable_diffusion_format,
|
||||
text_encoder1,
|
||||
text_encoder2,
|
||||
vae,
|
||||
unet,
|
||||
logit_scale,
|
||||
ckpt_info,
|
||||
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype)
|
||||
|
||||
# モデルに xformers とか memory efficient attention を組み込む
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
||||
|
||||
# 学習を準備する
|
||||
if cache_latents:
|
||||
vae.to(accelerator.device, dtype=vae_dtype)
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_latents(
|
||||
vae,
|
||||
args.vae_batch_size,
|
||||
args.cache_latents_to_disk,
|
||||
accelerator.is_main_process,
|
||||
)
|
||||
vae.to("cpu")
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# TextEncoderの出力をキャッシュする
|
||||
if args.cache_text_encoder_outputs:
|
||||
# Text Encodes are eval and no grad
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_text_encoder_outputs(
|
||||
(tokenizer1, tokenizer2),
|
||||
(text_encoder1, text_encoder2),
|
||||
accelerator.device,
|
||||
None,
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
accelerator.is_main_process,
|
||||
)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# prepare ControlNet
|
||||
network = control_net_lllite.ControlNetLLLite(unet, args.cond_emb_dim, args.network_dim, args.network_dropout)
|
||||
network.apply_to()
|
||||
|
||||
if args.network_weights is not None:
|
||||
info = network.load_weights(args.network_weights)
|
||||
accelerator.print(f"load ControlNet weights from {args.network_weights}: {info}")
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
network.enable_gradient_checkpointing() # may have no effect
|
||||
|
||||
# 学習に必要なクラスを準備する
|
||||
accelerator.print("prepare optimizer, data loader etc.")
|
||||
|
||||
trainable_params = list(network.prepare_optimizer_params())
|
||||
print(f"trainable params count: {len(trainable_params)}")
|
||||
print(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}")
|
||||
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||
)
|
||||
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
# データセット側にも学習ステップを送信
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
|
||||
if args.full_fp16:
|
||||
assert (
|
||||
args.mixed_precision == "fp16"
|
||||
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
||||
accelerator.print("enable full fp16 training.")
|
||||
unet.to(weight_dtype)
|
||||
network.to(weight_dtype)
|
||||
elif args.full_bf16:
|
||||
assert (
|
||||
args.mixed_precision == "bf16"
|
||||
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
|
||||
accelerator.print("enable full bf16 training.")
|
||||
unet.to(weight_dtype)
|
||||
network.to(weight_dtype)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, network, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
network: control_net_lllite.ControlNetLLLite
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
|
||||
else:
|
||||
unet.eval()
|
||||
|
||||
network.prepare_grad_etc()
|
||||
|
||||
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
|
||||
if args.cache_text_encoder_outputs:
|
||||
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
|
||||
text_encoder1.to("cpu", dtype=torch.float32)
|
||||
text_encoder2.to("cpu", dtype=torch.float32)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
# make sure Text Encoders are on GPU
|
||||
text_encoder1.to(accelerator.device)
|
||||
text_encoder2.to(accelerator.device)
|
||||
|
||||
if not cache_latents:
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
vae.to(accelerator.device, dtype=vae_dtype)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||
if args.full_fp16:
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
|
||||
# resumeする
|
||||
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||
|
||||
# epoch数を計算する
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
||||
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
||||
|
||||
# 学習する
|
||||
# TODO: find a way to handle total batch size when there are multiple datasets
|
||||
accelerator.print("running training / 学習開始")
|
||||
accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
||||
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
||||
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
||||
accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
|
||||
# print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
||||
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||
|
||||
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
||||
global_step = 0
|
||||
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
||||
)
|
||||
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
|
||||
if args.zero_terminal_snr:
|
||||
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers(
|
||||
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
|
||||
)
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
del train_dataset_group
|
||||
|
||||
# function for saving/removing
|
||||
def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
|
||||
accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
|
||||
sai_metadata = train_util.get_sai_model_spec(None, args, True, True, False)
|
||||
sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/control-net-lllite"
|
||||
|
||||
unwrapped_nw.save_weights(ckpt_file, save_dtype, sai_metadata)
|
||||
if args.huggingface_repo_id is not None:
|
||||
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
|
||||
|
||||
def remove_model(old_ckpt_name):
|
||||
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
||||
if os.path.exists(old_ckpt_file):
|
||||
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
|
||||
os.remove(old_ckpt_file)
|
||||
|
||||
# training loop
|
||||
for epoch in range(num_train_epochs):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
network.on_epoch_start() # train()
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(network):
|
||||
with torch.no_grad():
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device)
|
||||
else:
|
||||
# latentに変換
|
||||
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample()
|
||||
|
||||
# NaNが含まれていれば警告を表示し0に置き換える
|
||||
if torch.any(torch.isnan(latents)):
|
||||
accelerator.print("NaN found in latents, replacing with zeros")
|
||||
latents = torch.nan_to_num(latents, 0, out=latents)
|
||||
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
|
||||
|
||||
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
|
||||
input_ids1 = batch["input_ids"]
|
||||
input_ids2 = batch["input_ids2"]
|
||||
with torch.no_grad():
|
||||
# Get the text embedding for conditioning
|
||||
input_ids1 = input_ids1.to(accelerator.device)
|
||||
input_ids2 = input_ids2.to(accelerator.device)
|
||||
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
|
||||
args.max_token_length,
|
||||
input_ids1,
|
||||
input_ids2,
|
||||
tokenizer1,
|
||||
tokenizer2,
|
||||
text_encoder1,
|
||||
text_encoder2,
|
||||
None if not args.full_fp16 else weight_dtype,
|
||||
)
|
||||
else:
|
||||
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
|
||||
encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype)
|
||||
pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype)
|
||||
|
||||
# get size embeddings
|
||||
orig_size = batch["original_sizes_hw"]
|
||||
crop_size = batch["crop_top_lefts"]
|
||||
target_size = batch["target_sizes_hw"]
|
||||
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
|
||||
|
||||
# concat embeddings
|
||||
vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
|
||||
text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
|
||||
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||
|
||||
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
||||
|
||||
controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype)
|
||||
|
||||
with accelerator.autocast():
|
||||
# conditioning imageをControlNetに渡す / pass conditioning image to ControlNet
|
||||
# 内部でcond_embに変換される / it will be converted to cond_emb inside
|
||||
network.set_cond_image(controlnet_image)
|
||||
|
||||
# それらの値を使いつつ、U-Netでノイズを予測する / predict noise with U-Net using those values
|
||||
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
||||
|
||||
if args.v_parameterization:
|
||||
# v-parameterization training
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
if args.v_pred_like_loss:
|
||||
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
||||
if args.debiased_estimation_loss:
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
params_to_clip = network.get_trainable_params()
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
# sdxl_train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
|
||||
save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch)
|
||||
|
||||
if args.save_state:
|
||||
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
|
||||
|
||||
remove_step_no = train_util.get_remove_step_no(args, global_step)
|
||||
if remove_step_no is not None:
|
||||
remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
|
||||
remove_model(remove_ckpt_name)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||
avr_loss: float = loss_recorder.moving_average
|
||||
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss/epoch": loss_recorder.moving_average}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# 指定エポックごとにモデルを保存
|
||||
if args.save_every_n_epochs is not None:
|
||||
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
|
||||
if is_main_process and saving:
|
||||
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
|
||||
save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1)
|
||||
|
||||
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
|
||||
if remove_epoch_no is not None:
|
||||
remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no)
|
||||
remove_model(remove_ckpt_name)
|
||||
|
||||
if args.save_state:
|
||||
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
|
||||
|
||||
# self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
# end of epoch
|
||||
|
||||
if is_main_process:
|
||||
network = accelerator.unwrap_model(network)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
if is_main_process and args.save_state:
|
||||
train_util.save_state_on_train_end(args, accelerator)
|
||||
|
||||
if is_main_process:
|
||||
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
||||
save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True)
|
||||
|
||||
print("model saved.")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, False, True, True)
|
||||
train_util.add_training_arguments(parser, False)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
config_util.add_config_arguments(parser)
|
||||
custom_train_functions.add_custom_train_arguments(parser)
|
||||
sdxl_train_util.add_sdxl_training_arguments(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--save_model_as",
|
||||
type=str,
|
||||
default="safetensors",
|
||||
choices=[None, "ckpt", "pt", "safetensors"],
|
||||
help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)",
|
||||
)
|
||||
parser.add_argument("--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数")
|
||||
parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み")
|
||||
parser.add_argument("--network_dim", type=int, default=None, help="network dimensions (rank) / モジュールの次元数")
|
||||
parser.add_argument(
|
||||
"--network_dropout",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする(0またはNoneはdropoutなし、1は全ニューロンをdropout)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--conditioning_data_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="conditioning data directory / 条件付けデータのディレクトリ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_half_vae",
|
||||
action="store_true",
|
||||
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# sdxl_original_unet.USE_REENTRANT = False
|
||||
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
train(args)
|
||||
184
sdxl_train_network.py
Normal file
184
sdxl_train_network.py
Normal file
@@ -0,0 +1,184 @@
|
||||
import argparse
|
||||
import torch
|
||||
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
from library import sdxl_model_util, sdxl_train_util, train_util
|
||||
import train_network
|
||||
|
||||
|
||||
class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR
|
||||
self.is_sdxl = True
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group):
|
||||
super().assert_extra_args(args, train_dataset_group)
|
||||
sdxl_train_util.verify_sdxl_training_args(args)
|
||||
|
||||
if args.cache_text_encoder_outputs:
|
||||
assert (
|
||||
train_dataset_group.is_text_encoder_output_cacheable()
|
||||
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
||||
|
||||
assert (
|
||||
args.network_train_unet_only or not args.cache_text_encoder_outputs
|
||||
), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(32)
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
(
|
||||
load_stable_diffusion_format,
|
||||
text_encoder1,
|
||||
text_encoder2,
|
||||
vae,
|
||||
unet,
|
||||
logit_scale,
|
||||
ckpt_info,
|
||||
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype)
|
||||
|
||||
self.load_stable_diffusion_format = load_stable_diffusion_format
|
||||
self.logit_scale = logit_scale
|
||||
self.ckpt_info = ckpt_info
|
||||
|
||||
return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet
|
||||
|
||||
def load_tokenizer(self, args):
|
||||
tokenizer = sdxl_train_util.load_tokenizers(args)
|
||||
return tokenizer
|
||||
|
||||
def is_text_encoder_outputs_cached(self, args):
|
||||
return args.cache_text_encoder_outputs
|
||||
|
||||
def cache_text_encoder_outputs_if_needed(
|
||||
self, args, accelerator, unet, vae, tokenizers, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
|
||||
):
|
||||
if args.cache_text_encoder_outputs:
|
||||
if not args.lowram:
|
||||
# メモリ消費を減らす
|
||||
print("move vae and unet to cpu to save memory")
|
||||
org_vae_device = vae.device
|
||||
org_unet_device = unet.device
|
||||
vae.to("cpu")
|
||||
unet.to("cpu")
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
||||
with accelerator.autocast():
|
||||
dataset.cache_text_encoder_outputs(
|
||||
tokenizers,
|
||||
text_encoders,
|
||||
accelerator.device,
|
||||
weight_dtype,
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
accelerator.is_main_process,
|
||||
)
|
||||
|
||||
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
|
||||
text_encoders[1].to("cpu", dtype=torch.float32)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if not args.lowram:
|
||||
print("move vae and unet back to original device")
|
||||
vae.to(org_vae_device)
|
||||
unet.to(org_unet_device)
|
||||
else:
|
||||
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
|
||||
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
|
||||
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
|
||||
input_ids1 = batch["input_ids"]
|
||||
input_ids2 = batch["input_ids2"]
|
||||
with torch.enable_grad():
|
||||
# Get the text embedding for conditioning
|
||||
# TODO support weighted captions
|
||||
# if args.weighted_captions:
|
||||
# encoder_hidden_states = get_weighted_text_embeddings(
|
||||
# tokenizer,
|
||||
# text_encoder,
|
||||
# batch["captions"],
|
||||
# accelerator.device,
|
||||
# args.max_token_length // 75 if args.max_token_length else 1,
|
||||
# clip_skip=args.clip_skip,
|
||||
# )
|
||||
# else:
|
||||
input_ids1 = input_ids1.to(accelerator.device)
|
||||
input_ids2 = input_ids2.to(accelerator.device)
|
||||
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
|
||||
args.max_token_length,
|
||||
input_ids1,
|
||||
input_ids2,
|
||||
tokenizers[0],
|
||||
tokenizers[1],
|
||||
text_encoders[0],
|
||||
text_encoders[1],
|
||||
None if not args.full_fp16 else weight_dtype,
|
||||
accelerator=accelerator,
|
||||
)
|
||||
else:
|
||||
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
|
||||
encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype)
|
||||
pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype)
|
||||
|
||||
# # verify that the text encoder outputs are correct
|
||||
# ehs1, ehs2, p2 = train_util.get_hidden_states_sdxl(
|
||||
# args.max_token_length,
|
||||
# batch["input_ids"].to(text_encoders[0].device),
|
||||
# batch["input_ids2"].to(text_encoders[0].device),
|
||||
# tokenizers[0],
|
||||
# tokenizers[1],
|
||||
# text_encoders[0],
|
||||
# text_encoders[1],
|
||||
# None if not args.full_fp16 else weight_dtype,
|
||||
# )
|
||||
# b_size = encoder_hidden_states1.shape[0]
|
||||
# assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
||||
# assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
||||
# assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
||||
# print("text encoder outputs verified")
|
||||
|
||||
return encoder_hidden_states1, encoder_hidden_states2, pool2
|
||||
|
||||
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
||||
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
||||
|
||||
# get size embeddings
|
||||
orig_size = batch["original_sizes_hw"]
|
||||
crop_size = batch["crop_top_lefts"]
|
||||
target_size = batch["target_sizes_hw"]
|
||||
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
|
||||
|
||||
# concat embeddings
|
||||
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
|
||||
vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
|
||||
text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
|
||||
|
||||
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
||||
return noise_pred
|
||||
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
|
||||
sdxl_train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = train_network.setup_parser()
|
||||
sdxl_train_util.add_sdxl_training_arguments(parser)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
trainer = SdxlNetworkTrainer()
|
||||
trainer.train(args)
|
||||
137
sdxl_train_textual_inversion.py
Normal file
137
sdxl_train_textual_inversion.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import regex
|
||||
import torch
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
import open_clip
|
||||
from library import sdxl_model_util, sdxl_train_util, train_util
|
||||
|
||||
import train_textual_inversion
|
||||
|
||||
|
||||
class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTrainer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR
|
||||
self.is_sdxl = True
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group):
|
||||
super().assert_extra_args(args, train_dataset_group)
|
||||
sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False)
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(32)
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
(
|
||||
load_stable_diffusion_format,
|
||||
text_encoder1,
|
||||
text_encoder2,
|
||||
vae,
|
||||
unet,
|
||||
logit_scale,
|
||||
ckpt_info,
|
||||
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype)
|
||||
|
||||
self.load_stable_diffusion_format = load_stable_diffusion_format
|
||||
self.logit_scale = logit_scale
|
||||
self.ckpt_info = ckpt_info
|
||||
|
||||
return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet
|
||||
|
||||
def load_tokenizer(self, args):
|
||||
tokenizer = sdxl_train_util.load_tokenizers(args)
|
||||
return tokenizer
|
||||
|
||||
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
|
||||
input_ids1 = batch["input_ids"]
|
||||
input_ids2 = batch["input_ids2"]
|
||||
with torch.enable_grad():
|
||||
input_ids1 = input_ids1.to(accelerator.device)
|
||||
input_ids2 = input_ids2.to(accelerator.device)
|
||||
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
|
||||
args.max_token_length,
|
||||
input_ids1,
|
||||
input_ids2,
|
||||
tokenizers[0],
|
||||
tokenizers[1],
|
||||
text_encoders[0],
|
||||
text_encoders[1],
|
||||
None if not args.full_fp16 else weight_dtype,
|
||||
accelerator=accelerator,
|
||||
)
|
||||
return encoder_hidden_states1, encoder_hidden_states2, pool2
|
||||
|
||||
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
||||
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
||||
|
||||
# get size embeddings
|
||||
orig_size = batch["original_sizes_hw"]
|
||||
crop_size = batch["crop_top_lefts"]
|
||||
target_size = batch["target_sizes_hw"]
|
||||
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
|
||||
|
||||
# concat embeddings
|
||||
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
|
||||
vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
|
||||
text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
|
||||
|
||||
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
||||
return noise_pred
|
||||
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement):
|
||||
sdxl_train_util.sample_images(
|
||||
accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement
|
||||
)
|
||||
|
||||
def save_weights(self, file, updated_embs, save_dtype, metadata):
|
||||
state_dict = {"clip_l": updated_embs[0], "clip_g": updated_embs[1]}
|
||||
|
||||
if save_dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
v = state_dict[key]
|
||||
v = v.detach().clone().to("cpu").to(save_dtype)
|
||||
state_dict[key] = v
|
||||
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import save_file
|
||||
|
||||
save_file(state_dict, file, metadata)
|
||||
else:
|
||||
torch.save(state_dict, file)
|
||||
|
||||
def load_weights(self, file):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
data = load_file(file)
|
||||
else:
|
||||
data = torch.load(file, map_location="cpu")
|
||||
|
||||
emb_l = data.get("clip_l", None) # ViT-L text encoder 1
|
||||
emb_g = data.get("clip_g", None) # BiG-G text encoder 2
|
||||
|
||||
assert (
|
||||
emb_l is not None or emb_g is not None
|
||||
), f"weight file does not contains weights for text encoder 1 or 2 / 重みファイルにテキストエンコーダー1または2の重みが含まれていません: {file}"
|
||||
|
||||
return [emb_l, emb_g]
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = train_textual_inversion.setup_parser()
|
||||
# don't add sdxl_train_util.add_sdxl_training_arguments(parser): because it only adds text encoder caching
|
||||
# sdxl_train_util.add_sdxl_training_arguments(parser)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
trainer = SdxlTextualInversionTrainer()
|
||||
trainer.train(args)
|
||||
194
tools/cache_latents.py
Normal file
194
tools/cache_latents.py
Normal file
@@ -0,0 +1,194 @@
|
||||
# latentsのdiskへの事前キャッシュを行う / cache latents to disk
|
||||
|
||||
import argparse
|
||||
import math
|
||||
from multiprocessing import Value
|
||||
import os
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from library import config_util
|
||||
from library import train_util
|
||||
from library import sdxl_train_util
|
||||
from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
|
||||
|
||||
def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
|
||||
# check cache latents arg
|
||||
assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります"
|
||||
|
||||
use_dreambooth_method = args.in_json is None
|
||||
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed) # 乱数系列を初期化する
|
||||
|
||||
# tokenizerを準備する:datasetを動かすために必要
|
||||
if args.sdxl:
|
||||
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
|
||||
tokenizers = [tokenizer1, tokenizer2]
|
||||
else:
|
||||
tokenizer = train_util.load_tokenizer(args)
|
||||
tokenizers = [tokenizer]
|
||||
|
||||
# データセットを準備する
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
|
||||
if args.dataset_config is not None:
|
||||
print(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
print(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
else:
|
||||
if use_dreambooth_method:
|
||||
print("Using DreamBooth method.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
|
||||
args.train_data_dir, args.reg_data_dir
|
||||
)
|
||||
}
|
||||
]
|
||||
}
|
||||
else:
|
||||
print("Training with captions.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": [
|
||||
{
|
||||
"image_dir": args.train_data_dir,
|
||||
"metadata_file": args.in_json,
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers)
|
||||
|
||||
# datasetのcache_latentsを呼ばなければ、生の画像が返る
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
# acceleratorを準備する
|
||||
print("prepare accelerator")
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
weight_dtype, _ = train_util.prepare_dtype(args)
|
||||
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
|
||||
|
||||
# モデルを読み込む
|
||||
print("load model")
|
||||
if args.sdxl:
|
||||
(_, _, _, vae, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
|
||||
else:
|
||||
_, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||
|
||||
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
|
||||
vae.set_use_memory_efficient_attention_xformers(args.xformers)
|
||||
vae.to(accelerator.device, dtype=vae_dtype)
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
|
||||
# dataloaderを準備する
|
||||
train_dataset_group.set_caching_mode("latents")
|
||||
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
# acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず
|
||||
train_dataloader = accelerator.prepare(train_dataloader)
|
||||
|
||||
# データ取得のためのループ
|
||||
for batch in tqdm(train_dataloader):
|
||||
b_size = len(batch["images"])
|
||||
vae_batch_size = b_size if args.vae_batch_size is None else args.vae_batch_size
|
||||
flip_aug = batch["flip_aug"]
|
||||
random_crop = batch["random_crop"]
|
||||
bucket_reso = batch["bucket_reso"]
|
||||
|
||||
# バッチを分割して処理する
|
||||
for i in range(0, b_size, vae_batch_size):
|
||||
images = batch["images"][i : i + vae_batch_size]
|
||||
absolute_paths = batch["absolute_paths"][i : i + vae_batch_size]
|
||||
resized_sizes = batch["resized_sizes"][i : i + vae_batch_size]
|
||||
|
||||
image_infos = []
|
||||
for i, (image, absolute_path, resized_size) in enumerate(zip(images, absolute_paths, resized_sizes)):
|
||||
image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path)
|
||||
image_info.image = image
|
||||
image_info.bucket_reso = bucket_reso
|
||||
image_info.resized_size = resized_size
|
||||
image_info.latents_npz = os.path.splitext(absolute_path)[0] + ".npz"
|
||||
|
||||
if args.skip_existing:
|
||||
if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug):
|
||||
print(f"Skipping {image_info.latents_npz} because it already exists.")
|
||||
continue
|
||||
|
||||
image_infos.append(image_info)
|
||||
|
||||
if len(image_infos) > 0:
|
||||
train_util.cache_batch_latents(vae, True, image_infos, flip_aug, random_crop)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_training_arguments(parser, True)
|
||||
train_util.add_dataset_arguments(parser, True, True, True)
|
||||
config_util.add_config_arguments(parser)
|
||||
parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する")
|
||||
parser.add_argument(
|
||||
"--no_half_vae",
|
||||
action="store_true",
|
||||
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_existing",
|
||||
action="store_true",
|
||||
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
cache_to_disk(args)
|
||||
191
tools/cache_text_encoder_outputs.py
Normal file
191
tools/cache_text_encoder_outputs.py
Normal file
@@ -0,0 +1,191 @@
|
||||
# text encoder出力のdiskへの事前キャッシュを行う / cache text encoder outputs to disk in advance
|
||||
|
||||
import argparse
|
||||
import math
|
||||
from multiprocessing import Value
|
||||
import os
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from library import config_util
|
||||
from library import train_util
|
||||
from library import sdxl_train_util
|
||||
from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
|
||||
|
||||
def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
|
||||
# check cache arg
|
||||
assert (
|
||||
args.cache_text_encoder_outputs_to_disk
|
||||
), "cache_text_encoder_outputs_to_disk must be True / cache_text_encoder_outputs_to_diskはTrueである必要があります"
|
||||
|
||||
# できるだけ準備はしておくが今のところSDXLのみしか動かない
|
||||
assert (
|
||||
args.sdxl
|
||||
), "cache_text_encoder_outputs_to_disk is only available for SDXL / cache_text_encoder_outputs_to_diskはSDXLのみ利用可能です"
|
||||
|
||||
use_dreambooth_method = args.in_json is None
|
||||
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed) # 乱数系列を初期化する
|
||||
|
||||
# tokenizerを準備する:datasetを動かすために必要
|
||||
if args.sdxl:
|
||||
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
|
||||
tokenizers = [tokenizer1, tokenizer2]
|
||||
else:
|
||||
tokenizer = train_util.load_tokenizer(args)
|
||||
tokenizers = [tokenizer]
|
||||
|
||||
# データセットを準備する
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
|
||||
if args.dataset_config is not None:
|
||||
print(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
print(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
else:
|
||||
if use_dreambooth_method:
|
||||
print("Using DreamBooth method.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
|
||||
args.train_data_dir, args.reg_data_dir
|
||||
)
|
||||
}
|
||||
]
|
||||
}
|
||||
else:
|
||||
print("Training with captions.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": [
|
||||
{
|
||||
"image_dir": args.train_data_dir,
|
||||
"metadata_file": args.in_json,
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers)
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
# acceleratorを準備する
|
||||
print("prepare accelerator")
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
weight_dtype, _ = train_util.prepare_dtype(args)
|
||||
|
||||
# モデルを読み込む
|
||||
print("load model")
|
||||
if args.sdxl:
|
||||
(_, text_encoder1, text_encoder2, _, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
|
||||
text_encoders = [text_encoder1, text_encoder2]
|
||||
else:
|
||||
text_encoder1, _, _, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||
text_encoders = [text_encoder1]
|
||||
|
||||
for text_encoder in text_encoders:
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder.requires_grad_(False)
|
||||
text_encoder.eval()
|
||||
|
||||
# dataloaderを準備する
|
||||
train_dataset_group.set_caching_mode("text")
|
||||
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
# acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず
|
||||
train_dataloader = accelerator.prepare(train_dataloader)
|
||||
|
||||
# データ取得のためのループ
|
||||
for batch in tqdm(train_dataloader):
|
||||
absolute_paths = batch["absolute_paths"]
|
||||
input_ids1_list = batch["input_ids1_list"]
|
||||
input_ids2_list = batch["input_ids2_list"]
|
||||
|
||||
image_infos = []
|
||||
for absolute_path, input_ids1, input_ids2 in zip(absolute_paths, input_ids1_list, input_ids2_list):
|
||||
image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path)
|
||||
image_info.text_encoder_outputs_npz = os.path.splitext(absolute_path)[0] + train_util.TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX
|
||||
image_info
|
||||
|
||||
if args.skip_existing:
|
||||
if os.path.exists(image_info.text_encoder_outputs_npz):
|
||||
print(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.")
|
||||
continue
|
||||
|
||||
image_info.input_ids1 = input_ids1
|
||||
image_info.input_ids2 = input_ids2
|
||||
image_infos.append(image_info)
|
||||
|
||||
if len(image_infos) > 0:
|
||||
b_input_ids1 = torch.stack([image_info.input_ids1 for image_info in image_infos])
|
||||
b_input_ids2 = torch.stack([image_info.input_ids2 for image_info in image_infos])
|
||||
train_util.cache_batch_text_encoder_outputs(
|
||||
image_infos, tokenizers, text_encoders, args.max_token_length, True, b_input_ids1, b_input_ids2, weight_dtype
|
||||
)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_training_arguments(parser, True)
|
||||
train_util.add_dataset_arguments(parser, True, True, True)
|
||||
config_util.add_config_arguments(parser)
|
||||
sdxl_train_util.add_sdxl_training_arguments(parser)
|
||||
parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する")
|
||||
parser.add_argument(
|
||||
"--skip_existing",
|
||||
action="store_true",
|
||||
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
cache_to_disk(args)
|
||||
@@ -13,12 +13,18 @@ def canny(args):
|
||||
print("done!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input", type=str, default=None, help="input path")
|
||||
parser.add_argument("--output", type=str, default=None, help="output path")
|
||||
parser.add_argument("--thres1", type=int, default=32, help="thres1")
|
||||
parser.add_argument("--thres2", type=int, default=224, help="thres2")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
canny(args)
|
||||
|
||||
@@ -9,81 +9,152 @@ import library.model_util as model_util
|
||||
|
||||
|
||||
def convert(args):
|
||||
# 引数を確認する
|
||||
load_dtype = torch.float16 if args.fp16 else None
|
||||
# 引数を確認する
|
||||
load_dtype = torch.float16 if args.fp16 else None
|
||||
|
||||
save_dtype = None
|
||||
if args.fp16:
|
||||
save_dtype = torch.float16
|
||||
elif args.bf16:
|
||||
save_dtype = torch.bfloat16
|
||||
elif args.float:
|
||||
save_dtype = torch.float
|
||||
save_dtype = None
|
||||
if args.fp16 or args.save_precision_as == "fp16":
|
||||
save_dtype = torch.float16
|
||||
elif args.bf16 or args.save_precision_as == "bf16":
|
||||
save_dtype = torch.bfloat16
|
||||
elif args.float or args.save_precision_as == "float":
|
||||
save_dtype = torch.float
|
||||
|
||||
is_load_ckpt = os.path.isfile(args.model_to_load)
|
||||
is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0
|
||||
is_load_ckpt = os.path.isfile(args.model_to_load)
|
||||
is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0
|
||||
|
||||
assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です"
|
||||
assert is_save_ckpt or args.reference_model is not None, f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です"
|
||||
assert not is_load_ckpt or args.v1 != args.v2, "v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です"
|
||||
# assert (
|
||||
# is_save_ckpt or args.reference_model is not None
|
||||
# ), f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です"
|
||||
|
||||
# モデルを読み込む
|
||||
msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else ""))
|
||||
print(f"loading {msg}: {args.model_to_load}")
|
||||
# モデルを読み込む
|
||||
msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else ""))
|
||||
print(f"loading {msg}: {args.model_to_load}")
|
||||
|
||||
if is_load_ckpt:
|
||||
v2_model = args.v2
|
||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load)
|
||||
else:
|
||||
pipe = StableDiffusionPipeline.from_pretrained(args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None)
|
||||
text_encoder = pipe.text_encoder
|
||||
vae = pipe.vae
|
||||
unet = pipe.unet
|
||||
|
||||
if args.v1 == args.v2:
|
||||
# 自動判定する
|
||||
v2_model = unet.config.cross_attention_dim == 1024
|
||||
print("checking model version: model is " + ('v2' if v2_model else 'v1'))
|
||||
if is_load_ckpt:
|
||||
v2_model = args.v2
|
||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(
|
||||
v2_model, args.model_to_load, unet_use_linear_projection_in_v2=args.unet_use_linear_projection
|
||||
)
|
||||
else:
|
||||
v2_model = not args.v1
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None, variant=args.variant
|
||||
)
|
||||
text_encoder = pipe.text_encoder
|
||||
vae = pipe.vae
|
||||
unet = pipe.unet
|
||||
|
||||
# 変換して保存する
|
||||
msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers"
|
||||
print(f"converting and saving as {msg}: {args.model_to_save}")
|
||||
if args.v1 == args.v2:
|
||||
# 自動判定する
|
||||
v2_model = unet.config.cross_attention_dim == 1024
|
||||
print("checking model version: model is " + ("v2" if v2_model else "v1"))
|
||||
else:
|
||||
v2_model = not args.v1
|
||||
|
||||
if is_save_ckpt:
|
||||
original_model = args.model_to_load if is_load_ckpt else None
|
||||
key_count = model_util.save_stable_diffusion_checkpoint(v2_model, args.model_to_save, text_encoder, unet,
|
||||
original_model, args.epoch, args.global_step, save_dtype, vae)
|
||||
print(f"model saved. total converted state_dict keys: {key_count}")
|
||||
else:
|
||||
print(f"copy scheduler/tokenizer config from: {args.reference_model}")
|
||||
model_util.save_diffusers_checkpoint(v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors)
|
||||
print(f"model saved.")
|
||||
# 変換して保存する
|
||||
msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers"
|
||||
print(f"converting and saving as {msg}: {args.model_to_save}")
|
||||
|
||||
if is_save_ckpt:
|
||||
original_model = args.model_to_load if is_load_ckpt else None
|
||||
key_count = model_util.save_stable_diffusion_checkpoint(
|
||||
v2_model,
|
||||
args.model_to_save,
|
||||
text_encoder,
|
||||
unet,
|
||||
original_model,
|
||||
args.epoch,
|
||||
args.global_step,
|
||||
None if args.metadata is None else eval(args.metadata),
|
||||
save_dtype=save_dtype,
|
||||
vae=vae,
|
||||
)
|
||||
print(f"model saved. total converted state_dict keys: {key_count}")
|
||||
else:
|
||||
print(
|
||||
f"copy scheduler/tokenizer config from: {args.reference_model if args.reference_model is not None else 'default model'}"
|
||||
)
|
||||
model_util.save_diffusers_checkpoint(
|
||||
v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors
|
||||
)
|
||||
print("model saved.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--v1", action='store_true',
|
||||
help='load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む')
|
||||
parser.add_argument("--v2", action='store_true',
|
||||
help='load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む')
|
||||
parser.add_argument("--fp16", action='store_true',
|
||||
help='load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込み(Diffusers形式のみ対応)、保存する(checkpointのみ対応)')
|
||||
parser.add_argument("--bf16", action='store_true', help='save as bf16 (checkpoint only) / bf16形式で保存する(checkpointのみ対応)')
|
||||
parser.add_argument("--float", action='store_true',
|
||||
help='save as float (checkpoint only) / float(float32)形式で保存する(checkpointのみ対応)')
|
||||
parser.add_argument("--epoch", type=int, default=0, help='epoch to write to checkpoint / checkpointに記録するepoch数の値')
|
||||
parser.add_argument("--global_step", type=int, default=0,
|
||||
help='global_step to write to checkpoint / checkpointに記録するglobal_stepの値')
|
||||
parser.add_argument("--reference_model", type=str, default=None,
|
||||
help="reference model for schduler/tokenizer, required in saving Diffusers, copy schduler/tokenizer from this / scheduler/tokenizerのコピー元のDiffusersモデル、Diffusers形式で保存するときに必要")
|
||||
parser.add_argument("--use_safetensors", action='store_true',
|
||||
help="use safetensors format to save Diffusers model (checkpoint depends on the file extension) / Duffusersモデルをsafetensors形式で保存する(checkpointは拡張子で自動判定)")
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--v1", action="store_true", help="load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--v2", action="store_true", help="load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--unet_use_linear_projection",
|
||||
action="store_true",
|
||||
help="When saving v2 model as Diffusers, set U-Net config to `use_linear_projection=true` (to match stabilityai's model) / Diffusers形式でv2モデルを保存するときにU-Netの設定を`use_linear_projection=true`にする(stabilityaiのモデルと合わせる)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
action="store_true",
|
||||
help="load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込み(Diffusers形式のみ対応)、保存する(checkpointのみ対応)",
|
||||
)
|
||||
parser.add_argument("--bf16", action="store_true", help="save as bf16 (checkpoint only) / bf16形式で保存する(checkpointのみ対応)")
|
||||
parser.add_argument(
|
||||
"--float", action="store_true", help="save as float (checkpoint only) / float(float32)形式で保存する(checkpointのみ対応)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_precision_as",
|
||||
type=str,
|
||||
default="no",
|
||||
choices=["fp16", "bf16", "float"],
|
||||
help="save precision, do not specify with --fp16/--bf16/--float / 保存する精度、--fp16/--bf16/--floatと併用しないでください",
|
||||
)
|
||||
parser.add_argument("--epoch", type=int, default=0, help="epoch to write to checkpoint / checkpointに記録するepoch数の値")
|
||||
parser.add_argument(
|
||||
"--global_step", type=int, default=0, help="global_step to write to checkpoint / checkpointに記録するglobal_stepの値"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata",
|
||||
type=str,
|
||||
default=None,
|
||||
help='モデルに保存されるメタデータ、Pythonの辞書形式で指定 / metadata: metadata written in to the model in Python Dictionary. Example metadata: \'{"name": "model_name", "resolution": "512x512"}\'',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--variant",
|
||||
type=str,
|
||||
default=None,
|
||||
help="読む込むDiffusersのvariantを指定する、例: fp16 / variant: Diffusers variant to load. Example: fp16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reference_model",
|
||||
type=str,
|
||||
default=None,
|
||||
help="scheduler/tokenizerのコピー元Diffusersモデル、Diffusers形式で保存するときに使用される、省略時は`runwayml/stable-diffusion-v1-5` または `stabilityai/stable-diffusion-2-1` / reference Diffusers model to copy scheduler/tokenizer config from, used when saving as Diffusers format, default is `runwayml/stable-diffusion-v1-5` or `stabilityai/stable-diffusion-2-1`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_safetensors",
|
||||
action="store_true",
|
||||
help="use safetensors format to save Diffusers model (checkpoint depends on the file extension) / Duffusersモデルをsafetensors形式で保存する(checkpointは拡張子で自動判定)",
|
||||
)
|
||||
|
||||
parser.add_argument("model_to_load", type=str, default=None,
|
||||
help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ")
|
||||
parser.add_argument("model_to_save", type=str, default=None,
|
||||
help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存")
|
||||
parser.add_argument(
|
||||
"model_to_load",
|
||||
type=str,
|
||||
default=None,
|
||||
help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"model_to_save",
|
||||
type=str,
|
||||
default=None,
|
||||
help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存",
|
||||
)
|
||||
return parser
|
||||
|
||||
args = parser.parse_args()
|
||||
convert(args)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
convert(args)
|
||||
|
||||
@@ -214,7 +214,7 @@ def process(args):
|
||||
buf.tofile(f)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--src_dir", type=str, help="directory to load images / 画像を読み込むディレクトリ")
|
||||
parser.add_argument("--dst_dir", type=str, help="directory to save images / 画像を保存するディレクトリ")
|
||||
@@ -234,6 +234,13 @@ if __name__ == '__main__':
|
||||
parser.add_argument("--multiple_faces", action="store_true",
|
||||
help="output each faces / 複数の顔が見つかった場合、それぞれを切り出す")
|
||||
parser.add_argument("--debug", action="store_true", help="render rect for face / 処理後画像の顔位置に矩形を描画します")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
process(args)
|
||||
|
||||
348
tools/latent_upscaler.py
Normal file
348
tools/latent_upscaler.py
Normal file
@@ -0,0 +1,348 @@
|
||||
# 外部から簡単にupscalerを呼ぶためのスクリプト
|
||||
# 単体で動くようにモデル定義も含めている
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import cv2
|
||||
from diffusers import AutoencoderKL
|
||||
|
||||
from typing import Dict, List
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
if out_channels is None:
|
||||
out_channels = in_channels
|
||||
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(out_channels)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(out_channels)
|
||||
|
||||
self.relu2 = nn.ReLU(inplace=True) # このReLUはresidualに足す前にかけるほうがいいかも
|
||||
|
||||
# initialize weights
|
||||
self._initialize_weights()
|
||||
|
||||
def _initialize_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu1(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
out += residual
|
||||
|
||||
out = self.relu2(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Upscaler(nn.Module):
|
||||
def __init__(self):
|
||||
super(Upscaler, self).__init__()
|
||||
|
||||
# define layers
|
||||
# latent has 4 channels
|
||||
|
||||
self.conv1 = nn.Conv2d(4, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(128)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
|
||||
# resblocks
|
||||
# 数の暴力で20個:次元数を増やすよりもブロックを増やしたほうがreceptive fieldが広がるはずだぞ
|
||||
self.resblock1 = ResidualBlock(128)
|
||||
self.resblock2 = ResidualBlock(128)
|
||||
self.resblock3 = ResidualBlock(128)
|
||||
self.resblock4 = ResidualBlock(128)
|
||||
self.resblock5 = ResidualBlock(128)
|
||||
self.resblock6 = ResidualBlock(128)
|
||||
self.resblock7 = ResidualBlock(128)
|
||||
self.resblock8 = ResidualBlock(128)
|
||||
self.resblock9 = ResidualBlock(128)
|
||||
self.resblock10 = ResidualBlock(128)
|
||||
self.resblock11 = ResidualBlock(128)
|
||||
self.resblock12 = ResidualBlock(128)
|
||||
self.resblock13 = ResidualBlock(128)
|
||||
self.resblock14 = ResidualBlock(128)
|
||||
self.resblock15 = ResidualBlock(128)
|
||||
self.resblock16 = ResidualBlock(128)
|
||||
self.resblock17 = ResidualBlock(128)
|
||||
self.resblock18 = ResidualBlock(128)
|
||||
self.resblock19 = ResidualBlock(128)
|
||||
self.resblock20 = ResidualBlock(128)
|
||||
|
||||
# last convs
|
||||
self.conv2 = nn.Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(64)
|
||||
self.relu2 = nn.ReLU(inplace=True)
|
||||
|
||||
self.conv3 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(64)
|
||||
self.relu3 = nn.ReLU(inplace=True)
|
||||
|
||||
# final conv: output 4 channels
|
||||
self.conv_final = nn.Conv2d(64, 4, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
|
||||
|
||||
# initialize weights
|
||||
self._initialize_weights()
|
||||
|
||||
def _initialize_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
# initialize final conv weights to 0: 流行りのzero conv
|
||||
nn.init.constant_(self.conv_final.weight, 0)
|
||||
|
||||
def forward(self, x):
|
||||
inp = x
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu1(x)
|
||||
|
||||
# いくつかのresblockを通した後に、residualを足すことで精度向上と学習速度向上が見込めるはず
|
||||
residual = x
|
||||
x = self.resblock1(x)
|
||||
x = self.resblock2(x)
|
||||
x = self.resblock3(x)
|
||||
x = self.resblock4(x)
|
||||
x = x + residual
|
||||
residual = x
|
||||
x = self.resblock5(x)
|
||||
x = self.resblock6(x)
|
||||
x = self.resblock7(x)
|
||||
x = self.resblock8(x)
|
||||
x = x + residual
|
||||
residual = x
|
||||
x = self.resblock9(x)
|
||||
x = self.resblock10(x)
|
||||
x = self.resblock11(x)
|
||||
x = self.resblock12(x)
|
||||
x = x + residual
|
||||
residual = x
|
||||
x = self.resblock13(x)
|
||||
x = self.resblock14(x)
|
||||
x = self.resblock15(x)
|
||||
x = self.resblock16(x)
|
||||
x = x + residual
|
||||
residual = x
|
||||
x = self.resblock17(x)
|
||||
x = self.resblock18(x)
|
||||
x = self.resblock19(x)
|
||||
x = self.resblock20(x)
|
||||
x = x + residual
|
||||
|
||||
x = self.conv2(x)
|
||||
x = self.bn2(x)
|
||||
x = self.relu2(x)
|
||||
x = self.conv3(x)
|
||||
x = self.bn3(x)
|
||||
|
||||
# ここにreluを入れないほうがいい気がする
|
||||
|
||||
x = self.conv_final(x)
|
||||
|
||||
# network estimates the difference between the input and the output
|
||||
x = x + inp
|
||||
|
||||
return x
|
||||
|
||||
def support_latents(self) -> bool:
|
||||
return False
|
||||
|
||||
def upscale(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
lowreso_images: List[Image.Image],
|
||||
lowreso_latents: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
width: int,
|
||||
height: int,
|
||||
batch_size: int = 1,
|
||||
vae_batch_size: int = 1,
|
||||
):
|
||||
# assertion
|
||||
assert lowreso_images is not None, "Upscaler requires lowreso image"
|
||||
|
||||
# make upsampled image with lanczos4
|
||||
upsampled_images = []
|
||||
for lowreso_image in lowreso_images:
|
||||
upsampled_image = np.array(lowreso_image.resize((width, height), Image.LANCZOS))
|
||||
upsampled_images.append(upsampled_image)
|
||||
|
||||
# convert to tensor: this tensor is too large to be converted to cuda
|
||||
upsampled_images = [torch.from_numpy(upsampled_image).permute(2, 0, 1).float() for upsampled_image in upsampled_images]
|
||||
upsampled_images = torch.stack(upsampled_images, dim=0)
|
||||
upsampled_images = upsampled_images.to(dtype)
|
||||
|
||||
# normalize to [-1, 1]
|
||||
upsampled_images = upsampled_images / 127.5 - 1.0
|
||||
|
||||
# convert upsample images to latents with batch size
|
||||
# print("Encoding upsampled (LANCZOS4) images...")
|
||||
upsampled_latents = []
|
||||
for i in tqdm(range(0, upsampled_images.shape[0], vae_batch_size)):
|
||||
batch = upsampled_images[i : i + vae_batch_size].to(vae.device)
|
||||
with torch.no_grad():
|
||||
batch = vae.encode(batch).latent_dist.sample()
|
||||
upsampled_latents.append(batch)
|
||||
|
||||
upsampled_latents = torch.cat(upsampled_latents, dim=0)
|
||||
|
||||
# upscale (refine) latents with this model with batch size
|
||||
print("Upscaling latents...")
|
||||
upscaled_latents = []
|
||||
for i in range(0, upsampled_latents.shape[0], batch_size):
|
||||
with torch.no_grad():
|
||||
upscaled_latents.append(self.forward(upsampled_latents[i : i + batch_size]))
|
||||
upscaled_latents = torch.cat(upscaled_latents, dim=0)
|
||||
|
||||
return upscaled_latents * 0.18215
|
||||
|
||||
|
||||
# external interface: returns a model
|
||||
def create_upscaler(**kwargs):
|
||||
weights = kwargs["weights"]
|
||||
model = Upscaler()
|
||||
|
||||
print(f"Loading weights from {weights}...")
|
||||
if os.path.splitext(weights)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
sd = load_file(weights)
|
||||
else:
|
||||
sd = torch.load(weights, map_location=torch.device("cpu"))
|
||||
model.load_state_dict(sd)
|
||||
return model
|
||||
|
||||
|
||||
# another interface: upscale images with a model for given images from command line
|
||||
def upscale_images(args: argparse.Namespace):
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
us_dtype = torch.float16 # TODO: support fp32/bf16
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# load VAE with Diffusers
|
||||
assert args.vae_path is not None, "VAE path is required"
|
||||
print(f"Loading VAE from {args.vae_path}...")
|
||||
vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae")
|
||||
vae.to(DEVICE, dtype=us_dtype)
|
||||
|
||||
# prepare model
|
||||
print("Preparing model...")
|
||||
upscaler: Upscaler = create_upscaler(weights=args.weights)
|
||||
# print("Loading weights from", args.weights)
|
||||
# upscaler.load_state_dict(torch.load(args.weights))
|
||||
upscaler.eval()
|
||||
upscaler.to(DEVICE, dtype=us_dtype)
|
||||
|
||||
# load images
|
||||
image_paths = glob.glob(args.image_pattern)
|
||||
images = []
|
||||
for image_path in image_paths:
|
||||
image = Image.open(image_path)
|
||||
image = image.convert("RGB")
|
||||
|
||||
# make divisible by 8
|
||||
width = image.width
|
||||
height = image.height
|
||||
if width % 8 != 0:
|
||||
width = width - (width % 8)
|
||||
if height % 8 != 0:
|
||||
height = height - (height % 8)
|
||||
if width != image.width or height != image.height:
|
||||
image = image.crop((0, 0, width, height))
|
||||
|
||||
images.append(image)
|
||||
|
||||
# debug output
|
||||
if args.debug:
|
||||
for image, image_path in zip(images, image_paths):
|
||||
image_debug = image.resize((image.width * 2, image.height * 2), Image.LANCZOS)
|
||||
|
||||
basename = os.path.basename(image_path)
|
||||
basename_wo_ext, ext = os.path.splitext(basename)
|
||||
dest_file_name = os.path.join(args.output_dir, f"{basename_wo_ext}_lanczos4{ext}")
|
||||
image_debug.save(dest_file_name)
|
||||
|
||||
# upscale
|
||||
print("Upscaling...")
|
||||
upscaled_latents = upscaler.upscale(
|
||||
vae, images, None, us_dtype, width * 2, height * 2, batch_size=args.batch_size, vae_batch_size=args.vae_batch_size
|
||||
)
|
||||
upscaled_latents /= 0.18215
|
||||
|
||||
# decode with batch
|
||||
print("Decoding...")
|
||||
upscaled_images = []
|
||||
for i in tqdm(range(0, upscaled_latents.shape[0], args.vae_batch_size)):
|
||||
with torch.no_grad():
|
||||
batch = vae.decode(upscaled_latents[i : i + args.vae_batch_size]).sample
|
||||
batch = batch.to("cpu")
|
||||
upscaled_images.append(batch)
|
||||
upscaled_images = torch.cat(upscaled_images, dim=0)
|
||||
|
||||
# tensor to numpy
|
||||
upscaled_images = upscaled_images.permute(0, 2, 3, 1).numpy()
|
||||
upscaled_images = (upscaled_images + 1.0) * 127.5
|
||||
upscaled_images = upscaled_images.clip(0, 255).astype(np.uint8)
|
||||
|
||||
upscaled_images = upscaled_images[..., ::-1]
|
||||
|
||||
# save images
|
||||
for i, image in enumerate(upscaled_images):
|
||||
basename = os.path.basename(image_paths[i])
|
||||
basename_wo_ext, ext = os.path.splitext(basename)
|
||||
dest_file_name = os.path.join(args.output_dir, f"{basename_wo_ext}_upscaled{ext}")
|
||||
cv2.imwrite(dest_file_name, image)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--vae_path", type=str, default=None, help="VAE path")
|
||||
parser.add_argument("--weights", type=str, default=None, help="Weights path")
|
||||
parser.add_argument("--image_pattern", type=str, default=None, help="Image pattern")
|
||||
parser.add_argument("--output_dir", type=str, default=".", help="Output directory")
|
||||
parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
|
||||
parser.add_argument("--vae_batch_size", type=int, default=1, help="VAE batch size")
|
||||
parser.add_argument("--debug", action="store_true", help="Debug mode")
|
||||
|
||||
args = parser.parse_args()
|
||||
upscale_images(args)
|
||||
168
tools/merge_models.py
Normal file
168
tools/merge_models.py
Normal file
@@ -0,0 +1,168 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def is_unet_key(key):
|
||||
# VAE or TextEncoder, the last one is for SDXL
|
||||
return not ("first_stage_model" in key or "cond_stage_model" in key or "conditioner." in key)
|
||||
|
||||
|
||||
TEXT_ENCODER_KEY_REPLACEMENTS = [
|
||||
("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."),
|
||||
("cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder."),
|
||||
("cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm."),
|
||||
]
|
||||
|
||||
|
||||
# support for models with different text encoder keys
|
||||
def replace_text_encoder_key(key):
|
||||
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
|
||||
if key.startswith(rep_from):
|
||||
return True, rep_to + key[len(rep_from) :]
|
||||
return False, key
|
||||
|
||||
|
||||
def merge(args):
|
||||
if args.precision == "fp16":
|
||||
dtype = torch.float16
|
||||
elif args.precision == "bf16":
|
||||
dtype = torch.bfloat16
|
||||
else:
|
||||
dtype = torch.float
|
||||
|
||||
if args.saving_precision == "fp16":
|
||||
save_dtype = torch.float16
|
||||
elif args.saving_precision == "bf16":
|
||||
save_dtype = torch.bfloat16
|
||||
else:
|
||||
save_dtype = torch.float
|
||||
|
||||
# check if all models are safetensors
|
||||
for model in args.models:
|
||||
if not model.endswith("safetensors"):
|
||||
print(f"Model {model} is not a safetensors model")
|
||||
exit()
|
||||
if not os.path.isfile(model):
|
||||
print(f"Model {model} does not exist")
|
||||
exit()
|
||||
|
||||
assert args.ratios is None or len(args.models) == len(args.ratios), "ratios must be the same length as models"
|
||||
|
||||
# load and merge
|
||||
ratio = 1.0 / len(args.models) # default
|
||||
supplementary_key_ratios = {} # [key] = ratio, for keys not in all models, add later
|
||||
|
||||
merged_sd = None
|
||||
first_model_keys = set() # check missing keys in other models
|
||||
for i, model in enumerate(args.models):
|
||||
if args.ratios is not None:
|
||||
ratio = args.ratios[i]
|
||||
|
||||
if merged_sd is None:
|
||||
# load first model
|
||||
print(f"Loading model {model}, ratio = {ratio}...")
|
||||
merged_sd = {}
|
||||
with safe_open(model, framework="pt", device=args.device) as f:
|
||||
for key in tqdm(f.keys()):
|
||||
value = f.get_tensor(key)
|
||||
_, key = replace_text_encoder_key(key)
|
||||
|
||||
first_model_keys.add(key)
|
||||
|
||||
if not is_unet_key(key) and args.unet_only:
|
||||
supplementary_key_ratios[key] = 1.0 # use first model's value for VAE or TextEncoder
|
||||
continue
|
||||
|
||||
value = ratio * value.to(dtype) # first model's value * ratio
|
||||
merged_sd[key] = value
|
||||
|
||||
print(f"Model has {len(merged_sd)} keys " + ("(UNet only)" if args.unet_only else ""))
|
||||
continue
|
||||
|
||||
# load other models
|
||||
print(f"Loading model {model}, ratio = {ratio}...")
|
||||
|
||||
with safe_open(model, framework="pt", device=args.device) as f:
|
||||
model_keys = f.keys()
|
||||
for key in tqdm(model_keys):
|
||||
_, new_key = replace_text_encoder_key(key)
|
||||
if new_key not in merged_sd:
|
||||
if args.show_skipped and new_key not in first_model_keys:
|
||||
print(f"Skip: {new_key}")
|
||||
continue
|
||||
|
||||
value = f.get_tensor(key)
|
||||
merged_sd[new_key] = merged_sd[new_key] + ratio * value.to(dtype)
|
||||
|
||||
# enumerate keys not in this model
|
||||
model_keys = set(model_keys)
|
||||
for key in merged_sd.keys():
|
||||
if key in model_keys:
|
||||
continue
|
||||
print(f"Key {key} not in model {model}, use first model's value")
|
||||
if key in supplementary_key_ratios:
|
||||
supplementary_key_ratios[key] += ratio
|
||||
else:
|
||||
supplementary_key_ratios[key] = ratio
|
||||
|
||||
# add supplementary keys' value (including VAE and TextEncoder)
|
||||
if len(supplementary_key_ratios) > 0:
|
||||
print("add first model's value")
|
||||
with safe_open(args.models[0], framework="pt", device=args.device) as f:
|
||||
for key in tqdm(f.keys()):
|
||||
_, new_key = replace_text_encoder_key(key)
|
||||
if new_key not in supplementary_key_ratios:
|
||||
continue
|
||||
|
||||
if is_unet_key(new_key): # not VAE or TextEncoder
|
||||
print(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}")
|
||||
|
||||
value = f.get_tensor(key) # original key
|
||||
|
||||
if new_key not in merged_sd:
|
||||
merged_sd[new_key] = supplementary_key_ratios[new_key] * value.to(dtype)
|
||||
else:
|
||||
merged_sd[new_key] = merged_sd[new_key] + supplementary_key_ratios[new_key] * value.to(dtype)
|
||||
|
||||
# save
|
||||
output_file = args.output
|
||||
if not output_file.endswith(".safetensors"):
|
||||
output_file = output_file + ".safetensors"
|
||||
|
||||
print(f"Saving to {output_file}...")
|
||||
|
||||
# convert to save_dtype
|
||||
for k in merged_sd.keys():
|
||||
merged_sd[k] = merged_sd[k].to(save_dtype)
|
||||
|
||||
save_file(merged_sd, output_file)
|
||||
|
||||
print("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Merge models")
|
||||
parser.add_argument("--models", nargs="+", type=str, help="Models to merge")
|
||||
parser.add_argument("--output", type=str, help="Output model")
|
||||
parser.add_argument("--ratios", nargs="+", type=float, help="Ratios of models, default is equal, total = 1.0")
|
||||
parser.add_argument("--unet_only", action="store_true", help="Only merge unet")
|
||||
parser.add_argument("--device", type=str, default="cpu", help="Device to use, default is cpu")
|
||||
parser.add_argument(
|
||||
"--precision", type=str, default="float", choices=["float", "fp16", "bf16"], help="Calculation precision, default is float"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--saving_precision",
|
||||
type=str,
|
||||
default="float",
|
||||
choices=["float", "fp16", "bf16"],
|
||||
help="Saving precision, default is float",
|
||||
)
|
||||
parser.add_argument("--show_skipped", action="store_true", help="Show skipped keys (keys not in first model)")
|
||||
|
||||
args = parser.parse_args()
|
||||
merge(args)
|
||||
@@ -4,175 +4,187 @@ import cv2
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from diffusers import UNet2DConditionModel
|
||||
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
|
||||
from library.original_unet import UNet2DConditionModel, SampleOutput
|
||||
|
||||
import library.model_util as model_util
|
||||
|
||||
|
||||
class ControlNetInfo(NamedTuple):
|
||||
unet: Any
|
||||
net: Any
|
||||
prep: Any
|
||||
weight: float
|
||||
ratio: float
|
||||
unet: Any
|
||||
net: Any
|
||||
prep: Any
|
||||
weight: float
|
||||
ratio: float
|
||||
|
||||
|
||||
class ControlNet(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
# make control model
|
||||
self.control_model = torch.nn.Module()
|
||||
# make control model
|
||||
self.control_model = torch.nn.Module()
|
||||
|
||||
dims = [320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280, 1280]
|
||||
zero_convs = torch.nn.ModuleList()
|
||||
for i, dim in enumerate(dims):
|
||||
sub_list = torch.nn.ModuleList([torch.nn.Conv2d(dim, dim, 1)])
|
||||
zero_convs.append(sub_list)
|
||||
self.control_model.add_module("zero_convs", zero_convs)
|
||||
dims = [320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280, 1280]
|
||||
zero_convs = torch.nn.ModuleList()
|
||||
for i, dim in enumerate(dims):
|
||||
sub_list = torch.nn.ModuleList([torch.nn.Conv2d(dim, dim, 1)])
|
||||
zero_convs.append(sub_list)
|
||||
self.control_model.add_module("zero_convs", zero_convs)
|
||||
|
||||
middle_block_out = torch.nn.Conv2d(1280, 1280, 1)
|
||||
self.control_model.add_module("middle_block_out", torch.nn.ModuleList([middle_block_out]))
|
||||
middle_block_out = torch.nn.Conv2d(1280, 1280, 1)
|
||||
self.control_model.add_module("middle_block_out", torch.nn.ModuleList([middle_block_out]))
|
||||
|
||||
dims = [16, 16, 32, 32, 96, 96, 256, 320]
|
||||
strides = [1, 1, 2, 1, 2, 1, 2, 1]
|
||||
prev_dim = 3
|
||||
input_hint_block = torch.nn.Sequential()
|
||||
for i, (dim, stride) in enumerate(zip(dims, strides)):
|
||||
input_hint_block.append(torch.nn.Conv2d(prev_dim, dim, 3, stride, 1))
|
||||
if i < len(dims) - 1:
|
||||
input_hint_block.append(torch.nn.SiLU())
|
||||
prev_dim = dim
|
||||
self.control_model.add_module("input_hint_block", input_hint_block)
|
||||
dims = [16, 16, 32, 32, 96, 96, 256, 320]
|
||||
strides = [1, 1, 2, 1, 2, 1, 2, 1]
|
||||
prev_dim = 3
|
||||
input_hint_block = torch.nn.Sequential()
|
||||
for i, (dim, stride) in enumerate(zip(dims, strides)):
|
||||
input_hint_block.append(torch.nn.Conv2d(prev_dim, dim, 3, stride, 1))
|
||||
if i < len(dims) - 1:
|
||||
input_hint_block.append(torch.nn.SiLU())
|
||||
prev_dim = dim
|
||||
self.control_model.add_module("input_hint_block", input_hint_block)
|
||||
|
||||
|
||||
def load_control_net(v2, unet, model):
|
||||
device = unet.device
|
||||
device = unet.device
|
||||
|
||||
# control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのU-Netに読み込む
|
||||
# state dictを読み込む
|
||||
print(f"ControlNet: loading control SD model : {model}")
|
||||
# control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのU-Netに読み込む
|
||||
# state dictを読み込む
|
||||
print(f"ControlNet: loading control SD model : {model}")
|
||||
|
||||
if model_util.is_safetensors(model):
|
||||
ctrl_sd_sd = load_file(model)
|
||||
else:
|
||||
ctrl_sd_sd = torch.load(model, map_location='cpu')
|
||||
ctrl_sd_sd = ctrl_sd_sd.pop("state_dict", ctrl_sd_sd)
|
||||
if model_util.is_safetensors(model):
|
||||
ctrl_sd_sd = load_file(model)
|
||||
else:
|
||||
ctrl_sd_sd = torch.load(model, map_location="cpu")
|
||||
ctrl_sd_sd = ctrl_sd_sd.pop("state_dict", ctrl_sd_sd)
|
||||
|
||||
# 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む
|
||||
is_difference = "difference" in ctrl_sd_sd
|
||||
print("ControlNet: loading difference")
|
||||
# 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む
|
||||
is_difference = "difference" in ctrl_sd_sd
|
||||
print("ControlNet: loading difference:", is_difference)
|
||||
|
||||
# ControlNetには存在しないキーがあるので、まず現在のU-NetでSD版の全keyを作っておく
|
||||
# またTransfer Controlの元weightとなる
|
||||
ctrl_unet_sd_sd = model_util.convert_unet_state_dict_to_sd(v2, unet.state_dict())
|
||||
# ControlNetには存在しないキーがあるので、まず現在のU-NetでSD版の全keyを作っておく
|
||||
# またTransfer Controlの元weightとなる
|
||||
ctrl_unet_sd_sd = model_util.convert_unet_state_dict_to_sd(v2, unet.state_dict())
|
||||
|
||||
# 元のU-Netに影響しないようにコピーする。またprefixが付いていないので付ける
|
||||
for key in list(ctrl_unet_sd_sd.keys()):
|
||||
ctrl_unet_sd_sd["model.diffusion_model." + key] = ctrl_unet_sd_sd.pop(key).clone()
|
||||
# 元のU-Netに影響しないようにコピーする。またprefixが付いていないので付ける
|
||||
for key in list(ctrl_unet_sd_sd.keys()):
|
||||
ctrl_unet_sd_sd["model.diffusion_model." + key] = ctrl_unet_sd_sd.pop(key).clone()
|
||||
|
||||
zero_conv_sd = {}
|
||||
for key in list(ctrl_sd_sd.keys()):
|
||||
if key.startswith("control_"):
|
||||
unet_key = "model.diffusion_" + key[len("control_"):]
|
||||
if unet_key not in ctrl_unet_sd_sd: # zero conv
|
||||
zero_conv_sd[key] = ctrl_sd_sd[key]
|
||||
continue
|
||||
if is_difference: # Transfer Control
|
||||
ctrl_unet_sd_sd[unet_key] += ctrl_sd_sd[key].to(device, dtype=unet.dtype)
|
||||
else:
|
||||
ctrl_unet_sd_sd[unet_key] = ctrl_sd_sd[key].to(device, dtype=unet.dtype)
|
||||
zero_conv_sd = {}
|
||||
for key in list(ctrl_sd_sd.keys()):
|
||||
if key.startswith("control_"):
|
||||
unet_key = "model.diffusion_" + key[len("control_") :]
|
||||
if unet_key not in ctrl_unet_sd_sd: # zero conv
|
||||
zero_conv_sd[key] = ctrl_sd_sd[key]
|
||||
continue
|
||||
if is_difference: # Transfer Control
|
||||
ctrl_unet_sd_sd[unet_key] += ctrl_sd_sd[key].to(device, dtype=unet.dtype)
|
||||
else:
|
||||
ctrl_unet_sd_sd[unet_key] = ctrl_sd_sd[key].to(device, dtype=unet.dtype)
|
||||
|
||||
unet_config = model_util.create_unet_diffusers_config(v2)
|
||||
ctrl_unet_du_sd = model_util.convert_ldm_unet_checkpoint(v2, ctrl_unet_sd_sd, unet_config) # DiffUsers版ControlNetのstate dict
|
||||
unet_config = model_util.create_unet_diffusers_config(v2)
|
||||
ctrl_unet_du_sd = model_util.convert_ldm_unet_checkpoint(v2, ctrl_unet_sd_sd, unet_config) # DiffUsers版ControlNetのstate dict
|
||||
|
||||
# ControlNetのU-Netを作成する
|
||||
ctrl_unet = UNet2DConditionModel(**unet_config)
|
||||
info = ctrl_unet.load_state_dict(ctrl_unet_du_sd)
|
||||
print("ControlNet: loading Control U-Net:", info)
|
||||
# ControlNetのU-Netを作成する
|
||||
ctrl_unet = UNet2DConditionModel(**unet_config)
|
||||
info = ctrl_unet.load_state_dict(ctrl_unet_du_sd)
|
||||
print("ControlNet: loading Control U-Net:", info)
|
||||
|
||||
# U-Net以外のControlNetを作成する
|
||||
# TODO support middle only
|
||||
ctrl_net = ControlNet()
|
||||
info = ctrl_net.load_state_dict(zero_conv_sd)
|
||||
print("ControlNet: loading ControlNet:", info)
|
||||
# U-Net以外のControlNetを作成する
|
||||
# TODO support middle only
|
||||
ctrl_net = ControlNet()
|
||||
info = ctrl_net.load_state_dict(zero_conv_sd)
|
||||
print("ControlNet: loading ControlNet:", info)
|
||||
|
||||
ctrl_unet.to(unet.device, dtype=unet.dtype)
|
||||
ctrl_net.to(unet.device, dtype=unet.dtype)
|
||||
return ctrl_unet, ctrl_net
|
||||
ctrl_unet.to(unet.device, dtype=unet.dtype)
|
||||
ctrl_net.to(unet.device, dtype=unet.dtype)
|
||||
return ctrl_unet, ctrl_net
|
||||
|
||||
|
||||
def load_preprocess(prep_type: str):
|
||||
if prep_type is None or prep_type.lower() == "none":
|
||||
if prep_type is None or prep_type.lower() == "none":
|
||||
return None
|
||||
|
||||
if prep_type.startswith("canny"):
|
||||
args = prep_type.split("_")
|
||||
th1 = int(args[1]) if len(args) >= 2 else 63
|
||||
th2 = int(args[2]) if len(args) >= 3 else 191
|
||||
|
||||
def canny(img):
|
||||
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
||||
return cv2.Canny(img, th1, th2)
|
||||
|
||||
return canny
|
||||
|
||||
print("Unsupported prep type:", prep_type)
|
||||
return None
|
||||
|
||||
if prep_type.startswith("canny"):
|
||||
args = prep_type.split("_")
|
||||
th1 = int(args[1]) if len(args) >= 2 else 63
|
||||
th2 = int(args[2]) if len(args) >= 3 else 191
|
||||
|
||||
def canny(img):
|
||||
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
||||
return cv2.Canny(img, th1, th2)
|
||||
return canny
|
||||
|
||||
print("Unsupported prep type:", prep_type)
|
||||
return None
|
||||
|
||||
|
||||
def preprocess_ctrl_net_hint_image(image):
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[:, :, ::-1].copy() # rgb to bgr
|
||||
image = image[None].transpose(0, 3, 1, 2) # nchw
|
||||
image = torch.from_numpy(image)
|
||||
return image # 0 to 1
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
# ControlNetのサンプルはcv2を使っているが、読み込みはGradioなので実はRGBになっている
|
||||
# image = image[:, :, ::-1].copy() # rgb to bgr
|
||||
image = image[None].transpose(0, 3, 1, 2) # nchw
|
||||
image = torch.from_numpy(image)
|
||||
return image # 0 to 1
|
||||
|
||||
|
||||
def get_guided_hints(control_nets: List[ControlNetInfo], num_latent_input, b_size, hints):
|
||||
guided_hints = []
|
||||
for i, cnet_info in enumerate(control_nets):
|
||||
# hintは 1枚目の画像のcnet1, 1枚目の画像のcnet2, 1枚目の画像のcnet3, 2枚目の画像のcnet1, 2枚目の画像のcnet2 ... と並んでいること
|
||||
b_hints = []
|
||||
if len(hints) == 1: # すべて同じ画像をhintとして使う
|
||||
hint = hints[0]
|
||||
if cnet_info.prep is not None:
|
||||
hint = cnet_info.prep(hint)
|
||||
hint = preprocess_ctrl_net_hint_image(hint)
|
||||
b_hints = [hint for _ in range(b_size)]
|
||||
else:
|
||||
for bi in range(b_size):
|
||||
hint = hints[(bi * len(control_nets) + i) % len(hints)]
|
||||
if cnet_info.prep is not None:
|
||||
hint = cnet_info.prep(hint)
|
||||
hint = preprocess_ctrl_net_hint_image(hint)
|
||||
b_hints.append(hint)
|
||||
b_hints = torch.cat(b_hints, dim=0)
|
||||
b_hints = b_hints.to(cnet_info.unet.device, dtype=cnet_info.unet.dtype)
|
||||
guided_hints = []
|
||||
for i, cnet_info in enumerate(control_nets):
|
||||
# hintは 1枚目の画像のcnet1, 1枚目の画像のcnet2, 1枚目の画像のcnet3, 2枚目の画像のcnet1, 2枚目の画像のcnet2 ... と並んでいること
|
||||
b_hints = []
|
||||
if len(hints) == 1: # すべて同じ画像をhintとして使う
|
||||
hint = hints[0]
|
||||
if cnet_info.prep is not None:
|
||||
hint = cnet_info.prep(hint)
|
||||
hint = preprocess_ctrl_net_hint_image(hint)
|
||||
b_hints = [hint for _ in range(b_size)]
|
||||
else:
|
||||
for bi in range(b_size):
|
||||
hint = hints[(bi * len(control_nets) + i) % len(hints)]
|
||||
if cnet_info.prep is not None:
|
||||
hint = cnet_info.prep(hint)
|
||||
hint = preprocess_ctrl_net_hint_image(hint)
|
||||
b_hints.append(hint)
|
||||
b_hints = torch.cat(b_hints, dim=0)
|
||||
b_hints = b_hints.to(cnet_info.unet.device, dtype=cnet_info.unet.dtype)
|
||||
|
||||
guided_hint = cnet_info.net.control_model.input_hint_block(b_hints)
|
||||
guided_hints.append(guided_hint)
|
||||
return guided_hints
|
||||
guided_hint = cnet_info.net.control_model.input_hint_block(b_hints)
|
||||
guided_hints.append(guided_hint)
|
||||
return guided_hints
|
||||
|
||||
|
||||
def call_unet_and_control_net(step, num_latent_input, original_unet, control_nets: List[ControlNetInfo], guided_hints, current_ratio, sample, timestep, encoder_hidden_states):
|
||||
# ControlNet
|
||||
# 複数のControlNetの場合は、出力をマージするのではなく交互に適用する
|
||||
cnet_cnt = len(control_nets)
|
||||
cnet_idx = step % cnet_cnt
|
||||
cnet_info = control_nets[cnet_idx]
|
||||
def call_unet_and_control_net(
|
||||
step,
|
||||
num_latent_input,
|
||||
original_unet,
|
||||
control_nets: List[ControlNetInfo],
|
||||
guided_hints,
|
||||
current_ratio,
|
||||
sample,
|
||||
timestep,
|
||||
encoder_hidden_states,
|
||||
encoder_hidden_states_for_control_net,
|
||||
):
|
||||
# ControlNet
|
||||
# 複数のControlNetの場合は、出力をマージするのではなく交互に適用する
|
||||
cnet_cnt = len(control_nets)
|
||||
cnet_idx = step % cnet_cnt
|
||||
cnet_info = control_nets[cnet_idx]
|
||||
|
||||
# print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
|
||||
if cnet_info.ratio < current_ratio:
|
||||
return original_unet(sample, timestep, encoder_hidden_states)
|
||||
# print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
|
||||
if cnet_info.ratio < current_ratio:
|
||||
return original_unet(sample, timestep, encoder_hidden_states)
|
||||
|
||||
guided_hint = guided_hints[cnet_idx]
|
||||
guided_hint = guided_hint.repeat((num_latent_input, 1, 1, 1))
|
||||
outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states)
|
||||
outs = [o * cnet_info.weight for o in outs]
|
||||
guided_hint = guided_hints[cnet_idx]
|
||||
guided_hint = guided_hint.repeat((num_latent_input, 1, 1, 1))
|
||||
outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states_for_control_net)
|
||||
outs = [o * cnet_info.weight for o in outs]
|
||||
|
||||
# U-Net
|
||||
return unet_forward(False, cnet_info.net, original_unet, None, outs, sample, timestep, encoder_hidden_states)
|
||||
# U-Net
|
||||
return unet_forward(False, cnet_info.net, original_unet, None, outs, sample, timestep, encoder_hidden_states)
|
||||
|
||||
|
||||
"""
|
||||
@@ -203,118 +215,123 @@ def call_unet_and_control_net(step, num_latent_input, original_unet, control_net
|
||||
"""
|
||||
|
||||
|
||||
def unet_forward(is_control_net, control_net: ControlNet, unet: UNet2DConditionModel, guided_hint, ctrl_outs, sample, timestep, encoder_hidden_states):
|
||||
# copy from UNet2DConditionModel
|
||||
default_overall_up_factor = 2**unet.num_upsamplers
|
||||
def unet_forward(
|
||||
is_control_net,
|
||||
control_net: ControlNet,
|
||||
unet: UNet2DConditionModel,
|
||||
guided_hint,
|
||||
ctrl_outs,
|
||||
sample,
|
||||
timestep,
|
||||
encoder_hidden_states,
|
||||
):
|
||||
# copy from UNet2DConditionModel
|
||||
default_overall_up_factor = 2**unet.num_upsamplers
|
||||
|
||||
forward_upsample_size = False
|
||||
upsample_size = None
|
||||
forward_upsample_size = False
|
||||
upsample_size = None
|
||||
|
||||
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
||||
print("Forward upsample size to force interpolation output size.")
|
||||
forward_upsample_size = True
|
||||
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
||||
print("Forward upsample size to force interpolation output size.")
|
||||
forward_upsample_size = True
|
||||
|
||||
# 0. center input if necessary
|
||||
if unet.config.center_input_sample:
|
||||
sample = 2 * sample - 1.0
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||
# This would be a good case for the `match` statement (Python 3.10+)
|
||||
is_mps = sample.device.type == "mps"
|
||||
if isinstance(timestep, float):
|
||||
dtype = torch.float32 if is_mps else torch.float64
|
||||
else:
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
||||
elif len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||
# This would be a good case for the `match` statement (Python 3.10+)
|
||||
is_mps = sample.device.type == "mps"
|
||||
if isinstance(timestep, float):
|
||||
dtype = torch.float32 if is_mps else torch.float64
|
||||
else:
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
||||
elif len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
t_emb = unet.time_proj(timesteps)
|
||||
|
||||
t_emb = unet.time_proj(timesteps)
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=unet.dtype)
|
||||
emb = unet.time_embedding(t_emb)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=unet.dtype)
|
||||
emb = unet.time_embedding(t_emb)
|
||||
outs = [] # output of ControlNet
|
||||
zc_idx = 0
|
||||
|
||||
outs = [] # output of ControlNet
|
||||
zc_idx = 0
|
||||
|
||||
# 2. pre-process
|
||||
sample = unet.conv_in(sample)
|
||||
if is_control_net:
|
||||
sample += guided_hint
|
||||
outs.append(control_net.control_model.zero_convs[zc_idx][0](sample)) # , emb, encoder_hidden_states))
|
||||
zc_idx += 1
|
||||
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in unet.down_blocks:
|
||||
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
# 2. pre-process
|
||||
sample = unet.conv_in(sample)
|
||||
if is_control_net:
|
||||
for rs in res_samples:
|
||||
outs.append(control_net.control_model.zero_convs[zc_idx][0](rs)) # , emb, encoder_hidden_states))
|
||||
sample += guided_hint
|
||||
outs.append(control_net.control_model.zero_convs[zc_idx][0](sample)) # , emb, encoder_hidden_states))
|
||||
zc_idx += 1
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in unet.down_blocks:
|
||||
if downsample_block.has_cross_attention:
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
if is_control_net:
|
||||
for rs in res_samples:
|
||||
outs.append(control_net.control_model.zero_convs[zc_idx][0](rs)) # , emb, encoder_hidden_states))
|
||||
zc_idx += 1
|
||||
|
||||
# 4. mid
|
||||
sample = unet.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
|
||||
if is_control_net:
|
||||
outs.append(control_net.control_model.middle_block_out[0](sample))
|
||||
return outs
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
if not is_control_net:
|
||||
sample += ctrl_outs.pop()
|
||||
# 4. mid
|
||||
sample = unet.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
|
||||
if is_control_net:
|
||||
outs.append(control_net.control_model.middle_block_out[0](sample))
|
||||
return outs
|
||||
|
||||
# 5. up
|
||||
for i, upsample_block in enumerate(unet.up_blocks):
|
||||
is_final_block = i == len(unet.up_blocks) - 1
|
||||
if not is_control_net:
|
||||
sample += ctrl_outs.pop()
|
||||
|
||||
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
|
||||
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
||||
# 5. up
|
||||
for i, upsample_block in enumerate(unet.up_blocks):
|
||||
is_final_block = i == len(unet.up_blocks) - 1
|
||||
|
||||
if not is_control_net and len(ctrl_outs) > 0:
|
||||
res_samples = list(res_samples)
|
||||
apply_ctrl_outs = ctrl_outs[-len(res_samples):]
|
||||
ctrl_outs = ctrl_outs[:-len(res_samples)]
|
||||
for j in range(len(res_samples)):
|
||||
res_samples[j] = res_samples[j] + apply_ctrl_outs[j]
|
||||
res_samples = tuple(res_samples)
|
||||
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
||||
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
||||
|
||||
# if we have not reached the final block and need to forward the
|
||||
# upsample size, we do it here
|
||||
if not is_final_block and forward_upsample_size:
|
||||
upsample_size = down_block_res_samples[-1].shape[2:]
|
||||
if not is_control_net and len(ctrl_outs) > 0:
|
||||
res_samples = list(res_samples)
|
||||
apply_ctrl_outs = ctrl_outs[-len(res_samples) :]
|
||||
ctrl_outs = ctrl_outs[: -len(res_samples)]
|
||||
for j in range(len(res_samples)):
|
||||
res_samples[j] = res_samples[j] + apply_ctrl_outs[j]
|
||||
res_samples = tuple(res_samples)
|
||||
|
||||
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
||||
sample = upsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
upsample_size=upsample_size,
|
||||
)
|
||||
else:
|
||||
sample = upsample_block(
|
||||
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
||||
)
|
||||
# 6. post-process
|
||||
sample = unet.conv_norm_out(sample)
|
||||
sample = unet.conv_act(sample)
|
||||
sample = unet.conv_out(sample)
|
||||
# if we have not reached the final block and need to forward the
|
||||
# upsample size, we do it here
|
||||
if not is_final_block and forward_upsample_size:
|
||||
upsample_size = down_block_res_samples[-1].shape[2:]
|
||||
|
||||
return UNet2DConditionOutput(sample=sample)
|
||||
if upsample_block.has_cross_attention:
|
||||
sample = upsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
upsample_size=upsample_size,
|
||||
)
|
||||
else:
|
||||
sample = upsample_block(
|
||||
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
||||
)
|
||||
# 6. post-process
|
||||
sample = unet.conv_norm_out(sample)
|
||||
sample = unet.conv_act(sample)
|
||||
sample = unet.conv_out(sample)
|
||||
|
||||
return SampleOutput(sample=sample)
|
||||
|
||||
@@ -98,7 +98,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
|
||||
shutil.copy(os.path.join(src_img_folder, asoc_file), os.path.join(dst_img_folder, new_asoc_file))
|
||||
|
||||
|
||||
def main():
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Resize images in a folder to a specified max resolution(s) / 指定されたフォルダ内の画像を指定した最大画像サイズ(面積)以下にアスペクト比を維持したままリサイズします')
|
||||
parser.add_argument('src_img_folder', type=str, help='Source folder containing the images / 元画像のフォルダ')
|
||||
@@ -113,6 +113,12 @@ def main():
|
||||
parser.add_argument('--copy_associated_files', action='store_true',
|
||||
help='Copy files with same base name to images (captions etc) / 画像と同じファイル名(拡張子を除く)のファイルもコピーする')
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution,
|
||||
args.divisible_by, args.interpolation, args.save_as_png, args.copy_associated_files)
|
||||
|
||||
19
tools/show_metadata.py
Normal file
19
tools/show_metadata.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import json
|
||||
import argparse
|
||||
from safetensors import safe_open
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
with safe_open(args.model, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
|
||||
if metadata is None:
|
||||
print("No metadata found")
|
||||
else:
|
||||
# metadata is json dict, but not pretty printed
|
||||
# sort by key and pretty print
|
||||
print(json.dumps(metadata, indent=4, sort_keys=True))
|
||||
|
||||
|
||||
612
train_controlnet.py
Normal file
612
train_controlnet.py
Normal file
@@ -0,0 +1,612 @@
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from multiprocessing import Value
|
||||
from types import SimpleNamespace
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler, ControlNetModel
|
||||
from safetensors.torch import load_file
|
||||
|
||||
import library.model_util as model_util
|
||||
import library.train_util as train_util
|
||||
import library.config_util as config_util
|
||||
from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
import library.huggingface_util as huggingface_util
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import (
|
||||
apply_snr_weight,
|
||||
pyramid_noise_like,
|
||||
apply_noise_offset,
|
||||
)
|
||||
|
||||
|
||||
# TODO 他のスクリプトと共通化する
|
||||
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
||||
logs = {
|
||||
"loss/current": current_loss,
|
||||
"loss/average": avr_loss,
|
||||
"lr": lr_scheduler.get_last_lr()[0],
|
||||
}
|
||||
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()):
|
||||
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
|
||||
|
||||
return logs
|
||||
|
||||
|
||||
def train(args):
|
||||
# session_id = random.randint(0, 2**32)
|
||||
# training_started_at = time.time()
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
|
||||
cache_latents = args.cache_latents
|
||||
use_user_config = args.dataset_config is not None
|
||||
|
||||
if args.seed is None:
|
||||
args.seed = random.randint(0, 2**32)
|
||||
set_seed(args.seed)
|
||||
|
||||
tokenizer = train_util.load_tokenizer(args)
|
||||
|
||||
# データセットを準備する
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
|
||||
if use_user_config:
|
||||
print(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "conditioning_data_dir"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
print(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
else:
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": config_util.generate_controlnet_subsets_config_by_subdirs(
|
||||
args.train_data_dir,
|
||||
args.conditioning_data_dir,
|
||||
args.caption_extension,
|
||||
)
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
if args.debug_dataset:
|
||||
train_util.debug_dataset(train_dataset_group)
|
||||
return
|
||||
if len(train_dataset_group) == 0:
|
||||
print(
|
||||
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)"
|
||||
)
|
||||
return
|
||||
|
||||
if cache_latents:
|
||||
assert (
|
||||
train_dataset_group.is_latent_cacheable()
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
|
||||
# acceleratorを準備する
|
||||
print("prepare accelerator")
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
is_main_process = accelerator.is_main_process
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
|
||||
# モデルを読み込む
|
||||
text_encoder, vae, unet, _ = train_util.load_target_model(
|
||||
args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=True
|
||||
)
|
||||
|
||||
# DiffusersのControlNetが使用するデータを準備する
|
||||
if args.v2:
|
||||
unet.config = {
|
||||
"act_fn": "silu",
|
||||
"attention_head_dim": [5, 10, 20, 20],
|
||||
"block_out_channels": [320, 640, 1280, 1280],
|
||||
"center_input_sample": False,
|
||||
"cross_attention_dim": 1024,
|
||||
"down_block_types": ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"],
|
||||
"downsample_padding": 1,
|
||||
"dual_cross_attention": False,
|
||||
"flip_sin_to_cos": True,
|
||||
"freq_shift": 0,
|
||||
"in_channels": 4,
|
||||
"layers_per_block": 2,
|
||||
"mid_block_scale_factor": 1,
|
||||
"norm_eps": 1e-05,
|
||||
"norm_num_groups": 32,
|
||||
"num_class_embeds": None,
|
||||
"only_cross_attention": False,
|
||||
"out_channels": 4,
|
||||
"sample_size": 96,
|
||||
"up_block_types": ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"],
|
||||
"use_linear_projection": True,
|
||||
"upcast_attention": True,
|
||||
"only_cross_attention": False,
|
||||
"downsample_padding": 1,
|
||||
"use_linear_projection": True,
|
||||
"class_embed_type": None,
|
||||
"num_class_embeds": None,
|
||||
"resnet_time_scale_shift": "default",
|
||||
"projection_class_embeddings_input_dim": None,
|
||||
}
|
||||
else:
|
||||
unet.config = {
|
||||
"act_fn": "silu",
|
||||
"attention_head_dim": 8,
|
||||
"block_out_channels": [320, 640, 1280, 1280],
|
||||
"center_input_sample": False,
|
||||
"cross_attention_dim": 768,
|
||||
"down_block_types": ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"],
|
||||
"downsample_padding": 1,
|
||||
"flip_sin_to_cos": True,
|
||||
"freq_shift": 0,
|
||||
"in_channels": 4,
|
||||
"layers_per_block": 2,
|
||||
"mid_block_scale_factor": 1,
|
||||
"norm_eps": 1e-05,
|
||||
"norm_num_groups": 32,
|
||||
"out_channels": 4,
|
||||
"sample_size": 64,
|
||||
"up_block_types": ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"],
|
||||
"only_cross_attention": False,
|
||||
"downsample_padding": 1,
|
||||
"use_linear_projection": False,
|
||||
"class_embed_type": None,
|
||||
"num_class_embeds": None,
|
||||
"upcast_attention": False,
|
||||
"resnet_time_scale_shift": "default",
|
||||
"projection_class_embeddings_input_dim": None,
|
||||
}
|
||||
unet.config = SimpleNamespace(**unet.config)
|
||||
|
||||
controlnet = ControlNetModel.from_unet(unet)
|
||||
|
||||
if args.controlnet_model_name_or_path:
|
||||
filename = args.controlnet_model_name_or_path
|
||||
if os.path.isfile(filename):
|
||||
if os.path.splitext(filename)[1] == ".safetensors":
|
||||
state_dict = load_file(filename)
|
||||
else:
|
||||
state_dict = torch.load(filename)
|
||||
state_dict = model_util.convert_controlnet_state_dict_to_diffusers(state_dict)
|
||||
controlnet.load_state_dict(state_dict)
|
||||
elif os.path.isdir(filename):
|
||||
controlnet = ControlNetModel.from_pretrained(filename)
|
||||
|
||||
# モデルに xformers とか memory efficient attention を組み込む
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
||||
|
||||
# 学習を準備する
|
||||
if cache_latents:
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_latents(
|
||||
vae,
|
||||
args.vae_batch_size,
|
||||
args.cache_latents_to_disk,
|
||||
accelerator.is_main_process,
|
||||
)
|
||||
vae.to("cpu")
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
controlnet.enable_gradient_checkpointing()
|
||||
|
||||
# 学習に必要なクラスを準備する
|
||||
accelerator.print("prepare optimizer, data loader etc.")
|
||||
|
||||
trainable_params = controlnet.parameters()
|
||||
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||
)
|
||||
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
# データセット側にも学習ステップを送信
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
||||
if args.full_fp16:
|
||||
assert (
|
||||
args.mixed_precision == "fp16"
|
||||
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
||||
accelerator.print("enable full fp16 training.")
|
||||
controlnet.to(weight_dtype)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
controlnet, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
unet.requires_grad_(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
unet.to(accelerator.device)
|
||||
text_encoder.to(accelerator.device)
|
||||
|
||||
# transform DDP after prepare
|
||||
controlnet = controlnet.module if isinstance(controlnet, DDP) else controlnet
|
||||
|
||||
controlnet.train()
|
||||
|
||||
if not cache_latents:
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||
if args.full_fp16:
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
|
||||
# resumeする
|
||||
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||
|
||||
# epoch数を計算する
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
||||
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
||||
|
||||
# 学習する
|
||||
# TODO: find a way to handle total batch size when there are multiple datasets
|
||||
accelerator.print("running training / 学習開始")
|
||||
accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
||||
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
||||
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
||||
accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
|
||||
# print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
||||
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||
|
||||
progress_bar = tqdm(
|
||||
range(args.max_train_steps),
|
||||
smoothing=0,
|
||||
disable=not accelerator.is_local_main_process,
|
||||
desc="steps",
|
||||
)
|
||||
global_step = 0
|
||||
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
num_train_timesteps=1000,
|
||||
clip_sample=False,
|
||||
)
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
init_kwargs['wandb'] = {'name': args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers(
|
||||
"controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
|
||||
)
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
del train_dataset_group
|
||||
|
||||
# function for saving/removing
|
||||
def save_model(ckpt_name, model, force_sync_upload=False):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
|
||||
accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
|
||||
|
||||
state_dict = model_util.convert_controlnet_state_dict_to_sd(model.state_dict())
|
||||
|
||||
if save_dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
v = state_dict[key]
|
||||
v = v.detach().clone().to("cpu").to(save_dtype)
|
||||
state_dict[key] = v
|
||||
|
||||
if os.path.splitext(ckpt_file)[1] == ".safetensors":
|
||||
from safetensors.torch import save_file
|
||||
|
||||
save_file(state_dict, ckpt_file)
|
||||
else:
|
||||
torch.save(state_dict, ckpt_file)
|
||||
|
||||
if args.huggingface_repo_id is not None:
|
||||
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
|
||||
|
||||
def remove_model(old_ckpt_name):
|
||||
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
||||
if os.path.exists(old_ckpt_file):
|
||||
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
|
||||
os.remove(old_ckpt_file)
|
||||
|
||||
# For --sample_at_first
|
||||
train_util.sample_images(
|
||||
accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, controlnet=controlnet
|
||||
)
|
||||
|
||||
# training loop
|
||||
for epoch in range(num_train_epochs):
|
||||
if is_main_process:
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(controlnet):
|
||||
with torch.no_grad():
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device)
|
||||
else:
|
||||
# latentに変換
|
||||
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||
latents = latents * 0.18215
|
||||
b_size = latents.shape[0]
|
||||
|
||||
input_ids = batch["input_ids"].to(accelerator.device)
|
||||
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents, device=latents.device)
|
||||
if args.noise_offset:
|
||||
noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale)
|
||||
elif args.multires_noise_iterations:
|
||||
noise = pyramid_noise_like(
|
||||
noise,
|
||||
latents.device,
|
||||
args.multires_noise_iterations,
|
||||
args.multires_noise_discount,
|
||||
)
|
||||
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
0,
|
||||
noise_scheduler.config.num_train_timesteps,
|
||||
(b_size,),
|
||||
device=latents.device,
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype)
|
||||
|
||||
with accelerator.autocast():
|
||||
down_block_res_samples, mid_block_res_sample = controlnet(
|
||||
noisy_latents,
|
||||
timesteps,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
controlnet_cond=controlnet_image,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
# Predict the noise residual
|
||||
noise_pred = unet(
|
||||
noisy_latents,
|
||||
timesteps,
|
||||
encoder_hidden_states,
|
||||
down_block_additional_residuals=[sample.to(dtype=weight_dtype) for sample in down_block_res_samples],
|
||||
mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
|
||||
).sample
|
||||
|
||||
if args.v_parameterization:
|
||||
# v-parameterization training
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
params_to_clip = controlnet.parameters()
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
train_util.sample_images(
|
||||
accelerator,
|
||||
args,
|
||||
None,
|
||||
global_step,
|
||||
accelerator.device,
|
||||
vae,
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
unet,
|
||||
controlnet=controlnet,
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
|
||||
save_model(
|
||||
ckpt_name,
|
||||
accelerator.unwrap_model(controlnet),
|
||||
)
|
||||
|
||||
if args.save_state:
|
||||
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
|
||||
|
||||
remove_step_no = train_util.get_remove_step_no(args, global_step)
|
||||
if remove_step_no is not None:
|
||||
remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
|
||||
remove_model(remove_ckpt_name)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||
avr_loss: float = loss_recorder.moving_average
|
||||
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss/epoch": loss_recorder.moving_average}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# 指定エポックごとにモデルを保存
|
||||
if args.save_every_n_epochs is not None:
|
||||
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
|
||||
if is_main_process and saving:
|
||||
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
|
||||
save_model(ckpt_name, accelerator.unwrap_model(controlnet))
|
||||
|
||||
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
|
||||
if remove_epoch_no is not None:
|
||||
remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no)
|
||||
remove_model(remove_ckpt_name)
|
||||
|
||||
if args.save_state:
|
||||
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
|
||||
|
||||
train_util.sample_images(
|
||||
accelerator,
|
||||
args,
|
||||
epoch + 1,
|
||||
global_step,
|
||||
accelerator.device,
|
||||
vae,
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
unet,
|
||||
controlnet=controlnet,
|
||||
)
|
||||
|
||||
# end of epoch
|
||||
if is_main_process:
|
||||
controlnet = accelerator.unwrap_model(controlnet)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
if is_main_process and args.save_state:
|
||||
train_util.save_state_on_train_end(args, accelerator)
|
||||
|
||||
# del accelerator # この後メモリを使うのでこれは消す→printで使うので消さずにおく
|
||||
|
||||
if is_main_process:
|
||||
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
||||
save_model(ckpt_name, controlnet, force_sync_upload=True)
|
||||
|
||||
print("model saved.")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, False, True, True)
|
||||
train_util.add_training_arguments(parser, False)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
config_util.add_config_arguments(parser)
|
||||
custom_train_functions.add_custom_train_arguments(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--save_model_as",
|
||||
type=str,
|
||||
default="safetensors",
|
||||
choices=[None, "ckpt", "pt", "safetensors"],
|
||||
help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--controlnet_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="controlnet model name or path / controlnetのモデル名またはパス",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--conditioning_data_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="conditioning data directory / 条件付けデータのディレクトリ",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
train(args)
|
||||
760
train_db.py
760
train_db.py
@@ -2,363 +2,497 @@
|
||||
# XXX dropped option: fine_tune
|
||||
|
||||
import gc
|
||||
import time
|
||||
import argparse
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
from multiprocessing import Value
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
import diffusers
|
||||
from diffusers import DDPMScheduler
|
||||
|
||||
import library.train_util as train_util
|
||||
import library.config_util as config_util
|
||||
from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import (
|
||||
apply_snr_weight,
|
||||
get_weighted_text_embeddings,
|
||||
prepare_scheduler_for_custom_training,
|
||||
pyramid_noise_like,
|
||||
apply_noise_offset,
|
||||
scale_v_prediction_loss_like_noise_prediction,
|
||||
apply_debiased_estimation,
|
||||
)
|
||||
|
||||
|
||||
def collate_fn(examples):
|
||||
return examples[0]
|
||||
# perlin_noise,
|
||||
|
||||
|
||||
def train(args):
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, False)
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, False)
|
||||
|
||||
cache_latents = args.cache_latents
|
||||
cache_latents = args.cache_latents
|
||||
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed) # 乱数系列を初期化する
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed) # 乱数系列を初期化する
|
||||
|
||||
tokenizer = train_util.load_tokenizer(args)
|
||||
tokenizer = train_util.load_tokenizer(args)
|
||||
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True))
|
||||
if args.dataset_config is not None:
|
||||
print(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "reg_data_dir"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
|
||||
else:
|
||||
user_config = {
|
||||
"datasets": [{
|
||||
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)
|
||||
}]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
|
||||
if args.no_token_padding:
|
||||
train_dataset_group.disable_token_padding()
|
||||
|
||||
if args.debug_dataset:
|
||||
train_util.debug_dataset(train_dataset_group)
|
||||
return
|
||||
|
||||
if cache_latents:
|
||||
assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
|
||||
# acceleratorを準備する
|
||||
print("prepare accelerator")
|
||||
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
print(f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong")
|
||||
print(
|
||||
f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデル(U-NetおよびText Encoder)の学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です")
|
||||
|
||||
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
|
||||
# モデルを読み込む
|
||||
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype)
|
||||
|
||||
# verify load/save model formats
|
||||
if load_stable_diffusion_format:
|
||||
src_stable_diffusion_ckpt = args.pretrained_model_name_or_path
|
||||
src_diffusers_model_path = None
|
||||
else:
|
||||
src_stable_diffusion_ckpt = None
|
||||
src_diffusers_model_path = args.pretrained_model_name_or_path
|
||||
|
||||
if args.save_model_as is None:
|
||||
save_stable_diffusion_format = load_stable_diffusion_format
|
||||
use_safetensors = args.use_safetensors
|
||||
else:
|
||||
save_stable_diffusion_format = args.save_model_as.lower() == 'ckpt' or args.save_model_as.lower() == 'safetensors'
|
||||
use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
|
||||
|
||||
# モデルに xformers とか memory efficient attention を組み込む
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
||||
|
||||
# 学習を準備する
|
||||
if cache_latents:
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_latents(vae)
|
||||
vae.to("cpu")
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
# 学習を準備する:モデルを適切な状態にする
|
||||
train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0
|
||||
unet.requires_grad_(True) # 念のため追加
|
||||
text_encoder.requires_grad_(train_text_encoder)
|
||||
if not train_text_encoder:
|
||||
print("Text Encoder is not trained.")
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
|
||||
if not cache_latents:
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# 学習に必要なクラスを準備する
|
||||
print("prepare optimizer, data loader etc.")
|
||||
if train_text_encoder:
|
||||
trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters()))
|
||||
else:
|
||||
trainable_params = unet.parameters()
|
||||
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
|
||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
if args.stop_text_encoder_training is None:
|
||||
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
|
||||
|
||||
# lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する
|
||||
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
||||
num_training_steps=args.max_train_steps,
|
||||
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
||||
if args.full_fp16:
|
||||
assert args.mixed_precision == "fp16", "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
||||
print("enable full fp16 training.")
|
||||
unet.to(weight_dtype)
|
||||
text_encoder.to(weight_dtype)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
if train_text_encoder:
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler)
|
||||
else:
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
if not train_text_encoder:
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||
if args.full_fp16:
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
|
||||
# resumeする
|
||||
if args.resume is not None:
|
||||
print(f"resume training from state: {args.resume}")
|
||||
accelerator.load_state(args.resume)
|
||||
|
||||
# epoch数を計算する
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
||||
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
||||
|
||||
# 学習する
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
print("running training / 学習開始")
|
||||
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
||||
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
||||
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||
print(f" num epochs / epoch数: {num_train_epochs}")
|
||||
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
||||
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
||||
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||
|
||||
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
||||
global_step = 0
|
||||
|
||||
noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
|
||||
num_train_timesteps=1000, clip_sample=False)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("dreambooth")
|
||||
|
||||
loss_list = []
|
||||
loss_total = 0.0
|
||||
for epoch in range(num_train_epochs):
|
||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||
train_dataset_group.set_current_epoch(epoch + 1)
|
||||
|
||||
# 指定したステップ数までText Encoderを学習する:epoch最初の状態
|
||||
unet.train()
|
||||
# train==True is required to enable gradient_checkpointing
|
||||
if args.gradient_checkpointing or global_step < args.stop_text_encoder_training:
|
||||
text_encoder.train()
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# 指定したステップ数でText Encoderの学習を止める
|
||||
if global_step == args.stop_text_encoder_training:
|
||||
print(f"stop text encoder training at step {global_step}")
|
||||
if not args.gradient_checkpointing:
|
||||
text_encoder.train(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
|
||||
with accelerator.accumulate(unet):
|
||||
with torch.no_grad():
|
||||
# latentに変換
|
||||
if cache_latents:
|
||||
latents = batch["latents"].to(accelerator.device)
|
||||
else:
|
||||
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||
latents = latents * 0.18215
|
||||
b_size = latents.shape[0]
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents, device=latents.device)
|
||||
if args.noise_offset:
|
||||
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):
|
||||
input_ids = batch["input_ids"].to(accelerator.device)
|
||||
encoder_hidden_states = train_util.get_hidden_states(
|
||||
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype)
|
||||
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Predict the noise residual
|
||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
if args.v_parameterization:
|
||||
# v-parameterization training
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
# データセットを準備する
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, False, True))
|
||||
if args.dataset_config is not None:
|
||||
print(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "reg_data_dir"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
print(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
else:
|
||||
target = noise
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
|
||||
]
|
||||
}
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
if args.no_token_padding:
|
||||
train_dataset_group.disable_token_padding()
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
if train_text_encoder:
|
||||
params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()))
|
||||
else:
|
||||
params_to_clip = unet.parameters()
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
if args.debug_dataset:
|
||||
train_util.debug_dataset(train_dataset_group)
|
||||
return
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
if cache_latents:
|
||||
assert (
|
||||
train_dataset_group.is_latent_cacheable()
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
# acceleratorを準備する
|
||||
print("prepare accelerator")
|
||||
|
||||
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
print(
|
||||
f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong"
|
||||
)
|
||||
print(
|
||||
f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデル(U-NetおよびText Encoder)の学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です"
|
||||
)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
||||
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
|
||||
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
|
||||
accelerator.log(logs, step=global_step)
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
|
||||
if epoch == 0:
|
||||
loss_list.append(current_loss)
|
||||
else:
|
||||
loss_total -= loss_list[step]
|
||||
loss_list[step] = current_loss
|
||||
loss_total += current_loss
|
||||
avr_loss = loss_total / len(loss_list)
|
||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
# モデルを読み込む
|
||||
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss/epoch": loss_total / len(loss_list)}
|
||||
accelerator.log(logs, step=epoch+1)
|
||||
# verify load/save model formats
|
||||
if load_stable_diffusion_format:
|
||||
src_stable_diffusion_ckpt = args.pretrained_model_name_or_path
|
||||
src_diffusers_model_path = None
|
||||
else:
|
||||
src_stable_diffusion_ckpt = None
|
||||
src_diffusers_model_path = args.pretrained_model_name_or_path
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
if args.save_model_as is None:
|
||||
save_stable_diffusion_format = load_stable_diffusion_format
|
||||
use_safetensors = args.use_safetensors
|
||||
else:
|
||||
save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors"
|
||||
use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
|
||||
|
||||
if args.save_every_n_epochs is not None:
|
||||
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
||||
train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
|
||||
save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
|
||||
# モデルに xformers とか memory efficient attention を組み込む
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
||||
|
||||
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
# 学習を準備する
|
||||
if cache_latents:
|
||||
vae.to(accelerator.device, dtype=vae_dtype)
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||
vae.to("cpu")
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
is_main_process = accelerator.is_main_process
|
||||
if is_main_process:
|
||||
unet = unwrap_model(unet)
|
||||
text_encoder = unwrap_model(text_encoder)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
accelerator.end_training()
|
||||
# 学習を準備する:モデルを適切な状態にする
|
||||
train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0
|
||||
unet.requires_grad_(True) # 念のため追加
|
||||
text_encoder.requires_grad_(train_text_encoder)
|
||||
if not train_text_encoder:
|
||||
accelerator.print("Text Encoder is not trained.")
|
||||
|
||||
if args.save_state:
|
||||
train_util.save_state_on_train_end(args, accelerator)
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
|
||||
del accelerator # この後メモリを使うのでこれは消す
|
||||
if not cache_latents:
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
if is_main_process:
|
||||
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
||||
train_util.save_sd_model_on_train_end(args, src_path, save_stable_diffusion_format, use_safetensors,
|
||||
save_dtype, epoch, global_step, text_encoder, unet, vae)
|
||||
print("model saved.")
|
||||
# 学習に必要なクラスを準備する
|
||||
accelerator.print("prepare optimizer, data loader etc.")
|
||||
if train_text_encoder:
|
||||
if args.learning_rate_te is None:
|
||||
# wightout list, adamw8bit is crashed
|
||||
trainable_params = list(itertools.chain(unet.parameters(), text_encoder.parameters()))
|
||||
else:
|
||||
trainable_params = [
|
||||
{"params": list(unet.parameters()), "lr": args.learning_rate},
|
||||
{"params": list(text_encoder.parameters()), "lr": args.learning_rate_te},
|
||||
]
|
||||
else:
|
||||
trainable_params = unet.parameters()
|
||||
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||
)
|
||||
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
# データセット側にも学習ステップを送信
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
|
||||
if args.stop_text_encoder_training is None:
|
||||
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
|
||||
|
||||
# lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
||||
if args.full_fp16:
|
||||
assert (
|
||||
args.mixed_precision == "fp16"
|
||||
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
||||
accelerator.print("enable full fp16 training.")
|
||||
unet.to(weight_dtype)
|
||||
text_encoder.to(weight_dtype)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
if train_text_encoder:
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
else:
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
if not train_text_encoder:
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||
if args.full_fp16:
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
|
||||
# resumeする
|
||||
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||
|
||||
# epoch数を計算する
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
||||
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
||||
|
||||
# 学習する
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
accelerator.print("running training / 学習開始")
|
||||
accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
||||
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
||||
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
||||
accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
||||
accelerator.print(
|
||||
f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
|
||||
)
|
||||
accelerator.print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||
|
||||
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
||||
global_step = 0
|
||||
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
||||
)
|
||||
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
|
||||
if args.zero_terminal_snr:
|
||||
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
init_kwargs['wandb'] = {'name': args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
||||
|
||||
# For --sample_at_first
|
||||
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
for epoch in range(num_train_epochs):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
# 指定したステップ数までText Encoderを学習する:epoch最初の状態
|
||||
unet.train()
|
||||
# train==True is required to enable gradient_checkpointing
|
||||
if args.gradient_checkpointing or global_step < args.stop_text_encoder_training:
|
||||
text_encoder.train()
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
# 指定したステップ数でText Encoderの学習を止める
|
||||
if global_step == args.stop_text_encoder_training:
|
||||
accelerator.print(f"stop text encoder training at step {global_step}")
|
||||
if not args.gradient_checkpointing:
|
||||
text_encoder.train(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
|
||||
with accelerator.accumulate(unet):
|
||||
with torch.no_grad():
|
||||
# latentに変換
|
||||
if cache_latents:
|
||||
latents = batch["latents"].to(accelerator.device)
|
||||
else:
|
||||
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||
latents = latents * 0.18215
|
||||
b_size = latents.shape[0]
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):
|
||||
if args.weighted_captions:
|
||||
encoder_hidden_states = get_weighted_text_embeddings(
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
batch["captions"],
|
||||
accelerator.device,
|
||||
args.max_token_length // 75 if args.max_token_length else 1,
|
||||
clip_skip=args.clip_skip,
|
||||
)
|
||||
else:
|
||||
input_ids = batch["input_ids"].to(accelerator.device)
|
||||
encoder_hidden_states = train_util.get_hidden_states(
|
||||
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
|
||||
)
|
||||
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||
|
||||
# Predict the noise residual
|
||||
with accelerator.autocast():
|
||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
if args.v_parameterization:
|
||||
# v-parameterization training
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
if args.debiased_estimation_loss:
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
if train_text_encoder:
|
||||
params_to_clip = itertools.chain(unet.parameters(), text_encoder.parameters())
|
||||
else:
|
||||
params_to_clip = unet.parameters()
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
train_util.sample_images(
|
||||
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
||||
train_util.save_sd_model_on_epoch_end_or_stepwise(
|
||||
args,
|
||||
False,
|
||||
accelerator,
|
||||
src_path,
|
||||
save_stable_diffusion_format,
|
||||
use_safetensors,
|
||||
save_dtype,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
accelerator.unwrap_model(text_encoder),
|
||||
accelerator.unwrap_model(unet),
|
||||
vae,
|
||||
)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss}
|
||||
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||
avr_loss: float = loss_recorder.moving_average
|
||||
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss/epoch": loss_recorder.moving_average}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if args.save_every_n_epochs is not None:
|
||||
if accelerator.is_main_process:
|
||||
# checking for saving is in util
|
||||
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
||||
train_util.save_sd_model_on_epoch_end_or_stepwise(
|
||||
args,
|
||||
True,
|
||||
accelerator,
|
||||
src_path,
|
||||
save_stable_diffusion_format,
|
||||
use_safetensors,
|
||||
save_dtype,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
accelerator.unwrap_model(text_encoder),
|
||||
accelerator.unwrap_model(unet),
|
||||
vae,
|
||||
)
|
||||
|
||||
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
is_main_process = accelerator.is_main_process
|
||||
if is_main_process:
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
if args.save_state and is_main_process:
|
||||
train_util.save_state_on_train_end(args, accelerator)
|
||||
|
||||
del accelerator # この後メモリを使うのでこれは消す
|
||||
|
||||
if is_main_process:
|
||||
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
||||
train_util.save_sd_model_on_train_end(
|
||||
args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
|
||||
)
|
||||
print("model saved.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, True, False, True)
|
||||
train_util.add_training_arguments(parser, True)
|
||||
train_util.add_sd_saving_arguments(parser)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
config_util.add_config_arguments(parser)
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, True, False, True)
|
||||
train_util.add_training_arguments(parser, True)
|
||||
train_util.add_sd_saving_arguments(parser)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
config_util.add_config_arguments(parser)
|
||||
custom_train_functions.add_custom_train_arguments(parser)
|
||||
|
||||
parser.add_argument("--no_token_padding", action="store_true",
|
||||
help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)")
|
||||
parser.add_argument("--stop_text_encoder_training", type=int, default=None,
|
||||
help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない")
|
||||
parser.add_argument(
|
||||
"--learning_rate_te",
|
||||
type=float,
|
||||
default=None,
|
||||
help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_token_padding",
|
||||
action="store_true",
|
||||
help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stop_text_encoder_training",
|
||||
type=int,
|
||||
default=None,
|
||||
help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_half_vae",
|
||||
action="store_true",
|
||||
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
train(args)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
train(args)
|
||||
|
||||
1538
train_network.py
1538
train_network.py
File diff suppressed because it is too large
Load Diff
1039
train_network_appl_weights.py
Normal file
1039
train_network_appl_weights.py
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
696
train_textual_inversion_XTI.py
Normal file
696
train_textual_inversion_XTI.py
Normal file
@@ -0,0 +1,696 @@
|
||||
import importlib
|
||||
import argparse
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
import toml
|
||||
from multiprocessing import Value
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
import diffusers
|
||||
from diffusers import DDPMScheduler
|
||||
import library
|
||||
|
||||
import library.train_util as train_util
|
||||
import library.huggingface_util as huggingface_util
|
||||
import library.config_util as config_util
|
||||
from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import (
|
||||
apply_snr_weight,
|
||||
prepare_scheduler_for_custom_training,
|
||||
pyramid_noise_like,
|
||||
apply_noise_offset,
|
||||
scale_v_prediction_loss_like_noise_prediction,
|
||||
apply_debiased_estimation,
|
||||
)
|
||||
import library.original_unet as original_unet
|
||||
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
|
||||
|
||||
imagenet_templates_small = [
|
||||
"a photo of a {}",
|
||||
"a rendering of a {}",
|
||||
"a cropped photo of the {}",
|
||||
"the photo of a {}",
|
||||
"a photo of a clean {}",
|
||||
"a photo of a dirty {}",
|
||||
"a dark photo of the {}",
|
||||
"a photo of my {}",
|
||||
"a photo of the cool {}",
|
||||
"a close-up photo of a {}",
|
||||
"a bright photo of the {}",
|
||||
"a cropped photo of a {}",
|
||||
"a photo of the {}",
|
||||
"a good photo of the {}",
|
||||
"a photo of one {}",
|
||||
"a close-up photo of the {}",
|
||||
"a rendition of the {}",
|
||||
"a photo of the clean {}",
|
||||
"a rendition of a {}",
|
||||
"a photo of a nice {}",
|
||||
"a good photo of a {}",
|
||||
"a photo of the nice {}",
|
||||
"a photo of the small {}",
|
||||
"a photo of the weird {}",
|
||||
"a photo of the large {}",
|
||||
"a photo of a cool {}",
|
||||
"a photo of a small {}",
|
||||
]
|
||||
|
||||
imagenet_style_templates_small = [
|
||||
"a painting in the style of {}",
|
||||
"a rendering in the style of {}",
|
||||
"a cropped painting in the style of {}",
|
||||
"the painting in the style of {}",
|
||||
"a clean painting in the style of {}",
|
||||
"a dirty painting in the style of {}",
|
||||
"a dark painting in the style of {}",
|
||||
"a picture in the style of {}",
|
||||
"a cool painting in the style of {}",
|
||||
"a close-up painting in the style of {}",
|
||||
"a bright painting in the style of {}",
|
||||
"a cropped painting in the style of {}",
|
||||
"a good painting in the style of {}",
|
||||
"a close-up painting in the style of {}",
|
||||
"a rendition in the style of {}",
|
||||
"a nice painting in the style of {}",
|
||||
"a small painting in the style of {}",
|
||||
"a weird painting in the style of {}",
|
||||
"a large painting in the style of {}",
|
||||
]
|
||||
|
||||
|
||||
def train(args):
|
||||
if args.output_name is None:
|
||||
args.output_name = args.token_string
|
||||
use_template = args.use_object_template or args.use_style_template
|
||||
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
|
||||
if args.sample_every_n_steps is not None or args.sample_every_n_epochs is not None:
|
||||
print(
|
||||
"sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsとsample_every_n_epochsは現在このスクリプトではサポートされていません"
|
||||
)
|
||||
assert (
|
||||
args.dataset_class is None
|
||||
), "dataset_class is not supported in this script currently / dataset_classは現在このスクリプトではサポートされていません"
|
||||
|
||||
cache_latents = args.cache_latents
|
||||
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
tokenizer = train_util.load_tokenizer(args)
|
||||
|
||||
# acceleratorを準備する
|
||||
print("prepare accelerator")
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
|
||||
# モデルを読み込む
|
||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||
|
||||
# Convert the init_word to token_id
|
||||
if args.init_word is not None:
|
||||
init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False)
|
||||
if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token:
|
||||
print(
|
||||
f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}"
|
||||
)
|
||||
else:
|
||||
init_token_ids = None
|
||||
|
||||
# add new word to tokenizer, count is num_vectors_per_token
|
||||
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
|
||||
num_added_tokens = tokenizer.add_tokens(token_strings)
|
||||
assert (
|
||||
num_added_tokens == args.num_vectors_per_token
|
||||
), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}"
|
||||
|
||||
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
|
||||
print(f"tokens are added: {token_ids}")
|
||||
assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered"
|
||||
assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
|
||||
|
||||
token_strings_XTI = []
|
||||
XTI_layers = [
|
||||
"IN01",
|
||||
"IN02",
|
||||
"IN04",
|
||||
"IN05",
|
||||
"IN07",
|
||||
"IN08",
|
||||
"MID",
|
||||
"OUT03",
|
||||
"OUT04",
|
||||
"OUT05",
|
||||
"OUT06",
|
||||
"OUT07",
|
||||
"OUT08",
|
||||
"OUT09",
|
||||
"OUT10",
|
||||
"OUT11",
|
||||
]
|
||||
for layer_name in XTI_layers:
|
||||
token_strings_XTI += [f"{t}_{layer_name}" for t in token_strings]
|
||||
|
||||
tokenizer.add_tokens(token_strings_XTI)
|
||||
token_ids_XTI = tokenizer.convert_tokens_to_ids(token_strings_XTI)
|
||||
print(f"tokens are added (XTI): {token_ids_XTI}")
|
||||
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
||||
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
||||
token_embeds = text_encoder.get_input_embeddings().weight.data
|
||||
if init_token_ids is not None:
|
||||
for i, token_id in enumerate(token_ids_XTI):
|
||||
token_embeds[token_id] = token_embeds[init_token_ids[(i // 16) % len(init_token_ids)]]
|
||||
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
|
||||
|
||||
# load weights
|
||||
if args.weights is not None:
|
||||
embeddings = load_weights(args.weights)
|
||||
assert len(token_ids) == len(
|
||||
embeddings
|
||||
), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}"
|
||||
# print(token_ids, embeddings.size())
|
||||
for token_id, embedding in zip(token_ids_XTI, embeddings):
|
||||
token_embeds[token_id] = embedding
|
||||
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
|
||||
print(f"weighs loaded")
|
||||
|
||||
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
|
||||
|
||||
# データセットを準備する
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False))
|
||||
if args.dataset_config is not None:
|
||||
print(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
print(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
else:
|
||||
use_dreambooth_method = args.in_json is None
|
||||
if use_dreambooth_method:
|
||||
print("Use DreamBooth method.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
|
||||
]
|
||||
}
|
||||
else:
|
||||
print("Train with captions.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": [
|
||||
{
|
||||
"image_dir": args.train_data_dir,
|
||||
"metadata_file": args.in_json,
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
train_dataset_group.enable_XTI(XTI_layers, token_strings=token_strings)
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
|
||||
if use_template:
|
||||
print(f"use template for training captions. is object: {args.use_object_template}")
|
||||
templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small
|
||||
replace_to = " ".join(token_strings)
|
||||
captions = []
|
||||
for tmpl in templates:
|
||||
captions.append(tmpl.format(replace_to))
|
||||
train_dataset_group.add_replacement("", captions)
|
||||
|
||||
if args.num_vectors_per_token > 1:
|
||||
prompt_replacement = (args.token_string, replace_to)
|
||||
else:
|
||||
prompt_replacement = None
|
||||
else:
|
||||
if args.num_vectors_per_token > 1:
|
||||
replace_to = " ".join(token_strings)
|
||||
train_dataset_group.add_replacement(args.token_string, replace_to)
|
||||
prompt_replacement = (args.token_string, replace_to)
|
||||
else:
|
||||
prompt_replacement = None
|
||||
|
||||
if args.debug_dataset:
|
||||
train_util.debug_dataset(train_dataset_group, show_input_ids=True)
|
||||
return
|
||||
if len(train_dataset_group) == 0:
|
||||
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
|
||||
return
|
||||
|
||||
if cache_latents:
|
||||
assert (
|
||||
train_dataset_group.is_latent_cacheable()
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
|
||||
# モデルに xformers とか memory efficient attention を組み込む
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
||||
original_unet.UNet2DConditionModel.forward = unet_forward_XTI
|
||||
original_unet.CrossAttnDownBlock2D.forward = downblock_forward_XTI
|
||||
original_unet.CrossAttnUpBlock2D.forward = upblock_forward_XTI
|
||||
|
||||
# 学習を準備する
|
||||
if cache_latents:
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||
vae.to("cpu")
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
|
||||
# 学習に必要なクラスを準備する
|
||||
print("prepare optimizer, data loader etc.")
|
||||
trainable_params = text_encoder.get_input_embeddings().parameters()
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||
)
|
||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
# データセット側にも学習ステップを送信
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0]
|
||||
# print(len(index_no_updates), torch.sum(index_no_updates))
|
||||
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
|
||||
|
||||
# Freeze all parameters except for the token embeddings in text encoder
|
||||
text_encoder.requires_grad_(True)
|
||||
text_encoder.text_model.encoder.requires_grad_(False)
|
||||
text_encoder.text_model.final_layer_norm.requires_grad_(False)
|
||||
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
|
||||
# text_encoder.text_model.embeddings.token_embedding.requires_grad_(True)
|
||||
|
||||
unet.requires_grad_(False)
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
|
||||
unet.train()
|
||||
else:
|
||||
unet.eval()
|
||||
|
||||
if not cache_latents:
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||
if args.full_fp16:
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
text_encoder.to(weight_dtype)
|
||||
|
||||
# resumeする
|
||||
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||
|
||||
# epoch数を計算する
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
||||
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
||||
|
||||
# 学習する
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
print("running training / 学習開始")
|
||||
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
||||
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
||||
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||
print(f" num epochs / epoch数: {num_train_epochs}")
|
||||
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
||||
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
||||
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||
|
||||
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
||||
global_step = 0
|
||||
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
||||
)
|
||||
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
|
||||
if args.zero_terminal_snr:
|
||||
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
init_kwargs['wandb'] = {'name': args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
||||
|
||||
# function for saving/removing
|
||||
def save_model(ckpt_name, embs, steps, epoch_no, force_sync_upload=False):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
|
||||
print(f"\nsaving checkpoint: {ckpt_file}")
|
||||
save_weights(ckpt_file, embs, save_dtype)
|
||||
if args.huggingface_repo_id is not None:
|
||||
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
|
||||
|
||||
def remove_model(old_ckpt_name):
|
||||
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
||||
if os.path.exists(old_ckpt_file):
|
||||
print(f"removing old checkpoint: {old_ckpt_file}")
|
||||
os.remove(old_ckpt_file)
|
||||
|
||||
# training loop
|
||||
for epoch in range(num_train_epochs):
|
||||
print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
text_encoder.train()
|
||||
|
||||
loss_total = 0
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(text_encoder):
|
||||
with torch.no_grad():
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device)
|
||||
else:
|
||||
# latentに変換
|
||||
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||
latents = latents * 0.18215
|
||||
b_size = latents.shape[0]
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
input_ids = batch["input_ids"].to(accelerator.device)
|
||||
# weight_dtype) use float instead of fp16/bf16 because text encoder is float
|
||||
encoder_hidden_states = torch.stack(
|
||||
[
|
||||
train_util.get_hidden_states(args, s, tokenizer, text_encoder, weight_dtype)
|
||||
for s in torch.split(input_ids, 1, dim=1)
|
||||
]
|
||||
)
|
||||
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||
|
||||
# Predict the noise residual
|
||||
with accelerator.autocast():
|
||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
if args.v_parameterization:
|
||||
# v-parameterization training
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
|
||||
loss = loss * loss_weights
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
if args.debiased_estimation_loss:
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
params_to_clip = text_encoder.get_input_embeddings().parameters()
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
with torch.no_grad():
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[
|
||||
index_no_updates
|
||||
]
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
# TODO: fix sample_images
|
||||
# train_util.sample_images(
|
||||
# accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
|
||||
# )
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
updated_embs = (
|
||||
accelerator.unwrap_model(text_encoder)
|
||||
.get_input_embeddings()
|
||||
.weight[token_ids_XTI]
|
||||
.data.detach()
|
||||
.clone()
|
||||
)
|
||||
|
||||
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
|
||||
save_model(ckpt_name, updated_embs, global_step, epoch)
|
||||
|
||||
if args.save_state:
|
||||
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
|
||||
|
||||
remove_step_no = train_util.get_remove_step_no(args, global_step)
|
||||
if remove_step_no is not None:
|
||||
remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
|
||||
remove_model(remove_ckpt_name)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
||||
if (
|
||||
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
|
||||
): # tracking d*lr value
|
||||
logs["lr/d*lr"] = (
|
||||
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
|
||||
)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
loss_total += current_loss
|
||||
avr_loss = loss_total / (step + 1)
|
||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss/epoch": loss_total / len(train_dataloader)}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
updated_embs = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[token_ids_XTI].data.detach().clone()
|
||||
|
||||
if args.save_every_n_epochs is not None:
|
||||
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
|
||||
if accelerator.is_main_process and saving:
|
||||
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
|
||||
save_model(ckpt_name, updated_embs, epoch + 1, global_step)
|
||||
|
||||
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
|
||||
if remove_epoch_no is not None:
|
||||
remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no)
|
||||
remove_model(remove_ckpt_name)
|
||||
|
||||
if args.save_state:
|
||||
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
|
||||
|
||||
# TODO: fix sample_images
|
||||
# train_util.sample_images(
|
||||
# accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
|
||||
# )
|
||||
|
||||
# end of epoch
|
||||
|
||||
is_main_process = accelerator.is_main_process
|
||||
if is_main_process:
|
||||
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
if args.save_state and is_main_process:
|
||||
train_util.save_state_on_train_end(args, accelerator)
|
||||
|
||||
updated_embs = text_encoder.get_input_embeddings().weight[token_ids_XTI].data.detach().clone()
|
||||
|
||||
del accelerator # この後メモリを使うのでこれは消す
|
||||
|
||||
if is_main_process:
|
||||
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
||||
save_model(ckpt_name, updated_embs, global_step, num_train_epochs, force_sync_upload=True)
|
||||
|
||||
print("model saved.")
|
||||
|
||||
|
||||
def save_weights(file, updated_embs, save_dtype):
|
||||
updated_embs = updated_embs.reshape(16, -1, updated_embs.shape[-1])
|
||||
updated_embs = updated_embs.chunk(16)
|
||||
XTI_layers = [
|
||||
"IN01",
|
||||
"IN02",
|
||||
"IN04",
|
||||
"IN05",
|
||||
"IN07",
|
||||
"IN08",
|
||||
"MID",
|
||||
"OUT03",
|
||||
"OUT04",
|
||||
"OUT05",
|
||||
"OUT06",
|
||||
"OUT07",
|
||||
"OUT08",
|
||||
"OUT09",
|
||||
"OUT10",
|
||||
"OUT11",
|
||||
]
|
||||
state_dict = {}
|
||||
for i, layer_name in enumerate(XTI_layers):
|
||||
state_dict[layer_name] = updated_embs[i].squeeze(0).detach().clone().to("cpu").to(save_dtype)
|
||||
|
||||
# if save_dtype is not None:
|
||||
# for key in list(state_dict.keys()):
|
||||
# v = state_dict[key]
|
||||
# v = v.detach().clone().to("cpu").to(save_dtype)
|
||||
# state_dict[key] = v
|
||||
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import save_file
|
||||
|
||||
save_file(state_dict, file)
|
||||
else:
|
||||
torch.save(state_dict, file) # can be loaded in Web UI
|
||||
|
||||
|
||||
def load_weights(file):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
data = load_file(file)
|
||||
else:
|
||||
raise ValueError(f"NOT XTI: {file}")
|
||||
|
||||
if len(data.values()) != 16:
|
||||
raise ValueError(f"NOT XTI: {file}")
|
||||
|
||||
emb = torch.concat([x for x in data.values()])
|
||||
|
||||
return emb
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, True, True, False)
|
||||
train_util.add_training_arguments(parser, True)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
config_util.add_config_arguments(parser)
|
||||
custom_train_functions.add_custom_train_arguments(parser, False)
|
||||
|
||||
parser.add_argument(
|
||||
"--save_model_as",
|
||||
type=str,
|
||||
default="pt",
|
||||
choices=[None, "ckpt", "pt", "safetensors"],
|
||||
help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)",
|
||||
)
|
||||
|
||||
parser.add_argument("--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み")
|
||||
parser.add_argument(
|
||||
"--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--token_string",
|
||||
type=str,
|
||||
default=None,
|
||||
help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること",
|
||||
)
|
||||
parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可")
|
||||
parser.add_argument(
|
||||
"--use_object_template",
|
||||
action="store_true",
|
||||
help="ignore caption and use default templates for object / キャプションは使わずデフォルトの物体用テンプレートで学習する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_style_template",
|
||||
action="store_true",
|
||||
help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
train(args)
|
||||
Reference in New Issue
Block a user