This commit is contained in:
Johannes Stelzer 2023-02-20 11:26:04 +01:00
parent 7df09e8d0b
commit 4ce995a023
2 changed files with 39 additions and 21 deletions

View File

@ -101,7 +101,6 @@ class LatentBlending():
self.text_embedding2 = None self.text_embedding2 = None
self.image1_lowres = None self.image1_lowres = None
self.image2_lowres = None self.image2_lowres = None
self.stop_diffusion = False
self.negative_prompt = None self.negative_prompt = None
self.num_inference_steps = self.sdh.num_inference_steps self.num_inference_steps = self.sdh.num_inference_steps
self.noise_level_upscaling = 20 self.noise_level_upscaling = 20
@ -117,7 +116,6 @@ class LatentBlending():
self.parental_crossfeed_range = 0.8 self.parental_crossfeed_range = 0.8
self.parental_crossfeed_power_decay = 0.8 self.parental_crossfeed_power_decay = 0.8
self.branch1_insertion_completed = False
self.set_guidance_scale(guidance_scale) self.set_guidance_scale(guidance_scale)
self.init_mode() self.init_mode()
self.multi_transition_img_first = None self.multi_transition_img_first = None
@ -454,7 +452,6 @@ class LatentBlending():
if stop_criterion == "t_compute_max_allowed" and t_compute > t_compute_max_allowed: if stop_criterion == "t_compute_max_allowed" and t_compute > t_compute_max_allowed:
stop_criterion_reached = True stop_criterion_reached = True
# FIXME: also undersample here... but how... maybe drop them iteratively?
elif stop_criterion == "nmb_max_branches" and np.sum(list_nmb_stems) >= nmb_max_branches: elif stop_criterion == "nmb_max_branches" and np.sum(list_nmb_stems) >= nmb_max_branches:
stop_criterion_reached = True stop_criterion_reached = True
if is_first_iteration: if is_first_iteration:
@ -528,15 +525,20 @@ class LatentBlending():
def get_spatial_mask_template(self): def get_spatial_mask_template(self):
r"""
Experimental helper function to get a spatial mask template.
"""
shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f] shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f]
C, H, W = shape_latents C, H, W = shape_latents
return np.ones((H, W)) return np.ones((H, W))
def set_spatial_mask(self, img_mask): def set_spatial_mask(self, img_mask):
r""" r"""
Helper function to #FIXME Experimental helper function to set a spatial mask.
The mask forces latents to be overwritten.
Args: Args:
seed: int img_mask:
mask image [0,1]. You can get a template using get_spatial_mask_template
""" """
@ -591,7 +593,8 @@ class LatentBlending():
Depending on the mode, the correct one will be executed. Depending on the mode, the correct one will be executed.
Args: Args:
list_conditionings: List of all conditionings for the diffusion model. list_conditionings: list
List of all conditionings for the diffusion model.
latents_start: torch.FloatTensor latents_start: torch.FloatTensor
Latents that are used for injection Latents that are used for injection
idx_start: int idx_start: int
@ -640,10 +643,33 @@ class LatentBlending():
num_inference_steps: int = 100, num_inference_steps: int = 100,
nmb_max_branches_highres: int = 5, nmb_max_branches_highres: int = 5,
nmb_max_branches_lowres: int = 6, nmb_max_branches_lowres: int = 6,
fixed_seeds: Optional[List[int]] = None,
duration_single_segment = 3, duration_single_segment = 3,
fixed_seeds: Optional[List[int]] = None,
): ):
#FIXME r"""
Runs upscaling with the x4 model. Requires that you run a transition before with a low-res model and save the results using write_imgs_transition.
Args:
dp_img: str
Path to the low-res transition path (as saved in write_imgs_transition)
depth_strength:
Determines how deep the first injection will happen.
Deeper injections will cause (unwanted) formation of new structures,
more shallow values will go into alpha-blendy land.
num_inference_steps:
Number of diffusion steps. Higher values will take more compute time.
nmb_max_branches_highres: int
Number of final branches of the upscaling transition pass. Note this is the number
of branches between each pair of low-res images.
nmb_max_branches_lowres: int
Number of input low-res images, subsampling all transition images written in the low-res pass.
Setting this number lower (e.g. 6) will decrease the compute time but not affect the results too much.
duration_single_segment: float
The duration of each high-res movie segment. You will have nmb_max_branches_lowres-1 segments in total.
fixed_seeds: Optional[List[int)]:
You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2).
Otherwise random seeds will be taken.
"""
fp_yml = os.path.join(dp_img, "lowres.yaml") fp_yml = os.path.join(dp_img, "lowres.yaml")
fp_movie = os.path.join(dp_img, "movie_highres.mp4") fp_movie = os.path.join(dp_img, "movie_highres.mp4")
fps = 24 fps = 24

View File

@ -265,7 +265,9 @@ class StableDiffusionHolder:
idx_start: int idx_start: int
Index of the diffusion process start and where the latents_for_injection are injected Index of the diffusion process start and where the latents_for_injection are injected
mixing_coeff: mixing_coeff:
# FIXME spatial_mask mixing coefficients for latent blending
spatial_mask:
experimental feature for enforcing pixels from list_latents_mixing
return_image: Optional[bool] return_image: Optional[bool]
Optionally return image directly Optionally return image directly
@ -352,15 +354,6 @@ class StableDiffusionHolder:
): ):
r""" r"""
Diffusion upscaling version. Diffusion upscaling version.
# FIXME
Args:
??
latents_for_injection: torch.FloatTensor
Latents that are used for injection
idx_start: int
Index of the diffusion process start and where the latents_for_injection are injected
return_image: Optional[bool]
Optionally return image directly
""" """
# Asserts # Asserts
@ -376,7 +369,6 @@ class StableDiffusionHolder:
assert len(list_latents_mixing) == self.num_inference_steps assert len(list_latents_mixing) == self.num_inference_steps
precision_scope = autocast if self.precision == "autocast" else nullcontext precision_scope = autocast if self.precision == "autocast" else nullcontext
generator = torch.Generator(device=self.device).manual_seed(int(self.seed))
h = uc_full['c_concat'][0].shape[2] h = uc_full['c_concat'][0].shape[2]
w = uc_full['c_concat'][0].shape[3] w = uc_full['c_concat'][0].shape[3]