csharp/src/Client/AdbcCommand.cs (491 lines of code) (raw):
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
using System;
using System.Collections;
using System.Collections.Generic;
using System.Data;
using System.Data.Common;
using System.Data.SqlTypes;
using System.Globalization;
using System.Linq;
using System.Threading.Tasks;
using Apache.Arrow.Types;
namespace Apache.Arrow.Adbc.Client
{
/// <summary>
/// Creates an ADO.NET command over an Adbc statement.
/// </summary>
public sealed class AdbcCommand : DbCommand
{
private readonly AdbcStatement _adbcStatement;
private AdbcParameterCollection? _dbParameterCollection;
private int _timeout = 30;
private bool _disposed;
private string? _commandTimeoutProperty;
/// <summary>
/// Overloaded. Initializes <see cref="AdbcCommand"/>.
/// </summary>
/// <param name="adbcConnection">
/// The <see cref="AdbcConnection"/> to use.
/// </param>
/// <exception cref="ArgumentNullException"></exception>
public AdbcCommand(AdbcConnection adbcConnection) : base()
{
if (adbcConnection == null)
throw new ArgumentNullException(nameof(adbcConnection));
this.DbConnection = adbcConnection;
this.DecimalBehavior = adbcConnection.DecimalBehavior;
this.StructBehavior = adbcConnection.StructBehavior;
this._adbcStatement = adbcConnection.CreateStatement();
}
/// <summary>
/// Overloaded. Initializes <see cref="AdbcCommand"/>.
/// </summary>
/// <param name="query">The command text to use.</param>
/// <param name="adbcConnection">The <see cref="AdbcConnection"/> to use.</param>
public AdbcCommand(string query, AdbcConnection adbcConnection) : base()
{
if (string.IsNullOrEmpty(query))
throw new ArgumentNullException(nameof(query));
if (adbcConnection == null)
throw new ArgumentNullException(nameof(adbcConnection));
this._adbcStatement = adbcConnection.CreateStatement();
this.CommandText = query;
this.DbConnection = adbcConnection;
this.DecimalBehavior = adbcConnection.DecimalBehavior;
this.StructBehavior = adbcConnection.StructBehavior;
}
// For testing
internal AdbcCommand(AdbcStatement adbcStatement, AdbcConnection adbcConnection)
{
this._adbcStatement = adbcStatement;
this.DbConnection = adbcConnection;
this.DecimalBehavior = adbcConnection.DecimalBehavior;
this.StructBehavior = adbcConnection.StructBehavior;
if (adbcConnection.CommandTimeoutValue != null)
{
this.AdbcCommandTimeoutProperty = adbcConnection.CommandTimeoutValue.DriverPropertyName;
this.CommandTimeout = adbcConnection.CommandTimeoutValue.Value;
}
}
/// <summary>
/// Gets the <see cref="AdbcStatement"/> associated with
/// this <see cref="AdbcCommand"/>.
/// </summary>
public AdbcStatement AdbcStatement => _disposed ? throw new ObjectDisposedException(nameof(AdbcCommand)) : this._adbcStatement;
public DecimalBehavior DecimalBehavior { get; set; }
public StructBehavior StructBehavior { get; set; }
public override string CommandText
{
get => AdbcStatement.SqlQuery ?? string.Empty;
#nullable disable
set => AdbcStatement.SqlQuery = string.IsNullOrEmpty(value) ? null : value;
#nullable restore
}
public override CommandType CommandType
{
get
{
return CommandType.Text;
}
set
{
if (value != CommandType.Text)
{
throw new AdbcException("Only CommandType.Text is supported");
}
}
}
/// <summary>
/// Gets or sets the name of the command timeout property for the underlying ADBC driver.
/// </summary>
public string AdbcCommandTimeoutProperty
{
get
{
if (string.IsNullOrEmpty(_commandTimeoutProperty))
throw new InvalidOperationException("CommandTimeoutProperty is not set.");
return _commandTimeoutProperty!;
}
set => _commandTimeoutProperty = value;
}
public override int CommandTimeout
{
get => _timeout;
set
{
// ensures the property exists before setting the CommandTimeout value
string property = AdbcCommandTimeoutProperty;
_adbcStatement.SetOption(property, value.ToString(CultureInfo.InvariantCulture));
_timeout = value;
}
}
protected override DbParameterCollection DbParameterCollection
{
get
{
if (_dbParameterCollection == null)
{
_dbParameterCollection = new AdbcParameterCollection();
}
return _dbParameterCollection;
}
}
/// <summary>
/// Gets or sets the Substrait plan used by the command.
/// </summary>
public byte[]? SubstraitPlan
{
get => AdbcStatement.SubstraitPlan;
set => AdbcStatement.SubstraitPlan = value;
}
protected override DbConnection? DbConnection { get; set; }
public override int ExecuteNonQuery()
{
BindParameters();
return Convert.ToInt32(AdbcStatement.ExecuteUpdate().AffectedRows);
}
/// <summary>
/// Similar to <see cref="ExecuteNonQuery"/> but returns Int64
/// instead of Int32.
/// </summary>
/// <returns></returns>
public long ExecuteUpdate()
{
BindParameters();
return AdbcStatement.ExecuteUpdate().AffectedRows;
}
/// <summary>
/// Executes the query
/// </summary>
/// <returns><see cref="Result"></returns>
public QueryResult ExecuteQuery()
{
BindParameters();
QueryResult executed = AdbcStatement.ExecuteQuery();
return executed;
}
protected override DbDataReader ExecuteDbDataReader(CommandBehavior behavior)
{
return ExecuteReader(behavior);
}
/// <summary>
/// Executes the reader with the default behavior.
/// </summary>
/// <returns><see cref="AdbcDataReader"/></returns>
public new AdbcDataReader ExecuteReader()
{
return ExecuteReader(CommandBehavior.Default);
}
/// <summary>
/// Executes the reader with the specified behavior.
/// </summary>
/// <param name="behavior">
/// The <see cref="CommandBehavior"/>
/// </param>
/// <returns><see cref="AdbcDataReader"/></returns>
public new AdbcDataReader ExecuteReader(CommandBehavior behavior)
{
if (_disposed)
throw new ObjectDisposedException(nameof(AdbcCommand));
bool closeConnection = (behavior & CommandBehavior.CloseConnection) != 0;
switch (behavior & ~CommandBehavior.CloseConnection)
{
case CommandBehavior.SchemaOnly: // The schema is not known until a read happens
case CommandBehavior.Default:
QueryResult result = this.ExecuteQuery();
return new AdbcDataReader(this, result, this.DecimalBehavior, this.StructBehavior, closeConnection);
default:
throw new InvalidOperationException($"{behavior} is not supported with this provider");
}
}
protected override void Dispose(bool disposing)
{
if (disposing && !_disposed)
{
// TODO: ensure not in the middle of pulling
this._adbcStatement.Dispose();
_disposed = true;
}
base.Dispose(disposing);
}
private void BindParameters()
{
if (_dbParameterCollection?.Count > 0)
{
Field[] fields = new Field[_dbParameterCollection.Count];
IArrowArray[] parameters = new IArrowArray[_dbParameterCollection.Count];
for (int i = 0; i < fields.Length; i++)
{
AdbcParameter param = (AdbcParameter)_dbParameterCollection[i];
switch (param.DbType)
{
case DbType.Binary:
var binaryBuilder = new BinaryArray.Builder();
switch (param.Value)
{
case null: binaryBuilder.AppendNull(); break;
case byte[] array: binaryBuilder.Append(array.AsSpan()); break;
default: throw new NotSupportedException($"Values of type {param.Value.GetType().Name} cannot be bound as binary");
}
parameters[i] = binaryBuilder.Build();
break;
case DbType.Boolean:
var boolBuilder = new BooleanArray.Builder();
switch (param.Value)
{
case null: boolBuilder.AppendNull(); break;
case bool boolValue: boolBuilder.Append(boolValue); break;
default: boolBuilder.Append(ConvertValue(param.Value, Convert.ToBoolean, DbType.Boolean)); break;
}
parameters[i] = boolBuilder.Build();
break;
case DbType.Byte:
var uint8Builder = new UInt8Array.Builder();
switch (param.Value)
{
case null: uint8Builder.AppendNull(); break;
case byte byteValue: uint8Builder.Append(byteValue); break;
default: uint8Builder.Append(ConvertValue(param.Value, Convert.ToByte, DbType.Byte)); break;
}
parameters[i] = uint8Builder.Build();
break;
case DbType.Date:
var dateBuilder = new Date32Array.Builder();
switch (param.Value)
{
case null: dateBuilder.AppendNull(); break;
case DateTime datetime: dateBuilder.Append(datetime); break;
#if NET5_0_OR_GREATER
case DateOnly dateonly: dateBuilder.Append(dateonly); break;
#endif
default: dateBuilder.Append(ConvertValue(param.Value, Convert.ToDateTime, DbType.Date)); break;
}
parameters[i] = dateBuilder.Build();
break;
case DbType.DateTime:
var timestampBuilder = new TimestampArray.Builder();
switch (param.Value)
{
case null: timestampBuilder.AppendNull(); break;
case DateTime datetime: timestampBuilder.Append(datetime); break;
default: timestampBuilder.Append(ConvertValue(param.Value, Convert.ToDateTime, DbType.DateTime)); break;
}
parameters[i] = timestampBuilder.Build();
break;
case DbType.Decimal:
var value = param.Value switch
{
null => (SqlDecimal?)null,
SqlDecimal sqlDecimal => sqlDecimal,
decimal d => new SqlDecimal(d),
_ => new SqlDecimal(ConvertValue(param.Value, Convert.ToDecimal, DbType.Decimal)),
};
var decimalBuilder = new Decimal128Array.Builder(new Decimal128Type(value?.Precision ?? 10, value?.Scale ?? 0));
if (value is null)
{
decimalBuilder.AppendNull();
}
else
{
decimalBuilder.Append(value.Value);
}
parameters[i] = decimalBuilder.Build();
break;
case DbType.Double:
var doubleBuilder = new DoubleArray.Builder();
switch (param.Value)
{
case null: doubleBuilder.AppendNull(); break;
case double dbl: doubleBuilder.Append(dbl); break;
default: doubleBuilder.Append(ConvertValue(param.Value, Convert.ToDouble, DbType.Double)); break;
}
parameters[i] = doubleBuilder.Build();
break;
case DbType.Int16:
var int16Builder = new Int16Array.Builder();
switch (param.Value)
{
case null: int16Builder.AppendNull(); break;
case short shortValue: int16Builder.Append(shortValue); break;
default: int16Builder.Append(ConvertValue(param.Value, Convert.ToInt16, DbType.Int16)); break;
}
parameters[i] = int16Builder.Build();
break;
case DbType.Int32:
var int32Builder = new Int32Array.Builder();
switch (param.Value)
{
case null: int32Builder.AppendNull(); break;
case int intValue: int32Builder.Append(intValue); break;
default: int32Builder.Append(ConvertValue(param.Value, Convert.ToInt32, DbType.Int32)); break;
}
parameters[i] = int32Builder.Build();
break;
case DbType.Int64:
var int64Builder = new Int64Array.Builder();
switch (param.Value)
{
case null: int64Builder.AppendNull(); break;
case long longValue: int64Builder.Append(longValue); break;
default: int64Builder.Append(ConvertValue(param.Value, Convert.ToInt64, DbType.Int64)); break;
}
parameters[i] = int64Builder.Build();
break;
case DbType.SByte:
var int8Builder = new Int8Array.Builder();
switch (param.Value)
{
case null: int8Builder.AppendNull(); break;
case sbyte sbyteValue: int8Builder.Append(sbyteValue); break;
default: int8Builder.Append(ConvertValue(param.Value, Convert.ToSByte, DbType.SByte)); break;
}
parameters[i] = int8Builder.Build();
break;
case DbType.Single:
var floatBuilder = new FloatArray.Builder();
switch (param.Value)
{
case null: floatBuilder.AppendNull(); break;
case float floatValue: floatBuilder.Append(floatValue); break;
default: floatBuilder.Append(ConvertValue(param.Value, Convert.ToSingle, DbType.Single)); break;
}
parameters[i] = floatBuilder.Build();
break;
case DbType.String:
var stringBuilder = new StringArray.Builder();
switch (param.Value)
{
case null: stringBuilder.AppendNull(); break;
case string stringValue: stringBuilder.Append(stringValue); break;
default: stringBuilder.Append(ConvertValue(param.Value, Convert.ToString, DbType.String)); break;
}
parameters[i] = stringBuilder.Build();
break;
case DbType.Time:
var timeBuilder = new Time32Array.Builder();
switch (param.Value)
{
case null: timeBuilder.AppendNull(); break;
case DateTime datetime: timeBuilder.Append((int)(datetime.TimeOfDay.Ticks / TimeSpan.TicksPerMillisecond)); break;
#if NET5_0_OR_GREATER
case TimeOnly timeonly: timeBuilder.Append(timeonly); break;
#endif
default:
DateTime convertedDateTime = ConvertValue(param.Value, Convert.ToDateTime, DbType.Time);
timeBuilder.Append((int)(convertedDateTime.TimeOfDay.Ticks / TimeSpan.TicksPerMillisecond));
break;
}
parameters[i] = timeBuilder.Build();
break;
case DbType.UInt16:
var uint16Builder = new UInt16Array.Builder();
switch (param.Value)
{
case null: uint16Builder.AppendNull(); break;
case ushort ushortValue: uint16Builder.Append(ushortValue); break;
default: uint16Builder.Append(ConvertValue(param.Value, Convert.ToUInt16, DbType.UInt16)); break;
}
parameters[i] = uint16Builder.Build();
break;
case DbType.UInt32:
var uint32Builder = new UInt32Array.Builder();
switch (param.Value)
{
case null: uint32Builder.AppendNull(); break;
case uint uintValue: uint32Builder.Append(uintValue); break;
default: uint32Builder.Append(ConvertValue(param.Value, Convert.ToUInt32, DbType.UInt32)); break;
}
parameters[i] = uint32Builder.Build();
break;
case DbType.UInt64:
var uint64Builder = new UInt64Array.Builder();
switch (param.Value)
{
case null: uint64Builder.AppendNull(); break;
case ulong ulongValue: uint64Builder.Append(ulongValue); break;
default: uint64Builder.Append(ConvertValue(param.Value, Convert.ToUInt64, DbType.UInt64)); break;
}
parameters[i] = uint64Builder.Build();
break;
default:
throw new NotSupportedException($"Parameters of type {param.DbType} are not supported");
}
fields[i] = new Field(
string.IsNullOrWhiteSpace(param.ParameterName) ? Guid.NewGuid().ToString() : param.ParameterName,
parameters[i].Data.DataType,
param.IsNullable || param.Value == null);
}
Schema schema = new Schema(fields, null);
AdbcStatement.Bind(new RecordBatch(schema, parameters, 1), schema);
}
}
private static T ConvertValue<T>(object value, Func<object, T> converter, DbType type)
{
try
{
return converter(value);
}
catch (Exception)
{
throw new NotSupportedException($"Values of type {value.GetType().Name} cannot be bound as {type}.");
}
}
public override void Prepare()
{
_adbcStatement.Prepare();
var schema = _adbcStatement.GetParameterSchema();
DbParameterCollection.Clear();
foreach (Field field in schema.FieldsList)
{
AdbcParameter parameter = new AdbcParameter
{
ParameterName = field.Name,
IsNullable = field.IsNullable,
DbType = field.DataType.TypeId switch
{
ArrowTypeId.UInt8 => DbType.Byte,
ArrowTypeId.UInt16 => DbType.UInt16,
ArrowTypeId.UInt32 => DbType.UInt32,
ArrowTypeId.UInt64 => DbType.UInt64,
ArrowTypeId.Int8 => DbType.SByte,
ArrowTypeId.Int16 => DbType.Int16,
ArrowTypeId.Int32 => DbType.Int32,
ArrowTypeId.Int64 => DbType.Int64,
ArrowTypeId.Float => DbType.Single,
ArrowTypeId.Double => DbType.Double,
ArrowTypeId.Boolean => DbType.Boolean,
ArrowTypeId.String => DbType.String,
ArrowTypeId.Date32 => DbType.Date,
ArrowTypeId.Date64 => DbType.DateTime,
ArrowTypeId.Time32 => DbType.Time,
ArrowTypeId.Time64 => DbType.Time,
ArrowTypeId.Timestamp => DbType.DateTime,
ArrowTypeId.Decimal32 or
ArrowTypeId.Decimal64 or
ArrowTypeId.Decimal128 or
ArrowTypeId.Decimal256 => DbType.Decimal,
_ => DbType.Object,
},
};
DbParameterCollection.Add(parameter);
}
}
protected override DbParameter CreateDbParameter()
{
return new AdbcParameter();
}
#if NET5_0_OR_GREATER
public override ValueTask DisposeAsync()
{
return base.DisposeAsync();
}
#endif
#region NOT_IMPLEMENTED
public override bool DesignTimeVisible { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }
public override UpdateRowSource UpdatedRowSource { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }
protected override DbTransaction? DbTransaction { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }
public override void Cancel()
{
throw new NotImplementedException();
}
public override object ExecuteScalar()
{
throw new NotImplementedException();
}
#endregion
private class AdbcParameterCollection : DbParameterCollection
{
readonly List<AdbcParameter> _parameters = new List<AdbcParameter>();
public override int Count => _parameters.Count;
public override object SyncRoot => throw new NotImplementedException();
public override int Add(object value)
{
int result = _parameters.Count;
_parameters.Add((AdbcParameter)value);
return result;
}
public override void AddRange(System.Array values) => _parameters.AddRange(values.Cast<AdbcParameter>());
public override void Clear() => _parameters.Clear();
public override bool Contains(object value) => _parameters.Contains((AdbcParameter)value);
public override bool Contains(string value) => IndexOf(value) >= 0;
public override void CopyTo(System.Array array, int index) => throw new NotImplementedException();
public override IEnumerator GetEnumerator() => _parameters.GetEnumerator();
public override int IndexOf(object value) => _parameters.IndexOf((AdbcParameter)value);
public override int IndexOf(string parameterName) => GetParameterIndex(parameterName, throwOnFailure: false);
public override void Insert(int index, object value) => _parameters.Insert(index, (AdbcParameter)value);
public override void Remove(object value) => _parameters.Remove((AdbcParameter)value);
public override void RemoveAt(int index) => _parameters.RemoveAt(index);
public override void RemoveAt(string parameterName) => _parameters.RemoveAt(GetParameterIndex(parameterName));
protected override DbParameter GetParameter(int index) => _parameters[index];
protected override DbParameter GetParameter(string parameterName) => _parameters[GetParameterIndex(parameterName)];
protected override void SetParameter(int index, DbParameter value) => _parameters[index] = (AdbcParameter)value;
protected override void SetParameter(string parameterName, DbParameter value) => throw new NotImplementedException();
private int GetParameterIndex(string parameterName, bool throwOnFailure = true)
{
for (int i = 0; i < _parameters.Count; i++)
{
if (parameterName == _parameters[i].ParameterName)
{
return i;
}
}
if (throwOnFailure)
{
throw new IndexOutOfRangeException("parameterName not found");
}
return -1;
}
}
}
}