14
14
CFG = Config ()
15
15
16
16
17
- def generate_image (prompt : str ) -> str :
17
+ def generate_image (prompt : str , size : int = 256 ) -> str :
18
18
"""Generate an image from a prompt.
19
19
20
20
Args:
21
21
prompt (str): The prompt to use
22
+ size (int, optional): The size of the image. Defaults to 256. (Not supported by HuggingFace)
22
23
23
24
Returns:
24
25
str: The filename of the image
@@ -27,11 +28,14 @@ def generate_image(prompt: str) -> str:
27
28
28
29
# DALL-E
29
30
if CFG .image_provider == "dalle" :
30
- return generate_image_with_dalle (prompt , filename )
31
- elif CFG .image_provider == "sd" :
31
+ return generate_image_with_dalle (prompt , filename , size )
32
+ # HuggingFace
33
+ elif CFG .image_provider == "huggingface" :
32
34
return generate_image_with_hf (prompt , filename )
33
- else :
34
- return "No Image Provider Set"
35
+ # SD WebUI
36
+ elif CFG .image_provider == "sdwebui" :
37
+ return generate_image_with_sd_webui (prompt , filename , size )
38
+ return "No Image Provider Set"
35
39
36
40
37
41
def generate_image_with_hf (prompt : str , filename : str ) -> str :
@@ -45,13 +49,16 @@ def generate_image_with_hf(prompt: str, filename: str) -> str:
45
49
str: The filename of the image
46
50
"""
47
51
API_URL = (
48
- "https://api-inference.huggingface.co/models/CompVis/stable-diffusion-v1-4 "
52
+ f "https://api-inference.huggingface.co/models/{ CFG . huggingface_image_model } "
49
53
)
50
54
if CFG .huggingface_api_token is None :
51
55
raise ValueError (
52
56
"You need to set your Hugging Face API token in the config file."
53
57
)
54
- headers = {"Authorization" : f"Bearer { CFG .huggingface_api_token } " }
58
+ headers = {
59
+ "Authorization" : f"Bearer { CFG .huggingface_api_token } " ,
60
+ "X-Use-Cache" : "false" ,
61
+ }
55
62
56
63
response = requests .post (
57
64
API_URL ,
@@ -81,10 +88,18 @@ def generate_image_with_dalle(prompt: str, filename: str) -> str:
81
88
"""
82
89
openai .api_key = CFG .openai_api_key
83
90
91
+ # Check for supported image sizes
92
+ if size not in [256 , 512 , 1024 ]:
93
+ closest = min ([256 , 512 , 1024 ], key = lambda x : abs (x - size ))
94
+ print (
95
+ f"DALL-E only supports image sizes of 256x256, 512x512, or 1024x1024. Setting to { closest } , was { size } ."
96
+ )
97
+ size = closest
98
+
84
99
response = openai .Image .create (
85
100
prompt = prompt ,
86
101
n = 1 ,
87
- size = "256x256 " ,
102
+ size = f" { size } x { size } " ,
88
103
response_format = "b64_json" ,
89
104
)
90
105
@@ -96,3 +111,53 @@ def generate_image_with_dalle(prompt: str, filename: str) -> str:
96
111
png .write (image_data )
97
112
98
113
return f"Saved to disk:{ filename } "
114
+
115
+
116
+ def generate_image_with_sd_webui (
117
+ prompt : str ,
118
+ filename : str ,
119
+ size : int = 512 ,
120
+ negative_prompt : str = "" ,
121
+ extra : dict = {},
122
+ ) -> str :
123
+ """Generate an image with Stable Diffusion webui.
124
+ Args:
125
+ prompt (str): The prompt to use
126
+ filename (str): The filename to save the image to
127
+ size (int, optional): The size of the image. Defaults to 256.
128
+ negative_prompt (str, optional): The negative prompt to use. Defaults to "".
129
+ extra (dict, optional): Extra parameters to pass to the API. Defaults to {}.
130
+ Returns:
131
+ str: The filename of the image
132
+ """
133
+ # Create a session and set the basic auth if needed
134
+ s = requests .Session ()
135
+ if CFG .sd_webui_auth :
136
+ username , password = CFG .sd_webui_auth .split (":" )
137
+ s .auth = (username , password or "" )
138
+
139
+ # Generate the images
140
+ response = requests .post (
141
+ f"{ CFG .sd_webui_url } /sdapi/v1/txt2img" ,
142
+ json = {
143
+ "prompt" : prompt ,
144
+ "negative_prompt" : negative_prompt ,
145
+ "sampler_index" : "DDIM" ,
146
+ "steps" : 20 ,
147
+ "cfg_scale" : 7.0 ,
148
+ "width" : size ,
149
+ "height" : size ,
150
+ "n_iter" : 1 ,
151
+ ** extra ,
152
+ },
153
+ )
154
+
155
+ print (f"Image Generated for prompt:{ prompt } " )
156
+
157
+ # Save the image to disk
158
+ response = response .json ()
159
+ b64 = b64decode (response ["images" ][0 ].split ("," , 1 )[0 ])
160
+ image = Image .open (io .BytesIO (b64 ))
161
+ image .save (path_in_workspace (filename ))
162
+
163
+ return f"Saved to disk:{ filename } "
0 commit comments