Thanks for providing an open source implementation of Deepseek R1. It is simple and easy to understand.
The benchmark results you have shared for TPU-v5e-64 logs throughput at 71 tokens/sec.
https://github.com/jax-ml/jax-llm-examples/tree/main/deepseek_r1_jax#inference-performance-results
The LMSys + SGLang implementation of Deepseek R1 places the token throughput per GPU at 5600 tokens/sec/GPU
- Why is there a significant difference in throughput? (Just first order approximation is enough)
- Is your throughput number tokens/sec or tokens/sec/TPU