module/apmsql/conn.go (188 lines of code) (raw):

// Licensed to Elasticsearch B.V. under one or more contributor // license agreements. See the NOTICE file distributed with // this work for additional information regarding copyright // ownership. Elasticsearch B.V. licenses this file to you under // the Apache License, Version 2.0 (the "License"); you may // not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. package apmsql // import "go.elastic.co/apm/module/apmsql/v2" import ( "context" "database/sql/driver" "errors" "go.elastic.co/apm/v2" ) var _ driver.Validator = (*conn)(nil) func newConn(in driver.Conn, d *tracingDriver, dsnInfo DSNInfo) driver.Conn { conn := &conn{Conn: in, driver: d} conn.dsnInfo = dsnInfo conn.namedValueChecker, _ = in.(namedValueChecker) conn.pinger, _ = in.(driver.Pinger) conn.queryer, _ = in.(driver.Queryer) conn.queryerContext, _ = in.(driver.QueryerContext) conn.connPrepareContext, _ = in.(driver.ConnPrepareContext) conn.execer, _ = in.(driver.Execer) conn.execerContext, _ = in.(driver.ExecerContext) conn.connBeginTx, _ = in.(driver.ConnBeginTx) conn.sessionResetter, _ = in.(driver.SessionResetter) conn.validator, _ = in.(driver.Validator) if in, ok := in.(driver.ConnBeginTx); ok { return &connBeginTx{conn, in} } return conn } type conn struct { driver.Conn driver *tracingDriver dsnInfo DSNInfo namedValueChecker namedValueChecker pinger driver.Pinger queryer driver.Queryer queryerContext driver.QueryerContext connPrepareContext driver.ConnPrepareContext execer driver.Execer execerContext driver.ExecerContext connBeginTx driver.ConnBeginTx sessionResetter driver.SessionResetter validator driver.Validator } func (c *conn) startStmtSpan(ctx context.Context, stmt, spanType string) (*apm.Span, context.Context) { return c.startSpan(ctx, c.driver.querySignature(stmt), spanType, stmt) } func (c *conn) startSpan(ctx context.Context, name, spanType, stmt string) (*apm.Span, context.Context) { span, ctx := apm.StartSpanOptions(ctx, name, spanType, apm.SpanOptions{ ExitSpan: true, }) if !span.Dropped() { if c.dsnInfo.Address != "" { span.Context.SetDestinationAddress(c.dsnInfo.Address, c.dsnInfo.Port) span.Context.SetDestinationService(apm.DestinationServiceSpanContext{ Name: c.driver.driverName, Resource: c.driver.driverName, }) span.Context.SetServiceTarget(apm.ServiceTargetSpanContext{ Type: c.driver.driverName, Name: c.dsnInfo.Database, }) } span.Context.SetDatabase(apm.DatabaseSpanContext{ Instance: c.dsnInfo.Database, Statement: stmt, Type: "sql", User: c.dsnInfo.User, }) } return span, ctx } func (c *conn) finishSpan(ctx context.Context, span *apm.Span, result *driver.Result, resultError *error) { if *resultError == driver.ErrSkip { // TODO(axw) mark span as abandoned, // so it's not sent and not counted // in the span limit. Ideally remove // from the slice so memory is kept // in check. return } switch *resultError { case nil: if !span.Dropped() && result != nil && *result != nil && *result != driver.ResultNoRows { rowsAffected, err := (*result).RowsAffected() if err == nil && rowsAffected >= 0 { span.Context.SetDatabaseRowsAffected(rowsAffected) } } case driver.ErrBadConn, context.Canceled: // ErrBadConn is used by the connection pooling // logic in database/sql, and so is expected and // should not be reported. // // context.Canceled means the callers canceled // the operation, so this is also expected. default: if e := apm.CaptureError(ctx, *resultError); e != nil { e.Send() } } span.End() } func (c *conn) Ping(ctx context.Context) (resultError error) { if c.pinger == nil { return nil } span, ctx := c.startSpan(ctx, "ping", c.driver.pingSpanType, "") defer c.finishSpan(ctx, span, nil, &resultError) return c.pinger.Ping(ctx) } func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (_ driver.Rows, resultError error) { if c.queryerContext == nil && c.queryer == nil { return nil, driver.ErrSkip } span, ctx := c.startStmtSpan(ctx, query, c.driver.querySpanType) defer c.finishSpan(ctx, span, nil, &resultError) if c.queryerContext != nil { return c.queryerContext.QueryContext(ctx, query, args) } dargs, err := namedValueToValue(args) if err != nil { return nil, err } select { default: case <-ctx.Done(): return nil, ctx.Err() } return c.queryer.Query(query, dargs) } func (*conn) Query(query string, args []driver.Value) (driver.Rows, error) { return nil, errors.New("Query should never be called") } func (c *conn) PrepareContext(ctx context.Context, query string) (_ driver.Stmt, resultError error) { span, ctx := c.startStmtSpan(ctx, query, c.driver.prepareSpanType) defer c.finishSpan(ctx, span, nil, &resultError) var stmt driver.Stmt var err error if c.connPrepareContext != nil { stmt, err = c.connPrepareContext.PrepareContext(ctx, query) } else { stmt, err = c.Prepare(query) if err == nil { select { default: case <-ctx.Done(): stmt.Close() return nil, ctx.Err() } } } if stmt != nil { stmt = newStmt(stmt, c, query) } return stmt, err } func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (result driver.Result, resultError error) { if c.execerContext == nil && c.execer == nil { return nil, driver.ErrSkip } span, ctx := c.startStmtSpan(ctx, query, c.driver.execSpanType) defer c.finishSpan(ctx, span, &result, &resultError) if c.execerContext != nil { return c.execerContext.ExecContext(ctx, query, args) } dargs, err := namedValueToValue(args) if err != nil { return nil, err } select { default: case <-ctx.Done(): return nil, ctx.Err() } return c.execer.Exec(query, dargs) } func (*conn) Exec(query string, args []driver.Value) (driver.Result, error) { return nil, errors.New("Exec should never be called") } func (c *conn) CheckNamedValue(nv *driver.NamedValue) error { return checkNamedValue(nv, c.namedValueChecker) } func (c *conn) ResetSession(ctx context.Context) error { if c.sessionResetter != nil { return c.sessionResetter.ResetSession(ctx) } return nil } func (c *conn) IsValid() bool { if c.validator != nil { return c.validator.IsValid() } return true } type connBeginTx struct { *conn connBeginTx driver.ConnBeginTx } func (c *connBeginTx) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { // TODO(axw) instrument commit/rollback? return c.connBeginTx.BeginTx(ctx, opts) }