in hugegraph-ml/src/hugegraph_ml/tasks/node_classify_with_sample.py [0:0]
def __init__(self, graph: DGLGraph, model: nn.Module):
self.graph = graph
self._model = model
self.gpu = -1
self._device = "cpu"
self._early_stopping = None
self._is_trained = False
self.num_partitions = 100
self.batch_size = 100
self.sampler = dgl.dataloading.ClusterGCNSampler(
graph,
self.num_partitions,
)
self.dataloader = dgl.dataloading.DataLoader(
self.graph,
torch.arange(self.num_partitions).to(self._device),
self.sampler,
device=self._device,
batch_size=self.batch_size,
shuffle=True,
drop_last=False,
num_workers=0,
use_uva=False,
)
self._check_graph()