Skip to content

Commit

Permalink
fix an encoding issue in EncodeT5Text where normalization was not app…
Browse files Browse the repository at this point in the history
…lied to the correct layer
  • Loading branch information
Nerogar committed Jul 7, 2024
1 parent cf90cd7 commit 583b1f8
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
15 changes: 9 additions & 6 deletions src/mgds/pipelineModules/EncodeT5Text.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class EncodeT5Text(
def __init__(
self,
tokens_in_name: str,
tokens_attention_mask_in_name: str,
tokens_attention_mask_in_name: str | None,
hidden_state_out_name: str,
pooled_out_name: str | None,
text_encoder: T5EncoderModel,
Expand Down Expand Up @@ -48,10 +48,13 @@ def get_outputs(self) -> list[str]:

def get_item(self, variation: int, index: int, requested_name: str = None) -> dict:
tokens = self._get_previous_item(variation, self.tokens_in_name, index)
tokens_attention_mask = self._get_previous_item(variation, self.tokens_attention_mask_in_name, index)

tokens = tokens.unsqueeze(0)
tokens_attention_mask = tokens_attention_mask.unsqueeze(0)

if self.tokens_attention_mask_in_name is not None:
tokens_attention_mask = self._get_previous_item(variation, self.tokens_attention_mask_in_name, index)
tokens_attention_mask = tokens_attention_mask.unsqueeze(0)
else:
tokens_attention_mask = None

with self._all_contexts(self.autocast_contexts):
if tokens_attention_mask is not None and self.dtype:
Expand All @@ -64,7 +67,7 @@ def get_item(self, variation: int, index: int, requested_name: str = None) -> di
return_dict=True,
)

hidden_states = text_encoder_output.hidden_states[:-1]
hidden_states = text_encoder_output.hidden_states
if self.pooled_out_name:
pooled_state = text_encoder_output.text_embeds
else:
Expand All @@ -75,7 +78,7 @@ def get_item(self, variation: int, index: int, requested_name: str = None) -> di

hidden_state = hidden_states[self.hidden_state_output_index]

if self.add_layer_norm:
if self.hidden_state_output_index != -1 and self.add_layer_norm:
with self._all_contexts(self.autocast_contexts):
final_layer_norm = self.text_encoder.encoder.final_layer_norm
hidden_state = final_layer_norm(
Expand Down
4 changes: 2 additions & 2 deletions src/mgds/pipelineModules/Tokenize.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from transformers import CLIPTokenizer, T5Tokenizer
from transformers import CLIPTokenizer, T5Tokenizer, T5TokenizerFast

from mgds.PipelineModule import PipelineModule
from mgds.pipelineModuleTypes.RandomAccessPipelineModule import RandomAccessPipelineModule
Expand All @@ -13,7 +13,7 @@ def __init__(
in_name: str,
tokens_out_name: str,
mask_out_name: str,
tokenizer: CLIPTokenizer | T5Tokenizer,
tokenizer: CLIPTokenizer | T5Tokenizer | T5TokenizerFast,
max_token_length: int,
):
super(Tokenize, self).__init__()
Expand Down

0 comments on commit 583b1f8

Please sign in to comment.