Commit Graph

231 Commits

Author SHA1 Message Date
Kohya S
e24d9606a2 add clean_memory_on_device and use it from training 2024-02-12 11:10:52 +09:00
Disty0
a6a2b5a867 Fix IPEX support and add XPU device to device_utils 2024-01-31 17:32:37 +03:00
Kohya S
2ca4d0c831 Merge pull request #1054 from akx/mps
Device support improvements (MPS)
2024-01-31 21:30:12 +09:00
DukeG
4e67fb8444 test 2024-01-26 20:22:49 +08:00
DukeG
50f631c768 test 2024-01-26 20:02:48 +08:00
DukeG
85bc371ebc test 2024-01-26 18:58:47 +08:00
Aarni Koskela
afc38707d5 Refactor memory cleaning into a single function 2024-01-23 14:28:50 +02:00
Kohya S
7a20df5ad5 Merge pull request #1064 from KohakuBlueleaf/fix-grad-sync
Avoid grad sync on each step even when doing accumulation
2024-01-23 20:33:55 +09:00
Kohya S
bea4362e21 Merge pull request #1060 from akx/refactor-xpu-init
Deduplicate ipex initialization code
2024-01-23 20:25:37 +09:00
Kohaku-Blueleaf
711b40ccda Avoid always sync 2024-01-23 11:49:03 +08:00
Kohya S
fef172966f Add network_multiplier for dataset and train LoRA 2024-01-20 16:24:43 +09:00
Kohya S
a7ef6422b6 fix to work with torch 2.0 2024-01-20 10:00:30 +09:00
Kohaku-Blueleaf
9cfa68c92f [Experimental Feature] FP8 weight dtype for base model when running train_network (or sdxl_train_network) (#1057)
* Add fp8 support

* remove some debug prints

* Better implementation for te

* Fix some misunderstanding

* as same as unet, add explicit convert

* better impl for convert TE to fp8

* fp8 for not only unet

* Better cache TE and TE lr

* match arg name

* Fix with list

* Add timeout settings

* Fix arg style

* Add custom seperator

* Fix typo

* Fix typo again

* Fix dtype error

* Fix gradient problem

* Fix req grad

* fix merge

* Fix merge

* Resolve merge

* arrangement and document

* Resolve merge error

* Add assert for mixed precision
2024-01-20 09:46:53 +09:00
Aarni Koskela
6f3f701d3d Deduplicate ipex initialization code 2024-01-19 18:07:36 +02:00
Kohya S
976d092c68 fix text encodes are on gpu even when not trained 2024-01-17 21:31:50 +09:00
Nir Weingarten
ab716302e4 Added cli argument for wandb session name 2024-01-03 11:52:38 +02:00
Kohya S
0676f1a86f Merge pull request #1009 from liubo0902/main
speed up latents nan replace
2023-12-21 21:37:16 +09:00
liubo0902
8c7d05afd2 speed up latents nan replace 2023-12-20 09:35:17 +08:00
Kohya S
912dca8f65 fix duplicated sample gen for every epoch ref #907 2023-12-07 22:13:38 +09:00
Isotr0py
db84530074 Fix gradients synchronization for multi-GPUs training (#989)
* delete DDP wrapper

* fix train_db vae and train_network

* fix train_db vae and train_network unwrap

* network grad sync

---------

Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
2023-12-07 22:01:42 +09:00
Kohya S
383b4a2c3e Merge pull request #907 from shirayu/add_option_sample_at_first
Add option --sample_at_first
2023-12-03 21:00:32 +09:00
feffy380
6b3148fd3f Fix min-snr-gamma for v-prediction and ZSNR.
This fixes min-snr for vpred+zsnr by dividing directly by SNR+1.
The old implementation did it in two steps: (min-snr/snr) * (snr/(snr+1)), which causes division by zero when combined with --zero_terminal_snr
2023-11-07 23:02:25 +01:00
Yuta Hayashibe
2c731418ad Added sample_images() for --sample_at_first 2023-10-29 22:08:42 +09:00
Kohya S
9d6a5a0c79 Merge pull request #899 from shirayu/use_moving_average
Show moving average loss in the progress bar
2023-10-29 14:37:58 +09:00
Kohaku-Blueleaf
1cefb2a753 Better implementation for te autocast (#895)
* Better implementation for te

* Fix some misunderstanding

* as same as unet, add explicit convert

* Better cache TE and TE lr

* Fix with list

* Add timeout settings

* Fix arg style
2023-10-28 15:49:59 +09:00
Yuta Hayashibe
0d21925bdf Use @property 2023-10-27 18:14:27 +09:00
Yuta Hayashibe
efef5c8ead Show "avr_loss" instead of "loss" because it is moving average 2023-10-27 17:59:58 +09:00
Yuta Hayashibe
3d2bb1a8f1 Add LossRecorder and use moving average in all places 2023-10-27 17:49:49 +09:00
青龍聖者@bdsqlsz
202f2c3292 Debias Estimation loss (#889)
* update for bnb 0.41.1

* fixed generate_controlnet_subsets_config for training

* Revert "update for bnb 0.41.1"

This reverts commit 70bd3612d8.

* add debiased_estimation_loss

* add train_network

* Revert "add train_network"

This reverts commit 6539363c5c.

* Update train_network.py
2023-10-23 22:59:14 +09:00
Kohya S
025368f51c may work dropout in LyCORIS #859 2023-10-09 14:06:58 +09:00
Yuta Hayashibe
27f9b6ffeb updated typos to v1.16.15 and fix typos 2023-10-01 21:51:24 +09:00
Kohya S
4cc919607a fix placing of requires_grad_ of U-Net 2023-10-01 16:41:48 +09:00
Kohya S
81419f7f32 Fix to work training U-Net only LoRA for SD1/2 2023-10-01 16:37:23 +09:00
Kohya S
d39f1a3427 Merge pull request #808 from rockerBOO/metadata
Add ip_noise_gamma metadata
2023-09-24 14:35:18 +09:00
Disty0
b64389c8a9 Intel ARC support with IPEX 2023-09-19 18:05:05 +03:00
rockerBOO
80aca1ccc7 Add ip_noise_gamma metadata 2023-09-05 15:20:15 -04:00
Kohya S
c142dadb46 support sai model spec 2023-08-06 21:50:05 +09:00
Kohya S
0636399c8c add adding v-pred like loss for noise pred 2023-07-31 08:23:28 +09:00
Kohya S
8ba02ac829 fix to work text encoder only network with bf16 2023-07-22 09:56:36 +09:00
Kohya S
73a08c0be0 Merge pull request #630 from ddPn08/sdxl
make tracker init_kwargs configurable
2023-07-20 22:05:55 +09:00
Kohya S
acf16c063a make to work with PyTorch 1.12 2023-07-20 21:41:16 +09:00
Kohya S
225e871819 enable full bf16 trainint in train_network 2023-07-19 08:41:42 +09:00
Kohya S
7875ca8fb5 Merge pull request #645 from Ttl/prepare_order
Cast weights to correct precision before transferring them to GPU
2023-07-19 08:33:32 +09:00
Kohya S
6d2d8dfd2f add zero_terminal_snr option 2023-07-18 23:17:23 +09:00
Henrik Forstén
cdffd19f61 Cast weights to correct precision before transferring them to GPU 2023-07-13 12:45:28 +03:00
ddPn08
b841dd78fe make tracker init_kwargs configurable 2023-07-11 10:21:45 +09:00
Kohya S
0416f26a76 support multi gpu in caching text encoder outputs 2023-07-09 16:02:56 +09:00
Kohya S
ea182461d3 add min/max_timestep 2023-07-03 20:44:42 +09:00
Kohya S
d395bc0647 fix max_token_length not works for sdxl 2023-06-29 13:02:19 +09:00
Kohya S
9ebebb22db fix typos 2023-06-26 20:43:34 +09:00