diff --git a/gradio_ui.py b/gradio_ui.py
index 33116b2..200abe5 100644
--- a/gradio_ui.py
+++ b/gradio_ui.py
@@ -32,23 +32,35 @@ torch.set_grad_enabled(False)
 import gradio as gr
 import copy
 from dotenv import find_dotenv, load_dotenv
+import shutil
 
+"""
+never hit compute trans -> multi movie add fail
+
+"""
 
 
 #%%
 
 class BlendingFrontend():
     def __init__(self, sdh=None):
+        self.num_inference_steps = 30
         if sdh is None:
             self.use_debug = True
+            self.height = 768
+            self.width = 768
         else:
             self.use_debug = False
             self.lb = LatentBlending(sdh)
-            
-        self.share = True
-        self.num_inference_steps = 30
+            self.lb.sdh.num_inference_steps = self.num_inference_steps
+            self.height = self.lb.sdh.height
+            self.width = self.lb.sdh.width
+        
+        self.init_save_dir()
+        self.save_empty_image()
+        self.share = False
         self.depth_strength = 0.25
-        self.seed1 = 42
+        self.seed1 = 420
         self.seed2 = 420
         self.guidance_scale = 4.0
         self.guidance_scale_mid_damper = 0.5
@@ -72,16 +84,13 @@ class BlendingFrontend():
         self.current_timestamp = None
         self.recycle_img1 = False
         self.recycle_img2 = False
+        self.fp_img1 = None
+        self.fp_img2 = None
+        self.multi_idx_current = -1
+        self.multi_list_concat = []
+        self.list_imgs_shown_last = 5*[self.fp_img_empty]
+        self.nmb_trans_stack = 6
         
-        if not self.use_debug:
-            self.lb.sdh.num_inference_steps = self.num_inference_steps
-            self.height = self.lb.sdh.height
-            self.width = self.lb.sdh.width
-        else:
-            self.height = 768
-            self.width = 768
-            
-        self.init_save_dir()
         
         
     def init_save_dir(self):
@@ -89,12 +98,18 @@ class BlendingFrontend():
         try:
             self.dp_out = os.getenv("dp_out")
         except Exception as e:
+            print(f"did not find .env file. using local folder. {e}")
             self.dp_out = ""
+        self.dp_imgs = os.path.join(self.dp_out, "imgs")
+        os.makedirs(self.dp_imgs, exist_ok=True)
+        self.dp_movies = os.path.join(self.dp_out, "movies")
+        os.makedirs(self.dp_movies, exist_ok=True)
+        
         
         
         # make dummy image
     def save_empty_image(self):
-        self.fp_img_empty = os.path.join(self.dp_out, 'empty.jpg')
+        self.fp_img_empty = os.path.join(self.dp_imgs, 'empty.jpg')
         Image.fromarray(np.zeros((self.height, self.width, 3), dtype=np.uint8)).save(self.fp_img_empty, quality=5)
         
         
@@ -140,22 +155,21 @@ class BlendingFrontend():
     def compute_img1(self, *args):
         list_ui_elem = args
         self.setup_lb(list_ui_elem)
-        fp_img1 = os.path.join(self.dp_out, f"img1_{get_time('second')}.jpg")
+        self.fp_img1 = os.path.join(self.dp_imgs, f"img1_{get_time('second')}.jpg")
         img1 = Image.fromarray(self.lb.compute_latents1(return_image=True))
-        img1.save(fp_img1)
-        self.save_empty_image()
+        img1.save(self.fp_img1)
         self.recycle_img1 = True
         self.recycle_img2 = False
-        return [fp_img1, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty]
+        return [self.fp_img1, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty]
     
     def compute_img2(self, *args):
         list_ui_elem = args
         self.setup_lb(list_ui_elem)
-        fp_img2 = os.path.join(self.dp_out, f"img2_{get_time('second')}.jpg")
+        self.fp_img2 = os.path.join(self.dp_imgs, f"img2_{get_time('second')}.jpg")
         img2 = Image.fromarray(self.lb.compute_latents2(return_image=True))
-        img2.save(fp_img2)
+        img2.save(self.fp_img2)
         self.recycle_img2 = True
-        return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, fp_img2]
+        return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img2]
         
     def compute_transition(self, *args):
         
@@ -199,7 +213,7 @@ class BlendingFrontend():
         self.current_timestamp = get_time('second')
         self.list_fp_imgs_current = []
         for i in range(len(list_imgs_preview)):
-            fp_img = f"img_preview_{i}_{self.current_timestamp}.jpg"
+            fp_img = os.path.join(self.dp_imgs, f"img_preview_{i}_{self.current_timestamp}.jpg")
             list_imgs_preview[i].save(fp_img)
             self.list_fp_imgs_current.append(fp_img)
         
@@ -207,51 +221,38 @@ class BlendingFrontend():
         imgs_transition_ext = add_frames_linear_interp(imgs_transition, self.duration_video, self.fps)
 
         # Save as movie
-        fp_movie = self.get_fp_movie(self.current_timestamp)
-        if os.path.isfile(fp_movie):
-            os.remove(fp_movie)
-        ms = MovieSaver(fp_movie, fps=self.fps)
+        self.fp_movie = os.path.join(self.dp_movies, f"movie_{self.current_timestamp}.mp4") 
+        if os.path.isfile(self.fp_movie):
+            os.remove(self.fp_movie)
+        ms = MovieSaver(self.fp_movie, fps=self.fps)
         for img in tqdm(imgs_transition_ext):
             ms.write_frame(img)
         ms.finalize()
         print("DONE SAVING MOVIE! SENDING BACK...")
         
         # Assemble Output, updating the preview images and le movie
-        list_return = self.list_fp_imgs_current + [fp_movie]
+        list_return = self.list_fp_imgs_current + [self.fp_movie]
         return list_return
 
-    def get_fp_movie(self, timestamp, is_stacked=False):
-        if not is_stacked:
-            fn = f"movie_{timestamp}.mp4"
-        else:
-            fn = f"movie_stacked_{timestamp}.mp4"
-        fp = os.path.join(self.dp_out, fn)
-        return fp
-            
     
     def stack_forward(self, prompt2, seed2):
         # Save preview images, prompts and seeds into dictionary for stacking
-        dp_out = os.path.join(self.dp_out, get_time('second'))
-        self.lb.write_imgs_transition(dp_out)
+        # self.list_imgs_shown_last = self.get_multi_trans_imgs_preview(f"lowres_{self.current_timestamp}")[0:5]
+        timestamp_section = get_time('second')
+        self.lb.write_imgs_transition(os.path.join(self.dp_out, f"lowres_{timestamp_section}"))
+        self.lb.write_imgs_transition(os.path.join(self.dp_out, "lowres_current"))
+        shutil.copyfile(self.fp_movie, os.path.join(self.dp_out, f"lowres_{timestamp_section}", "movie.mp4"))
+        
         self.lb.swap_forward()
-        list_out = [self.list_fp_imgs_current[-1]]
+        list_out = [self.fp_img2] 
         list_out.extend([self.fp_img_empty]*4)
         list_out.append(prompt2)
         list_out.append(seed2)
         list_out.append("")
         list_out.append(np.random.randint(0, 10000000))
+        
         return list_out
-        
-
-    def stack_movie(self):
-        # collect all that are in...
-        list_fp_movies = []
-
-        list_fp_movies.append(self.get_fp_movie(timestamp))
-        
-        fp_stacked = self.get_fp_movie(get_time('second'), True)
-        concatenate_movies(fp_stacked, list_fp_movies)
-        return fp_stacked
+    
         
 
     def get_state_dict(self):
@@ -264,7 +265,90 @@ class BlendingFrontend():
             state_dict[v] = getattr(self, v)
         return state_dict
     
