def evaluate_model()

in train.py [0:0]


def evaluate_model(model, val_loader, device, global_step, max_val_item_count, weight_dtype, disable_pbar):
    # Evaluation phase
    model.eval()
    val_loss = 0
    with torch.no_grad():
        val_item_count = 0
        for batch in tqdm(val_loader, desc=f"Evaluation at step {global_step}", disable=disable_pbar):
            val_item_count += len(batch)

            # Prepare the input and target tensors
            color_inputs, colors = batch["color_inputs"], batch["colors"]
            lighting_inputs, lightings = batch["lighting_inputs"], batch["lightings"]
            lighting_type_inputs, lighting_types = batch["lighting_type_inputs"], batch["lighting_types"]
            composition_inputs, compositions = batch["composition_inputs"], batch["compositions"]

            losses = []
            for inputs, labels in [
                (color_inputs, colors),
                (lighting_inputs, lightings),
                (lighting_type_inputs, lighting_types),
                (composition_inputs, compositions),
            ]:
                losses.append(forward_with_model(model, inputs, labels, weight_dtype=weight_dtype).loss)

            loss = torch.stack(losses).mean()

            val_loss += loss.item()
            if val_item_count > max_val_item_count:
                break

        avg_val_loss = val_loss / val_item_count

    model.train()
    return avg_val_loss