pkg/datasource/sql/driver.go (191 lines of code) (raw):

/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package sql import ( "context" "database/sql" "database/sql/driver" "errors" "fmt" "io" "reflect" "strings" "github.com/go-sql-driver/mysql" "seata.apache.org/seata-go/pkg/datasource/sql/datasource" mysql2 "seata.apache.org/seata-go/pkg/datasource/sql/datasource/mysql" "seata.apache.org/seata-go/pkg/datasource/sql/types" "seata.apache.org/seata-go/pkg/datasource/sql/util" "seata.apache.org/seata-go/pkg/protocol/branch" "seata.apache.org/seata-go/pkg/util/log" ) const ( // SeataATMySQLDriver MySQL driver for AT mode SeataATMySQLDriver = "seata-at-mysql" // SeataXAMySQLDriver MySQL driver for XA mode SeataXAMySQLDriver = "seata-xa-mysql" ) func initDriver() { sql.Register(SeataATMySQLDriver, &seataATDriver{ seataDriver: &seataDriver{ branchType: branch.BranchTypeAT, transType: types.ATMode, target: mysql.MySQLDriver{}, }, }) sql.Register(SeataXAMySQLDriver, &seataXADriver{ seataDriver: &seataDriver{ branchType: branch.BranchTypeXA, transType: types.XAMode, target: mysql.MySQLDriver{}, }, }) } type seataATDriver struct { *seataDriver } func (d *seataATDriver) OpenConnector(name string) (c driver.Connector, err error) { connector, err := d.seataDriver.OpenConnector(name) if err != nil { return nil, err } _connector, _ := connector.(*seataConnector) _connector.transType = types.ATMode cfg, _ := mysql.ParseDSN(name) _connector.cfg = cfg return &seataATConnector{ seataConnector: _connector, }, nil } type seataXADriver struct { *seataDriver } func (d *seataXADriver) OpenConnector(name string) (c driver.Connector, err error) { connector, err := d.seataDriver.OpenConnector(name) if err != nil { return nil, err } _connector, _ := connector.(*seataConnector) _connector.transType = types.XAMode cfg, _ := mysql.ParseDSN(name) _connector.cfg = cfg return &seataXAConnector{ seataConnector: _connector, }, nil } type seataDriver struct { branchType branch.BranchType transType types.TransactionMode target driver.Driver } // Open never be called, because seataDriver implemented dri.DriverContext interface. // reference package: datasource/sql [https://cs.opensource.google/go/go/+/master:src/database/sql/sql.go;l=813] // and maybe the sql.BD will be call Driver() method, but it obtain the Driver is fron Connector that is proxed by seataConnector. func (d *seataDriver) Open(name string) (driver.Conn, error) { return nil, errors.New(("operation unsupport.")) } func (d *seataDriver) OpenConnector(name string) (c driver.Connector, err error) { c = &dsnConnector{dsn: name, driver: d.target} if driverCtx, ok := d.target.(driver.DriverContext); ok { c, err = driverCtx.OpenConnector(name) if err != nil { log.Errorf("open connector: %w", err) return nil, err } } dbType := types.ParseDBType(d.getTargetDriverName()) if dbType == types.DBTypeUnknown { return nil, fmt.Errorf("unsupport conn type %s", d.getTargetDriverName()) } proxy, err := d.getOpenConnectorProxy(c, dbType, sql.OpenDB(c), name) if err != nil { log.Errorf("register resource: %w", err) return nil, err } return proxy, nil } func (d *seataDriver) getOpenConnectorProxy(connector driver.Connector, dbType types.DBType, db *sql.DB, dataSourceName string) (driver.Connector, error) { cfg, _ := mysql.ParseDSN(dataSourceName) options := []dbOption{ withResourceID(parseResourceID(dataSourceName)), withTarget(db), withBranchType(d.branchType), withDBType(dbType), withDBName(cfg.DBName), withConnector(connector), } res, err := newResource(options...) if err != nil { log.Errorf("create new resource: %w", err) return nil, err } datasource.RegisterTableCache(types.DBTypeMySQL, mysql2.NewTableMetaInstance(db, cfg)) if err = datasource.GetDataSourceManager(d.branchType).RegisterResource(res); err != nil { log.Errorf("regisiter resource: %w", err) return nil, err } return &seataConnector{ res: res, target: connector, cfg: cfg, }, nil } func (d *seataDriver) getTargetDriverName() string { return "mysql" } type dsnConnector struct { dsn string driver driver.Driver } func (t *dsnConnector) Connect(_ context.Context) (driver.Conn, error) { return t.driver.Open(t.dsn) } func (t *dsnConnector) Driver() driver.Driver { return t.driver } func parseResourceID(dsn string) string { i := strings.Index(dsn, "?") res := dsn if i > 0 { res = dsn[:i] } return strings.ReplaceAll(res, ",", "|") } func selectDBVersion(ctx context.Context, conn driver.Conn) (string, error) { var rowsi driver.Rows var err error queryerCtx, ok := conn.(driver.QueryerContext) var queryer driver.Queryer if !ok { queryer, ok = conn.(driver.Queryer) } if ok { rowsi, err = util.CtxDriverQuery(ctx, queryerCtx, queryer, "SELECT VERSION()", nil) defer func() { if rowsi != nil { rowsi.Close() } }() if err != nil { log.Errorf("ctx driver query: %+v", err) return "", err } } else { log.Errorf("target conn should been driver.QueryerContext or driver.Queryer") return "", fmt.Errorf("invalid conn") } dest := make([]driver.Value, 1) var version string if err = rowsi.Next(dest); err != nil { if err == io.EOF { return version, nil } return "", err } if len(dest) != 1 { return "", errors.New("get db version is not column 1") } switch reflect.TypeOf(dest[0]).Kind() { case reflect.Slice, reflect.Array: val := reflect.ValueOf(dest[0]).Bytes() version = string(val) case reflect.String: version = reflect.ValueOf(dest[0]).String() default: return "", errors.New("get db version is not a string") } return version, nil }