diff --git a/examples/post_training_quantization/imagenet1k/deit/qconfig.yaml b/examples/post_training_quantization/imagenet1k/deit/qconfig.yaml index 61f68a6..bdee4db 100644 --- a/examples/post_training_quantization/imagenet1k/deit/qconfig.yaml +++ b/examples/post_training_quantization/imagenet1k/deit/qconfig.yaml @@ -15,4 +15,8 @@ A: BIT: 8 OBSERVER: TYPE: MINMAX - LAYOUT: NCHW + LAYOUT: NLC + SPECIFIC: [{ + "patch_embed_proj": ["OBSERVER.LAYOUT", "NCHW"], + "head": ["OBSERVER.LAYOUT", "NCHW"], + }] diff --git a/sparsebit/quantization/quant_config.py b/sparsebit/quantization/quant_config.py index b46532d..5ec39d1 100644 --- a/sparsebit/quantization/quant_config.py +++ b/sparsebit/quantization/quant_config.py @@ -9,6 +9,7 @@ _C.SCHEDULE = CN() _C.SCHEDULE.FUSE_BN = False # use ``with torch.no_grad()`` if it's enabled +_C.SCHEDULE.BIAS_CORRECTION = False _C.SCHEDULE.BN_TUNING = False _C.SCHEDULE.DISABLE_UNNECESSARY_QUANT = True diff --git a/sparsebit/quantization/quant_model.py b/sparsebit/quantization/quant_model.py index 582f9ae..7fd2138 100644 --- a/sparsebit/quantization/quant_model.py +++ b/sparsebit/quantization/quant_model.py @@ -185,7 +185,9 @@ def prepare_calibration(self): from sparsebit.quantization.tools.calibration import CalibrationRunner self.eval() - self.calibration_runner = CalibrationRunner(self.model) + self.calibration_runner = CalibrationRunner( + self.model, self.cfg.SCHEDULE.BIAS_CORRECTION + ) self.calibration_runner.prepare_calibration() def calc_qparams(self, asym=False, w_quant=False, a_quant=False): diff --git a/sparsebit/quantization/tools/calibration.py b/sparsebit/quantization/tools/calibration.py index 8a94706..7d44fe0 100644 --- a/sparsebit/quantization/tools/calibration.py +++ b/sparsebit/quantization/tools/calibration.py @@ -1,5 +1,6 @@ import copy import torch +import torch.nn as nn from functools import partial from sparsebit.quantization.modules import QuantOpr @@ -9,8 +10,9 @@ class CalibrationRunner(object): - def __init__(self, model): + def __init__(self, model, bias_correction=False): self.model = fx_symbolic_trace(model) + self.bias_correction = bias_correction def prepare_calibration(self): input_names_cache = set( @@ -96,6 +98,9 @@ def layerwise_calibration(self, device, asym=False, w_quant=False, a_quant=False ) self.builder.qstorage.set_output(node.target, quant_outputs) self.builder.qstorage.finish_node(node.target) + # bias correction + if self.bias_correction: + self.run_bias_correction(batch_num, node, device) # pop the outputs of nodes whose out-degree=0 self.builder.storage.finish_node(node.target) @@ -153,3 +158,49 @@ def module_forward( if isinstance(module, QuantOpr): module.set_quant(w_quant=False, a_quant=False) return outputs + + def run_bias_correction(self, batch_num, node, device): + module = self.model + for n in node.target.split("."): + module = getattr(module, n) + if isinstance(module, QuantOpr) and getattr(module, "weight_quantizer", None): + for inp_node in node.all_input_nodes: + inp_tensors = self.builder.storage.get_output(inp_node.target) + float_outputs = torch.Tensor([]) + quant_outputs = torch.Tensor([]) + float_outputs_cached = self.builder.storage.get_output(node.target) + for idx in range(batch_num): + inp_tensor = inp_tensors[idx].cuda() + with torch.no_grad(): + float_output = ( + float_outputs_cached[idx] + .transpose(module.input_quantizer.qdesc._ch_axis, 0) + .flatten(1) + ) + module.set_quant(True, False) + quant_output = ( + module(inp_tensor) + .cpu() + .transpose(module.input_quantizer.qdesc._ch_axis, 0) + .flatten(1) + ) + module.set_quant(False, False) + float_outputs = torch.cat( + (float_outputs, float_output.detach()), 1 + ) + quant_outputs = torch.cat( + (quant_outputs, quant_output.detach()), 1 + ) + float_output_mean = float_outputs.mean(-1) + quant_output_mean = quant_outputs.mean(-1) + bias = quant_output_mean - float_output_mean + if module.bias is None: + module.bias = nn.Parameter( + data=torch.zeros( + module.weight.size(0), + dtype=torch.float32, + device=device, + ), + requires_grad=False, + ) + module.bias.data = module.bias.data - bias.cuda()