mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
fix training textencoder in sdxl not working
This commit is contained in:
@@ -3761,8 +3761,9 @@ def pool_workaround(
|
||||
# get hidden states for EOS token
|
||||
pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index]
|
||||
|
||||
# apply projection
|
||||
pooled_output = text_encoder.text_projection(pooled_output)
|
||||
# apply projection: projection may be of different dtype than last_hidden_state
|
||||
pooled_output = text_encoder.text_projection(pooled_output.to(text_encoder.text_projection.weight.dtype))
|
||||
pooled_output = pooled_output.to(last_hidden_state.dtype)
|
||||
|
||||
return pooled_output
|
||||
|
||||
|
||||
Reference in New Issue
Block a user