-   
+    
+    def get_list_all_stacked(self):
+        list_all = os.listdir(os.path.join(self.dp_out))
+        list_all = [l for l in list_all if l[:8]=="lowres_2"]
+        list_all.sort()
+        return list_all
+        
+    def multi_trans_show_older(self):
+        list_all = self.get_list_all_stacked()
+        if self.multi_idx_current == -1:
+            self.multi_idx_current = len(list_all) - 1
+        else:
+            self.multi_idx_current -= 1
+    
+        if self.multi_idx_current < 0:
+            self.multi_idx_current  = 0
+        dn = list_all[self.multi_idx_current]
+        return self.get_multi_trans_imgs_preview(dn)
+    
+    def multi_trans_show_newer(self):
+        list_all = self.get_list_all_stacked()
+        if self.multi_idx_current == -1:
+            self.multi_idx_current = len(list_all) - 1
+        else:
+            self.multi_idx_current += 1
+    
+        if self.multi_idx_current >= len(list_all):
+            self.multi_idx_current  = len(list_all) - 1
+        dn = list_all[self.multi_idx_current]
+        return self.get_multi_trans_imgs_preview(dn)
+    
+    def get_multi_trans_imgs_preview(self, dn):
+        dp_show = os.path.join(self.dp_out, dn)
+        list_imgs_transition = os.listdir(dp_show)
+        list_imgs_transition = [l for l in list_imgs_transition if l[:11]=="lowres_img_"]
+        list_imgs_transition.sort()
+    
+        idx_img_prev = np.round(np.linspace(0, len(list_imgs_transition)-1, 5)).astype(np.int32)
+        list_imgs_preview = []
+        for j in idx_img_prev:
+            list_imgs_preview.append(os.path.join(dp_show, list_imgs_transition[j]))
+        
+        list_out = list_imgs_preview
+        list_out.append(dn[7:])
+        
+        return list_out
+       
+    def multi_append(self):
+        list_all = self.get_list_all_stacked()
+        dn = list_all[self.multi_idx_current]
+        self.multi_list_concat.append(dn)
+        list_short = [dn[7:] for dn in self.multi_list_concat]
+        str_out = "\n".join(list_short)
+        return str_out
+       
+    def multi_reset(self):
+        self.multi_list_concat = []
+        str_out = ""
+        return str_out
+       
+    def multi_concat(self):
+        # Make new output directory
+        dp_multi = os.path.join(self.dp_out, f"multi_{get_time('second')}")       
+        os.makedirs(dp_multi, exist_ok=False)
+        
+        # Copy all low-res folders (prepending multi001_xxxx), however leave out the movie.mp4
+        # also collect all movie.mp4
+        list_fp_movies = []
+        for i, dn in enumerate(self.multi_list_concat):
+            dp_source = os.path.join(self.dp_out, dn)
+            dp_sequence = os.path.join(dp_multi, f"{str(i).zfill(3)}_{dn}")
+            os.makedirs(dp_sequence, exist_ok=False)
+            list_source = os.listdir(dp_source)
+            list_source = [l for l in list_source if not l.endswith(".mp4")]
+            for fn in list_source:
+                shutil.copyfile(os.path.join(dp_source, fn), os.path.join(dp_sequence, fn))
+            list_fp_movies.append(os.path.join(dp_source, "movie.mp4"))
+            
+        # Concatenate movies and save
+        fp_final = os.path.join(dp_multi, "movie.mp4")
+        concatenate_movies(fp_final, list_fp_movies)
+        return fp_final
+        
+                
 
 def get_img_rand():
     return (255*np.random.rand(self.height,self.width,3)).astype(np.uint8)
