709 lines
28 KiB
Python
709 lines
28 KiB
Python
# Copyright 2022 Lunar Ring. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import os, sys
|
|
dp_git = "/home/lugo/git/"
|
|
sys.path.append(os.path.join(dp_git,'garden4'))
|
|
sys.path.append('util')
|
|
import torch
|
|
torch.backends.cudnn.benchmark = False
|
|
import numpy as np
|
|
import warnings
|
|
warnings.filterwarnings('ignore')
|
|
import time
|
|
import subprocess
|
|
import warnings
|
|
import torch
|
|
from tqdm.auto import tqdm
|
|
from PIL import Image
|
|
# import matplotlib.pyplot as plt
|
|
import torch
|
|
from movie_util import MovieSaver
|
|
import datetime
|
|
from typing import Callable, List, Optional, Union
|
|
import inspect
|
|
from threading import Thread
|
|
torch.set_grad_enabled(False)
|
|
from omegaconf import OmegaConf
|
|
from torch import autocast
|
|
from contextlib import nullcontext
|
|
from ldm.util import instantiate_from_config
|
|
from ldm.models.diffusion.ddim import DDIMSampler
|
|
from einops import repeat, rearrange
|
|
|
|
#%%
|
|
|
|
|
|
def pad_image(input_image):
|
|
pad_w, pad_h = np.max(((2, 2), np.ceil(
|
|
np.array(input_image.size) / 64).astype(int)), axis=0) * 64 - input_image.size
|
|
im_padded = Image.fromarray(
|
|
np.pad(np.array(input_image), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
|
|
return im_padded
|
|
|
|
|
|
|
|
def make_batch_inpaint(
|
|
image,
|
|
mask,
|
|
txt,
|
|
device,
|
|
num_samples=1):
|
|
image = np.array(image.convert("RGB"))
|
|
image = image[None].transpose(0, 3, 1, 2)
|
|
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
|
|
|
mask = np.array(mask.convert("L"))
|
|
mask = mask.astype(np.float32) / 255.0
|
|
mask = mask[None, None]
|
|
mask[mask < 0.5] = 0
|
|
mask[mask >= 0.5] = 1
|
|
mask = torch.from_numpy(mask)
|
|
|
|
masked_image = image * (mask < 0.5)
|
|
|
|
batch = {
|
|
"image": repeat(image.to(device=device), "1 ... -> n ...", n=num_samples),
|
|
"txt": num_samples * [txt],
|
|
"mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
|
|
"masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples),
|
|
}
|
|
return batch
|
|
|
|
|
|
def make_batch_superres(
|
|
image,
|
|
txt,
|
|
device,
|
|
num_samples=1,
|
|
):
|
|
image = np.array(image.convert("RGB"))
|
|
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
|
batch = {
|
|
"lr": rearrange(image, 'h w c -> 1 c h w'),
|
|
"txt": num_samples * [txt],
|
|
}
|
|
batch["lr"] = repeat(batch["lr"].to(device=device),
|
|
"1 ... -> n ...", n=num_samples)
|
|
return batch
|
|
|
|
|
|
def make_noise_augmentation(model, batch, noise_level=None):
|
|
x_low = batch[model.low_scale_key]
|
|
x_low = x_low.to(memory_format=torch.contiguous_format).float()
|
|
x_aug, noise_level = model.low_scale_model(x_low, noise_level)
|
|
return x_aug, noise_level
|
|
|
|
|
|
class StableDiffusionHolder:
|
|
def __init__(self,
|
|
fp_ckpt: str = None,
|
|
fp_config: str = None,
|
|
num_inference_steps: int = 30,
|
|
height: Optional[int] = None,
|
|
width: Optional[int] = None,
|
|
device: str = None,
|
|
precision: str='autocast',
|
|
):
|
|
|
|
self.seed = 42
|
|
self.guidance_scale = 5.0
|
|
|
|
if device is None:
|
|
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
else:
|
|
self.device = device
|
|
self.precision = precision
|
|
self.init_model(fp_ckpt, fp_config)
|
|
|
|
self.f = 8 #downsampling factor, most often 8 or 16",
|
|
self.C = 4
|
|
self.ddim_eta = 0
|
|
self.num_inference_steps = num_inference_steps
|
|
|
|
if height is None and width is None:
|
|
self.init_auto_res()
|
|
else:
|
|
assert height is not None, "specify both width and height"
|
|
assert width is not None, "specify both width and height"
|
|
self.height = height
|
|
self.width = width
|
|
|
|
# Inpainting inits
|
|
self.mask_empty = Image.fromarray(255*np.ones([self.width, self.height], dtype=np.uint8))
|
|
self.image_empty = Image.fromarray(np.zeros([self.width, self.height, 3], dtype=np.uint8))
|
|
|
|
|
|
def init_model(self, fp_ckpt, fp_config):
|
|
assert os.path.isfile(fp_ckpt), f"Your model checkpoint file does not exist: {fp_ckpt}"
|
|
assert os.path.isfile(fp_config), f"Your config file does not exist: {fp_config}"
|
|
self.fp_ckpt = fp_ckpt
|
|
|
|
config = OmegaConf.load(fp_config)
|
|
|
|
self.model = instantiate_from_config(config.model)
|
|
self.model.load_state_dict(torch.load(fp_ckpt)["state_dict"], strict=False)
|
|
|
|
self.model = self.model.to(self.device)
|
|
self.sampler = DDIMSampler(self.model)
|
|
|
|
|
|
|
|
def init_auto_res(self):
|
|
r"""Automatically set the resolution to the one used in training.
|
|
"""
|
|
if '768' in self.fp_ckpt:
|
|
self.height = 768
|
|
self.width = 768
|
|
else:
|
|
self.height = 512
|
|
self.width = 512
|
|
|
|
|
|
def init_inpainting(
|
|
self,
|
|
image_source: Union[Image.Image, np.ndarray] = None,
|
|
mask_image: Union[Image.Image, np.ndarray] = None,
|
|
init_empty: Optional[bool] = False,
|
|
):
|
|
r"""
|
|
Initializes inpainting with a source and maks image.
|
|
Args:
|
|
image_source: Union[Image.Image, np.ndarray]
|
|
Source image onto which the mask will be applied.
|
|
mask_image: Union[Image.Image, np.ndarray]
|
|
Mask image, value = 0 will stay untouched, value = 255 subjet to diffusion
|
|
init_empty: Optional[bool]:
|
|
Initialize inpainting with an empty image and mask, effectively disabling inpainting,
|
|
useful for generating a first image for transitions using diffusion.
|
|
"""
|
|
if not init_empty:
|
|
assert image_source is not None, "init_inpainting: you need to provide image_source"
|
|
assert mask_image is not None, "init_inpainting: you need to provide mask_image"
|
|
if type(image_source) == np.ndarray:
|
|
image_source = Image.fromarray(image_source)
|
|
self.image_source = image_source
|
|
|
|
if type(mask_image) == np.ndarray:
|
|
mask_image = Image.fromarray(mask_image)
|
|
self.mask_image = mask_image
|
|
else:
|
|
self.mask_image = self.mask_empty
|
|
self.image_source = self.image_empty
|
|
|
|
|
|
def get_text_embedding(self, prompt):
|
|
c = self.model.get_learned_conditioning(prompt)
|
|
return c
|
|
|
|
@torch.no_grad()
|
|
def get_cond_upscaling(self, image, text_embedding, noise_level):
|
|
r"""
|
|
Initializes the conditioning for the x4 upscaling model.
|
|
"""
|
|
|
|
image = pad_image(image) # resize to integer multiple of 32
|
|
w, h = image.size
|
|
noise_level = torch.Tensor(1 * [noise_level]).to(self.sampler.model.device).long()
|
|
batch = make_batch_superres(image, txt="placeholder", device=self.device, num_samples=1)
|
|
|
|
x_augment, noise_level = make_noise_augmentation(self.model, batch, noise_level)
|
|
|
|
cond = {"c_concat": [x_augment], "c_crossattn": [text_embedding], "c_adm": noise_level}
|
|
# uncond cond
|
|
uc_cross = self.model.get_unconditional_conditioning(1, "")
|
|
uc_full = {"c_concat": [x_augment], "c_crossattn": [uc_cross], "c_adm": noise_level}
|
|
|
|
return cond, uc_full
|
|
|
|
@torch.no_grad()
|
|
def run_diffusion_standard(
|
|
self,
|
|
text_embeddings: torch.FloatTensor,
|
|
latents_for_injection: torch.FloatTensor = None,
|
|
idx_start: int = -1,
|
|
idx_stop: int = -1,
|
|
return_image: Optional[bool] = False
|
|
):
|
|
r"""
|
|
Wrapper function for run_diffusion_standard and run_diffusion_inpaint.
|
|
Depending on the mode, the correct one will be executed.
|
|
|
|
Args:
|
|
text_embeddings: torch.FloatTensor
|
|
Text embeddings used for diffusion
|
|
latents_for_injection: torch.FloatTensor
|
|
Latents that are used for injection
|
|
idx_start: int
|
|
Index of the diffusion process start and where the latents_for_injection are injected
|
|
idx_stop: int
|
|
Index of the diffusion process end.
|
|
return_image: Optional[bool]
|
|
Optionally return image directly
|
|
"""
|
|
|
|
|
|
if latents_for_injection is None:
|
|
do_inject_latents = False
|
|
else:
|
|
do_inject_latents = True
|
|
|
|
|
|
precision_scope = autocast if self.precision == "autocast" else nullcontext
|
|
generator = torch.Generator(device=self.device).manual_seed(int(self.seed))
|
|
|
|
with precision_scope("cuda"):
|
|
with self.model.ema_scope():
|
|
if self.guidance_scale != 1.0:
|
|
uc = self.model.get_learned_conditioning([""])
|
|
else:
|
|
uc = None
|
|
shape_latents = [self.C, self.height // self.f, self.width // self.f]
|
|
|
|
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=self.ddim_eta, verbose=False)
|
|
C, H, W = shape_latents
|
|
size = (1, C, H, W)
|
|
b = size[0]
|
|
|
|
latents = torch.randn(size, generator=generator, device=self.device)
|
|
|
|
timesteps = self.sampler.ddim_timesteps
|
|
|
|
time_range = np.flip(timesteps)
|
|
total_steps = timesteps.shape[0]
|
|
|
|
# collect latents
|
|
list_latents_out = []
|
|
for i, step in enumerate(time_range):
|
|
if do_inject_latents:
|
|
# Inject latent at right place
|
|
if i < idx_start:
|
|
continue
|
|
elif i == idx_start:
|
|
latents = latents_for_injection.clone()
|
|
|
|
if i == idx_stop:
|
|
return list_latents_out
|
|
|
|
# print(f"diffusion iter {i}")
|
|
index = total_steps - i - 1
|
|
ts = torch.full((b,), step, device=self.device, dtype=torch.long)
|
|
outs = self.sampler.p_sample_ddim(latents, text_embeddings, ts, index=index, use_original_steps=False,
|
|
quantize_denoised=False, temperature=1.0,
|
|
noise_dropout=0.0, score_corrector=None,
|
|
corrector_kwargs=None,
|
|
unconditional_guidance_scale=self.guidance_scale,
|
|
unconditional_conditioning=uc,
|
|
dynamic_threshold=None)
|
|
latents, pred_x0 = outs
|
|
list_latents_out.append(latents.clone())
|
|
|
|
if return_image:
|
|
return self.latent2image(latents)
|
|
else:
|
|
return list_latents_out
|
|
|
|
@torch.no_grad()
|
|
def run_diffusion_inpaint(
|
|
self,
|
|
text_embeddings: torch.FloatTensor,
|
|
latents_for_injection: torch.FloatTensor = None,
|
|
idx_start: int = -1,
|
|
idx_stop: int = -1,
|
|
return_image: Optional[bool] = False
|
|
):
|
|
r"""
|
|
Runs inpaint-based diffusion. Returns a list of latents that were computed.
|
|
Adaptations allow to supply
|
|
a) starting index for diffusion
|
|
b) stopping index for diffusion
|
|
c) latent representations that are injected at the starting index
|
|
Furthermore the intermittent latents are collected and returned.
|
|
|
|
Adapted from diffusers (https://github.com/huggingface/diffusers)
|
|
Args:
|
|
text_embeddings: torch.FloatTensor
|
|
Text embeddings used for diffusion
|
|
latents_for_injection: torch.FloatTensor
|
|
Latents that are used for injection
|
|
idx_start: int
|
|
Index of the diffusion process start and where the latents_for_injection are injected
|
|
idx_stop: int
|
|
Index of the diffusion process end.
|
|
return_image: Optional[bool]
|
|
Optionally return image directly
|
|
|
|
"""
|
|
|
|
if latents_for_injection is None:
|
|
do_inject_latents = False
|
|
else:
|
|
do_inject_latents = True
|
|
|
|
precision_scope = autocast if self.precision == "autocast" else nullcontext
|
|
generator = torch.Generator(device=self.device).manual_seed(int(self.seed))
|
|
|
|
with precision_scope("cuda"):
|
|
with self.model.ema_scope():
|
|
|
|
batch = make_batch_inpaint(self.image_source, self.mask_image, txt="willbereplaced", device=self.device, num_samples=1)
|
|
c = text_embeddings
|
|
c_cat = list()
|
|
for ck in self.model.concat_keys:
|
|
cc = batch[ck].float()
|
|
if ck != self.model.masked_image_key:
|
|
bchw = [1, 4, self.height // 8, self.width // 8]
|
|
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
|
|
else:
|
|
cc = self.model.get_first_stage_encoding(self.model.encode_first_stage(cc))
|
|
c_cat.append(cc)
|
|
c_cat = torch.cat(c_cat, dim=1)
|
|
|
|
# cond
|
|
cond = {"c_concat": [c_cat], "c_crossattn": [c]}
|
|
|
|
# uncond cond
|
|
uc_cross = self.model.get_unconditional_conditioning(1, "")
|
|
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
|
|
|
|
shape_latents = [self.model.channels, self.height // 8, self.width // 8]
|
|
|
|
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=0., verbose=False)
|
|
# sampling
|
|
C, H, W = shape_latents
|
|
size = (1, C, H, W)
|
|
|
|
device = self.model.betas.device
|
|
b = size[0]
|
|
latents = torch.randn(size, generator=generator, device=device)
|
|
|
|
timesteps = self.sampler.ddim_timesteps
|
|
|
|
time_range = np.flip(timesteps)
|
|
total_steps = timesteps.shape[0]
|
|
|
|
# collect latents
|
|
list_latents_out = []
|
|
for i, step in enumerate(time_range):
|
|
if do_inject_latents:
|
|
# Inject latent at right place
|
|
if i < idx_start:
|
|
continue
|
|
elif i == idx_start:
|
|
latents = latents_for_injection.clone()
|
|
|
|
if i == idx_stop:
|
|
return list_latents_out
|
|
|
|
index = total_steps - i - 1
|
|
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
|
|
|
outs = self.sampler.p_sample_ddim(latents, cond, ts, index=index, use_original_steps=False,
|
|
quantize_denoised=False, temperature=1.0,
|
|
noise_dropout=0.0, score_corrector=None,
|
|
corrector_kwargs=None,
|
|
unconditional_guidance_scale=self.guidance_scale,
|
|
unconditional_conditioning=uc_full,
|
|
dynamic_threshold=None)
|
|
latents, pred_x0 = outs
|
|
list_latents_out.append(latents.clone())
|
|
|
|
if return_image:
|
|
return self.latent2image(latents)
|
|
else:
|
|
return list_latents_out
|
|
|
|
@torch.no_grad()
|
|
def run_diffusion_upscaling(
|
|
self,
|
|
cond,
|
|
uc_full,
|
|
latents_for_injection: torch.FloatTensor = None,
|
|
idx_start: int = -1,
|
|
idx_stop: int = -1,
|
|
return_image: Optional[bool] = False
|
|
):
|
|
r"""
|
|
Wrapper function for run_diffusion_standard and run_diffusion_inpaint.
|
|
Depending on the mode, the correct one will be executed.
|
|
|
|
Args:
|
|
??
|
|
latents_for_injection: torch.FloatTensor
|
|
Latents that are used for injection
|
|
idx_start: int
|
|
Index of the diffusion process start and where the latents_for_injection are injected
|
|
idx_stop: int
|
|
Index of the diffusion process end.
|
|
return_image: Optional[bool]
|
|
Optionally return image directly
|
|
"""
|
|
|
|
|
|
if latents_for_injection is None:
|
|
do_inject_latents = False
|
|
else:
|
|
do_inject_latents = True
|
|
|
|
precision_scope = autocast if self.precision == "autocast" else nullcontext
|
|
generator = torch.Generator(device=self.device).manual_seed(int(self.seed))
|
|
|
|
h = uc_full['c_concat'][0].shape[2]
|
|
w = uc_full['c_concat'][0].shape[3]
|
|
|
|
with precision_scope("cuda"):
|
|
with self.model.ema_scope():
|
|
|
|
|
|
shape_latents = [self.model.channels, h, w]
|
|
|
|
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=self.ddim_eta, verbose=False)
|
|
C, H, W = shape_latents
|
|
size = (1, C, H, W)
|
|
b = size[0]
|
|
|
|
latents = torch.randn(size, generator=generator, device=self.device)
|
|
|
|
timesteps = self.sampler.ddim_timesteps
|
|
|
|
time_range = np.flip(timesteps)
|
|
total_steps = timesteps.shape[0]
|
|
|
|
# collect latents
|
|
list_latents_out = []
|
|
for i, step in enumerate(time_range):
|
|
if do_inject_latents:
|
|
# Inject latent at right place
|
|
if i < idx_start:
|
|
continue
|
|
elif i == idx_start:
|
|
latents = latents_for_injection.clone()
|
|
|
|
if i == idx_stop:
|
|
return list_latents_out
|
|
|
|
# print(f"diffusion iter {i}")
|
|
index = total_steps - i - 1
|
|
ts = torch.full((b,), step, device=self.device, dtype=torch.long)
|
|
outs = self.sampler.p_sample_ddim(latents, cond, ts, index=index, use_original_steps=False,
|
|
quantize_denoised=False, temperature=1.0,
|
|
noise_dropout=0.0, score_corrector=None,
|
|
corrector_kwargs=None,
|
|
unconditional_guidance_scale=self.guidance_scale,
|
|
unconditional_conditioning=uc_full,
|
|
dynamic_threshold=None)
|
|
latents, pred_x0 = outs
|
|
list_latents_out.append(latents.clone())
|
|
|
|
if return_image:
|
|
return self.latent2image(latents)
|
|
else:
|
|
return list_latents_out
|
|
|
|
@torch.no_grad()
|
|
def latent2image(
|
|
self,
|
|
latents: torch.FloatTensor
|
|
):
|
|
r"""
|
|
Returns an image provided a latent representation from diffusion.
|
|
Args:
|
|
latents: torch.FloatTensor
|
|
Result of the diffusion process.
|
|
"""
|
|
x_sample = self.model.decode_first_stage(latents)
|
|
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
|
|
x_sample = 255 * x_sample[0,:,:].permute([1,2,0]).cpu().numpy()
|
|
image = x_sample.astype(np.uint8)
|
|
return image
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
fp_ckpt= "../stable_diffusion_models/ckpt/x4-upscaler-ema.ckpt"
|
|
fp_config = 'configs/x4-upscaling.yaml'
|
|
num_inference_steps = 100
|
|
self = StableDiffusionHolder(fp_ckpt, fp_config, num_inference_steps=num_inference_steps)
|
|
xxx
|
|
#%% image A
|
|
image = Image.open('/home/lugo/latentblending/test1/img_0007.jpg')
|
|
image = image.resize((32*20, 32*12))
|
|
promptA = "photo of a an ancient castle surrounded by a forest"
|
|
noise_level = 20 #gradio min=0, max=350, value=20
|
|
text_embeddingA = self.get_text_embedding(promptA)
|
|
cond, uc_full = self.get_cond_upscaling(image, text_embeddingA, noise_level)
|
|
|
|
list_samplesA = self.run_diffusion_upscaling(cond, uc_full)
|
|
image_result = Image.fromarray(self.latent2image(list_samplesA[-1]))
|
|
image_result.save('/home/lugo/latentblending/test1/high/imgA.jpg')
|
|
|
|
|
|
#%% image B
|
|
from latent_blending import interpolate_linear, interpolate_spherical
|
|
image = Image.open('/home/lugo/latentblending/test1/img_0006.jpg')
|
|
image = image.resize((32*20, 32*12))
|
|
promptA = "photo of a an ancient castle surrounded by a forest"
|
|
promptB = "photo of a beautiful island on the horizon, blue sea with waves"
|
|
noise_level = 20 #gradio min=0, max=350, value=20
|
|
text_embeddingA = self.get_text_embedding(promptA)
|
|
text_embeddingB = self.get_text_embedding(promptB)
|
|
text_embedding = interpolate_linear(text_embeddingA, text_embeddingB, 1/8)
|
|
|
|
cond, uc_full = self.get_cond_upscaling(image, text_embedding, noise_level)
|
|
|
|
list_samplesB = self.run_diffusion_upscaling(cond, uc_full)
|
|
image_result = Image.fromarray(self.latent2image(list_samplesB[-1]))
|
|
image_result.save('/home/lugo/latentblending/test1/high/imgB.jpg')
|
|
|
|
|
|
#%% reality check: run only for 50 iter.
|
|
image = Image.open('/home/lugo/latentblending/test1/img_0007.jpg')
|
|
image = image.resize((32*20, 32*12))
|
|
promptA = "photo of a an ancient castle surrounded by a forest"
|
|
noise_level = 20 #gradio min=0, max=350, value=20
|
|
text_embeddingA = self.get_text_embedding(promptA)
|
|
cond, uc_full = self.get_cond_upscaling(image, text_embeddingA, noise_level)
|
|
|
|
latents_inject = list_samplesA[50]
|
|
list_samplesAx = self.run_diffusion_upscaling(cond, uc_full, latents_inject, idx_start=50)
|
|
image_result = Image.fromarray(self.latent2image(list_samplesAx[-1]))
|
|
image_result.save('/home/lugo/latentblending/test1/high/imgA_restart.jpg')
|
|
|
|
# RESULTS ARE NOT EXACTLY IDENTICAL! INVESTIGATE WHY
|
|
|
|
#%% mix in the middle! which uc_full should be taken?
|
|
# expA: take the one from A
|
|
idx_start = 90
|
|
latentsA = list_samplesA[idx_start]
|
|
latentsB = list_samplesB[idx_start]
|
|
latents_inject = interpolate_spherical(latentsA, latentsB, 0.5)
|
|
|
|
image = Image.open('/home/lugo/latentblending/test1/img_0007.jpg')
|
|
image = image.resize((32*20, 32*12))
|
|
promptA = "photo of a an ancient castle surrounded by a forest"
|
|
noise_level = 20 #gradio min=0, max=350, value=20
|
|
text_embeddingA = self.get_text_embedding(promptA)
|
|
cond, uc_full = self.get_cond_upscaling(image, text_embeddingA, noise_level)
|
|
|
|
list_samples = self.run_diffusion_upscaling(cond, uc_full, latents_inject, idx_start=idx_start)
|
|
image_result = Image.fromarray(self.latent2image(list_samples[-1]))
|
|
image_result.save('/home/lugo/latentblending/test1/high/img_mix_expA_late.jpg')
|
|
|
|
|
|
#%% mix in the middle! which uc_full should be taken?
|
|
# expA: take the one from B
|
|
idx_start = 90
|
|
latentsA = list_samplesA[idx_start]
|
|
latentsB = list_samplesB[idx_start]
|
|
latents_inject = interpolate_spherical(latentsA, latentsB, 0.5)
|
|
|
|
image = Image.open('/home/lugo/latentblending/test1/img_0006.jpg').resize((32*20, 32*12))
|
|
promptA = "photo of a an ancient castle surrounded by a forest"
|
|
promptB = "photo of a beautiful island on the horizon, blue sea with waves"
|
|
noise_level = 20 #gradio min=0, max=350, value=20
|
|
text_embeddingA = self.get_text_embedding(promptA)
|
|
text_embeddingB = self.get_text_embedding(promptB)
|
|
text_embedding = interpolate_linear(text_embeddingA, text_embeddingB, 1/8)
|
|
cond, uc_full = self.get_cond_upscaling(image, text_embedding, noise_level)
|
|
|
|
list_samples = self.run_diffusion_upscaling(cond, uc_full, latents_inject, idx_start=idx_start)
|
|
image_result = Image.fromarray(self.latent2image(list_samples[-1]))
|
|
image_result.save('/home/lugo/latentblending/test1/high/img_mix_expB_late.jpg')
|
|
|
|
|
|
|
|
|
|
#%% lets blend the uc_full too!
|
|
# expC
|
|
|
|
idx_start = 50
|
|
list_mix = np.linspace(0, 1, 20)
|
|
for fract_mix in list_mix:
|
|
# fract_mix = 0.75
|
|
latentsA = list_samplesA[idx_start]
|
|
latentsB = list_samplesB[idx_start]
|
|
latents_inject = interpolate_spherical(latentsA, latentsB, fract_mix)
|
|
|
|
text_embeddingA = self.get_text_embedding(promptA)
|
|
text_embeddingB = self.get_text_embedding(promptB)
|
|
text_embedding = interpolate_linear(text_embeddingA, text_embeddingB, 1/8)
|
|
|
|
imageA = Image.open('/home/lugo/latentblending/test1/img_0007.jpg').resize((32*20, 32*12))
|
|
condA, uc_fullA = self.get_cond_upscaling(imageA, text_embedding, noise_level)
|
|
|
|
imageB = Image.open('/home/lugo/latentblending/test1/img_0006.jpg').resize((32*20, 32*12))
|
|
condB, uc_fullB = self.get_cond_upscaling(imageB, text_embedding, noise_level)
|
|
|
|
condA['c_concat'][0] = interpolate_spherical(condA['c_concat'][0], condB['c_concat'][0], fract_mix)
|
|
uc_fullA['c_concat'][0] = interpolate_spherical(uc_fullA['c_concat'][0], uc_fullB['c_concat'][0], fract_mix)
|
|
|
|
list_samples = self.run_diffusion_upscaling(condA, uc_fullA, latents_inject, idx_start=idx_start)
|
|
image_result = Image.fromarray(self.latent2image(list_samples[-1]))
|
|
image_result.save(f'/home/lugo/latentblending/test1/high/img_mix_expC_{fract_mix}_start{idx_start}.jpg')
|
|
|
|
|
|
|
|
#%%
|
|
|
|
list_imgs = os.listdir('/home/lugo/latentblending/test1/high/')
|
|
list_imgs = [l for l in list_imgs if "expC" in l]
|
|
list_imgs.pop(0)
|
|
|
|
lx = []
|
|
for fn in list_imgs:
|
|
Image.open
|
|
|
|
|
|
#%%
|
|
|
|
|
|
|
|
if False:
|
|
|
|
num_inference_steps = 20 # Number of diffusion interations
|
|
|
|
# fp_ckpt = "../stable_diffusion_models/ckpt/768-v-ema.ckpt"
|
|
# fp_config = '../stablediffusion/configs/stable-diffusion/v2-inference-v.yaml'
|
|
|
|
fp_ckpt= "../stable_diffusion_models/ckpt/512-inpainting-ema.ckpt"
|
|
fp_config = '../stablediffusion/configs//stable-diffusion/v2-inpainting-inference.yaml'
|
|
|
|
sdh = StableDiffusionHolder(fp_ckpt, fp_config, num_inference_steps)
|
|
# fp_ckpt= "../stable_diffusion_models/ckpt/512-base-ema.ckpt"
|
|
# fp_config = '../stablediffusion/configs//stable-diffusion/v2-inference.yaml'
|
|
|
|
|
|
|
|
image_source = Image.fromarray((255*np.random.rand(512,512,3)).astype(np.uint8))
|
|
mask = 255*np.ones([512,512], dtype=np.uint8)
|
|
mask[0:50, 0:50] = 0
|
|
mask = Image.fromarray(mask)
|
|
|
|
sdh.init_inpainting(image_source, mask)
|
|
text_embedding = sdh.get_text_embedding("photo of a strange house, surreal painting")
|
|
list_latents = sdh.run_diffusion_inpaint(text_embedding)
|
|
|
|
idx_inject = 3
|
|
img_orig = sdh.latent2image(list_latents[-1])
|
|
list_inject = sdh.run_diffusion_inpaint(text_embedding, list_latents[idx_inject], idx_start=idx_inject+1)
|
|
img_inject = sdh.latent2image(list_inject[-1])
|
|
|
|
img_diff = img_orig - img_inject
|
|
import matplotlib.pyplot as plt
|
|
plt.imshow(np.concatenate((img_orig, img_inject, img_diff), axis=1))
|
|
|
|
|
|
|