def fit()

in src/rime/models/graph_conv.py [0:0]


    def fit(self, *V_arr):
        dataset, G_list, user_proposal, item_proposal, prior_score = zip(*[
            self._extract_task(k, V) for k, V in enumerate(V_arr)
        ])

        print("GraphConv label sizes", [len(d) for d in dataset])
        dataset = np.vstack(dataset)

        if "embedding" in V_arr[0].user_in_test:
            self._model_kw["user_embeddings"] = np.vstack(
                V_arr[0].user_in_test['embedding'].iloc[:1])  # just need shape[1]
        model = _GraphConv(None, len(self._padded_item_list), **self._model_kw)

        N = len(dataset)
        train_set, valid_set = default_random_split(dataset)

        trainer = Trainer(
            max_epochs=self.max_epochs, gpus=int(torch.cuda.is_available()),
            log_every_n_steps=1, callbacks=[model._checkpoint, LearningRateMonitor()])

        model.G_list = G_list
        model.user_proposal = user_proposal
        model.item_proposal = item_proposal
        if self.sample_with_prior:
            model.prior_score = [auto_cast_lazy_score(p) for p in prior_score]
        else:
            model.prior_score = [None for p in prior_score]

        trainer.fit(
            model,
            DataLoader(train_set, self.batch_size, shuffle=True, num_workers=(N > 1e4) * 4),
            DataLoader(valid_set, self.batch_size, num_workers=(N > 1e4) * 4))
        model._load_best_checkpoint("best")
        for attr in ['G_list', 'user_proposal', 'item_proposal', 'prior_score', 'prior_score_T']:
            delattr(model, attr)

        src_j = torch.arange(len(self._padded_item_list))
        self.item_embeddings = model.item_encoder(src_j).detach().numpy()
        self.item_biases = model.item_bias_vec(src_j).detach().numpy().ravel()
        self.model = model
        return self