@@ -286,120 +370,150 @@ def generate_list_output(
     
     return list_output
 
-
         
 if __name__ == "__main__":    
     
-    # fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_768-ema-pruned.ckpt" 
-    fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt" 
-    sdh = StableDiffusionHolder(fp_ckpt)
-    
-    self = BlendingFrontend(sdh) # Yes this is possible in python and yes it is an awesome trick
+    fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_768-ema-pruned.ckpt" 
+    # fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt" 
+    self = BlendingFrontend(StableDiffusionHolder(fp_ckpt)) # Yes this is possible in python and yes it is an awesome trick
     # self = BlendingFrontend(None) # Yes this is possible in python and yes it is an awesome trick
     
     dict_ui_elem = {}
     
     with gr.Blocks() as demo:
-        with gr.Row():
-            prompt1 = gr.Textbox(label="prompt 1")
-            prompt2 = gr.Textbox(label="prompt 2")
-        
-        with gr.Row():
-            duration_compute = gr.Slider(5, 45, self.t_compute_max_allowed, step=1, label='compute budget for transition (seconds)', interactive=True) 
-            duration_video = gr.Slider(0.1, 30, self.duration_video, step=0.1, label='result video duration (seconds)', interactive=True) 
-            height = gr.Slider(256, 2048, self.height, step=128, label='height', interactive=True)
-            width = gr.Slider(256, 2048, self.width, step=128, label='width', interactive=True) 
+        with gr.Tab("Single Transition"):
+            with gr.Row():
+                prompt1 = gr.Textbox(label="prompt 1")
+                prompt2 = gr.Textbox(label="prompt 2")
             
-        with gr.Accordion("Advanced Settings (click to expand)", open=False):
-
-            with gr.Accordion("Diffusion settings", open=True):
-                with gr.Row():
-                    num_inference_steps = gr.Slider(5, 100, self.num_inference_steps, step=1, label='num_inference_steps', interactive=True)
-                    guidance_scale = gr.Slider(1, 25, self.guidance_scale, step=0.1, label='guidance_scale', interactive=True) 
-                    negative_prompt = gr.Textbox(label="negative prompt")          
-            
-            with gr.Accordion("Seeds control", open=True):
-                with gr.Row():
-                    seed1 = gr.Number(420, label="seed 1", interactive=True)
-                    b_newseed1 = gr.Button("randomize seed 1", variant='secondary')
-                    seed2 = gr.Number(420, label="seed 2", interactive=True)
-                    b_newseed2 = gr.Button("randomize seed 2", variant='secondary')
-                    
-            with gr.Accordion("Crossfeeding for last image", open=True):
-                with gr.Row():
-                    branch1_influence = gr.Slider(0.0, 1.0, self.branch1_influence, step=0.01, label='crossfeed power', interactive=True) 
-                    branch1_max_depth_influence = gr.Slider(0.0, 1.0, self.branch1_max_depth_influence, step=0.01, label='crossfeed range', interactive=True) 
-                    branch1_influence_decay = gr.Slider(0.0, 1.0, self.branch1_influence_decay, step=0.01, label='crossfeed decay', interactive=True) 
-
-            with gr.Accordion("Transition settings", open=True):
-                with gr.Row():
-                    depth_strength = gr.Slider(0.01, 0.99, self.depth_strength, step=0.01, label='depth_strength', interactive=True) 
-                    guidance_scale_mid_damper = gr.Slider(0.01, 2.0, self.guidance_scale_mid_damper, step=0.01, label='guidance_scale_mid_damper', interactive=True) 
-                    parental_influence = gr.Slider(0.0, 1.0, self.parental_influence, step=0.01, label='parental power', interactive=True) 
-                    parental_max_depth_influence = gr.Slider(0.0, 1.0, self.parental_max_depth_influence, step=0.01, label='parental range', interactive=True) 
-                    parental_influence_decay = gr.Slider(0.0, 1.0, self.parental_influence_decay, step=0.01, label='parental decay', interactive=True) 
-        
+            with gr.Row():
+                duration_compute = gr.Slider(5, 45, self.t_compute_max_allowed, step=1, label='compute budget for transition (seconds)', interactive=True) 
+                duration_video = gr.Slider(0.1, 30, self.duration_video, step=0.1, label='result video duration (seconds)', interactive=True) 
+                height = gr.Slider(256, 2048, self.height, step=128, label='height', interactive=True)
+                width = gr.Slider(256, 2048, self.width, step=128, label='width', interactive=True) 
                 
-        with gr.Row():
-            b_compute1 = gr.Button('compute first image', variant='primary')
-            b_compute_transition = gr.Button('compute transition', variant='primary')
-            b_compute2 = gr.Button('compute last image', variant='primary')
-        
-        with gr.Row():
-            img1 = gr.Image(label="1/5")
-            img2 = gr.Image(label="2/5")
-            img3 = gr.Image(label="3/5")
-            img4 = gr.Image(label="4/5")
-            img5 = gr.Image(label="5/5")
-        
-        with gr.Row():
-            vid_transition = gr.Video()
+            with gr.Accordion("Advanced Settings (click to expand)", open=False):
+    
+                with gr.Accordion("Diffusion settings", open=True):
+                    with gr.Row():
+                        num_inference_steps = gr.Slider(5, 100, self.num_inference_steps, step=1, label='num_inference_steps', interactive=True)
+                        guidance_scale = gr.Slider(1, 25, self.guidance_scale, step=0.1, label='guidance_scale', interactive=True) 
+                        negative_prompt = gr.Textbox(label="negative prompt")          
+                
+                with gr.Accordion("Seeds control", open=True):
+                    with gr.Row():
+                        seed1 = gr.Number(self.seed1, label="seed 1", interactive=True)
+                        b_newseed1 = gr.Button("randomize seed 1", variant='secondary')
+                        seed2 = gr.Number(self.seed2, label="seed 2", interactive=True)
+                        b_newseed2 = gr.Button("randomize seed 2", variant='secondary')
+                        
+                with gr.Accordion("Crossfeeding for last image", open=True):
+                    with gr.Row():
+                        branch1_influence = gr.Slider(0.0, 1.0, self.branch1_influence, step=0.01, label='crossfeed power', interactive=True) 
+                        branch1_max_depth_influence = gr.Slider(0.0, 1.0, self.branch1_max_depth_influence, step=0.01, label='crossfeed range', interactive=True) 
+                        branch1_influence_decay = gr.Slider(0.0, 1.0, self.branch1_influence_decay, step=0.01, label='crossfeed decay', interactive=True) 
+    
+                with gr.Accordion("Transition settings", open=True):
+                    with gr.Row():
+                        depth_strength = gr.Slider(0.01, 0.99, self.depth_strength, step=0.01, label='depth_strength', interactive=True) 
+                        guidance_scale_mid_damper = gr.Slider(0.01, 2.0, self.guidance_scale_mid_damper, step=0.01, label='guidance_scale_mid_damper', interactive=True) 
+                        parental_influence = gr.Slider(0.0, 1.0, self.parental_influence, step=0.01, label='parental power', interactive=True) 
+                        parental_max_depth_influence = gr.Slider(0.0, 1.0, self.parental_max_depth_influence, step=0.01, label='parental range', interactive=True) 
+                        parental_influence_decay = gr.Slider(0.0, 1.0, self.parental_influence_decay, step=0.01, label='parental decay', interactive=True) 
+            
+                    
+            with gr.Row():
+                b_compute1 = gr.Button('compute first image', variant='primary')
+                b_compute_transition = gr.Button('compute transition', variant='primary')
+                b_compute2 = gr.Button('compute last image', variant='primary')
+            
+            with gr.Row():
+                img1 = gr.Image(label="1/5")
+                img2 = gr.Image(label="2/5")
+                img3 = gr.Image(label="3/5")
+                img4 = gr.Image(label="4/5")
+                img5 = gr.Image(label="5/5")
+            
+            with gr.Row():
+                vid_transition = gr.Video()
+                
+            with gr.Row():
+                b_stackforward = gr.Button('multi-movie start next segment (move last image -> first image)')
+            
+            # Collect all UI elemts in list to easily pass as inputs
+            dict_ui_elem["prompt1"] = prompt1
+            dict_ui_elem["negative_prompt"] = negative_prompt
+            dict_ui_elem["prompt2"] = prompt2
+             
+            dict_ui_elem["duration_compute"] = duration_compute
+            dict_ui_elem["duration_video"] = duration_video
+            dict_ui_elem["height"] = height
+            dict_ui_elem["width"] = width
+             
+            dict_ui_elem["depth_strength"] = depth_strength
+            dict_ui_elem["branch1_influence"] = branch1_influence
+            dict_ui_elem["branch1_max_depth_influence"] = branch1_max_depth_influence
+            dict_ui_elem["branch1_influence_decay"] = branch1_influence_decay
+            
+            dict_ui_elem["num_inference_steps"] = num_inference_steps
+            dict_ui_elem["guidance_scale"] = guidance_scale
+            dict_ui_elem["guidance_scale_mid_damper"] = guidance_scale_mid_damper
+            dict_ui_elem["seed1"] = seed1
+            dict_ui_elem["seed2"] = seed2
+            
+            dict_ui_elem["parental_max_depth_influence"] = parental_max_depth_influence
+            dict_ui_elem["parental_influence"] = parental_influence
+            dict_ui_elem["parental_influence_decay"] = parental_influence_decay
+            
+            # Convert to list, as gradio doesn't seem to accept dicts
+            list_ui_elem = []
+            list_ui_keys = []
+            for k in dict_ui_elem.keys():
+                list_ui_elem.append(dict_ui_elem[k])
+                list_ui_keys.append(k)
+            self.list_ui_keys = list_ui_keys
+            
+            b_newseed1.click(self.randomize_seed1, outputs=seed1)
+            b_newseed2.click(self.randomize_seed2, outputs=seed2)
+            b_compute1.click(self.compute_img1, inputs=list_ui_elem, outputs=[img1, img2, img3, img4, img5])
+            b_compute2.click(self.compute_img2, inputs=list_ui_elem, outputs=[img2, img3, img4, img5])
+            b_compute_transition.click(self.compute_transition, 
+                                        inputs=list_ui_elem,
+                                        outputs=[img2, img3, img4, vid_transition])
+            
+            b_stackforward.click(self.stack_forward, 
+                          inputs=[prompt2, seed2], 
+                          outputs=[img1, img2, img3, img4, img5, prompt1, seed1, prompt2])
+            
+        with gr.Tab("Multi Transition"):
+            with gr.Row():
+                multi_img1_prev = gr.Image(value=self.list_imgs_shown_last[0], label="1/5")
+                multi_img2_prev = gr.Image(value=self.list_imgs_shown_last[1], label="2/5")
+                multi_img3_prev = gr.Image(value=self.list_imgs_shown_last[2], label="3/5")
+                multi_img4_prev = gr.Image(value=self.list_imgs_shown_last[3], label="4/5")
+                multi_img5_prev = gr.Image(value=self.list_imgs_shown_last[4], label="5/5")
+            
+            with gr.Row():
+                b_older = gr.Button("show older")
+                b_newer = gr.Button("show newer")
+                text_timestamp = gr.Textbox(label="created", interactive=False)
+                b_append = gr.Button("append this transition")
+             
+            with gr.Row():
+                text_all_timestamps = gr.Textbox(label="movie list", interactive=False)
+            with gr.Row():
+                b_reset = gr.Button("reset")
+                b_concat = gr.Button("merge together", variant='primary')
+            
+            with gr.Row():
+                vid_multi = gr.Video()
+            
+            
+            b_older.click(self.multi_trans_show_older, inputs=[], outputs=[multi_img1_prev, multi_img2_prev, multi_img3_prev, multi_img4_prev, multi_img5_prev, text_timestamp])
+            b_newer.click(self.multi_trans_show_newer, inputs=[], outputs=[multi_img1_prev, multi_img2_prev, multi_img3_prev, multi_img4_prev, multi_img5_prev, text_timestamp])
+            b_append.click(self.multi_append, inputs=[], outputs=[text_all_timestamps])
+            b_reset.click(self.multi_reset, inputs=[], outputs=[text_all_timestamps])
+            b_concat.click(self.multi_concat, inputs=[], outputs=[vid_multi])
+            
             
-        with gr.Row():
-            b_stackforward = gr.Button('multi-movie start next segment (move last image -> first image)')
-        
-        # Collect all UI elemts in list to easily pass as inputs
-        dict_ui_elem["prompt1"] = prompt1
-        dict_ui_elem["negative_prompt"] = negative_prompt
-        dict_ui_elem["prompt2"] = prompt2
-         
-        dict_ui_elem["duration_compute"] = duration_compute
-        dict_ui_elem["duration_video"] = duration_video
-        dict_ui_elem["height"] = height
-        dict_ui_elem["width"] = width
-         
-        dict_ui_elem["depth_strength"] = depth_strength
-        dict_ui_elem["branch1_influence"] = branch1_influence
-        dict_ui_elem["branch1_max_depth_influence"] = branch1_max_depth_influence
-        dict_ui_elem["branch1_influence_decay"] = branch1_influence_decay
-        
-        dict_ui_elem["num_inference_steps"] = num_inference_steps
-        dict_ui_elem["guidance_scale"] = guidance_scale
-        dict_ui_elem["guidance_scale_mid_damper"] = guidance_scale_mid_damper
-        dict_ui_elem["seed1"] = seed1
-        dict_ui_elem["seed2"] = seed2
-        
-        dict_ui_elem["parental_max_depth_influence"] = parental_max_depth_influence
-        dict_ui_elem["parental_influence"] = parental_influence
-        dict_ui_elem["parental_influence_decay"] = parental_influence_decay
-        
-        # Convert to list, as gradio doesn't seem to accept dicts
-        list_ui_elem = []
-        list_ui_keys = []
-        for k in dict_ui_elem.keys():
-            list_ui_elem.append(dict_ui_elem[k])
-            list_ui_keys.append(k)
-        self.list_ui_keys = list_ui_keys
-        
-        b_newseed1.click(self.randomize_seed1, outputs=seed1)
-        b_newseed2.click(self.randomize_seed2, outputs=seed2)
-        b_compute1.click(self.compute_img1, inputs=list_ui_elem, outputs=[img1, img2, img3, img4, img5])
-        b_compute2.click(self.compute_img2, inputs=list_ui_elem, outputs=[img2, img3, img4, img5])
-        b_compute_transition.click(self.compute_transition, 
-                                    inputs=list_ui_elem,
-                                    outputs=[img2, img3, img4, vid_transition])
-        
-        b_stackforward.click(self.stack_forward, 
-                      inputs=[prompt2, seed2], 
-                      outputs=[img1, img2, img3, img4, img5, prompt1, seed1, prompt2])
     demo.launch(share=self.share, inbrowser=True, inline=False)
diff --git a/latent_blending.py b/latent_blending.py
index ef100b1..ca32f15 100644
--- a/latent_blending.py
+++ b/latent_blending.py
@@ -231,36 +231,36 @@ class LatentBlending():
         
         if quality == 'lowest':
             num_inference_steps = 12
-            nmb_branches_final = 5
+            nmb_max_branches = 5
         elif quality == 'low':
             num_inference_steps = 15
-            nmb_branches_final = nmb_frames//16
+            nmb_max_branches = nmb_frames//16
         elif quality == 'medium':
             num_inference_steps = 30
-            nmb_branches_final = nmb_frames//8
+            nmb_max_branches = nmb_frames//8
         elif quality == 'high':
             num_inference_steps = 60
-            nmb_branches_final = nmb_frames//4
+            nmb_max_branches = nmb_frames//4
         elif quality == 'ultra':
             num_inference_steps = 100
-            nmb_branches_final = nmb_frames//2
+            nmb_max_branches = nmb_frames//2
         elif quality == 'upscaling_step1':
             num_inference_steps = 40
-            nmb_branches_final = 12
+            nmb_max_branches = 12
         elif quality == 'upscaling_step2':
             num_inference_steps = 100
-            nmb_branches_final = 6
+            nmb_max_branches = 6
         else: 
             raise ValueError(f"quality = '{quality}' not supported")
             
-        self.autosetup_branching(depth_strength, num_inference_steps, nmb_branches_final)
+        self.autosetup_branching(depth_strength, num_inference_steps, nmb_max_branches)
         
        
     def autosetup_branching(
             self, 
             depth_strength: float = 0.65,
             num_inference_steps: int = 30,
-            nmb_branches_final: int = 20,
+            nmb_max_branches: int = 20,
             nmb_mindist: int = 3,
         ):
         r"""
@@ -273,7 +273,7 @@ class LatentBlending():
                 more shallow values will go into alpha-blendy land.
             num_inference_steps: int
                 Number of diffusion steps. Higher values will take more compute time.
-            nmb_branches_final (int): The number of diffusion-generated images 
+            nmb_max_branches (int): The number of diffusion-generated images 
                 at the end of the inference.
             nmb_mindist (int): The minimum number of diffusion steps 
                 between two injections.
@@ -285,7 +285,7 @@ class LatentBlending():
         
         list_injection_idx = [0]
         list_injection_idx.extend(np.linspace(idx_injection_first, idx_injection_last, nmb_injections).astype(int))
-        list_nmb_branches = np.round(np.logspace(np.log10(2), np.log10(nmb_branches_final), nmb_injections+1)).astype(int)
+        list_nmb_branches = np.round(np.logspace(np.log10(2), np.log10(nmb_max_branches), nmb_injections+1)).astype(int)
         
         # Cleanup. There should be at least nmb_mindist diffusion steps between each injection and list_nmb_branches increases
         list_nmb_branches_clean = [list_nmb_branches[0]]
@@ -294,7 +294,7 @@ class LatentBlending():
             if idx_injection - list_injection_idx_clean[-1] >= nmb_mindist and nmb_branches > list_nmb_branches_clean[-1]:
                 list_nmb_branches_clean.append(nmb_branches)
                 list_injection_idx_clean.append(idx_injection)
-        list_nmb_branches_clean[-1] = nmb_branches_final
+        list_nmb_branches_clean[-1] = nmb_max_branches
                 
         list_injection_idx_clean = [int(l) for l in list_injection_idx_clean]
         list_nmb_branches_clean = [int(l) for l in list_nmb_branches_clean]
@@ -394,8 +394,36 @@ class LatentBlending():
             recycle_img2: Optional[bool] = False, 
             num_inference_steps: Optional[int] = 30,
             depth_strength: Optional[float] = 0.3,
+            t_compute_max_allowed: Optional[float] = None,
+            nmb_max_branches: Optional[int] = None,
             fixed_seeds: Optional[List[int]] = None,
         ):
+        r"""
+        Function for computing transitions.
+        Returns a list of transition images using spherical latent blending.
+        Args:
+            recycle_img1: Optional[bool]:
+                Don't recompute the latents for the first keyframe (purely prompt1). Saves compute.
+            recycle_img2: Optional[bool]:
+                Don't recompute the latents for the second keyframe (purely prompt2). Saves compute.
+            num_inference_steps:
+                Number of diffusion steps. Higher values will take more compute time.
+            depth_strength:
+                Determines how deep the first injection will happen. 
+                Deeper injections will cause (unwanted) formation of new structures,
+                more shallow values will go into alpha-blendy land.
+            t_compute_max_allowed:
+                Either provide t_compute_max_allowed or nmb_max_branches. 
+                The maximum time allowed for computation. Higher values give better results but take longer. 
+            nmb_max_branches: int
+                Either provide t_compute_max_allowed or nmb_max_branches. The maximum number of branches to be computed. Higher values give better
+                results. Use this if you want to have controllable results independent 
+                of your computer.
+            fixed_seeds: Optional[List[int)]:
+                You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2).
+                Otherwise random seeds will be taken.
+            
+        """
         
         # Sanity checks first
         assert self.text_embedding1 is not None, 'Set the first text embedding with .set_prompt1(...) before'
@@ -412,7 +440,8 @@ class LatentBlending():
             self.seed2 = fixed_seeds[1]
         
         # Ensure correct num_inference_steps in holder
-        self.sdh.num_inference_steps = self.num_inference_steps
+        self.num_inference_steps = num_inference_steps
+        self.sdh.num_inference_steps = num_inference_steps
         
         # Compute / Recycle first image
         if not recycle_img1 or len(self.tree_latents[0]) != self.num_inference_steps:
@@ -433,11 +462,61 @@ class LatentBlending():
         self.tree_idx_injection = [0, 0]
         
         # Set up branching scheme (dependent on provided compute time)
+        list_idx_injection, list_nmb_stems = self.get_time_based_branching(depth_strength, t_compute_max_allowed, nmb_max_branches)
+
+        # Run iteratively, starting with the longest trajectory. 
+        # Always inserting new branches where they are needed most according to image similarity
+        for s_idx in tqdm(range(len(list_idx_injection))):
+            nmb_stems = list_nmb_stems[s_idx]
+            idx_injection = list_idx_injection[s_idx]
+            
+            for i in range(nmb_stems):
+                fract_mixing, b_parent1, b_parent2 = self.get_mixing_parameters(idx_injection)
+                self.set_guidance_mid_dampening(fract_mixing)
+                list_latents = self.compute_latents_mix(fract_mixing, b_parent1, b_parent2, idx_injection)
+                self.insert_into_tree(fract_mixing, idx_injection, list_latents)
+                # print(f"fract_mixing: {fract_mixing} idx_injection {idx_injection}")
+            
+        return self.tree_final_imgs
+                
+
+    def get_time_based_branching(self, depth_strength, t_compute_max_allowed=None, nmb_max_branches=None):
+        r"""
+        Sets up the branching scheme dependent on the time that is granted for compute.
+        The scheme uses an estimation derived from the first image's computation speed.
+        Either provide t_compute_max_allowed or nmb_max_branches
+        Args:
+            depth_strength:
+                Determines how deep the first injection will happen. 
+                Deeper injections will cause (unwanted) formation of new structures,
+                more shallow values will go into alpha-blendy land.
+            t_compute_max_allowed: float
+                The maximum time allowed for computation. Higher values give better results
+                but take longer. Use this if you want to fix your waiting time for the results. 
+            nmb_max_branches: int
+                The maximum number of branches to be computed. Higher values give better
+                results. Use this if you want to have controllable results independent 
+                of your computer.
+        """
         idx_injection_base = int(round(self.num_inference_steps*depth_strength))
         list_idx_injection = np.arange(idx_injection_base, self.num_inference_steps-1, 3)
         list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32)
         t_compute = 0
-        while t_compute < self.t_compute_max_allowed:
+        
+        if nmb_max_branches is None:
+            assert t_compute_max_allowed is not None, "Either specify t_compute_max_allowed or nmb_max_branches"
+            stop_criterion = "t_compute_max_allowed"
+        elif t_compute_max_allowed is None:
+            assert nmb_max_branches is not None, "Either specify t_compute_max_allowed or nmb_max_branches"
+            stop_criterion = "nmb_max_branches"
+            nmb_max_branches -= 2 # discounting the outer frames
+        else:
+            raise ValueError("Either specify t_compute_max_allowed or nmb_max_branches")
+            
+        stop_criterion_reached = False
+        is_first_iteration = True
+        
+        while not stop_criterion_reached:
             list_compute_steps = self.num_inference_steps - list_idx_injection
             list_compute_steps *= list_nmb_stems
             t_compute = np.sum(list_compute_steps) * self.dt_per_diff  + 0.15*np.sum(list_nmb_stems)
@@ -449,23 +528,21 @@ class LatentBlending():
                     break
             if not increase_done:
                 list_nmb_stems[-1] += 1
-            # print(f"t_compute {t_compute} list_nmb_stems {list_nmb_stems}")
             
-        # Run iteratively, always inserting new branches where they are needed most
-        for s_idx in tqdm(range(len(list_idx_injection))):
-            nmb_stems = list_nmb_stems[s_idx]
-            idx_injection = list_idx_injection[s_idx]
-            
-            for i in range(nmb_stems):
-                fract_mixing, b_parent1, b_parent2 = self.get_mixing_parameters(idx_injection)
-                self.set_guidance_mid_dampening(fract_mixing)
-                # print(f"fract_mixing: {fract_mixing} idx_injection {idx_injection}")
-                list_latents = self.compute_latents_mix(fract_mixing, b_parent1, b_parent2, idx_injection)
-                self.insert_into_tree(fract_mixing, idx_injection, list_latents)
-            
-        return self.tree_final_imgs
+            if stop_criterion == "t_compute_max_allowed" and t_compute > t_compute_max_allowed:
+                stop_criterion_reached = True
+                # FIXME: also undersample here... but how... maybe drop them iteratively?
+            elif stop_criterion == "nmb_max_branches" and np.sum(list_nmb_stems) >= nmb_max_branches:
+                stop_criterion_reached = True
+                if is_first_iteration:
+                    # Need to undersample.
+                    list_idx_injection = np.linspace(list_idx_injection[0], list_idx_injection[-1], nmb_max_branches).astype(np.int32)
+                    list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32)
+            else:
+                is_first_iteration = False
                 
-
+            # print(f"t_compute {t_compute} list_nmb_stems {list_nmb_stems}")
+        return list_idx_injection, list_nmb_stems
 
     def get_mixing_parameters(self, idx_injection):
         r"""
