arrow/cdata/cdata_exports.go (468 lines of code) (raw):

// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package cdata // #include <errno.h> // #include <stdint.h> // #include <stdlib.h> // #include "abi.h" // #include "helpers.h" // // extern void releaseExportedSchema(struct ArrowSchema* schema); // extern void releaseExportedArray(struct ArrowArray* array); // // const uint8_t kGoCdataZeroRegion[8] = {0}; // // void goReleaseArray(struct ArrowArray* array) { // releaseExportedArray(array); // } // void goReleaseSchema(struct ArrowSchema* schema) { // releaseExportedSchema(schema); // } // // void goCallCancel(struct ArrowAsyncProducer* producer) { // producer->cancel(producer); // } // // int goExtractTaskData(struct ArrowAsyncTask* task, struct ArrowDeviceArray* out) { // return task->extract_data(task, out); // } // // static void goCallRequest(struct ArrowAsyncProducer* producer, int64_t n) { // producer->request(producer, n); // } import "C" import ( "bytes" "context" "encoding/binary" "fmt" "runtime/cgo" "strconv" "strings" "unsafe" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/endian" "github.com/apache/arrow-go/v18/arrow/internal" "github.com/apache/arrow-go/v18/arrow/ipc" ) func encodeCMetadata(keys, values []string) []byte { if len(keys) != len(values) { panic("unequal metadata key/values length") } npairs := int32(len(keys)) var b bytes.Buffer totalSize := 4 for i := range keys { totalSize += 8 + len(keys[i]) + len(values[i]) } b.Grow(totalSize) b.Write((*[4]byte)(unsafe.Pointer(&npairs))[:]) for i := range keys { binary.Write(&b, endian.Native, int32(len(keys[i]))) b.WriteString(keys[i]) binary.Write(&b, endian.Native, int32(len(values[i]))) b.WriteString(values[i]) } return b.Bytes() } type schemaExporter struct { format, name string extraMeta arrow.Metadata metadata []byte flags int64 children []schemaExporter dict *schemaExporter } func (exp *schemaExporter) handleExtension(dt arrow.DataType) arrow.DataType { if dt.ID() != arrow.EXTENSION { return dt } ext := dt.(arrow.ExtensionType) exp.extraMeta = arrow.NewMetadata([]string{ipc.ExtensionTypeKeyName, ipc.ExtensionMetadataKeyName}, []string{ext.ExtensionName(), ext.Serialize()}) return ext.StorageType() } func (exp *schemaExporter) exportMeta(m *arrow.Metadata) { var ( finalKeys []string finalValues []string ) if m == nil { if exp.extraMeta.Len() > 0 { finalKeys = exp.extraMeta.Keys() finalValues = exp.extraMeta.Values() } exp.metadata = encodeCMetadata(finalKeys, finalValues) return } finalKeys = m.Keys() finalValues = m.Values() if exp.extraMeta.Len() > 0 { for i, k := range exp.extraMeta.Keys() { if m.FindKey(k) != -1 { continue } finalKeys = append(finalKeys, k) finalValues = append(finalValues, exp.extraMeta.Values()[i]) } } exp.metadata = encodeCMetadata(finalKeys, finalValues) } func (exp *schemaExporter) exportFormat(dt arrow.DataType) string { switch dt := dt.(type) { case *arrow.NullType: return "n" case *arrow.BooleanType: return "b" case *arrow.Int8Type: return "c" case *arrow.Uint8Type: return "C" case *arrow.Int16Type: return "s" case *arrow.Uint16Type: return "S" case *arrow.Int32Type: return "i" case *arrow.Uint32Type: return "I" case *arrow.Int64Type: return "l" case *arrow.Uint64Type: return "L" case *arrow.Float16Type: return "e" case *arrow.Float32Type: return "f" case *arrow.Float64Type: return "g" case *arrow.FixedSizeBinaryType: return fmt.Sprintf("w:%d", dt.ByteWidth) case *arrow.Decimal32Type: return fmt.Sprintf("d:%d,%d,32", dt.Precision, dt.Scale) case *arrow.Decimal64Type: return fmt.Sprintf("d:%d,%d,64", dt.Precision, dt.Scale) case *arrow.Decimal128Type: return fmt.Sprintf("d:%d,%d", dt.Precision, dt.Scale) case *arrow.Decimal256Type: return fmt.Sprintf("d:%d,%d,256", dt.Precision, dt.Scale) case *arrow.BinaryType: return "z" case *arrow.LargeBinaryType: return "Z" case *arrow.StringType: return "u" case *arrow.LargeStringType: return "U" case *arrow.BinaryViewType: return "vz" case *arrow.StringViewType: return "vu" case *arrow.Date32Type: return "tdD" case *arrow.Date64Type: return "tdm" case *arrow.Time32Type: switch dt.Unit { case arrow.Second: return "tts" case arrow.Millisecond: return "ttm" default: panic(fmt.Sprintf("invalid time unit for time32: %s", dt.Unit)) } case *arrow.Time64Type: switch dt.Unit { case arrow.Microsecond: return "ttu" case arrow.Nanosecond: return "ttn" default: panic(fmt.Sprintf("invalid time unit for time64: %s", dt.Unit)) } case *arrow.TimestampType: var b strings.Builder switch dt.Unit { case arrow.Second: b.WriteString("tss:") case arrow.Millisecond: b.WriteString("tsm:") case arrow.Microsecond: b.WriteString("tsu:") case arrow.Nanosecond: b.WriteString("tsn:") default: panic(fmt.Sprintf("invalid time unit for timestamp: %s", dt.Unit)) } b.WriteString(dt.TimeZone) return b.String() case *arrow.DurationType: switch dt.Unit { case arrow.Second: return "tDs" case arrow.Millisecond: return "tDm" case arrow.Microsecond: return "tDu" case arrow.Nanosecond: return "tDn" default: panic(fmt.Sprintf("invalid time unit for duration: %s", dt.Unit)) } case *arrow.MonthIntervalType: return "tiM" case *arrow.DayTimeIntervalType: return "tiD" case *arrow.MonthDayNanoIntervalType: return "tin" case *arrow.ListType: return "+l" case *arrow.LargeListType: return "+L" case *arrow.ListViewType: return "+vl" case *arrow.LargeListViewType: return "+vL" case *arrow.FixedSizeListType: return fmt.Sprintf("+w:%d", dt.Len()) case *arrow.StructType: return "+s" case *arrow.RunEndEncodedType: return "+r" case *arrow.MapType: if dt.KeysSorted { exp.flags |= C.ARROW_FLAG_MAP_KEYS_SORTED } return "+m" case *arrow.DictionaryType: if dt.Ordered { exp.flags |= C.ARROW_FLAG_DICTIONARY_ORDERED } return exp.exportFormat(dt.IndexType) case arrow.UnionType: var b strings.Builder if dt.Mode() == arrow.SparseMode { b.WriteString("+us:") } else { b.WriteString("+ud:") } for i, c := range dt.TypeCodes() { if i != 0 { b.WriteByte(',') } b.WriteString(strconv.Itoa(int(c))) } return b.String() } panic("unsupported data type for export") } func (exp *schemaExporter) export(field arrow.Field) { exp.name = field.Name exp.format = exp.exportFormat(exp.handleExtension(field.Type)) if field.Nullable { exp.flags |= C.ARROW_FLAG_NULLABLE } switch dt := field.Type.(type) { case *arrow.DictionaryType: exp.dict = new(schemaExporter) exp.dict.export(arrow.Field{Type: dt.ValueType}) case arrow.NestedType: exp.children = make([]schemaExporter, dt.NumFields()) for i, f := range dt.Fields() { exp.children[i].export(f) } } exp.exportMeta(&field.Metadata) } func (exp *schemaExporter) finish(out *CArrowSchema) { out.dictionary = nil if exp.dict != nil { out.dictionary = (*CArrowSchema)(C.calloc(C.sizeof_struct_ArrowSchema, C.size_t(1))) exp.dict.finish(out.dictionary) } out.name = C.CString(exp.name) out.format = C.CString(exp.format) out.metadata = (*C.char)(C.CBytes(exp.metadata)) out.flags = C.int64_t(exp.flags) out.n_children = C.int64_t(len(exp.children)) if len(exp.children) > 0 { children := allocateArrowSchemaArr(len(exp.children)) childPtrs := allocateArrowSchemaPtrArr(len(exp.children)) for i, c := range exp.children { c.finish(&children[i]) childPtrs[i] = &children[i] } out.children = (**CArrowSchema)(unsafe.Pointer(&childPtrs[0])) } else { out.children = nil } out.release = (*[0]byte)(C.goReleaseSchema) } func exportField(field arrow.Field, out *CArrowSchema) { var exp schemaExporter exp.export(field) exp.finish(out) } func exportArray(arr arrow.Array, out *CArrowArray, outSchema *CArrowSchema) { if outSchema != nil { exportField(arrow.Field{Type: arr.DataType()}, outSchema) } buffers := arr.Data().Buffers() // Some types don't have validity bitmaps, but we keep them shifted // to make processing easier in other contexts. This means that // we have to adjust when exporting. has_validity_bitmap := internal.DefaultHasValidityBitmap(arr.DataType().ID()) if len(buffers) > 0 && !has_validity_bitmap { buffers = buffers[1:] } nbuffers := len(buffers) has_buffer_sizes_buffer := internal.HasBufferSizesBuffer(arr.DataType().ID()) if has_buffer_sizes_buffer { nbuffers++ } out.dictionary = nil out.null_count = C.int64_t(arr.NullN()) out.length = C.int64_t(arr.Len()) out.offset = C.int64_t(arr.Data().Offset()) out.n_buffers = C.int64_t(nbuffers) out.buffers = nil if nbuffers > 0 { cBufs := allocateBufferPtrArr(nbuffers) for i, buf := range buffers { if buf == nil || buf.Len() == 0 { if i > 0 || !has_validity_bitmap { // apache/arrow#33936: export a dummy buffer to be friendly to // implementations that don't import NULL properly cBufs[i] = (*C.void)(unsafe.Pointer(&C.kGoCdataZeroRegion)) } else { // null pointer permitted for the validity bitmap // (assuming null count is 0) cBufs[i] = nil } continue } cBufs[i] = (*C.void)(unsafe.Pointer(&buf.Bytes()[0])) } if has_buffer_sizes_buffer { sizes := allocateBufferSizeArr(len(buffers[2:])) for i, buf := range buffers[2:] { sizes[i] = C.int64_t(buf.Len()) } if len(sizes) > 0 { cBufs[nbuffers-1] = (*C.void)(unsafe.Pointer(&sizes[0])) } } out.buffers = (*unsafe.Pointer)(unsafe.Pointer(&cBufs[0])) } arr.Data().Retain() h := cgo.NewHandle(arr.Data()) out.private_data = createHandle(h) out.release = (*[0]byte)(C.goReleaseArray) switch arr := arr.(type) { case array.ListLike: out.n_children = 1 childPtrs := allocateArrowArrayPtrArr(1) children := allocateArrowArrayArr(1) exportArray(arr.ListValues(), &children[0], nil) childPtrs[0] = &children[0] out.children = (**CArrowArray)(unsafe.Pointer(&childPtrs[0])) case *array.Struct: out.n_children = C.int64_t(arr.NumField()) if arr.NumField() == 0 { return } childPtrs := allocateArrowArrayPtrArr(arr.NumField()) children := allocateArrowArrayArr(arr.NumField()) for i := 0; i < arr.NumField(); i++ { exportArray(arr.Field(i), &children[i], nil) childPtrs[i] = &children[i] } out.children = (**CArrowArray)(unsafe.Pointer(&childPtrs[0])) case *array.RunEndEncoded: out.n_children = 2 childPtrs := allocateArrowArrayPtrArr(2) children := allocateArrowArrayArr(2) exportArray(arr.RunEndsArr(), &children[0], nil) exportArray(arr.Values(), &children[1], nil) childPtrs[0], childPtrs[1] = &children[0], &children[1] out.children = (**CArrowArray)(unsafe.Pointer(&childPtrs[0])) case *array.Dictionary: out.dictionary = (*CArrowArray)(C.calloc(C.sizeof_struct_ArrowArray, C.size_t(1))) exportArray(arr.Dictionary(), out.dictionary, nil) case array.Union: out.n_children = C.int64_t(arr.NumFields()) if arr.NumFields() == 0 { return } childPtrs := allocateArrowArrayPtrArr(arr.NumFields()) children := allocateArrowArrayArr(arr.NumFields()) for i := 0; i < arr.NumFields(); i++ { exportArray(arr.Field(i), &children[i], nil) childPtrs[i] = &children[i] } out.children = (**CArrowArray)(unsafe.Pointer(&childPtrs[0])) default: out.n_children = 0 out.children = nil } } type cRecordReader struct { rdr array.RecordReader err *C.char } func (rr cRecordReader) getSchema(out *CArrowSchema) int { schema := rr.rdr.Schema() if schema == nil { return rr.maybeError() } ExportArrowSchema(schema, out) return 0 } func (rr cRecordReader) next(out *CArrowArray) int { if rr.rdr.Next() { ExportArrowRecordBatch(rr.rdr.Record(), out, nil) return 0 } C.ArrowArrayMarkReleased(out) return rr.maybeError() } func (rr cRecordReader) maybeError() int { err := rr.rdr.Err() if err != nil { return C.EIO } return 0 } func (rr cRecordReader) getLastError() *C.char { err := rr.rdr.Err() if err != nil { if rr.err != nil { C.free(unsafe.Pointer(rr.err)) } rr.err = C.CString(err.Error()) } return rr.err } func (rr cRecordReader) release() { if rr.err != nil { C.free(unsafe.Pointer(rr.err)) } rr.rdr.Release() } type cAsyncStreamHandler struct { producer *CArrowAsyncProducer taskQueue chan taskState ctx context.Context } func asyncTaskQueue(ctx context.Context, schema *arrow.Schema, recordStream chan<- RecordMessage, taskQueue <-chan taskState, producer *CArrowAsyncProducer) { defer close(recordStream) for { select { case <-ctx.Done(): C.goCallCancel(producer) return case task, ok := <-taskQueue: // if the queue closes or we receive a nil task, we're done if !ok || (task.err == nil && task.task.extract_data == nil) { return } if task.err != nil { recordStream <- RecordMessage{Err: task.err} continue } // request another batch now that we've processed this one C.goCallRequest(producer, C.int64_t(1)) var out CArrowDeviceArray if C.goExtractTaskData(&task.task, &out) != C.int(0) { continue } rec, err := ImportCRecordBatchWithSchema(&out.array, schema) if err != nil { recordStream <- RecordMessage{Err: err} } else { recordStream <- RecordMessage{Record: rec, AdditionalMetadata: task.meta} } } } } func (h *cAsyncStreamHandler) onNextTask(task *CArrowAsyncTask, metadata *C.char) C.int { if task == nil { h.taskQueue <- taskState{} return 0 } ts := taskState{task: *task} if metadata != nil { ts.meta = decodeCMetadata(metadata) } h.taskQueue <- ts return 0 } func (h *cAsyncStreamHandler) onError(code C.int, message, metadata *C.char) { h.taskQueue <- taskState{err: AsyncStreamError{ Code: int(code), Msg: C.GoString(message), Metadata: C.GoString(metadata)}} } func (h *cAsyncStreamHandler) release() { close(h.taskQueue) h.taskQueue, h.producer = nil, nil h.producer = nil }