def flatten_within_group()

in expanded_checklist/checklist/core_record.py [0:0]


    def flatten_within_group(self, flatten_group):
        if flatten_group in [None, FlattenGroup.NONE] or \
                is_1d_list(self.data[0]):
            return

        new_preds = []
        new_confs = []
        new_data = []
        new_meta = []
        new_labels = []

        if flatten_group == FlattenGroup.RANDOM_MATCH:
            assert self.data_structure == DataShape.UNGROUPED

            thresh = MAX_N_GROUPINGS
            for eidx in range(len(self.data)):
                # keeping the sampling in the inner loop allows for number
                # of samples for each group to differ from template to template
                # but also causes the samples to differ from temp. to temp.
                sample_idxs = [list(range(len(x))) for x in self.data[eidx]]

                # detect whether there is too many options while avoiding
                # overflow
                total_options = 1
                for x in sample_idxs:
                    total_options *= len(x)
                    if total_options > thresh:
                        break

                if total_options > thresh:
                    groupings = set()
                    while len(groupings) < thresh:
                        sample = tuple(
                            [np.random.choice(x) for x in sample_idxs])
                        groupings.add(sample)
                else:
                    groupings = set(product(*sample_idxs))

                meta = self.meta[eidx]
                for chosen_idxs in groupings:
                    new_data.append(
                        [self.data[eidx][gidx][cidx]
                         for gidx, cidx in enumerate(chosen_idxs)])

                    new_preds.append(
                        [self.preds[eidx][gidx][cidx]
                         for gidx, cidx in enumerate(chosen_idxs)])

                    new_confs.append(
                        [self.confs[eidx][gidx][cidx]
                         for gidx, cidx in enumerate(chosen_idxs)])

                    tmp_meta = deepcopy(meta)
                    for gidx, cidx in enumerate(chosen_idxs):
                        gname = meta[gidx]
                        tmp_meta["SAMPLE"][gname] = meta["SAMPLE"][gname][cidx]
                    new_meta.append(tmp_meta)
                    new_labels.append(self.labels[eidx])

        elif flatten_group == FlattenGroup.FLATTEN:
            assert self.data_structure == DataShape.GROUPED
            for gidx in range(len(self.group_names)):
                new_group_preds, new_group_confs, new_group_data, \
                    new_group_labels, new_group_meta = [], [], [], [], []
                for eidx in range(len(self.labels[gidx])):
                    nversions = len(self.preds[gidx][eidx])
                    new_group_preds += list(self.preds[gidx][eidx])
                    new_group_confs += list(self.confs[gidx][eidx])
                    new_group_data += list(self.data[gidx][eidx])
                    new_group_labels += [self.labels[gidx][eidx]] * nversions

                    for nv in range(nversions):
                        tmp_meta = deepcopy(self.meta[gidx][eidx])
                        gname = tmp_meta[gidx]
                        tmp_meta["SAMPLE"] =\
                            {gname: tmp_meta["SAMPLE"][gname][nv]}
                        tmp_meta["GROUP_FILL"] = tmp_meta["GROUP_FILL"][nv]
                        new_group_meta.append(tmp_meta)
                new_preds.append(new_group_preds)
                new_confs.append(new_group_confs)
                new_labels.append(new_group_labels)
                new_meta.append(new_group_meta)
                new_data.append(new_group_data)
        elif flatten_group == FlattenGroup.AVERAGE:
            assert self.data_structure == DataShape.UNGROUPED

            n_groups = len(self.data[0])
            new_labels = self.labels
            new_meta = self.meta

            for eidx in range(len(self.data)):
                meta = self.meta[eidx]
                new_data.append([self.meta[eidx]["TEMPLATE"]] * n_groups)

                new_preds.append(
                    [np.mean(np.array(self.preds[eidx][gidx]), axis=0)
                     for gidx in range(n_groups)])
                new_confs.append(
                    [np.mean(np.array(self.confs[eidx][gidx]), axis=0)
                     for gidx in range(n_groups)])
        elif flatten_group == FlattenGroup.FLATTEN_ALL:
            # this loses all structure; i.e. group divisions and example
            # divisions
            for gidx in range(len(self.data)):
                for eidx in range(len(self.data[gidx])):
                    for vidx in range(len(self.data[gidx][eidx])):
                        new_preds.append([self.preds[gidx][eidx][vidx]])
                        new_confs.append([self.confs[gidx][eidx][vidx]])
                        new_data.append([self.data[gidx][eidx][vidx]])

                        if self.data_structure == DataShape.GROUPED:
                            new_labels.append(self.labels[gidx][eidx])
                            tmp_meta = deepcopy(self.meta[gidx][eidx])
                            gname = tmp_meta[gidx]
                            tmp_meta["SAMPLE"] =\
                                {gname: tmp_meta["SAMPLE"][gname][vidx]}
                            tmp_meta["GROUP_FILL"] = tmp_meta["GROUP_FILL"][vidx]
                            new_meta.append(tmp_meta) 
                        else:
                            new_labels.append(self.labels[gidx])
                            new_meta.append(self.meta[gidx])

        self.labels = new_labels
        self.preds = new_preds
        self.confs = new_confs
        self.meta = new_meta
        self.data = new_data