@@ -581,7 +658,7 @@ class LatentBlending():
                 whether to return an image or the list of latents
         """
         print("starting compute_latents1")
-        list_conditionings = [self.text_embedding1]
+        list_conditionings = self.get_mixed_conditioning(0)
         t0 = time.time()
         latents_start = self.get_noise(self.seed1)
         list_latents1 = self.run_diffusion(
@@ -604,7 +681,8 @@ class LatentBlending():
             return_image: bool
                 whether to return an image or the list of latents
         """
-        list_conditionings = [self.text_embedding2]
+        print("starting compute_latents2")
+        list_conditionings = self.get_mixed_conditioning(1)
         latents_start = self.get_noise(self.seed2)
         # Influence from branch1
         if self.branch1_influence > 0.0:
@@ -630,178 +708,24 @@ class LatentBlending():
             return list_latents2
         
     def get_noise(self, seed):
-        generator = torch.Generator(device=self.sdh.device).manual_seed(int(seed))
-        shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f]
-        C, H, W = shape_latents
-        return torch.randn((1, C, H, W), generator=generator, device=self.sdh.device)
-        
-    
-    def run_transition_legacy(
-            self, 
-            recycle_img1: Optional[bool] = False, 
-            recycle_img2: Optional[bool] = False, 
-            fixed_seeds: Optional[List[int]] = None,
-            premature_stop: Optional[int] = np.inf,
-        ):
         r"""
-        Old legacy function for computing transitions.
-        Returns a list of transition images using spherical latent blending.
+        Helper function to get noise given seed.
         Args:
-            recycle_img1: Optional[bool]:
-                Don't recompute the latents for the first keyframe (purely prompt1). Saves compute.
-            recycle_img2: Optional[bool]:
-                Don't recompute the latents for the second keyframe (purely prompt2). Saves compute.
-            fixed_seeds: Optional[List[int)]:
-                You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2).
-                Otherwise random seeds will be taken.
-            premature_stop: Optional[int]:
-                Stop the computation after premature_stop frames have been computed in the transition
+            seed: int
             
         """
