Skip to content

Commit dc91905

Browse files
committed
use fixture for tmp_path in modular tests.
1 parent 94457fd commit dc91905

File tree

3 files changed

+51
-61
lines changed

3 files changed

+51
-61
lines changed

tests/modular_pipelines/flux/test_modular_pipeline_flux.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# limitations under the License.
1515

1616
import random
17-
import tempfile
1817

1918
import numpy as np
2019
import PIL
@@ -129,18 +128,16 @@ def get_dummy_inputs(self, seed=0):
129128

130129
return inputs
131130

132-
def test_save_from_pretrained(self):
131+
def test_save_from_pretrained(self, tmp_path):
133132
pipes = []
134133
base_pipe = self.get_pipeline().to(torch_device)
135134
pipes.append(base_pipe)
136135

137-
with tempfile.TemporaryDirectory() as tmpdirname:
138-
base_pipe.save_pretrained(tmpdirname)
139-
140-
pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
141-
pipe.load_components(torch_dtype=torch.float32)
142-
pipe.to(torch_device)
143-
pipe.image_processor = VaeImageProcessor(vae_scale_factor=2)
136+
base_pipe.save_pretrained(tmp_path)
137+
pipe = ModularPipeline.from_pretrained(tmp_path).to(torch_device)
138+
pipe.load_components(torch_dtype=torch.float32)
139+
pipe.to(torch_device)
140+
pipe.image_processor = VaeImageProcessor(vae_scale_factor=2)
144141

145142
pipes.append(pipe)
146143

@@ -212,18 +209,16 @@ def get_dummy_inputs(self, seed=0):
212209

213210
return inputs
214211

215-
def test_save_from_pretrained(self):
212+
def test_save_from_pretrained(self, tmp_path):
216213
pipes = []
217214
base_pipe = self.get_pipeline().to(torch_device)
218215
pipes.append(base_pipe)
219216

220-
with tempfile.TemporaryDirectory() as tmpdirname:
221-
base_pipe.save_pretrained(tmpdirname)
222-
223-
pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
224-
pipe.load_components(torch_dtype=torch.float32)
225-
pipe.to(torch_device)
226-
pipe.image_processor = VaeImageProcessor(vae_scale_factor=2)
217+
base_pipe.save_pretrained(tmp_path)
218+
pipe = ModularPipeline.from_pretrained(tmp_path).to(torch_device)
219+
pipe.load_components(torch_dtype=torch.float32)
220+
pipe.to(torch_device)
221+
pipe.image_processor = VaeImageProcessor(vae_scale_factor=2)
227222

228223
pipes.append(pipe)
229224

tests/modular_pipelines/test_modular_pipelines_common.py

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import gc
22
import json
33
import os
4-
import tempfile
54
from typing import Callable
65

76
import pytest
@@ -330,16 +329,15 @@ def test_components_auto_cpu_offload_inference_consistent(self):
330329

331330
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
332331

333-
def test_save_from_pretrained(self):
332+
def test_save_from_pretrained(self, tmp_path):
334333
pipes = []
335334
base_pipe = self.get_pipeline().to(torch_device)
336335
pipes.append(base_pipe)
337336

338-
with tempfile.TemporaryDirectory() as tmpdirname:
339-
base_pipe.save_pretrained(tmpdirname)
340-
pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
341-
pipe.load_components(torch_dtype=torch.float32)
342-
pipe.to(torch_device)
337+
base_pipe.save_pretrained(tmp_path)
338+
pipe = ModularPipeline.from_pretrained(tmp_path).to(torch_device)
339+
pipe.load_components(torch_dtype=torch.float32)
340+
pipe.to(torch_device)
343341

344342
pipes.append(pipe)
345343

@@ -351,32 +349,31 @@ def test_save_from_pretrained(self):
351349

352350
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
353351

354-
def test_modular_index_consistency(self):
352+
def test_modular_index_consistency(self, tmp_path):
355353
pipe = self.get_pipeline()
356354
components_spec = pipe._component_specs
357355
components = sorted(components_spec.keys())
358356

