Skip to content

Commit 476957b

Browse files
ganteArthurZucker
authored andcommitted
🚨 Llama: update rope scaling to match static cache changes (#29143)
1 parent 7a4bec6 commit 476957b

File tree

7 files changed

+38
-44
lines changed

7 files changed

+38
-44
lines changed

src/transformers/models/deprecated/open_llama/modeling_open_llama.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def forward(self, x, seq_len=None):
100100
)
101101

102102

103-
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->OpenLlama
103+
# Copied from transformers.models.falcon.modeling_falcon.FalconLinearScalingRotaryEmbedding with Falcon->OpenLlama
104104
class OpenLlamaLinearScalingRotaryEmbedding(OpenLlamaRotaryEmbedding):
105105
"""OpenLlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
106106

@@ -120,7 +120,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
120120
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
121121

122122

123-
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->OpenLlama
123+
# Copied from transformers.models.falcon.modeling_falcon.FalconDynamicNTKScalingRotaryEmbedding with Falcon->OpenLlama
124124
class OpenLlamaDynamicNTKScalingRotaryEmbedding(OpenLlamaRotaryEmbedding):
125125
"""OpenLlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
126126

src/transformers/models/falcon/modeling_falcon.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,8 @@ def forward(self, x, seq_len=None):
167167
)
168168

169169

170-
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Falcon
170+
# copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Falcon
171+
# TODO @joao no longer copied from LLama after static cache, fix me (copied -> Copied)
171172
class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding):
172173
"""FalconRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
173174

@@ -187,7 +188,8 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
187188
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
188189

189190

190-
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Falcon
191+
# copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Falcon
192+
# TODO @joao no longer copied from LLama after static cache, fix me (copied -> Copied)
191193
class FalconDynamicNTKScalingRotaryEmbedding(FalconRotaryEmbedding):
192194
"""FalconRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
193195

src/transformers/models/llama/modeling_llama.py

+26-33
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ def forward(self, hidden_states):
9494
class LlamaRotaryEmbedding(nn.Module):
9595
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
9696
super().__init__()
97-
9897
self.dim = dim
9998
self.max_position_embeddings = max_position_embeddings
10099
self.base = base
@@ -118,6 +117,9 @@ def cos_cached(self):
118117
return self._cos_cached
119118

120119
def forward(self, x, position_ids, seq_len=None):
120+
if seq_len is not None:
121+
logger.warning_once("The `seq_len` argument is deprecated and unused. It will be removed in v4.40.")
122+
121123
# x: [bs, num_attention_heads, seq_len, head_size]
122124
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
123125
position_ids_expanded = position_ids[:, None, :].float()
@@ -138,16 +140,11 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, s
138140
self.scaling_factor = scaling_factor
139141
super().__init__(dim, max_position_embeddings, base, device)
140142

141-
def _set_cos_sin_cache(self, seq_len, device, dtype):
142-
self.max_seq_len_cached = seq_len
143-
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
144-
t = t / self.scaling_factor
145-
146-
freqs = torch.outer(t, self.inv_freq)
147-
# Different from paper, but it uses a different permutation in order to obtain the same calculation
148-
emb = torch.cat((freqs, freqs), dim=-1)
149-
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
150-
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
143+
def forward(self, x, position_ids, seq_len=None):
144+
# difference to the original RoPE: a scaling factor is aplied to the position ids
145+
position_ids = position_ids.float() / self.scaling_factor
146+
cos, sin = super().forward(x, position_ids, seq_len)
147+
return cos, sin
151148

152149

153150
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
@@ -157,23 +154,20 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, s
157154
self.scaling_factor = scaling_factor
158155
super().__init__(dim, max_position_embeddings, base, device)
159156

160-
def _set_cos_sin_cache(self, seq_len, device, dtype):
161-
self.max_seq_len_cached = seq_len
162-
157+
def forward(self, x, position_ids, seq_len=None):
158+
# difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
159+
seq_len = torch.max(position_ids) + 1
163160
if seq_len > self.max_position_embeddings:
164161
base = self.base * (
165162
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
166163
) ** (self.dim / (self.dim - 2))
167-
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
168-
self.register_buffer("inv_freq", inv_freq, persistent=False)
169-
170-
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
164+
inv_freq = 1.0 / (
165+
base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim)
166+
)
167+
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation
171168

172-
freqs = torch.outer(t, self.inv_freq)
173-
# Different from paper, but it uses a different permutation in order to obtain the same calculation
174-
emb = torch.cat((freqs, freqs), dim=-1)
175-
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
176-
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
169+
cos, sin = super().forward(x, position_ids, seq_len)
170+
return cos, sin
177171

