From 64597a5bdc5406f1a2efa2f252db03780f3451dc Mon Sep 17 00:00:00 2001 From: Joe Cummings Date: Fri, 10 Feb 2023 11:52:47 -0500 Subject: [PATCH] Reworked w/ pre-allocated matrices, verrrrrrrry slow --- .../prototype/test_generate.py | 2 +- torchtext/prototype/generate.py | 95 ++++++------------- 2 files changed, 30 insertions(+), 67 deletions(-) diff --git a/test/torchtext_unittest/prototype/test_generate.py b/test/torchtext_unittest/prototype/test_generate.py index 64cfde85e0..c7e69338f4 100644 --- a/test/torchtext_unittest/prototype/test_generate.py +++ b/test/torchtext_unittest/prototype/test_generate.py @@ -90,7 +90,7 @@ def test_hf_DELETE(self) -> None: test_sequence_tk, max_len=100, pad_idx=t5.config.pad_token_id, - num_beams=10, + num_beams=7, beam_size_token=t5.config.vocab_size, ) end = time.time() - start diff --git a/torchtext/prototype/generate.py b/torchtext/prototype/generate.py index b981429916..043b25a74d 100644 --- a/torchtext/prototype/generate.py +++ b/torchtext/prototype/generate.py @@ -228,6 +228,12 @@ def beam_search( encoder_output_key = "last_hidden_state" if self.is_huggingface_model else "encoder_output" encoder_output = model_kwargs["encoder_outputs"][encoder_output_key] + num_sequences = input_ids.shape[0] + + # Pre-allocate everything + token_idxs = torch.full((num_sequences, num_beams, 1), eos_idx).to(dtype=torch.long, device=device) + beam_idxs = torch.zeros((num_sequences, num_beams, 1)).to(dtype=torch.long, device=device) + def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_step_model_states, timestep): # `emissions` and `N` are unused in this current implementation @@ -236,16 +242,8 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_ # For first timestep, create previous step token_idxs and model_states if timestep == 0: prev_step_token_idxs = [-1] - prev_step_model_states = [ - create_emitting_model_state( - Seq2SeqModelState(timestep=0, sequence=input_ids[i].unsqueeze(0), lm_scores=None) - ) - ] encoder_output_for_curr_seq = encoder_output[i, :, :].unsqueeze(0) if self.is_encoder_decoder else None - prev_model_state_sequences = [ - get_obj_from_emitting_model_state(state).sequence for state in prev_step_model_states - ] out_probs, model_states = [], [] start = 0 @@ -261,66 +259,32 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_ if end > curr_beam_size: end = curr_beam_size - num_samples = end - start - if prev_step_token_idxs != [-1]: - state_sequences = torch.cat(prev_model_state_sequences[start:end], dim=0) - token_indices = ( - torch.Tensor(prev_step_token_idxs[start:end]) - .to(dtype=torch.long, device=device) - .reshape(num_samples, 1) - ) - - state_and_tokens = torch.cat( - [state_sequences, token_indices], dim=-1 - ) # [batch_size x (timestep + 1)] - assert state_and_tokens.shape == ( - num_samples, - timestep + 1, - ), f"state_and_tokens has shape {state_and_tokens.shape} = expected {(num_samples, timestep + 1)}" + token_indices = torch.Tensor(prev_step_token_idxs[start:end]).to(dtype=torch.long, device=device) + token_idxs[i, : len(token_indices), 0] = token_indices + curr_token_idxs = token_idxs[i, :, 0].reshape(num_beams, 1) else: - assert len(prev_model_state_sequences) == 1 - state_and_tokens = token_indices = prev_model_state_sequences[0].expand( - num_beams, -1 - ) # TODO: Make this more robust - - # Cleanup -- combine this with the above - if self.is_encoder_decoder: - # Expand encoder outputs along the batch dimension so that they match the decoder input state's batch size - # This is a view-only operation and doesn't copy - model_kwargs["encoder_outputs"][encoder_output_key] = encoder_output_for_curr_seq.expand( - num_samples if timestep > 0 else num_beams, -1, -1 - ) + if self.is_encoder_decoder: + # Expand encoder outputs along the batch dimension so that they match the decoder input state's batch size + # This is a view-only operation and doesn't copy + model_kwargs["encoder_outputs"][encoder_output_key] = encoder_output_for_curr_seq.expand( + num_beams, -1, -1 + ) + curr_token_idxs = torch.zeros((num_beams, 1)).to(dtype=torch.long, device=device) + # Preprocess inputs for generation model_inputs = self.model.prepare_inputs_for_generation( - token_indices, **model_kwargs + curr_token_idxs, **model_kwargs ) # This should technically work with state_and_tokens, but the prepare function has to splice if past (like HF does) if self.is_huggingface_model: model_inputs.update(self._huggingface_model_input_values) if len(prev_step_hyp_idxs) > 1 and model_kwargs["past"] is not None: - beam_idxs = torch.Tensor(prev_step_hyp_idxs).to(dtype=torch.int32) - - # We could store this in model_kwargs - num_hyps_in_prev_step = model_kwargs["past"][0][0].shape[0] - - num_finished_hyps_in_step = num_hyps_in_prev_step - len(prev_step_hyp_idxs) - if num_finished_hyps_in_step > 0: - beam_idxs = F.pad(beam_idxs, (0, num_finished_hyps_in_step), "constant", 0) - - beam_idxs = torch.clamp(beam_idxs, max=len(prev_step_hyp_idxs) - 1) - - reordered_cached = self.model._reorder_cache(model_kwargs["past"], beam_idxs) - - if num_finished_hyps_in_step > 0: - sliced_cache = () - for states in reordered_cached: - sliced_state = () - for state in states: - sliced_state = sliced_state + (state[: len(prev_step_hyp_idxs)],) - sliced_cache = sliced_cache + (sliced_state,) - reordered_cached = sliced_cache + beam_indices = torch.Tensor(prev_step_hyp_idxs).to(dtype=torch.int32) + beam_idxs[i, : len(prev_step_hyp_idxs), 0] = beam_indices + curr_beam_idxs = beam_idxs[i, :, 0] + reordered_cached = self.model._reorder_cache(model_kwargs["past"], curr_beam_idxs) model_inputs["past_key_values"] = reordered_cached # Forward pass @@ -334,9 +298,12 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_ if self.is_huggingface_model: self._update_model_kwargs_for_generation(outputs, model_kwargs) + # Reset + token_idxs[i, :, 0] = eos_idx + beam_idxs[i, :, 0] = 0 + # Keep track of probabilities over vocab for this pairing - # TODO: fix how we track the number here? - for i in range(lm_scores.shape[0]): + for i in range(num_beams): sample_lm_scores = lm_scores[i, -1] out_probs.append(sample_lm_scores.tolist()) # Keep track of sequence and decoder hidden states @@ -344,8 +311,8 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_ create_emitting_model_state( Seq2SeqModelState( timestep=timestep, - sequence=state_and_tokens[i].unsqueeze(0), - lm_scores=sample_lm_scores, + sequence=[], + lm_scores=0, ) ) ) @@ -391,10 +358,6 @@ def is_not_neg_one(elem: int) -> bool: if not self.is_encoder_decoder: final_tokens = input_ids[timestep].tolist() + final_tokens - # Makeshift padding so that we can stack the tensors - while len(final_tokens) < max_len: - final_tokens += [0] - # Convert from list to tensors final_tokens_as_tensors = torch.Tensor(final_tokens).to(torch.long)