in expanded_checklist/checklist/core_record.py [0:0]
def flatten_within_group(self, flatten_group):
if flatten_group in [None, FlattenGroup.NONE] or \
is_1d_list(self.data[0]):
return
new_preds = []
new_confs = []
new_data = []
new_meta = []
new_labels = []
if flatten_group == FlattenGroup.RANDOM_MATCH:
assert self.data_structure == DataShape.UNGROUPED
thresh = MAX_N_GROUPINGS
for eidx in range(len(self.data)):
# keeping the sampling in the inner loop allows for number
# of samples for each group to differ from template to template
# but also causes the samples to differ from temp. to temp.
sample_idxs = [list(range(len(x))) for x in self.data[eidx]]
# detect whether there is too many options while avoiding
# overflow
total_options = 1
for x in sample_idxs:
total_options *= len(x)
if total_options > thresh:
break
if total_options > thresh:
groupings = set()
while len(groupings) < thresh:
sample = tuple(
[np.random.choice(x) for x in sample_idxs])
groupings.add(sample)
else:
groupings = set(product(*sample_idxs))
meta = self.meta[eidx]
for chosen_idxs in groupings:
new_data.append(
[self.data[eidx][gidx][cidx]
for gidx, cidx in enumerate(chosen_idxs)])
new_preds.append(
[self.preds[eidx][gidx][cidx]
for gidx, cidx in enumerate(chosen_idxs)])
new_confs.append(
[self.confs[eidx][gidx][cidx]
for gidx, cidx in enumerate(chosen_idxs)])
tmp_meta = deepcopy(meta)
for gidx, cidx in enumerate(chosen_idxs):
gname = meta[gidx]
tmp_meta["SAMPLE"][gname] = meta["SAMPLE"][gname][cidx]
new_meta.append(tmp_meta)
new_labels.append(self.labels[eidx])
elif flatten_group == FlattenGroup.FLATTEN:
assert self.data_structure == DataShape.GROUPED
for gidx in range(len(self.group_names)):
new_group_preds, new_group_confs, new_group_data, \
new_group_labels, new_group_meta = [], [], [], [], []
for eidx in range(len(self.labels[gidx])):
nversions = len(self.preds[gidx][eidx])
new_group_preds += list(self.preds[gidx][eidx])
new_group_confs += list(self.confs[gidx][eidx])
new_group_data += list(self.data[gidx][eidx])
new_group_labels += [self.labels[gidx][eidx]] * nversions
for nv in range(nversions):
tmp_meta = deepcopy(self.meta[gidx][eidx])
gname = tmp_meta[gidx]
tmp_meta["SAMPLE"] =\
{gname: tmp_meta["SAMPLE"][gname][nv]}
tmp_meta["GROUP_FILL"] = tmp_meta["GROUP_FILL"][nv]
new_group_meta.append(tmp_meta)
new_preds.append(new_group_preds)
new_confs.append(new_group_confs)
new_labels.append(new_group_labels)
new_meta.append(new_group_meta)
new_data.append(new_group_data)
elif flatten_group == FlattenGroup.AVERAGE:
assert self.data_structure == DataShape.UNGROUPED
n_groups = len(self.data[0])
new_labels = self.labels
new_meta = self.meta
for eidx in range(len(self.data)):
meta = self.meta[eidx]
new_data.append([self.meta[eidx]["TEMPLATE"]] * n_groups)
new_preds.append(
[np.mean(np.array(self.preds[eidx][gidx]), axis=0)
for gidx in range(n_groups)])
new_confs.append(
[np.mean(np.array(self.confs[eidx][gidx]), axis=0)
for gidx in range(n_groups)])
elif flatten_group == FlattenGroup.FLATTEN_ALL:
# this loses all structure; i.e. group divisions and example
# divisions
for gidx in range(len(self.data)):
for eidx in range(len(self.data[gidx])):
for vidx in range(len(self.data[gidx][eidx])):
new_preds.append([self.preds[gidx][eidx][vidx]])
new_confs.append([self.confs[gidx][eidx][vidx]])
new_data.append([self.data[gidx][eidx][vidx]])
if self.data_structure == DataShape.GROUPED:
new_labels.append(self.labels[gidx][eidx])
tmp_meta = deepcopy(self.meta[gidx][eidx])
gname = tmp_meta[gidx]
tmp_meta["SAMPLE"] =\
{gname: tmp_meta["SAMPLE"][gname][vidx]}
tmp_meta["GROUP_FILL"] = tmp_meta["GROUP_FILL"][vidx]
new_meta.append(tmp_meta)
else:
new_labels.append(self.labels[gidx])
new_meta.append(self.meta[gidx])
self.labels = new_labels
self.preds = new_preds
self.confs = new_confs
self.meta = new_meta
self.data = new_data