in videoalignment/datasets.py [0:0]
def split_train_val(self):
cc = list(set([v["video"].split("/")[0] for v in self.gt_all_videos]))
ncc = len(cc)
val_cc = cc[
(ncc // self.n_folds)
* self.args.fold_index : (ncc // self.n_folds)
* (self.args.fold_index + 1)
]
train_cc = (
cc[: (ncc // self.n_folds) * self.args.fold_index]
+ cc[(ncc // self.n_folds) * (self.args.fold_index + 1) :]
)
videos_train = []
videos_val = []
for c in train_cc:
videos_train.extend(
[v for v in self.gt_all_videos if v["video"].split("/")[0] == c]
)
for c in val_cc:
videos_val.extend(
[v for v in self.gt_all_videos if v["video"].split("/")[0] == c]
)
pairs_train = [
p for p in self.gt_all_overlapping_pairs if p["videos"][0] in videos_train
]
pairs_val = [
p for p in self.gt_all_overlapping_pairs if p["videos"][0] in videos_val
]
if self.phase == "train":
self.videos = videos_train
self.overlapping_pairs = pairs_train
elif self.phase == "val":
self.videos = videos_val
self.overlapping_pairs = pairs_val
elif self.phase == "all":
self.videos = self.gt_all_videos
self.overlapping_pairs = self.gt_all_overlapping_pairs