-        # Sanity checks first
-        assert self.text_embedding1 is not None, 'Set the first text embedding with .set_prompt1(...) before'
-        assert self.text_embedding2 is not None, 'Set the second text embedding with .set_prompt2(...) before'
-        assert self.list_injection_idx is not None, 'Set the branching structure before, by calling autosetup_branching or setup_branching'
+        generator = torch.Generator(device=self.sdh.device).manual_seed(int(seed))
+        if self.mode == 'standard':
+            shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f]
+            C, H, W = shape_latents
+        elif self.mode == 'upscale':
+            w = self.image1_lowres.size[0]
+            h = self.image1_lowres.size[1]
+            shape_latents = [self.sdh.model.channels, h, w]
+            C, H, W = shape_latents
         
-        if fixed_seeds is not None:
-            if fixed_seeds == 'randomize':
-                fixed_seeds = list(np.random.randint(0, 1000000, 2).astype(np.int32))
-            else:
-                assert len(fixed_seeds)==2, "Supply a list with len = 2"
-        
-            self.seed1 = fixed_seeds[0]
-            self.seed2 = fixed_seeds[1]
-        
-        # Process interruption variable
-        self.stop_diffusion = False
-        
-        # Ensure correct num_inference_steps in holder
-        self.sdh.num_inference_steps = self.num_inference_steps
-        
-        # Make a backup for future reference
-        self.list_nmb_branches_prev = self.list_nmb_branches[:]
-        self.list_injection_idx_prev = self.list_injection_idx[:]
-        
-        # Split the first block if there is branch1 crossfeeding
-        if self.branch1_influence > 0.0 and not self.branch1_insertion_completed:
-            assert self.list_nmb_branches[0]==2, 'branch1 influnce currently requires the self.list_nmb_branches[0] = 0'
-            self.branch1_influence = np.clip(self.branch1_influence, 0, 1)
-            self.branch1_max_depth_influence = np.clip(self.branch1_max_depth_influence, 0, 1)
-            self.list_nmb_branches.insert(1, 2)
-            idx_crossfeed = int(round(self.list_injection_idx[1]*self.branch1_max_depth_influence))
-            self.list_injection_idx_ext.insert(1, idx_crossfeed)
-            self.tree_fracts.insert(1, self.tree_fracts[0])
-            self.tree_status.insert(1, self.tree_status[0])
-            self.tree_latents.insert(1, self.tree_latents[0])
-            self.branch1_insertion_completed = True
-            
-            
-        
-        # Pre-define entire branching tree structures
-        self.tree_final_imgs = [None]*self.list_nmb_branches[-1]
-        nmb_blocks_time = len(self.list_injection_idx_ext)-1
-        if not recycle_img1 and not recycle_img2:
-            self.init_tree_struct()
-        else:
-            self.tree_final_imgs = [None]*self.list_nmb_branches[-1]
-            for t_block in range(nmb_blocks_time):
-                nmb_branches = self.list_nmb_branches[t_block]
-                for idx_branch in range(nmb_branches):
-                    self.tree_status[t_block][idx_branch] = 'untouched'
-                if recycle_img1:
-                    self.tree_status[t_block][0] = 'computed'
-                    self.tree_final_imgs[0] = self.sdh.latent2image(self.tree_latents[-1][0][-1])
-                    self.tree_final_imgs_timing[0] = 0
-                if recycle_img2:
-                    self.tree_status[t_block][-1] = 'computed'
-                    self.tree_final_imgs[-1] = self.sdh.latent2image(self.tree_latents[-1][-1][-1])
-                    self.tree_final_imgs_timing[-1] = 0
-                    
-        # setup compute order: goal: try to get last branch computed asap. 
-        # first compute the right keyframe. needs to be there in any case
-        list_compute = []
-        list_local_stem = []
-        for t_block in range(nmb_blocks_time - 1, -1, -1):
-            if self.tree_status[t_block][0] == 'untouched':
-                self.tree_status[t_block][0] = 'prefetched'
-                list_local_stem.append([t_block, 0])
-        list_compute.extend(list_local_stem[::-1]) 
-        
-        # setup compute order: start from last leafs (the final transition images) and work way down. what parents do they need?
-        for idx_leaf in range(1, self.list_nmb_branches[-1]):
-            list_local_stem = []
-            t_block = nmb_blocks_time - 1
-            t_block_prev = t_block - 1
-            self.tree_status[t_block][idx_leaf] = 'prefetched'
-            list_local_stem.append([t_block, idx_leaf])
-            idx_leaf_deep = idx_leaf
-            
-            for t_block in range(nmb_blocks_time-1, 0, -1):
-                t_block_prev = t_block - 1
-                fract_mixing = self.tree_fracts[t_block][idx_leaf_deep]
-                list_fract_mixing_prev = self.tree_fracts[t_block_prev]
-                b_parent1, b_parent2 = get_closest_idx(fract_mixing, list_fract_mixing_prev)
-                assert self.tree_status[t_block_prev][b_parent1] != 'untouched', 'Branch destruction??? This should never happen!'
-                if self.tree_status[t_block_prev][b_parent2] == 'untouched':
-                    self.tree_status[t_block_prev][b_parent2] = 'prefetched'
-                    list_local_stem.append([t_block_prev, b_parent2])
-                idx_leaf_deep = b_parent2
-            list_compute.extend(list_local_stem[::-1])        
-            
-        # Diffusion computations start here
-        time_start = time.time()
-        for t_block, idx_branch in tqdm(list_compute, desc="computing transition", smoothing=0.01):
-            if self.stop_diffusion:
-                print("run_transition: process interrupted")
-                return self.tree_final_imgs
-            if idx_branch > premature_stop:
-                print(f"run_transition: premature_stop criterion reached. returning tree with {premature_stop} branches")
-                return self.tree_final_imgs
-            
-            # print(f"computing t_block {t_block} idx_branch {idx_branch}")
-            idx_stop = self.list_injection_idx_ext[t_block+1]
-            fract_mixing = self.tree_fracts[t_block][idx_branch]
-            
-            list_conditionings = self.get_mixed_conditioning(fract_mixing)
-            self.set_guidance_mid_dampening(fract_mixing)
-            # print(f"fract_mixing {fract_mixing} guid {self.sdh.guidance_scale}")
-            if t_block == 0:
-                if fixed_seeds is not None:
-                    if idx_branch == 0:
-                        self.set_seed(fixed_seeds[0])
-                    elif idx_branch == self.list_nmb_branches[0] -1:
-                        self.set_seed(fixed_seeds[1])
-                
-                list_latents = self.run_diffusion(list_conditionings, idx_stop=idx_stop)
-                
-                # Inject latents from first branch for very first block
-                if idx_branch==1 and self.branch1_influence > 0:
-                    fract_base_influence = np.clip(self.branch1_influence, 0, 1)
-                    for i in range(len(list_latents)):
-                        list_latents[i] = interpolate_spherical(list_latents[i], self.tree_latents[0][0][i], fract_base_influence)
-            else:
-                # find parents latents
-                b_parent1, b_parent2 = get_closest_idx(fract_mixing, self.tree_fracts[t_block-1])
-                latents1 = self.tree_latents[t_block-1][b_parent1][-1]
-                if fract_mixing == 0:
-                    latents2 = latents1
-                else:
-                    latents2 = self.tree_latents[t_block-1][b_parent2][-1]
-                idx_start = self.list_injection_idx_ext[t_block]
-                fract_mixing_parental = (fract_mixing - self.tree_fracts[t_block-1][b_parent1]) / (self.tree_fracts[t_block-1][b_parent2] - self.tree_fracts[t_block-1][b_parent1]) 
-                latents_for_injection = interpolate_spherical(latents1, latents2, fract_mixing_parental)
-                list_latents = self.run_diffusion(list_conditionings, latents_for_injection, idx_start=idx_start, idx_stop=idx_stop)
-            
-            self.tree_latents[t_block][idx_branch] = list_latents
-            self.tree_status[t_block][idx_branch] = 'computed'
-            
-            # Convert latents to image directly for the last t_block
-            if t_block == nmb_blocks_time-1:
-                self.tree_final_imgs[idx_branch] = self.sdh.latent2image(list_latents[-1])
-                self.tree_final_imgs_timing[idx_branch] = time.time() - time_start
-            
-        return self.tree_final_imgs
-                
+        return torch.randn((1, C, H, W), generator=generator, device=self.sdh.device)
+    
 
     def run_multi_transition(
             self,
@@ -906,24 +830,31 @@ class LatentBlending():
                 return_image = return_image,
                 )
         
