in part_generator.py [0:0]
def sample_partial_test(self, n):
sample_ids = [np.random.randint(self.__len__()) for _ in range(n)]
sample_jsons = [json.load(open(self.paths[sample_id]))for sample_id in sample_ids]
samples = []
samples_partial = []
samples_partonly = []
for sample_json in sample_jsons:
input_parts_json = sample_json['input_parts']
target_part_json = sample_json['target_part']
img_partial_test = []
vector_input_part = []
for i in range(self.n_part):
key = self.id_to_part[i]
vector_input_part += input_parts_json[key]
img_partial_test.append(self.processed_part_to_raster(input_parts_json[key], side=self.image_size))
img_partial_test.append(self.processed_part_to_raster(vector_input_part, side=self.image_size))
samples_partial.append(torch.cat(img_partial_test, 0))
img_partonly_test = self.processed_part_to_raster(target_part_json, side=self.image_size)
img_test = self.processed_part_to_raster(vector_input_part+target_part_json, side=self.image_size)
samples.append(img_test)
samples_partonly.append(img_partonly_test)
return torch.stack(samples), torch.stack(samples_partial), torch.stack(samples_partonly)