Skip to content

Commit ec06e00

Browse files
authored
Use TypeVar defaults instead of Any when fixing instance types (PEP 696) (#16812)
Start using TypeVar defaults when fixing instance types, instead of filling those with `Any`. This PR preserves the way an invalid amount of args is handled. I.e. filling all with `Any` / defaults, instead of cutting off additional args. Thus preserving full backwards compatibility. This can be easily changed later if necessary. `TypeVarTuple` defaults aren't handled correctly yet. Those will require additional logic which would have complicated the change here and made it more difficult to review. Ref: #14851
1 parent 7eab8a4 commit ec06e00

File tree

3 files changed

+189
-36
lines changed

3 files changed

+189
-36
lines changed

Diff for: mypy/messages.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -3017,12 +3017,15 @@ def for_function(callee: CallableType) -> str:
30173017
return ""
30183018

30193019

3020-
def wrong_type_arg_count(n: int, act: str, name: str) -> str:
3021-
s = f"{n} type arguments"
3022-
if n == 0:
3023-
s = "no type arguments"
3024-
elif n == 1:
3025-
s = "1 type argument"
3020+
def wrong_type_arg_count(low: int, high: int, act: str, name: str) -> str:
3021+
if low == high:
3022+
s = f"{low} type arguments"
3023+
if low == 0:
3024+
s = "no type arguments"
3025+
elif low == 1:
3026+
s = "1 type argument"
3027+
else:
3028+
s = f"between {low} and {high} type arguments"
30263029
if act == "0":
30273030
act = "none"
30283031
return f'"{name}" expects {s}, but {act} given'

Diff for: mypy/typeanal.py

+57-30
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from mypy import errorcodes as codes, message_registry, nodes
1111
from mypy.errorcodes import ErrorCode
12+
from mypy.expandtype import expand_type
1213
from mypy.messages import MessageBuilder, format_type_bare, quote_type_string, wrong_type_arg_count
1314
from mypy.nodes import (
1415
ARG_NAMED,
@@ -75,6 +76,7 @@
7576
TypeOfAny,
7677
TypeQuery,
7778
TypeType,
79+
TypeVarId,
7880
TypeVarLikeType,
7981
TypeVarTupleType,
8082
TypeVarType,
@@ -1834,14 +1836,14 @@ def get_omitted_any(
18341836
return any_type
18351837

18361838

1837-
def fix_type_var_tuple_argument(any_type: Type, t: Instance) -> None:
1839+
def fix_type_var_tuple_argument(t: Instance) -> None:
18381840
if t.type.has_type_var_tuple_type:
18391841
args = list(t.args)
18401842
assert t.type.type_var_tuple_prefix is not None
18411843
tvt = t.type.defn.type_vars[t.type.type_var_tuple_prefix]
18421844
assert isinstance(tvt, TypeVarTupleType)
18431845
args[t.type.type_var_tuple_prefix] = UnpackType(
1844-
Instance(tvt.tuple_fallback.type, [any_type])
1846+
Instance(tvt.tuple_fallback.type, [args[t.type.type_var_tuple_prefix]])
18451847
)
18461848
t.args = tuple(args)
18471849

@@ -1855,26 +1857,42 @@ def fix_instance(
18551857
use_generic_error: bool = False,
18561858
unexpanded_type: Type | None = None,
18571859
) -> None:
1858-
"""Fix a malformed instance by replacing all type arguments with Any.
1860+
"""Fix a malformed instance by replacing all type arguments with TypeVar default or Any.
18591861
18601862
Also emit a suitable error if this is not due to implicit Any's.
18611863
"""
1862-
if len(t.args) == 0:
1863-
if use_generic_error:
1864-
fullname: str | None = None
1865-
else:
1866-
fullname = t.type.fullname
1867-
any_type = get_omitted_any(disallow_any, fail, note, t, options, fullname, unexpanded_type)
1868-
t.args = (any_type,) * len(t.type.type_vars)
1869-
fix_type_var_tuple_argument(any_type, t)
1870-
return
1871-
# Construct the correct number of type arguments, as
1872-
# otherwise the type checker may crash as it expects
1873-
# things to be right.
1874-
any_type = AnyType(TypeOfAny.from_error)
1875-
t.args = tuple(any_type for _ in t.type.type_vars)
1876-
fix_type_var_tuple_argument(any_type, t)
1877-
t.invalid = True
1864+
arg_count = len(t.args)
1865+
min_tv_count = sum(not tv.has_default() for tv in t.type.defn.type_vars)
1866+
max_tv_count = len(t.type.type_vars)
1867+
if arg_count < min_tv_count or arg_count > max_tv_count:
1868+
# Don't use existing args if arg_count doesn't match
1869+
t.args = ()
1870+
1871+
args: list[Type] = [*(t.args[:max_tv_count])]
1872+
any_type: AnyType | None = None
1873+
env: dict[TypeVarId, Type] = {}
1874+
1875+
for tv, arg in itertools.zip_longest(t.type.defn.type_vars, t.args, fillvalue=None):
1876+
if tv is None:
1877+
continue
1878+
if arg is None:
1879+
if tv.has_default():
1880+
arg = tv.default
1881+
else:
1882+
if any_type is None:
1883+
fullname = None if use_generic_error else t.type.fullname
1884+
any_type = get_omitted_any(
1885+
disallow_any, fail, note, t, options, fullname, unexpanded_type
1886+
)
1887+
arg = any_type
1888+
args.append(arg)
1889+
env[tv.id] = arg
1890+
t.args = tuple(args)
1891+
fix_type_var_tuple_argument(t)
1892+
if not t.type.has_type_var_tuple_type:
1893+
fixed = expand_type(t, env)
1894+
assert isinstance(fixed, Instance)
1895+
t.args = fixed.args
18781896

18791897

18801898
def instantiate_type_alias(
@@ -1963,7 +1981,7 @@ def instantiate_type_alias(
19631981
if use_standard_error:
19641982
# This is used if type alias is an internal representation of another type,
19651983
# for example a generic TypedDict or NamedTuple.
1966-
msg = wrong_type_arg_count(exp_len, str(act_len), node.name)
1984+
msg = wrong_type_arg_count(exp_len, exp_len, str(act_len), node.name)
19671985
else:
19681986
if node.tvar_tuple_index is not None:
19691987
exp_len_str = f"at least {exp_len - 1}"
@@ -2217,24 +2235,27 @@ def validate_instance(t: Instance, fail: MsgCallback, empty_tuple_index: bool) -
22172235
# TODO: is it OK to fill with TypeOfAny.from_error instead of special form?
22182236
return False
22192237
if t.type.has_type_var_tuple_type:
2220-
correct = len(t.args) >= len(t.type.type_vars) - 1
2238+
min_tv_count = sum(
2239+
not tv.has_default() and not isinstance(tv, TypeVarTupleType)
2240+
for tv in t.type.defn.type_vars
2241+
)
2242+
correct = len(t.args) >= min_tv_count
22212243
if any(
22222244
isinstance(a, UnpackType) and isinstance(get_proper_type(a.type), Instance)
22232245
for a in t.args
22242246
):
22252247
correct = True
2226-
if not correct:
2227-
exp_len = f"at least {len(t.type.type_vars) - 1}"
2248+
if not t.args:
2249+
if not (empty_tuple_index and len(t.type.type_vars) == 1):
2250+
# The Any arguments should be set by the caller.
2251+
return False
2252+
elif not correct:
22282253
fail(
2229-
f"Bad number of arguments, expected: {exp_len}, given: {len(t.args)}",
2254+
f"Bad number of arguments, expected: at least {min_tv_count}, given: {len(t.args)}",
22302255
t,
22312256
code=codes.TYPE_ARG,
22322257
)
22332258
return False
2234-
elif not t.args:
2235-
if not (empty_tuple_index and len(t.type.type_vars) == 1):
2236-
# The Any arguments should be set by the caller.
2237-
return False
22382259
else:
22392260
# We also need to check if we are not performing a type variable tuple split.
22402261
unpack = find_unpack_in_list(t.args)
@@ -2254,15 +2275,21 @@ def validate_instance(t: Instance, fail: MsgCallback, empty_tuple_index: bool) -
22542275
elif any(isinstance(a, UnpackType) for a in t.args):
22552276
# A variadic unpack in fixed size instance (fixed unpacks must be flattened by the caller)
22562277
fail(message_registry.INVALID_UNPACK_POSITION, t, code=codes.VALID_TYPE)
2278+
t.args = ()
22572279
return False
22582280
elif len(t.args) != len(t.type.type_vars):
22592281
# Invalid number of type parameters.
2260-
if t.args:
2282+
arg_count = len(t.args)
2283+
min_tv_count = sum(not tv.has_default() for tv in t.type.defn.type_vars)
2284+
max_tv_count = len(t.type.type_vars)
2285+
if arg_count and (arg_count < min_tv_count or arg_count > max_tv_count):
22612286
fail(
2262-
wrong_type_arg_count(len(t.type.type_vars), str(len(t.args)), t.type.name),
2287+
wrong_type_arg_count(min_tv_count, max_tv_count, str(arg_count), t.type.name),
22632288
t,
22642289
code=codes.TYPE_ARG,
22652290
)
2291+
t.args = ()
2292+
t.invalid = True
22662293
return False
22672294
return True
22682295

Diff for: test-data/unit/check-typevar-defaults.test

+123
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,126 @@ def func_c1(x: Union[int, Callable[[Unpack[Ts1]], None]]) -> Tuple[Unpack[Ts1]]:
116116
# reveal_type(func_c1(callback1)) # Revealed type is "builtins.tuple[str]" # TODO
117117
# reveal_type(func_c1(2)) # Revealed type is "builtins.tuple[builtins.int, builtins.str]" # TODO
118118
[builtins fixtures/tuple.pyi]
119+
120+
[case testTypeVarDefaultsClass1]
121+
from typing import Generic, TypeVar
122+
123+
T1 = TypeVar("T1")
124+
T2 = TypeVar("T2", default=int)
125+
T3 = TypeVar("T3", default=str)
126+
127+
class ClassA1(Generic[T2, T3]): ...
128+
129+
def func_a1(
130+
a: ClassA1,
131+
b: ClassA1[float],
132+
c: ClassA1[float, float],
133+
d: ClassA1[float, float, float], # E: "ClassA1" expects between 0 and 2 type arguments, but 3 given
134+
) -> None:
135+
reveal_type(a) # N: Revealed type is "__main__.ClassA1[builtins.int, builtins.str]"
136+
reveal_type(b) # N: Revealed type is "__main__.ClassA1[builtins.float, builtins.str]"
137+
reveal_type(c) # N: Revealed type is "__main__.ClassA1[builtins.float, builtins.float]"
138+
reveal_type(d) # N: Revealed type is "__main__.ClassA1[builtins.int, builtins.str]"
139+
140+
class ClassA2(Generic[T1, T2, T3]): ...
141+
142+
def func_a2(
143+
a: ClassA2,
144+
b: ClassA2[float],
145+
c: ClassA2[float, float],
146+
d: ClassA2[float, float, float],
147+
e: ClassA2[float, float, float, float], # E: "ClassA2" expects between 1 and 3 type arguments, but 4 given
148+
) -> None:
149+
reveal_type(a) # N: Revealed type is "__main__.ClassA2[Any, builtins.int, builtins.str]"
150+
reveal_type(b) # N: Revealed type is "__main__.ClassA2[builtins.float, builtins.int, builtins.str]"
151+
reveal_type(c) # N: Revealed type is "__main__.ClassA2[builtins.float, builtins.float, builtins.str]"
152+
reveal_type(d) # N: Revealed type is "__main__.ClassA2[builtins.float, builtins.float, builtins.float]"
153+
reveal_type(e) # N: Revealed type is "__main__.ClassA2[Any, builtins.int, builtins.str]"
154+
155+
[case testTypeVarDefaultsClass2]
156+
from typing import Generic, ParamSpec
157+
158+
P1 = ParamSpec("P1")
159+
P2 = ParamSpec("P2", default=[int, str])
160+
P3 = ParamSpec("P3", default=...)
161+
162+
class ClassB1(Generic[P2, P3]): ...
163+
164+
def func_b1(
165+
a: ClassB1,
166+
b: ClassB1[[float]],
167+
c: ClassB1[[float], [float]],
168+
d: ClassB1[[float], [float], [float]], # E: "ClassB1" expects between 0 and 2 type arguments, but 3 given
169+
) -> None:
170+
reveal_type(a) # N: Revealed type is "__main__.ClassB1[[builtins.int, builtins.str], ...]"
171+
reveal_type(b) # N: Revealed type is "__main__.ClassB1[[builtins.float], ...]"
172+
reveal_type(c) # N: Revealed type is "__main__.ClassB1[[builtins.float], [builtins.float]]"
173+
reveal_type(d) # N: Revealed type is "__main__.ClassB1[[builtins.int, builtins.str], ...]"
174+
175+
class ClassB2(Generic[P1, P2]): ...
176+
177+
def func_b2(
178+
a: ClassB2,
179+
b: ClassB2[[float]],
180+
c: ClassB2[[float], [float]],
181+
d: ClassB2[[float], [float], [float]], # E: "ClassB2" expects between 1 and 2 type arguments, but 3 given
182+
) -> None:
183+
reveal_type(a) # N: Revealed type is "__main__.ClassB2[Any, [builtins.int, builtins.str]]"
184+
reveal_type(b) # N: Revealed type is "__main__.ClassB2[[builtins.float], [builtins.int, builtins.str]]"
185+
reveal_type(c) # N: Revealed type is "__main__.ClassB2[[builtins.float], [builtins.float]]"
186+
reveal_type(d) # N: Revealed type is "__main__.ClassB2[Any, [builtins.int, builtins.str]]"
187+
188+
[case testTypeVarDefaultsClass3]
189+
from typing import Generic, Tuple, TypeVar
190+
from typing_extensions import TypeVarTuple, Unpack
191+
192+
T1 = TypeVar("T1")
193+
T3 = TypeVar("T3", default=str)
194+
195+
Ts1 = TypeVarTuple("Ts1")
196+
Ts2 = TypeVarTuple("Ts2", default=Unpack[Tuple[int, str]])
197+
Ts3 = TypeVarTuple("Ts3", default=Unpack[Tuple[float, ...]])
198+
Ts4 = TypeVarTuple("Ts4", default=Unpack[Tuple[()]])
199+
200+
class ClassC1(Generic[Unpack[Ts2]]): ...
201+
202+
def func_c1(
203+
a: ClassC1,
204+
b: ClassC1[float],
205+
) -> None:
206+
# reveal_type(a) # Revealed type is "__main__.ClassC1[builtins.int, builtins.str]" # TODO
207+
reveal_type(b) # N: Revealed type is "__main__.ClassC1[builtins.float]"
208+
209+
class ClassC2(Generic[T3, Unpack[Ts3]]): ...
210+
211+
def func_c2(
212+
a: ClassC2,
213+
b: ClassC2[int],
214+
c: ClassC2[int, Unpack[Tuple[()]]],
215+
) -> None:
216+
reveal_type(a) # N: Revealed type is "__main__.ClassC2[builtins.str, Unpack[builtins.tuple[builtins.float, ...]]]"
217+
# reveal_type(b) # Revealed type is "__main__.ClassC2[builtins.int, Unpack[builtins.tuple[builtins.float, ...]]]" # TODO
218+
reveal_type(c) # N: Revealed type is "__main__.ClassC2[builtins.int]"
219+
220+
class ClassC3(Generic[T3, Unpack[Ts4]]): ...
221+
222+
def func_c3(
223+
a: ClassC3,
224+
b: ClassC3[int],
225+
c: ClassC3[int, Unpack[Tuple[float]]]
226+
) -> None:
227+
# reveal_type(a) # Revealed type is "__main__.ClassC3[builtins.str]" # TODO
228+
reveal_type(b) # N: Revealed type is "__main__.ClassC3[builtins.int]"
229+
reveal_type(c) # N: Revealed type is "__main__.ClassC3[builtins.int, builtins.float]"
230+
231+
class ClassC4(Generic[T1, Unpack[Ts1], T3]): ...
232+
233+
def func_c4(
234+
a: ClassC4,
235+
b: ClassC4[int],
236+
c: ClassC4[int, float],
237+
) -> None:
238+
reveal_type(a) # N: Revealed type is "__main__.ClassC4[Any, Unpack[builtins.tuple[Any, ...]], builtins.str]"
239+
# reveal_type(b) # Revealed type is "__main__.ClassC4[builtins.int, builtins.str]" # TODO
240+
reveal_type(c) # N: Revealed type is "__main__.ClassC4[builtins.int, builtins.float]"
241+
[builtins fixtures/tuple.pyi]

0 commit comments

Comments
 (0)