diff --git a/torchtext/prototype/generate.py b/torchtext/prototype/generate.py index e08b180d47..c54813a930 100644 --- a/torchtext/prototype/generate.py +++ b/torchtext/prototype/generate.py @@ -223,6 +223,12 @@ def beam_search( encoder_output_key = "last_hidden_state" if self.is_huggingface_model else "encoder_output" encoder_output = model_kwargs["encoder_outputs"][encoder_output_key] + num_sequences = input_ids.shape[0] + + # Pre-allocate everything + token_idxs = torch.zeros((num_sequences, num_beams, max_len)).to(dtype=torch.long, device=device) + beam_idxs = torch.zeros((num_sequences, num_beams, max_len)).to(dtype=torch.long, device=device) + def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_step_model_states, timestep): # `emissions` and `N` are unused in this current implementation @@ -231,16 +237,8 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_ # For first timestep, create previous step token_idxs and model_states if timestep == 0: prev_step_token_idxs = [-1] - prev_step_model_states = [ - create_emitting_model_state( - Seq2SeqModelState(timestep=0, sequence=input_ids[i].unsqueeze(0), lm_scores=None) - ) - ] encoder_output_for_curr_seq = encoder_output[i, :, :].unsqueeze(0) if self.is_encoder_decoder else None - prev_model_state_sequences = [ - get_obj_from_emitting_model_state(state).sequence for state in prev_step_model_states - ] out_probs, model_states = [], [] start = 0 @@ -256,66 +254,32 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_ if end > curr_beam_size: end = curr_beam_size - num_samples = end - start - if prev_step_token_idxs != [-1]: - state_sequences = torch.cat(prev_model_state_sequences[start:end], dim=0) - token_indices = ( - torch.Tensor(prev_step_token_idxs[start:end]) - .to(dtype=torch.long, device=device) - .reshape(num_samples, 1) - ) + token_indices = torch.Tensor(prev_step_token_idxs[start:end]).to(dtype=torch.long, device=device) + token_idxs[i, : len(token_indices), timestep] = token_indices - state_and_tokens = torch.cat( - [state_sequences, token_indices], dim=-1 - ) # [batch_size x (timestep + 1)] - assert state_and_tokens.shape == ( - num_samples, - timestep + 1, - ), f"state_and_tokens has shape {state_and_tokens.shape} = expected {(num_samples, timestep + 1)}" - else: - assert len(prev_model_state_sequences) == 1 - state_and_tokens = token_indices = prev_model_state_sequences[0].expand( - num_beams, -1 - ) # TODO: Make this more robust + curr_token_idxs = token_idxs[i, :, timestep].reshape(num_beams, 1) # Cleanup -- combine this with the above if self.is_encoder_decoder: # Expand encoder outputs along the batch dimension so that they match the decoder input state's batch size # This is a view-only operation and doesn't copy model_kwargs["encoder_outputs"][encoder_output_key] = encoder_output_for_curr_seq.expand( - num_samples if timestep > 0 else num_beams, -1, -1 + num_beams, -1, -1 ) # Preprocess inputs for generation model_inputs = self.model.prepare_inputs_for_generation( - token_indices, **model_kwargs + curr_token_idxs, **model_kwargs ) # This should technically work with state_and_tokens, but the prepare function has to splice if past (like HF does) if self.is_huggingface_model: model_inputs.update(self._huggingface_model_input_values) if len(prev_step_hyp_idxs) > 1 and model_kwargs["past"] is not None: - beam_idxs = torch.Tensor(prev_step_hyp_idxs).to(dtype=torch.int32) - - # We could store this in model_kwargs - num_hyps_in_prev_step = model_kwargs["past"][0][0].shape[0] - - num_finished_hyps_in_step = num_hyps_in_prev_step - len(prev_step_hyp_idxs) - if num_finished_hyps_in_step > 0: - beam_idxs = F.pad(beam_idxs, (0, num_finished_hyps_in_step), "constant", 0) - - beam_idxs = torch.clamp(beam_idxs, max=len(prev_step_hyp_idxs) - 1) - - reordered_cached = self.model._reorder_cache(model_kwargs["past"], beam_idxs) - - if num_finished_hyps_in_step > 0: - sliced_cache = () - for states in reordered_cached: - sliced_state = () - for state in states: - sliced_state = sliced_state + (state[: len(prev_step_hyp_idxs)],) - sliced_cache = sliced_cache + (sliced_state,) - reordered_cached = sliced_cache + beam_indices = torch.Tensor(prev_step_hyp_idxs).to(dtype=torch.int32) + beam_idxs[i, : len(prev_step_hyp_idxs), timestep] = beam_indices + curr_beam_idxs = beam_idxs[i, :, timestep] + reordered_cached = self.model._reorder_cache(model_kwargs["past"], curr_beam_idxs) model_inputs["past_key_values"] = reordered_cached # Forward pass @@ -330,8 +294,7 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_ self._update_model_kwargs_for_generation(outputs, model_kwargs) # Keep track of probabilities over vocab for this pairing - # TODO: fix how we track the number here? - for i in range(lm_scores.shape[0]): + for i in range(num_beams): sample_lm_scores = lm_scores[i, -1] out_probs.append(sample_lm_scores.tolist()) # Keep track of sequence and decoder hidden states @@ -339,8 +302,8 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_ create_emitting_model_state( Seq2SeqModelState( timestep=timestep, - sequence=state_and_tokens[i].unsqueeze(0), - lm_scores=sample_lm_scores, + sequence=[], + lm_scores=0, ) ) ) @@ -386,10 +349,6 @@ def is_not_neg_one(elem: int) -> bool: if not self.is_encoder_decoder: final_tokens = input_ids[timestep].tolist() + final_tokens - # Makeshift padding so that we can stack the tensors - while len(final_tokens) < max_len: - final_tokens += [0] - # Convert from list to tensors final_tokens_as_tensors = torch.Tensor(final_tokens).to(torch.long)