common/persistence/sql/sqlplugin/mysql/plugin.go (197 lines of code) (raw):
// Copyright (c) 2019 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
package mysql
import (
"bytes"
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"net"
"net/url"
"strings"
"github.com/go-sql-driver/mysql"
"github.com/iancoleman/strcase"
"github.com/jmoiron/sqlx"
"github.com/uber/cadence/common/config"
pt "github.com/uber/cadence/common/persistence/persistence-tests"
"github.com/uber/cadence/common/persistence/sql"
"github.com/uber/cadence/common/persistence/sql/sqldriver"
"github.com/uber/cadence/common/persistence/sql/sqlplugin"
"github.com/uber/cadence/environment"
)
const (
// PluginName is the name of the plugin
PluginName = "mysql"
dsnFmt = "%s:%s@%v(%v)/%s"
isolationLevelAttrName = "transaction_isolation"
isolationLevelAttrNameLegacy = "tx_isolation"
defaultIsolationLevel = "'READ-COMMITTED'"
// customTLSName is the name used if a custom tls configuration is created
customTLSName = "tls-custom"
)
var dsnAttrOverrides = map[string]string{
"parseTime": "true",
"clientFoundRows": "true",
"multiStatements": "true",
}
type plugin struct{}
var _ sqlplugin.Plugin = (*plugin)(nil)
func init() {
sql.RegisterPlugin(PluginName, &plugin{})
}
// CreateDB initialize the db object
func (p *plugin) CreateDB(cfg *config.SQL) (sqlplugin.DB, error) {
conns, err := sqldriver.CreateDBConnections(cfg, func(cfg *config.SQL) (*sqlx.DB, error) {
return p.createSingleDBConn(cfg)
})
if err != nil {
return nil, err
}
return newDB(conns, nil, sqlplugin.DbShardUndefined, cfg.NumShards)
}
// CreateAdminDB initialize the adminDb object
func (p *plugin) CreateAdminDB(cfg *config.SQL) (sqlplugin.AdminDB, error) {
conns, err := sqldriver.CreateDBConnections(cfg, func(cfg *config.SQL) (*sqlx.DB, error) {
return p.createSingleDBConn(cfg)
})
if err != nil {
return nil, err
}
return newDB(conns, nil, sqlplugin.DbShardUndefined, cfg.NumShards)
}
func (p *plugin) createSingleDBConn(cfg *config.SQL) (*sqlx.DB, error) {
err := registerTLSConfig(cfg)
if err != nil {
return nil, err
}
db, err := sqlx.Connect(PluginName, buildDSN(cfg))
if err != nil {
return nil, err
}
if cfg.MaxConns > 0 {
db.SetMaxOpenConns(cfg.MaxConns)
}
if cfg.MaxIdleConns > 0 {
db.SetMaxIdleConns(cfg.MaxIdleConns)
}
if cfg.MaxConnLifetime > 0 {
db.SetConnMaxLifetime(cfg.MaxConnLifetime)
}
// Maps struct names in CamelCase to snake without need for db struct tags.
db.MapperFunc(strcase.ToSnake)
return db, nil
}
func registerTLSConfig(cfg *config.SQL) error {
if cfg.TLS == nil || !cfg.TLS.Enabled {
return nil
}
host, _, err := net.SplitHostPort(cfg.ConnectAddr)
if err != nil {
return fmt.Errorf("error in host port from ConnectAddr: %v", err)
}
// TODO: create a way to set MinVersion and CipherSuites via cfg.
tlsConfig := &tls.Config{
ServerName: host,
InsecureSkipVerify: !cfg.TLS.EnableHostVerification,
}
if cfg.TLS.CaFile != "" {
rootCertPool := x509.NewCertPool()
pem, err := ioutil.ReadFile(cfg.TLS.CaFile)
if err != nil {
return fmt.Errorf("failed to load CA files: %v", err)
}
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
return fmt.Errorf("failed to append CA file")
}
tlsConfig.RootCAs = rootCertPool
}
if cfg.TLS.CertFile != "" && cfg.TLS.KeyFile != "" {
clientCert := make([]tls.Certificate, 0, 1)
certs, err := tls.LoadX509KeyPair(
cfg.TLS.CertFile,
cfg.TLS.KeyFile,
)
if err != nil {
return fmt.Errorf("failed to load tls x509 key pair: %v", err)
}
clientCert = append(clientCert, certs)
tlsConfig.Certificates = clientCert
}
// In order to use the TLS configuration you need to register it. Once registered you use it by specifying
// `tls` in the connect attributes.
err = mysql.RegisterTLSConfig(customTLSName, tlsConfig)
if err != nil {
return fmt.Errorf("failed to register tls config: %v", err)
}
if cfg.ConnectAttributes == nil {
cfg.ConnectAttributes = map[string]string{}
}
// If no `tls` connect attribute is provided then we override it to our newly registered tls config automatically.
// This allows users to simply provide a tls config without needing to remember to also set the connect attribute
if cfg.ConnectAttributes["tls"] == "" {
if cfg.TLS.SSLMode != "" {
cfg.ConnectAttributes["tls"] = cfg.TLS.SSLMode
} else {
cfg.ConnectAttributes["tls"] = customTLSName
}
}
return nil
}
func buildDSN(cfg *config.SQL) string {
attrs := buildDSNAttrs(cfg)
dsn := fmt.Sprintf(dsnFmt, cfg.User, cfg.Password, cfg.ConnectProtocol, cfg.ConnectAddr, cfg.DatabaseName)
if attrs != "" {
dsn = dsn + "?" + attrs
}
return dsn
}
func buildDSNAttrs(cfg *config.SQL) string {
attrs := make(map[string]string, len(dsnAttrOverrides)+len(cfg.ConnectAttributes)+1)
for k, v := range cfg.ConnectAttributes {
k1, v1 := sanitizeAttr(k, v)
attrs[k1] = v1
}
// only override isolation level if not specified
if !hasAttr(attrs, isolationLevelAttrName) &&
!hasAttr(attrs, isolationLevelAttrNameLegacy) {
attrs[isolationLevelAttrName] = defaultIsolationLevel
}
// these attrs are always overriden
for k, v := range dsnAttrOverrides {
attrs[k] = v
}
first := true
var buf bytes.Buffer
for k, v := range attrs {
if !first {
buf.WriteString("&")
}
first = false
buf.WriteString(k)
buf.WriteString("=")
buf.WriteString(v)
}
return url.PathEscape(buf.String())
}
func hasAttr(attrs map[string]string, key string) bool {
_, ok := attrs[key]
return ok
}
func sanitizeAttr(inkey string, invalue string) (string, string) {
key := strings.ToLower(strings.TrimSpace(inkey))
value := strings.ToLower(strings.TrimSpace(invalue))
switch key {
case isolationLevelAttrName, isolationLevelAttrNameLegacy:
if value[0] != '\'' { // mysql sys variable values must be enclosed in single quotes
value = "'" + value + "'"
}
return key, value
default:
return inkey, invalue
}
}
const (
testSchemaDir = "schema/mysql/v8"
)
// GetTestClusterOption return test options
func GetTestClusterOption() (*pt.TestBaseOptions, error) {
port, err := environment.GetMySQLPort()
if err != nil {
return nil, err
}
return &pt.TestBaseOptions{
DBPluginName: PluginName,
DBUsername: environment.GetMySQLUser(),
DBPassword: environment.GetMySQLPassword(),
DBHost: environment.GetMySQLAddress(),
DBPort: port,
SchemaDir: testSchemaDir,
StoreType: config.StoreTypeSQL,
}, nil
}