sd 2.1
This commit is contained in:
parent
b99d18f7b3
commit
ab500dd288
|
@ -32,8 +32,8 @@ torch.set_grad_enabled(False)
|
||||||
|
|
||||||
#%% First let us spawn a stable diffusion holder
|
#%% First let us spawn a stable diffusion holder
|
||||||
device = "cuda:0"
|
device = "cuda:0"
|
||||||
fp_ckpt = "../stable_diffusion_models/ckpt/768-v-ema.ckpt"
|
fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_768-ema-pruned.ckpt"
|
||||||
fp_config = '../stablediffusion/configs/stable-diffusion/v2-inference-v.yaml'
|
fp_config = 'configs/v2-inference-v.yaml'
|
||||||
|
|
||||||
sdh = StableDiffusionHolder(fp_ckpt, fp_config, device)
|
sdh = StableDiffusionHolder(fp_ckpt, fp_config, device)
|
||||||
|
|
||||||
|
|
|
@ -17,8 +17,7 @@ from functools import partial
|
||||||
import itertools
|
import itertools
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from torchvision.utils import make_grid
|
from torchvision.utils import make_grid
|
||||||
from pytorch_lightning.utilities.rank_zero import rank_zero_only
|
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||||
# from pytorch_lightning.utilities.distributed import rank_zero_only
|
|
||||||
from omegaconf import ListConfig
|
from omegaconf import ListConfig
|
||||||
|
|
||||||
from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
|
from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
|
||||||
|
@ -391,7 +390,7 @@ class DDPM(pl.LightningModule):
|
||||||
elif self.parameterization == "v":
|
elif self.parameterization == "v":
|
||||||
target = self.get_v(x_start, noise, t)
|
target = self.get_v(x_start, noise, t)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
|
raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported")
|
||||||
|
|
||||||
loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
|
loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,9 @@ try:
|
||||||
except:
|
except:
|
||||||
XFORMERS_IS_AVAILBLE = False
|
XFORMERS_IS_AVAILBLE = False
|
||||||
|
|
||||||
|
# CrossAttn precision handling
|
||||||
|
import os
|
||||||
|
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
return val is not None
|
return val is not None
|
||||||
|
@ -167,7 +170,14 @@ class CrossAttention(nn.Module):
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||||
|
|
||||||
|
# force cast to fp32 to avoid overflowing
|
||||||
|
if _ATTN_PRECISION =="fp32":
|
||||||
|
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||||
|
q, k = q.float(), k.float()
|
||||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||||
|
else:
|
||||||
|
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||||
|
|
||||||
del q, k
|
del q, k
|
||||||
|
|
||||||
if exists(mask):
|
if exists(mask):
|
||||||
|
|
Before Width: | Height: | Size: 431 KiB After Width: | Height: | Size: 431 KiB |
Loading…
Reference in New Issue