359-
with tempfile.TemporaryDirectory() as tmpdir:
360-
pipe.save_pretrained(tmpdir)
361-
index_file = os.path.join(tmpdir, "modular_model_index.json")
362-
assert os.path.exists(index_file)
363-
364-
with open(index_file) as f:
365-
index_contents = json.load(f)
366-
367-
compulsory_keys = {"_blocks_class_name", "_class_name", "_diffusers_version"}
368-
for k in compulsory_keys:
369-
assert k in index_contents
370-
371-
to_check_attrs = {"pretrained_model_name_or_path", "revision", "subfolder"}
372-
for component in components:
373-
spec = components_spec[component]
374-
for attr in to_check_attrs:
375-
if getattr(spec, "pretrained_model_name_or_path", None) is not None:
376-
for attr in to_check_attrs:
377-
assert component in index_contents, f"{component} should be present in index but isn't."
378-
attr_value_from_index = index_contents[component][2][attr]
379-
assert getattr(spec, attr) == attr_value_from_index
357+
pipe.save_pretrained(tmp_path)
358+
index_file = os.path.join(tmp_path, "modular_model_index.json")
359+
assert os.path.exists(index_file)
360+
361+
with open(index_file) as f:
362+
index_contents = json.load(f)
363+
364+
compulsory_keys = {"_blocks_class_name", "_class_name", "_diffusers_version"}
365+
for k in compulsory_keys:
366+
assert k in index_contents
367+
368+
to_check_attrs = {"pretrained_model_name_or_path", "revision", "subfolder"}
369+
for component in components:
370+
spec = components_spec[component]
371+
for attr in to_check_attrs:
372+
if getattr(spec, "pretrained_model_name_or_path", None) is not None:
373+
for attr in to_check_attrs:
374+
assert component in index_contents, f"{component} should be present in index but isn't."
375+
attr_value_from_index = index_contents[component][2][attr]
376+
assert getattr(spec, attr) == attr_value_from_index
380377

381378
def test_workflow_map(self):
382379
blocks = self.pipeline_blocks_class()

tests/modular_pipelines/test_modular_pipelines_custom_blocks.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import json
1616
import os
17-
import tempfile
1817
from collections import deque
1918
from typing import List
2019

@@ -153,25 +152,24 @@ def test_custom_block_output(self):
153152
output_prompt = output.values["output_prompt"]
154153
assert output_prompt.startswith("Modular diffusers + ")
155154

156-
def test_custom_block_saving_loading(self):
155+
def test_custom_block_saving_loading(self, tmp_path):
157156
custom_block = DummyCustomBlockSimple()
158157

159-
with tempfile.TemporaryDirectory() as tmpdir:
160-
custom_block.save_pretrained(tmpdir)
161-
assert any("modular_config.json" in k for k in os.listdir(tmpdir))
158+
custom_block.save_pretrained(tmp_path)
159+
assert any("modular_config.json" in k for k in os.listdir(tmp_path))
162160

163-
with open(os.path.join(tmpdir, "modular_config.json"), "r") as f:
164-
config = json.load(f)
165-
auto_map = config["auto_map"]
166-
assert auto_map == {"ModularPipelineBlocks": "test_modular_pipelines_custom_blocks.DummyCustomBlockSimple"}
161+
with open(os.path.join(tmp_path, "modular_config.json"), "r") as f:
162+
config = json.load(f)
163+
auto_map = config["auto_map"]
164+
assert auto_map == {"ModularPipelineBlocks": "test_modular_pipelines_custom_blocks.DummyCustomBlockSimple"}
167165

168-
# For now, the Python script that implements the custom block has to be manually pushed to the Hub.
169-
# This is why, we have to separately save the Python script here.
170-
code_path = os.path.join(tmpdir, "test_modular_pipelines_custom_blocks.py")
171-
with open(code_path, "w") as f:
172-
f.write(CODE_STR)
166+
# For now, the Python script that implements the custom block has to be manually pushed to the Hub.
167+
# This is why, we have to separately save the Python script here.
168+
code_path = os.path.join(tmp_path, "test_modular_pipelines_custom_blocks.py")
169+
with open(code_path, "w") as f:
170+
f.write(CODE_STR)
173171

174-
loaded_custom_block = ModularPipelineBlocks.from_pretrained(tmpdir, trust_remote_code=True)
172+
loaded_custom_block = ModularPipelineBlocks.from_pretrained(tmp_path, trust_remote_code=True)
175173

176174
pipe = loaded_custom_block.init_pipeline()
177175
prompt = "Diffusers is nice"

0 commit comments

Comments
 (0)