diff --git a/latent_blending.py b/latent_blending.py
index 12ea8ea..4a5fbf1 100644
--- a/latent_blending.py
+++ b/latent_blending.py
@@ -123,6 +123,7 @@ class LatentBlending():
         self.multi_transition_img_first = None
         self.multi_transition_img_last = None
         self.dt_per_diff = 0
+        self.spatial_mask = None
         
         self.lpips = lpips.LPIPS(net='alex').cuda(self.device)
         
@@ -277,6 +278,9 @@ class LatentBlending():
         self.tree_final_imgs = [self.sdh.latent2image((self.tree_latents[0][-1])), self.sdh.latent2image((self.tree_latents[-1][-1]))]
         self.tree_idx_injection = [0, 0]
         
+        # Hard-fix. Apply spatial mask only for list_latents2 but not for transition. WIP...
+        self.spatial_mask = None
+        
         # Set up branching scheme (dependent on provided compute time)
         list_idx_injection, list_nmb_stems = self.get_time_based_branching(depth_strength, t_compute_max_allowed, nmb_max_branches)
 
@@ -296,6 +300,109 @@ class LatentBlending():
         return self.tree_final_imgs
                 
 
+    def compute_latents1(self, return_image=False):
+        r"""
+        Runs a diffusion trajectory for the first image
+        Args:
+            return_image: bool
+                whether to return an image or the list of latents
+        """
+        print("starting compute_latents1")
+        list_conditionings = self.get_mixed_conditioning(0)
+        t0 = time.time()
+        latents_start = self.get_noise(self.seed1)
+        list_latents1 = self.run_diffusion(
+            list_conditionings, 
+            latents_start = latents_start,
+            idx_start = 0
+            )
+        t1 = time.time()
+        self.dt_per_diff = (t1-t0) / self.num_inference_steps
+        self.tree_latents[0] = list_latents1
+        if return_image:
+            return self.sdh.latent2image(list_latents1[-1])
+        else:
+            return list_latents1
+    
+    def compute_latents2(self, return_image=False):
+        r"""
+        Runs a diffusion trajectory for the last image, which may be affected by the first image's trajectory.
+        Args:
+            return_image: bool
+                whether to return an image or the list of latents
+        """
+        print("starting compute_latents2")
+        list_conditionings = self.get_mixed_conditioning(1)
+        latents_start = self.get_noise(self.seed2)
+        # Influence from branch1
+        if self.branch1_influence > 0.0:
+            # Set up the mixing_coeffs
+            idx_mixing_stop = int(round(self.num_inference_steps*self.branch1_max_depth_influence))
+            mixing_coeffs = list(np.linspace(self.branch1_influence, self.branch1_influence*self.branch1_influence_decay, idx_mixing_stop))     
+            mixing_coeffs.extend((self.num_inference_steps-idx_mixing_stop)*[0])
+            list_latents_mixing = self.tree_latents[0]
+            list_latents2 = self.run_diffusion(
+                list_conditionings, 
+                latents_start = latents_start,
+                idx_start = 0,
+                list_latents_mixing = list_latents_mixing,
+                mixing_coeffs = mixing_coeffs
+                )
+        else:
+            list_latents2 = self.run_diffusion(list_conditionings, latents_start)
+        self.tree_latents[-1] = list_latents2
+        
+        if return_image:
+            return self.sdh.latent2image(list_latents2[-1])
+        else:
+            return list_latents2
+
+
+    def compute_latents_mix(self, fract_mixing, b_parent1, b_parent2, idx_injection):    
+        r"""
+        Runs a diffusion trajectory, using the latents from the respective parents
+        Args:
+            fract_mixing: float
+                the fraction along the transition axis [0, 1]
+            b_parent1: int
+                index of parent1 to be used
+            b_parent2: int
+                index of parent2 to be used
+            idx_injection: int
+                the index in terms of diffusion steps, where the next insertion will start.
+        """
+        list_conditionings = self.get_mixed_conditioning(fract_mixing)
+        fract_mixing_parental = (fract_mixing - self.tree_fracts[b_parent1]) / (self.tree_fracts[b_parent2] - self.tree_fracts[b_parent1]) 
+        # idx_reversed = self.num_inference_steps - idx_injection
+        
+        list_latents_parental_mix = []
+        for i in range(self.num_inference_steps):
+            latents_p1 = self.tree_latents[b_parent1][i]
+            latents_p2 = self.tree_latents[b_parent2][i]
+            if latents_p1 is None or latents_p2 is None:
+                latents_parental = None
+            else:
+                latents_parental = interpolate_spherical(latents_p1, latents_p2, fract_mixing_parental)
+            list_latents_parental_mix.append(latents_parental)
+
+        idx_mixing_stop = int(round(self.num_inference_steps*self.parental_max_depth_influence))
+        mixing_coeffs = idx_injection*[self.parental_influence]
+        nmb_mixing = idx_mixing_stop - idx_injection
+        if nmb_mixing > 0:
+            mixing_coeffs.extend(list(np.linspace(self.parental_influence, self.parental_influence*self.parental_influence_decay, nmb_mixing)))     
+        mixing_coeffs.extend((self.num_inference_steps-len(mixing_coeffs))*[0])
+        
+        latents_start = list_latents_parental_mix[idx_injection-1]
+        list_latents = self.run_diffusion(
+            list_conditionings, 
+            latents_start = latents_start,
+            idx_start = idx_injection,
+            list_latents_mixing = list_latents_parental_mix,
+            mixing_coeffs = mixing_coeffs
+            )
+        
+        return list_latents
+
     def get_time_based_branching(self, depth_strength, t_compute_max_allowed=None, nmb_max_branches=None):
         r"""
         Sets up the branching scheme dependent on the time that is granted for compute.
@@ -419,110 +526,34 @@ class LatentBlending():
         self.tree_fracts.insert(b_parent1+1, fract_mixing)
         self.tree_idx_injection.insert(b_parent1+1, idx_injection)
             
-        
-    def compute_latents1(self, return_image=False):
-        r"""
-        Runs a diffusion trajectory for the first image
-        Args:
-            return_image: bool
-                whether to return an image or the list of latents
-        """
-        print("starting compute_latents1")
-        list_conditionings = self.get_mixed_conditioning(0)
-        t0 = time.time()
-        latents_start = self.get_noise(self.seed1)
-        list_latents1 = self.run_diffusion(
-            list_conditionings, 
-            latents_start = latents_start,
-            idx_start = 0
-            )
-        t1 = time.time()
-        self.dt_per_diff = (t1-t0) / self.num_inference_steps
-        self.tree_latents[0] = list_latents1
-        if return_image:
-            return self.sdh.latent2image(list_latents1[-1])
-        else:
-            return list_latents1
     
-    def compute_latents2(self, return_image=False):
-        r"""
-        Runs a diffusion trajectory for the last image, which may be affected by the first image's trajectory.
-        Args:
-            return_image: bool
-                whether to return an image or the list of latents
-        """
-        print("starting compute_latents2")
-        list_conditionings = self.get_mixed_conditioning(1)
-        latents_start = self.get_noise(self.seed2)
-        # Influence from branch1
-        if self.branch1_influence > 0.0:
-            # Set up the mixing_coeffs
-            idx_mixing_stop = int(round(self.num_inference_steps*self.branch1_max_depth_influence))
-            mixing_coeffs = list(np.linspace(self.branch1_influence, self.branch1_influence*self.branch1_influence_decay, idx_mixing_stop))     
-            mixing_coeffs.extend((self.num_inference_steps-idx_mixing_stop)*[0])
-            list_latents_mixing = self.tree_latents[0]
-            list_latents2 = self.run_diffusion(
-                list_conditionings, 
-                latents_start = latents_start,
-                idx_start = 0,
-                list_latents_mixing = list_latents_mixing,
-                mixing_coeffs = mixing_coeffs
-                )
-        else:
-            list_latents2 = self.run_diffusion(list_conditionings, latents_start)
-        self.tree_latents[-1] = list_latents2
-        
-        if return_image:
-            return self.sdh.latent2image(list_latents2[-1])
-        else:
-            return list_latents2
-
-
-    def compute_latents_mix(self, fract_mixing, b_parent1, b_parent2, idx_injection):    
-        r"""
-        Runs a diffusion trajectory, using the latents from the respective parents
-        Args:
-            fract_mixing: float
-                the fraction along the transition axis [0, 1]
-            b_parent1: int
-                index of parent1 to be used
-            b_parent2: int
-                index of parent2 to be used
-            idx_injection: int
-                the index in terms of diffusion steps, where the next insertion will start.
-        """
-        list_conditionings = self.get_mixed_conditioning(fract_mixing)
-        fract_mixing_parental = (fract_mixing - self.tree_fracts[b_parent1]) / (self.tree_fracts[b_parent2] - self.tree_fracts[b_parent1]) 
-        # idx_reversed = self.num_inference_steps - idx_injection
-        
-        list_latents_parental_mix = []
-        for i in range(self.num_inference_steps):
-            latents_p1 = self.tree_latents[b_parent1][i]
-            latents_p2 = self.tree_latents[b_parent2][i]
-            if latents_p1 is None or latents_p2 is None:
-                latents_parental = None
-            else:
-                latents_parental = interpolate_spherical(latents_p1, latents_p2, fract_mixing_parental)
-            list_latents_parental_mix.append(latents_parental)
-
-        idx_mixing_stop = int(round(self.num_inference_steps*self.parental_max_depth_influence))
-        mixing_coeffs = idx_injection*[self.parental_influence]
-        nmb_mixing = idx_mixing_stop - idx_injection
-        if nmb_mixing > 0:
-            mixing_coeffs.extend(list(np.linspace(self.parental_influence, self.parental_influence*self.parental_influence_decay, nmb_mixing)))     
-        mixing_coeffs.extend((self.num_inference_steps-len(mixing_coeffs))*[0])
-        
-        latents_start = list_latents_parental_mix[idx_injection-1]
-        list_latents = self.run_diffusion(
-            list_conditionings, 
-            latents_start = latents_start,
-            idx_start = idx_injection,
-            list_latents_mixing = list_latents_parental_mix,
-            mixing_coeffs = mixing_coeffs
-            )
-        
-        return list_latents
+    def get_spatial_mask_template(self):    
+        shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f]
+        C, H, W = shape_latents
+        return np.ones((H, W))
     
+    def set_spatial_mask(self, img_mask):
+        r"""
+        Helper function to #FIXME
+        Args:
+            seed: int
+            
+        """
+        
+        shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f]
+        C, H, W = shape_latents
+        img_mask = np.asarray(img_mask)
+        assert len(img_mask.shape) == 2, "Currently, only 2D images are supported as mask"
+        img_mask = np.clip(img_mask, 0, 1)
+        assert img_mask.shape[0] == H, f"Your mask needs to be of dimension {H} x {W}"
+        assert img_mask.shape[1] == W, f"Your mask needs to be of dimension {H} x {W}"
+        spatial_mask = torch.from_numpy(img_mask).to(device=self.device)
+        spatial_mask = torch.unsqueeze(spatial_mask, 0)
+        spatial_mask = spatial_mask.repeat((C,1,1))
+        spatial_mask = torch.unsqueeze(spatial_mask, 0)
+        
+        self.spatial_mask = spatial_mask
+        
         
     def get_noise(self, seed):
         r"""
