db/open.go (81 lines of code) (raw):
// Copyright (c) 2017 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
package db
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/pkg/errors"
"github.com/uber/storagetapper/log"
)
const (
connInfoTimeout = 15 * time.Second
enumeratorTimeout = 30 * time.Second
)
// GetConnInfo is database address resolving function
var GetConnInfo = GetConnInfoByType
// GetEnumerator is database location enumerating function
var GetEnumerator = GetEnumeratorByType
// IsValidConn is database connection validator function that verifies connection is to the correct DB
var IsValidConn = IsValidConnByType
// Log returns logger with Addr fields
func (a *Addr) Log() log.Logger {
return log.WithFields(log.Fields{"user": a.User, "host": a.Host, "port": a.Port, "db": a.DB})
}
// SQLMode can be substituted by tests
// Difference from default value is ONLY_FULL_GROUP_BY removed, required by
// state SQL joins
var SQLMode = "STRICT_TRANS_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION"
//OpenModeType opens database connection by given address
func OpenModeType(ci *Addr, drv, mode string) (*sql.DB, error) {
ci.Log().Debugf("Connect string")
dbc, err := sql.Open(drv, fmt.Sprintf("%v:%v@tcp(%v:%v)/%v?sql_mode='"+mode+"'&interpolateParams=true&parseTime=true&loc=Local&time_zone='SYSTEM'", ci.User, ci.Pwd, ci.Host, ci.Port, ci.DB))
if log.EL(ci.Log(), err) {
return nil, err
}
// Open doesn't open a connection. Validate DSN data:
//FIXME: in go1.8 change to: err = dbc.PingContext(shutdown.Context)
err = dbc.Ping()
if log.EL(ci.Log(), err) {
return nil, err
}
ci.Log().Infof("Connected")
return dbc, nil
}
//Open opens database connection by given address
func Open(ci *Addr) (*sql.DB, error) {
return OpenModeType(ci, "mysql", SQLMode)
}
// OpenService resolves db information for database and connects to that db.
// substDB can be passed to override the db resolved by database locator.
func OpenService(dbl *Loc, substDB string, inputType string) (*sql.DB, error) {
dbl.LogFields().Infof("Fetching connection info")
ci, err := GetConnInfo(dbl, Slave, inputType)
if err != nil {
return nil, err
}
if substDB != "" {
ci.DB = substDB
}
return Open(ci)
}
// GetConnInfoByType returns DB connection info by type of DB node
func GetConnInfoByType(dbl *Loc, connType ConnectionType, inputType string) (*Addr, error) {
res, err := NewResolver(inputType)
if log.E(err) {
return nil, err
}
ctx, cancel := context.WithTimeout(context.Background(), connInfoTimeout)
defer cancel()
ci, err := res.GetInfo(ctx, dbl, connType)
if err != nil {
err = errors.Wrap(err, "Failed to fetch DB connection info")
dbl.LogFields().Errorf(err.Error())
return nil, err
}
return ci, nil
}
// GetEnumeratorByType returns a DB location enumerator depending on the input type
func GetEnumeratorByType(svc, cluster, sdb, table, inputType string) (Enumerator, error) {
res, err := NewResolver(inputType)
if log.E(err) {
return nil, err
}
ctx, cancel := context.WithTimeout(context.Background(), enumeratorTimeout)
defer cancel()
return res.GetEnumerator(ctx, svc, cluster, sdb, table)
}
// IsValidConnByType checks the validity of the connection to make sure connection is to the correct DB
func IsValidConnByType(dbl *Loc, connType ConnectionType, addr *Addr, inputType string) bool {
res, err := NewResolver(inputType)
if err != nil {
log.Errorf(errors.Wrap(err, "Invalid connection").Error())
return false
}
ctx, cancel := context.WithTimeout(context.Background(), connInfoTimeout)
defer cancel()
return res.IsValidConn(ctx, dbl, connType, addr)
}