pkg/datasource/sql/driver.go (191 lines of code) (raw):
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"io"
"reflect"
"strings"
"github.com/go-sql-driver/mysql"
"seata.apache.org/seata-go/pkg/datasource/sql/datasource"
mysql2 "seata.apache.org/seata-go/pkg/datasource/sql/datasource/mysql"
"seata.apache.org/seata-go/pkg/datasource/sql/types"
"seata.apache.org/seata-go/pkg/datasource/sql/util"
"seata.apache.org/seata-go/pkg/protocol/branch"
"seata.apache.org/seata-go/pkg/util/log"
)
const (
// SeataATMySQLDriver MySQL driver for AT mode
SeataATMySQLDriver = "seata-at-mysql"
// SeataXAMySQLDriver MySQL driver for XA mode
SeataXAMySQLDriver = "seata-xa-mysql"
)
func initDriver() {
sql.Register(SeataATMySQLDriver, &seataATDriver{
seataDriver: &seataDriver{
branchType: branch.BranchTypeAT,
transType: types.ATMode,
target: mysql.MySQLDriver{},
},
})
sql.Register(SeataXAMySQLDriver, &seataXADriver{
seataDriver: &seataDriver{
branchType: branch.BranchTypeXA,
transType: types.XAMode,
target: mysql.MySQLDriver{},
},
})
}
type seataATDriver struct {
*seataDriver
}
func (d *seataATDriver) OpenConnector(name string) (c driver.Connector, err error) {
connector, err := d.seataDriver.OpenConnector(name)
if err != nil {
return nil, err
}
_connector, _ := connector.(*seataConnector)
_connector.transType = types.ATMode
cfg, _ := mysql.ParseDSN(name)
_connector.cfg = cfg
return &seataATConnector{
seataConnector: _connector,
}, nil
}
type seataXADriver struct {
*seataDriver
}
func (d *seataXADriver) OpenConnector(name string) (c driver.Connector, err error) {
connector, err := d.seataDriver.OpenConnector(name)
if err != nil {
return nil, err
}
_connector, _ := connector.(*seataConnector)
_connector.transType = types.XAMode
cfg, _ := mysql.ParseDSN(name)
_connector.cfg = cfg
return &seataXAConnector{
seataConnector: _connector,
}, nil
}
type seataDriver struct {
branchType branch.BranchType
transType types.TransactionMode
target driver.Driver
}
// Open never be called, because seataDriver implemented dri.DriverContext interface.
// reference package: datasource/sql [https://cs.opensource.google/go/go/+/master:src/database/sql/sql.go;l=813]
// and maybe the sql.BD will be call Driver() method, but it obtain the Driver is fron Connector that is proxed by seataConnector.
func (d *seataDriver) Open(name string) (driver.Conn, error) {
return nil, errors.New(("operation unsupport."))
}
func (d *seataDriver) OpenConnector(name string) (c driver.Connector, err error) {
c = &dsnConnector{dsn: name, driver: d.target}
if driverCtx, ok := d.target.(driver.DriverContext); ok {
c, err = driverCtx.OpenConnector(name)
if err != nil {
log.Errorf("open connector: %w", err)
return nil, err
}
}
dbType := types.ParseDBType(d.getTargetDriverName())
if dbType == types.DBTypeUnknown {
return nil, fmt.Errorf("unsupport conn type %s", d.getTargetDriverName())
}
proxy, err := d.getOpenConnectorProxy(c, dbType, sql.OpenDB(c), name)
if err != nil {
log.Errorf("register resource: %w", err)
return nil, err
}
return proxy, nil
}
func (d *seataDriver) getOpenConnectorProxy(connector driver.Connector, dbType types.DBType,
db *sql.DB, dataSourceName string) (driver.Connector, error) {
cfg, _ := mysql.ParseDSN(dataSourceName)
options := []dbOption{
withResourceID(parseResourceID(dataSourceName)),
withTarget(db),
withBranchType(d.branchType),
withDBType(dbType),
withDBName(cfg.DBName),
withConnector(connector),
}
res, err := newResource(options...)
if err != nil {
log.Errorf("create new resource: %w", err)
return nil, err
}
datasource.RegisterTableCache(types.DBTypeMySQL, mysql2.NewTableMetaInstance(db, cfg))
if err = datasource.GetDataSourceManager(d.branchType).RegisterResource(res); err != nil {
log.Errorf("regisiter resource: %w", err)
return nil, err
}
return &seataConnector{
res: res,
target: connector,
cfg: cfg,
}, nil
}
func (d *seataDriver) getTargetDriverName() string {
return "mysql"
}
type dsnConnector struct {
dsn string
driver driver.Driver
}
func (t *dsnConnector) Connect(_ context.Context) (driver.Conn, error) {
return t.driver.Open(t.dsn)
}
func (t *dsnConnector) Driver() driver.Driver {
return t.driver
}
func parseResourceID(dsn string) string {
i := strings.Index(dsn, "?")
res := dsn
if i > 0 {
res = dsn[:i]
}
return strings.ReplaceAll(res, ",", "|")
}
func selectDBVersion(ctx context.Context, conn driver.Conn) (string, error) {
var rowsi driver.Rows
var err error
queryerCtx, ok := conn.(driver.QueryerContext)
var queryer driver.Queryer
if !ok {
queryer, ok = conn.(driver.Queryer)
}
if ok {
rowsi, err = util.CtxDriverQuery(ctx, queryerCtx, queryer, "SELECT VERSION()", nil)
defer func() {
if rowsi != nil {
rowsi.Close()
}
}()
if err != nil {
log.Errorf("ctx driver query: %+v", err)
return "", err
}
} else {
log.Errorf("target conn should been driver.QueryerContext or driver.Queryer")
return "", fmt.Errorf("invalid conn")
}
dest := make([]driver.Value, 1)
var version string
if err = rowsi.Next(dest); err != nil {
if err == io.EOF {
return version, nil
}
return "", err
}
if len(dest) != 1 {
return "", errors.New("get db version is not column 1")
}
switch reflect.TypeOf(dest[0]).Kind() {
case reflect.Slice, reflect.Array:
val := reflect.ValueOf(dest[0]).Bytes()
version = string(val)
case reflect.String:
version = reflect.ValueOf(dest[0]).String()
default:
return "", errors.New("get db version is not a string")
}
return version, nil
}