-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathstable_diffusion_handler.py
More file actions
106 lines (91 loc) · 3.6 KB
/
stable_diffusion_handler.py
File metadata and controls
106 lines (91 loc) · 3.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import logging
import zipfile
import orjson
from abc import ABC
import diffusers
import torch
from diffusers import DiffusionPipeline
import numpy as np
from ts.torch_handler.base_handler import BaseHandler
logger = logging.getLogger(__name__)
logger.info("Diffusers version %s", diffusers.__version__)
class DiffusersHandler(BaseHandler, ABC):
"""
Diffusers handler class for text to image generation.
"""
def __init__(self):
self.initialized = False
def initialize(self, ctx):
"""In this initialize function, the Stable Diffusion model is loaded and
initialized here.
Args:
ctx (context): It is a JSON Object containing information
pertaining to the model artifacts parameters.
"""
self.manifest = ctx.manifest
properties = ctx.system_properties
model_dir = properties.get("model_dir")
self.device = torch.device(
"cuda:" + str(properties.get("gpu_id"))
if torch.cuda.is_available() and properties.get("gpu_id") is not None
else "cpu"
)
# Loading the model and tokenizer from checkpoint and config files based on the user's choice of mode
# further setup config can be added.
with zipfile.ZipFile(model_dir + "/model.zip", "r") as zip_ref:
zip_ref.extractall(model_dir + "/model")
self.pipe = DiffusionPipeline.from_pretrained(model_dir + "/model")
self.pipe.to(self.device)
logger.info("Diffusion model from path %s loaded successfully", model_dir)
self.initialized = True
def preprocess(self, requests):
"""Basic text preprocessing, of the user's prompt.
Args:
requests (str): The Input data in the form of text is passed on to the preprocess
function.
Returns:
list : The preprocess function returns a list of prompts.
"""
inputs = []
for _, data in enumerate(requests):
input_text = data['body']['inputs'][0]['data'][0]
if input_text is None:
input_text = data.get("body")
if isinstance(input_text, (bytes, bytearray)):
input_text = input_text.decode("utf-8")
logger.info("Received text: '%s'", input_text)
inputs.append(input_text)
return inputs
def inference(self, inputs):
"""Generates the image relevant to the received text.
Args:
input_batch (list): List of Text from the pre-process function is passed here
Returns:
list : It returns a list of the generate images for the input text
"""
# Handling inference for sequence_classification.
inferences = self.pipe(
inputs, guidance_scale=7.5, num_inference_steps=50, height=768, width=768
).images
logger.info("Generated image: '%s'", inferences)
return inferences
def postprocess(self, inference_output):
"""Post Process Function converts the generated image into Torchserve readable format.
Args:
inference_output (list): It contains the generated image of the input text.
Returns:
(list): Returns a list of the images.
"""
img_nparray = np.array(inference_output[0])
response = {
"id": "42",
"outputs": [
{
"name": "output0",
"datatype": "fp16",
"shape": img_nparray.shape,
"data": img_nparray.tolist()
}
]
}
return [orjson.dumps(response)]