in mapillary_sls/datasets/msls.py [0:0]
def __init__(self, root_dir, cities = '', nNeg = 5, transform = None, mode = 'train', task = 'im2im', subtask = 'all', seq_length = 1, posDistThr = 10, negDistThr = 25, cached_queries = 1000, cached_negatives = 1000, positive_sampling = True):
# initializing
assert mode in ('train', 'val', 'test')
assert task in ('im2im', 'im2seq', 'seq2im', 'seq2seq')
assert subtask in ('all', 's2w', 'w2s', 'o2n', 'n2o', 'd2n', 'n2d')
assert seq_length % 2 == 1
assert (task == 'im2im' and seq_length == 1) or (task != 'im2im' and seq_length > 1)
if cities in default_cities:
self.cities = default_cities[cities]
elif cities == '':
self.cities = default_cities[mode]
else:
self.cities = cities.split(',')
self.qIdx = []
self.qImages = []
self.pIdx = []
self.nonNegIdx = []
self.dbImages = []
self.sideways = []
self.night = []
# hyper-parameters
self.nNeg = nNeg
self.margin = 0.1
self.posDistThr = posDistThr
self.negDistThr = negDistThr
self.cached_queries = cached_queries
self.cached_negatives = cached_negatives
# flags
self.cache = None
self.exclude_panos = True
self.mode = mode
self.subtask = subtask
# other
self.transform = transform
self.query_keys_with_no_match = []
# define sequence length based on task
if task == 'im2im':
seq_length_q, seq_length_db = 1, 1
elif task == 'seq2seq':
seq_length_q, seq_length_db = seq_length, seq_length
elif task == 'seq2im':
seq_length_q, seq_length_db = seq_length, 1
else: #im2seq
seq_length_q, seq_length_db = 1, seq_length
# load data
for city in self.cities:
print("=====> {}".format(city))
subdir = 'test' if city in default_cities['test'] else 'train_val'
# get len of images from cities so far for indexing
_lenQ = len(self.qImages)
_lenDb = len(self.dbImages)
# when GPS / UTM is available
if self.mode in ['train','val']:
# load query data
qData = pd.read_csv(join(root_dir, subdir, city, 'query', 'postprocessed.csv'), index_col = 0)
qDataRaw = pd.read_csv(join(root_dir, subdir, city, 'query', 'raw.csv'), index_col = 0)
# load database data
dbData = pd.read_csv(join(root_dir, subdir, city, 'database', 'postprocessed.csv'), index_col = 0)
dbDataRaw = pd.read_csv(join(root_dir, subdir, city, 'database', 'raw.csv'), index_col = 0)
# arange based on task
qSeqKeys, qSeqIdxs = self.arange_as_seq(qData, join(root_dir, subdir, city, 'query'), seq_length_q)
dbSeqKeys, dbSeqIdxs = self.arange_as_seq(dbData, join(root_dir, subdir, city, 'database'), seq_length_db)
# filter based on subtasks
if self.mode in ['val']:
qIdx = pd.read_csv(join(root_dir, subdir, city, 'query', 'subtask_index.csv'), index_col = 0)
dbIdx = pd.read_csv(join(root_dir, subdir, city, 'database', 'subtask_index.csv'), index_col = 0)
# find all the sequence where the center frame belongs to a subtask
val_frames = np.where(qIdx[self.subtask])[0]
qSeqKeys, qSeqIdxs = self.filter(qSeqKeys, qSeqIdxs, val_frames)
val_frames = np.where(dbIdx[self.subtask])[0]
dbSeqKeys, dbSeqIdxs = self.filter(dbSeqKeys, dbSeqIdxs, val_frames)
# filter based on panorama data
if self.exclude_panos:
panos_frames = np.where((qDataRaw['pano'] == False).values)[0]
qSeqKeys, qSeqIdxs = self.filter(qSeqKeys, qSeqIdxs, panos_frames)
panos_frames = np.where((dbDataRaw['pano'] == False).values)[0]
dbSeqKeys, dbSeqIdxs = self.filter(dbSeqKeys, dbSeqIdxs, panos_frames)
unique_qSeqIdx = np.unique(qSeqIdxs)
unique_dbSeqIdx = np.unique(dbSeqIdxs)
# if a combination of city, task and subtask is chosen, where there are no query/dabase images, then continue to next city
if len(unique_qSeqIdx) == 0 or len(unique_dbSeqIdx) == 0: continue
self.qImages.extend(qSeqKeys)
self.dbImages.extend(dbSeqKeys)
qData = qData.loc[unique_qSeqIdx]
dbData = dbData.loc[unique_dbSeqIdx]
# useful indexing functions
seqIdx2frameIdx = lambda seqIdx, seqIdxs : seqIdxs[seqIdx]
frameIdx2seqIdx = lambda frameIdx, seqIdxs: np.where(seqIdxs == frameIdx)[0][1]
frameIdx2uniqFrameIdx = lambda frameIdx, uniqFrameIdx : np.where(np.in1d(uniqFrameIdx, frameIdx))[0]
uniqFrameIdx2seqIdx = lambda frameIdxs, seqIdxs : np.where(np.in1d(seqIdxs,frameIdxs).reshape(seqIdxs.shape))[0]
# utm coordinates
utmQ = qData[['easting', 'northing']].values.reshape(-1,2)
utmDb = dbData[['easting', 'northing']].values.reshape(-1,2)
# find positive images for training
neigh = NearestNeighbors(algorithm = 'brute')
neigh.fit(utmDb)
D, I = neigh.radius_neighbors(utmQ, self.posDistThr)
if mode == 'train':
nD, nI = neigh.radius_neighbors(utmQ, self.negDistThr)
night, sideways, index = qData['night'].values, (qData['view_direction'] == 'Sideways').values, qData.index
for q_seq_idx in range(len(qSeqKeys)):
q_frame_idxs = seqIdx2frameIdx(q_seq_idx, qSeqIdxs)
q_uniq_frame_idx = frameIdx2uniqFrameIdx(q_frame_idxs, unique_qSeqIdx)
p_uniq_frame_idxs = np.unique([p for pos in I[q_uniq_frame_idx] for p in pos])
# the query image has at least one positive
if len(p_uniq_frame_idxs) > 0:
p_seq_idx = np.unique(uniqFrameIdx2seqIdx(unique_dbSeqIdx[p_uniq_frame_idxs], dbSeqIdxs))
self.pIdx.append(p_seq_idx + _lenDb)
self.qIdx.append(q_seq_idx + _lenQ)
# in training we have two thresholds, one for finding positives and one for finding images that we are certain are negatives.
if self.mode == 'train':
n_uniq_frame_idxs = np.unique([n for nonNeg in nI[q_uniq_frame_idx] for n in nonNeg])
n_seq_idx = np.unique(uniqFrameIdx2seqIdx(unique_dbSeqIdx[n_uniq_frame_idxs], dbSeqIdxs))
self.nonNegIdx.append(n_seq_idx + _lenDb)
# gather meta which is useful for positive sampling
if sum(night[np.in1d(index, q_frame_idxs)]) > 0: self.night.append(len(self.qIdx)-1)
if sum(sideways[np.in1d(index, q_frame_idxs)]) > 0: self.sideways.append(len(self.qIdx)-1)
else:
query_key = qSeqKeys[q_seq_idx].split('/')[-1][:-4]
self.query_keys_with_no_match.append(query_key)
# when GPS / UTM / pano info is not available
elif self.mode in ['test']:
# load images for subtask
qIdx = pd.read_csv(join(root_dir, subdir, city, 'query', 'subtask_index.csv'), index_col = 0)
dbIdx = pd.read_csv(join(root_dir, subdir, city, 'database', 'subtask_index.csv'), index_col = 0)
# arange in sequences
qSeqKeys, qSeqIdxs = self.arange_as_seq(qIdx, join(root_dir, subdir, city, 'query'), seq_length_q)
dbSeqKeys, dbSeqIdxs = self.arange_as_seq(dbIdx, join(root_dir, subdir, city, 'database'), seq_length_db)
# filter query based on subtask
val_frames = np.where(qIdx[self.subtask])[0]
qSeqKeys, qSeqIdxs = self.filter(qSeqKeys, qSeqIdxs, val_frames)
# filter database based on subtask
val_frames = np.where(dbIdx[self.subtask])[0]
dbSeqKeys, dbSeqIdxs = self.filter(dbSeqKeys, dbSeqIdxs, val_frames)
self.qImages.extend(qSeqKeys)
self.dbImages.extend(dbSeqKeys)
# add query index
self.qIdx.extend(list(range(_lenQ, len(qSeqKeys) + _lenQ)))
# if a combination of cities, task and subtask is chosen, where there are no query/database images, then exit
if len(self.qImages) == 0 or len(self.dbImages) == 0:
print("Exiting...")
print("A combination of cities, task and subtask have been chosen, where there are no query/database images.")
print("Try choosing a different subtask or more cities")
sys.exit()
# cast to np.arrays for indexing during training
self.qIdx = np.asarray(self.qIdx)
self.qImages = np.asarray(self.qImages)
self.pIdx = np.asarray(self.pIdx)
self.nonNegIdx = np.asarray(self.nonNegIdx)
self.dbImages = np.asarray(self.dbImages)
self.sideways = np.asarray(self.sideways)
self.night = np.asarray(self.night)
# decide device type ( important for triplet mining )
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.threads = 8
self.bs = 24
if mode == 'train':
# for now always 1-1 lookup.
self.negCache = np.asarray([np.empty((0,), dtype=int)]*len(self.qIdx))
# calculate weights for positive sampling
if positive_sampling:
self.__calcSamplingWeights__()
else:
self.weights = np.ones(len(self.qIdx)) / float(len(self.qIdx))