-
Notifications
You must be signed in to change notification settings - Fork 1
/
rope.py
31 lines (22 loc) · 1.08 KB
/
rope.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
def apply_rope(input_tensor, rope_theta=500000.0):
input_tensor = input_tensor.view(input_tensor.shape[0], input_tensor.shape[1]//2, 2)
output_tensor = torch.zeros_like(input_tensor)
seq_len = input_tensor.shape[0]
dim = input_tensor.shape[1]
position_ids = torch.arange(seq_len, dtype=torch.float32).unsqueeze(1) # Shape: [seq_len, 1]
theta = rope_theta
dim_indices = torch.arange(dim, dtype=torch.float32)
inv_freq = 1.0 / (theta ** (dim_indices / dim))
angles = position_ids * inv_freq
angles_cos = torch.cos(angles)
angles_sin = torch.sin(angles)
for i in range(seq_len):
for j in range(dim):
real_part = input_tensor[i, j, 0]
imag_part = input_tensor[i, j, 1]
cos_angle = angles_cos[i, j]
sin_angle = angles_sin[i, j]
output_tensor[i, j, 0] = real_part * cos_angle - imag_part * sin_angle
output_tensor[i, j, 1] = real_part * sin_angle + imag_part * cos_angle
return output_tensor.view(input_tensor.shape[0], input_tensor.shape[1]*2)