Skip to content

Commit 6ebd990

Browse files
committed
add a test to check modular index consistency
1 parent 1f6ac1c commit 6ebd990

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

tests/modular_pipelines/test_modular_pipelines_common.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import gc
2+
import json
3+
import os
24
import tempfile
35
from typing import Callable
46

@@ -349,6 +351,29 @@ def test_save_from_pretrained(self):
349351

350352
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
351353

354+
def test_modular_index_consistency(self):
355+
pipe = self.get_pipeline()
356+
components_spec = pipe._component_specs
357+
components = sorted(components_spec.keys())
358+
359+
with tempfile.TemporaryDirectory() as tmpdir:
360+
pipe.save_pretrained(tmpdir)
361+
index_file = os.path.join(tmpdir, "modular_model_index.json")
362+
assert os.path.exists(index_file)
363+
364+
with open(index_file) as f:
365+
index_contents = json.load(f)
366+
367+
to_check_attrs = {"pretrained_model_name_or_path", "revision", "subfolder"}
368+
for component in components:
369+
spec = components_spec[component]
370+
for attr in to_check_attrs:
371+
if getattr(spec, "pretrained_model_name_or_path", None) is not None:
372+
for attr in to_check_attrs:
373+
assert component in index_contents, f"{component} should be present in index but isn't."
374+
attr_value_from_index = index_contents[component][2][attr]
375+
assert getattr(spec, attr) == attr_value_from_index
376+
352377
def test_workflow_map(self):
353378
blocks = self.pipeline_blocks_class()
354379
if blocks._workflow_map is None:

0 commit comments

Comments
 (0)