def print_val_loss_tmp()

in src/eval.py [0:0]


def print_val_loss_tmp(epoch, c: configs.Train_Config, criterion, X, y, train_loss, bone_lengths_, c2: configs.Train_Config = None, train_loss2=None):
    if c.model.recurrent:
        c.model.init_hidden(X.shape[0])

    X = torch.clone(X[:c.max_batch_size]).to(c.device)
    y = torch.clone(y[:c.max_batch_size]).to(c.device)
    prev_out = torch.clone(y).to(c.device)

    with torch.no_grad():
        for i in range(c.sequence_length, len(X)):
            X_cat = torch.cat([X[i:i + 1], prev_out[None, i - c.sequence_length:i]], dim=2)
            _out = c.y_scaler.inv_transform(c.model(X_cat))
            prev_out[i] = torch.clone(_out)

            if c2 is not None:
                X_cat2 = torch.cat([X[i:i + 1], prev_out[None, i - c.sequence_length + 1:i + 1]], dim=2)
                if not c2.model.recurrent:
                    X_cat2 = torch.cat([X[i:i + 1, -1], prev_out[None, i], prev_out[None, i - 1]], dim=1)

                _out2 = c2.y_scaler.inv_transform(c2.model(X_cat2))
                prev_out[i] = torch.clone(_out2)

    X = X[c.sequence_length:]
    y = y[c.sequence_length:]
    out = prev_out[c.sequence_length:]

    out = c.y_scaler.inv_transform(out)
    y_scaled = c.y_scaler.inv_transform(torch.clone(y))

    if c.model.recurrent:
        X = X[:, -1]

    joint_mask = get_joint_mask(X, c) > 0.00001
    mask = torch_tile(joint_mask, 1, c.in_features.features_wo_occlusions() // (joint_mask.shape[-1]))

    prev_out_multi_frame = prev_out[c.sequence_length - 1:-1, None, :]
    for i in range(1, c.sequence_length):
        if c.sequence_length - i - 1 >= 0:
            prev_out_multi_frame = torch.cat((prev_out[c.sequence_length - i - 1:-i - 1, None, :], prev_out_multi_frame), dim=1)
        else:
            prev_out_multi_frame = torch.cat((prev_out[0:-c.sequence_length, None, :], prev_out_multi_frame), dim=1)
    loss = criterion(input=out, target=y_scaled, mask=mask, config=c, bone_lengths=bone_lengths_, prev_out=prev_out_multi_frame, body=X)

    if c2 is not None:
        loss2 = criterion(input=out, target=y_scaled, mask=mask, config=c2, bone_lengths=bone_lengths_, prev_out=prev_out_multi_frame, body=X)

    y_scaled_wo_mask = y_scaled[:, :c.out_features.features_wo_occlusions()]
    out_wo_mask = out[:, :c.out_features.features_wo_occlusions()]
    tracked_inputs = mask == False
    untracked_loss = out_wo_mask[mask] - y_scaled_wo_mask[mask]
    tracked_loss = out_wo_mask[tracked_inputs] - y_scaled_wo_mask[tracked_inputs]
    untracked_rmse = torch.sqrt(torch.mean(torch.pow(untracked_loss, 2)))
    tracked_rmse = torch.sqrt(torch.mean(torch.pow(tracked_loss, 2)))

    print(
        f"\repoch: {epoch}, train loss: {train_loss:.5f} {f'c2 train loss {train_loss2:.5f}' if c2 is not None else ''} "
        f"val loss: {loss:.5f} {f'c2 val loss {loss2:.5f}' if c2 is not None else ''}, "
        f"untracked val rmse = {untracked_rmse:.4f}, tracked val rmse = {tracked_rmse:.4f}",
        end="")