Skip to content

fix: cherry pick PR of 3445 #3457

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: release/2.7
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build-test-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ jobs:
python -m pip install -r requirements.txt
cd dynamo
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml test_modelopt_models.py
Comment on lines 176 to +177
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this addition necessary? It looks the previous line tests models/ which contains test_modelopt_models.py

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved the test_modelopt_models.py file out of the models/ directory just for separating quantized models for debugging purposes.

popd
tests-py-dynamo-serde:
Expand Down
122 changes: 62 additions & 60 deletions py/torch_tensorrt/dynamo/_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import tensorrt as trt
import torch
from torch.export import ExportedProgram
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
from torch_tensorrt._enums import dtype
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo import partitioning
Expand Down Expand Up @@ -144,71 +145,72 @@ def _refit_single_trt_engine_with_gm(
Refit a TensorRT Engine in place
"""

refitted = set()
torch_device = get_model_device(new_gm)
refitter = trt.Refitter(old_engine, TRT_LOGGER)
weight_list = refitter.get_all_weights()

if weight_name_map:
# Get the refitting mapping
trt_wt_location = (
trt.TensorLocation.DEVICE
if torch_device.type == "cuda"
else trt.TensorLocation.HOST
)
with unset_fake_temporarily():
refitted = set()
torch_device = get_model_device(new_gm)
refitter = trt.Refitter(old_engine, TRT_LOGGER)
weight_list = refitter.get_all_weights()

if weight_name_map:
# Get the refitting mapping
trt_wt_location = (
trt.TensorLocation.DEVICE
if torch_device.type == "cuda"
else trt.TensorLocation.HOST
)

constant_mapping: dict[str, Any] = weight_name_map.pop(
"constant_mapping", {}
) # type: ignore
mapping = construct_refit_mapping_from_weight_name_map(
weight_name_map, new_gm.state_dict()
)
constant_mapping_with_type = {}

for constant_name, val in constant_mapping.items():
np_weight_type = val.dtype
val_tensor = torch.from_numpy(val).cuda()
trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType)
torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype)
constant_mapping_with_type[constant_name] = (
val_tensor.clone().reshape(-1).contiguous().to(torch_dtype),
trt_dtype,
constant_mapping: dict[str, Any] = weight_name_map.pop(
"constant_mapping", {}
) # type: ignore
mapping = construct_refit_mapping_from_weight_name_map(
weight_name_map, new_gm.state_dict()
)
constant_mapping_with_type = {}

for constant_name, val in constant_mapping.items():
np_weight_type = val.dtype
val_tensor = torch.from_numpy(val).cuda()
trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType)
torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype)
constant_mapping_with_type[constant_name] = (
val_tensor.clone().reshape(-1).contiguous().to(torch_dtype),
trt_dtype,
)

mapping.update(constant_mapping_with_type)
mapping.update(constant_mapping_with_type)

for layer_name in weight_list:
if layer_name not in mapping:
logger.warning(f"{layer_name} is not found in weight mapping.")
continue
# Use Numpy to create weights
weight, weight_dtype = mapping[layer_name]
trt_wt_tensor = trt.Weights(
weight_dtype, weight.data_ptr(), torch.numel(weight)
)
refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location)
assert (
len(refitter.get_missing_weights()) == 0
), "Fast refitting failed due to incomplete mapping"
for layer_name in weight_list:
if layer_name not in mapping:
logger.warning(f"{layer_name} is not found in weight mapping.")
continue
# Use Numpy to create weights
weight, weight_dtype = mapping[layer_name]
trt_wt_tensor = trt.Weights(
weight_dtype, weight.data_ptr(), torch.numel(weight)
)
refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location)
assert (
len(refitter.get_missing_weights()) == 0
), "Fast refitting failed due to incomplete mapping"

else:
mapping = construct_refit_mapping(new_gm, input_list, settings)
trt_wt_location = trt.TensorLocation.HOST
for layer_name in weight_list:
if layer_name not in mapping:
raise AssertionError(f"{layer_name} is not found in weight mapping")
# Use Numpy to create weights
weight, datatype = mapping[layer_name]
trt_wt_tensor = trt.Weights(datatype, weight.ctypes.data, weight.size)
refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location)
refitted.add(layer_name)

if len(refitted) != len(weight_list):
logger.warning("Not all weights have been refitted!!!")

if not refitter.refit_cuda_engine():
logger.error("Error: failed to refit new weights.")
raise AssertionError("Refitting failed.")
else:
mapping = construct_refit_mapping(new_gm, input_list, settings)
trt_wt_location = trt.TensorLocation.HOST
for layer_name in weight_list:
if layer_name not in mapping:
raise AssertionError(f"{layer_name} is not found in weight mapping")
# Use Numpy to create weights
weight, datatype = mapping[layer_name]
trt_wt_tensor = trt.Weights(datatype, weight.ctypes.data, weight.size)
refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location)
refitted.add(layer_name)

if len(refitted) != len(weight_list):
logger.warning("Not all weights have been refitted!!!")

if not refitter.refit_cuda_engine():
logger.error("Error: failed to refit new weights.")
raise AssertionError("Refitting failed.")


def refit_module_weights(
Expand Down
40 changes: 20 additions & 20 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import tensorrt as trt
import torch
import torch.fx
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
from torch.fx.node import _get_qualified_name
from torch.fx.passes.shape_prop import TensorMetadata
from torch.utils._python_dispatch import _disable_current_modes
Expand All @@ -41,6 +42,7 @@
get_node_io,
get_node_name,
get_trt_tensor,
to_torch,
)
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, get_model_device, to_torch_device
from torch_tensorrt.fx.observer import Observer
Expand Down Expand Up @@ -408,27 +410,29 @@ def find_weight(
np_map: the map from weight name to np values in INetworkDefinition
state_dict: state of the graph module
"""
network_weight = torch.from_numpy(np_map[weight_name]).to(device)
for sd_w_name, sd_weight in state_dict.items():
if TRTInterpreter.check_weight_equal(sd_weight, network_weight, device):
del state_dict[sd_w_name]
return sd_w_name
return ""
with unset_fake_temporarily():
network_weight = torch.from_numpy(np_map[weight_name]).to(device)
for sd_w_name, sd_weight in state_dict.items():
if TRTInterpreter.check_weight_equal(sd_weight, network_weight, device):
del state_dict[sd_w_name]
return sd_w_name
return ""

@staticmethod
def check_weight_equal(
sd_weight: torch.tensor,
network_weight: Union[torch.Tensor, np.ndarray],
device: torch.device,
) -> Any:
if not isinstance(network_weight, torch.Tensor):
network_weight = torch.from_numpy(network_weight).to(device)
try:
return sd_weight.shape == network_weight.shape and torch.all(
torch.abs(sd_weight - network_weight) < 0.01
)
except Exception:
return torch.all(sd_weight == network_weight)
with unset_fake_temporarily():
if not isinstance(network_weight, torch.Tensor):
network_weight = torch.from_numpy(network_weight).to(device)
try:
return sd_weight.shape == network_weight.shape and torch.all(
torch.abs(sd_weight - network_weight) < 0.01
)
except Exception:
return torch.all(sd_weight == network_weight)

def _save_weight_mapping(self) -> None:
"""
Expand Down Expand Up @@ -887,19 +891,15 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any:
return converter(self.ctx, target, args, kwargs, self._cur_node_name)

def get_attr(self, target: str, args: Any, kwargs: Any) -> np.ndarray:
with _disable_current_modes():
from torch_tensorrt.dynamo.conversion.converter_utils import to_numpy

with _disable_current_modes(), unset_fake_temporarily():
frozen_attr = self.fetch_attr(target)

if isinstance(frozen_attr, torch.nn.Parameter):
constant_tensor = frozen_attr.data
else:
constant_tensor = frozen_attr

network_constant = to_numpy(constant_tensor)

return network_constant
return to_torch(constant_tensor)

def call_method(self, target: str, args: Any, kwargs: Any) -> Any:
assert isinstance(target, str)
Expand Down
Loading
Loading