diff --git a/README.md b/README.md
index bfe0764..070a1f2 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,5 @@
## Fast Style Transfer in [TensorFlow](https://github.com/tensorflow/tensorflow)
+
Add styles from famous paintings to any photo in a fraction of a second! [You can even style videos!](#video-stylization)
@@ -11,7 +12,7 @@ Add styles from famous paintings to any photo in a fraction of a second! [You ca
It takes 100ms on a 2015 Titan X to style the MIT Stata Center (1024×680) like Udnie, by Francis Picabia.
@@ -127,7 +128,7 @@ You will need the following to run the above:
```
### Attributions/Thanks
-- This project could not have happened without the advice (and GPU access) given by [Anish Athalye](http://www.anishathalye.com/).
+- This project could not have happened without the advice (and GPU access) given by [Anish Athalye](http://www.anishathalye.com/).
- The project also borrowed some code from Anish's [Neural Style](https://github.com/anishathalye/neural-style/)
- Some readme/docs formatting was borrowed from Justin Johnson's [Fast Neural Style](https://github.com/jcjohnson/fast-neural-style)
- The image of the Stata Center at the very beginning of the README was taken by [Juan Paulo](https://juanpaulo.me/)
diff --git a/cog.yaml b/cog.yaml
new file mode 100644
index 0000000..86984cd
--- /dev/null
+++ b/cog.yaml
@@ -0,0 +1,11 @@
+build:
+ python_version: "3.7"
+ gpu: true
+ cuda: 11.0
+ python_packages:
+ - tensorflow-gpu==2.1.0
+ - moviepy==1.0.2
+ - imageio-ffmpeg==0.2.0
+ system_packages:
+ - ffmpeg
+predict: "predict.py:Predictor"
diff --git a/predict.py b/predict.py
new file mode 100644
index 0000000..87c9b93
--- /dev/null
+++ b/predict.py
@@ -0,0 +1,54 @@
+import os
+import tempfile
+from pathlib import Path
+
+import cog
+
+from evaluate import *
+from transform_video import *
+
+
+class Predictor(cog.Predictor):
+ def setup(self):
+ """nothing to pre-load"""
+ # no setup here as we need to
+ # dynamically change which checkpoint to load
+ # loading is very quick though!
+
+ @cog.input(
+ "input",
+ type=Path,
+ help="Input file: can be image (jpg, png) or video (mp4). Video processing takes ~100 milliseconds per frame",
+ )
+ @cog.input(
+ "style",
+ type=str,
+ options=["la_muse", "rain_princess", "scream", "udnie", "wave", "wreck"],
+ help="Pre-trained style to apply to input image",
+ default="udnie",
+ )
+ def predict(self, input, style):
+ """Compute prediction"""
+ output_path_jpg = Path(tempfile.mkdtemp()) / "output.jpg"
+ output_path_video = Path(tempfile.mkdtemp()) / "output.mp4"
+
+ checkpoints_dir = "pretrained_models"
+ checkpoint_path = os.path.join(checkpoints_dir, style + ".ckpt")
+ device = "/gpu:0"
+ batch_size = 4
+
+ img_extensions = [".jpg", ".png"]
+ video_extensions = [".mp4"]
+
+ if input.suffix in img_extensions:
+ ffwd_to_img(str(input), output_path_jpg, checkpoint_path, device=device)
+ return output_path_jpg
+
+ elif input.suffix in video_extensions:
+ ffwd_video(
+ str(input), str(output_path_video), checkpoint_path, device, batch_size
+ )
+ return output_path_video
+
+ else:
+ raise NotImplementedError("Input file extension not supported")