in compert/api.py [0:0]
def __init__(self, datasets, model):
"""
Parameters
----------
dataset : ComPertDataset
Full dataset.
model : ComPertModel
Pre-trained ComPert model.
"""
dataset = datasets['training']
self.perturbation_key = dataset.perturbation_key
self.dose_key = dataset.dose_key
self.covars_key = dataset.covars_key
self.min_dose = dataset.drugs[dataset.drugs > 0].min().item()
self.max_dose = dataset.drugs[dataset.drugs > 0].max().item()
self.model = model
self.var_names = dataset.var_names
self.unique_perts = list(dataset.perts_dict.keys())
self.unique_сovars = list(dataset.covars_dict.keys())
self.num_drugs = dataset.num_drugs
self.perts_dict = dataset.perts_dict
self.covars_dict = dataset.covars_dict
self.drug_ohe = torch.Tensor(list(dataset.perts_dict.values()))
self.covars_ohe = torch.LongTensor(list(dataset.covars_dict.values()))
self.emb_covars = None
self.emb_perts = None
self.seen_covars_perts = None
self.comb_emb = None
self.control_cat = None
self.seen_covars_perts = {}
for k in datasets.keys():
self.seen_covars_perts[k] = np.unique(datasets[k].pert_categories)
self.measured_points = {}
self.num_measured_points = {}
for k in datasets.keys():
self.measured_points[k] = {}
self.num_measured_points[k] = {}
for pert in np.unique(datasets[k].pert_categories):
num_points = len(np.where(datasets[k].pert_categories == pert)[0])
self.num_measured_points[k][pert] = num_points
cov, drug, dose = pert.split('_')
if not('+' in dose):
dose = float(dose)
if cov in self.measured_points[k].keys():
if drug in self.measured_points[k][cov].keys():
self.measured_points[k][cov][drug].append(dose)
else:
self.measured_points[k][cov][drug] = [dose]
else:
self.measured_points[k][cov] = {drug: [dose]}
self.measured_points['all'] = copy.deepcopy(self.measured_points['training'])
for cov in self.measured_points['ood'].keys():
for pert in self.measured_points['ood'][cov].keys():
if pert in self.measured_points['training'][cov].keys():
self.measured_points['all'][cov][pert] =\
self.measured_points['training'][cov][pert].copy()+\
self.measured_points['ood'][cov][pert].copy()
else:
self.measured_points['all'][cov][pert] =\
self.measured_points['ood'][cov][pert].copy()