Skip to content

Commit 4a1940d

Browse files
authored
🧑‍💻 torch.compile is not compatible with Windows. (TissueImageAnalytics#888)
- `torch.compile` is not currently compatible with Windows. See pytorch/pytorch#122094
1 parent 5f1cecb commit 4a1940d

File tree

1 file changed

+23
-4
lines changed

1 file changed

+23
-4
lines changed

tiatoolbox/models/architecture/utils.py

+23-4
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from __future__ import annotations
44

55
import sys
6-
from typing import NoReturn
76

87
import numpy as np
98
import torch
@@ -12,13 +11,17 @@
1211
from tiatoolbox import logger
1312

1413

15-
def is_torch_compile_compatible() -> NoReturn:
14+
def is_torch_compile_compatible() -> bool:
1615
"""Check if the current GPU is compatible with torch-compile.
1716
17+
Returns:
18+
True if current GPU is compatible with torch-compile, False otherwise.
19+
1820
Raises:
1921
Warning if GPU is not compatible with `torch.compile`.
2022
2123
"""
24+
gpu_compatibility = True
2225
if torch.cuda.is_available(): # pragma: no cover
2326
device_cap = torch.cuda.get_device_capability()
2427
if device_cap not in ((7, 0), (8, 0), (9, 0)):
@@ -28,13 +31,17 @@ def is_torch_compile_compatible() -> NoReturn:
2831
"Speedup numbers may be lower than expected.",
2932
stacklevel=2,
3033
)
34+
gpu_compatibility = False
3135
else:
3236
logger.warning(
3337
"No GPU detected or cuda not installed, "
3438
"torch.compile is only supported on selected NVIDIA GPUs. "
3539
"Speedup numbers may be lower than expected.",
3640
stacklevel=2,
3741
)
42+
gpu_compatibility = False
43+
44+
return gpu_compatibility
3845

3946

4047
def compile_model(
@@ -68,12 +75,24 @@ def compile_model(
6875
return model
6976

7077
# Check if GPU is compatible with torch.compile
71-
is_torch_compile_compatible()
78+
gpu_compatibility = is_torch_compile_compatible()
79+
80+
if not gpu_compatibility:
81+
return model
82+
83+
if sys.platform == "win32": # pragma: no cover
84+
msg = (
85+
"`torch.compile` is not supported on Windows. Please see "
86+
"https://github.com/pytorch/pytorch/issues/122094."
87+
)
88+
logger.warning(msg=msg)
89+
return model
7290

7391
# This check will be removed when torch.compile is supported in Python 3.12+
7492
if sys.version_info > (3, 12): # pragma: no cover
93+
msg = "torch-compile is currently not supported in Python 3.12+."
7594
logger.warning(
76-
("torch-compile is currently not supported in Python 3.12+. ",),
95+
msg=msg,
7796
)
7897
return model
7998

0 commit comments

Comments
 (0)