def __init__()

in src/envs/ode.py [0:0]


    def __init__(self, params):

        self.max_degree = params.max_degree
        self.min_degree = params.min_degree
        assert self.min_degree >= 2
        assert self.max_degree >= self.min_degree

        self.max_ops = 200

        self.max_int = params.max_int
        self.positive = params.positive
        self.nonnull = params.nonnull
        self.predict_jacobian = params.predict_jacobian
        self.predict_gramian = params.predict_gramian
        self.qualitative = params.qualitative
        self.allow_complex = params.allow_complex
        self.reversed_eval = params.reversed_eval
        self.euclidian_metric = params.euclidian_metric
        self.auxiliary_task = params.auxiliary_task
        self.tau = params.tau
        self.gramian_norm1 = params.gramian_norm1
        self.gramian_tolerance = params.gramian_tolerance

        self.min_expr_len_factor_cspeed = params.min_expr_len_factor_cspeed
        self.max_expr_len_factor_cspeed = params.max_expr_len_factor_cspeed

        self.custom_unary_probs = params.custom_unary_probs
        self.prob_trigs = params.prob_trigs
        self.prob_arc_trigs = params.prob_arc_trigs
        self.prob_logs = params.prob_logs
        self.prob_others = 1.0 - self.prob_trigs - self.prob_arc_trigs - self.prob_logs
        assert self.prob_others >= 0.0

        self.prob_int = params.prob_int
        self.precision = params.precision
        self.jacobian_precision = params.jacobian_precision

        self.max_len = params.max_len
        self.eval_value = params.eval_value
        self.skip_zero_gradient = params.skip_zero_gradient
        self.prob_positive = params.prob_positive

        self.np_positive = np.zeros(self.max_degree + 1, dtype=int)
        self.np_total = np.zeros(self.max_degree + 1, dtype=int)
        self.complex_input = "fourier" in params.tasks

        self.SYMPY_OPERATORS = {
            # Elementary functions
            sp.Add: "+",
            sp.Mul: "*",
            sp.Pow: "^",
            sp.exp: "exp",
            sp.log: "ln",
            # sp.Abs: 'abs',
            # sp.sign: 'sign',
            # Trigonometric Functions
            sp.sin: "sin",
            sp.cos: "cos",
            sp.tan: "tan",
            # sp.cot: 'cot',
            # sp.sec: 'sec',
            # sp.csc: 'csc',
            # Trigonometric Inverses
            sp.asin: "asin",
            sp.acos: "acos",
            sp.atan: "atan",
            # sp.acot: 'acot',
            # sp.asec: 'asec',
            # sp.acsc: 'acsc',
            sp.DiracDelta: "delta0",
        }

        self.operators_conv = {
            "+": 2,
            "-": 2,
            "*": 2,
            "/": 2,
            "sqrt": 1,
            "exp": 1,
            "ln": 1,
            "sin": 1,
            "cos": 1,
            "tan": 1,
            "asin": 1,
            "acos": 1,
            "atan": 1,
        }

        self.trig_ops = ["sin", "cos", "tan"]
        self.arctrig_ops = ["asin", "acos", "atan"]
        self.exp_ops = ["exp", "ln"]
        self.other_ops = ["sqrt"]

        self.operators_lyap = {
            "+": 2,
            "-": 2,
            "*": 2,
            "/": 2,
            "^": 2,
            "sqrt": 1,
            "exp": 1,
            "ln": 1,
            "sin": 1,
            "cos": 1,
            "tan": 1,
            "asin": 1,
            "acos": 1,
            "atan": 1,
            "delta0": 1,
        }

        self.operators = (
            self.operators_lyap if "fourier" in params.tasks else self.operators_conv
        )
        self.unaries = [o for o in self.operators.keys() if self.operators[o] == 1]
        self.binaries = [o for o in self.operators.keys() if self.operators[o] == 2]
        self.unary = len(self.unaries) > 0
        self.predict_bounds = params.predict_bounds

        assert self.max_int >= 1
        assert self.precision >= 2

        # variables
        self.variables = OrderedDict(
            {f"x{i}": sp.Symbol(f"x{i}") for i in range(2 * self.max_degree)}
        )

        self.eval_point = OrderedDict(
            {
                self.variables[f"x{i}"]: self.eval_value
                for i in range(2 * self.max_degree)
            }
        )

        # symbols / elements
        self.constants = ["pi", "E"]

        self.symbols = ["I", "INT+", "INT-", "FLOAT+", "FLOAT-", ".", "10^"]
        self.elements = [str(i) for i in range(10)]

        # SymPy elements
        self.local_dict = {}
        for k, v in list(self.variables.items()):
            assert k not in self.local_dict
            self.local_dict[k] = v

        # vocabulary
        self.words = (
            SPECIAL_WORDS
            + self.constants
            + list(self.variables.keys())
            + list(self.operators.keys())
            + self.symbols
            + self.elements
        )
        self.id2word = {i: s for i, s in enumerate(self.words)}
        self.word2id = {s: i for i, s in self.id2word.items()}
        assert len(self.words) == len(set(self.words))

        # number of words / indices
        self.n_words = params.n_words = len(self.words)
        self.eos_index = params.eos_index = 0
        self.pad_index = params.pad_index = 1
        self.func_separator = "<SPECIAL_3>"  # separate equations in a system
        self.line_separator = "<SPECIAL_4>"  # separate lines in a matrix
        self.list_separator = "<SPECIAL_5>"  # separate elements in a list
        self.mtrx_separator = "<SPECIAL_6>"  # end of a matrix
        self.neg_inf = "<SPECIAL_7>"  # negative infinity
        self.pos_inf = "<SPECIAL_8>"  # positive infinity
        logger.info(f"words: {self.word2id}")

        # initialize distribution for binary and unary-binary trees
        # self.max_ops + 1 should be enough
        self.distrib = self.generate_dist(2 * self.max_ops)