pkg/datasource/sql/exec/at/base_executor.go (329 lines of code) (raw):

/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package at import ( "bytes" "context" "database/sql" "database/sql/driver" "fmt" "seata.apache.org/seata-go/pkg/datasource/sql/undo" "strings" "github.com/arana-db/parser/ast" "github.com/arana-db/parser/test_driver" gxsort "github.com/dubbogo/gost/sort" "seata.apache.org/seata-go/pkg/datasource/sql/exec" "seata.apache.org/seata-go/pkg/datasource/sql/types" "seata.apache.org/seata-go/pkg/datasource/sql/util" "seata.apache.org/seata-go/pkg/util/reflectx" ) type baseExecutor struct { hooks []exec.SQLHook } func (b *baseExecutor) beforeHooks(ctx context.Context, execCtx *types.ExecContext) { for _, hook := range b.hooks { hook.Before(ctx, execCtx) } } func (b *baseExecutor) afterHooks(ctx context.Context, execCtx *types.ExecContext) { for _, hook := range b.hooks { hook.After(ctx, execCtx) } } // GetScanSlice get the column type for scan // todo to use ColumnInfo get slice func (*baseExecutor) GetScanSlice(columnNames []string, tableMeta *types.TableMeta) []interface{} { scanSlice := make([]interface{}, 0, len(columnNames)) for _, columnName := range columnNames { var ( // get from metaData from this column columnMeta = tableMeta.Columns[columnName] ) switch strings.ToUpper(columnMeta.DatabaseTypeString) { case "VARCHAR", "NVARCHAR", "VARCHAR2", "CHAR", "TEXT", "JSON", "TINYTEXT": var scanVal sql.NullString scanSlice = append(scanSlice, &scanVal) case "BIT", "INT", "LONGBLOB", "SMALLINT", "TINYINT", "BIGINT", "MEDIUMINT": if columnMeta.IsNullable == 0 { scanVal := int64(0) scanSlice = append(scanSlice, &scanVal) } else { scanVal := sql.NullInt64{} scanSlice = append(scanSlice, &scanVal) } case "DATE", "DATETIME", "TIME", "TIMESTAMP", "YEAR": var scanVal sql.NullTime scanSlice = append(scanSlice, &scanVal) case "DECIMAL", "DOUBLE", "FLOAT": if columnMeta.IsNullable == 0 { scanVal := float64(0) scanSlice = append(scanSlice, &scanVal) } else { scanVal := sql.NullFloat64{} scanSlice = append(scanSlice, &scanVal) } default: scanVal := sql.RawBytes{} scanSlice = append(scanSlice, &scanVal) } } return scanSlice } func (b *baseExecutor) buildSelectArgs(stmt *ast.SelectStmt, args []driver.NamedValue) []driver.NamedValue { var ( selectArgsIndexs = make([]int32, 0) selectArgs = make([]driver.NamedValue, 0) ) b.traversalArgs(stmt.Where, &selectArgsIndexs) if stmt.OrderBy != nil { for _, item := range stmt.OrderBy.Items { b.traversalArgs(item, &selectArgsIndexs) } } if stmt.Limit != nil { if stmt.Limit.Offset != nil { b.traversalArgs(stmt.Limit.Offset, &selectArgsIndexs) } if stmt.Limit.Count != nil { b.traversalArgs(stmt.Limit.Count, &selectArgsIndexs) } } // sort selectArgs index array gxsort.Int32(selectArgsIndexs) for _, index := range selectArgsIndexs { selectArgs = append(selectArgs, args[index]) } return selectArgs } // todo perfect all sql operation func (b *baseExecutor) traversalArgs(node ast.Node, argsIndex *[]int32) { if node == nil { return } switch node.(type) { case *ast.BinaryOperationExpr: expr := node.(*ast.BinaryOperationExpr) b.traversalArgs(expr.L, argsIndex) b.traversalArgs(expr.R, argsIndex) break case *ast.BetweenExpr: expr := node.(*ast.BetweenExpr) b.traversalArgs(expr.Left, argsIndex) b.traversalArgs(expr.Right, argsIndex) break case *ast.PatternInExpr: exprs := node.(*ast.PatternInExpr).List for i := 0; i < len(exprs); i++ { b.traversalArgs(exprs[i], argsIndex) } break case *test_driver.ParamMarkerExpr: *argsIndex = append(*argsIndex, int32(node.(*test_driver.ParamMarkerExpr).Order)) break } } func (b *baseExecutor) buildRecordImages(rowsi driver.Rows, tableMetaData *types.TableMeta, sqlType types.SQLType) (*types.RecordImage, error) { // select column names columnNames := rowsi.Columns() rowImages := make([]types.RowImage, 0) sqlRows := util.NewScanRows(rowsi) for sqlRows.Next() { ss := b.GetScanSlice(columnNames, tableMetaData) err := sqlRows.Scan(ss...) if err != nil { return nil, err } columns := make([]types.ColumnImage, 0) // build record image for i, name := range columnNames { columnMeta := tableMetaData.Columns[name] keyType := types.IndexTypeNull if _, ok := tableMetaData.GetPrimaryKeyMap()[name]; ok { keyType = types.IndexTypePrimaryKey } jdbcType := types.MySQLStrToJavaType(columnMeta.DatabaseTypeString) columns = append(columns, types.ColumnImage{ KeyType: keyType, ColumnName: name, ColumnType: jdbcType, Value: getSqlNullValue(reflectx.GetElemDataValue(ss[i])), }) } rowImages = append(rowImages, types.RowImage{Columns: columns}) } return &types.RecordImage{TableName: tableMetaData.TableName, Rows: rowImages, SQLType: sqlType}, nil } func (b *baseExecutor) getNeedColumns(meta *types.TableMeta, columns []string, dbType types.DBType) []string { var needUpdateColumns []string if undo.UndoConfig.OnlyCareUpdateColumns && columns != nil && len(columns) > 0 { needUpdateColumns = columns if !b.containsPKByName(meta, columns) { pkNames := meta.GetPrimaryKeyOnlyName() if pkNames != nil && len(pkNames) > 0 { for _, name := range pkNames { needUpdateColumns = append(needUpdateColumns, name) } } } // todo If it contains onUpdate columns, add onUpdate columns } else { needUpdateColumns = meta.ColumnNames } for i := range needUpdateColumns { needUpdateColumns[i] = AddEscape(needUpdateColumns[i], dbType) } return needUpdateColumns } func (b *baseExecutor) containsPKByName(meta *types.TableMeta, columns []string) bool { pkColumnNameList := meta.GetPrimaryKeyOnlyName() if len(pkColumnNameList) == 0 { return false } matchCounter := 0 for _, column := range columns { for _, pkName := range pkColumnNameList { if strings.EqualFold(pkName, column) || strings.EqualFold(pkName, strings.ToLower(column)) { matchCounter++ } } } return matchCounter == len(pkColumnNameList) } func getSqlNullValue(value interface{}) interface{} { if value == nil { return nil } if v, ok := value.(sql.NullString); ok { if v.Valid { return v.String } return nil } if v, ok := value.(sql.NullFloat64); ok { if v.Valid { return v.Float64 } return nil } if v, ok := value.(sql.NullBool); ok { if v.Valid { return v.Bool } return nil } if v, ok := value.(sql.NullTime); ok { if v.Valid { return v.Time } return nil } if v, ok := value.(sql.NullByte); ok { if v.Valid { return v.Byte } return nil } if v, ok := value.(sql.NullInt16); ok { if v.Valid { return v.Int16 } return nil } if v, ok := value.(sql.NullInt32); ok { if v.Valid { return v.Int32 } return nil } if v, ok := value.(sql.NullInt64); ok { if v.Valid { return v.Int64 } return nil } return value } // buildWhereConditionByPKs build where condition by primary keys // each pk is a condition.the result will like :" (id,userCode) in ((?,?),(?,?)) or (id,userCode) in ((?,?),(?,?) ) or (id,userCode) in ((?,?))" func (b *baseExecutor) buildWhereConditionByPKs(pkNameList []string, rowSize int, dbType string, maxInSize int) string { var ( whereStr = &strings.Builder{} batchSize = rowSize/maxInSize + 1 ) if rowSize%maxInSize == 0 { batchSize = rowSize / maxInSize } for batch := 0; batch < batchSize; batch++ { if batch > 0 { whereStr.WriteString(" OR ") } whereStr.WriteString("(") for i := 0; i < len(pkNameList); i++ { if i > 0 { whereStr.WriteString(",") } // todo add escape whereStr.WriteString(fmt.Sprintf("`%s`", pkNameList[i])) } whereStr.WriteString(") IN (") var eachSize int if batch == batchSize-1 { if rowSize%maxInSize == 0 { eachSize = maxInSize } else { eachSize = rowSize % maxInSize } } else { eachSize = maxInSize } for i := 0; i < eachSize; i++ { if i > 0 { whereStr.WriteString(",") } whereStr.WriteString("(") for j := 0; j < len(pkNameList); j++ { if j > 0 { whereStr.WriteString(",") } whereStr.WriteString("?") } whereStr.WriteString(")") } whereStr.WriteString(")") } return whereStr.String() } func (b *baseExecutor) buildPKParams(rows []types.RowImage, pkNameList []string) []driver.NamedValue { params := make([]driver.NamedValue, 0) for _, row := range rows { coumnMap := row.GetColumnMap() for i, pk := range pkNameList { if col, ok := coumnMap[pk]; ok { params = append(params, driver.NamedValue{ Ordinal: i, Value: col.Value, }) } } } return params } // the string as local key. the local key example(multi pk): "t_user:1_a,2_b" func (b *baseExecutor) buildLockKey(records *types.RecordImage, meta types.TableMeta) string { var ( lockKeys bytes.Buffer filedSequence int ) lockKeys.WriteString(meta.TableName) lockKeys.WriteString(":") keys := meta.GetPrimaryKeyOnlyName() for _, row := range records.Rows { if filedSequence > 0 { lockKeys.WriteString(",") } pkSplitIndex := 0 for _, column := range row.Columns { var hasKeyColumn bool for _, key := range keys { if column.ColumnName == key { hasKeyColumn = true if pkSplitIndex > 0 { lockKeys.WriteString("_") } lockKeys.WriteString(fmt.Sprintf("%v", column.Value)) pkSplitIndex++ } } if hasKeyColumn { filedSequence++ } } } return lockKeys.String() }