odps/tunnel/record_pack_stream_writer.go (89 lines of code) (raw):
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tunnel
import (
"bytes"
"time"
"github.com/pkg/errors"
"github.com/aliyun/aliyun-odps-go-sdk/odps/datatype"
"github.com/aliyun/aliyun-odps-go-sdk/odps/data"
"github.com/aliyun/aliyun-odps-go-sdk/odps/tableschema"
)
type RecordPackStreamWriter struct {
session *StreamUploadSession
protocWriter RecordProtocWriter
flushing bool
buffer *bytes.Buffer
recordCount int64
}
func newRecordStreamHttpWriter(session *StreamUploadSession) RecordPackStreamWriter {
buffer := bytes.NewBuffer(nil)
return RecordPackStreamWriter{
session: session,
buffer: buffer,
protocWriter: newRecordProtocWriter(&bufWriter{buffer}, session.schema.Columns, false),
}
}
func (rsw *RecordPackStreamWriter) Append(record data.Record) error {
if rsw.flushing {
return errors.New("There's an unsuccessful flush called, you should call flush to retry or call reset to drop the data")
}
if !rsw.session.allowSchemaMismatch {
err := checkIfRecordSchemaMatchSessionSchema(&record, rsw.session.schema.Columns)
if err != nil {
return errors.WithStack(err)
}
}
err := rsw.protocWriter.Write(record)
if err == nil {
rsw.recordCount += 1
}
return errors.WithStack(err)
}
func checkIfRecordSchemaMatchSessionSchema(record *data.Record, schema []tableschema.Column) error {
if record.Len() != len(schema) {
return errors.Errorf("Record schema not match session schema, record len: %d, session schema len: %d",
record.Len(), len(schema))
}
for index, recordData := range *record {
colType := schema[index].Type.ID()
if recordData != nil && recordData.Type() != datatype.NullType && recordData.Type().ID() != colType {
return errors.Errorf("Record schema not match session schema, index: %d, record type: %s, session schema type: %s",
index, recordData.Type().Name(), schema[index].Type.Name())
}
}
return nil
}
// Flush send all buffered data to server. return (traceId, recordCount, recordBytes, error)
// `recordCount` and `recordBytes` is the count and bytes count of the records uploaded
func (rsw *RecordPackStreamWriter) Flush(timeout_ ...time.Duration) (string, int64, int64, error) {
timeout := time.Duration(0)
if len(timeout_) > 0 {
timeout = timeout_[0]
}
if rsw.recordCount == 0 {
return "", 0, 0, nil
}
// close protoc stream writer, the protoc stream will write the last protoc tags
if (!rsw.flushing) && (!rsw.protocWriter.closed) {
err := rsw.protocWriter.Close()
if err != nil {
return "", 0, 0, errors.WithStack(err)
}
}
rsw.flushing = true
reqId, bytesSend, err := rsw.session.flushStream(rsw, timeout)
if err != nil {
return "", 0, 0, err
}
recordCount := rsw.recordCount
rsw.flushing = false
rsw.reset()
return reqId, recordCount, int64(bytesSend), nil
}
// RecordCount the buffered record count
func (rsw *RecordPackStreamWriter) RecordCount() int64 {
return rsw.recordCount
}
// DataSize the buffered data size
func (rsw *RecordPackStreamWriter) DataSize() int64 {
return int64(rsw.buffer.Len())
}
func (rsw *RecordPackStreamWriter) reset() {
rsw.buffer.Reset()
rsw.protocWriter = newRecordProtocWriter(&bufWriter{rsw.buffer}, rsw.session.schema.Columns, false)
rsw.recordCount = 0
}