[Compiler Toolkit] Add option for full inductor.#2150
Conversation
SherlockNoMad
left a comment
There was a problem hiding this comment.
lgtm with comments.
bd66bbd to
82834ba
Compare
|
aot_eager for llama3: full inductor: |
yiming0416
left a comment
There was a problem hiding this comment.
Can we add commands to run this to compiler_toolkit/README?
Also you can add 8-GPU CI following existing ones in compiler_toolkit/tests/integration_tests.py
82834ba to
0083214
Compare
Added to README and integration test. |
0083214 to
f8e7294
Compare
| """ | ||
|
|
||
| joint_passes: list[str] = field(default_factory=list) | ||
| passes: list[str] = field(default_factory=list) |
There was a problem hiding this comment.
non-blocking. I wonder if we should have a better naming here to distinguish joint_passes and passes applied on partitioned graphs. Do we consider having fwd_passes and bwd_passes (maybe an overkill for now)
There was a problem hiding this comment.
I think joint_passes is pretty descriptive for what it is.
I think passes could be renamed to post_partition_passes, as we don't yet have an example for fwd-only or bwd-only pass, so it is overkill to split to two options for fwd and bwd.
- Being able to compile fw/bw graphs using compile_fx_inner could help with establishing perf rooflines. - Full inductor compilation is achieved using compile_fx_inner, however, it requires the graph to have been decomposed using Inductor's default decomposition table. We apply this decomposition as a pass on the joint graph. We need to be careful to suitably unwrap the primals/tangents before running this decomposition. <!-- ps-id: 861b7e8c-8917-4806-8fec-25b0e5288f29 -->
f8e7294 to
2f8bfb2
Compare
Being able to compile fw/bw graphs using compile_fx_inner could help with establishing perf rooflines. Full inductor compilation is achieved using `compile_fx_inner`, however, it requires the graph to have been decomposed using Inductor's default decomposition table. We apply this decomposition as a pass on the joint graph. We need to be careful to suitably unwrap the primals/tangents before running this decomposition. Manual testing: NGPU=4 \ CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml \ TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train \ ./run_train.sh \ --model.name $MODEL_NAME \ --parallelism.data_parallel_shard_degree=2 \ --parallelism.tensor_parallel_degree=2 \ --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config \ --compile.joint_passes inductor_decomposition \ --compile.passes full_inductor_compilation <!-- ps-id: 5d590700-6d1f-44fe-8f70-4d2ea39106f4 -->
Being able to compile fw/bw graphs using compile_fx_inner could help with establishing perf rooflines. Full inductor compilation is achieved using `compile_fx_inner`, however, it requires the graph to have been decomposed using Inductor's default decomposition table. We apply this decomposition as a pass on the joint graph. We need to be careful to suitably unwrap the primals/tangents before running this decomposition. Manual testing: NGPU=4 \ CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml \ TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train \ ./run_train.sh \ --model.name $MODEL_NAME \ --parallelism.data_parallel_shard_degree=2 \ --parallelism.tensor_parallel_degree=2 \ --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config \ --compile.joint_passes inductor_decomposition \ --compile.passes full_inductor_compilation <!-- ps-id: 5d590700-6d1f-44fe-8f70-4d2ea39106f4 -->
Being able to compile fw/bw graphs using compile_fx_inner could help with establishing perf rooflines.
Full inductor compilation is achieved using
compile_fx_inner, however, it requires the graph to have been decomposed using Inductor's default decomposition table. We apply this decomposition as a pass on the joint graph. We need to be careful to suitably unwrap the primals/tangents before running this decomposition.Manual testing:
NGPU=4
CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml
TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train
./run_train.sh
--model.name $MODEL_NAME
--parallelism.data_parallel_shard_degree=2
--parallelism.tensor_parallel_degree=2
--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config
--compile.joint_passes inductor_decomposition
--compile.passes full_inductor_compilation