Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,8 @@ A:
BIT: 8
OBSERVER:
TYPE: MINMAX
LAYOUT: NCHW
LAYOUT: NLC
SPECIFIC: [{
"patch_embed_proj": ["OBSERVER.LAYOUT", "NCHW"],
"head": ["OBSERVER.LAYOUT", "NCHW"],
}]
1 change: 1 addition & 0 deletions sparsebit/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

_C.SCHEDULE = CN()
_C.SCHEDULE.FUSE_BN = False # use ``with torch.no_grad()`` if it's enabled
_C.SCHEDULE.BIAS_CORRECTION = False
_C.SCHEDULE.BN_TUNING = False
_C.SCHEDULE.DISABLE_UNNECESSARY_QUANT = True

Expand Down
4 changes: 3 additions & 1 deletion sparsebit/quantization/quant_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,9 @@ def prepare_calibration(self):
from sparsebit.quantization.tools.calibration import CalibrationRunner

self.eval()
self.calibration_runner = CalibrationRunner(self.model)
self.calibration_runner = CalibrationRunner(
self.model, self.cfg.SCHEDULE.BIAS_CORRECTION
)
self.calibration_runner.prepare_calibration()

def calc_qparams(self, asym=False, w_quant=False, a_quant=False):
Expand Down
53 changes: 52 additions & 1 deletion sparsebit/quantization/tools/calibration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import torch
import torch.nn as nn
from functools import partial

from sparsebit.quantization.modules import QuantOpr
Expand All @@ -9,8 +10,9 @@


class CalibrationRunner(object):
def __init__(self, model):
def __init__(self, model, bias_correction=False):
self.model = fx_symbolic_trace(model)
self.bias_correction = bias_correction

def prepare_calibration(self):
input_names_cache = set(
Expand Down Expand Up @@ -96,6 +98,9 @@ def layerwise_calibration(self, device, asym=False, w_quant=False, a_quant=False
)
self.builder.qstorage.set_output(node.target, quant_outputs)
self.builder.qstorage.finish_node(node.target)
# bias correction
if self.bias_correction:
self.run_bias_correction(batch_num, node, device)
# pop the outputs of nodes whose out-degree=0
self.builder.storage.finish_node(node.target)

Expand Down Expand Up @@ -153,3 +158,49 @@ def module_forward(
if isinstance(module, QuantOpr):
module.set_quant(w_quant=False, a_quant=False)
return outputs

def run_bias_correction(self, batch_num, node, device):
module = self.model
for n in node.target.split("."):
module = getattr(module, n)
if isinstance(module, QuantOpr) and getattr(module, "weight_quantizer", None):
for inp_node in node.all_input_nodes:
inp_tensors = self.builder.storage.get_output(inp_node.target)
float_outputs = torch.Tensor([])
quant_outputs = torch.Tensor([])
float_outputs_cached = self.builder.storage.get_output(node.target)
for idx in range(batch_num):
inp_tensor = inp_tensors[idx].cuda()
with torch.no_grad():
float_output = (
float_outputs_cached[idx]
.transpose(module.input_quantizer.qdesc._ch_axis, 0)
.flatten(1)
)
module.set_quant(True, False)
quant_output = (
module(inp_tensor)
.cpu()
.transpose(module.input_quantizer.qdesc._ch_axis, 0)
.flatten(1)
)
module.set_quant(False, False)
float_outputs = torch.cat(
(float_outputs, float_output.detach()), 1
)
quant_outputs = torch.cat(
(quant_outputs, quant_output.detach()), 1
)
float_output_mean = float_outputs.mean(-1)
quant_output_mean = quant_outputs.mean(-1)
bias = quant_output_mean - float_output_mean
if module.bias is None:
module.bias = nn.Parameter(
data=torch.zeros(
module.weight.size(0),
dtype=torch.float32,
device=device,
),
requires_grad=False,
)
module.bias.data = module.bias.data - bias.cuda()