def _get_converter()

in pylib/cqlshlib/copyutil.py [0:0]


    def _get_converter(self, cql_type):
        """
        Return a function that converts a string into a value the can be passed
        into BoundStatement.bind() for the given cql type. See cassandra.cqltypes
        for more details.
        """
        unprotect = self.unprotect

        def convert(t, v):
            v = unprotect(v)
            if v == self.nullval:
                return self.get_null_val()
            return converters.get(t.typename, convert_unknown)(v, ct=t)

        def convert_mandatory(t, v):
            v = unprotect(v)
            # we can't distinguish between empty strings and null values in csv. Null values are not supported in
            # collections, so it must be an empty string.
            if v == self.nullval and not issubclass(t, VarcharType):
                raise ParseError('Empty values are not allowed')
            return converters.get(t.typename, convert_unknown)(v, ct=t)

        def convert_blob(v, **_):
            if sys.version_info.major >= 3:
                return bytes.fromhex(v[2:])
            else:
                return BlobType(v[2:].decode("hex"))

        def convert_text(v, **_):
            return str(v)

        def convert_uuid(v, **_):
            return UUID(v)

        def convert_bool(v, **_):
            return True if v.lower() == self.boolean_styles[0].lower() else False

        def get_convert_integer_fcn(adapter=int):
            """
            Return a slow and a fast integer conversion function depending on self.thousands_sep
            """
            if self.thousands_sep:
                return lambda v, ct=cql_type: adapter(v.replace(self.thousands_sep, ''))
            else:
                return lambda v, ct=cql_type: adapter(v)

        def get_convert_decimal_fcn(adapter=float):
            """
            Return a slow and a fast decimal conversion function depending on self.thousands_sep and self.decimal_sep
            """
            empty_str = ''
            dot_str = '.'
            if self.thousands_sep and self.decimal_sep:
                return lambda v, ct=cql_type: \
                    adapter(v.replace(self.thousands_sep, empty_str).replace(self.decimal_sep, dot_str))
            elif self.thousands_sep:
                return lambda v, ct=cql_type: adapter(v.replace(self.thousands_sep, empty_str))
            elif self.decimal_sep:
                return lambda v, ct=cql_type: adapter(v.replace(self.decimal_sep, dot_str))
            else:
                return lambda v, ct=cql_type: adapter(v)

        def split(val, sep=','):
            """
            Split "val" into a list of values whenever the separator "sep" is found, but
            ignore separators inside parentheses or single quotes, except for the two
            outermost parentheses, which will be ignored. This method is called when parsing composite
            types, "val" should be at least 2 characters long, the first char should be an
            open parenthesis and the last char should be a matching closing parenthesis. We could also
            check exactly which parenthesis type depending on the caller, but I don't want to enforce
            too many checks that don't necessarily provide any additional benefits, and risk breaking
            data that could previously be imported, even if strictly speaking it is incorrect CQL.
            For example, right now we accept sets that start with '[' and ']', I don't want to break this
            by enforcing '{' and '}' in a minor release.
            """
            def is_open_paren(cc):
                return cc == '{' or cc == '[' or cc == '('

            def is_close_paren(cc):
                return cc == '}' or cc == ']' or cc == ')'

            def paren_match(c1, c2):
                return (c1 == '{' and c2 == '}') or (c1 == '[' and c2 == ']') or (c1 == '(' and c2 == ')')

            if len(val) < 2 or not paren_match(val[0], val[-1]):
                raise ParseError('Invalid composite string, it should start and end with matching parentheses: {}'
                                 .format(val))

            ret = []
            last = 1
            level = 0
            quote = False
            for i, c in enumerate(val):
                if c == '\'':
                    quote = not quote
                elif not quote:
                    if is_open_paren(c):
                        level += 1
                    elif is_close_paren(c):
                        level -= 1
                    elif c == sep and level == 1:
                        ret.append(val[last:i])
                        last = i + 1
            else:
                if last < len(val) - 1:
                    ret.append(val[last:-1])

            return ret

        # this should match all possible CQL and CQLSH datetime formats
        p = re.compile(r"(\d{4})-(\d{2})-(\d{2})\s?(?:'T')?"  # YYYY-MM-DD[( |'T')]
                       + r"(?:(\d{2}):(\d{2})(?::(\d{2})(?:\.(\d{1,6}))?))?"  # [HH:MM[:SS[.NNNNNN]]]
                       + r"(?:([+\-])(\d{2}):?(\d{2}))?")  # [(+|-)HH[:]MM]]

        def convert_datetime(val, **_):
            try:
                dtval = datetime.datetime.strptime(val, self.date_time_format)
                return dtval.timestamp() * 1000
            except ValueError:
                pass  # if it's not in the default format we try CQL formats

            m = p.match(val)
            if not m:
                try:
                    # in case of overflow COPY TO prints dates as milliseconds from the epoch, see
                    # deserialize_date_fallback_int in cqlsh.py
                    return int(val)
                except ValueError:
                    raise ValueError("can't interpret %r as a date with format %s or as int" % (val,
                                                                                                self.date_time_format))

            # https://docs.python.org/3/library/time.html#time.struct_time
            tval = time.struct_time((int(m.group(1)), int(m.group(2)), int(m.group(3)),  # year, month, day
                                    int(m.group(4)) if m.group(4) else 0,  # hour
                                    int(m.group(5)) if m.group(5) else 0,  # minute
                                    int(m.group(6)) if m.group(6) else 0,  # second
                                    0, 1, -1))  # day of week, day of year, dst-flag

            # convert sub-seconds (a number between 1 and 6 digits) to milliseconds
            milliseconds = 0 if not m.group(7) else int(m.group(7)) * pow(10, 3 - len(m.group(7)))

            if m.group(8):
                offset = (int(m.group(9)) * 3600 + int(m.group(10)) * 60) * int(m.group(8) + '1')
            else:
                offset = -time.timezone

            # scale seconds to millis for the raw value
            return ((timegm(tval) + offset) * 1000) + milliseconds

        def convert_date(v, **_):
            return Date(v)

        def convert_time(v, **_):
            return Time(v)

        def convert_tuple(val, ct=cql_type):
            return tuple(convert_mandatory(t, v) for t, v in zip(ct.subtypes, split(val)))

        def convert_list(val, ct=cql_type):
            return tuple(convert_mandatory(ct.subtypes[0], v) for v in split(val))

        def convert_set(val, ct=cql_type):
            return frozenset(convert_mandatory(ct.subtypes[0], v) for v in split(val))

        def convert_map(val, ct=cql_type):
            """
            See ImmutableDict above for a discussion of why a special object is needed here.
            """
            split_format_str = '{%s}'
            sep = ':'
            return ImmutableDict(frozenset((convert_mandatory(ct.subtypes[0], v[0]), convert(ct.subtypes[1], v[1]))
                                 for v in [split(split_format_str % vv, sep=sep) for vv in split(val)]))

        def convert_vector(val, ct=cql_type):
            string_coordinates = split(val)
            if len(string_coordinates) != ct.vector_size:
                raise ParseError("The length of given vector value '%d' is not equal to the vector size from the type definition '%d'" % (len(string_coordinates), ct.vector_size))
            return [convert_mandatory(ct.subtype, v) for v in string_coordinates]

        def convert_user_type(val, ct=cql_type):
            """
            A user type is a dictionary except that we must convert each key into
            an attribute, so we are using named tuples. It must also be hashable,
            so we cannot use dictionaries. Maybe there is a way to instantiate ct
            directly but I could not work it out.
            Also note that it is possible that the subfield names in the csv are in the
            wrong order, so we must sort them according to ct.fieldnames, see CASSANDRA-12959.
            """
            split_format_str = '{%s}'
            sep = ':'
            vals = [v for v in [split(split_format_str % vv, sep=sep) for vv in split(val)]]
            dict_vals = dict((unprotect(v[0]), v[1]) for v in vals)
            sorted_converted_vals = [(n, convert(t, dict_vals[n]) if n in dict_vals else self.get_null_val())
                                     for n, t in zip(ct.fieldnames, ct.subtypes)]
            ret_type = namedtuple(ct.typename, [v[0] for v in sorted_converted_vals])
            return ret_type(*tuple(v[1] for v in sorted_converted_vals))

        def convert_single_subtype(val, ct=cql_type):
            return converters.get(ct.subtypes[0].typename, convert_unknown)(val, ct=ct.subtypes[0])

        def convert_unknown(val, ct=cql_type):
            if issubclass(ct, UserType):
                return convert_user_type(val, ct=ct)
            elif issubclass(ct, ReversedType):
                return convert_single_subtype(val, ct=ct)

            printdebugmsg("Unknown type %s (%s) for val %s" % (ct, ct.typename, val))
            return val

        converters = {
            'blob': convert_blob,
            'decimal': get_convert_decimal_fcn(adapter=Decimal),
            'uuid': convert_uuid,
            'boolean': convert_bool,
            'tinyint': get_convert_integer_fcn(),
            'ascii': convert_text,
            'float': get_convert_decimal_fcn(),
            'double': get_convert_decimal_fcn(),
            'bigint': get_convert_integer_fcn(adapter=int),
            'int': get_convert_integer_fcn(),
            'varint': get_convert_integer_fcn(),
            'inet': convert_text,
            'counter': get_convert_integer_fcn(adapter=int),
            'timestamp': convert_datetime,
            'timeuuid': convert_uuid,
            'date': convert_date,
            'smallint': get_convert_integer_fcn(),
            'time': convert_time,
            'text': convert_text,
            'varchar': convert_text,
            'list': convert_list,
            'set': convert_set,
            'map': convert_map,
            'tuple': convert_tuple,
            'frozen': convert_single_subtype,
            VectorType.typename: convert_vector,
        }

        return converters.get(cql_type.typename, convert_unknown)