in pai-python-sdk/training/pytorch_ddp/train_src/train_multinode.py [0:0]
def __init__(
self,
model: torch.nn.Module,
train_data: DataLoader,
optimizer: torch.optim.Optimizer,
save_every: int,
output_model_path: str,
checkpoint_path: str,