def stratified_train_model()

in src/gluonts/nursery/SCott/pts/model/estimator.py [0:0]


    def stratified_train_model(self, data_package) -> TrainOutput:
        transformation = self.create_transformation()
        transformation_full_batch = self.create_transformation(
            is_full_batch=True
        )

        anchor_iter_dataset = TransformedGroupedIterableDataset(
            list_of_dataset=data_package["group_data"],
            is_train=True,
            transform=transformation,
            batch_size=self.trainer.batch_size,
        )

        anchor_data_loader = DataLoader(
            anchor_iter_dataset,
            batch_size=self.trainer.batch_size * self.trainer.num_strata,
            num_workers=self.trainer.num_workers,
            pin_memory=self.trainer.pin_memory,
        )

        training_iter_dataset = TransformedIterableDataset(
            dataset=data_package["whole_data"],
            is_train=True,
            transform=transformation,
        )

        training_data_loader = DataLoader(
            training_iter_dataset,
            batch_size=self.trainer.batch_size,
            num_workers=self.trainer.num_workers,
            pin_memory=self.trainer.pin_memory,
        )

        test_iter_dataset = FullBatchDataset(
            dataset=data_package["val_data"],
            is_train=True,
            transform=transformation_full_batch,
        )

        test_data_loader = DataLoader(
            test_iter_dataset,
            batch_size=8192,
            num_workers=self.trainer.num_workers,
            pin_memory=self.trainer.pin_memory,
        )

        full_batch_dataset = FullBatchDataset(
            dataset=data_package["whole_data"],
            is_train=True,
            transform=transformation_full_batch,
        )

        full_batch_loader = DataLoader(
            full_batch_dataset,
            batch_size=8192,
            num_workers=self.trainer.num_workers,
            pin_memory=self.trainer.pin_memory,
        )

        # ensure that the training network is created on the same device
        trained_net = self.create_training_network(self.trainer.device)

        self.trainer(
            net=trained_net,
            input_names=get_module_forward_input_names(trained_net),
            data_loaders={
                "training_data_loader": training_data_loader,
                "validation_data_loader": test_data_loader,
                "anchor_data_loader": anchor_data_loader,
                "full_batch_loader": full_batch_loader,
                "group_ratio": data_package["group_ratio"],
            },
        )

        return TrainOutput(
            transformation=transformation,
            trained_net=trained_net,
            predictor=self.create_predictor(
                transformation, trained_net, self.trainer.device
            ),
        )