+        elif self.mode == 'upscale':
+            cond = list_conditionings[0]
+            uc_full = list_conditionings[1]
+            return self.sdh.run_diffusion_upscaling(
+                cond, 
+                uc_full, 
+                latents_start=latents_start, 
+                idx_start=idx_start, 
+                list_latents_mixing = list_latents_mixing,
+                mixing_coeffs = mixing_coeffs,
+                return_image=return_image)
+        
         # elif self.mode == 'inpaint':
         #     text_embeddings = list_conditionings[0]
         #     assert self.sdh.image_source is not None, "image_source is None. Please run init_inpainting first."
         #     assert self.sdh.mask_image is not None, "image_source is None. Please run init_inpainting first."
         #     return self.sdh.run_diffusion_inpaint(text_embeddings, latents_for_injection=latents_for_injection, idx_start=idx_start, idx_stop=idx_stop, return_image=return_image)
-        
-        # # FIXME LONG LINE and bad args
-        # elif self.mode == 'upscale':
-        #     cond = list_conditionings[0]
-        #     uc_full = list_conditionings[1]
-        #     return self.sdh.run_diffusion_upscaling(cond, uc_full, latents_for_injection=latents_for_injection, idx_start=idx_start, idx_stop=idx_stop, return_image=return_image)
 
