Skip to content

Commit a58bb08

Browse files
author
Anthony Wu
committed
config updates
1 parent 6f6c228 commit a58bb08

File tree

3 files changed

+7
-6
lines changed

3 files changed

+7
-6
lines changed

tools/cloud-replicate-cog/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ This is a quick demo of `mflux` running on Linux CPUs in Replicate's cloud using
33
# Development
44

55
- `cog build --tag mflux-linux-cpu:test` # produces image cog-cloud-replicate-cog
6-
- `cog predict mflux-linux-cpu:test -i prompt=sunset -i steps=1 -i num_outputs=1 -i seed=$RANDOM`
6+
- `cog predict mflux-linux-cpu:test --setup-timeout 3600 -i prompt=sunset -i steps=1 -i num_outputs=1 -i seed=$RANDOM`
77

88
when running the first time, we may need to allow a longer setup timeout for initial model download:
99

10-
- `cog predict <image> --setup-timeout 3600 ...`
10+
- `cog predict mflux-linux-cpu:test --setup-timeout 3600 ...`

tools/cloud-replicate-cog/cog.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
image: r8.im/anthonywu/mflux-linux-cpu:test
12
build:
23
cog_runtime: true
34
gpu: false

tools/cloud-replicate-cog/generate.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from mflux.config.config import Config
99
from mflux.flux.flux import Flux1
1010

11-
MODEL_CACHE = "FLUX.1-schnell"
11+
MODEL_CACHE = "/data/FLUX.1-schnell"
1212
MODEL_URL = "https://weights.replicate.delivery/default/black-forest-labs/FLUX.1-schnell/files.tar"
1313

1414
ASPECT_RATIOS = {
@@ -62,9 +62,9 @@ class TextToImage(BasePredictor):
6262
def setup(self) -> None:
6363
"""Load the model into memory to make running multiple predictions efficient"""
6464
start = time.time()
65-
print("Loading Flux model weights")
66-
if not os.path.exists(MODEL_CACHE):
67-
download_weights(MODEL_URL, ".")
65+
print(f"Loading Flux model weights to {MODEL_CACHE}")
66+
if not Path(MODEL_CACHE).exists():
67+
download_weights(MODEL_URL, MODEL_CACHE)
6868
model_path = find_transformers_parent(MODEL_CACHE)
6969
print(f"model path found at: {model_path}")
7070
self.flux = Flux1(path=model_path, model="schnell")

0 commit comments

Comments
 (0)