Skip to content

Commit 603a365

Browse files
authored
Add initial support for new style TypeVar defaults (PEP 696) (#17985)
Add initial support for TypeVar defaults using the new syntax. Similar to the old syntax, it doesn't fully work yet for ParamSpec, TypeVarTuple and recursive TypeVar defaults. Refs: #14851
1 parent c9d4c61 commit 603a365

10 files changed

+456
-85
lines changed

mypy/checker.py

+15
Original file line numberDiff line numberDiff line change
@@ -1159,6 +1159,7 @@ def check_func_def(
11591159
) -> None:
11601160
"""Type check a function definition."""
11611161
# Expand type variables with value restrictions to ordinary types.
1162+
self.check_typevar_defaults(typ.variables)
11621163
expanded = self.expand_typevars(defn, typ)
11631164
original_typ = typ
11641165
for item, typ in expanded:
@@ -2483,6 +2484,8 @@ def visit_class_def(self, defn: ClassDef) -> None:
24832484
context=defn,
24842485
code=codes.TYPE_VAR,
24852486
)
2487+
if typ.defn.type_vars:
2488+
self.check_typevar_defaults(typ.defn.type_vars)
24862489

24872490
if typ.is_protocol and typ.defn.type_vars:
24882491
self.check_protocol_variance(defn)
@@ -2546,6 +2549,15 @@ def check_init_subclass(self, defn: ClassDef) -> None:
25462549
# all other bases have already been checked.
25472550
break
25482551

