def __init__()

in compert/data.py [0:0]


    def __init__(self,
                 fname,
                 perturbation_key,
                 dose_key,
                 cell_type_key,
                 split_key='split'):

        data = sc.read(fname)

        self.perturbation_key = perturbation_key
        self.dose_key = dose_key
        self.cell_type_key = cell_type_key
        self.genes = torch.Tensor(data.X.A)

        self.var_names = data.var_names        

        self.pert_categories = np.array(data.obs['cov_drug_dose_name'].values)

        self.de_genes = data.uns['rank_genes_groups_cov']
        self.ctrl = data.obs['control'].values
        self.ctrl_name = list(np.unique(data[data.obs['control'] == 1].obs[self.perturbation_key]))

        self.drugs_names = np.array(data.obs[perturbation_key].values)
        self.dose_names = np.array(data.obs[dose_key].values)

        # get unique drugs
        drugs_names_unique = set()
        for d in self.drugs_names:
            [drugs_names_unique.add(i) for i in d.split("+")]
        self.drugs_names_unique = np.array(list(drugs_names_unique))

        # save encoder for a comparison with Mo's model
        # later we need to remove this part
        encoder_drug = OneHotEncoder(sparse=False)
        encoder_drug.fit(self.drugs_names_unique.reshape(-1, 1))
        
        self.atomic_drugs_dict = dict(zip(self.drugs_names_unique, encoder_drug.transform(
                self.drugs_names_unique.reshape(-1, 1))))

        # get drug combinations
        drugs = []
        for i, comb in enumerate(self.drugs_names):
            drugs_combos = encoder_drug.transform(
                np.array(comb.split("+")).reshape(-1, 1))
            dose_combos = str(data.obs[dose_key].values[i]).split("+")
            for j, d in enumerate(dose_combos):
                if j == 0:
                    drug_ohe = float(d) * drugs_combos[j]
                else:
                    drug_ohe += float(d) * drugs_combos[j]
            drugs.append(drug_ohe)
        self.drugs = torch.Tensor(drugs)

        self.cell_types_names = np.array(data.obs[cell_type_key].values)
        self.cell_types_names_unique = np.unique(self.cell_types_names)

        encoder_ct = OneHotEncoder(sparse=False)
        encoder_ct.fit(self.cell_types_names_unique.reshape(-1, 1))

        self.atomic_сovars_dict = dict(zip(list(self.cell_types_names_unique), encoder_ct.transform(
                self.cell_types_names_unique.reshape(-1, 1))))

        self.cell_types = torch.Tensor(encoder_ct.transform(
            self.cell_types_names.reshape(-1, 1))).float()

        self.num_cell_types = len(self.cell_types_names_unique)
        self.num_genes = self.genes.shape[1]
        self.num_drugs = len(self.drugs_names_unique)

        self.indices = {
            "all": list(range(len(self.genes))),
            "control": np.where(data.obs['control'] == 1)[0].tolist(),
            "treated": np.where(data.obs['control'] != 1)[0].tolist(),
            "train": np.where(data.obs[split_key] == 'train')[0].tolist(),
            "test": np.where(data.obs[split_key] == 'test')[0].tolist(),
            "ood": np.where(data.obs[split_key] == 'ood')[0].tolist()
        }

        atomic_ohe = encoder_drug.transform(
            self.drugs_names_unique.reshape(-1, 1))

        self.drug_dict = {}
        for idrug, drug in enumerate(self.drugs_names_unique):
            i = np.where(atomic_ohe[idrug] == 1)[0][0]
            self.drug_dict[i] = drug