@@ -585,6 +616,7 @@ class LatentBlending():
                 idx_start = idx_start,
                 list_latents_mixing = list_latents_mixing,
                 mixing_coeffs = mixing_coeffs,
+                spatial_mask =  self.spatial_mask,
                 return_image = return_image,
                 )
         
diff --git a/stable_diffusion_holder.py b/stable_diffusion_holder.py
index 8bcf28e..fcd720d 100644
--- a/stable_diffusion_holder.py
+++ b/stable_diffusion_holder.py
@@ -218,37 +218,6 @@ class StableDiffusionHolder:
         if len(self.negative_prompt) > 1:
             self.negative_prompt = [self.negative_prompt[0]]
 
-    def init_inpainting(
-            self, 
-            image_source: Union[Image.Image, np.ndarray] = None, 
-            mask_image: Union[Image.Image, np.ndarray] = None, 
-            init_empty: Optional[bool] = False,
-        ):
-        r"""
-        Initializes inpainting with a source and maks image.
-        Args:
-            image_source: Union[Image.Image, np.ndarray]
-                Source image onto which the mask will be applied.
-            mask_image: Union[Image.Image, np.ndarray]
-                Mask image, value = 0 will stay untouched, value = 255 subjet to diffusion
-            init_empty: Optional[bool]:
-                Initialize inpainting with an empty image and mask, effectively disabling inpainting,
-                useful for generating a first image for transitions using diffusion.
-        """
-        if not init_empty:
-            assert image_source is not None, "init_inpainting: you need to provide image_source"
-            assert mask_image is not None, "init_inpainting: you need to provide mask_image"
-            if type(image_source) == np.ndarray:
-                image_source = Image.fromarray(image_source)
-            self.image_source = image_source
-            
-            if type(mask_image) == np.ndarray:
-                mask_image = Image.fromarray(mask_image)
-            self.mask_image = mask_image
-        else:
-            self.mask_image  = self.mask_empty
-            self.image_source  = self.image_empty
-
 
     def get_text_embedding(self, prompt):
         c = self.model.get_learned_conditioning(prompt)
