in src/envs/ode.py [0:0]
def __init__(self, params):
self.max_degree = params.max_degree
self.min_degree = params.min_degree
assert self.min_degree >= 2
assert self.max_degree >= self.min_degree
self.max_ops = 200
self.max_int = params.max_int
self.positive = params.positive
self.nonnull = params.nonnull
self.predict_jacobian = params.predict_jacobian
self.predict_gramian = params.predict_gramian
self.qualitative = params.qualitative
self.allow_complex = params.allow_complex
self.reversed_eval = params.reversed_eval
self.euclidian_metric = params.euclidian_metric
self.auxiliary_task = params.auxiliary_task
self.tau = params.tau
self.gramian_norm1 = params.gramian_norm1
self.gramian_tolerance = params.gramian_tolerance
self.min_expr_len_factor_cspeed = params.min_expr_len_factor_cspeed
self.max_expr_len_factor_cspeed = params.max_expr_len_factor_cspeed
self.custom_unary_probs = params.custom_unary_probs
self.prob_trigs = params.prob_trigs
self.prob_arc_trigs = params.prob_arc_trigs
self.prob_logs = params.prob_logs
self.prob_others = 1.0 - self.prob_trigs - self.prob_arc_trigs - self.prob_logs
assert self.prob_others >= 0.0
self.prob_int = params.prob_int
self.precision = params.precision
self.jacobian_precision = params.jacobian_precision
self.max_len = params.max_len
self.eval_value = params.eval_value
self.skip_zero_gradient = params.skip_zero_gradient
self.prob_positive = params.prob_positive
self.np_positive = np.zeros(self.max_degree + 1, dtype=int)
self.np_total = np.zeros(self.max_degree + 1, dtype=int)
self.complex_input = "fourier" in params.tasks
self.SYMPY_OPERATORS = {
# Elementary functions
sp.Add: "+",
sp.Mul: "*",
sp.Pow: "^",
sp.exp: "exp",
sp.log: "ln",
# sp.Abs: 'abs',
# sp.sign: 'sign',
# Trigonometric Functions
sp.sin: "sin",
sp.cos: "cos",
sp.tan: "tan",
# sp.cot: 'cot',
# sp.sec: 'sec',
# sp.csc: 'csc',
# Trigonometric Inverses
sp.asin: "asin",
sp.acos: "acos",
sp.atan: "atan",
# sp.acot: 'acot',
# sp.asec: 'asec',
# sp.acsc: 'acsc',
sp.DiracDelta: "delta0",
}
self.operators_conv = {
"+": 2,
"-": 2,
"*": 2,
"/": 2,
"sqrt": 1,
"exp": 1,
"ln": 1,
"sin": 1,
"cos": 1,
"tan": 1,
"asin": 1,
"acos": 1,
"atan": 1,
}
self.trig_ops = ["sin", "cos", "tan"]
self.arctrig_ops = ["asin", "acos", "atan"]
self.exp_ops = ["exp", "ln"]
self.other_ops = ["sqrt"]
self.operators_lyap = {
"+": 2,
"-": 2,
"*": 2,
"/": 2,
"^": 2,
"sqrt": 1,
"exp": 1,
"ln": 1,
"sin": 1,
"cos": 1,
"tan": 1,
"asin": 1,
"acos": 1,
"atan": 1,
"delta0": 1,
}
self.operators = (
self.operators_lyap if "fourier" in params.tasks else self.operators_conv
)
self.unaries = [o for o in self.operators.keys() if self.operators[o] == 1]
self.binaries = [o for o in self.operators.keys() if self.operators[o] == 2]
self.unary = len(self.unaries) > 0
self.predict_bounds = params.predict_bounds
assert self.max_int >= 1
assert self.precision >= 2
# variables
self.variables = OrderedDict(
{f"x{i}": sp.Symbol(f"x{i}") for i in range(2 * self.max_degree)}
)
self.eval_point = OrderedDict(
{
self.variables[f"x{i}"]: self.eval_value
for i in range(2 * self.max_degree)
}
)
# symbols / elements
self.constants = ["pi", "E"]
self.symbols = ["I", "INT+", "INT-", "FLOAT+", "FLOAT-", ".", "10^"]
self.elements = [str(i) for i in range(10)]
# SymPy elements
self.local_dict = {}
for k, v in list(self.variables.items()):
assert k not in self.local_dict
self.local_dict[k] = v
# vocabulary
self.words = (
SPECIAL_WORDS
+ self.constants
+ list(self.variables.keys())
+ list(self.operators.keys())
+ self.symbols
+ self.elements
)
self.id2word = {i: s for i, s in enumerate(self.words)}
self.word2id = {s: i for i, s in self.id2word.items()}
assert len(self.words) == len(set(self.words))
# number of words / indices
self.n_words = params.n_words = len(self.words)
self.eos_index = params.eos_index = 0
self.pad_index = params.pad_index = 1
self.func_separator = "<SPECIAL_3>" # separate equations in a system
self.line_separator = "<SPECIAL_4>" # separate lines in a matrix
self.list_separator = "<SPECIAL_5>" # separate elements in a list
self.mtrx_separator = "<SPECIAL_6>" # end of a matrix
self.neg_inf = "<SPECIAL_7>" # negative infinity
self.pos_inf = "<SPECIAL_8>" # positive infinity
logger.info(f"words: {self.word2id}")
# initialize distribution for binary and unary-binary trees
# self.max_ops + 1 should be enough
self.distrib = self.generate_dist(2 * self.max_ops)