odps/tunnel/download_session.go (314 lines of code) (raw):
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tunnel
import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"github.com/pkg/errors"
"github.com/aliyun/aliyun-odps-go-sdk/arrow"
"github.com/aliyun/aliyun-odps-go-sdk/odps/common"
"github.com/aliyun/aliyun-odps-go-sdk/odps/restclient"
"github.com/aliyun/aliyun-odps-go-sdk/odps/tableschema"
)
type DownLoadStatus int
const (
_ DownLoadStatus = iota
DownloadStatusUnknown
DownloadStatusNormal
DownloadStatusClosed
DownloadStatusExpired
DownloadStatusInitiating
)
// DownloadSession is used to download table data, it can be created by Tunnel.
// You can use RecordCount to get the count of total records, and can create
// multiply RecordReader in parallel according the record count to download
// the data in less time. The RecordArrowReader is the only RecordReader now.
//
// Underneath the RecordReader is the http connection, when no data occurs in it during
// 300s, the tunnel sever will closeRes it.
type DownloadSession struct {
Id string
ProjectName string
// TODO use schema to get the resource url of a table
SchemaName string
TableName string
QuotaName string
// The partition keys used by a session can not contain "'", for example, "region=hangzhou" is a
// positive case, and "region='hangzhou'" is a negative case. But the partition keys like "region='hangzhou'" are more
// common, to avoid the users use the error format, the partitionKey of UploadSession is private, it can be set when
// creating a session or using SetPartitionKey.
partitionKey string
Async bool
ShardId int
Compressor Compressor
RestClient restclient.RestClient
schema tableschema.TableSchema
status DownLoadStatus
recordCount int
shouldTransformDate bool
arrowSchema *arrow.Schema
}
// CreateDownloadSession create a new download session before downing data.
// The opts can be one or more of:
// SessionCfg.WithPartitionKey
// SessionCfg.WithSchemaName, it doesn't work now
// SessionCfg.WithDefaultDeflateCompressor, using deflate compressor with default level
// SessionCfg.WithDeflateCompressor, using deflate compressor with specific level
// SessionCfg.WithSnappyFramedCompressor
// SessionCfg.Overwrite, overwrite data
// SessionCfg.DisableArrow, disable arrow reader, using protoc reader instead.
// SessionCfg.ShardId, set the shard id of the table
// SessionCfg.Async, enable the async mode of the session which can avoiding timeout when there are many small files
func CreateDownloadSession(
projectName, tableName, quotaName string,
restClient restclient.RestClient,
opts ...Option,
) (*DownloadSession, error) {
cfg := newSessionConfig(opts...)
session := DownloadSession{
ProjectName: projectName,
SchemaName: cfg.SchemaName,
TableName: tableName,
QuotaName: quotaName,
RestClient: restClient,
partitionKey: cfg.PartitionKey,
ShardId: cfg.ShardId,
Async: cfg.Async,
Compressor: cfg.Compressor,
}
req, err := session.newInitiationRequest()
if err != nil {
return nil, errors.WithStack(err)
}
err = session.loadInformation(req)
if err != nil {
return nil, errors.WithStack(err)
}
return &session, nil
}
// AttachToExistedDownloadSession get an existed session by the session id.
// The opts can be one or more of:
// SessionCfg.WithPartitionKey
// SessionCfg.WithSchemaName, it doesn't work now
// SessionCfg.WithDefaultDeflateCompressor, using deflate compressor with default level
// SessionCfg.WithDeflateCompressor, using deflate compressor with specific level
// SessionCfg.WithSnappyFramedCompressor
// SessionCfg.Overwrite, overwrite data
// SessionCfg.DisableArrow, disable arrow reader, using protoc reader instead.
// SessionCfg.ShardId, set the shard id of the table
// SessionCfg.Async, enable the async mode of the session which can avoiding timeout when there are many small files
func AttachToExistedDownloadSession(
sessionId, projectName, tableName string,
restClient restclient.RestClient,
opts ...Option,
) (*DownloadSession, error) {
cfg := newSessionConfig(opts...)
session := DownloadSession{
Id: sessionId,
ProjectName: projectName,
SchemaName: cfg.SchemaName,
TableName: tableName,
RestClient: restClient,
partitionKey: cfg.PartitionKey,
ShardId: cfg.ShardId,
Async: cfg.Async,
Compressor: cfg.Compressor,
}
req, err := session.newLoadRequest()
if err != nil {
return nil, errors.WithStack(err)
}
err = session.loadInformation(req)
if err != nil {
return nil, errors.WithStack(err)
}
return &session, nil
}
func (ds *DownloadSession) Schema() tableschema.TableSchema {
return ds.schema
}
func (ds *DownloadSession) Status() DownLoadStatus {
return ds.status
}
func (ds *DownloadSession) RecordCount() int {
return ds.recordCount
}
func (ds *DownloadSession) ShouldTransformDate() bool {
return ds.shouldTransformDate
}
func (ds *DownloadSession) ArrowSchema() *arrow.Schema {
if ds.arrowSchema != nil {
return ds.arrowSchema
}
ds.arrowSchema = ds.schema.ToArrowSchema()
return ds.arrowSchema
}
func (ds *DownloadSession) PartitionKey() string {
return ds.partitionKey
}
func (ds *DownloadSession) SetPartitionKey(partitionKey string) {
ds.partitionKey = strings.ReplaceAll(partitionKey, "'", "")
ds.partitionKey = strings.ReplaceAll(ds.partitionKey, "\"", "")
}
func (ds *DownloadSession) ResourceUrl() string {
rb := common.NewResourceBuilder(ds.ProjectName)
return rb.Table(ds.SchemaName, ds.TableName)
}
func (ds *DownloadSession) OpenRecordArrowReader(start, count int, columnNames []string) (*RecordArrowReader, error) {
arrowSchema := ds.arrowSchema
if len(columnNames) == 0 {
columnNames = make([]string, len(ds.schema.Columns))
for i, c := range ds.schema.Columns {
columnNames[i] = c.Name
}
}
arrowFields := make([]arrow.Field, 0, len(columnNames))
for _, columnName := range columnNames {
fs, ok := ds.arrowSchema.FieldsByName(columnName)
if !ok {
return nil, errors.Errorf("no column names %s in table %s", columnName, ds.TableName)
}
arrowFields = append(arrowFields, fs...)
}
arrowSchema = arrow.NewSchema(arrowFields, nil)
res, err := ds.newDownloadConnection(start, count, columnNames, true)
if err != nil {
return nil, errors.WithStack(err)
}
reader := newRecordArrowReader(res, arrowSchema)
return &reader, nil
}
func (ds *DownloadSession) OpenRecordReader(start, count int, columnNames []string) (*RecordProtocReader, error) {
if len(columnNames) == 0 {
columnNames = make([]string, len(ds.schema.Columns))
for i, c := range ds.schema.Columns {
columnNames[i] = c.Name
}
}
columns := make([]tableschema.Column, len(columnNames))
for i, columnName := range columnNames {
c, ok := ds.schema.FieldByName(columnName)
if !ok {
return nil, errors.Errorf("no column names %s in table", columnName)
}
columns[i] = c
}
res, err := ds.newDownloadConnection(start, count, columnNames, false)
if err != nil {
return nil, errors.WithStack(err)
}
reader := newRecordProtocReader(res, columns, ds.shouldTransformDate)
return &reader, nil
}
func (ds *DownloadSession) newInitiationRequest() (*http.Request, error) {
resource := ds.ResourceUrl()
queryArgs := make(url.Values, 4)
queryArgs.Set("downloads", "")
if ds.Async {
queryArgs.Set("asyncmode", "true")
}
if ds.partitionKey != "" {
queryArgs.Set("partition", ds.partitionKey)
}
if ds.ShardId != 0 {
queryArgs.Set("shard", strconv.Itoa(ds.ShardId))
}
if ds.QuotaName != "" {
queryArgs.Set("quotaName", ds.QuotaName)
}
req, err := ds.RestClient.NewRequestWithUrlQuery(common.HttpMethod.PostMethod, resource, nil, queryArgs)
if err != nil {
return nil, errors.WithStack(err)
}
addCommonSessionHttpHeader(req.Header)
return req, nil
}
func (ds *DownloadSession) newLoadRequest() (*http.Request, error) {
resource := ds.ResourceUrl()
queryArgs := make(url.Values, 2)
queryArgs.Set("downloadid", ds.Id)
if ds.partitionKey != "" {
queryArgs.Set("partition", ds.partitionKey)
}
if ds.ShardId != 0 {
queryArgs.Set("shard", strconv.Itoa(ds.ShardId))
}
req, err := ds.RestClient.NewRequestWithUrlQuery(common.HttpMethod.GetMethod, resource, nil, queryArgs)
if err != nil {
return nil, errors.WithStack(err)
}
addCommonSessionHttpHeader(req.Header)
return req, nil
}
func (ds *DownloadSession) loadInformation(req *http.Request) error {
type ResModel struct {
DownloadID string `json:"DownloadID"`
Initiated string `json:"Initiated"`
Owner string `json:"Owner"`
RecordCount int `json:"RecordCount"`
Schema schemaResModel `json:"Schema"`
Status string `json:"Status"`
}
var resModel ResModel
err := ds.RestClient.DoWithParseFunc(req, func(res *http.Response) error {
if res.StatusCode/100 != 2 {
return restclient.NewHttpNotOk(res)
}
ds.shouldTransformDate = res.Header.Get(common.HttpHeaderOdpsDateTransFrom) == "true"
decoder := json.NewDecoder(res.Body)
return decoder.Decode(&resModel)
})
if err != nil {
return errors.WithStack(err)
}
tableSchema, err := resModel.Schema.toTableSchema(ds.TableName)
if err != nil {
return errors.WithStack(err)
}
ds.Id = resModel.DownloadID
ds.status = DownloadStatusFromStr(resModel.Status)
ds.recordCount = resModel.RecordCount
ds.schema = tableSchema
ds.arrowSchema = tableSchema.ToArrowSchema()
return nil
}
func (ds *DownloadSession) newDownloadConnection(start, count int, columnNames []string, useArrow bool) (*http.Response, error) {
queryArgs := make(url.Values, 6)
if len(columnNames) > 0 {
queryArgs.Set("columns", strings.Join(columnNames, ","))
}
queryArgs.Set("downloadid", ds.Id)
queryArgs.Set("data", "")
queryArgs.Set("rowrange", fmt.Sprintf("(%d,%d)", start, count))
if ds.partitionKey != "" {
queryArgs.Set("partition", ds.partitionKey)
}
if useArrow {
queryArgs.Set("arrow", "")
}
req, err := ds.RestClient.NewRequestWithUrlQuery(
common.HttpMethod.GetMethod,
ds.ResourceUrl(),
nil,
queryArgs,
)
if err != nil {
return nil, errors.WithStack(err)
}
if ds.Compressor != nil {
req.Header.Set("Accept-Encoding", ds.Compressor.Name())
}
addCommonSessionHttpHeader(req.Header)
res, err := ds.RestClient.Do(req)
if err != nil {
return nil, errors.WithStack(err)
}
if res.StatusCode/100 != 2 {
return res, restclient.NewHttpNotOk(res)
}
contentEncoding := res.Header.Get("Content-Encoding")
if contentEncoding != "" {
res.Body = WrapByCompressor(res.Body, contentEncoding)
}
return res, nil
}
func DownloadStatusFromStr(s string) DownLoadStatus {
switch strings.ToUpper(s) {
case "UNKNOWN":
return DownloadStatusUnknown
case "NORMAL":
return DownloadStatusNormal
case "CLOSED":
return DownloadStatusClosed
case "EXPIRED":
return DownloadStatusExpired
case "INITIATING":
return DownloadStatusInitiating
default:
return DownloadStatusUnknown
}
}
func (status DownLoadStatus) String() string {
switch status {
case DownloadStatusUnknown:
return "UNKNOWN"
case DownloadStatusNormal:
return "NORMAL"
case DownloadStatusClosed:
return "CLOSED"
case DownloadStatusExpired:
return "EXPIRED"
case DownloadStatusInitiating:
return "INITIATING"
default:
return "UNKNOWN"
}
}