3
3
from __future__ import annotations
4
4
5
5
import sys
6
- from typing import NoReturn
7
6
8
7
import numpy as np
9
8
import torch
12
11
from tiatoolbox import logger
13
12
14
13
15
- def is_torch_compile_compatible () -> NoReturn :
14
+ def is_torch_compile_compatible () -> bool :
16
15
"""Check if the current GPU is compatible with torch-compile.
17
16
17
+ Returns:
18
+ True if current GPU is compatible with torch-compile, False otherwise.
19
+
18
20
Raises:
19
21
Warning if GPU is not compatible with `torch.compile`.
20
22
21
23
"""
24
+ gpu_compatibility = True
22
25
if torch .cuda .is_available (): # pragma: no cover
23
26
device_cap = torch .cuda .get_device_capability ()
24
27
if device_cap not in ((7 , 0 ), (8 , 0 ), (9 , 0 )):
@@ -28,13 +31,17 @@ def is_torch_compile_compatible() -> NoReturn:
28
31
"Speedup numbers may be lower than expected." ,
29
32
stacklevel = 2 ,
30
33
)
34
+ gpu_compatibility = False
31
35
else :
32
36
logger .warning (
33
37
"No GPU detected or cuda not installed, "
34
38
"torch.compile is only supported on selected NVIDIA GPUs. "
35
39
"Speedup numbers may be lower than expected." ,
36
40
stacklevel = 2 ,
37
41
)
42
+ gpu_compatibility = False
43
+
44
+ return gpu_compatibility
38
45
39
46
40
47
def compile_model (
@@ -68,12 +75,24 @@ def compile_model(
68
75
return model
69
76
70
77
# Check if GPU is compatible with torch.compile
71
- is_torch_compile_compatible ()
78
+ gpu_compatibility = is_torch_compile_compatible ()
79
+
80
+ if not gpu_compatibility :
81
+ return model
82
+
83
+ if sys .platform == "win32" : # pragma: no cover
84
+ msg = (
85
+ "`torch.compile` is not supported on Windows. Please see "
86
+ "https://github.com/pytorch/pytorch/issues/122094."
87
+ )
88
+ logger .warning (msg = msg )
89
+ return model
72
90
73
91
# This check will be removed when torch.compile is supported in Python 3.12+
74
92
if sys .version_info > (3 , 12 ): # pragma: no cover
93
+ msg = "torch-compile is currently not supported in Python 3.12+."
75
94
logger .warning (
76
- ( "torch-compile is currently not supported in Python 3.12+. " ,) ,
95
+ msg = msg ,
77
96
)
78
97
return model
79
98
0 commit comments