pkg/datasource/sql/util/ctxutil.go (91 lines of code) (raw):
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package util
import (
"context"
"database/sql/driver"
"errors"
)
func ctxDriverPrepare(ctx context.Context, ci driver.Conn, query string) (driver.Stmt, error) {
if ciCtx, is := ci.(driver.ConnPrepareContext); is {
return ciCtx.PrepareContext(ctx, query)
}
si, err := ci.Prepare(query)
if err == nil {
select {
default:
case <-ctx.Done():
si.Close()
return nil, ctx.Err()
}
}
return si, err
}
func ctxDriverExec(ctx context.Context, execerCtx driver.ExecerContext, execer driver.Execer, query string, nvdargs []driver.NamedValue) (driver.Result, error) {
if execerCtx != nil {
return execerCtx.ExecContext(ctx, query, nvdargs)
}
dargs, err := namedValueToValue(nvdargs)
if err != nil {
return nil, err
}
select {
default:
case <-ctx.Done():
return nil, ctx.Err()
}
return execer.Exec(query, dargs)
}
func CtxDriverQuery(ctx context.Context, queryerCtx driver.QueryerContext, queryer driver.Queryer, query string, nvdargs []driver.NamedValue) (driver.Rows, error) {
if queryerCtx != nil {
return queryerCtx.QueryContext(ctx, query, nvdargs)
}
dargs, err := namedValueToValue(nvdargs)
if err != nil {
return nil, err
}
select {
default:
case <-ctx.Done():
return nil, ctx.Err()
}
return queryer.Query(query, dargs)
}
func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Result, error) {
if siCtx, is := si.(driver.StmtExecContext); is {
return siCtx.ExecContext(ctx, nvdargs)
}
dargs, err := namedValueToValue(nvdargs)
if err != nil {
return nil, err
}
select {
default:
case <-ctx.Done():
return nil, ctx.Err()
}
return si.Exec(dargs)
}
func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Rows, error) {
if siCtx, is := si.(driver.StmtQueryContext); is {
return siCtx.QueryContext(ctx, nvdargs)
}
dargs, err := namedValueToValue(nvdargs)
if err != nil {
return nil, err
}
select {
default:
case <-ctx.Done():
return nil, ctx.Err()
}
return si.Query(dargs)
}
func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
dargs := make([]driver.Value, len(named))
for n, param := range named {
if len(param.Name) > 0 {
return nil, errors.New("sql: driver does not support the use of Named Parameters")
}
dargs[n] = param.Value
}
return dargs, nil
}