diff --git a/library/train_util.py b/library/train_util.py index 4d7f6727..3eda5098 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3756,8 +3756,6 @@ def pool_workaround( # find index for EOS token eos_token_index = torch.where(input_ids == eos_token_id)[1] eos_token_index = eos_token_index.to(device=last_hidden_state.device) - print(eos_token_index) - print(input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1)) # get hidden states for EOS token pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index]