+    # FIXME. new transition engine
     def run_upscaling_step1(
             self, 
             dp_img: str,
             depth_strength: float = 0.65,
             num_inference_steps: int = 30,
-            nmb_branches_final: int = 10,
+            nmb_max_branches: int = 10,
             fixed_seeds: Optional[List[int]] = None,
             ):
         r"""
@@ -932,6 +863,7 @@ class LatentBlending():
             dp_img: 
                 Path to directory where the low-res images and yaml will be saved to.
                 This directory cannot exist and will be created here.
+            FIXME
             quality: str 
                 Determines how many diffusion steps are being made + how many branches in total.
                 We suggest to leave it with upscaling_step1 which has 10 final branches.
@@ -951,7 +883,6 @@ class LatentBlending():
             fixed_seeds = list(np.random.randint(0, 1000000, 2).astype(np.int32))
 
         # Run latent blending
-        self.autosetup_branching(depth_strength, num_inference_steps, nmb_branches_final)
         imgs_transition = self.run_transition(fixed_seeds=fixed_seeds)
         self.write_imgs_transition(dp_img, imgs_transition)
 
@@ -962,13 +893,14 @@ class LatentBlending():
             self, 
             dp_img: str,
             depth_strength: float = 0.65,
-            num_inference_steps: int = 30,
-            nmb_branches_final: int = 10,
+            num_inference_steps: int = 100,
+            nmb_max_branches_highres: int = 5,
+            nmb_max_branches_lowres: int = 6,
             fixed_seeds: Optional[List[int]] = None,
             ):
         
         fp_yml = os.path.join(dp_img, "lowres.yaml")
-        fp_movie = os.path.join(dp_img, "movie.mp4")
+        fp_movie = os.path.join(dp_img, "movie_highres.mp4")
         fps = 24
         ms = MovieSaver(fp_movie, fps=fps)
         assert os.path.isfile(fp_yml), "lowres.yaml does not exist. did you forget run_upscaling_step1?"
@@ -978,8 +910,9 @@ class LatentBlending():
         nmb_images_lowres = dict_stuff['nmb_images']
         prompt1 = dict_stuff['prompt1']
         prompt2 = dict_stuff['prompt2']
+        idx_img_lowres = np.round(np.linspace(0, nmb_images_lowres-1, nmb_max_branches_lowres)).astype(np.int32)
         imgs_lowres = []
-        for i in range(nmb_images_lowres):
+        for i in idx_img_lowres:
             fp_img_lowres = os.path.join(dp_img, f"lowres_img_{str(i).zfill(4)}.jpg")
             assert os.path.isfile(fp_img_lowres), f"{fp_img_lowres} does not exist. did you forget run_upscaling_step1?"
             imgs_lowres.append(Image.open(fp_img_lowres))
@@ -989,13 +922,12 @@ class LatentBlending():
         text_embeddingA = self.sdh.get_text_embedding(prompt1)
         text_embeddingB = self.sdh.get_text_embedding(prompt2)
         
-        self.autosetup_branching(depth_strength, num_inference_steps, nmb_branches_final)
-        
+        #FIXME: have a total length for the whole video section
         duration_single_trans = 3
-        list_fract_mixing = np.linspace(0, 1, nmb_images_lowres-1)
+        list_fract_mixing = np.linspace(0, 1, nmb_max_branches_lowres-1)
         
-        for i in range(nmb_images_lowres-1):
-            print(f"Starting movie segment {i+1}/{nmb_images_lowres-1}")
+        for i in range(nmb_max_branches_lowres-1):
+            print(f"Starting movie segment {i+1}/{nmb_max_branches_lowres-1}")
             
             self.text_embedding1 = interpolate_linear(text_embeddingA, text_embeddingB, list_fract_mixing[i])
             self.text_embedding2 = interpolate_linear(text_embeddingA, text_embeddingB, 1-list_fract_mixing[i])
@@ -1008,7 +940,15 @@ class LatentBlending():
             
             self.set_image1(imgs_lowres[i])
             self.set_image2(imgs_lowres[i+1])
-            list_imgs = self.run_transition(recycle_img1=recycle_img1)
+            
+            list_imgs = self.run_transition(
+                recycle_img1 = recycle_img1,
+                recycle_img2 = False,
+                num_inference_steps = num_inference_steps, 
+                depth_strength = depth_strength,
+                nmb_max_branches = nmb_max_branches_highres,
+                )
+            
             list_imgs_interp = add_frames_linear_interp(list_imgs, fps, duration_single_trans)
             
             # Save movie frame
@@ -1075,11 +1015,12 @@ class LatentBlending():
         return self.sdh.get_text_embedding(prompt)
     
 
-    def write_imgs_transition(self, dp_img, imgs_transition):
+    def write_imgs_transition(self, dp_img):
         r"""
         Writes the transition images into the folder dp_img.
         """
-        os.makedirs(dp_img)
+        imgs_transition = self.tree_final_imgs
+        os.makedirs(dp_img, exist_ok=True)
         for i, img in enumerate(imgs_transition):
             img_leaf = Image.fromarray(img)
             img_leaf.save(os.path.join(dp_img, f"lowres_img_{str(i).zfill(4)}.jpg"))
@@ -1090,6 +1031,7 @@ class LatentBlending():
         
     def save_statedict(self, fp_yml):
         # Dump everything relevant into yaml
+        imgs_transition = self.tree_final_imgs
         state_dict = self.get_state_dict()
         state_dict['nmb_images'] = len(imgs_transition)
         yml_save(fp_yml, state_dict)
@@ -1098,7 +1040,9 @@ class LatentBlending():
         state_dict = {}
         grab_vars = ['prompt1', 'prompt2', 'seed1', 'seed2', 'height', 'width',
                      'num_inference_steps', 'depth_strength', 'guidance_scale',
-                     'guidance_scale_mid_damper', 'mid_compression_scaler', 'negative_prompt']
+                     'guidance_scale_mid_damper', 'mid_compression_scaler', 'negative_prompt',
+                     'branch1_influence', 'branch1_max_depth_influence', 'branch1_influence_decay'
+                     'parental_influence', 'parental_max_depth_influence', 'parental_influence_decay']
         for v in grab_vars:
             if hasattr(self, v):
                 if v == 'seed1' or v == 'seed2':
@@ -1107,9 +1051,11 @@ class LatentBlending():
                     state_dict[v] = float(getattr(self, v))
                     
                 else:
-                    state_dict[v] = getattr(self, v)
+                    try:
+                        state_dict[v] = getattr(self, v)
+                    except Exception as e:
+                        pass
                 
-            
         return state_dict
         
     def randomize_seed(self):
@@ -1163,8 +1109,7 @@ class LatentBlending():
         as in run_multi_transition()
         """ 
         # Move over all latents
