in glide_text2im/clip/model_creation.py [0:0]
def cond_fn(self, prompts: List[str], grad_scale: float) -> Callable[..., torch.Tensor]:
with torch.no_grad():
z_t = self.text_embeddings(prompts)
def cond_fn(x, t, grad_scale=grad_scale, **kwargs):
with torch.enable_grad():
x_var = x.detach().requires_grad_(True)
z_i = self.image_embeddings(x_var, t)
loss = torch.exp(self.logit_scale) * (z_t * z_i).sum()
grad = torch.autograd.grad(loss, x_var)[0].detach()
return grad * grad_scale
return cond_fn