in arrow/compute/exprs/types.go [546:697]
func ToSubstraitType(dt arrow.DataType, nullable bool, ext ExtensionIDSet) (types.Type, error) {
var nullability types.Nullability
if nullable {
nullability = types.NullabilityNullable
} else {
nullability = types.NullabilityRequired
}
switch dt.ID() {
case arrow.BOOL:
return &types.BooleanType{Nullability: nullability}, nil
case arrow.INT8:
return &types.Int8Type{Nullability: nullability}, nil
case arrow.INT16:
return &types.Int16Type{Nullability: nullability}, nil
case arrow.INT32:
return &types.Int32Type{Nullability: nullability}, nil
case arrow.INT64:
return &types.Int64Type{Nullability: nullability}, nil
case arrow.UINT8:
_, anchor, ok := ext.EncodeTypeVariation(dt)
if !ok {
return nil, arrow.ErrNotFound
}
return &types.Int8Type{
Nullability: nullability,
TypeVariationRef: anchor,
}, nil
case arrow.UINT16:
_, anchor, ok := ext.EncodeTypeVariation(dt)
if !ok {
return nil, arrow.ErrNotFound
}
return &types.Int16Type{
Nullability: nullability,
TypeVariationRef: anchor,
}, nil
case arrow.UINT32:
_, anchor, ok := ext.EncodeTypeVariation(dt)
if !ok {
return nil, arrow.ErrNotFound
}
return &types.Int32Type{
Nullability: nullability,
TypeVariationRef: anchor,
}, nil
case arrow.UINT64:
_, anchor, ok := ext.EncodeTypeVariation(dt)
if !ok {
return nil, arrow.ErrNotFound
}
return &types.Int64Type{
Nullability: nullability,
TypeVariationRef: anchor,
}, nil
case arrow.FLOAT16:
_, anchor, ok := ext.EncodeTypeVariation(dt)
if !ok {
return nil, arrow.ErrNotFound
}
return &types.Int16Type{
Nullability: nullability,
TypeVariationRef: anchor,
}, nil
case arrow.FLOAT32:
return &types.Float32Type{Nullability: nullability}, nil
case arrow.FLOAT64:
return &types.Float64Type{Nullability: nullability}, nil
case arrow.STRING, arrow.LARGE_STRING:
return &types.StringType{Nullability: nullability}, nil
case arrow.BINARY, arrow.LARGE_BINARY:
return &types.BinaryType{Nullability: nullability}, nil
case arrow.DATE32:
return &types.DateType{Nullability: nullability}, nil
case arrow.EXTENSION:
dt := dt.(arrow.ExtensionType)
switch dt.ExtensionName() {
case "uuid":
return &types.UUIDType{Nullability: nullability}, nil
case "fixed_char":
return &types.FixedCharType{
Nullability: nullability,
Length: int32(dt.StorageType().(*arrow.FixedSizeBinaryType).ByteWidth),
}, nil
case "varchar":
return &types.VarCharType{Nullability: nullability, Length: -1}, nil
case "interval_year":
return &types.IntervalYearType{Nullability: nullability}, nil
case "interval_day":
return &types.IntervalDayType{Nullability: nullability}, nil
default:
_, anchor, ok := ext.EncodeType(dt)
if !ok {
return nil, arrow.ErrNotFound
}
return &types.UserDefinedType{
Nullability: nullability,
TypeReference: anchor,
}, nil
}
case arrow.FIXED_SIZE_BINARY:
return &types.FixedBinaryType{Nullability: nullability,
Length: int32(dt.(*arrow.FixedSizeBinaryType).ByteWidth)}, nil
case arrow.DECIMAL128, arrow.DECIMAL256:
dt := dt.(arrow.DecimalType)
return &types.DecimalType{Nullability: nullability,
Precision: dt.GetPrecision(), Scale: dt.GetScale()}, nil
case arrow.STRUCT:
dt := dt.(*arrow.StructType)
fields := make([]types.Type, dt.NumFields())
var err error
for i, f := range dt.Fields() {
fields[i], err = ToSubstraitType(f.Type, f.Nullable, ext)
if err != nil {
return nil, err
}
}
return &types.StructType{
Nullability: nullability,
Types: fields,
}, nil
case arrow.LIST, arrow.FIXED_SIZE_LIST, arrow.LARGE_LIST:
dt := dt.(arrow.NestedType)
elemType, err := ToSubstraitType(dt.Fields()[0].Type, dt.Fields()[0].Nullable, ext)
if err != nil {
return nil, err
}
return &types.ListType{
Nullability: nullability,
Type: elemType,
}, nil
case arrow.MAP:
dt := dt.(*arrow.MapType)
keyType, err := ToSubstraitType(dt.KeyType(), false, ext)
if err != nil {
return nil, err
}
valueType, err := ToSubstraitType(dt.ItemType(), dt.ItemField().Nullable, ext)
if err != nil {
return nil, err
}
return &types.MapType{
Nullability: nullability,
Key: keyType,
Value: valueType,
}, nil
}
return nil, arrow.ErrNotImplemented
}