-        for t_block in range(len(self.tree_latents)):
-            self.tree_latents[t_block][0] = self.tree_latents[t_block][-1]
+        self.tree_latents[0] = self.tree_latents[-1]
         
         # Move over prompts and text embeddings
         self.prompt1 = self.prompt2
diff --git a/stable_diffusion_holder.py b/stable_diffusion_holder.py
index 235e319..8bcf28e 100644
--- a/stable_diffusion_holder.py
+++ b/stable_diffusion_holder.py
@@ -285,8 +285,7 @@ class StableDiffusionHolder:
             return_image: Optional[bool] = False,
         ):
         r"""
-        Wrapper function for run_diffusion_standard and run_diffusion_inpaint.
-        Depending on the mode, the correct one will be executed.
+        Diffusion standard version. 
         
         Args:
             text_embeddings: torch.FloatTensor 
@@ -363,6 +362,99 @@ class StableDiffusionHolder:
                     return self.latent2image(latents)
                 else:
                     return list_latents_out
+                
+                
+    @torch.no_grad()
+    def run_diffusion_upscaling(
+            self, 
+            cond,
+            uc_full,
+            latents_start: torch.FloatTensor, 
+            idx_start: int = -1, 
+            list_latents_mixing = None, 
+            mixing_coeffs = 0.0,
+            return_image: Optional[bool] = False
+        ):
+        r"""
+        Diffusion upscaling version. 
+        # FIXME
+        Args:
+            ??
+            latents_for_injection: torch.FloatTensor 
+                Latents that are used for injection
+            idx_start: int
+                Index of the diffusion process start and where the latents_for_injection are injected
+            return_image: Optional[bool]
+                Optionally return image directly
+        """
+ 
+        # Asserts
+        if type(mixing_coeffs) == float:
+            list_mixing_coeffs = self.num_inference_steps*[mixing_coeffs]
+        elif type(mixing_coeffs) == list:
+            assert len(mixing_coeffs) == self.num_inference_steps
+            list_mixing_coeffs = mixing_coeffs
+        else:
+            raise ValueError("mixing_coeffs should be float or list with len=num_inference_steps")
+        
+        if np.sum(list_mixing_coeffs) > 0:
+            assert len(list_latents_mixing) == self.num_inference_steps
+        
+        precision_scope = autocast if self.precision == "autocast" else nullcontext
+        generator = torch.Generator(device=self.device).manual_seed(int(self.seed))
+        
+        h = uc_full['c_concat'][0].shape[2]        
+        w = uc_full['c_concat'][0].shape[3]  
+        
+        with precision_scope("cuda"):
+            with self.model.ema_scope():
+
+                shape_latents = [self.model.channels, h, w]
+    
+                self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=self.ddim_eta, verbose=False)
+                C, H, W = shape_latents
+                size = (1, C, H, W)
+                b = size[0]
+                
+                latents = latents_start.clone()
+    
+                timesteps = self.sampler.ddim_timesteps
+    
+                time_range = np.flip(timesteps)
+                total_steps = timesteps.shape[0]
+                
+                # collect latents
+                list_latents_out = []
+                for i, step in enumerate(time_range):
+                    # Set the right starting latents
+                    if i < idx_start:
+                        list_latents_out.append(None)
+                        continue
+                    elif i == idx_start:
+                        latents = latents_start.clone()
+                    
+                    # Mix the latents. 
+                    if i > 0 and list_mixing_coeffs[i]>0:
+                        latents_mixtarget = list_latents_mixing[i-1].clone()
+                        latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i])
+                    
+                    # print(f"diffusion iter {i}")
+                    index = total_steps - i - 1
+                    ts = torch.full((b,), step, device=self.device, dtype=torch.long)
+                    outs = self.sampler.p_sample_ddim(latents, cond, ts, index=index, use_original_steps=False,
+                                              quantize_denoised=False, temperature=1.0,
+                                              noise_dropout=0.0, score_corrector=None,
+                                              corrector_kwargs=None,
+                                              unconditional_guidance_scale=self.guidance_scale,
+                                              unconditional_conditioning=uc_full,
+                                              dynamic_threshold=None)
+                    latents, pred_x0 = outs
+                    list_latents_out.append(latents.clone())
+    
+                if return_image:        
+                    return self.latent2image(latents)
+                else:
+                    return list_latents_out                    
 
     @torch.no_grad()
     def run_diffusion_inpaint(
@@ -473,93 +565,6 @@ class StableDiffusionHolder:
                     return self.latent2image(latents)
                 else:
                     return list_latents_out
-                
-    @torch.no_grad()
-    def run_diffusion_upscaling(
-            self, 
-            cond,
-            uc_full,
-            latents_for_injection: torch.FloatTensor = None, 
-            idx_start: int = -1, 
-            idx_stop: int = -1, 
-            return_image: Optional[bool] = False
-        ):
-        r"""
-        Wrapper function for run_diffusion_standard and run_diffusion_inpaint.
-        Depending on the mode, the correct one will be executed.
-        
-        Args:
-            ??
-            latents_for_injection: torch.FloatTensor 
-                Latents that are used for injection
-            idx_start: int
-                Index of the diffusion process start and where the latents_for_injection are injected
-            idx_stop: int
-                Index of the diffusion process end.
-            return_image: Optional[bool]
-                Optionally return image directly
-        """
- 
-        
-        if latents_for_injection is None:
-            do_inject_latents = False
-        else:
-            do_inject_latents = True    
-        
-        precision_scope = autocast if self.precision == "autocast" else nullcontext
-        generator = torch.Generator(device=self.device).manual_seed(int(self.seed))
-        
-        h = uc_full['c_concat'][0].shape[2]        
-        w = uc_full['c_concat'][0].shape[3]  
-        
-        with precision_scope("cuda"):
-            with self.model.ema_scope():
-                
-
-                shape_latents = [self.model.channels, h, w]
-    
-                self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=self.ddim_eta, verbose=False)
-                C, H, W = shape_latents
-                size = (1, C, H, W)
-                b = size[0]
-                
-                latents = torch.randn(size, generator=generator, device=self.device)
-    
-                timesteps = self.sampler.ddim_timesteps
-    
-                time_range = np.flip(timesteps)
-                total_steps = timesteps.shape[0]
-                
-                # collect latents
-                list_latents_out = []
-                for i, step in enumerate(time_range):
-                    if do_inject_latents:
-                        # Inject latent at right place
-                        if i < idx_start:
-                            continue
-                        elif i == idx_start:
-                            latents = latents_for_injection.clone()
-                    
-                    if i == idx_stop:
-                        return list_latents_out
-                    
-                    # print(f"diffusion iter {i}")
-                    index = total_steps - i - 1
-                    ts = torch.full((b,), step, device=self.device, dtype=torch.long)
-                    outs = self.sampler.p_sample_ddim(latents, cond, ts, index=index, use_original_steps=False,
-                                              quantize_denoised=False, temperature=1.0,
-                                              noise_dropout=0.0, score_corrector=None,
-                                              corrector_kwargs=None,
-                                              unconditional_guidance_scale=self.guidance_scale,
-                                              unconditional_conditioning=uc_full,
-                                              dynamic_threshold=None)
-                    latents, pred_x0 = outs
-                    list_latents_out.append(latents.clone())
-    
-                if return_image:        
-                    return self.latent2image(latents)
-                else:
-                    return list_latents_out                    
 
     @torch.no_grad()
     def latent2image(