Skip to content

Commit

Permalink
[megatron gpt checkpoint conversion] causal mask requires pos_embed d…
Browse files Browse the repository at this point in the history
…imension (#13735)
  • Loading branch information
stas00 authored Sep 26, 2021
1 parent 91df455 commit 400c5a1
Showing 1 changed file with 5 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,11 @@ def convert_megatron_checkpoint(args, input_state_dict, config):

# The position embeddings.
pos_embeddings = embeddings["position_embeddings"]["weight"]
# Read the hidden dimension.
n_embed = pos_embeddings.size(1)
# DEBUG.
# Read the causal mask dimension (seqlen). [max_sequence_length, hidden_size]
n_ctx = pos_embeddings.size(0)
assert (
n_embed == heads * hidden_size_per_head
), f"detected mismatch n_embed={n_embed} != heads={heads}*hidden_size_per_head={hidden_size_per_head}"
n_ctx == config.n_ctx
), f"pos_embeddings.max_sequence_length={n_ctx} and config.n_ctx={config.n_ctx} don't match"
# Store the position embeddings.
output_state_dict["transformer.wpe.weight"] = pos_embeddings

Expand Down Expand Up @@ -175,7 +174,7 @@ def convert_megatron_checkpoint(args, input_state_dict, config):
) and weight_or_bias == "weight":

# Insert a tensor of 1x1xDxD bias.
causal_mask = torch.tril(torch.ones((n_embed, n_embed), dtype=torch.float16)).view(1, 1, n_embed, n_embed)
causal_mask = torch.tril(torch.ones((n_ctx, n_ctx), dtype=torch.float16)).view(1, 1, n_ctx, n_ctx)
output_state_dict[layer_name + ".attn.bias"] = causal_mask

# Insert a "dummy" tensor for masked_bias.
Expand Down

0 comments on commit 400c5a1

Please sign in to comment.