fix training textencoder in sdxl not working

This commit is contained in:
Kohya S
2023-08-05 21:22:50 +09:00
parent 25d8cd473e
commit e5f9772a35

View File

@@ -3761,8 +3761,9 @@ def pool_workaround(
# get hidden states for EOS token
pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index]
# apply projection
pooled_output = text_encoder.text_projection(pooled_output)
# apply projection: projection may be of different dtype than last_hidden_state
pooled_output = text_encoder.text_projection(pooled_output.to(text_encoder.text_projection.weight.dtype))
pooled_output = pooled_output.to(last_hidden_state.dtype)
return pooled_output