@@ -282,6 +251,7 @@ class StableDiffusionHolder:
             idx_start: int = 0, 
             list_latents_mixing = None, 
             mixing_coeffs = 0.0,
+            spatial_mask = None,
             return_image: Optional[bool] = False,
         ):
         r"""
@@ -295,7 +265,7 @@ class StableDiffusionHolder:
             idx_start: int
                 Index of the diffusion process start and where the latents_for_injection are injected
             mixing_coeff:
-                # FIXME
+                # FIXME spatial_mask
             return_image: Optional[bool]
                 Optionally return image directly
             
@@ -313,6 +283,7 @@ class StableDiffusionHolder:
         if np.sum(list_mixing_coeffs) > 0:
             assert len(list_latents_mixing) == self.num_inference_steps
         
+        
         precision_scope = autocast if self.precision == "autocast" else nullcontext
         
         with precision_scope("cuda"):
@@ -345,6 +316,10 @@ class StableDiffusionHolder:
                     if i > 0 and list_mixing_coeffs[i]>0:
                         latents_mixtarget = list_latents_mixing[i-1].clone()
                         latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i])
+                        
+                    if spatial_mask is not None and list_latents_mixing is not None:
+                        latents = interpolate_spherical(latents, list_latents_mixing[i-1], 1-spatial_mask)
+                        # latents[:,:,-15:,:] = latents_mixtarget[:,:,-15:,:]
                     
                     index = total_steps - i - 1
                     ts = torch.full((1,), step, device=self.device, dtype=torch.long)