def __init__()

in compert/api.py [0:0]


    def __init__(self, datasets, model):
        """
        Parameters
        ----------
        dataset : ComPertDataset
            Full dataset.
        model : ComPertModel
            Pre-trained ComPert model.
        """
        dataset = datasets['training']
        self.perturbation_key = dataset.perturbation_key
        self.dose_key = dataset.dose_key
        self.covars_key = dataset.covars_key
        self.min_dose = dataset.drugs[dataset.drugs > 0].min().item()
        self.max_dose = dataset.drugs[dataset.drugs > 0].max().item()

        self.model = model
        self.var_names = dataset.var_names

        self.unique_perts = list(dataset.perts_dict.keys())
        self.unique_сovars = list(dataset.covars_dict.keys())
        self.num_drugs = dataset.num_drugs

        self.perts_dict = dataset.perts_dict
        self.covars_dict = dataset.covars_dict

        self.drug_ohe = torch.Tensor(list(dataset.perts_dict.values()))
        self.covars_ohe = torch.LongTensor(list(dataset.covars_dict.values()))

        self.emb_covars = None
        self.emb_perts = None
        self.seen_covars_perts = None
        self.comb_emb = None
        self.control_cat = None

        self.seen_covars_perts = {}
        for k in datasets.keys():
            self.seen_covars_perts[k] =  np.unique(datasets[k].pert_categories)

        self.measured_points = {}
        self.num_measured_points = {}
        for k in datasets.keys():
            self.measured_points[k] = {}
            self.num_measured_points[k] = {}
            for pert in np.unique(datasets[k].pert_categories):
                num_points = len(np.where(datasets[k].pert_categories == pert)[0])
                self.num_measured_points[k][pert] = num_points

                cov, drug, dose = pert.split('_')
                if not('+' in dose):
                    dose = float(dose)
                if cov in self.measured_points[k].keys():
                    if drug in self.measured_points[k][cov].keys():
                        self.measured_points[k][cov][drug].append(dose)
                    else:
                        self.measured_points[k][cov][drug] = [dose]
                else:
                    self.measured_points[k][cov] = {drug: [dose]}

        self.measured_points['all'] = copy.deepcopy(self.measured_points['training'])
        for cov in self.measured_points['ood'].keys():
            for pert in self.measured_points['ood'][cov].keys():
                if pert in self.measured_points['training'][cov].keys():
                    self.measured_points['all'][cov][pert] =\
                        self.measured_points['training'][cov][pert].copy()+\
                    self.measured_points['ood'][cov][pert].copy()
                else:
                    self.measured_points['all'][cov][pert] =\
                        self.measured_points['ood'][cov][pert].copy()