def load_train_objs()

in pai-python-sdk/training/pytorch_ddp/train_src/train_multinode.py [0:0]


def load_train_objs():
    train_set = MyTrainDataset(2048)  # load your dataset
    model = torch.nn.Linear(20, 1)  # load your model
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
    return train_set, model, optimizer