Skip to content

Commit

Permalink
fix: rotary embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
lukasellinger committed Aug 6, 2024
1 parent 4cfcc97 commit d540b9c
Show file tree
Hide file tree
Showing 7 changed files with 610 additions and 21 deletions.
18 changes: 17 additions & 1 deletion models/evidence_selection_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
class EvidenceSelectionModel(nn.Module):
"""Model to compute sentence embeddings."""

def __init__(self, model, feed_forward=False, normalize_before_fc=True, out_features=256):
def __init__(self, model, feed_forward=False, normalize_before_fc=True, out_features=256, device=None):
super().__init__()
self.model = model
self.device = device or "cuda" if torch.cuda.is_available() else "cpu"
self.feed_forward = feed_forward
self.normalize_before_fc = normalize_before_fc

Expand All @@ -21,13 +22,28 @@ def __init__(self, model, feed_forward=False, normalize_before_fc=True, out_feat
param.requires_grad = False
self.fc = nn.Linear(1024, out_features)

def _reset_rotary_embeddings(self):
for module in self.model.modules():
if module.__class__.__name__ == 'NomicBertDynamicNTKRotaryEmbedding':
module._seq_len_cached = 0
module._cos_cached = None
module._sin_cached = None
module._cos_k_cached = None
module._sin_k_cached = None
if hasattr(module, 'inv_freq'):
module.inv_freq = module._compute_inv_freq(device=self.device)

def forward(self, input_ids=None, attention_mask=None, sentence_mask=None, **kwargs):
"""Forward function."""
if sentence_mask is None:
sentence_mask = attention_mask.unsqueeze(dim=1)
#sentence_mask = torch.zeros_like(attention_mask.unsqueeze(dim=1))
#sentence_mask[:, :, 0] = 1 # try only cls

# When the length of the sequence exceeds the maximum position embeddings, the 'base'
# variable is adjusted in the dynamic rotary embedding computation. This change in 'base'
# affects the cached values, leading to inconsistencies in subsequent inference steps.
self._reset_rotary_embeddings()
outputs = self.model(input_ids=input_ids,
attention_mask=attention_mask)['last_hidden_state']

Expand Down
Loading

0 comments on commit d540b9c

Please sign in to comment.