cassandra-bigtable-migration-tools/cassandra-bigtable-proxy/responsehandler/responsehandler_utils.go (229 lines of code) (raw):

/* * Copyright (C) 2025 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); you may not * use this file except in compliance with the License. You may obtain a copy of * the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations under * the License. */ package responsehandler import ( "fmt" "reflect" "strconv" "github.com/GoogleCloudPlatform/cloud-bigtable-ecosystem/cassandra-bigtable-migration-tools/cassandra-bigtable-proxy/third_party/datastax/proxycore" "github.com/datastax/go-cassandra-native-protocol/datatype" "github.com/datastax/go-cassandra-native-protocol/message" "github.com/datastax/go-cassandra-native-protocol/primitive" ) // HandleTimestampMap encodes map data with timestamps as keys and appends it to a message.Row. // It first decodes the map data, determines the map type for the given element type, and then // encodes the map data into a byte slice which is appended to the message.Row. // // Parameters: // - mapData: A map with timestamps as keys and byte slices as values, representing the data to encode. // - mr: A pointer to a message.Row where the encoded data will be appended. // - elementType: A string indicating the data type of the map's values. // - protocolV: The Cassandra protocol version used for encoding. // // Returns: An error if the encoding fails or if the map type retrieval does not succeed. func (th *TypeHandler) HandleTimestampMap(mapData map[string]interface{}, mr *message.Row, mapType datatype.MapType, protocolV primitive.ProtocolVersion) error { var bytes []byte var err error detailsField, err := th.decodeMapData(mapData, mapType.GetValueType(), protocolV) if err != nil { return err } bytes, err = proxycore.EncodeType(mapType, protocolV, detailsField) if err != nil { return fmt.Errorf("error encoding map data: %v", err) } *mr = append(*mr, bytes) return nil } // decodeMapData converts map data with string keys to a map with int64 keys, and decodes the values // from byte slices to the specified element type. It interprets the string keys as timestamps. // // Parameters: // - mapData: A map with string keys and byte slice values, representing data to decode. // - elementType: A string specifying the type of the values in the map. // - protocolV: The Cassandra protocol version used for decoding. // // Returns: An interface{} containing the decoded map data, or an error if decoding fails or // // if any key can't be parsed to an int64. func (th *TypeHandler) decodeMapData(mapData map[string]interface{}, elementType datatype.DataType, protocolV primitive.ProtocolVersion) (interface{}, error) { result := make(map[int64]interface{}) for key, value := range mapData { byteArray, ok := value.([]byte) if !ok { return nil, fmt.Errorf("type assertion to []byte failed") } decodedValue, err := th.DecodeValue(byteArray, elementType, protocolV) if err != nil { return nil, err } // Convert key as string to timestamp bigintKey, err := strconv.ParseInt(key, 10, 64) if err != nil { return nil, fmt.Errorf("unable to parse key to int64") } result[bigintKey] = decodedValue } return result, nil } // DecodeValue decodes a byte slice into a specified data type based on the element type and protocol version. // // Parameters: // - byteArray: A byte slice representing the encoded data. // - elementType: A string indicating the desired type to decode the byte slice into. // - protocolV: The Cassandra protocol version used for decoding. // // Returns: An interface{} containing the decoded value and an error if the decoding fails // // or if the element type is unsupported. func (th *TypeHandler) DecodeValue(byteArray []byte, elementType datatype.DataType, protocolV primitive.ProtocolVersion) (interface{}, error) { var decodedValue interface{} var err error switch elementType { case datatype.Boolean: decodedValue, err = HandlePrimitiveEncoding(elementType, byteArray, protocolV, false) case datatype.Int: decodedValue, err = HandlePrimitiveEncoding(elementType, byteArray, protocolV, false) case datatype.Bigint: decodedValue, err = proxycore.DecodeType(datatype.Bigint, protocolV, byteArray) case datatype.Float: decodedValue, err = proxycore.DecodeType(datatype.Float, protocolV, byteArray) case datatype.Double: decodedValue, err = proxycore.DecodeType(datatype.Double, protocolV, byteArray) case datatype.Varchar: decodedValue = string(byteArray) case datatype.Timestamp: decodedValue, err = proxycore.DecodeType(datatype.Bigint, protocolV, byteArray) default: return nil, fmt.Errorf("unsupported element type: %v", elementType) } if err != nil { return nil, fmt.Errorf("error decoding value for element type %v: %v", elementType, err) } return decodedValue, nil } // HasDollarSymbolPrefix checks if the first character of a given string is a dollar sign. // // Parameters: // - s: The string to check. // // Returns: A boolean indicating whether the first character is a dollar sign. func HasDollarSymbolPrefix(s string) bool { if len(s) == 0 { return false } return s[0] == '$' } // GetMapKeyForColumn retrieves the map key associated with a specific column from the query metadata. // // Parameters: // - queryMetadata: The QueryMetadata containing information about selected columns. // - column: A string representing the column name for which to find the associated map key. // // Returns: A string containing the map key associated with the given column name, or an empty string if not found. func GetMapKeyForColumn(queryMetadata QueryMetadata, column string) string { for _, value := range queryMetadata.SelectedColumns { if value.Name == column || value.Alias == column { return value.MapKey } } return "" } // DecodeAndReturnBool decodes a byte array to a boolean value. // // Parameters: // - btBytes: The byte array to be decoded. // - pv: The Cassandra protocol version. // // Returns: The decoded boolean value and an error if any. func decodeAndReturnBool(value interface{}, pv primitive.ProtocolVersion) (bool, error) { switch v := value.(type) { case []byte: bv, err := proxycore.DecodeType(datatype.Bigint, pv, v) if err != nil { return false, fmt.Errorf("failed to retrieve int in the DecodeAndReturnBool function: %v", err) } bigint := bv.(int64) if bigint > 0 { return true, nil } return false, nil case string: val, err := strconv.ParseInt(v, 10, 64) if err != nil { return false, fmt.Errorf("error converting string to int64: %w", err) } if val > 0 { return true, nil } return false, nil default: return false, fmt.Errorf("unsupported type: %T", v) } } /** * DecodeAndReturnInt is a function that decodes a value to an int32. * * Parameters: * - value: The value to be decoded. * - pv: The Cassandra protocol version. * * Returns: The decoded int32 value and an error if any. */ func decodeAndReturnInt(value interface{}, pv primitive.ProtocolVersion) (int32, error) { switch v := value.(type) { case []byte: val, err := proxycore.DecodeType(datatype.Bigint, pv, v) if err != nil { return 0, fmt.Errorf("failed to retrieve int in the DecodeAndReturnBool function: %v", err) } return int32(val.(int64)), nil case string: val, err := strconv.ParseInt(v, 10, 64) if err != nil { return 0, fmt.Errorf("error converting string to int64: %w", err) } return int32(val), nil case int64: // should be safe to convert to int32 because this value should've been written with int32 constraints return int32(value.(int64)), nil default: return 0, fmt.Errorf("unsupported type: %T", v) } } /** * DecodeAndReturnBigInt is a function that decodes a value to an int64. * * Parameters: * - value: The value to be decoded. * - pv: The Cassandra protocol version. * * Returns: The decoded int64 value and an error if any. */ func decodeAndReturnBigInt(value interface{}, pv primitive.ProtocolVersion) (int64, error) { switch v := value.(type) { case []byte: value, err := proxycore.DecodeType(datatype.Bigint, pv, v) if err != nil { return 0, fmt.Errorf("failed to retrieve int in the DecodeAndReturnBigInt function: %v", err) } return value.(int64), nil case string: val, err := strconv.ParseInt(v, 10, 64) if err != nil { return 0, fmt.Errorf("error converting string to int64: %w", err) } return val, nil case int64: return value.(int64), nil default: return 0, fmt.Errorf("unsupported type: %T", v) } } /** * DecodeAndReturnFloat is a function that decodes a value to an int64. * * Parameters: * - value: The value to be decoded. * - pv: The Cassandra protocol version. * * Returns: The decoded float32 value and an error if any. */ func decodeAndReturnFloat(value interface{}, pv primitive.ProtocolVersion) (float32, error) { switch v := value.(type) { case []byte: value, err := proxycore.DecodeType(datatype.Float, pv, v) if err != nil { return 0, fmt.Errorf("failed to retrieve int in the DecodeAndReturnFloat function: %v", err) } return value.(float32), nil case string: val, err := strconv.ParseFloat(v, 32) if err != nil { return 0, fmt.Errorf("error converting string to int64: %w", err) } return float32(val), nil default: return 0, fmt.Errorf("unsupported type: %T", v) } } /** * DecodeAndReturnDouble is a function that decodes a value to an float64. * * Parameters: * - value: The value to be decoded. * - pv: The Cassandra protocol version. * * Returns: The decoded float64 value and an error if any. */ func decodeAndReturnDouble(value interface{}, pv primitive.ProtocolVersion) (float64, error) { switch v := value.(type) { case []byte: value, err := proxycore.DecodeType(datatype.Double, pv, v) if err != nil { return 0, fmt.Errorf("failed to retrieve int in the DecodeAndReturnDouble function: %v", err) } return value.(float64), nil case string: val, err := strconv.ParseFloat(v, 64) if err != nil { return 0, fmt.Errorf("error converting string to int64: %w", err) } return val, nil default: return 0, fmt.Errorf("unsupported type: %T", v) } } /** * HandlePrimitiveEncoding is a function that encodes a value based on the cqlType. * * Parameters: * - cqlType: A string representing the type of the value. * - value: The value to be encoded. * - protocalVersion: The Cassandra protocol version. * * Returns: The encoded value and an error if any. */ func HandlePrimitiveEncoding(dt datatype.DataType, value interface{}, protocalVersion primitive.ProtocolVersion, encode bool) (interface{}, error) { val := reflect.ValueOf(value) if !val.IsValid() { return value, nil } if val.Kind() == reflect.Slice { if val.Len() == 0 { return value, nil } } var decodedValue interface{} var err error if dt == datatype.Boolean { decodedValue, err = decodeAndReturnBool(value, protocalVersion) } else if dt == datatype.Int { decodedValue, err = decodeAndReturnInt(value, protocalVersion) } else if dt == datatype.Bigint || dt == datatype.Timestamp { decodedValue, err = decodeAndReturnBigInt(value, protocalVersion) } else if dt == datatype.Float { decodedValue, err = decodeAndReturnFloat(value, protocalVersion) } else if dt == datatype.Double { decodedValue, err = decodeAndReturnDouble(value, protocalVersion) } else if dt == datatype.Varchar { byteArray, okByte := value.([]byte) stringValue, okString := value.(string) if !okByte && !okString { return nil, fmt.Errorf("value is not a byte array or string") } if okByte { decodedValue = string(byteArray) } else { decodedValue = stringValue } } else { return nil, fmt.Errorf("unsupported primitive type: %s", dt.String()) } if err != nil { return nil, err } if encode { encoded, _ := proxycore.EncodeType(dt, protocalVersion, decodedValue) return encoded, nil } return decodedValue, nil }