hello DeepEp and HybridEp team
I found some bugs about HybridEp
- During SFT training, each rank may have a different
max_num_of_tokens_per_rank in HybridEPBuffer.__init__(). In update_template_config calls, the num_of_tokens_per_rank in the hidden state is also different. This causes some ranks to wait at an all_gather while others skip it, resulting in a hang.
the hang place is bool HybridEPBuffer::update_buffer. some ranks need_reallocate=false so skip the all_gather in allgather_obj.update(buffer_config);
- In
dispatch_with_permute, metadata_preprocessing receives a routing_map, but the length of routing_map is different. This leads to different input lengths for the all_gather_into_tensor operation in Executor::allgather_routing_map, causing precision issues.
Fix:
- I try to use
max_num_of_tokens_per_rank=next_power_of_two(max_num_of_tokens_per_rank) so that the length is always max_length.
- Fix precision issue by padding
routing_map to max_num_of_tokens_per_rank.
pad_rows = config.max_num_of_tokens_per_rank - num_of_tokens_per_rank
routing_map = torch.nn.functional.pad(routing_map, (0, 0, 0, pad_rows), value=False)
Training normally after these two fixes
hello DeepEp and HybridEp team
I found some bugs about HybridEp
max_num_of_tokens_per_rankinHybridEPBuffer.__init__().Inupdate_template_configcalls, thenum_of_tokens_per_rankin the hidden state is also different. This causes some ranks to wait at an all_gather while others skip it, resulting in a hang.the hang place is
bool HybridEPBuffer::update_buffer. some ranksneed_reallocate=falseso skip the all_gather inallgather_obj.update(buffer_config);dispatch_with_permute,metadata_preprocessingreceives arouting_map, but the length ofrouting_mapis different. This leads to different input lengths for theall_gather_into_tensoroperation inExecutor::allgather_routing_map, causing precision issues.Fix:
max_num_of_tokens_per_rank=next_power_of_two(max_num_of_tokens_per_rank)so that the length is always max_length.routing_maptomax_num_of_tokens_per_rank.pad_rows = config.max_num_of_tokens_per_rank - num_of_tokens_per_rankrouting_map = torch.nn.functional.pad(routing_map, (0, 0, 0, pad_rows), value=False)Training normally after these two fixes