in sat/diffusion_video.py [0:0]
def log_conditionings(self, batch: Dict, n: int) -> Dict:
"""
Defines heuristics to log different conditionings.
These can be lists of strings (text-to-image), tensors, ints, ...
"""
image_h, image_w = batch[self.input_key].shape[3:]
log = dict()
for embedder in self.conditioner.embedders:
if ((self.log_keys is None) or (embedder.input_key in self.log_keys)) and not self.no_cond_log:
x = batch[embedder.input_key][:n]
if isinstance(x, torch.Tensor):
if x.dim() == 1:
# class-conditional, convert integer to string
x = [str(x[i].item()) for i in range(x.shape[0])]
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
elif x.dim() == 2:
# size and crop cond and the like
x = ["x".join([str(xx) for xx in x[i].tolist()]) for i in range(x.shape[0])]
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
else:
raise NotImplementedError()
elif isinstance(x, (List, ListConfig)):
if isinstance(x[0], str):
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
else:
raise NotImplementedError()
else:
raise NotImplementedError()
log[embedder.input_key] = xc
return log