Skip to content

Commit 72fbce4

Browse files
authored
Smoothquant refactor for 3.x API (#1792)
Signed-off-by: Cheng, Zixuan <zixuan.cheng@intel.com>
1 parent ee24dba commit 72fbce4

File tree

11 files changed

+408
-137
lines changed

11 files changed

+408
-137
lines changed

examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py

+5-13
Original file line numberDiff line numberDiff line change
@@ -361,20 +361,12 @@ def run_fn(model):
361361

362362
from utils import get_example_inputs
363363
example_inputs = get_example_inputs(user_model, calib_dataloader)
364-
if args.sq:
365-
# currently, smooth quant only support quantize API
366-
# TODO: support prepare/convert API for smooth quant
367-
from neural_compressor.torch.quantization import quantize
368364

369-
user_model = quantize(
370-
model=user_model, quant_config=quant_config, example_inputs=example_inputs, run_fn=run_fn
371-
)
372-
else:
373-
from neural_compressor.torch.quantization import prepare, convert
374-
375-
user_model = prepare(model=user_model, quant_config=quant_config, example_inputs=example_inputs)
376-
run_fn(user_model)
377-
user_model = convert(user_model)
365+
from neural_compressor.torch.quantization import prepare, convert
366+
user_model = prepare(model=user_model, quant_config=quant_config, example_inputs=example_inputs)
367+
run_fn(user_model)
368+
user_model = convert(user_model)
369+
378370
user_model.save(args.output_dir)
379371

380372

neural_compressor/torch/algorithms/base_algorithm.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import copy
1516
from abc import ABC, abstractmethod
16-
from collections import OrderedDict
1717
from typing import Any, Optional
1818

1919
import torch
@@ -111,5 +111,15 @@ def execute(self, model: torch.nn.Module, mode, *args: Any, **kwargs: Any):
111111
elif mode == Mode.CONVERT:
112112
model = self.convert(model, *args, **kwargs)
113113
elif mode == Mode.QUANTIZE:
114-
model = self.quantize(model, *args, **kwargs)
114+
if not isinstance(self.quant_config, dict):
115+
user_cfg = copy.deepcopy(self.quant_config).to_dict()
116+
else:
117+
user_cfg = copy.deepcopy(self.quant_config)
118+
if "recipe_cfgs" in user_cfg: # keep quantize API for smoothquant
119+
run_fn = kwargs.get("run_fn", None)
120+
example_inputs = kwargs.get("example_inputs", None)
121+
inplace = kwargs.get("inplace", True)
122+
model = self.quantize(model, self.quant_config, run_fn, example_inputs, inplace)
123+
else:
124+
model = self.quantize(model, *args, **kwargs)
115125
return model

neural_compressor/torch/algorithms/smooth_quant/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,5 @@
1414
# limitations under the License.
1515

1616
from .utility import *
17-
from .smooth_quant import smooth_quantize
17+
from .smooth_quant import SmoothQuantQuantizer
1818
from .save_load import save, load, recover_model_from_json

neural_compressor/torch/algorithms/smooth_quant/smooth_quant.py

+190-74
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,16 @@
2121

2222
try:
2323
import intel_extension_for_pytorch as ipex
24-
except:
24+
except: # pragma: no cover
2525
assert False, "Please install IPEX for smooth quantization."
2626

27+
from collections import OrderedDict
28+
from types import MethodType
29+
2730
from packaging.version import Version
2831

32+
from neural_compressor.torch.algorithms import Quantizer
33+
2934
from .utility import (
3035
TorchSmoothQuant,
3136
cfg_to_qconfig,
@@ -41,88 +46,199 @@
4146
ipex_ver = get_ipex_version()
4247

4348

44-
def smooth_quantize(model, tune_cfg, run_fn, example_inputs, inplace=True):
45-
"""Execute the quantize process on the specified model.
49+
class SmoothQuantQuantizer(Quantizer):
50+
def __init__(self, quant_config: OrderedDict = {}):
51+
"""Init a SmoothQuantQuantizer object.
4652
47-
Args:
48-
model: a float model to be quantized.
49-
tune_cfg: quantization config for ops.
50-
run_fn: a calibration function for calibrating the model.
51-
example_inputs: used to trace torch model.
52-
inplace: whether to carry out model transformations in-place.
53+
Args:
54+
quant_config (OrderedDict, optional): quantization config for ops. Defaults to {}.
55+
"""
56+
super().__init__(quant_config)
5357

54-
Returns:
55-
A quantized model.
56-
"""
57-
assert not ipex_ver.release < Version("2.1").release, "IPEX version >= 2.1 is required for SmoothQuant."
58+
def prepare(self, model, example_inputs, inplace=True, *args, **kwargs):
59+
"""Prepares a given model for quantization.
60+
61+
Args:
62+
model: A float model to be quantized.
63+
example_inputs: Used to trace torch model.
64+
inplace: Whether to carry out model transformations in-place. Defaults to True.
65+
66+
Returns:
67+
A prepared model.
68+
"""
69+
assert example_inputs is not None, "Please provide example_inputs for smooth quantization."
70+
assert not ipex_ver.release < Version("2.1").release, "IPEX version >= 2.1 is required for SmoothQuant."
71+
72+
# Note: we should make sure smoothquant is only executed once with inplacing fp32 model.
73+
if hasattr(model, "_smoothquant_optimized") and model._smoothquant_optimized: # pragma: no cover
74+
logger.info("The model is already optimized by SmoothQuant algorithm, skip it.")
75+
return model
76+
77+
cfgs, op_infos_from_cfgs, output_tensor_id_op_name = (
78+
model.cfgs,
79+
model.op_infos_from_cfgs,
80+
model.output_tensor_id_op_name,
81+
)
82+
83+
# Update json file in ipex_config_path
84+
cfg_to_qconfig(self.quant_config, cfgs, op_infos_from_cfgs, output_tensor_id_op_name)
85+
model.eval()
86+
87+
# check smoothquant alpha and act_algo value
88+
recipe_cfgs = self.quant_config.get("recipe_cfgs", None)
89+
alpha = recipe_cfgs["smooth_quant_args"]["alpha"]
90+
for op, _ in self.quant_config["op"].items():
91+
act_algo = self.quant_config["op"][op]["activation"]["algorithm"]
5892

59-
_, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, _ = get_quantizable_ops_recursively(model, example_inputs)
93+
# Check save_qconf_summary part is a workaround for IPEX bug.
94+
# Sometimes the prepared model from get_op_capablitiy loss this attribute.
95+
if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"):
96+
from torch.ao.quantization.observer import MinMaxObserver
6097

61-
# check smoothquant folding value
62-
recipe_cfgs = tune_cfg.get("recipe_cfgs", None)
63-
if "smooth_quant_args" in recipe_cfgs and "folding" in recipe_cfgs["smooth_quant_args"]:
64-
if recipe_cfgs["smooth_quant_args"]["folding"] is None:
65-
if ipex_ver.release < Version("2.1").release:
66-
folding = True
98+
if ipex_ver.release >= Version("2.1.1").release:
99+
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(
100+
alpha=alpha, act_observer=MinMaxObserver
101+
)
102+
else: # pragma: no cover
103+
if act_algo == "minmax":
104+
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(
105+
alpha=alpha, act_observer=MinMaxObserver()
106+
)
107+
logger.warning(
108+
"The int8 model accuracy will be close to 0 with MinMaxobserver, "
109+
+ "the suggested IPEX version is higher or equal than 2.1.100+cpu."
110+
)
111+
else:
112+
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=alpha)
113+
114+
if isinstance(example_inputs, dict):
115+
model = ipex.quantization.prepare(
116+
model, static_qconfig, example_kwarg_inputs=example_inputs, inplace=inplace
117+
)
67118
else:
68-
folding = False
69-
else:
70-
folding = recipe_cfgs["smooth_quant_args"]["folding"]
119+
model = ipex.quantization.prepare(model, static_qconfig, example_inputs=example_inputs, inplace=inplace)
71120

72-
# Note: we should make sure smoothquant is only executed once with inplacing fp32 model.
73-
if hasattr(model, "_smoothquant_optimized") and model._smoothquant_optimized:
74-
logger.info("The model is already optimized by SmoothQuant algorithm, skip it.")
121+
cfg_to_qconfig(self.quant_config, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, smooth_quant=True)
122+
model.load_qconf_summary(qconf_summary=ipex_config_path)
75123
return model
76124

77-
sq = TorchSmoothQuant(model, dataloader=None, example_inputs=example_inputs, q_func=run_fn, record_max_info=True)
78-
model = sq.transform(
79-
alpha=recipe_cfgs["smooth_quant_args"]["alpha"],
80-
folding=folding,
81-
auto_alpha_args=recipe_cfgs["smooth_quant_args"]["auto_alpha_args"],
82-
scale_sharing=recipe_cfgs["smooth_quant_args"]["scale_sharing"],
83-
)
84-
85-
# Update model parameter when smoothquant folding = False
86-
if recipe_cfgs and recipe_cfgs.get("smooth_quant", False) and not folding:
87-
return qdq_quantize(
88-
model, tune_cfg, run_fn, example_inputs, inplace, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, sq
89-
)
125+
def convert(self, model, example_inputs, inplace=True, *args, **kwargs):
126+
"""Converts a prepared model to a quantized model.
90127
91-
# Update model parameter when smoothquant folding = True
92-
if recipe_cfgs and recipe_cfgs.get("smooth_quant", False) and folding:
93-
_apply_pre_optimization(model, tune_cfg, sq)
94-
model.eval()
128+
Args:
129+
model: The prepared model to be converted.
130+
example_inputs: Used to trace torch model.
131+
inplace: Whether to carry out model transformations in-place. Defaults to True.
95132
96-
# Check save_qconf_summary part is a workaround for IPEX bug.
97-
# Sometimes the prepared model from get_op_capablitiy loss this attribute
98-
if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"):
99-
static_qconfig = ipex.quantization.default_static_qconfig_mapping
100-
if isinstance(example_inputs, dict):
101-
model = ipex.quantization.prepare(
102-
model, static_qconfig, example_kwarg_inputs=example_inputs, inplace=inplace
133+
Returns:
134+
A quantized model.
135+
"""
136+
model.save_qconf_summary(qconf_summary=ipex_config_path)
137+
model = _ipex_post_quant_process(model, example_inputs, inplace=inplace)
138+
139+
with open(ipex_config_path, "r") as f:
140+
model.tune_cfg = json.load(f)
141+
model.ipex_config_path = ipex_config_path
142+
dump_model_op_stats(self.quant_config["op"])
143+
144+
from neural_compressor.torch.algorithms.smooth_quant import save
145+
146+
logger.info("Smooth quantization done.")
147+
model.ori_save = model.save
148+
model.save = MethodType(save, model)
149+
return model
150+
151+
def quantize(self, model, tune_cfg, run_fn, example_inputs, inplace=True, *args, **kwargs):
152+
"""Execute the quantize process on the specified model.
153+
154+
Args:
155+
model: a float model to be quantized.
156+
tune_cfg: quantization config for ops.
157+
run_fn: a calibration function for calibrating the model.
158+
example_inputs: used to trace torch model.
159+
inplace: whether to carry out model transformations in-place.
160+
161+
Returns:
162+
A quantized model.
163+
"""
164+
assert not ipex_ver.release < Version("2.1").release, "IPEX version >= 2.1 is required for SmoothQuant."
165+
166+
cfgs, op_infos_from_cfgs, output_tensor_id_op_name = (
167+
model.cfgs,
168+
model.op_infos_from_cfgs,
169+
model.output_tensor_id_op_name,
170+
)
171+
172+
# check smoothquant folding value
173+
recipe_cfgs = tune_cfg.get("recipe_cfgs", None)
174+
if "smooth_quant_args" in recipe_cfgs and "folding" in recipe_cfgs["smooth_quant_args"]:
175+
if recipe_cfgs["smooth_quant_args"]["folding"] is None: # pragma: no cover
176+
if ipex_ver.release < Version("2.1").release:
177+
folding = True
178+
else:
179+
folding = False
180+
else:
181+
folding = recipe_cfgs["smooth_quant_args"]["folding"]
182+
183+
# Note: we should make sure smoothquant is only executed once with inplacing fp32 model.
184+
if hasattr(model, "_smoothquant_optimized") and model._smoothquant_optimized: # pragma: no cover
185+
logger.info("The model is already optimized by SmoothQuant algorithm, skip it.")
186+
return model
187+
188+
sq_info = model.sq_info
189+
190+
# Update model parameter when smoothquant folding = False
191+
if recipe_cfgs and recipe_cfgs.get("smooth_quant", False) and not folding:
192+
return qdq_quantize(
193+
model,
194+
tune_cfg,
195+
run_fn,
196+
example_inputs,
197+
inplace,
198+
cfgs,
199+
op_infos_from_cfgs,
200+
output_tensor_id_op_name,
201+
sq_info,
103202
)
104-
else:
105-
model = ipex.quantization.prepare(model, static_qconfig, example_inputs=example_inputs, inplace=inplace)
106203

107-
model.load_qconf_summary(qconf_summary=ipex_config_path)
108-
run_fn(model)
109-
model.save_qconf_summary(qconf_summary=ipex_config_path)
110-
model = _ipex_post_quant_process(model, example_inputs, inplace=inplace)
204+
# Update model parameter when smoothquant folding = True
205+
if recipe_cfgs and recipe_cfgs.get("smooth_quant", False) and folding:
206+
_apply_pre_optimization(model, tune_cfg, sq_info)
111207

112-
# Recover model parameter when smoothquant folding = True
113-
if (
114-
recipe_cfgs
115-
and recipe_cfgs.get("smooth_quant", False)
116-
and recipe_cfgs["smooth_quant_args"]["folding"]
117-
and not inplace
118-
): # pragma: no cover
119-
_apply_pre_optimization(model, tune_cfg, sq, recover=True)
208+
# Update json file in ipex_config_path
209+
cfg_to_qconfig(self.quant_config, cfgs, op_infos_from_cfgs, output_tensor_id_op_name)
210+
model.eval()
120211

121-
with open(ipex_config_path, "r") as f:
122-
model.tune_cfg = json.load(f)
123-
model.ipex_config_path = ipex_config_path
124-
dump_model_op_stats(tune_cfg["op"])
125-
return model
212+
# Check save_qconf_summary part is a workaround for IPEX bug.
213+
# Sometimes the prepared model from get_op_capablitiy loss this attribute
214+
if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"): # pragma: no cover
215+
static_qconfig = ipex.quantization.default_static_qconfig_mapping
216+
if isinstance(example_inputs, dict):
217+
model = ipex.quantization.prepare(
218+
model, static_qconfig, example_kwarg_inputs=example_inputs, inplace=inplace
219+
)
220+
else:
221+
model = ipex.quantization.prepare(model, static_qconfig, example_inputs=example_inputs, inplace=inplace)
222+
223+
model.load_qconf_summary(qconf_summary=ipex_config_path)
224+
run_fn(model)
225+
model.save_qconf_summary(qconf_summary=ipex_config_path)
226+
model = _ipex_post_quant_process(model, example_inputs, inplace=inplace)
227+
228+
# Recover model parameter when smoothquant folding = True
229+
if (
230+
recipe_cfgs
231+
and recipe_cfgs.get("smooth_quant", False)
232+
and recipe_cfgs["smooth_quant_args"]["folding"]
233+
and not inplace
234+
): # pragma: no cover
235+
_apply_pre_optimization(model, tune_cfg, sq_info, recover=True)
236+
237+
with open(ipex_config_path, "r") as f:
238+
model.tune_cfg = json.load(f)
239+
model.ipex_config_path = ipex_config_path
240+
dump_model_op_stats(tune_cfg["op"])
241+
return model
126242

127243

128244
def qdq_quantize(
@@ -133,12 +249,12 @@ def qdq_quantize(
133249

134250
# Check save_qconf_summary part is a workaround for IPEX bug.
135251
# Sometimes the prepared model from get_op_capablitiy loss this attribute
136-
if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"):
252+
if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"): # pragma: no cover
137253
from torch.ao.quantization.observer import MinMaxObserver
138254

139255
if ipex_ver.release >= Version("2.1.1").release:
140256
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5, act_observer=MinMaxObserver)
141-
else:
257+
else: # pragma: no cover
142258
if sq_minmax_init:
143259
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(
144260
alpha=0.5, act_observer=MinMaxObserver()
@@ -169,7 +285,7 @@ def qdq_quantize(
169285
# IPEX may raise an error on the second iteration.
170286
# OverflowError: cannot convert float infinity to integer
171287
run_fn(model)
172-
except:
288+
except: # pragma: no cover
173289
logger.warning(
174290
"The calibration failed when calibrating with ipex, "
175291
+ "using scale info from SmoothQuant for Linear and "
@@ -197,7 +313,7 @@ def _apply_pre_optimization(model, tune_cfg, sq, recover=False):
197313
tsq = TorchSmoothQuant(model, None)
198314
alpha = tune_cfg["recipe_cfgs"]["smooth_quant_args"]["alpha"]
199315
for op_name, info in sq_max_info.items():
200-
if alpha == "auto":
316+
if alpha == "auto": # pragma: no cover
201317
alpha = info["alpha"]
202318
absorb_layer = op_name
203319
absorbed_layer = info["absorbed_layer"]
@@ -237,7 +353,7 @@ def _ipex_post_quant_process(model, example_inputs, inplace=False):
237353
else:
238354
model = torch.jit.trace(model, example_inputs)
239355
model = torch.jit.freeze(model.eval())
240-
except:
356+
except: # pragma: no cover
241357
if isinstance(example_inputs, dict):
242358
model = torch.jit.trace(model, example_kwarg_inputs=example_inputs, strict=False, check_trace=False)
243359
else:

0 commit comments

Comments
 (0)