def _handle_function()

in odps/df/backends/odpssql/rewriter.py [0:0]


    def _handle_function(self, expr, raw_inputs):
        # Since Python UDF cannot support decimal field,
        # We will try to replace the decimal input with string.
        # If the output is decimal, we will also try to replace it with string,
        # and then cast back to decimal
        def no_output_decimal():
            if isinstance(expr, (SequenceExpr, Scalar)):
                return expr.dtype != types.decimal
            else:
                return all(t != types.decimal for t in expr.schema.types)

        if isinstance(expr, Func):
            return

        if all(input.dtype != types.decimal for input in raw_inputs) and \
                no_output_decimal():
            return

        inputs = list(raw_inputs)
        for input in raw_inputs:
            if input.dtype == types.decimal:
                self._sub(input, input.astype('string'), parents=[expr, ])
        if hasattr(expr, '_raw_inputs'):
            expr._raw_inputs = inputs
        else:
            assert len(inputs) == 1
            expr._raw_input = inputs[0]

        attrs = get_attrs(expr)
        attr_values = dict((attr, getattr(expr, attr, None)) for attr in attrs)
        if isinstance(expr, (SequenceExpr, Scalar)):
            if expr.dtype == types.decimal:
                if isinstance(expr, SequenceExpr):
                    attr_values['_data_type'] = types.string
                    attr_values['_source_data_type'] = types.string
                else:
                    attr_values['_value_type'] = types.string
                    attr_values['_source_value_type'] = types.string
            sub = type(expr)._new(**attr_values)

            if expr.dtype == types.decimal:
                sub = sub.astype('decimal')
        else:
            names = expr.schema.names
            tps = expr.schema.types
            cast_names = set()
            if any(tp == types.decimal for tp in tps):
                new_tps = []
                for name, tp in zip(names, tps):
                    if tp != types.decimal:
                        new_tps.append(tp)
                        continue
                    new_tps.append(types.string)
                    cast_names.add(name)
                if len(cast_names) > 0:
                    attr_values['_schema'] = TableSchema.from_lists(names, new_tps)
            sub = type(expr)(**attr_values)

            if len(cast_names) > 0:
                fields = []
                for name in names:
                    if name in cast_names:
                        fields.append(sub[name].astype('decimal'))
                    else:
                        fields.append(name)
                sub = sub[fields]

        self._sub(expr, sub)