in models/dr_constant.py [0:0]
def __init__(self, config, theta, treatments, dev1_hot, precisions=None, version=1):
super(DR_Constant_RHS, self).__init__(config, theta, treatments, dev1_hot)
# Pass in a class instance for dynamic (neural) precisions. If None, then it's expected that you have latent
# variables for the precisions, assigned as part of BaseModel.expand_precisions_by_time()
self.precisions = precisions
self.n_batch = theta.get_n_batch()
self.n_iwae = theta.get_n_samples()
self.n_species = 8
# tile treatments, one per iwae sample
treatments_transformed = torch.clamp(torch.exp(treatments) - 1.0, 1e-12, 1e6)
c6a, c12a = torch.unbind(treatments_transformed, axis=1)
c6 = torch.transpose(c6a.repeat([self.n_iwae, 1]), 0, 1)
c12 = torch.transpose(c12a.repeat([self.n_iwae, 1]), 0, 1)
# need to clip these to avoid overflow
self.r = torch.clamp(theta.r, 0.0, 4.0)
self.K = torch.clamp(theta.K, 0.0, 4.0)
self.tlag = theta.tlag
self.rc = theta.rc
self.a530 = theta.a530
self.a480 = theta.a480
self.drfp = torch.clamp(theta.drfp, 1e-12, 2.0)
self.dyfp = torch.clamp(theta.dyfp, 1e-12, 2.0)
self.dcfp = torch.clamp(theta.dcfp, 1e-12, 2.0)
self.dR = torch.clamp(theta.dR, 1e-12, 5.0)
self.dS = torch.clamp(theta.dS, 1e-12, 5.0)
self.e76 = theta.e76
self.e81 = theta.e81
self.aCFP = theta.aCFP
self.aYFP = theta.aYFP
self.KGR_76 = theta.KGR_76
self.KGS_76 = theta.KGS_76
self.KGR_81 = theta.KGR_81
self.KGS_81 = theta.KGS_81
self.aR = theta.aR
self.aS = theta.aS
# Activation constants for convenience
nR = torch.clamp(theta.nR, 0.5, 3.0)
nS = torch.clamp(theta.nS, 0.5, 3.0)
lb = 1e-12
ub = 1e0
if version == 1:
KR6 = torch.clamp(theta.KR6, lb, ub)
KR12 = torch.clamp(theta.KR12, lb, ub)
KS6 = torch.clamp(theta.KS6, lb, ub)
KS12 = torch.clamp(theta.KS12, lb, ub)
self.fracLuxR = (power(KR6 * c6, nR) + power(KR12 * c12, nR)) / power(1.0 + KR6 * c6 + KR12 * c12, nR)
self.fracLasR = (power(KS6 * c6, nS) + power(KS12 * c12, nS)) / power(1.0 + KS6 * c6 + KS12 * c12, nS)
elif version == 2:
eS6 = torch.clamp(theta.eS6, lb, ub)
eR12 = torch.clamp(theta.eR12, lb, ub)
self.fracLuxR = power(c6, nR) + power(eR12 * c12, nR)
self.fracLasR = power(eS6 * c6, nS) + power(c12, nS)
else:
raise Exception("Unknown version of DR_Constant: %d" % version)