def create_training_data_config()

in src/data_creation_torch.py [0:0]


def create_training_data_config(parser, c: Train_Config, update, save_all_to_mem=True, shuffle_all=True, dataset_name="", test_set=False):
    in_features = c.in_features.data_to_features()
    out_features = c.out_features.data_to_features()

    if dataset_name == "":
        raise Exception("ERROR: no dataset_name set")

    if shuffle_all:
        save_all_to_mem = True

    _suffix = "_" + dataset_name + ("_normalized" if c.normalized_bones else "")
    if update:
        data = parser.load_numpy(c.normalized_bones)
        batch_nr = 0
        suffix = f"{_suffix}{batch_nr}"
        while os.path.exists(os.path.join(path_data, in_features.name + suffix + ".dat")) or \
                os.path.exists(os.path.join(path_data, out_features.name + suffix + ".dat")):
            if os.path.exists(os.path.join(path_data, in_features.name + suffix + ".dat")):
                os.remove(os.path.join(path_data, in_features.name + suffix + ".dat"))
            if os.path.exists(os.path.join(path_data, out_features.name + suffix + ".dat")):
                os.remove(os.path.join(path_data, out_features.name + suffix + ".dat"))
            batch_nr += 1
            suffix = f"{_suffix}{batch_nr}"

        batch_nr = 0
        X = []
        y = []

        for skeleton, tracks in tqdm(data, desc="extracting features"):
            for _track in tracks:
                track = _track[0] if len(_track.shape) == 4 else _track
                _x = __frames_to_feature_batch(track, in_features, config=c)
                _y = __frames_to_feature_batch(track, out_features, config=c)

                X.extend(_x.reshape((-1, _x.shape[1])))
                y.extend(_y.reshape((-1, _y.shape[1])))

            while len(X) >= c.input_size:
                next_X = X[c.input_size:]
                next_y = y[c.input_size:]
                X = X[:c.input_size]
                y = y[:c.input_size]

                X = torch.from_numpy(np.array(X))
                if c.out_features is in_features:
                    y = X
                else:
                    y = torch.from_numpy(np.array(y))

                suffix = f"{_suffix}{batch_nr}"

                torch.save(X, os.path.join(path_data, in_features.name + suffix + ".dat"))
                torch.save(y, os.path.join(path_data, out_features.name + suffix + ".dat"))
                batch_nr += 1

                X = next_X
                y = next_y
        if len(X) > 0:
            X = torch.from_numpy(np.array(X))
            if c.out_features is in_features:
                y = X
            else:
                y = torch.from_numpy(np.array(y))

            suffix = f"{_suffix}{batch_nr}"

            torch.save(X, os.path.join(path_data, in_features.name + suffix + ".dat"))
            torch.save(y, os.path.join(path_data, out_features.name + suffix + ".dat"))

    suffix = f"{_suffix}0"
    # if data is not available, create it!
    if not os.path.exists(os.path.join(path_data, in_features.name + suffix + ".dat")) or \
            not os.path.exists(os.path.join(path_data, out_features.name + suffix + ".dat")):
        yield from create_training_data_config(parser, c, True, save_all_to_mem, shuffle_all, dataset_name)
        return

    batch_nr = 0
    batch_samples = 0
    X_batch = None
    y_batch = None
    suffix = f"{_suffix}{batch_nr}"
    X_all = None
    y_all = None
    while os.path.exists(os.path.join(path_data, in_features.name + suffix + ".dat")) and os.path.exists(
            os.path.join(path_data, out_features.name + suffix + ".dat")):
        X = torch.load(os.path.join(path_data, in_features.name + suffix + ".dat"))
        y = torch.load(os.path.join(path_data, out_features.name + suffix + ".dat"))
        input_size = c.input_size if c.input_size > 0 else len(X)
        assert (len(X.shape) == 2)

        if save_all_to_mem:
            if X_all is None or y_all is None:
                X_all = [X]
                y_all = [y]
            else:
                X_all.append(X)
                y_all.append(y)
        else:
            if X_batch is None:
                X_batch = torch.empty((input_size, X.shape[1]), dtype=torch.float32)
                y_batch = torch.empty((input_size, y.shape[1]), dtype=torch.float32)

            for batch in range(max(1, len(X) // input_size)):
                new_samples = min(len(X), input_size - batch_samples)
                X_batch[batch_samples:batch_samples + new_samples] = X[batch * input_size: batch * input_size + new_samples]
                y_batch[batch_samples:batch_samples + new_samples] = y[batch * input_size: batch * input_size + new_samples]

                if batch_samples + new_samples == input_size:
                    if c.model.recurrent:
                        yield to_recurrent_feature(X_batch, y_batch, c.sequence_length, c.sequence_distance,
                                                   c.sequence_skip)
                    else:
                        yield X_batch, y_batch
                    if batch == max(1, len(X) // input_size) - 1:
                        batch_samples = min(input_size, len(X) - new_samples)
                        if batch_samples > 0:
                            X_batch[:batch_samples] = X[-batch_samples:]
                            y_batch[:batch_samples] = y[-batch_samples:]
                    else:
                        batch_samples = 0
                else:
                    batch_samples += new_samples

        batch_nr += 1
        suffix = f"{_suffix}{batch_nr}"

    if save_all_to_mem and X_all is not None:
        X_all = torch.cat(X_all, dim=0)
        y_all = torch.cat(y_all, dim=0)

        if shuffle_all:
            batches = len(X_all) // input_size
            last_entry = len(X_all) - (len(X_all) % input_size)
            if c.model.recurrent:
                for i in range(0, batches, c.sequence_distance):
                    yield to_recurrent_feature_index(X_all, y_all, c.sequence_length, c.sequence_skip, i, last_entry, batches)
            else:
                for i in range(batches):
                    yield X_all[i:last_entry:batches], y_all[i:last_entry:batches]
        else:
            if test_set:
                X_all = X_all[X_all.shape[0] * 8 // 10:]
                y_all = y_all[y_all.shape[0] * 8 // 10:]
            if c.model.recurrent:
                yield to_recurrent_feature(X_all, y_all, c.sequence_length, c.sequence_distance, c.sequence_skip)
            else:
                yield X_all, y_all