helper function moved

This commit is contained in:
Johannes Stelzer 2023-01-12 10:16:31 +01:00
parent ea298030f3
commit 30f4aaaa24
2 changed files with 22 additions and 21 deletions

View File

@ -26,7 +26,7 @@ from PIL import Image
import torch import torch
from movie_util import MovieSaver from movie_util import MovieSaver
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
from latent_blending import get_time, yml_save, LatentBlending, add_frames_linear_interp from latent_blending import get_time, yml_save, LatentBlending, add_frames_linear_interp, compare_dicts
from stable_diffusion_holder import StableDiffusionHolder from stable_diffusion_holder import StableDiffusionHolder
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
import gradio as gr import gradio as gr
@ -35,26 +35,6 @@ import copy
#%% #%%
def compare_dicts(a, b):
"""
Compares two dictionaries a and b and returns a dictionary c, with all
keys,values that have shared keys in a and b but same values in a and b.
The values of a and b are stacked together in the output.
Example:
a = {}; a['bobo'] = 4
b = {}; b['bobo'] = 5
c = dict_compare(a,b)
c = {"bobo",[4,5]}
"""
c = {}
for key in a.keys():
if key in b.keys():
val_a = a[key]
val_b = b[key]
if val_a != val_b:
c[key] = [val_a, val_b]
return c
class BlendingFrontend(): class BlendingFrontend():
def __init__(self, sdh=None): def __init__(self, sdh=None):
if sdh is None: if sdh is None:

View File

@ -1113,6 +1113,25 @@ def get_time(resolution=None):
raise ValueError("bad resolution provided: %s" %resolution) raise ValueError("bad resolution provided: %s" %resolution)
return t return t
def compare_dicts(a, b):
"""
Compares two dictionaries a and b and returns a dictionary c, with all
keys,values that have shared keys in a and b but same values in a and b.
The values of a and b are stacked together in the output.
Example:
a = {}; a['bobo'] = 4
b = {}; b['bobo'] = 5
c = dict_compare(a,b)
c = {"bobo",[4,5]}
"""
c = {}
for key in a.keys():
if key in b.keys():
val_a = a[key]
val_b = b[key]
if val_a != val_b:
c[key] = [val_a, val_b]
return c
def yml_load(fp_yml, print_fields=False): def yml_load(fp_yml, print_fields=False):
""" """
@ -1144,6 +1163,8 @@ if __name__ == "__main__":
sdh = StableDiffusionHolder(fp_ckpt, fp_config, device) sdh = StableDiffusionHolder(fp_ckpt, fp_config, device)
xxx
#%% Next let's set up all parameters #%% Next let's set up all parameters
quality = 'medium' quality = 'medium'