Skip to content

Commit

Permalink
Remove updates
Browse files Browse the repository at this point in the history
  • Loading branch information
joecummings committed Feb 28, 2023
1 parent 64597a5 commit a78dd94
Showing 1 changed file with 2 additions and 8 deletions.
10 changes: 2 additions & 8 deletions torchtext/models/t5/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,8 @@

from .modules import DECODER_OUTPUTS_TYPE, ENCODER_OUTPUTS_TYPE, PAST_KEY_VALUES_TYPE, T5Decoder, T5Encoder

# logging library is not automatically supported by Torchscript
import warnings


@dataclass(frozen=True)
@dataclass
class T5Conf:
encoder_only: bool = False
linear_head: bool = False
Expand Down Expand Up @@ -215,7 +212,6 @@ def prepare_inputs_for_generation(
"return_past_key_values": return_past_key_values,
}

@torch.jit.export
def get_encoder(self) -> T5Encoder:
return self.encoder

Expand Down Expand Up @@ -292,8 +288,6 @@ def forward(

# decoder_tokens is None means at start of inference, in which case decoder sequence should begin with padding idx.
if decoder_tokens is None:
batch_size = encoder_output.size()[0]
encoder_output_device = encoder_output.device
decoder_tokens = (
torch.ones((batch_size, 1), device=encoder_output_device, dtype=torch.long) * self.padding_idx
)
Expand Down Expand Up @@ -323,7 +317,7 @@ def forward(
# Rescale output before projecting on vocab. This happens when the encoder and decoder share the
# same word embeddings, which is always the case in our t5 implementation.
# See https://github.com/huggingface/transformers/blob/d0acc9537829e7d067edbb791473bbceb2ecf056/src/transformers/models/t5/modeling_t5.py#L1661
decoder_output = decoder_output * (self.embedding_dim**-0.5)
decoder_output = decoder_output * (self.embedding_dim ** -0.5)
decoder_output = self.lm_head(decoder_output)
decoder_outputs["decoder_output"] = decoder_output

Expand Down

0 comments on commit a78dd94

Please sign in to comment.