|
1 | 1 | import gc |
| 2 | +import json |
| 3 | +import os |
2 | 4 | import tempfile |
3 | 5 | from typing import Callable |
4 | 6 |
|
@@ -349,6 +351,29 @@ def test_save_from_pretrained(self): |
349 | 351 |
|
350 | 352 | assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3 |
351 | 353 |
|
| 354 | + def test_modular_index_consistency(self): |
| 355 | + pipe = self.get_pipeline() |
| 356 | + components_spec = pipe._component_specs |
| 357 | + components = sorted(components_spec.keys()) |
| 358 | + |
| 359 | + with tempfile.TemporaryDirectory() as tmpdir: |
| 360 | + pipe.save_pretrained(tmpdir) |
| 361 | + index_file = os.path.join(tmpdir, "modular_model_index.json") |
| 362 | + assert os.path.exists(index_file) |
| 363 | + |
| 364 | + with open(index_file) as f: |
| 365 | + index_contents = json.load(f) |
| 366 | + |
| 367 | + to_check_attrs = {"pretrained_model_name_or_path", "revision", "subfolder"} |
| 368 | + for component in components: |
| 369 | + spec = components_spec[component] |
| 370 | + for attr in to_check_attrs: |
| 371 | + if getattr(spec, "pretrained_model_name_or_path", None) is not None: |
| 372 | + for attr in to_check_attrs: |
| 373 | + assert component in index_contents, f"{component} should be present in index but isn't." |
| 374 | + attr_value_from_index = index_contents[component][2][attr] |
| 375 | + assert getattr(spec, attr) == attr_value_from_index |
| 376 | + |
352 | 377 | def test_workflow_map(self): |
353 | 378 | blocks = self.pipeline_blocks_class() |
354 | 379 | if blocks._workflow_map is None: |
|
0 commit comments