Skip to content

Broken interop between GPxQ and ConvTranspose2d #1479

@i-colbert

Description

@i-colbert

Describe the bug

All GPxQ algorithms (GPTQ, GPFQ, Qronos, etc.) rely on modifying a local weight tensor that shares memory with self.layer.weight.data. This allows get_quant_weights() to see modifications. However, for ConvTranspose2d layers, this memory sharing is broken.

Root cause: PyTorch memory behavior with transpose + flatten:

if isinstance(self.layer, SUPPORTED_CONV_OP):
    if is_conv_transposed(self.layer):
        # View - same memory, changed strides
        weight = weight.transpose(1, 0) 
    # Issue is here for non-contiguous. It is only a view for contiguous (eg, Conv2d)
    weight = weight.flatten(1)  

You cannot:

  • Transpose (changing strides to non-contiguous)
  • Then flatten across non-contiguous dimensions
  • And expect no copy

Flattening across permuted memory is fundamentally incompatible with stride-only views.

Reproducibility

  • Can be reproduced consistently.
  • Difficult to reproduce.
  • Unable to reproduce.

To Reproduce

>>> import torch.nn as nn

# ConvTranspose2d with transpose - CREATES COPY
>>> layer = nn.ConvTranspose2d(3, 32, 3)
>>> weight = layer.weight.data
>>> weight.data_ptr()
94067644132544
>>> weight = weight.transpose(1, 0)
>>> weight.data_ptr()
94067644132544  # Same - transpose is a view
>>> weight = weight.flatten(1)
>>> weight.data_ptr()
94067752850112  # DIFFERENT - flatten copied!

# ConvTranspose2d without transpose - NO COPY
>>> layer = nn.ConvTranspose2d(3, 32, 3)
>>> weight = layer.weight.data
>>> weight.data_ptr()
94067752850112
>>> weight = weight.flatten(1)
>>> weight.data_ptr()
94067752850112  # Same - contiguous memory preserved

Expected behavior

Modifications to the working weight tensor should be visible to get_quant_weights(). For ConvTranspose2d, this currently doesn't work because the tensor is detached from layer memory.

If known:

  • Brevitas version: dev (c2b9e7)
  • PyTorch version: 2.6
  • Operating System / platform: Linux

Additional context
This affects all GPxQ algorithms for ConvTranspose2d, but it is most visible in A2GPTQ and A2GPFQ where explicit accumulator bounds checking catches the issue in #1181

Possible solutions

A non-exhaustive list of possible solutions include:

Option 1. As a temporary solution, we could unregister ConvTranspose2d from the list of supported modules in GPxQ, possibly with a warning.

Option 2. We could wrap the weight with a custom class with similar function attributes as a Tensor, but we brute force the mapping in the wrapper so that the memory isn't detached. Something like:

class wrapper:
    ...
 
if isinstance(self.layer, SUPPORTED_CONV_OP):
    if is_conv_transposed(self.layer):
        weight = wrapper(weight)
    weight = weight.flatten(1)  

Option 3. we can treat either the kernel dimension or the in_channels dimension as a group_dim, and handle them separately in GPxQ so you don't force the memory stride permutation. Something like:

if isinstance(self.layer, SUPPORTED_CONV_OP):
    dim = 2 if is_conv_transposed(self.layer) else 1
    weight = weight.flatten(dim)  # would be [IC, OC, K * K] for ConvTranspose2d, and we treat it like groups=IC

Option 4. rewrite GPxQ to either work on strided memory, or not rely on the shared data pointer.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions