pai-python-sdk/training/pytorch_ddp/train_src/train_multinode.py (144 lines of code) (raw):
# Copyright 2023 Alibaba, Inc. or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# source: https://github.com/pytorch/examples/blob/main/distributed/ddp-tutorial-series/multinode.py
import os
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import Dataset
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
class MyTrainDataset(Dataset):
def __init__(self, size):
self.size = size
self.data = [(torch.rand(20), torch.rand(1)) for _ in range(size)]
def __len__(self):
return self.size
def __getitem__(self, index):
return self.data[index]
def ddp_setup():
init_process_group(backend="nccl")
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
class Trainer:
def __init__(
self,
model: torch.nn.Module,
train_data: DataLoader,
optimizer: torch.optim.Optimizer,
save_every: int,
output_model_path: str,
checkpoint_path: str,
) -> None:
self.local_rank = int(os.environ["LOCAL_RANK"])
self.global_rank = int(os.environ["RANK"])
self.model = model.to(self.local_rank)
self.train_data = train_data
self.optimizer = optimizer
self.save_every = save_every
self.epochs_run = 0
self.output_model_path = output_model_path
self.checkpoint_path = checkpoint_path
if os.path.exists(self.get_snapshot_path()):
print("Loading snapshot")
self._load_snapshot(self.get_snapshot_path())
self.model = DDP(self.model, device_ids=[self.local_rank])
def get_snapshot_path(self):
return os.path.join(self.checkpoint_path, "model.pt")
def _load_snapshot(self, snapshot_path):
loc = f"cuda:{self.local_rank}"
snapshot = torch.load(snapshot_path, map_location=loc)
self.model.load_state_dict(snapshot["MODEL_STATE"])
self.epochs_run = snapshot["EPOCHS_RUN"]
print(f"Resuming training from snapshot at Epoch {self.epochs_run}")
def _run_batch(self, source, targets):
self.optimizer.zero_grad()
output = self.model(source)
loss = F.cross_entropy(output, targets)
loss.backward()
self.optimizer.step()
def _run_epoch(self, epoch):
b_sz = len(next(iter(self.train_data))[0])
print(
f"[GPU-{self.global_rank}] Epoch {epoch} | Batchsize: {b_sz} | Steps: {len(self.train_data)}"
)
self.train_data.sampler.set_epoch(epoch)
for source, targets in self.train_data:
source = source.to(self.local_rank)
targets = targets.to(self.local_rank)
self._run_batch(source, targets)
def _save_snapshot(self, epoch):
snapshot = {
"MODEL_STATE": self.model.module.state_dict(),
"EPOCHS_RUN": epoch,
}
torch.save(snapshot, self.get_snapshot_path())
print(f"Epoch {epoch} | Training snapshot saved at {self.output_model_path}")
def _save_model(self):
torch.save(
self.model.state_dict(), os.path.join(self.output_model_path, "model.pt")
)
def train(self, max_epochs: int):
for epoch in range(self.epochs_run, max_epochs):
self._run_epoch(epoch)
if self.global_rank == 0 and epoch % self.save_every == 0:
self._save_snapshot(epoch)
# save model after training
self._save_model()
def load_train_objs():
train_set = MyTrainDataset(2048) # load your dataset
model = torch.nn.Linear(20, 1) # load your model
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
return train_set, model, optimizer
def prepare_dataloader(dataset: Dataset, batch_size: int):
return DataLoader(
dataset,
batch_size=batch_size,
pin_memory=True,
shuffle=False,
sampler=DistributedSampler(dataset),
)
def main(
save_every: int,
total_epochs: int,
batch_size: int,
output_model_path: str,
checkpoint_path: str,
):
ddp_setup()
dataset, model, optimizer = load_train_objs()
train_data = prepare_dataloader(dataset, batch_size)
trainer = Trainer(
model, train_data, optimizer, save_every, output_model_path, checkpoint_path
)
trainer.train(total_epochs)
destroy_process_group()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="simple distributed training job")
parser.add_argument(
"--total_epochs", type=int, help="Total epochs to train the model"
)
parser.add_argument("--save_every", type=int, help="How often to save a snapshot")
parser.add_argument(
"--batch_size",
default=32,
type=int,
help="Input batch size on each device (default: 32)",
)
# 使用PAI训练服务设置的环境变量,表示模型保存路径
parser.add_argument(
"--output_model_path",
default=os.environ.get("PAI_OUTPUT_MODEL"),
type=str,
help="Output model path",
)
# 使用PAI训练服务设置的环境变量,表示checkpoints保存路径
parser.add_argument(
"--checkpoint_path",
default=os.environ.get("PAI_OUTPUT_CHECKPOINTS"),
type=str,
help="checkpoints path",
)
args = parser.parse_args()
main(
args.save_every,
args.total_epochs,
args.batch_size,
args.output_model_path,
args.checkpoint_path,
)