in train_helpers.py [0:0]
def restore_params(model, path, local_rank, mpi_size, map_ddp=True, map_cpu=False):
state_dict = torch.load(distributed_maybe_download(path, local_rank, mpi_size), map_location='cpu' if map_cpu else None)
if map_ddp:
new_state_dict = {}
l = len('module.')
for k in state_dict:
if k.startswith('module.'):
new_state_dict[k[l:]] = state_dict[k]
else:
new_state_dict[k] = state_dict[k]
state_dict = new_state_dict
model.load_state_dict(state_dict)