odps/data/struct.go (198 lines of code) (raw):
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package data
import (
"reflect"
"strings"
"sync"
"github.com/pkg/errors"
"github.com/aliyun/aliyun-odps-go-sdk/odps/datatype"
)
type StructField struct {
Name string
Value Data
}
func NewStructField(name string, value Data) StructField {
return StructField{
Name: name,
Value: value,
}
}
// Struct 这里用slice而不用map,是要保持Field顺序
type Struct struct {
typ datatype.StructType
fields []StructField
fieldIndexes map[string]int
Valid bool
}
func NewStruct() *Struct {
return &Struct{
typ: datatype.StructType{},
fields: make([]StructField, 0),
fieldIndexes: make(map[string]int),
}
}
func NewStructWithTyp(typ datatype.StructType) *Struct {
return &Struct{
typ: typ,
fields: make([]StructField, 0),
fieldIndexes: make(map[string]int),
Valid: true,
}
}
func (s Struct) Type() datatype.DataType {
return s.typ
}
func (s Struct) String() string {
var sb strings.Builder
sb.WriteString("struct<")
n := len(s.fields) - 1
for i, field := range s.fields {
sb.WriteString(field.Name)
sb.WriteString(":")
sb.WriteString(field.Value.String())
if i < n {
sb.WriteString(",")
}
}
sb.WriteString(">")
return sb.String()
}
func (s Struct) Sql() string {
var sb strings.Builder
sb.WriteString("named_struct(")
n := len(s.fields) - 1
for i, field := range s.fields {
sb.WriteString("'")
sb.WriteString(field.Name)
sb.WriteString("'")
sb.WriteString(", ")
if field.Value != nil {
sb.WriteString(field.Value.Sql())
} else {
sb.WriteString("null")
}
if i < n {
sb.WriteString(", ")
}
}
sb.WriteString(")")
return sb.String()
}
func (s *Struct) Scan(value interface{}) error {
return errors.WithStack(tryConvertType(value, s))
}
func (s *Struct) Fields() []StructField {
return s.fields
}
func (s *Struct) GetField(fieldName string) Data {
i, ok := s.fieldIndexes[fieldName]
if !ok {
return nil
}
return s.fields[i].Value
}
func (s *Struct) SetField(fieldName string, a interface{}) error {
d, err := TryConvertGoToOdpsData(a)
if err != nil {
return errors.WithStack(err)
}
i, ok := s.fieldIndexes[fieldName]
if !ok {
m := sync.Mutex{}
m.Lock()
s.fields = append(s.fields, NewStructField(fieldName, d))
s.fieldIndexes[fieldName] = len(s.fields) - 1
m.Unlock()
} else {
s.fields[i] = NewStructField(fieldName, d)
}
return nil
}
func (s *Struct) SafeSetField(fieldName string, i interface{}) error {
if s.typ.Fields == nil {
return errors.New("type of Struct has not be set")
}
d, err := TryConvertGoToOdpsData(i)
if err != nil {
return errors.WithStack(err)
}
var fieldType datatype.DataType
for _, f := range s.typ.Fields {
if f.Name == fieldName {
fieldType = f.Type
break
}
}
if fieldType == nil {
return errors.Errorf("cannot set %s to %s", fieldName, s.typ)
}
if !datatype.IsTypeEqual(fieldType, d.Type()) {
return errors.Errorf("cannot set type %s to %s of %s", d.Type(), fieldName, s.typ)
}
_ = s.SetField(fieldName, d)
return nil
}
func (s *Struct) TypeInfer() (datatype.DataType, error) {
if len(s.fields) == 0 {
return nil, errors.New("cannot infer type for empty struct")
}
fieldTypes := make([]datatype.StructFieldType, len(s.fields))
for i, field := range s.fields {
fieldTypes[i] = datatype.NewStructFieldType(field.Name, field.Value.Type())
}
return datatype.NewStructType(fieldTypes...), nil
}
func StructFromGoStruct(i interface{}) (*Struct, error) {
it := reflect.TypeOf(i)
if it.Kind() != reflect.Struct {
return nil, errors.Errorf("%s is not a struct", it.Name())
}
s, err := TryConvertGoToOdpsData(i)
if err != nil {
return nil, errors.WithStack(err)
}
ret, _ := s.(*Struct)
return ret, nil
}
func (s *Struct) FillGoStruct(i interface{}) error {
it := reflect.TypeOf(i)
if it.Kind() != reflect.Ptr {
return errors.Errorf("%s is not a pointer", it.Name())
}
it = it.Elem()
if it.Kind() != reflect.Struct {
return errors.Errorf("%s is not a struct", it.Name())
}
iv := reflect.ValueOf(i).Elem()
for j, n := 0, it.NumField(); j < n; j++ {
field := it.Field(j)
fieldName := field.Tag.Get("odps")
if fieldName == "" {
fieldName = field.Name
}
data := s.GetField(fieldName)
if data == nil {
continue
}
fv := iv.Field(j)
var goData interface{}
switch dt := data.(type) {
case *Struct:
return dt.FillGoStruct(fv)
case *Array:
goData = dt.ToSlice()
case *Map:
goData = dt.ToGoMap()
case *String:
goData = *(*string)(dt)
default:
goData = dt
}
goDataT := reflect.TypeOf(goData)
if goDataT.AssignableTo(field.Type) {
fv.Set(reflect.ValueOf(goData))
}
if goDataT.ConvertibleTo(field.Type) {
fv.Set(reflect.ValueOf(goData).Convert(field.Type))
}
}
return nil
}