178172

179173
def rotate_half(x):
@@ -183,17 +177,16 @@ def rotate_half(x):
183177
return torch.cat((-x2, x1), dim=-1)
184178

185179

186-
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
180+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
187181
"""Applies Rotary Position Embedding to the query and key tensors.
188182
189183
Args:
190184
q (`torch.Tensor`): The query tensor.
191185
k (`torch.Tensor`): The key tensor.
192186
cos (`torch.Tensor`): The cosine part of the rotary embedding.
193187
sin (`torch.Tensor`): The sine part of the rotary embedding.
194-
position_ids (`torch.Tensor`):
195-
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
196-
used to pass offsetted position ids when working with a KV-cache.
188+
position_ids (`torch.Tensor`, *optional*):
189+
Deprecated and unused.
197190
unsqueeze_dim (`int`, *optional*, defaults to 1):
198191
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
199192
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
@@ -360,8 +353,8 @@ def forward(
360353
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
361354

362355
past_key_value = getattr(self, "past_key_value", past_key_value)
363-
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
364-
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
356+
cos, sin = self.rotary_emb(value_states, position_ids)
357+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
365358

366359
if past_key_value is not None:
367360
# sin and cos are specific to RoPE models; position_ids needed for the static cache
@@ -447,8 +440,8 @@ def forward(
447440
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
448441
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
449442

450-
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
451-
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
443+
cos, sin = self.rotary_emb(value_states, position_ids)
444+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
452445

453446
past_key_value = getattr(self, "past_key_value", past_key_value)
454447

@@ -645,8 +638,8 @@ def forward(
645638
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
646639
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
647640

648-
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
649-
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
641+
cos, sin = self.rotary_emb(value_states, position_ids)
642+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
650643

651644
past_key_value = getattr(self, "past_key_value", past_key_value)
652645

src/transformers/models/persimmon/modeling_persimmon.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def forward(self, x, seq_len=None):
7777
)
7878

7979

80-
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Persimmon
80+
# Copied from transformers.models.falcon.modeling_falcon.FalconLinearScalingRotaryEmbedding with Falcon->Persimmon
8181
class PersimmonLinearScalingRotaryEmbedding(PersimmonRotaryEmbedding):
8282
"""PersimmonRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
8383

@@ -97,7 +97,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
9797
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
9898

9999

100-
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Persimmon
100+
# Copied from transformers.models.falcon.modeling_falcon.FalconDynamicNTKScalingRotaryEmbedding with Falcon->Persimmon
101101
class PersimmonDynamicNTKScalingRotaryEmbedding(PersimmonRotaryEmbedding):
102102
"""PersimmonRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
103103

src/transformers/models/phi/modeling_phi.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def forward(self, x, seq_len=None):
115115
)
116116

117117

118-
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Phi
118+
# Copied from transformers.models.falcon.modeling_falcon.FalconLinearScalingRotaryEmbedding with Falcon->Phi
119119
class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding):
120120
"""PhiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
121121

@@ -135,7 +135,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
135135
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
136136

137137

138-
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Phi
138+
# Copied from transformers.models.falcon.modeling_falcon.FalconDynamicNTKScalingRotaryEmbedding with Falcon->Phi
139139
class PhiDynamicNTKScalingRotaryEmbedding(PhiRotaryEmbedding):
140140
"""PhiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
141141

src/transformers/models/stablelm/modeling_stablelm.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def forward(self, x, seq_len=None):
103103
)
104104

105105

106-
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->StableLm
106+
# Copied from transformers.models.falcon.modeling_falcon.FalconLinearScalingRotaryEmbedding with Falcon->StableLm
107107
class StableLmLinearScalingRotaryEmbedding(StableLmRotaryEmbedding):
108108
"""StableLmRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
109109

@@ -123,7 +123,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
123123
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
124124

125125

126-
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->StableLm
126+
# Copied from transformers.models.falcon.modeling_falcon.FalconDynamicNTKScalingRotaryEmbedding with Falcon->StableLm
127127
class StableLmDynamicNTKScalingRotaryEmbedding(StableLmRotaryEmbedding):
128128
"""StableLmRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
129129

tests/models/llama/test_modeling_llama.py

-1
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,6 @@ def test_save_load_fast_init_from_base(self):
362362
pass
363363

364364
@parameterized.expand([("linear",), ("dynamic",)])
365-
@unittest.skip("TODO @gante fix this for Llama")
366365
def test_model_rope_scaling(self, scaling_type):
367366
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
368367
short_input = ids_tensor([1, 10], config.vocab_size)

0 commit comments

Comments
 (0)