in imnet_resnet50_scratch/train.py [0:0]
def _init_state(self) -> None:
"""
Initialize the state and load it from an existing checkpoint if any
"""
torch.manual_seed(0)
np.random.seed(0)
print("Create data loaders", flush=True)
Input_size_Image=self._train_cfg.input_size
Test_size=Input_size_Image
print("Input size : "+str(Input_size_Image))
print("Test size : "+str(Input_size_Image))
print("Initial LR :"+str(self._train_cfg.lr))
transf=get_transforms(input_size=Input_size_Image,test_size=Test_size, kind='full', crop=True, need=('train', 'val'), backbone=None)
transform_train = transf['train']
transform_test = transf['val']
train_set = datasets.ImageFolder(self._train_cfg.imnet_path + '/train',transform=transform_train)
train_sampler = RASampler(
train_set,self._train_cfg.num_tasks,self._train_cfg.global_rank,len(train_set),self._train_cfg.batch_per_gpu,repetitions=3,len_factor=2.0,shuffle=True, drop_last=False
)
self._train_loader = torch.utils.data.DataLoader(
train_set,
batch_size=self._train_cfg.batch_per_gpu,
num_workers=(self._train_cfg.workers-1),
sampler=train_sampler,
)
test_set = datasets.ImageFolder(self._train_cfg.imnet_path + '/val',transform=transform_test)
self._test_loader = torch.utils.data.DataLoader(
test_set, batch_size=self._train_cfg.batch_per_gpu, shuffle=False, num_workers=(self._train_cfg.workers-1),#sampler=test_sampler, Attention je le met pas pour l instant
)
print(f"Total batch_size: {self._train_cfg.batch_per_gpu * self._train_cfg.num_tasks}", flush=True)
print("Create distributed model", flush=True)
model = models.resnet50(pretrained=False)
model.cuda(self._train_cfg.local_rank)
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[self._train_cfg.local_rank], output_device=self._train_cfg.local_rank
)
linear_scaled_lr = 8.0 * self._train_cfg.lr * self._train_cfg.batch_per_gpu * self._train_cfg.num_tasks /512.0
optimizer = optim.SGD(model.parameters(), lr=linear_scaled_lr, momentum=0.9,weight_decay=1e-4)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30)
self._state = TrainerState(
epoch=0,accuracy=0.0, model=model, optimizer=optimizer, lr_scheduler=lr_scheduler
)
checkpoint_fn = osp.join(self._train_cfg.save_folder, str(self._train_cfg.job_id), "checkpoint.pth")
if os.path.isfile(checkpoint_fn):
print(f"Load existing checkpoint from {checkpoint_fn}", flush=True)
self._state = TrainerState.load(checkpoint_fn, default=self._state)