2552+
def check_typevar_defaults(self, tvars: Sequence[TypeVarLikeType]) -> None:
2553+
for tv in tvars:
2554+
if not (isinstance(tv, TypeVarType) and tv.has_default()):
2555+
continue
2556+
if not is_subtype(tv.default, tv.upper_bound):
2557+
self.fail("TypeVar default must be a subtype of the bound type", tv)
2558+
if tv.values and not any(tv.default == value for value in tv.values):
2559+
self.fail("TypeVar default must be one of the constraint types", tv)
2560+
25492561
def check_enum(self, defn: ClassDef) -> None:
25502562
assert defn.info.is_enum
25512563
if defn.info.fullname not in ENUM_BASES:
@@ -5365,6 +5377,9 @@ def remove_capture_conflicts(self, type_map: TypeMap, inferred_types: dict[Var,
53655377
del type_map[expr]
53665378

53675379
def visit_type_alias_stmt(self, o: TypeAliasStmt) -> None:
5380+
if o.alias_node:
5381+
self.check_typevar_defaults(o.alias_node.alias_tvars)
5382+
53685383
with self.msg.filter_errors():
53695384
self.expr_checker.accept(o.value)
53705385

mypy/fastparse.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -1196,19 +1196,17 @@ def validate_type_param(self, type_param: ast_TypeVar) -> None:
11961196
def translate_type_params(self, type_params: list[Any]) -> list[TypeParam]:
11971197
explicit_type_params = []
11981198
for p in type_params:
1199-
bound = None
1199+
bound: Type | None = None
12001200
values: list[Type] = []
1201-
if sys.version_info >= (3, 13) and p.default_value is not None:
1202-
self.fail(
1203-
message_registry.TYPE_PARAM_DEFAULT_NOT_SUPPORTED,
1204-
p.lineno,
1205-
p.col_offset,
1206-
blocker=False,
1207-
)
1201+
default: Type | None = None
1202+
if sys.version_info >= (3, 13):
1203+
default = TypeConverter(self.errors, line=p.lineno).visit(p.default_value)
12081204
if isinstance(p, ast_ParamSpec): # type: ignore[misc]
1209-
explicit_type_params.append(TypeParam(p.name, PARAM_SPEC_KIND, None, []))
1205+
explicit_type_params.append(TypeParam(p.name, PARAM_SPEC_KIND, None, [], default))
12101206
elif isinstance(p, ast_TypeVarTuple): # type: ignore[misc]
1211-
explicit_type_params.append(TypeParam(p.name, TYPE_VAR_TUPLE_KIND, None, []))
1207+
explicit_type_params.append(
1208+
TypeParam(p.name, TYPE_VAR_TUPLE_KIND, None, [], default)
1209+
)
12121210
else:
12131211
if isinstance(p.bound, ast3.Tuple):
12141212
if len(p.bound.elts) < 2:
@@ -1224,7 +1222,9 @@ def translate_type_params(self, type_params: list[Any]) -> list[TypeParam]:
12241222
elif p.bound is not None:
12251223
self.validate_type_param(p)
12261224
bound = TypeConverter(self.errors, line=p.lineno).visit(p.bound)
1227-
explicit_type_params.append(TypeParam(p.name, TYPE_VAR_KIND, bound, values))
1225+
explicit_type_params.append(
1226+
TypeParam(p.name, TYPE_VAR_KIND, bound, values, default)
1227+
)
12281228
return explicit_type_params
12291229

12301230
# Return(expr? value)

mypy/message_registry.py

-5
Original file line numberDiff line numberDiff line change
@@ -362,8 +362,3 @@ def with_additional_msg(self, info: str) -> ErrorMessage:
362362
TYPE_ALIAS_WITH_AWAIT_EXPRESSION: Final = ErrorMessage(
363363
"Await expression cannot be used within a type alias", codes.SYNTAX
364364
)
365-
366-
TYPE_PARAM_DEFAULT_NOT_SUPPORTED: Final = ErrorMessage(
367-
"Type parameter default types not supported when using Python 3.12 type parameter syntax",
368-
codes.MISC,
369-
)

mypy/nodes.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -670,19 +670,21 @@ def set_line(
670670

671671

672672
class TypeParam:
673-
__slots__ = ("name", "kind", "upper_bound", "values")
673+
__slots__ = ("name", "kind", "upper_bound", "values", "default")
674674

675675
def __init__(
676676
self,
677677
name: str,
678678
kind: int,
679679
upper_bound: mypy.types.Type | None,
680680
values: list[mypy.types.Type],
681+
default: mypy.types.Type | None,
681682
) -> None:
682683
self.name = name
683684
self.kind = kind
684685
self.upper_bound = upper_bound
685686
self.values = values
687+
self.default = default
686688

687689

688690
FUNCITEM_FLAGS: Final = FUNCBASE_FLAGS + [
@@ -782,7 +784,7 @@ class FuncDef(FuncItem, SymbolNode, Statement):
782784
"deco_line",
783785
"is_trivial_body",
784786
"is_mypy_only",
785-
# Present only when a function is decorated with @typing.datasclass_transform or similar
787+
# Present only when a function is decorated with @typing.dataclass_transform or similar
786788
"dataclass_transform_spec",
787789
"docstring",
788790
"deprecated",
@@ -1657,21 +1659,23 @@ def accept(self, visitor: StatementVisitor[T]) -> T:
16571659

16581660

16591661
class TypeAliasStmt(Statement):
1660-
__slots__ = ("name", "type_args", "value", "invalid_recursive_alias")
1662+
__slots__ = ("name", "type_args", "value", "invalid_recursive_alias", "alias_node")
16611663

16621664
__match_args__ = ("name", "type_args", "value")
16631665

16641666
name: NameExpr
16651667
type_args: list[TypeParam]
16661668
value: LambdaExpr # Return value will get translated into a type
16671669
invalid_recursive_alias: bool
1670+
alias_node: TypeAlias | None
16681671

16691672
def __init__(self, name: NameExpr, type_args: list[TypeParam], value: LambdaExpr) -> None:
16701673
super().__init__()
16711674
self.name = name
16721675
self.type_args = type_args
16731676
self.value = value
16741677
self.invalid_recursive_alias = False
1678+
self.alias_node = None
16751679

16761680
def accept(self, visitor: StatementVisitor[T]) -> T:
16771681
return visitor.visit_type_alias_stmt(self)

mypy/semanal.py

+79-53
Original file line numberDiff line numberDiff line change
@@ -1808,7 +1808,26 @@ def analyze_type_param(
18081808
upper_bound = self.named_type("builtins.tuple", [self.object_type()])
18091809
else:
18101810
upper_bound = self.object_type()
1811-
default = AnyType(TypeOfAny.from_omitted_generics)
1811+
if type_param.default:
1812+
default = self.anal_type(
1813+
type_param.default,
1814+
allow_placeholder=True,
1815+
allow_unbound_tvars=True,
1816+
report_invalid_types=False,
1817+
allow_param_spec_literals=type_param.kind == PARAM_SPEC_KIND,
1818+
allow_tuple_literal=type_param.kind == PARAM_SPEC_KIND,
1819+
allow_unpack=type_param.kind == TYPE_VAR_TUPLE_KIND,
1820+
)
1821+
if default is None:
1822+
default = PlaceholderType(None, [], context.line)
1823+
elif type_param.kind == TYPE_VAR_KIND:
1824+
default = self.check_typevar_default(default, type_param.default)
1825+
elif type_param.kind == PARAM_SPEC_KIND:
1826+
default = self.check_paramspec_default(default, type_param.default)
1827+
elif type_param.kind == TYPE_VAR_TUPLE_KIND:
1828+
default = self.check_typevartuple_default(default, type_param.default)
1829+
else:
1830+
default = AnyType(TypeOfAny.from_omitted_generics)
18121831
if type_param.kind == TYPE_VAR_KIND:
18131832
values = []
18141833
if type_param.values:
@@ -2243,21 +2262,7 @@ class Foo(Bar, Generic[T]): ...
22432262
# grained incremental mode.
22442263
defn.removed_base_type_exprs.append(defn.base_type_exprs[i])
22452264
del base_type_exprs[i]
2246-
tvar_defs: list[TypeVarLikeType] = []
2247-
last_tvar_name_with_default: str | None = None
2248-
for name, tvar_expr in declared_tvars:
2249-
tvar_expr.default = tvar_expr.default.accept(
2250-
TypeVarDefaultTranslator(self, tvar_expr.name, context)
2251-
)
2252-
tvar_def = self.tvar_scope.bind_new(name, tvar_expr)
2253-
if last_tvar_name_with_default is not None and not tvar_def.has_default():
2254-
self.msg.tvar_without_default_type(
2255-
tvar_def.name, last_tvar_name_with_default, context
2256-
)
2257-
tvar_def.default = AnyType(TypeOfAny.from_error)
2258-
elif tvar_def.has_default():
2259-
last_tvar_name_with_default = tvar_def.name
2260-
tvar_defs.append(tvar_def)
2265+
tvar_defs = self.tvar_defs_from_tvars(declared_tvars, context)
22612266
return base_type_exprs, tvar_defs, is_protocol
22622267

22632268
def analyze_class_typevar_declaration(self, base: Type) -> tuple[TypeVarLikeList, bool] | None:
@@ -2358,6 +2363,26 @@ def get_all_bases_tvars(
23582363
tvars.extend(base_tvars)
23592364
return remove_dups(tvars)
23602365

2366+
def tvar_defs_from_tvars(
2367+
self, tvars: TypeVarLikeList, context: Context
2368+
) -> list[TypeVarLikeType]:
2369+
tvar_defs: list[TypeVarLikeType] = []
2370+
last_tvar_name_with_default: str | None = None
2371+
for name, tvar_expr in tvars:
2372+
tvar_expr.default = tvar_expr.default.accept(
2373+
TypeVarDefaultTranslator(self, tvar_expr.name, context)
2374+
)
2375+
tvar_def = self.tvar_scope.bind_new(name, tvar_expr)
2376+
if last_tvar_name_with_default is not None and not tvar_def.has_default():
2377+
self.msg.tvar_without_default_type(
2378+
tvar_def.name, last_tvar_name_with_default, context
2379+
)
2380+
tvar_def.default = AnyType(TypeOfAny.from_error)
2381+
elif tvar_def.has_default():
2382+
last_tvar_name_with_default = tvar_def.name
2383+
tvar_defs.append(tvar_def)
2384+
return tvar_defs
2385+
23612386
def get_and_bind_all_tvars(self, type_exprs: list[Expression]) -> list[TypeVarLikeType]:
23622387
"""Return all type variable references in item type expressions.
23632388
@@ -3833,21 +3858,8 @@ def analyze_alias(
38333858
tvar_defs: list[TypeVarLikeType] = []
38343859
namespace = self.qualified_name(name)
38353860
alias_type_vars = found_type_vars if declared_type_vars is None else declared_type_vars
3836-
last_tvar_name_with_default: str | None = None
38373861
with self.tvar_scope_frame(self.tvar_scope.class_frame(namespace)):
3838-
for name, tvar_expr in alias_type_vars:
3839-
tvar_expr.default = tvar_expr.default.accept(
3840-
TypeVarDefaultTranslator(self, tvar_expr.name, typ)
3841-
)
3842-
tvar_def = self.tvar_scope.bind_new(name, tvar_expr)
3843-
if last_tvar_name_with_default is not None and not tvar_def.has_default():
3844-
self.msg.tvar_without_default_type(
3845-
tvar_def.name, last_tvar_name_with_default, typ
3846-
)
3847-
tvar_def.default = AnyType(TypeOfAny.from_error)
3848-
elif tvar_def.has_default():
3849-
last_tvar_name_with_default = tvar_def.name
3850-
tvar_defs.append(tvar_def)
3862+
tvar_defs = self.tvar_defs_from_tvars(alias_type_vars, typ)
38513863

38523864
if python_3_12_type_alias:
38533865
with self.allow_unbound_tvars_set():
@@ -4615,6 +4627,40 @@ def process_typevar_declaration(self, s: AssignmentStmt) -> bool:
46154627
self.add_symbol(name, call.analyzed, s)
46164628
return True
46174629

4630+
def check_typevar_default(self, default: Type, context: Context) -> Type:
4631+
typ = get_proper_type(default)
4632+
if isinstance(typ, AnyType) and typ.is_from_error:
4633+
self.fail(
4634+
message_registry.TYPEVAR_ARG_MUST_BE_TYPE.format("TypeVar", "default"), context
4635+
)
4636+
return default
4637+
4638+
def check_paramspec_default(self, default: Type, context: Context) -> Type:
4639+
typ = get_proper_type(default)
4640+
if isinstance(typ, Parameters):
4641+
for i, arg_type in enumerate(typ.arg_types):
4642+
arg_ptype = get_proper_type(arg_type)
4643+
if isinstance(arg_ptype, AnyType) and arg_ptype.is_from_error:
4644+
self.fail(f"Argument {i} of ParamSpec default must be a type", context)
4645+
elif (
4646+
isinstance(typ, AnyType)
4647+
and typ.is_from_error
4648+
or not isinstance(typ, (AnyType, UnboundType))
4649+
):
4650+
self.fail(
4651+
"The default argument to ParamSpec must be a list expression, ellipsis, or a ParamSpec",
4652+
context,
4653+
)
4654+
default = AnyType(TypeOfAny.from_error)
4655+
return default
4656+
4657+
def check_typevartuple_default(self, default: Type, context: Context) -> Type:
4658+
typ = get_proper_type(default)
4659+
if not isinstance(typ, UnpackType):
4660+
self.fail("The default argument to TypeVarTuple must be an Unpacked tuple", context)
4661+
default = AnyType(TypeOfAny.from_error)
4662+
return default
4663+
46184664
def check_typevarlike_name(self, call: CallExpr, name: str, context: Context) -> bool:
46194665
"""Checks that the name of a TypeVar or ParamSpec matches its variable."""
46204666
name = unmangle(name)
@@ -4822,23 +4868,7 @@ def process_paramspec_declaration(self, s: AssignmentStmt) -> bool:
48224868
report_invalid_typevar_arg=False,
48234869
)
48244870
default = tv_arg or AnyType(TypeOfAny.from_error)
4825-
if isinstance(tv_arg, Parameters):
4826-
for i, arg_type in enumerate(tv_arg.arg_types):
4827-
typ = get_proper_type(arg_type)
4828-
if isinstance(typ, AnyType) and typ.is_from_error:
4829-
self.fail(
4830-
f"Argument {i} of ParamSpec default must be a type", param_value
4831-
)
4832-
elif (
4833-
isinstance(default, AnyType)
4834-
and default.is_from_error
4835-
or not isinstance(default, (AnyType, UnboundType))
4836-
):
4837-
self.fail(
4838-
"The default argument to ParamSpec must be a list expression, ellipsis, or a ParamSpec",
4839-
param_value,
4840-
)
4841-
default = AnyType(TypeOfAny.from_error)
4871+
default = self.check_paramspec_default(default, param_value)
48424872
else:
48434873
# ParamSpec is different from a regular TypeVar:
48444874
# arguments are not semantically valid. But, allowed in runtime.
@@ -4899,12 +4929,7 @@ def process_typevartuple_declaration(self, s: AssignmentStmt) -> bool:
48994929
allow_unpack=True,
49004930
)
49014931
default = tv_arg or AnyType(TypeOfAny.from_error)
4902-
if not isinstance(default, UnpackType):
4903-
self.fail(
4904-
"The default argument to TypeVarTuple must be an Unpacked tuple",
4905-
param_value,
4906-
)
4907-
default = AnyType(TypeOfAny.from_error)
4932+
default = self.check_typevartuple_default(default, param_value)
49084933
else:
49094934
self.fail(f'Unexpected keyword argument "{param_name}" for "TypeVarTuple"', s)
49104935

@@ -5503,6 +5528,7 @@ def visit_type_alias_stmt(self, s: TypeAliasStmt) -> None:
55035528
eager=eager,
55045529
python_3_12_type_alias=True,
55055530
)
5531+
s.alias_node = alias_node
55065532

55075533
if (
55085534
existing

mypy/strconv.py

+2
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,8 @@ def type_param(self, p: mypy.nodes.TypeParam) -> list[Any]:
349349
a.append(p.upper_bound)
350350
if p.values:
351351
a.append(("Values", p.values))
352+
if p.default:
353+
a.append(("Default", [p.default]))
352354
return [("TypeParam", a)]
353355

354356
# Expressions

mypy/test/testparse.py

+4
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ class ParserSuite(DataSuite):
2525
files.remove("parse-python310.test")
2626
if sys.version_info < (3, 12):
2727
files.remove("parse-python312.test")
28+
if sys.version_info < (3, 13):
29+
files.remove("parse-python313.test")
2830

2931
def run_case(self, testcase: DataDrivenTestCase) -> None:
3032
test_parser(testcase)
@@ -43,6 +45,8 @@ def test_parser(testcase: DataDrivenTestCase) -> None:
4345
options.python_version = (3, 10)
4446
elif testcase.file.endswith("python312.test"):
4547
options.python_version = (3, 12)
48+
elif testcase.file.endswith("python313.test"):
49+
options.python_version = (3, 13)
4650
else:
4751
options.python_version = defaults.PYTHON3_VERSION
4852

0 commit comments

Comments
 (0)