module/apmsql/stmt.go (78 lines of code) (raw):

// Licensed to Elasticsearch B.V. under one or more contributor // license agreements. See the NOTICE file distributed with // this work for additional information regarding copyright // ownership. Elasticsearch B.V. licenses this file to you under // the Apache License, Version 2.0 (the "License"); you may // not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. package apmsql // import "go.elastic.co/apm/module/apmsql/v2" import ( "context" "database/sql/driver" "go.elastic.co/apm/v2" ) func newStmt(in driver.Stmt, conn *conn, query string) driver.Stmt { stmt := &stmt{ Stmt: in, conn: conn, signature: conn.driver.querySignature(query), query: query, } stmt.columnConverter, _ = in.(driver.ColumnConverter) stmt.stmtExecContext, _ = in.(driver.StmtExecContext) stmt.stmtQueryContext, _ = in.(driver.StmtQueryContext) stmt.namedValueChecker, _ = in.(namedValueChecker) if stmt.namedValueChecker == nil { stmt.namedValueChecker = conn.namedValueChecker } return stmt } type stmt struct { driver.Stmt conn *conn signature string query string columnConverter driver.ColumnConverter namedValueChecker namedValueChecker stmtExecContext driver.StmtExecContext stmtQueryContext driver.StmtQueryContext } func (s *stmt) startSpan(ctx context.Context, spanType string) (*apm.Span, context.Context) { return s.conn.startSpan(ctx, s.signature, spanType, s.query) } func (s *stmt) ColumnConverter(idx int) driver.ValueConverter { if s.columnConverter != nil { return s.columnConverter.ColumnConverter(idx) } return driver.DefaultParameterConverter } func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (result driver.Result, resultError error) { span, ctx := s.startSpan(ctx, s.conn.driver.execSpanType) defer s.conn.finishSpan(ctx, span, &result, &resultError) if s.stmtExecContext != nil { return s.stmtExecContext.ExecContext(ctx, args) } dargs, err := namedValueToValue(args) if err != nil { return nil, err } select { default: case <-ctx.Done(): return nil, ctx.Err() } return s.Exec(dargs) } func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (_ driver.Rows, resultError error) { span, ctx := s.startSpan(ctx, s.conn.driver.querySpanType) defer s.conn.finishSpan(ctx, span, nil, &resultError) if s.stmtQueryContext != nil { return s.stmtQueryContext.QueryContext(ctx, args) } dargs, err := namedValueToValue(args) if err != nil { return nil, err } select { default: case <-ctx.Done(): return nil, ctx.Err() } return s.Query(dargs) } func (s *stmt) CheckNamedValue(nv *driver.NamedValue) error { return checkNamedValue(nv, s.namedValueChecker) }