odps/tableschema/arrow_util.go (327 lines of code) (raw):

// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package tableschema import ( "encoding/binary" "fmt" "strings" "time" "github.com/pkg/errors" "github.com/aliyun/aliyun-odps-go-sdk/arrow" "github.com/aliyun/aliyun-odps-go-sdk/arrow/array" "github.com/aliyun/aliyun-odps-go-sdk/odps/data" "github.com/aliyun/aliyun-odps-go-sdk/odps/datatype" ) type arrowOptions struct { ExtendedMode bool TimestampUnit TimeUnit DatetimeUnit TimeUnit } type TimeUnit string const ( Second TimeUnit = "second" Milli TimeUnit = "milli" Micro TimeUnit = "micro" Nano TimeUnit = "nano" ) // ArrowOptions can not be used directly, it can be created by ArrowOptionConfig.XXX type ArrowOptions func(cfg *arrowOptions) func newTypeConvertConfig(opts ...ArrowOptions) *arrowOptions { cfg := &arrowOptions{ ExtendedMode: false, TimestampUnit: Nano, DatetimeUnit: Milli, } for _, opt := range opts { opt(cfg) } return cfg } func withExtendedMode() ArrowOptions { return func(cfg *arrowOptions) { cfg.ExtendedMode = true } } func withTimestampUnit(unit TimeUnit) ArrowOptions { return func(cfg *arrowOptions) { cfg.TimestampUnit = unit } } func withDatetimeUnit(unit TimeUnit) ArrowOptions { return func(cfg *arrowOptions) { cfg.DatetimeUnit = unit } } var ArrowOptionConfig = struct { WithExtendedMode func() ArrowOptions WithTimestampUnit func(unit TimeUnit) ArrowOptions WithDatetimeUnit func(unit TimeUnit) ArrowOptions }{ WithExtendedMode: withExtendedMode, WithTimestampUnit: withTimestampUnit, WithDatetimeUnit: withDatetimeUnit, } // TypeToArrowType convert odps field type to arrow field type // * Storage Type | Arrow Type // * ----------------------+--------------------- // * boolean | boolean // * tinyint | int8 // * smallint | int16 // * int | int32 // * bigint | int64 // * float | float32 // * double | float64 // * char | utf8 // * varchar | utf8 // * string | utf8 // * binary | binary // * date | date32 // * datetime | timestamp(nano) // * timestamp | timestamp(nano) 【注:精度选择功能开发中】 // * interval_day_time | day_time_interval // * interval_year_month | month_interval // * decimal | decimal // * struct | struct // * array | list // * map | map func TypeToArrowType(odpsType datatype.DataType) (arrow.DataType, error) { switch odpsType.ID() { case datatype.BOOLEAN: return arrow.FixedWidthTypes.Boolean, nil case datatype.TINYINT: return arrow.PrimitiveTypes.Int8, nil case datatype.SMALLINT: return arrow.PrimitiveTypes.Int16, nil case datatype.INT: return arrow.PrimitiveTypes.Int32, nil case datatype.BIGINT: return arrow.PrimitiveTypes.Int64, nil case datatype.FLOAT: return arrow.PrimitiveTypes.Float32, nil case datatype.DOUBLE: return arrow.PrimitiveTypes.Float64, nil case datatype.CHAR, datatype.VARCHAR, datatype.STRING, datatype.JSON: return arrow.BinaryTypes.String, nil case datatype.BINARY: return arrow.BinaryTypes.Binary, nil case datatype.DATE: return arrow.FixedWidthTypes.Date32, nil case datatype.DATETIME: return arrow.FixedWidthTypes.Timestamp_ms, nil case datatype.TIMESTAMP: return arrow.FixedWidthTypes.Timestamp_ns, nil case datatype.TIMESTAMP_NTZ: return arrow.FixedWidthTypes.Timestamp_ns, nil case datatype.IntervalDayTime: return arrow.FixedWidthTypes.DayTimeInterval, nil case datatype.IntervalYearMonth: return arrow.FixedWidthTypes.MonthInterval, nil case datatype.DECIMAL: decimal, _ := odpsType.(datatype.DecimalType) return &arrow.Decimal128Type{ Precision: decimal.Precision, Scale: decimal.Scale, }, nil case datatype.STRUCT: structType, _ := odpsType.(datatype.StructType) arrowFields := make([]arrow.Field, len(structType.Fields)) for i, field := range structType.Fields { arrowType, err := TypeToArrowType(field.Type) if err != nil { return arrow.Null, err } arrowFields[i] = arrow.Field{ Name: field.Name, Type: arrowType, } } return arrow.StructOf(arrowFields...), nil case datatype.ARRAY: arrayType, _ := odpsType.(datatype.ArrayType) itemType, err := TypeToArrowType(arrayType.ElementType) if err != nil { return arrow.Null, err } return arrow.ListOf(itemType), nil // case datatype.MAP: // mapType, _ := odpsType.(datatype.MapType) // keyType, err := TypeToArrowType(mapType.KeyType) // if err != nil { // return arrow.Null, err // } // valueType, err := TypeToArrowType(mapType.ValueType) // if err != nil { // return arrow.Null, err // } // return arrow.MapOf(keyType, valueType), nil } return arrow.Null, errors.Errorf("unknown odps data type: %s", odpsType.Name()) } func toMaxComputeData(vector array.Interface, index int, typeInfo datatype.DataType, cfg *arrowOptions) (data.Data, error) { switch typeInfo.ID() { case datatype.BOOLEAN: value := vector.(*array.Boolean).Value(index) return data.Bool(value), nil case datatype.TINYINT: value := vector.(*array.Int8).Value(index) return data.TinyInt(value), nil case datatype.SMALLINT: value := vector.(*array.Int16).Value(index) return data.SmallInt(value), nil case datatype.INT: value := vector.(*array.Int32).Value(index) return data.Int(value), nil case datatype.BIGINT: value := vector.(*array.Int64).Value(index) return data.BigInt(value), nil case datatype.FLOAT: value := vector.(*array.Float32).Value(index) return data.Float(value), nil case datatype.DOUBLE: value := vector.(*array.Float64).Value(index) return data.Double(value), nil case datatype.CHAR, datatype.VARCHAR, datatype.STRING, datatype.JSON: value := vector.(*array.String).Value(index) return data.String(value), nil case datatype.BINARY: value := vector.(*array.Binary).Value(index) return data.Binary(value), nil case datatype.DATE: value := vector.(*array.Date32).Value(index) // Date32 从 epoch(1970-01-01)起的天数 days := int64(value) // 将天数转换为 time.Duration,并加到 epoch 时间上 return data.Date(time.Unix(0, 0).AddDate(0, 0, int(days))), nil case datatype.DATETIME: value := vector.(*array.Timestamp).Value(index) epochTime := int64(value) switch cfg.DatetimeUnit { case Second: return data.DateTime(time.Unix(epochTime, 0)), nil case Milli: return data.DateTime(time.Unix(epochTime/1e3, (epochTime%1e3)*1e6)), nil case Micro: return data.DateTime(time.Unix(epochTime/1e6, (epochTime%1e6)*1e3)), nil case Nano: return data.DateTime(time.Unix(0, epochTime)), nil } case datatype.TIMESTAMP: if cfg.ExtendedMode { sec := vector.(*array.Struct).Field(0).(*array.Int64).Value(index) nano := vector.(*array.Struct).Field(1).(*array.Int32).Value(index) return data.Timestamp(time.Unix(sec, int64(nano))), nil } else { value := vector.(*array.Timestamp).Value(index) epochTime := int64(value) switch cfg.TimestampUnit { case Second: return data.Timestamp(time.Unix(epochTime, 0)), nil case Milli: return data.Timestamp(time.Unix(epochTime/1e3, (epochTime%1e3)*1e6)), nil case Micro: return data.Timestamp(time.Unix(epochTime/1e6, (epochTime%1e6)*1e3)), nil case Nano: return data.Timestamp(time.Unix(0, epochTime)), nil } } case datatype.TIMESTAMP_NTZ: if cfg.ExtendedMode { sec := vector.(*array.Struct).Field(0).(*array.Int64).Value(index) nano := vector.(*array.Struct).Field(1).(*array.Int32).Value(index) return data.TimestampNtz(time.Unix(sec, int64(nano))), nil } else { value := vector.(*array.Timestamp).Value(index) epochTime := int64(value) switch cfg.TimestampUnit { case Second: return data.TimestampNtz(time.Unix(epochTime, 0)), nil case Milli: return data.TimestampNtz(time.Unix(epochTime/1e3, (epochTime%1e3)*1e6)), nil case Micro: return data.TimestampNtz(time.Unix(epochTime/1e6, (epochTime%1e6)*1e3)), nil case Nano: return data.TimestampNtz(time.Unix(0, epochTime)), nil } } case datatype.DECIMAL: decimalType := typeInfo.(datatype.DecimalType) if cfg.ExtendedMode { fixedSizeBinaryVector, extendedMode := vector.(*array.FixedSizeBinary) if extendedMode { val := fixedSizeBinaryVector.Value(index) if len(val) < 8 { return nil, errors.Errorf("Unrecognized Decimal type, val len %d", len(val)) } mSign := val[1] mIntg := val[2] mFrac := val[3] var decimalBuilder strings.Builder if mSign > 0 { decimalBuilder.WriteString("-") } for j := int(mIntg); j > 0; j-- { num := int(binary.LittleEndian.Uint32(val[8+j*4 : 12+j*4])) if j == int(mIntg) { decimalBuilder.WriteString(fmt.Sprintf("%d", num)) } else { decimalBuilder.WriteString(fmt.Sprintf("%09d", num)) } } decimalBuilder.WriteString(".") for j := 0; j < int(mFrac); j++ { num := int(binary.LittleEndian.Uint32(val[8-4*j : 8-4*j+4])) decimalBuilder.WriteString(fmt.Sprintf("%09d", num)) } // trim trailing zeros result := decimalBuilder.String() result = strings.TrimRight(result, "0") if strings.HasSuffix(result, ".") { result = result[:len(result)-1] // 移除多余的小数点 } decimal := data.NewDecimal(int(decimalType.Precision), int(decimalType.Scale), result) return decimal, nil } } value := vector.(*array.Decimal128).Value(index) decimal := data.NewDecimalFromValue(int(decimalType.Precision), int(decimalType.Scale), value.BigInt()) return decimal, nil case datatype.ARRAY: arrayType := typeInfo.(datatype.ArrayType) // 处理 LIST 类型 listCol := vector.(*array.List) // 获取偏移值,包括第 i 个列表对应的偏移量 offsets := listCol.Offsets() start := offsets[index] // 当前列表的起始位置 end := offsets[index+1] // 下一个列表的起始位置 numElements := end - start // 当前列表的元素数量 listData := make([]interface{}, 0, numElements) // 创建容量为当前列表长度的切片 // 获取当前列表的值 childArray := listCol.ListValues() // 获取子列表,这里通常是一个 Array 接口 // 遍历子列表中的实际元素 for j := start; j < end; j++ { elementData, err := toMaxComputeData(childArray, int(j-start), arrayType.ElementType, cfg) if err != nil { return nil, err } listData = append(listData, elementData) } return data.ArrayFromSlice(listData...) case datatype.MAP: mapType := typeInfo.(datatype.MapType) // 处理 MAP 类型 mapCol := vector.(*array.Map) offsets := mapCol.Offsets() start := offsets[index] end := offsets[index+1] mapData := make(map[interface{}]interface{}) keys := mapCol.Keys() values := mapCol.Items() for j := start; j < end; j++ { keyData, err := toMaxComputeData(keys, int(j-start), mapType.KeyType, cfg) if err != nil { return nil, err } valueData, err := toMaxComputeData(values, int(j-start), mapType.ValueType, cfg) if err != nil { return nil, err } mapData[keyData] = valueData } return data.MapFromGoMap(mapData) case datatype.STRUCT: structType := typeInfo.(datatype.StructType) // 处理 STRUCT 类型 structCol := vector.(*array.Struct) structData := data.NewStructWithTyp(structType) for fieldIndex := 0; fieldIndex < structCol.NumField(); fieldIndex++ { fieldType := structType.Fields[fieldIndex] fieldName := fieldType.Name fieldArray := structCol.Field(fieldIndex) fieldValue, err := toMaxComputeData(fieldArray, index, fieldType.Type, cfg) if err != nil { return nil, err } err = structData.SetField(fieldName, fieldValue) if err != nil { return nil, err } } return structData, nil } return nil, fmt.Errorf("unsupported ODPS type: %v", typeInfo.Name()) } // ToMaxComputeRecords 将 Arrow Record Batch 转换为 ODPS Record 列表 func ToMaxComputeRecords(arrowBatch array.Record, columns []Column, opt ...ArrowOptions) ([]data.Record, error) { cfg := newTypeConvertConfig(opt...) odpsRecords := make([]data.Record, 0, int(arrowBatch.NumRows())) // 迭代每一行 for i := 0; i < int(arrowBatch.NumRows()); i++ { // 创建 ODPS Record odpsRecord := make([]data.Data, 0, int(arrowBatch.NumCols())) // 遍历 Arrow Record 中的所有列 for j := 0; j < int(arrowBatch.NumCols()); j++ { col := arrowBatch.Column(j) if col.IsValid(i) { odpsData, err := toMaxComputeData(col, i, columns[j].Type, cfg) if err != nil { return nil, err } odpsRecord = append(odpsRecord, odpsData) } else { // 处理空值 odpsRecord = append(odpsRecord, data.Null) } } odpsRecords = append(odpsRecords, odpsRecord) } return odpsRecords, nil }