def eval()

in src/eval.py [0:0]


def eval(parser, c, update, plot=False):
    par_indices = c.skeleton.parent_idx_vector()
    bone_lengths_ = torch.from_numpy(get_bone_length_dataset(par_indices, parser)).to(c.device)
    bone_offsets = get_bone_offsets_dataset_normalized(parser).to(c.device)

    # initialize various data that we are evaluating
    acc = []
    untracked_acc = []
    tracked_acc = []
    acc_repeat = []
    untracked_acc_repeat = []
    tracked_acc_repeat = []
    joint_occ = get_empty_joint_acc_dict(c)
    joint_acc_all = get_empty_joint_acc_dict(c)
    joint_acc_untracked = get_empty_joint_acc_dict(c)
    joint_acc_tracked = get_empty_joint_acc_dict(c)
    bone_length_acc_ratio = get_empty_joint_acc_dict(c)
    bone_length_acc_distance = get_empty_joint_acc_dict(c)

    joint_occ_input = get_empty_joint_acc_dict(c)
    joint_acc_all_input = get_empty_joint_acc_dict(c)
    joint_acc_untracked_input = get_empty_joint_acc_dict(c)
    joint_acc_tracked_input = get_empty_joint_acc_dict(c)

    compare_input_acc = False
    eval_joint_acc = True
    eval_bone_length = False
    eval_jitter = False
    enable_timing = False

    vispy_plot = None

    for X_val, y_val in create_training_data_config(parser, c, update, save_all_to_mem=True, shuffle_all=False, dataset_name=parser.name(), test_set=True):
        max_samples = 200000

        X_val = X_val[:max_samples]
        y_val = y_val[:max_samples]

        if c.y_scaler is None:
            X_init = X_val[:, -1] if c.model.recurrent else X_val
            c.x_scaler = c.scale(X_init, group_size=3, idx_end=c.in_features.features_wo_occlusions())
            c.y_scaler = c.scale(y_val, group_size=3, idx_end=c.out_features.features_wo_occlusions())

        y_val = c.y_scaler.transform(y_val)
        X_val = c.x_scaler.transform_recurrent(X_val) if c.model.recurrent else c.x_scaler.transform(X_val)

        n_occ = 0
        n_sum = 0

        batch_size = X_val.shape[0] if c.batch_size is None else c.batch_size

        for batch in range(max(X_val.shape[0] // batch_size, 1)):
            X_batch = X_val[batch * batch_size: (batch + 1) * batch_size].to(c.device)
            y_batch = y_val[batch * batch_size: (batch + 1) * batch_size].to(c.device)
            with torch.no_grad():
                if enable_timing:
                    start = torch.cuda.Event(enable_timing=True)
                    end = torch.cuda.Event(enable_timing=True)
                    start.record()
                out = c.model(X_batch)
                out = c.model.post_loss_activation(out)
                if enable_timing:
                    end.record()
                    torch.cuda.synchronize()
                    print(f"inference duration: {start.elapsed_time(end)}")

            if c.model.recurrent:
                mask = get_joint_mask(X_batch[:, -1], c) > 0.00001
                X_batch[:, -1] = c.x_scaler.inv_transform(X_batch[:, -1])
            else:
                mask = get_joint_mask(X_batch, c) > 0.00001
                X_batch = c.x_scaler.inv_transform(X_batch)

            n_occ += torch.sum(mask).cpu().item()
            n_sum += mask.shape[0] * mask.shape[1]

            y_batch = c.y_scaler.inv_transform(y_batch)
            out = c.y_scaler.inv_transform(out)

            X_batch_0 = X_batch[:, -1] if c.model.recurrent else X_batch

            ref_pose = c.out_features.features_to_pose().solve_batch(y_batch, ref_positions=None, mask=mask, bone_lengths=bone_lengths_, body=X_batch_0,
                                                                     bone_offsets=bone_offsets)
            pred_pose = c.out_features.features_to_pose().solve_batch(out, ref_positions=None, mask=None, bone_lengths=bone_lengths_, body=X_batch_0,
                                                                      bone_offsets=bone_offsets)

            if eval_jitter:
                eval_joint_acceleration(X_batch, y_batch, out, c, ref_pose, pred_pose, mask)

            print("use predicted pos:    ", end="")
            _acc, _untracked_acc, _tracked_acc = print_pose_accuracy(mask, ref_pose, pred_pose, c)
            if compare_input_acc:
                print("reuse last known pos: ", end="")
                in_pose = c.in_features.features_to_pose().solve_batch(X_batch_0, ref_positions=None, mask=mask, bone_lengths=bone_lengths, body=X_batch_0,
                                                                       bone_offsets=bone_offsets)
                _acc_repeat, _untracked_acc_repeat, _tracked_acc_repeat = print_pose_accuracy(mask, ref_pose, in_pose, c)
                acc_repeat.append(_acc_repeat)
                untracked_acc_repeat.append(_untracked_acc_repeat)
                tracked_acc_repeat.append(_tracked_acc_repeat)
                X_batch_wo_mask = X_batch[:, -1] if c.model.recurrent else X_batch
                X_batch_wo_mask = X_batch_wo_mask[:, :c.in_features.features_wo_occlusions()]
                if eval_joint_acc:
                    append_batch_occ_acc(joint_occ_input, joint_acc_all_input, joint_acc_untracked_input, joint_acc_tracked_input, mask, y_batch,
                                         X_batch_wo_mask, c, bone_lengths_, X_batch_0, bone_offsets=bone_offsets)

            acc.append(_acc)
            untracked_acc.append(_untracked_acc)
            tracked_acc.append(_tracked_acc)

            if eval_joint_acc:
                append_batch_occ_acc(joint_occ, joint_acc_all, joint_acc_untracked, joint_acc_tracked, mask, y_batch, out, c, bone_lengths_, X_batch_0,
                                     bone_offsets)

            if eval_bone_length:
                append_bone_length_acc(bone_length_acc_ratio, bone_length_acc_distance, mask, out, c, bone_lengths_, X_batch_0, bone_offsets)

            if plot is True or plot == 1:
                vispy_plot = plot_evaluation(X_batch_0, y_batch, out, parser, c, untracked_only=False, parent_indices=par_indices, vispy=vispy_plot,
                                             bone_lengths=bone_lengths_, bone_offsets=bone_offsets, mask=mask)
            if plot is not False and plot is not False:
                plot -= 1

    print("\n\n-------- predicted ---------")
    print_err("prediction ", acc, untracked_acc, tracked_acc)
    if eval_joint_acc:
        print_joint_occ_acc(joint_occ, joint_acc_all, joint_acc_untracked, joint_acc_tracked)

    if compare_input_acc:
        print("\n\n-------- input ---------")
        print_err("input ", acc_repeat, untracked_acc_repeat, tracked_acc_repeat)
        if eval_joint_acc:
            print_joint_occ_acc(joint_occ_input, joint_acc_all_input, joint_acc_untracked_input, joint_acc_tracked_input)

    if eval_bone_length:
        print("\n\n-------- bone lengths ---------")
        print_bone_length_acc(bone_length_acc_ratio, bone_length_acc_distance, c, bone_lengths_)