def convert_func()

in evaluation/latex2sympy/latex2sympy2.py [0:0]


def convert_func(func):
    if func.func_normal_single_arg():
        if func.L_PAREN():  # function called with parenthesis
            arg = convert_func_arg(func.func_single_arg())
        else:
            arg = convert_func_arg(func.func_single_arg_noparens())

        name = func.func_normal_single_arg().start.text[1:]

        # change arc<trig> -> a<trig>
        if name in ["arcsin", "arccos", "arctan", "arccsc", "arcsec",
                    "arccot"]:
            name = "a" + name[3:]
            expr = getattr(sympy.functions, name)(arg, evaluate=False)
        elif name in ["arsinh", "arcosh", "artanh"]:
            name = "a" + name[2:]
            expr = getattr(sympy.functions, name)(arg, evaluate=False)
        elif name in ["arcsinh", "arccosh", "arctanh"]:
            name = "a" + name[3:]
            expr = getattr(sympy.functions, name)(arg, evaluate=False)
        elif name == "operatorname":
            operatorname = func.func_normal_single_arg().func_operator_name.getText()

            if operatorname in ["arsinh", "arcosh", "artanh"]:
                operatorname = "a" + operatorname[2:]
                expr = getattr(sympy.functions, operatorname)(arg, evaluate=False)
            elif operatorname in ["arcsinh", "arccosh", "arctanh"]:
                operatorname = "a" + operatorname[3:]
                expr = getattr(sympy.functions, operatorname)(arg, evaluate=False)
            elif operatorname == "floor":
                expr = handle_floor(arg)
            elif operatorname == "ceil":
                expr = handle_ceil(arg)
            elif operatorname == 'eye':
                expr = sympy.eye(arg)
            elif operatorname == 'rank':
                expr = sympy.Integer(arg.rank())
            elif operatorname in ['trace', 'tr']:
                expr = arg.trace()
            elif operatorname == 'rref':
                expr = arg.rref()[0]
            elif operatorname == 'nullspace':
                expr = arg.nullspace()
            elif operatorname == 'norm':
                expr = arg.norm()
            elif operatorname == 'cols':
                expr = [arg.col(i) for i in range(arg.cols)]
            elif operatorname == 'rows':
                expr = [arg.row(i) for i in range(arg.rows)]
            elif operatorname in ['eig', 'eigen', 'diagonalize']:
                expr = arg.diagonalize()
            elif operatorname in ['eigenvals', 'eigenvalues']:
                expr = arg.eigenvals()
            elif operatorname in ['eigenvects', 'eigenvectors']:
                expr = arg.eigenvects()
            elif operatorname in ['svd', 'SVD']:
                expr = arg.singular_value_decomposition()
        elif name in ["log", "ln"]:
            if func.subexpr():
                if func.subexpr().atom():
                    base = convert_atom(func.subexpr().atom())
                else:
                    base = convert_expr(func.subexpr().expr())
            elif name == "log":
                base = 10
            elif name == "ln":
                base = sympy.E
            expr = sympy.log(arg, base, evaluate=False)
        elif name in ["exp", "exponentialE"]:
            expr = sympy.exp(arg)
        elif name == "floor":
            expr = handle_floor(arg)
        elif name == "ceil":
            expr = handle_ceil(arg)
        elif name == 'det':
            expr = arg.det()

        func_pow = None
        should_pow = True
        if func.supexpr():
            if func.supexpr().expr():
                func_pow = convert_expr(func.supexpr().expr())
            else:
                func_pow = convert_atom(func.supexpr().atom())

        if name in ["sin", "cos", "tan", "csc", "sec", "cot", "sinh", "cosh", "tanh"]:
            if func_pow == -1:
                name = "a" + name
                should_pow = False
            expr = getattr(sympy.functions, name)(arg, evaluate=False)

        if func_pow and should_pow:
            expr = sympy.Pow(expr, func_pow, evaluate=False)

        return expr

    elif func.func_normal_multi_arg():
        if func.L_PAREN():  # function called with parenthesis
            args = func.func_multi_arg().getText().split(",")
        else:
            args = func.func_multi_arg_noparens().split(",")

        args = list(map(lambda arg: latex2sympy(arg, VARIABLE_VALUES), args))
        name = func.func_normal_multi_arg().start.text[1:]

        if name == "operatorname":
            operatorname = func.func_normal_multi_arg().func_operator_name.getText()
            if operatorname in ["gcd", "lcm"]:
                expr = handle_gcd_lcm(operatorname, args)
            elif operatorname == 'zeros':
                expr = sympy.zeros(*args)
            elif operatorname == 'ones':
                expr = sympy.ones(*args)
            elif operatorname == 'diag':
                expr = sympy.diag(*args)
            elif operatorname == 'hstack':
                expr = sympy.Matrix.hstack(*args)
            elif operatorname == 'vstack':
                expr = sympy.Matrix.vstack(*args)
            elif operatorname in ['orth', 'ortho', 'orthogonal', 'orthogonalize']:
                if len(args) == 1:
                    arg = args[0]
                    expr = sympy.matrices.GramSchmidt([arg.col(i) for i in range(arg.cols)], True)
                else:
                    expr = sympy.matrices.GramSchmidt(args, True)
        elif name in ["gcd", "lcm"]:
            expr = handle_gcd_lcm(name, args)
        elif name in ["max", "min"]:
            name = name[0].upper() + name[1:]
            expr = getattr(sympy.functions, name)(*args, evaluate=False)

        func_pow = None
        should_pow = True
        if func.supexpr():
            if func.supexpr().expr():
                func_pow = convert_expr(func.supexpr().expr())
            else:
                func_pow = convert_atom(func.supexpr().atom())

        if func_pow and should_pow:
            expr = sympy.Pow(expr, func_pow, evaluate=False)

        return expr
    elif func.atom_expr_no_supexpr():
        # define a function
        f = sympy.Function(func.atom_expr_no_supexpr().getText())
        # args
        args = func.func_common_args().getText().split(",")
        if args[-1] == '':
            args = args[:-1]
        args = [latex2sympy(arg, VARIABLE_VALUES) for arg in args]
        # supexpr
        if func.supexpr():
            if func.supexpr().expr():
                expr = convert_expr(func.supexpr().expr())
            else:
                expr = convert_atom(func.supexpr().atom())
            return sympy.Pow(f(*args), expr, evaluate=False)
        else:
            return f(*args)
    elif func.FUNC_INT():
        return handle_integral(func)
    elif func.FUNC_SQRT():
        expr = convert_expr(func.base)
        if func.root:
            r = convert_expr(func.root)
            return sympy.Pow(expr, 1 / r, evaluate=False)
        else:
            return sympy.Pow(expr, sympy.S.Half, evaluate=False)
    elif func.FUNC_SUM():
        return handle_sum_or_prod(func, "summation")
    elif func.FUNC_PROD():
        return handle_sum_or_prod(func, "product")
    elif func.FUNC_LIM():
        return handle_limit(func)
    elif func.EXP_E():
        return handle_exp(func)