def restrict_images()

in c3dm/dataset/keypoints_dataset.py [0:0]


	def restrict_images(self):

		print( "limitting dataset to seqs: " + str(self.limit_seq_to) )
		if type(self.limit_seq_to) in (tuple,list):
			if len(self.limit_seq_to) > 1 or self.limit_seq_to[0] >= 0:
				self.db = [f for f in self.db if f['seq'] in self.limit_seq_to ]            
		elif type(self.limit_seq_to)==int:
			if self.limit_seq_to > 0:
				self.db = [f for f in self.db if f['seq'] < self.limit_seq_to ]
		else:
			assert False, "bad seq limit type"

		if self.limit_to > 0:
			tgtnum = min( self.limit_to, len(self.db) )
			prm = list(range(len(self.db)))[0:tgtnum]
			# with NumpySeedFix(): 
			# 	prm = np.random.permutation( \
			# 				len(self.db))[0:tgtnum]
			print( "limitting dataset to %d samples" % tgtnum )
			self.db = [self.db[i] for i in prm]
	
		if self.subsample > 1:
			orig_len = len(self.db)
			self.db = [self.db[i] for i in range(0, len(self.db), self.subsample)]
			print('db subsampled %d -> %d' % (orig_len, len(self.db)))

		if self.kp_conf_thr > 0. and 'kp_conf' in self.db[0]:
			for e in self.db:
				v = torch.FloatTensor(e['kp_vis'])
				c = torch.FloatTensor(e['kp_conf'])
				e['kp_vis'] = (c > self.kp_conf_thr).float().tolist()
		
		if self.min_visible > 0:
			len_orig = len(self.db)
			self.db = [ e for e in self.db \
				if (torch.FloatTensor(e['kp_vis'])>0).float().sum()>self.min_visible]
			print('kept %3.1f %% entries' % (100.*len(self.db)/float(len_orig)) )
			assert len(self.db) > 10