-
Notifications
You must be signed in to change notification settings - Fork 1k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Python -VV
Python 3.10.16 (main, Dec 11 2024, 16:24:50) [GCC 11.2.0]Pip Freeze
accelerate==1.7.0
aiohappyeyeballs==2.6.1
aiohttp==3.11.18
aiosignal==1.3.2
annotated-types @ file:///croot/annotated-types_1709542908624/work
async-timeout==5.0.1
attrs @ file:///croot/attrs_1734533101012/work
autoawq==0.2.9
boto3 @ file:///croot/boto3_1743092127406/work
botocore @ file:///croot/botocore_1743061876997/work
Brotli @ file:///croot/brotli-split_1736182456865/work
certifi @ file:///croot/certifi_1745939216646/work/certifi
charset-normalizer @ file:///croot/charset-normalizer_1721748349566/work
click @ file:///croot/click_1744271578095/work
datasets==3.6.0
dill==0.3.8
docstring_parser==0.16
filelock @ file:///croot/filelock_1744281381737/work
fire==0.7.0
frozenlist==1.6.0
fsspec==2025.3.0
gmpy2 @ file:///croot/gmpy2_1738085463648/work
huggingface-hub==0.31.4
idna @ file:///croot/idna_1714398848350/work
Jinja2 @ file:///croot/jinja2_1741710844255/work
jmespath @ file:///croot/jmespath_1700144569655/work
joblib @ file:///croot/joblib_1718217211762/work
jsonschema @ file:///croot/jsonschema_1728486696720/work
jsonschema-specifications @ file:///croot/jsonschema-specifications_1699032386549/work
lit @ file:///croot/llvm-package_1741800901749/work/llvm/utils/lit
MarkupSafe @ file:///croot/markupsafe_1738584038848/work
mistral_common==1.5.4
mistral_inference==1.6.0
mkl-service==2.4.0
mkl_fft @ file:///io/mkl313/mkl_fft_1730824109137/work
mkl_random @ file:///io/mkl313/mkl_random_1730823916628/work
modelscope==1.26.0
mpmath @ file:///croot/mpmath_1690848262763/work
multidict==6.4.4
multiprocess==0.70.16
networkx @ file:///croot/networkx_1737039604450/work
numpy @ file:///croot/numpy_and_numpy_base_1747238018250/work/dist/numpy-2.2.5-cp310-cp310-linux_x86_64.whl#sha256=5a6df30fd2c407292da04652f137bbbe50d4765b8e20b11b2e225dadf36c269e
nvidia-cublas-cu12==12.6.4.1
nvidia-cuda-cupti-cu12==12.6.80
nvidia-cuda-nvrtc-cu12==12.6.77
nvidia-cuda-runtime-cu12==12.6.77
nvidia-cudnn-cu12==9.5.1.17
nvidia-cufft-cu12==11.3.0.4
nvidia-cufile-cu12==1.11.1.6
nvidia-curand-cu12==10.3.7.77
nvidia-cusolver-cu12==11.7.1.2
nvidia-cusparse-cu12==12.5.4.2
nvidia-cusparselt-cu12==0.6.3
nvidia-nccl-cu12==2.26.2
nvidia-nvjitlink-cu12==12.6.85
nvidia-nvtx-cu12==12.6.77
packaging @ file:///croot/packaging_1734472117206/work
pandas==2.2.3
pillow @ file:///croot/pillow_1744613067434/work
propcache==0.3.1
psutil==7.0.0
pyarrow==20.0.0
pydantic @ file:///croot/pydantic_1734736067156/work
pydantic_core @ file:///croot/pydantic-core_1734726052986/work
PySocks @ file:///home/builder/ci_310/pysocks_1640793678128/work
python-dateutil @ file:///croot/python-dateutil_1716495738603/work
pytz==2025.2
PyYAML @ file:///croot/pyyaml_1728657952215/work
referencing @ file:///croot/referencing_1699012038513/work
regex @ file:///croot/regex_1736540786412/work
requests @ file:///croot/requests_1730999120400/work
rpds-py @ file:///croot/rpds-py_1736541261634/work
s3transfer @ file:///croot/s3transfer_1738245147924/work
sacremoses @ file:///tmp/build/80754af9/sacremoses_1633107328213/work
safetensors @ file:///croot/safetensors_1741361308036/work
sentencepiece @ file:///croot/sentencepiece-split_1742566759237/work/python
simple-parsing==0.1.7
six @ file:///croot/six_1744271502820/work
sympy @ file:///croot/sympy_1738108488918/work
termcolor @ file:///croot/termcolor_1668084651543/work
tiktoken==0.9.0
tokenizers @ file:///croot/tokenizers_1741370336077/work
torch==2.7.0
tqdm @ file:///croot/tqdm_1738943501192/work
transformers==4.52.2
triton==3.3.0
typing-inspection @ file:///croot/typing-inspection_1746023470701/work
typing_extensions @ file:///croot/typing_extensions_1734714854207/work
tzdata==2025.2
urllib3 @ file:///croot/urllib3_1737133630106/work
xformers==0.0.30
xxhash==3.5.0
yarl==1.20.0
zstandard==0.23.0Reproduction Steps
I'm encountering an issue where the register_forward_hook() function does not seem to work correctly in the attention and feed-forward network (FFN) layers. However, when using it within the transformer block, everything functions as expected.
Could anyone provide insights or solutions on why this discrepancy occurs?
Thanks!
Here is my code
import torch.nn.functional as F
from mistral_inference.transformer import Transformer
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest
def custom_generate(tokens, model, max_tokens, temperature, eos_id, device):
model.eval()
input_ids = torch.tensor(tokens[0], dtype=torch.long, device=device)
out_tokens = []
handles = []
for i in range(len(model.layers)):
layer = model.layers[str(i)]
def make_attn_hook():
def hook(module, input, output):
print(f"[HOOK] Attention layer {i} input shape: {input[0].shape}")
return hook
def make_ffn_hook():
def hook(module, input, output):
print(f"[HOOK] FFN layer {i} input shape: {input[0].shape}")
return hook
def make_transform_hook():
def hook(module, input, output):
print(f"[HOOK] Transform layer {i} input shape: {input[0].shape}")
return hook
handles.append(layer.attention.register_forward_hook(make_attn_hook()))
handles.append(layer.feed_forward.register_forward_hook(make_ffn_hook()))
handles.append(layer.register_forward_hook(make_transform_hook()))
for _ in range(max_tokens):
seqlens = [input_ids.shape[0]]
with torch.no_grad():
logits = model.forward(input_ids, seqlens=seqlens)
next_token_logits = logits[-1, :] # [vocab_size]
if temperature > 0:
next_token_probs = F.softmax(next_token_logits / temperature, dim=-1)
next_token = torch.multinomial(next_token_probs, num_samples=1)
else:
next_token = next_token_logits.argmax(dim=-1, keepdim=True)
out_tokens.append(next_token.item())
if next_token.item() == eos_id:
break
input_ids = torch.cat([input_ids, next_token], dim=-1)
for h in handles:
h.remove()
return out_tokens
if __name__ == "__main__":
mistral_models_path = "/data/Mistral-7B"
tokenizer = MistralTokenizer.from_file(f"{mistral_models_path}/tokenizer.model.v3")
model = Transformer.from_folder(mistral_models_path)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
activations = {}
completion_request = ChatCompletionRequest(messages=[UserMessage(content="Use the numbers 1, 2, 3, 4 and the operations +, −, *, / to calculate 24.")])
tokens = tokenizer.encode_chat_completion(completion_request).tokens
out_tokens = custom_generate([tokens], model, max_tokens=64, temperature=0.0, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id, device=device)
result = tokenizer.instruct_tokenizer.tokenizer.decode(out_tokens)
print("Generated text:", result)
and the output is
[HOOK] Transform layer 31 input shape: torch.Size([95, 4096])
[HOOK] Transform layer 31 input shape: torch.Size([95, 4096])
[HOOK] Transform layer 31 input shape: torch.Size([95, 4096])
[HOOK] Transform layer 31 input shape: torch.Size([95, 4096])
[HOOK] Transform layer 31 input shape: torch.Size([95, 4096])
[HOOK] Transform layer 31 input shape: torch.Size([95, 4096])
[HOOK] Transform layer 31 input shape: torch.Size([95, 4096])
[HOOK] Transform layer 31 input shape: torch.Size([95, 4096])
[HOOK] Transform layer 31 input shape: torch.Size([95, 4096])
[HOOK] Transform layer 31 input shape: torch.Size([95, 4096])
[HOOK] Transform layer 31 input shape: torch.Size([95, 4096])
Generated text: Here are some possible expressions that calculate 24 using the given numbers and operations:
1. (4 * 3) + (4 * 2) + 1
2. 4 * (3 + 2) + 1
3. 4 * 3 + 2 +
Observe that the hooks attached to the attention and FFN layers do not trigger as expected.
Expected Behavior
The register_forward_hook() should successfully register and trigger during the forward pass of both attention/FFN layers and transformer blocks.
Additional Context
No response
Suggested Solutions
No response
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working