diff --git a/torchtext/models/t5/model.py b/torchtext/models/t5/model.py index bb5e7fc9a5..eebc9ce83e 100644 --- a/torchtext/models/t5/model.py +++ b/torchtext/models/t5/model.py @@ -18,11 +18,8 @@ from .modules import DECODER_OUTPUTS_TYPE, ENCODER_OUTPUTS_TYPE, PAST_KEY_VALUES_TYPE, T5Decoder, T5Encoder -# logging library is not automatically supported by Torchscript -import warnings - -@dataclass(frozen=True) +@dataclass class T5Conf: encoder_only: bool = False linear_head: bool = False @@ -215,7 +212,6 @@ def prepare_inputs_for_generation( "return_past_key_values": return_past_key_values, } - @torch.jit.export def get_encoder(self) -> T5Encoder: return self.encoder @@ -292,8 +288,6 @@ def forward( # decoder_tokens is None means at start of inference, in which case decoder sequence should begin with padding idx. if decoder_tokens is None: - batch_size = encoder_output.size()[0] - encoder_output_device = encoder_output.device decoder_tokens = ( torch.ones((batch_size, 1), device=encoder_output_device, dtype=torch.long) * self.padding_idx ) @@ -323,7 +317,7 @@ def forward( # Rescale output before projecting on vocab. This happens when the encoder and decoder share the # same word embeddings, which is always the case in our t5 implementation. # See https://github.com/huggingface/transformers/blob/d0acc9537829e7d067edbb791473bbceb2ecf056/src/transformers/models/t5/modeling_t5.py#L1661 - decoder_output = decoder_output * (self.embedding_dim**-0.5) + decoder_output = decoder_output * (self.embedding_dim ** -0.5) decoder_output = self.lm_head(decoder_output) decoder_outputs["decoder_output"] = decoder_output