def __init__()

in mapillary_sls/datasets/msls.py [0:0]


    def __init__(self, root_dir, cities = '', nNeg = 5, transform = None, mode = 'train', task = 'im2im', subtask = 'all', seq_length = 1, posDistThr = 10, negDistThr = 25, cached_queries = 1000, cached_negatives = 1000, positive_sampling = True):

        # initializing
        assert mode in ('train', 'val', 'test')
        assert task in ('im2im', 'im2seq', 'seq2im', 'seq2seq')
        assert subtask in ('all', 's2w', 'w2s', 'o2n', 'n2o', 'd2n', 'n2d')
        assert seq_length % 2 == 1
        assert (task == 'im2im' and seq_length == 1) or (task != 'im2im' and seq_length > 1)

        if cities in default_cities:
            self.cities = default_cities[cities]
        elif cities == '':
            self.cities = default_cities[mode]
        else:
            self.cities = cities.split(',')

        self.qIdx = []
        self.qImages = []
        self.pIdx = []
        self.nonNegIdx = []
        self.dbImages = []
        self.sideways = []
        self.night = []

        # hyper-parameters
        self.nNeg = nNeg
        self.margin = 0.1
        self.posDistThr = posDistThr
        self.negDistThr = negDistThr
        self.cached_queries = cached_queries
        self.cached_negatives = cached_negatives

        # flags
        self.cache = None
        self.exclude_panos = True
        self.mode = mode
        self.subtask = subtask

        # other
        self.transform = transform
        self.query_keys_with_no_match = []

        # define sequence length based on task
        if task == 'im2im':
            seq_length_q, seq_length_db = 1, 1
        elif task == 'seq2seq':
            seq_length_q, seq_length_db = seq_length, seq_length
        elif task == 'seq2im':
            seq_length_q, seq_length_db = seq_length, 1
        else: #im2seq
            seq_length_q, seq_length_db = 1, seq_length

        # load data
        for city in self.cities:
            print("=====> {}".format(city))

            subdir = 'test' if city in default_cities['test'] else 'train_val'

            # get len of images from cities so far for indexing
            _lenQ = len(self.qImages)
            _lenDb = len(self.dbImages)

            # when GPS / UTM is available
            if self.mode in ['train','val']:
                # load query data
                qData = pd.read_csv(join(root_dir, subdir, city, 'query', 'postprocessed.csv'), index_col = 0)
                qDataRaw = pd.read_csv(join(root_dir, subdir, city, 'query', 'raw.csv'), index_col = 0)

                # load database data
                dbData = pd.read_csv(join(root_dir, subdir, city, 'database', 'postprocessed.csv'), index_col = 0)
                dbDataRaw = pd.read_csv(join(root_dir, subdir, city, 'database', 'raw.csv'), index_col = 0)

                # arange based on task
                qSeqKeys, qSeqIdxs = self.arange_as_seq(qData, join(root_dir, subdir, city, 'query'), seq_length_q)
                dbSeqKeys, dbSeqIdxs = self.arange_as_seq(dbData, join(root_dir, subdir, city, 'database'), seq_length_db)

                # filter based on subtasks
                if self.mode in ['val']:
                    qIdx = pd.read_csv(join(root_dir, subdir, city, 'query', 'subtask_index.csv'), index_col = 0)
                    dbIdx = pd.read_csv(join(root_dir, subdir, city, 'database', 'subtask_index.csv'), index_col = 0)

                    # find all the sequence where the center frame belongs to a subtask
                    val_frames = np.where(qIdx[self.subtask])[0]
                    qSeqKeys, qSeqIdxs = self.filter(qSeqKeys, qSeqIdxs, val_frames)

                    val_frames = np.where(dbIdx[self.subtask])[0]
                    dbSeqKeys, dbSeqIdxs = self.filter(dbSeqKeys, dbSeqIdxs, val_frames)

                # filter based on panorama data
                if self.exclude_panos:
                    panos_frames = np.where((qDataRaw['pano'] == False).values)[0]
                    qSeqKeys, qSeqIdxs = self.filter(qSeqKeys, qSeqIdxs, panos_frames)

                    panos_frames = np.where((dbDataRaw['pano'] == False).values)[0]
                    dbSeqKeys, dbSeqIdxs = self.filter(dbSeqKeys, dbSeqIdxs, panos_frames)

                unique_qSeqIdx = np.unique(qSeqIdxs)
                unique_dbSeqIdx = np.unique(dbSeqIdxs)

                # if a combination of city, task and subtask is chosen, where there are no query/dabase images, then continue to next city
                if len(unique_qSeqIdx) == 0 or len(unique_dbSeqIdx) == 0: continue

                self.qImages.extend(qSeqKeys)
                self.dbImages.extend(dbSeqKeys)

                qData = qData.loc[unique_qSeqIdx]
                dbData = dbData.loc[unique_dbSeqIdx]

                # useful indexing functions
                seqIdx2frameIdx = lambda seqIdx, seqIdxs : seqIdxs[seqIdx]
                frameIdx2seqIdx = lambda frameIdx, seqIdxs: np.where(seqIdxs == frameIdx)[0][1]
                frameIdx2uniqFrameIdx = lambda frameIdx, uniqFrameIdx : np.where(np.in1d(uniqFrameIdx, frameIdx))[0]
                uniqFrameIdx2seqIdx = lambda frameIdxs, seqIdxs : np.where(np.in1d(seqIdxs,frameIdxs).reshape(seqIdxs.shape))[0]

                # utm coordinates
                utmQ = qData[['easting', 'northing']].values.reshape(-1,2)
                utmDb = dbData[['easting', 'northing']].values.reshape(-1,2)

                # find positive images for training
                neigh = NearestNeighbors(algorithm = 'brute')
                neigh.fit(utmDb)
                D, I = neigh.radius_neighbors(utmQ, self.posDistThr)

                if mode == 'train':
                    nD, nI = neigh.radius_neighbors(utmQ, self.negDistThr)

                night, sideways, index = qData['night'].values, (qData['view_direction'] == 'Sideways').values, qData.index
                for q_seq_idx in range(len(qSeqKeys)):

                    q_frame_idxs = seqIdx2frameIdx(q_seq_idx, qSeqIdxs)
                    q_uniq_frame_idx = frameIdx2uniqFrameIdx(q_frame_idxs, unique_qSeqIdx)

                    p_uniq_frame_idxs = np.unique([p for pos in I[q_uniq_frame_idx] for p in pos])

                    # the query image has at least one positive
                    if len(p_uniq_frame_idxs) > 0:
                        p_seq_idx = np.unique(uniqFrameIdx2seqIdx(unique_dbSeqIdx[p_uniq_frame_idxs], dbSeqIdxs))

                        self.pIdx.append(p_seq_idx + _lenDb)
                        self.qIdx.append(q_seq_idx + _lenQ)

                        # in training we have two thresholds, one for finding positives and one for finding images that we are certain are negatives.
                        if self.mode == 'train':

                            n_uniq_frame_idxs = np.unique([n for nonNeg in nI[q_uniq_frame_idx] for n in nonNeg])
                            n_seq_idx = np.unique(uniqFrameIdx2seqIdx(unique_dbSeqIdx[n_uniq_frame_idxs], dbSeqIdxs))

                            self.nonNegIdx.append(n_seq_idx + _lenDb)

                            # gather meta which is useful for positive sampling
                            if sum(night[np.in1d(index, q_frame_idxs)]) > 0: self.night.append(len(self.qIdx)-1)
                            if sum(sideways[np.in1d(index, q_frame_idxs)]) > 0: self.sideways.append(len(self.qIdx)-1)

                    else:
                        query_key = qSeqKeys[q_seq_idx].split('/')[-1][:-4]
                        self.query_keys_with_no_match.append(query_key)

            # when GPS / UTM / pano info is not available
            elif self.mode in ['test']:

                # load images for subtask
                qIdx = pd.read_csv(join(root_dir, subdir, city, 'query', 'subtask_index.csv'), index_col = 0)
                dbIdx = pd.read_csv(join(root_dir, subdir, city, 'database', 'subtask_index.csv'), index_col = 0)

                # arange in sequences
                qSeqKeys, qSeqIdxs = self.arange_as_seq(qIdx, join(root_dir, subdir, city, 'query'), seq_length_q)
                dbSeqKeys, dbSeqIdxs = self.arange_as_seq(dbIdx, join(root_dir, subdir, city, 'database'), seq_length_db)

                # filter query based on subtask
                val_frames = np.where(qIdx[self.subtask])[0]
                qSeqKeys, qSeqIdxs = self.filter(qSeqKeys, qSeqIdxs, val_frames)

                # filter database based on subtask
                val_frames = np.where(dbIdx[self.subtask])[0]
                dbSeqKeys, dbSeqIdxs = self.filter(dbSeqKeys, dbSeqIdxs, val_frames)

                self.qImages.extend(qSeqKeys)
                self.dbImages.extend(dbSeqKeys)

                # add query index
                self.qIdx.extend(list(range(_lenQ, len(qSeqKeys) + _lenQ)))

        # if a combination of cities, task and subtask is chosen, where there are no query/database images, then exit
        if len(self.qImages) == 0 or len(self.dbImages) == 0:
            print("Exiting...")
            print("A combination of cities, task and subtask have been chosen, where there are no query/database images.")
            print("Try choosing a different subtask or more cities")
            sys.exit()

        # cast to np.arrays for indexing during training
        self.qIdx = np.asarray(self.qIdx)
        self.qImages = np.asarray(self.qImages)
        self.pIdx = np.asarray(self.pIdx)
        self.nonNegIdx = np.asarray(self.nonNegIdx)
        self.dbImages = np.asarray(self.dbImages)
        self.sideways = np.asarray(self.sideways)
        self.night = np.asarray(self.night)

        # decide device type ( important for triplet mining )
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.threads = 8
        self.bs = 24

        if mode == 'train':

            # for now always 1-1 lookup.
            self.negCache = np.asarray([np.empty((0,), dtype=int)]*len(self.qIdx))

            # calculate weights for positive sampling
            if positive_sampling:
                self.__calcSamplingWeights__()
            else:
                self.weights = np.ones(len(self.qIdx)) / float(len(self.qIdx))