resolution fix
This commit is contained in:
parent
62b1ffad38
commit
507b06958d
|
@ -26,9 +26,6 @@ import subprocess
|
||||||
import warnings
|
import warnings
|
||||||
import torch
|
import torch
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from diffusers import StableDiffusionInpaintPipeline
|
|
||||||
from diffusers import StableDiffusionPipeline
|
|
||||||
from diffusers.schedulers import DDIMScheduler
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import torch
|
import torch
|
||||||
|
@ -97,7 +94,6 @@ class StableDiffusionHolder:
|
||||||
fp_ckpt: str = None,
|
fp_ckpt: str = None,
|
||||||
fp_config: str = None,
|
fp_config: str = None,
|
||||||
device: str = None,
|
device: str = None,
|
||||||
set_auto_res: bool = True,
|
|
||||||
height: Optional[int] = None,
|
height: Optional[int] = None,
|
||||||
width: Optional[int] = None,
|
width: Optional[int] = None,
|
||||||
num_inference_steps: int = 30,
|
num_inference_steps: int = 30,
|
||||||
|
@ -118,10 +114,13 @@ class StableDiffusionHolder:
|
||||||
self.ddim_eta = 0
|
self.ddim_eta = 0
|
||||||
self.num_inference_steps = num_inference_steps
|
self.num_inference_steps = num_inference_steps
|
||||||
|
|
||||||
if set_auto_res:
|
if height is None and width is None:
|
||||||
assert height is None, "Either enable automatic setting of resolution or specify height/width"
|
|
||||||
assert width is None, "Either enable automatic setting of resolution or specify height/width"
|
|
||||||
self.init_auto_res()
|
self.init_auto_res()
|
||||||
|
else:
|
||||||
|
assert height is not None, "specify both width and height"
|
||||||
|
assert width is not None, "specify both width and height"
|
||||||
|
self.height = height
|
||||||
|
self.width = width
|
||||||
|
|
||||||
# Inpainting inits
|
# Inpainting inits
|
||||||
self.mask_empty = Image.fromarray(255*np.ones([self.width, self.height], dtype=np.uint8))
|
self.mask_empty = Image.fromarray(255*np.ones([self.width, self.height], dtype=np.uint8))
|
||||||
|
|
Loading…
Reference in New Issue