pai-python-sdk/training/pytorch_ddp/train_src/train_multinode.py (144 lines of code) (raw):

# Copyright 2023 Alibaba, Inc. or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # source: https://github.com/pytorch/examples/blob/main/distributed/ddp-tutorial-series/multinode.py import os import torch import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from torch.utils.data.distributed import DistributedSampler from torch.utils.data import Dataset from torch.nn.parallel import DistributedDataParallel as DDP from torch.distributed import init_process_group, destroy_process_group class MyTrainDataset(Dataset): def __init__(self, size): self.size = size self.data = [(torch.rand(20), torch.rand(1)) for _ in range(size)] def __len__(self): return self.size def __getitem__(self, index): return self.data[index] def ddp_setup(): init_process_group(backend="nccl") torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) class Trainer: def __init__( self, model: torch.nn.Module, train_data: DataLoader, optimizer: torch.optim.Optimizer, save_every: int, output_model_path: str, checkpoint_path: str, ) -> None: self.local_rank = int(os.environ["LOCAL_RANK"]) self.global_rank = int(os.environ["RANK"]) self.model = model.to(self.local_rank) self.train_data = train_data self.optimizer = optimizer self.save_every = save_every self.epochs_run = 0 self.output_model_path = output_model_path self.checkpoint_path = checkpoint_path if os.path.exists(self.get_snapshot_path()): print("Loading snapshot") self._load_snapshot(self.get_snapshot_path()) self.model = DDP(self.model, device_ids=[self.local_rank]) def get_snapshot_path(self): return os.path.join(self.checkpoint_path, "model.pt") def _load_snapshot(self, snapshot_path): loc = f"cuda:{self.local_rank}" snapshot = torch.load(snapshot_path, map_location=loc) self.model.load_state_dict(snapshot["MODEL_STATE"]) self.epochs_run = snapshot["EPOCHS_RUN"] print(f"Resuming training from snapshot at Epoch {self.epochs_run}") def _run_batch(self, source, targets): self.optimizer.zero_grad() output = self.model(source) loss = F.cross_entropy(output, targets) loss.backward() self.optimizer.step() def _run_epoch(self, epoch): b_sz = len(next(iter(self.train_data))[0]) print( f"[GPU-{self.global_rank}] Epoch {epoch} | Batchsize: {b_sz} | Steps: {len(self.train_data)}" ) self.train_data.sampler.set_epoch(epoch) for source, targets in self.train_data: source = source.to(self.local_rank) targets = targets.to(self.local_rank) self._run_batch(source, targets) def _save_snapshot(self, epoch): snapshot = { "MODEL_STATE": self.model.module.state_dict(), "EPOCHS_RUN": epoch, } torch.save(snapshot, self.get_snapshot_path()) print(f"Epoch {epoch} | Training snapshot saved at {self.output_model_path}") def _save_model(self): torch.save( self.model.state_dict(), os.path.join(self.output_model_path, "model.pt") ) def train(self, max_epochs: int): for epoch in range(self.epochs_run, max_epochs): self._run_epoch(epoch) if self.global_rank == 0 and epoch % self.save_every == 0: self._save_snapshot(epoch) # save model after training self._save_model() def load_train_objs(): train_set = MyTrainDataset(2048) # load your dataset model = torch.nn.Linear(20, 1) # load your model optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) return train_set, model, optimizer def prepare_dataloader(dataset: Dataset, batch_size: int): return DataLoader( dataset, batch_size=batch_size, pin_memory=True, shuffle=False, sampler=DistributedSampler(dataset), ) def main( save_every: int, total_epochs: int, batch_size: int, output_model_path: str, checkpoint_path: str, ): ddp_setup() dataset, model, optimizer = load_train_objs() train_data = prepare_dataloader(dataset, batch_size) trainer = Trainer( model, train_data, optimizer, save_every, output_model_path, checkpoint_path ) trainer.train(total_epochs) destroy_process_group() if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="simple distributed training job") parser.add_argument( "--total_epochs", type=int, help="Total epochs to train the model" ) parser.add_argument("--save_every", type=int, help="How often to save a snapshot") parser.add_argument( "--batch_size", default=32, type=int, help="Input batch size on each device (default: 32)", ) # 使用PAI训练服务设置的环境变量,表示模型保存路径 parser.add_argument( "--output_model_path", default=os.environ.get("PAI_OUTPUT_MODEL"), type=str, help="Output model path", ) # 使用PAI训练服务设置的环境变量,表示checkpoints保存路径 parser.add_argument( "--checkpoint_path", default=os.environ.get("PAI_OUTPUT_CHECKPOINTS"), type=str, help="checkpoints path", ) args = parser.parse_args() main( args.save_every, args.total_epochs, args.batch_size, args.output_model_path, args.checkpoint_path, )