cassandra-bigtable-migration-tools/cassandra-bigtable-proxy/collectiondecoder/collectiondecoder.go (357 lines of code) (raw):

/* * Copyright (C) 2025 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); you may not * use this file except in compliance with the License. You may obtain a copy of * the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations under * the License. */ package collectiondecoder import ( "bytes" "encoding/binary" "fmt" "time" "github.com/GoogleCloudPlatform/cloud-bigtable-ecosystem/cassandra-bigtable-migration-tools/cassandra-bigtable-proxy/third_party/datastax/proxycore" "github.com/datastax/go-cassandra-native-protocol/datatype" "github.com/datastax/go-cassandra-native-protocol/primitive" ) // DecodeCollection decodes a collection (list, set, or map) of elements from the provided byte array. // It reads the collection length, then iterates over the encoded elements to decode them based on their data type. // Parameters: // - encoded: The byte array containing the encoded collection data. // - version: The Cassandra protocol version used for encoding. // - dt: The data type of the collection elements. // Returns: // - An interface{} containing the decoded collection as a slice or map of the appropriate type. // - An error if any decoding step fails. func DecodeCollection(dt datatype.DataType, version primitive.ProtocolVersion, encoded []byte) (interface{}, error) { reader := bytes.NewReader(encoded) var length int32 // Read the collection length (4 bytes) err := binary.Read(reader, binary.BigEndian, &length) if err != nil { return nil, err } switch dt.GetDataTypeCode() { case primitive.DataTypeCodeList: listType := dt.(datatype.ListType) return decodeListOrSet(listType.GetElementType(), version, reader, length) case primitive.DataTypeCodeSet: setType := dt.(datatype.SetType) return decodeListOrSet(setType.GetElementType(), version, reader, length) case primitive.DataTypeCodeMap: mapType := dt.(datatype.MapType) return decodeMap(mapType.GetValueType(), version, reader, mapType.GetKeyType(), length) default: return nil, fmt.Errorf("unsupported collection type: %v", dt.GetDataTypeCode()) } } // decodeListOrSet decodes a list or set of elements from the provided byte reader. // It reads each element's length and value, then decodes the value based on the specified element data type. // Parameters: // - reader: A byte reader positioned at the start of the encoded elements. // - version: The Cassandra protocol version used for encoding. // - elementType: The data type of the elements in the list or set. // - length: The number of elements in the collection. // Returns: // - An interface{} containing the decoded elements as a slice of the appropriate type. // - An error if any decoding step fails. func decodeListOrSet(elementType datatype.DataType, version primitive.ProtocolVersion, reader *bytes.Reader, length int32) (interface{}, error) { decodedElements := make([]interface{}, length) for i := int32(0); i < length; i++ { var elementLength int32 err := binary.Read(reader, binary.BigEndian, &elementLength) if err != nil { return nil, err } elementValue := make([]byte, elementLength) _, err = reader.Read(elementValue) if err != nil { return nil, err } decodedValue, err := proxycore.DecodeType(elementType, version, elementValue) if err != nil { return nil, err } decodedElements[i] = decodedValue } return ConvertToTypedSlice(decodedElements, elementType) } // decodeMap decodes a map of key-value pairs from the provided byte reader. // It reads each key's and value's length and value, then decodes them based on their respective data types. // Parameters: // - reader: A byte reader positioned at the start of the encoded key-value pairs. // - version: The Cassandra protocol version used for encoding. // - keyType: The data type of the map keys. // - valueType: The data type of the map values. // - length: The number of key-value pairs in the map. // Returns: // - An interface{} containing the decoded map as a map[interface{}]interface{} with keys and values of the appropriate types. // - An error if any decoding step fails. func decodeMap(valueType datatype.DataType, version primitive.ProtocolVersion, reader *bytes.Reader, keyType datatype.DataType, length int32) (interface{}, error) { decodedMap := make(map[interface{}]interface{}, length) for i := int32(0); i < length; i++ { var keyLength int32 err := binary.Read(reader, binary.BigEndian, &keyLength) if err != nil { return nil, err } keyValue := make([]byte, keyLength) _, err = reader.Read(keyValue) if err != nil { return nil, err } decodedKey, err := proxycore.DecodeType(keyType, version, keyValue) if err != nil { return nil, err } var valueLength int32 err = binary.Read(reader, binary.BigEndian, &valueLength) if err != nil { return nil, err } value := make([]byte, valueLength) _, err = reader.Read(value) if err != nil { return nil, err } decodedValue, err := proxycore.DecodeType(valueType, version, value) if err != nil { return nil, err } decodedMap[decodedKey] = decodedValue } return ConvertToTypedMap(decodedMap, keyType, valueType) } // ConvertToTypedSlice converts a slice of interface{} elements to a specific type based on the provided data type. // It performs type assertions and creates a new slice of the appropriate type for the elements. // Parameters: // - decodedElements: A slice of interface{} containing the decoded elements. // - dt: The data type of the elements in the slice. // Returns: // - An interface{} containing the converted slice of the appropriate type. // - An error if the type conversion fails or if the data type is unsupported. func ConvertToTypedSlice(decodedElements []interface{}, dt datatype.DataType) (interface{}, error) { switch dt.GetDataTypeCode() { case primitive.DataTypeCodeAscii, primitive.DataTypeCodeVarchar: typedCollection := make([]string, len(decodedElements)) for i, elem := range decodedElements { typedCollection[i] = elem.(string) } return typedCollection, nil case primitive.DataTypeCodeBigint, primitive.DataTypeCodeCounter: typedCollection := make([]int64, len(decodedElements)) for i, elem := range decodedElements { typedCollection[i] = elem.(int64) } return typedCollection, nil case primitive.DataTypeCodeInt: typedCollection := make([]int32, len(decodedElements)) for i, elem := range decodedElements { typedCollection[i] = elem.(int32) } return typedCollection, nil case primitive.DataTypeCodeFloat: typedCollection := make([]float32, len(decodedElements)) for i, elem := range decodedElements { typedCollection[i] = elem.(float32) } return typedCollection, nil case primitive.DataTypeCodeDouble: typedCollection := make([]float64, len(decodedElements)) for i, elem := range decodedElements { typedCollection[i] = elem.(float64) } return typedCollection, nil case primitive.DataTypeCodeBoolean: typedCollection := make([]bool, len(decodedElements)) for i, elem := range decodedElements { typedCollection[i] = elem.(bool) } return typedCollection, nil case primitive.DataTypeCodeTimestamp, primitive.DataTypeCodeDate, primitive.DataTypeCodeTime: typedCollection := make([]int64, len(decodedElements)) for i, elem := range decodedElements { typedCollection[i] = elem.(int64) } return typedCollection, nil case primitive.DataTypeCodeUuid, primitive.DataTypeCodeTimeuuid: typedCollection := make([]string, len(decodedElements)) for i, elem := range decodedElements { typedCollection[i] = elem.(string) } return typedCollection, nil case primitive.DataTypeCodeDecimal, primitive.DataTypeCodeVarint: // Assuming decimal and varint are represented as strings typedCollection := make([]string, len(decodedElements)) for i, elem := range decodedElements { typedCollection[i] = elem.(string) } return typedCollection, nil case primitive.DataTypeCodeBlob, primitive.DataTypeCodeInet: // Assuming blob and inet are represented as byte slices typedCollection := make([][]byte, len(decodedElements)) for i, elem := range decodedElements { typedCollection[i] = elem.([]byte) } return typedCollection, nil case primitive.DataTypeCodeSmallint: typedCollection := make([]int16, len(decodedElements)) for i, elem := range decodedElements { typedCollection[i] = elem.(int16) } return typedCollection, nil case primitive.DataTypeCodeTinyint: typedCollection := make([]int8, len(decodedElements)) for i, elem := range decodedElements { typedCollection[i] = elem.(int8) } return typedCollection, nil case primitive.DataTypeCodeDuration: // Assuming duration is represented as int64 typedCollection := make([]int64, len(decodedElements)) for i, elem := range decodedElements { typedCollection[i] = elem.(int64) } return typedCollection, nil default: // Unsupported type return nil, fmt.Errorf("unsupported data type: %v", dt.GetDataTypeCode()) } } // ConvertToTypedMap converts a map of interface{} keys and values to a specific type based on the provided key and value data types. // It performs type assertions and creates a new map of the appropriate types for the keys and values. // Parameters: // - decodedMap: A map[interface{}]interface{} containing the decoded key-value pairs. // - keyType: The data type of the map keys. // - valueType: The data type of the map values. // Returns: // - An interface{} containing the converted map of the appropriate types for keys and values. // - An error if the type conversion fails or if the key or value data type is unsupported. func ConvertToTypedMap(decodedMap map[interface{}]interface{}, keyType, valueType datatype.DataType) (interface{}, error) { switch keyType.GetDataTypeCode() { case primitive.DataTypeCodeAscii, primitive.DataTypeCodeVarchar: switch valueType.GetDataTypeCode() { case primitive.DataTypeCodeBigint, primitive.DataTypeCodeCounter: typedMap := make(map[string]int64, len(decodedMap)) for key, value := range decodedMap { typedMap[key.(string)] = value.(int64) } return typedMap, nil case primitive.DataTypeCodeInt: typedMap := make(map[string]int32, len(decodedMap)) for key, value := range decodedMap { typedMap[key.(string)] = value.(int32) } return typedMap, nil case primitive.DataTypeCodeFloat: typedMap := make(map[string]float32, len(decodedMap)) for key, value := range decodedMap { typedMap[key.(string)] = value.(float32) } return typedMap, nil case primitive.DataTypeCodeDouble: typedMap := make(map[string]float64, len(decodedMap)) for key, value := range decodedMap { typedMap[key.(string)] = value.(float64) } return typedMap, nil case primitive.DataTypeCodeBoolean: typedMap := make(map[string]bool, len(decodedMap)) for key, value := range decodedMap { typedMap[key.(string)] = value.(bool) } return typedMap, nil case primitive.DataTypeCodeTimestamp, primitive.DataTypeCodeDate, primitive.DataTypeCodeTime: typedMap := make(map[string]int64, len(decodedMap)) for key, value := range decodedMap { typedMap[key.(string)] = value.(int64) } return typedMap, nil case primitive.DataTypeCodeUuid, primitive.DataTypeCodeTimeuuid: typedMap := make(map[string]string, len(decodedMap)) for key, value := range decodedMap { typedMap[key.(string)] = value.(string) } return typedMap, nil case primitive.DataTypeCodeDecimal, primitive.DataTypeCodeVarint: typedMap := make(map[string]string, len(decodedMap)) for key, value := range decodedMap { typedMap[key.(string)] = value.(string) } return typedMap, nil case primitive.DataTypeCodeBlob, primitive.DataTypeCodeInet: typedMap := make(map[string][]byte, len(decodedMap)) for key, value := range decodedMap { typedMap[key.(string)] = value.([]byte) } return typedMap, nil case primitive.DataTypeCodeSmallint: typedMap := make(map[string]int16, len(decodedMap)) for key, value := range decodedMap { typedMap[key.(string)] = value.(int16) } return typedMap, nil case primitive.DataTypeCodeTinyint: typedMap := make(map[string]int8, len(decodedMap)) for key, value := range decodedMap { typedMap[key.(string)] = value.(int8) } return typedMap, nil case primitive.DataTypeCodeDuration: typedMap := make(map[string]int64, len(decodedMap)) for key, value := range decodedMap { typedMap[key.(string)] = value.(int64) } return typedMap, nil default: return nil, fmt.Errorf("unsupported map value data type: %v", valueType.GetDataTypeCode()) } case primitive.DataTypeCodeTimestamp: switch valueType.GetDataTypeCode() { case primitive.DataTypeCodeAscii, primitive.DataTypeCodeVarchar: typedMap := make(map[time.Time]string, len(decodedMap)) for key, value := range decodedMap { timestampKey, ok := key.(time.Time) if !ok { return nil, fmt.Errorf("invalid key type, expect time.Time but got %T", key) } stringValue, ok := value.(string) if !ok { return nil, fmt.Errorf("invalid value type, expect string but got %T", value) } typedMap[timestampKey] = stringValue } return typedMap, nil case primitive.DataTypeCodeInt: typedMap := make(map[time.Time]int32, len(decodedMap)) for key, value := range decodedMap { timestampKey, ok := key.(time.Time) if !ok { return nil, fmt.Errorf("invalid key type, expect time.Time but got %T", key) } int32value, ok := value.(int32) if !ok { return nil, fmt.Errorf("invalid value type, expect int but got %T", value) } typedMap[timestampKey] = int32value } return typedMap, nil case primitive.DataTypeCodeBigint: typedMap := make(map[time.Time]int64, len(decodedMap)) for key, value := range decodedMap { timestampKey, ok := key.(time.Time) if !ok { return nil, fmt.Errorf("invalid key type, expect time.Time but got %T", key) } int64value, ok := value.(int64) if !ok { return nil, fmt.Errorf("invalid value type, expect bigint but got %T", value) } typedMap[timestampKey] = int64value } return typedMap, nil case primitive.DataTypeCodeFloat: typedMap := make(map[time.Time]float32, len(decodedMap)) for key, value := range decodedMap { timestampKey, ok := key.(time.Time) if !ok { return nil, fmt.Errorf("invalid key type, expect time.Time but got %T", key) } float32value, ok := value.(float32) if !ok { return nil, fmt.Errorf("invalid value type, expect float but got %T", value) } typedMap[timestampKey] = float32value } return typedMap, nil case primitive.DataTypeCodeDouble: typedMap := make(map[time.Time]float64, len(decodedMap)) for key, value := range decodedMap { timestampKey, ok := key.(time.Time) if !ok { return nil, fmt.Errorf("invalid key type, expect time.Time but got %T", key) } float64value, ok := value.(float64) if !ok { return nil, fmt.Errorf("invalid value type, double string but got %T", value) } typedMap[timestampKey] = float64value } return typedMap, nil case primitive.DataTypeCodeBoolean: typedMap := make(map[time.Time]bool, len(decodedMap)) for key, value := range decodedMap { timestampKey, ok := key.(time.Time) if !ok { return nil, fmt.Errorf("invalid key type, expect time.Time but got %T", key) } boolValue, ok := value.(bool) if !ok { return nil, fmt.Errorf("invalid value type, expect boolean but got %T", value) } typedMap[timestampKey] = boolValue } return typedMap, nil case primitive.DataTypeCodeTimestamp: typedMap := make(map[time.Time]time.Time, len(decodedMap)) for key, value := range decodedMap { timestampKey, ok := key.(time.Time) if !ok { return nil, fmt.Errorf("invalid key type, expect time.Time but got %T", key) } timeValue, ok := value.(time.Time) if !ok { return nil, fmt.Errorf("invalid value type, expect timestamp but got %T", value) } typedMap[timestampKey] = timeValue } return typedMap, nil default: return nil, fmt.Errorf("unsupported map value data type: %v", valueType.GetDataTypeCode()) } default: return nil, fmt.Errorf("unsupported map key data type: %v", keyType.GetDataTypeCode()) } }