in src/data_manager.py [0:0]
def init_transform(root, samples, class_to_idx, seed,
keep_file=keep_file,
training=training):
""" Transforms applied to dataset at the start of training """
new_targets, new_samples = [], []
if training and (keep_file is not None) and os.path.exists(keep_file):
logger.info(f'Using {keep_file}')
with open(keep_file, 'r') as rfile:
for line in rfile:
class_name = line.split('_')[0]
target = class_to_idx[class_name]
img = line.split('\n')[0]
new_samples.append(
(os.path.join(root, class_name, img),
target))
new_targets.append(target)
else:
logger.info('flipping coin to keep labels')
g = torch.Generator()
g.manual_seed(seed)
for sample in samples:
if torch.bernoulli(torch.tensor(unlabel_prob), generator=g) == 0:
target = sample[1]
new_samples.append((sample[0], target))
new_targets.append(target)
return np.array(new_targets), np.array(new_samples)