in train.py [0:0]
def evaluate_model(model, val_loader, device, global_step, max_val_item_count, weight_dtype, disable_pbar):
# Evaluation phase
model.eval()
val_loss = 0
with torch.no_grad():
val_item_count = 0
for batch in tqdm(val_loader, desc=f"Evaluation at step {global_step}", disable=disable_pbar):
val_item_count += len(batch)
# Prepare the input and target tensors
color_inputs, colors = batch["color_inputs"], batch["colors"]
lighting_inputs, lightings = batch["lighting_inputs"], batch["lightings"]
lighting_type_inputs, lighting_types = batch["lighting_type_inputs"], batch["lighting_types"]
composition_inputs, compositions = batch["composition_inputs"], batch["compositions"]
losses = []
for inputs, labels in [
(color_inputs, colors),
(lighting_inputs, lightings),
(lighting_type_inputs, lighting_types),
(composition_inputs, compositions),
]:
losses.append(forward_with_model(model, inputs, labels, weight_dtype=weight_dtype).loss)
loss = torch.stack(losses).mean()
val_loss += loss.item()
if val_item_count > max_val_item_count:
break
avg_val_loss = val_loss / val_item_count
model.train()
return avg_val_loss