Describe the bug
All GPxQ algorithms (GPTQ, GPFQ, Qronos, etc.) rely on modifying a local weight tensor that shares memory with self.layer.weight.data. This allows get_quant_weights() to see modifications. However, for ConvTranspose2d layers, this memory sharing is broken.
Root cause: PyTorch memory behavior with transpose + flatten:
if isinstance(self.layer, SUPPORTED_CONV_OP):
if is_conv_transposed(self.layer):
# View - same memory, changed strides
weight = weight.transpose(1, 0)
# Issue is here for non-contiguous. It is only a view for contiguous (eg, Conv2d)
weight = weight.flatten(1)
You cannot:
- Transpose (changing strides to non-contiguous)
- Then flatten across non-contiguous dimensions
- And expect no copy
Flattening across permuted memory is fundamentally incompatible with stride-only views.
Reproducibility
To Reproduce
>>> import torch.nn as nn
# ConvTranspose2d with transpose - CREATES COPY
>>> layer = nn.ConvTranspose2d(3, 32, 3)
>>> weight = layer.weight.data
>>> weight.data_ptr()
94067644132544
>>> weight = weight.transpose(1, 0)
>>> weight.data_ptr()
94067644132544 # Same - transpose is a view
>>> weight = weight.flatten(1)
>>> weight.data_ptr()
94067752850112 # DIFFERENT - flatten copied!
# ConvTranspose2d without transpose - NO COPY
>>> layer = nn.ConvTranspose2d(3, 32, 3)
>>> weight = layer.weight.data
>>> weight.data_ptr()
94067752850112
>>> weight = weight.flatten(1)
>>> weight.data_ptr()
94067752850112 # Same - contiguous memory preserved
Expected behavior
Modifications to the working weight tensor should be visible to get_quant_weights(). For ConvTranspose2d, this currently doesn't work because the tensor is detached from layer memory.
If known:
- Brevitas version: dev (c2b9e7)
- PyTorch version: 2.6
- Operating System / platform: Linux
Additional context
This affects all GPxQ algorithms for ConvTranspose2d, but it is most visible in A2GPTQ and A2GPFQ where explicit accumulator bounds checking catches the issue in #1181
Possible solutions
A non-exhaustive list of possible solutions include:
Option 1. As a temporary solution, we could unregister ConvTranspose2d from the list of supported modules in GPxQ, possibly with a warning.
Option 2. We could wrap the weight with a custom class with similar function attributes as a Tensor, but we brute force the mapping in the wrapper so that the memory isn't detached. Something like:
class wrapper:
...
if isinstance(self.layer, SUPPORTED_CONV_OP):
if is_conv_transposed(self.layer):
weight = wrapper(weight)
weight = weight.flatten(1)
Option 3. we can treat either the kernel dimension or the in_channels dimension as a group_dim, and handle them separately in GPxQ so you don't force the memory stride permutation. Something like:
if isinstance(self.layer, SUPPORTED_CONV_OP):
dim = 2 if is_conv_transposed(self.layer) else 1
weight = weight.flatten(dim) # would be [IC, OC, K * K] for ConvTranspose2d, and we treat it like groups=IC
Option 4. rewrite GPxQ to either work on strided memory, or not rely on the shared data pointer.
Describe the bug
All GPxQ algorithms (GPTQ, GPFQ, Qronos, etc.) rely on modifying a local
weighttensor that shares memory withself.layer.weight.data. This allowsget_quant_weights()to see modifications. However, forConvTranspose2dlayers, this memory sharing is broken.Root cause: PyTorch memory behavior with transpose + flatten:
You cannot:
Flattening across permuted memory is fundamentally incompatible with stride-only views.
Reproducibility
To Reproduce
Expected behavior
Modifications to the working weight tensor should be visible to
get_quant_weights(). ForConvTranspose2d, this currently doesn't work because the tensor is detached from layer memory.If known:
Additional context
This affects all GPxQ algorithms for
ConvTranspose2d, but it is most visible in A2GPTQ and A2GPFQ where explicit accumulator bounds checking catches the issue in #1181Possible solutions
A non-exhaustive list of possible solutions include:
Option 1. As a temporary solution, we could unregister
ConvTranspose2dfrom the list of supported modules in GPxQ, possibly with a warning.Option 2. We could wrap the weight with a custom class with similar function attributes as a
Tensor, but we brute force the mapping in the wrapper so that the memory isn't detached. Something like:Option 3. we can treat either the kernel dimension or the in_channels dimension as a group_dim, and handle them separately in GPxQ so you don't force the memory stride permutation. Something like:
Option 4. rewrite GPxQ to either work on strided memory, or not rely on the shared data pointer.