in odps/df/backends/odpssql/rewriter.py [0:0]
def _handle_function(self, expr, raw_inputs):
# Since Python UDF cannot support decimal field,
# We will try to replace the decimal input with string.
# If the output is decimal, we will also try to replace it with string,
# and then cast back to decimal
def no_output_decimal():
if isinstance(expr, (SequenceExpr, Scalar)):
return expr.dtype != types.decimal
else:
return all(t != types.decimal for t in expr.schema.types)
if isinstance(expr, Func):
return
if all(input.dtype != types.decimal for input in raw_inputs) and \
no_output_decimal():
return
inputs = list(raw_inputs)
for input in raw_inputs:
if input.dtype == types.decimal:
self._sub(input, input.astype('string'), parents=[expr, ])
if hasattr(expr, '_raw_inputs'):
expr._raw_inputs = inputs
else:
assert len(inputs) == 1
expr._raw_input = inputs[0]
attrs = get_attrs(expr)
attr_values = dict((attr, getattr(expr, attr, None)) for attr in attrs)
if isinstance(expr, (SequenceExpr, Scalar)):
if expr.dtype == types.decimal:
if isinstance(expr, SequenceExpr):
attr_values['_data_type'] = types.string
attr_values['_source_data_type'] = types.string
else:
attr_values['_value_type'] = types.string
attr_values['_source_value_type'] = types.string
sub = type(expr)._new(**attr_values)
if expr.dtype == types.decimal:
sub = sub.astype('decimal')
else:
names = expr.schema.names
tps = expr.schema.types
cast_names = set()
if any(tp == types.decimal for tp in tps):
new_tps = []
for name, tp in zip(names, tps):
if tp != types.decimal:
new_tps.append(tp)
continue
new_tps.append(types.string)
cast_names.add(name)
if len(cast_names) > 0:
attr_values['_schema'] = TableSchema.from_lists(names, new_tps)
sub = type(expr)(**attr_values)
if len(cast_names) > 0:
fields = []
for name in names:
if name in cast_names:
fields.append(sub[name].astype('decimal'))
else:
fields.append(name)
sub = sub[fields]
self._sub(expr, sub)