odps/tunnel/record_protoc_reader.go (357 lines of code) (raw):
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tunnel
import (
"io"
"net/http"
"time"
"github.com/pkg/errors"
"github.com/aliyun/aliyun-odps-go-sdk/odps/data"
"github.com/aliyun/aliyun-odps-go-sdk/odps/datatype"
"github.com/aliyun/aliyun-odps-go-sdk/odps/tableschema"
)
type RecordProtocReader struct {
httpRes *http.Response // TODO 改成和ArrowStreamStream一样,用io.ReaderCloser
protocReader *ProtocStreamReader
columns []tableschema.Column
shouldTransformDate bool
recordCrc Crc32CheckSum
crcOfCrc Crc32CheckSum // crc of record crc
count int64
}
func newRecordProtocReader(httpRes *http.Response, columns []tableschema.Column, shouldTransformDate bool) RecordProtocReader {
return RecordProtocReader{
httpRes: httpRes,
protocReader: NewProtocStreamReader(httpRes.Body),
columns: columns,
shouldTransformDate: shouldTransformDate,
recordCrc: NewCrc32CheckSum(),
crcOfCrc: NewCrc32CheckSum(),
}
}
func (r *RecordProtocReader) HttpRes() *http.Response {
return r.httpRes
}
func (r *RecordProtocReader) Read() (data.Record, error) {
record := make([]data.Data, len(r.columns))
LOOP:
for {
tag, _, err := r.protocReader.ReadTag()
if err != nil {
return nil, errors.WithStack(err)
}
switch tag {
case EndRecord:
crc := r.recordCrc.Value()
uint32V, err := r.protocReader.ReadUInt32()
if err != nil {
return nil, errors.WithStack(err)
}
if crc != uint32V {
return nil, errors.New("crc value is error")
}
r.recordCrc.Reset()
r.crcOfCrc.Update(crc)
break LOOP
case MetaCount:
sInt64, err := r.protocReader.ReadSInt64()
if err != nil {
return nil, errors.WithStack(err)
}
if sInt64 != r.count {
return nil, errors.New("record count does not match")
}
tag, _, err := r.protocReader.ReadTag()
if err != nil {
return nil, errors.WithStack(err)
}
if tag != MetaChecksum {
return nil, errors.New("invalid stream")
}
crcOfCrc, err := r.protocReader.ReadUInt32()
if err == nil {
_, err = r.protocReader.inner.Read([]byte{'0'})
if (!errors.Is(err, io.EOF)) && (!errors.Is(err, io.ErrUnexpectedEOF)) {
return nil, errors.New("expect end of stream, but not")
}
}
if r.crcOfCrc.Value() != crcOfCrc {
return nil, errors.New("checksum is invalid")
}
default:
columnIndex := int32(tag)
if int(columnIndex) > len(r.columns) {
return nil, errors.New("invalid protobuf tag")
}
r.recordCrc.Update(columnIndex)
c := r.columns[columnIndex-1]
fv, err := r.readField(c.Type)
if err != nil {
return nil, errors.WithStack(err)
}
record[columnIndex-1] = fv
}
}
r.count += 1
return record, nil
}
func (r *RecordProtocReader) Iterator(f func(record data.Record, err error)) error {
for {
record, err := r.Read()
isEOF := errors.Is(err, io.EOF)
if isEOF {
return nil
}
f(record, err)
if err != nil {
return err
}
}
}
func (r *RecordProtocReader) Close() error {
return errors.WithStack(r.httpRes.Body.Close())
}
func (r *RecordProtocReader) readField(dt datatype.DataType) (data.Data, error) {
var fieldValue data.Data
switch dt.ID() {
case datatype.DOUBLE:
v, err := r.protocReader.ReadFloat64()
if err != nil {
return nil, errors.WithStack(err)
}
r.recordCrc.Update(v)
fieldValue = data.Double(v)
case datatype.FLOAT:
v, err := r.protocReader.ReadFloat32()
if err != nil {
return nil, errors.WithStack(err)
}
r.recordCrc.Update(v)
fieldValue = data.Float(v)
case datatype.BOOLEAN:
v, err := r.protocReader.ReadBool()
if err != nil {
return nil, errors.WithStack(err)
}
r.recordCrc.Update(v)
fieldValue = data.Bool(v)
case datatype.BIGINT:
v, err := r.protocReader.ReadSInt64()
if err != nil {
return nil, errors.WithStack(err)
}
r.recordCrc.Update(v)
fieldValue = data.BigInt(v)
case datatype.IntervalYearMonth:
v, err := r.protocReader.ReadSInt64()
if err != nil {
return nil, errors.WithStack(err)
}
r.recordCrc.Update(v)
fieldValue = data.IntervalYearMonth(v)
case datatype.INT:
v, err := r.protocReader.ReadSInt64()
if err != nil {
return nil, errors.WithStack(err)
}
r.recordCrc.Update(v)
fieldValue = data.Int(v)
case datatype.SMALLINT:
v, err := r.protocReader.ReadSInt64()
if err != nil {
return nil, errors.WithStack(err)
}
r.recordCrc.Update(v)
fieldValue = data.SmallInt(v)
case datatype.TINYINT:
v, err := r.protocReader.ReadSInt64()
if err != nil {
return nil, errors.WithStack(err)
}
r.recordCrc.Update(v)
fieldValue = data.TinyInt(v)
case datatype.STRING:
v, err := r.protocReader.ReadBytes()
if err != nil {
return nil, errors.WithStack(err)
}
r.recordCrc.Update(v)
fieldValue = data.String(v)
case datatype.VARCHAR:
v, err := r.protocReader.ReadBytes()
if err != nil {
return nil, errors.WithStack(err)
}
r.recordCrc.Update(v)
t := dt.(datatype.VarcharType)
fieldValue, _ = data.MakeVarChar(t.Length, string(v))
case datatype.CHAR:
v, err := r.protocReader.ReadBytes()
if err != nil {
return nil, errors.WithStack(err)
}
r.recordCrc.Update(v)
t := dt.(datatype.CharType)
fieldValue, _ = data.MakeChar(t.Length, string(v))
case datatype.BINARY:
v, err := r.protocReader.ReadBytes()
if err != nil {
return nil, errors.WithStack(err)
}
r.recordCrc.Update(v)
fieldValue = data.Binary(v)
case datatype.DATETIME:
v, err := r.protocReader.ReadSInt64()
if err != nil {
return nil, errors.WithStack(err)
}
r.recordCrc.Update(v)
// TODO 需要根据schema中的shouldTransform,来确定是否将时间转换为本地时区的时间
seconds := v / 1000
nanoSeconds := (v % 1000) * 1000_000
// time.Unix获取的时间已经带本地时区信息
fieldValue = data.DateTime(time.Unix(seconds, nanoSeconds))
case datatype.DATE:
v, err := r.protocReader.ReadSInt64()
if err != nil {
return nil, errors.WithStack(err)
}
r.recordCrc.Update(v)
// v为从1970-01-01以来的天数
d := epochDay.AddDate(0, 0, int(v))
fieldValue = data.Date(d)
case datatype.IntervalDayTime:
seconds, err := r.protocReader.ReadSInt64()
if err != nil {
return nil, errors.WithStack(err)
}
nanoSeconds, err := r.protocReader.ReadSInt32()
if err != nil {
return nil, errors.WithStack(err)
}
r.recordCrc.Update(seconds)
r.recordCrc.Update(nanoSeconds)
fieldValue = data.NewIntervalDayTime(seconds, nanoSeconds)
case datatype.TIMESTAMP:
seconds, err := r.protocReader.ReadSInt64()
if err != nil {
return nil, errors.WithStack(err)
}
nanoSeconds, err := r.protocReader.ReadSInt32()
if err != nil {
return nil, errors.WithStack(err)
}
r.recordCrc.Update(seconds)
r.recordCrc.Update(nanoSeconds)
fieldValue = data.Timestamp(time.Unix(seconds, int64(nanoSeconds)))
case datatype.TIMESTAMP_NTZ:
seconds, err := r.protocReader.ReadSInt64()
if err != nil {
return nil, errors.WithStack(err)
}
nanoSeconds, err := r.protocReader.ReadSInt32()
if err != nil {
return nil, errors.WithStack(err)
}
r.recordCrc.Update(seconds)
r.recordCrc.Update(nanoSeconds)
fieldValue = data.TimestampNtz(time.Unix(seconds, int64(nanoSeconds)).UTC())
case datatype.DECIMAL:
v, err := r.protocReader.ReadBytes()
if err != nil {
return nil, errors.WithStack(err)
}
r.recordCrc.Update(v)
decimalType := dt.(datatype.DecimalType)
fieldValue = data.NewDecimal(int(decimalType.Precision), int(decimalType.Scale), string(v))
case datatype.ARRAY:
var err error
fieldValue, err = r.readArray(dt.(datatype.ArrayType).ElementType)
if err != nil {
return nil, errors.WithStack(err)
}
case datatype.MAP:
var err error
fieldValue, err = r.readMap(dt.(datatype.MapType))
if err != nil {
return nil, errors.WithStack(err)
}
case datatype.STRUCT:
var err error
fieldValue, err = r.readStruct(dt.(datatype.StructType))
if err != nil {
return nil, errors.WithStack(err)
}
case datatype.JSON:
v, err := r.protocReader.ReadBytes()
if err != nil {
return nil, errors.WithStack(err)
}
r.recordCrc.Update(v)
fieldValue = &data.Json{
Data: string(v),
Valid: true,
}
}
return fieldValue, nil
}
func (r *RecordProtocReader) readArray(t datatype.DataType) (*data.Array, error) {
arraySize, err := r.protocReader.ReadUInt32()
if err != nil {
return nil, errors.WithStack(err)
}
arrayData := make([]data.Data, arraySize)
for i := uint32(0); i < arraySize; i++ {
b, err := r.protocReader.ReadBool()
if err != nil {
return nil, errors.WithStack(err)
}
if b {
arrayData[i] = nil
} else {
arrayData[i], err = r.readField(t)
if err != nil {
return nil, errors.WithStack(err)
}
}
}
at := datatype.NewArrayType(t)
array := data.NewArrayWithType(at)
array.UnSafeAppend(arrayData...)
return array, nil
}
func (r *RecordProtocReader) readMap(t datatype.MapType) (*data.Map, error) {
keys, err := r.readArray(t.KeyType)
if err != nil {
return nil, errors.WithStack(err)
}
values, err := r.readArray(t.ValueType)
if err != nil {
return nil, errors.WithStack(err)
}
if keys.Len() != values.Len() {
return nil, errors.New("failed to read map")
}
dm := data.NewMapWithType(t)
for i, n := 0, keys.Len(); i < n; i++ {
key := keys.Index(i)
value := values.Index(i)
dm.Set(key, value)
}
return dm, nil
}
func (r *RecordProtocReader) readStruct(t datatype.StructType) (*data.Struct, error) {
sd := data.NewStructWithTyp(t)
for _, ft := range t.Fields {
fn := ft.Name
b, err := r.protocReader.ReadBool()
if err != nil {
return nil, errors.WithStack(err)
}
if b {
sd.SetField(fn, nil)
} else {
fd, err := r.readField(ft.Type)
if err != nil {
return nil, errors.WithStack(err)
}
sd.SetField(fn, fd)
}
}
return sd, nil
}