remove debug print

This commit is contained in:
Kohya S
2023-08-04 08:42:54 +09:00
parent c6d52fdea4
commit 9d7619d1eb

View File

@@ -3756,8 +3756,6 @@ def pool_workaround(
# find index for EOS token
eos_token_index = torch.where(input_ids == eos_token_id)[1]
eos_token_index = eos_token_index.to(device=last_hidden_state.device)
print(eos_token_index)
print(input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1))
# get hidden states for EOS token
pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index]