def __init__()

in models/dr_constant.py [0:0]


    def __init__(self, config, theta, treatments, dev1_hot, precisions=None, version=1):
        super(DR_Constant_RHS, self).__init__(config, theta, treatments, dev1_hot)

        # Pass in a class instance for dynamic (neural) precisions. If None, then it's expected that you have latent
        # variables for the precisions, assigned as part of BaseModel.expand_precisions_by_time()
        self.precisions = precisions

        self.n_batch = theta.get_n_batch()
        self.n_iwae = theta.get_n_samples()
        self.n_species = 8

        # tile treatments, one per iwae sample
        treatments_transformed = torch.clamp(torch.exp(treatments) - 1.0, 1e-12, 1e6)
        c6a, c12a = torch.unbind(treatments_transformed, axis=1)
        c6 = torch.transpose(c6a.repeat([self.n_iwae, 1]), 0, 1)
        c12 = torch.transpose(c12a.repeat([self.n_iwae, 1]), 0, 1)

        # need to clip these to avoid overflow
        self.r = torch.clamp(theta.r, 0.0, 4.0)
        self.K = torch.clamp(theta.K, 0.0, 4.0)
        self.tlag = theta.tlag
        self.rc = theta.rc
        self.a530 = theta.a530
        self.a480 = theta.a480

        self.drfp = torch.clamp(theta.drfp, 1e-12, 2.0)
        self.dyfp = torch.clamp(theta.dyfp, 1e-12, 2.0)
        self.dcfp = torch.clamp(theta.dcfp, 1e-12, 2.0)
        self.dR = torch.clamp(theta.dR, 1e-12, 5.0)
        self.dS = torch.clamp(theta.dS, 1e-12, 5.0)

        self.e76 = theta.e76
        self.e81 = theta.e81
        self.aCFP = theta.aCFP
        self.aYFP = theta.aYFP
        self.KGR_76 = theta.KGR_76
        self.KGS_76 = theta.KGS_76
        self.KGR_81 = theta.KGR_81
        self.KGS_81 = theta.KGS_81

        self.aR = theta.aR
        self.aS = theta.aS

        # Activation constants for convenience
        nR = torch.clamp(theta.nR, 0.5, 3.0)
        nS = torch.clamp(theta.nS, 0.5, 3.0)
        lb = 1e-12
        ub = 1e0
        if version == 1:
            KR6 = torch.clamp(theta.KR6, lb, ub)
            KR12 = torch.clamp(theta.KR12, lb, ub)
            KS6 = torch.clamp(theta.KS6, lb, ub)
            KS12 = torch.clamp(theta.KS12, lb, ub)
            self.fracLuxR = (power(KR6 * c6, nR) + power(KR12 * c12, nR)) / power(1.0 + KR6 * c6 + KR12 * c12, nR)
            self.fracLasR = (power(KS6 * c6, nS) + power(KS12 * c12, nS)) / power(1.0 + KS6 * c6 + KS12 * c12, nS)
        elif version == 2:
            eS6 = torch.clamp(theta.eS6, lb, ub)
            eR12 = torch.clamp(theta.eR12, lb, ub)
            self.fracLuxR = power(c6, nR) + power(eR12 * c12, nR)
            self.fracLasR = power(eS6 * c6, nS) + power(c12, nS)
        else:
            raise Exception("Unknown version of DR_Constant: %d" % version)