@@ -94,7 +94,6 @@ def forward(self, hidden_states):
94
94
class LlamaRotaryEmbedding (nn .Module ):
95
95
def __init__ (self , dim , max_position_embeddings = 2048 , base = 10000 , device = None ):
96
96
super ().__init__ ()
97
-
98
97
self .dim = dim
99
98
self .max_position_embeddings = max_position_embeddings
100
99
self .base = base
@@ -118,6 +117,9 @@ def cos_cached(self):
118
117
return self ._cos_cached
119
118
120
119
def forward (self , x , position_ids , seq_len = None ):
120
+ if seq_len is not None :
121
+ logger .warning_once ("The `seq_len` argument is deprecated and unused. It will be removed in v4.40." )
122
+
121
123
# x: [bs, num_attention_heads, seq_len, head_size]
122
124
inv_freq_expanded = self .inv_freq [None , :, None ].float ().expand (position_ids .shape [0 ], - 1 , 1 )
123
125
position_ids_expanded = position_ids [:, None , :].float ()
@@ -138,16 +140,11 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, s
138
140
self .scaling_factor = scaling_factor
139
141
super ().__init__ (dim , max_position_embeddings , base , device )
140
142
141
- def _set_cos_sin_cache (self , seq_len , device , dtype ):
142
- self .max_seq_len_cached = seq_len
143
- t = torch .arange (self .max_seq_len_cached , device = device , dtype = torch .int64 ).type_as (self .inv_freq )
144
- t = t / self .scaling_factor
145
-
146
- freqs = torch .outer (t , self .inv_freq )
147
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
148
- emb = torch .cat ((freqs , freqs ), dim = - 1 )
149
- self .register_buffer ("cos_cached" , emb .cos ().to (dtype ), persistent = False )
150
- self .register_buffer ("sin_cached" , emb .sin ().to (dtype ), persistent = False )
143
+ def forward (self , x , position_ids , seq_len = None ):
144
+ # difference to the original RoPE: a scaling factor is aplied to the position ids
145
+ position_ids = position_ids .float () / self .scaling_factor
146
+ cos , sin = super ().forward (x , position_ids , seq_len )
147
+ return cos , sin
151
148
152
149
153
150
class LlamaDynamicNTKScalingRotaryEmbedding (LlamaRotaryEmbedding ):
@@ -157,23 +154,20 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, s
157
154
self .scaling_factor = scaling_factor
158
155
super ().__init__ (dim , max_position_embeddings , base , device )
159
156
160
- def _set_cos_sin_cache (self , seq_len , device , dtype ):
161
- self . max_seq_len_cached = seq_len
162
-
157
+ def forward (self , x , position_ids , seq_len = None ):
158
+ # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
159
+ seq_len = torch . max ( position_ids ) + 1
163
160
if seq_len > self .max_position_embeddings :
164
161
base = self .base * (
165
162
(self .scaling_factor * seq_len / self .max_position_embeddings ) - (self .scaling_factor - 1 )
166
163
) ** (self .dim / (self .dim - 2 ))
167
- inv_freq = 1.0 / (base ** ( torch . arange ( 0 , self . dim , 2 , dtype = torch . int64 ). float (). to ( device ) / self . dim ))
168
- self . register_buffer ( "inv_freq" , inv_freq , persistent = False )
169
-
170
- t = torch . arange ( self .max_seq_len_cached , device = device , dtype = torch . int64 ). type_as ( self . inv_freq )
164
+ inv_freq = 1.0 / (
165
+ base ** ( torch . arange ( 0 , self . dim , 2 , dtype = torch . int64 ). float (). to ( x . device ) / self . dim )
166
+ )
167
+ self .register_buffer ( "inv_freq" , inv_freq , persistent = False ) # TODO joao: this may break with compilation
171
168
172
- freqs = torch .outer (t , self .inv_freq )
173
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
174
- emb = torch .cat ((freqs , freqs ), dim = - 1 )
175
- self .register_buffer ("cos_cached" , emb .cos ().to (dtype ), persistent = False )
176
- self .register_buffer ("sin_cached" , emb .sin ().to (dtype ), persistent = False )
169
+ cos , sin = super ().forward (x , position_ids , seq_len )
170
+ return cos , sin
177
171
178
172
179
173
def rotate_half (x ):
@@ -183,17 +177,16 @@ def rotate_half(x):
183
177
return torch .cat ((- x2 , x1 ), dim = - 1 )
184
178
185
179
186
- def apply_rotary_pos_emb (q , k , cos , sin , position_ids , unsqueeze_dim = 1 ):
180
+ def apply_rotary_pos_emb (q , k , cos , sin , position_ids = None , unsqueeze_dim = 1 ):
187
181
"""Applies Rotary Position Embedding to the query and key tensors.
188
182
189
183
Args:
190
184
q (`torch.Tensor`): The query tensor.
191
185
k (`torch.Tensor`): The key tensor.
192
186
cos (`torch.Tensor`): The cosine part of the rotary embedding.
193
187
sin (`torch.Tensor`): The sine part of the rotary embedding.
194
- position_ids (`torch.Tensor`):
195
- The position indices of the tokens corresponding to the query and key tensors. For example, this can be
196
- used to pass offsetted position ids when working with a KV-cache.
188
+ position_ids (`torch.Tensor`, *optional*):
189
+ Deprecated and unused.
197
190
unsqueeze_dim (`int`, *optional*, defaults to 1):
198
191
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
199
192
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
@@ -360,8 +353,8 @@ def forward(
360
353
value_states = value_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
361
354
362
355
past_key_value = getattr (self , "past_key_value" , past_key_value )
363
- cos , sin = self .rotary_emb (value_states , position_ids , seq_len = None )
364
- query_states , key_states = apply_rotary_pos_emb (query_states , key_states , cos , sin , None )
356
+ cos , sin = self .rotary_emb (value_states , position_ids )
357
+ query_states , key_states = apply_rotary_pos_emb (query_states , key_states , cos , sin )
365
358
366
359
if past_key_value is not None :
367
360
# sin and cos are specific to RoPE models; position_ids needed for the static cache
@@ -447,8 +440,8 @@ def forward(
447
440
key_states = key_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
448
441
value_states = value_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
449
442
450
- cos , sin = self .rotary_emb (value_states , position_ids , seq_len = None )
451
- query_states , key_states = apply_rotary_pos_emb (query_states , key_states , cos , sin , None )
443
+ cos , sin = self .rotary_emb (value_states , position_ids )
444
+ query_states , key_states = apply_rotary_pos_emb (query_states , key_states , cos , sin )
452
445
453
446
past_key_value = getattr (self , "past_key_value" , past_key_value )
454
447
@@ -645,8 +638,8 @@ def forward(
645
638
key_states = key_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
646
639
value_states = value_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
647
640
648
- cos , sin = self .rotary_emb (value_states , position_ids , seq_len = None )
649
- query_states , key_states = apply_rotary_pos_emb (query_states , key_states , cos , sin , None )
641
+ cos , sin = self .rotary_emb (value_states , position_ids )
642
+ query_states , key_states = apply_rotary_pos_emb (query_states , key_states , cos , sin )
650
643
651
644
past_key_value = getattr (self , "past_key_value" , past_key_value )
652
645
0 commit comments