odps/tunnel/record_protoc_writer.go (355 lines of code) (raw):
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tunnel
import (
"io"
"github.com/pkg/errors"
"google.golang.org/protobuf/encoding/protowire"
"github.com/aliyun/aliyun-odps-go-sdk/odps/data"
"github.com/aliyun/aliyun-odps-go-sdk/odps/datatype"
"github.com/aliyun/aliyun-odps-go-sdk/odps/tableschema"
)
type RecordProtocWriter struct {
httpRes *httpConnection
writeCloser io.Closer
protocWriter *ProtocStreamWriter
columns []tableschema.Column
shouldTransformDate bool
recordCrc Crc32CheckSum
crcOfCrc Crc32CheckSum // crc of record crc
count int64
closed bool
}
func newRecordProtocWriter(w io.WriteCloser, columns []tableschema.Column, shouldTransformDate bool) RecordProtocWriter {
return RecordProtocWriter{
httpRes: nil,
writeCloser: w,
protocWriter: NewProtocStreamWriter(w),
columns: columns,
shouldTransformDate: shouldTransformDate,
recordCrc: NewCrc32CheckSum(),
crcOfCrc: NewCrc32CheckSum(),
closed: false,
}
}
func newRecordProtocHttpWriter(conn *httpConnection, columns []tableschema.Column, shouldTransformDate bool) RecordProtocWriter {
return RecordProtocWriter{
httpRes: conn,
writeCloser: conn.Writer,
protocWriter: NewProtocStreamWriter(conn.Writer),
columns: columns,
shouldTransformDate: shouldTransformDate,
recordCrc: NewCrc32CheckSum(),
crcOfCrc: NewCrc32CheckSum(),
}
}
func (r *RecordProtocWriter) Write(record data.Record) error {
err := r.write(record)
if err != nil {
err1 := r.Close()
if err1 != nil {
return errors.WithStack(err1)
}
}
return errors.WithStack(err)
}
func (r *RecordProtocWriter) write(record data.Record) error {
// 这里加一个判断会引起不必要的耗时
//if r.closed {
// return errors.New("cannot write to a closed RecordProtocWriter")
//}
recordColNum := record.Len()
if recordColNum > len(r.columns) {
return errors.New("record values are more than schema.")
}
for colIndex, value := range record {
if value == nil || value.Type() == datatype.NullType {
continue
}
r.recordCrc.Update(int32(colIndex + 1))
err := r.writeFieldTag(colIndex+1, r.columns[colIndex].Type)
if err != nil {
return errors.WithStack(err)
}
err = r.writeField(value)
if err != nil {
return errors.WithStack(err)
}
}
err := r.protocWriter.WriteTag(EndRecord, protowire.VarintType)
if err != nil {
return errors.WithStack(err)
}
recordCrcVal := r.recordCrc.Value()
err = r.protocWriter.WriteUInt32(recordCrcVal)
if err != nil {
return errors.WithStack(err)
}
r.recordCrc.Reset()
r.crcOfCrc.Update(recordCrcVal)
r.count += 1
return nil
}
func (r *RecordProtocWriter) writeFieldTag(colIndex int, dt datatype.DataType) error {
var wireType protowire.Type
switch dt.ID() {
case datatype.DATETIME,
datatype.BOOLEAN,
datatype.BIGINT,
datatype.TINYINT,
datatype.SMALLINT,
datatype.INT,
datatype.DATE,
datatype.IntervalYearMonth:
wireType = protowire.VarintType
case datatype.DOUBLE:
wireType = protowire.Fixed64Type
case datatype.FLOAT:
wireType = protowire.Fixed32Type
case datatype.IntervalDayTime,
datatype.TIMESTAMP,
datatype.TIMESTAMP_NTZ,
datatype.STRING,
datatype.CHAR,
datatype.VARCHAR,
datatype.BINARY,
datatype.DECIMAL,
datatype.ARRAY,
datatype.MAP,
datatype.STRUCT,
datatype.JSON:
wireType = protowire.BytesType
default:
return errors.Errorf("Invalid data type, %s", dt.Name())
}
err := r.protocWriter.WriteTag(protowire.Number(int32(colIndex)), wireType)
if err != nil {
return errors.WithStack(err)
}
return nil
}
func (r *RecordProtocWriter) writeField(val data.Data) error {
switch val := val.(type) {
case data.Double:
r.recordCrc.Update(val)
return errors.WithStack(r.protocWriter.WriteFloat64(float64(val)))
case data.Float:
r.recordCrc.Update(val)
return errors.WithStack(r.protocWriter.WriteFloat32(float32(val)))
case data.Bool:
r.recordCrc.Update(val)
return errors.WithStack(r.protocWriter.WriteBool(bool(val)))
case data.BigInt:
r.recordCrc.Update(val)
return errors.WithStack(r.protocWriter.WriteSInt64(int64(val)))
case data.IntervalYearMonth:
r.recordCrc.Update(int64(val))
return errors.WithStack(r.protocWriter.WriteSInt64(int64(val)))
case data.Int:
r.recordCrc.Update(int64(val))
return errors.WithStack(r.protocWriter.WriteSInt64(int64(val)))
case data.SmallInt:
r.recordCrc.Update(int64(val))
return errors.WithStack(r.protocWriter.WriteSInt64(int64(val)))
case data.TinyInt:
r.recordCrc.Update(int64(val))
return errors.WithStack(r.protocWriter.WriteSInt64(int64(val)))
case *data.String:
b := []byte(string(*val))
r.recordCrc.Update(b)
return errors.WithStack(r.protocWriter.WriteBytes(b))
case data.String:
b := []byte(string(val))
r.recordCrc.Update(b)
return errors.WithStack(r.protocWriter.WriteBytes(b))
case *data.VarChar:
b := []byte(val.Data())
r.recordCrc.Update(b)
return errors.WithStack(r.protocWriter.WriteBytes(b))
case data.VarChar:
b := []byte(val.Data())
r.recordCrc.Update(b)
return errors.WithStack(r.protocWriter.WriteBytes(b))
case *data.Char:
b := []byte(val.Data())
r.recordCrc.Update(b)
return errors.WithStack(r.protocWriter.WriteBytes(b))
case data.Char:
b := []byte(val.Data())
r.recordCrc.Update(b)
return errors.WithStack(r.protocWriter.WriteBytes(b))
case data.Binary:
r.recordCrc.Update(val)
return errors.WithStack(r.protocWriter.WriteBytes(val))
case data.DateTime:
t := val.Time()
// 应该直接写成: milliSeconds := t.UnixMilli()
// 但是 Time.UnixMilli is added in go.1.17
// func (t Time) UnixMilli() int64 {
// return t.unixSec()*1e3 + int64(t.nsec())/1e6
// }
unixSec := t.Unix()
nanoSec := t.Nanosecond()
milliSeconds := unixSec*1e3 + int64(nanoSec)/1e6
// TODO 需要根据schema中的shouldTransform,来确定是否将时间转换为本地时区的时间
r.recordCrc.Update(milliSeconds)
return errors.WithStack(r.protocWriter.WriteSInt64(milliSeconds))
case data.Date:
t := val.Time()
// 获取从1970年以来的天数
days := t.Unix() / data.SecondsPerDay
r.recordCrc.Update(days)
return errors.WithStack(r.protocWriter.WriteSInt64(days))
case data.IntervalDayTime:
seconds := val.Seconds()
nanoSeconds := val.NanosFraction()
r.recordCrc.Update(seconds)
r.recordCrc.Update(nanoSeconds)
err := r.protocWriter.WriteSInt64(seconds)
if err != nil {
return errors.WithStack(err)
}
return errors.WithStack(r.protocWriter.WriteSInt32(nanoSeconds))
case data.Timestamp:
t := val.Time()
seconds := t.Unix()
nanoSeconds := int32(t.Nanosecond())
r.recordCrc.Update(seconds)
r.recordCrc.Update(nanoSeconds)
err := r.protocWriter.WriteSInt64(seconds)
if err != nil {
return errors.WithStack(err)
}
return errors.WithStack(r.protocWriter.WriteSInt32(nanoSeconds))
case data.TimestampNtz:
t := val.Time()
seconds := t.Unix()
nanoSeconds := int32(t.Nanosecond())
r.recordCrc.Update(seconds)
r.recordCrc.Update(nanoSeconds)
err := r.protocWriter.WriteSInt64(seconds)
if err != nil {
return errors.WithStack(err)
}
return errors.WithStack(r.protocWriter.WriteSInt32(nanoSeconds))
case data.Decimal:
b := []byte(val.Value())
r.recordCrc.Update(b)
return errors.WithStack(r.protocWriter.WriteBytes(b))
case *data.Decimal:
b := []byte(val.Value())
r.recordCrc.Update(b)
return errors.WithStack(r.protocWriter.WriteBytes(b))
case data.Array:
return errors.WithStack(r.writeArray(val.ToSlice()))
case *data.Array:
return errors.WithStack(r.writeArray(val.ToSlice()))
case data.Map:
return errors.WithStack(r.writeMap(&val))
case *data.Map:
return errors.WithStack(r.writeMap(val))
case data.Struct:
return errors.WithStack(r.writeStruct(&val))
case *data.Struct:
return errors.WithStack(r.writeStruct(val))
case data.Json:
return errors.WithStack(r.writeJson(&val))
case *data.Json:
return errors.WithStack(r.writeJson(val))
}
return errors.Errorf("invalid data type %v", val.Type())
}
func (r *RecordProtocWriter) writeArray(val []data.Data) error {
err := r.protocWriter.WriteInt32(int32(len(val)))
if err != nil {
return errors.WithStack(err)
}
for _, d := range val {
if d == nil {
err = r.protocWriter.WriteBool(true)
if err != nil {
return errors.WithStack(err)
}
} else {
err = r.protocWriter.WriteBool(false)
if err != nil {
return errors.WithStack(err)
}
err = r.writeField(d)
if err != nil {
return errors.WithStack(err)
}
}
}
return nil
}
func (r *RecordProtocWriter) writeMap(val *data.Map) error {
m := val.ToGoMap()
l := len(m)
keys, values := make([]data.Data, 0, l), make([]data.Data, 0, l)
for k, v := range m {
keys = append(keys, k)
values = append(values, v)
}
err := r.writeArray(keys)
if err != nil {
return errors.WithStack(err)
}
return errors.WithStack(r.writeArray(values))
}
func (r *RecordProtocWriter) writeStruct(val *data.Struct) error {
for _, field := range val.Fields() {
if field.Value == nil {
err := r.protocWriter.WriteBool(true)
if err != nil {
return errors.WithStack(err)
}
} else {
err := r.protocWriter.WriteBool(false)
if err != nil {
return errors.WithStack(err)
}
err = r.writeField(field.Value)
if err != nil {
return errors.WithStack(err)
}
}
}
return nil
}
func (r *RecordProtocWriter) writeJson(val *data.Json) error {
jsonStr := val.GetData()
b := []byte(jsonStr)
r.recordCrc.Update(b)
return errors.WithStack(r.protocWriter.WriteBytes(b))
}
func (r *RecordProtocWriter) close() error {
err := r.protocWriter.WriteTag(MetaCount, protowire.VarintType)
if err != nil {
return errors.WithStack(err)
}
err = r.protocWriter.WriteSInt64(r.count)
if err != nil {
return errors.WithStack(err)
}
err = r.protocWriter.WriteTag(MetaChecksum, protowire.VarintType)
if err != nil {
return errors.WithStack(err)
}
err = r.protocWriter.WriteUInt32(r.crcOfCrc.Value())
if err != nil {
return errors.WithStack(err)
}
err = r.writeCloser.Close()
if err != nil {
return errors.WithStack(err)
}
return nil
}
func (r *RecordProtocWriter) Close() error {
if r.closed {
return errors.New("try to close a closed RecordProtocWriter")
}
r.closed = true
err := r.close()
if r.httpRes != nil {
closeHttpError := errors.WithStack(r.httpRes.closeRes())
if closeHttpError != nil {
return errors.WithStack(closeHttpError)
}
}
return errors.WithStack(err)
}
func (r *RecordProtocWriter) RecordCount() int64 {
return r.count
}
func (r *RecordProtocWriter) BytesCount() int64 {
return int64(r.httpRes.bytesCount())
}