def split_train_val()

in videoalignment/datasets.py [0:0]


    def split_train_val(self):
        cc = list(set([v["video"].split("/")[0] for v in self.gt_all_videos]))
        ncc = len(cc)

        val_cc = cc[
            (ncc // self.n_folds)
            * self.args.fold_index : (ncc // self.n_folds)
            * (self.args.fold_index + 1)
        ]
        train_cc = (
            cc[: (ncc // self.n_folds) * self.args.fold_index]
            + cc[(ncc // self.n_folds) * (self.args.fold_index + 1) :]
        )

        videos_train = []
        videos_val = []

        for c in train_cc:
            videos_train.extend(
                [v for v in self.gt_all_videos if v["video"].split("/")[0] == c]
            )

        for c in val_cc:
            videos_val.extend(
                [v for v in self.gt_all_videos if v["video"].split("/")[0] == c]
            )

        pairs_train = [
            p for p in self.gt_all_overlapping_pairs if p["videos"][0] in videos_train
        ]
        pairs_val = [
            p for p in self.gt_all_overlapping_pairs if p["videos"][0] in videos_val
        ]

        if self.phase == "train":
            self.videos = videos_train
            self.overlapping_pairs = pairs_train
        elif self.phase == "val":
            self.videos = videos_val
            self.overlapping_pairs = pairs_val
        elif self.phase == "all":
            self.videos = self.gt_all_videos
            self.overlapping_pairs = self.gt_all_overlapping_pairs