in src/gluonts/nursery/SCott/pts/model/estimator.py [0:0]
def stratified_train_model(self, data_package) -> TrainOutput:
transformation = self.create_transformation()
transformation_full_batch = self.create_transformation(
is_full_batch=True
)
anchor_iter_dataset = TransformedGroupedIterableDataset(
list_of_dataset=data_package["group_data"],
is_train=True,
transform=transformation,
batch_size=self.trainer.batch_size,
)
anchor_data_loader = DataLoader(
anchor_iter_dataset,
batch_size=self.trainer.batch_size * self.trainer.num_strata,
num_workers=self.trainer.num_workers,
pin_memory=self.trainer.pin_memory,
)
training_iter_dataset = TransformedIterableDataset(
dataset=data_package["whole_data"],
is_train=True,
transform=transformation,
)
training_data_loader = DataLoader(
training_iter_dataset,
batch_size=self.trainer.batch_size,
num_workers=self.trainer.num_workers,
pin_memory=self.trainer.pin_memory,
)
test_iter_dataset = FullBatchDataset(
dataset=data_package["val_data"],
is_train=True,
transform=transformation_full_batch,
)
test_data_loader = DataLoader(
test_iter_dataset,
batch_size=8192,
num_workers=self.trainer.num_workers,
pin_memory=self.trainer.pin_memory,
)
full_batch_dataset = FullBatchDataset(
dataset=data_package["whole_data"],
is_train=True,
transform=transformation_full_batch,
)
full_batch_loader = DataLoader(
full_batch_dataset,
batch_size=8192,
num_workers=self.trainer.num_workers,
pin_memory=self.trainer.pin_memory,
)
# ensure that the training network is created on the same device
trained_net = self.create_training_network(self.trainer.device)
self.trainer(
net=trained_net,
input_names=get_module_forward_input_names(trained_net),
data_loaders={
"training_data_loader": training_data_loader,
"validation_data_loader": test_data_loader,
"anchor_data_loader": anchor_data_loader,
"full_batch_loader": full_batch_loader,
"group_ratio": data_package["group_ratio"],
},
)
return TrainOutput(
transformation=transformation,
trained_net=trained_net,
predictor=self.create_predictor(
transformation, trained_net, self.trainer.device
),
)