Skip to content

Commit ede67f7

Browse files
NickCaoclaude
andcommitted
[Enhancement] Support RunAI Model Streamer for diffusion weight loading
Add enable_runai_streamer flag to OmniDiffusionConfig so diffusion models can use the runai_model_streamer library for streaming safetensors weights, matching the support already available in the LLM weight loading path. Co-authored-by: Claude <noreply@anthropic.com> Signed-off-by: Nick Cao <ncao@redhat.com>
1 parent d2b9f9f commit ede67f7

File tree

3 files changed

+16
-1
lines changed

3 files changed

+16
-1
lines changed

vllm_omni/diffusion/data.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,7 @@ class OmniDiffusionConfig:
412412
# Parallel weight loading (for faster diffusion model startup)
413413
enable_multithread_weight_load: bool = True
414414
num_weight_load_threads: int = 4
415+
enable_runai_streamer: bool = False
415416

416417
# Enable sleep mode
417418
enable_sleep_mode: bool = False

vllm_omni/diffusion/model_loader/diffusers_loader.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
filter_files_not_needed_for_inference,
2525
maybe_download_from_modelscope,
2626
multi_thread_safetensors_weights_iterator,
27+
runai_safetensors_weights_iterator,
2728
safetensors_weights_iterator,
2829
)
2930
from vllm.utils.import_utils import resolve_obj_by_qualname
@@ -180,13 +181,20 @@ def _get_weights_iterator(self, source: "ComponentSource") -> Generator[tuple[st
180181
)
181182

182183
od_config = self.od_config
184+
use_runai = use_safetensors and od_config is not None and getattr(od_config, "enable_runai_streamer", False)
183185
use_multithread = (
184186
use_safetensors
185187
and od_config is not None
186188
and getattr(od_config, "enable_multithread_weight_load", False)
187189
and self.load_config.safetensors_load_strategy != "torchao"
188190
)
189-
if use_multithread:
191+
if use_runai:
192+
sorted_hf_weights_files = sorted(hf_weights_files, key=_natural_sort_key)
193+
weights_iterator = runai_safetensors_weights_iterator(
194+
sorted_hf_weights_files,
195+
self.load_config.use_tqdm_on_load,
196+
)
197+
elif use_multithread:
190198
num_threads = getattr(od_config, "num_weight_load_threads", 4)
191199
# Keep deterministic shard order before passing to vLLM helper.
192200
sorted_hf_weights_files = sorted(hf_weights_files, key=_natural_sort_key)

vllm_omni/entrypoints/cli/serve.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,12 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu
294294
default=4,
295295
help="Number of threads for parallel weight loading (default: 4).",
296296
)
297+
omni_config_group.add_argument(
298+
"--enable-runai-streamer",
299+
action="store_true",
300+
default=False,
301+
help="Use RunAI Model Streamer for loading diffusion safetensors weights.",
302+
)
297303

298304
# diffusion model offload parameters
299305
omni_config_group.add_argument(

0 commit comments

Comments
 (0)