in code/train.py [0:0]
def _get_train_data_loader(batch_size, training_dir):
dataset = pd.read_csv(os.path.join(training_dir, "deeploc_per_protein_train.csv"))
train_data = ProteinSequenceDataset(
sequence=dataset.sequence.to_numpy(),
targets=dataset.location.to_numpy(),
tokenizer=tokenizer,
max_len=MAX_LEN
)
train_sampler = torch.utils.data.distributed.DistributedSampler(
dataset,
num_replicas=dist.get_world_size(),
rank=dist.get_rank())
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True,
sampler=train_sampler)
return train_dataloader