def train_config()

in src/train.py [0:0]


def train_config(parser: BvhConverter, c: Train_Config, update: bool):
    """
    :param parser: a parser that implements "load_numpy"
    :param c: Definition of model, skeleton, training params and feature sets
    :param update: renew training features - if False, last features will be loaded from disk
    :param plot: plot prediction after training
    :return:
    """

    criterion = c.get_loss_fun()
    par_indices = c.skeleton.parent_idx_vector()
    bone_lengths_ = torch.from_numpy(get_bone_length_dataset(par_indices, parser)).to(c.device)
    bone_offsets = get_bone_offsets_dataset_normalized(parser).to(c.device)

    for epoch in range(c.epochs):
        print("")
        for _X, _y in create_training_data_config(parser, c, update, save_all_to_mem=True, shuffle_all=True,
                                                  dataset_name=parser.name(), test_set=False):
            X = _X.to(c.device)
            y = _y.to(c.device)

            if c.y_scaler is None:
                X_init = X[:, -1] if c.model.recurrent else X
                c.x_scaler = c.scale(X_init, group_size=3, idx_end=c.in_features.features_wo_occlusions())
                c.y_scaler = c.scale(y, group_size=3, idx_end=c.out_features.features_wo_occlusions())

            y = c.y_scaler.transform(y)
            X = c.x_scaler.transform_recurrent(X) if c.model.recurrent else c.x_scaler.transform(X)

            if c.random_noise > 0:
                X += torch.rand_like(X, device=X.device) * c.random_noise - torch.ones_like(X, device=X.device) * (c.random_noise * 0.5)
                y += torch.rand_like(y, device=y.device) * c.random_noise - torch.ones_like(y, device=y.device) * (c.random_noise * 0.5)

            X_train = X[:X.shape[0] * 8 // 10]
            y_train = y[:y.shape[0] * 8 // 10]
            X_val = X[X_train.shape[0]:]
            y_val = y[y_train.shape[0]:]

            batch_size = X_train.shape[0] if c.batch_size is None else c.batch_size
            X_train, y_train = shuffle_X_y(X_train, y_train)
            batches = max(X_train.shape[0] // batch_size, 1)
            end_idx = X_train.shape[0] - (X_train.shape[0] % batch_size)

            for batch in range(batches):
                c.model.train()
                c.optimizer.zero_grad()
                X_train_batch = X_train[batch:end_idx:batches]
                y_train_batch = y_train[batch:end_idx:batches]

                if c.rotate_random > 0.0001:
                    X_train_batch, y_train_batch = rotate_features_random(X_train_batch, c, y_train_batch)

                if c.translate_random > 0.0001:
                    X_train_batch, y_train_batch = translate_features_random(X_train_batch, c, y_train_batch)

                if c.mask_noise is not None:
                    X_train_batch, y_train_batch = add_mask_noise(X_train_batch, c, y_train_batch)

                if c.model.recurrent:
                    c.model.init_hidden(batch_size)

                out = c.model(X_train_batch)

                # how many features before the occlusion mask at the end
                X_current_frame = X_train_batch[:, -1] if c.model.recurrent else X_train_batch
                joint_mask = get_joint_mask(X_current_frame, c)
                float_mask = get_float_mask(X_current_frame, c)
                loss_mask = torch_tile(joint_mask, 1, c.in_features.features_wo_occlusions() // (joint_mask.shape[-1]))

                loss = criterion(input=out, target=y_train_batch, mask=loss_mask, config=c, bone_lengths=bone_lengths_, body=X_current_frame,
                                 bone_offsets=bone_offsets, float_mask=float_mask)
                loss.backward()
                c.optimizer.step()
                out = c.model.post_loss_activation(out)

                print(f"\rloss={loss.item()}", end="")

                combine_out_finger_in_body = False

                if combine_out_finger_in_body:
                    pred_pos = c.out_features.features_to_pose().solve_batch(out.detach(), body=X_current_frame, mask=joint_mask, ref_positions=y_train_batch,
                                                                             bone_offsets=bone_offsets)
                    pred_pos = pred_pos.reshape((len(out), -1, 3))

                    from plots import Vispy3DScatter
                    vispy = Vispy3DScatter()
                    vispy.plot_skeleton_with_bones(pred_pos.cpu().numpy()[:, :, [0, 2, 1]], c.skeleton, fps=parser.dataset_fps / 30, speed=0.2)

                c.model.eval()
                if batch == batches - 1:
                    print_val_loss(epoch, c, criterion, X_val[:c.max_batch_size], y_val[:c.max_batch_size], loss, bone_lengths_, bone_offsets)
                    eval_pose_accuracy(X_val, c, y_val, bone_lengths_, bone_offsets)

        update = False

    c.model.save()