pkg/datasource/sql/exec/at/multi_update_excutor.go (280 lines of code) (raw):

/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package at import ( "context" "database/sql/driver" "fmt" "strings" "github.com/arana-db/parser" "github.com/arana-db/parser/ast" "github.com/arana-db/parser/format" "github.com/arana-db/parser/model" "github.com/pkg/errors" "seata.apache.org/seata-go/pkg/datasource/sql/datasource" "seata.apache.org/seata-go/pkg/datasource/sql/exec" "seata.apache.org/seata-go/pkg/datasource/sql/types" "seata.apache.org/seata-go/pkg/datasource/sql/undo" "seata.apache.org/seata-go/pkg/datasource/sql/util" "seata.apache.org/seata-go/pkg/util/bytes" "seata.apache.org/seata-go/pkg/util/log" ) // multiUpdateExecutor execute multiple update SQL type multiUpdateExecutor struct { baseExecutor parserCtx *types.ParseContext execContext *types.ExecContext } var rows driver.Rows var comma = "," // NewMultiUpdateExecutor get new multi update executor func NewMultiUpdateExecutor(parserCtx *types.ParseContext, execContext *types.ExecContext, hooks []exec.SQLHook) *multiUpdateExecutor { return &multiUpdateExecutor{parserCtx: parserCtx, execContext: execContext, baseExecutor: baseExecutor{hooks: hooks}} } // ExecContext exec SQL, and generate before image and after image func (u *multiUpdateExecutor) ExecContext(ctx context.Context, f exec.CallbackWithNamedValue) (types.ExecResult, error) { u.beforeHooks(ctx, u.execContext) defer func() { u.afterHooks(ctx, u.execContext) }() //single update sql handler if len(u.parserCtx.MultiStmt) == 1 { u.parserCtx.UpdateStmt = u.parserCtx.MultiStmt[0].UpdateStmt return NewUpdateExecutor(u.parserCtx, u.execContext, u.hooks).ExecContext(ctx, f) } beforeImages, err := u.beforeImage(ctx) if err != nil { return nil, err } res, err := f(ctx, u.execContext.Query, u.execContext.NamedValues) if err != nil { return nil, err } afterImages, err := u.afterImage(ctx, beforeImages) if err != nil { return nil, err } if len(afterImages) != len(beforeImages) { return nil, errors.New("Before image size is not equaled to after image size, probably because you updated the primary keys.") } for i, afterImage := range afterImages { beforeImage := afterImages[i] if len(beforeImage.Rows) != len(afterImage.Rows) { return nil, errors.New("Before image size is not equaled to after image size, probably because you updated the primary keys.") } u.execContext.TxCtx.RoundImages.AppendBeofreImage(beforeImage) u.execContext.TxCtx.RoundImages.AppendAfterImage(afterImage) } return res, nil } func (u *multiUpdateExecutor) beforeImage(ctx context.Context) ([]*types.RecordImage, error) { if !u.isAstStmtValid() { return nil, nil } tableName := u.parserCtx.MultiStmt[0].UpdateStmt.TableRefs.TableRefs.Left.(*ast.TableSource).Source.(*ast.TableName).Name.O metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, tableName) if err != nil { return nil, err } // use selectSQL, selectArgs, err := u.buildBeforeImageSQL(u.execContext.NamedValues, metaData) if err != nil { return nil, err } rows, err := u.rowsPrepare(ctx, selectSQL, selectArgs) defer func() { if err := rows.Close(); err != nil { log.Errorf("rows close fail, err:%v", err) return } }() if err != nil { return nil, err } image, err := u.buildRecordImages(rows, metaData, types.SQLTypeUpdate) if err != nil { return nil, err } lockKey := u.buildLockKey(image, *metaData) u.execContext.TxCtx.LockKeys[lockKey] = struct{}{} image.SQLType = u.parserCtx.SQLType return []*types.RecordImage{image}, nil } func (u *multiUpdateExecutor) afterImage(ctx context.Context, beforeImages []*types.RecordImage) ([]*types.RecordImage, error) { if !u.isAstStmtValid() { return nil, nil } if len(beforeImages) == 0 { return nil, errors.New("empty beforeImages") } beforeImage := beforeImages[0] tableName := u.parserCtx.MultiStmt[0].UpdateStmt.TableRefs.TableRefs.Left.(*ast.TableSource).Source.(*ast.TableName).Name.O metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, tableName) if err != nil { return nil, err } // use selectSQL, selectArgs := u.buildAfterImageSQL(beforeImage, *metaData) rows, err = u.rowsPrepare(ctx, selectSQL, selectArgs) defer func() { if err := rows.Close(); err != nil { log.Errorf("rows close fail, err:%v", err) return } }() if err != nil { return nil, err } image, err := u.buildRecordImages(rows, metaData, types.SQLTypeUpdate) if err != nil { return nil, err } image.SQLType = u.parserCtx.SQLType return []*types.RecordImage{image}, nil } func (u *multiUpdateExecutor) rowsPrepare(ctx context.Context, selectSQL string, selectArgs []driver.NamedValue) (driver.Rows, error) { var queryer driver.Queryer queryerContext, ok := u.execContext.Conn.(driver.QueryerContext) if !ok { queryer, ok = u.execContext.Conn.(driver.Queryer) } if ok { var err error rows, err = util.CtxDriverQuery(ctx, queryerContext, queryer, selectSQL, selectArgs) if err != nil { log.Errorf("ctx driver query: %+v", err) return nil, err } } else { log.Errorf("target conn should been driver.QueryerContext or driver.Queryer") return nil, errors.New("invalid conn") } return rows, nil } // buildAfterImageSQL build the SQL to query after image data func (u *multiUpdateExecutor) buildAfterImageSQL(beforeImage *types.RecordImage, meta types.TableMeta) (string, []driver.NamedValue) { if !u.isAstStmtValid() { return "", nil } selectSql := strings.Builder{} selectFields := make([]string, 0, len(meta.ColumnNames)) var selectFieldsStr string var fieldsExits = make(map[string]struct{}) if undo.UndoConfig.OnlyCareUpdateColumns { for _, row := range beforeImage.Rows { for _, column := range row.Columns { if _, exist := fieldsExits[column.ColumnName]; exist { continue } fieldsExits[column.ColumnName] = struct{}{} selectFields = append(selectFields, column.ColumnName) } } selectFieldsStr = strings.Join(selectFields, comma) } else { selectFieldsStr = strings.Join(meta.ColumnNames, comma) } selectSql.WriteString("SELECT " + selectFieldsStr + " FROM " + meta.TableName + " WHERE ") whereSQL := u.buildWhereConditionByPKs(meta.GetPrimaryKeyOnlyName(), len(beforeImage.Rows), "mysql", maxInSize) selectSql.WriteString(" " + whereSQL + " ") return selectSql.String(), u.buildPKParams(beforeImage.Rows, meta.GetPrimaryKeyOnlyName()) } // buildSelectSQLByUpdate build select sql from update sql func (u *multiUpdateExecutor) buildBeforeImageSQL(args []driver.NamedValue, meta *types.TableMeta) (string, []driver.NamedValue, error) { if !u.isAstStmtValid() { log.Errorf("invalid multi update stmt") return "", nil, errors.New("invalid muliti update stmt") } var ( whereCondition strings.Builder multiStmts = u.parserCtx.MultiStmt newArgs = make([]driver.NamedValue, 0, len(u.parserCtx.MultiStmt)) fields = make([]*ast.SelectField, 0, len(meta.ColumnNames)) fieldsExits = make(map[string]struct{}, len(meta.ColumnNames)) ) for _, multiStmt := range u.parserCtx.MultiStmt { updateStmt := multiStmt.UpdateStmt if updateStmt.Limit != nil { return "", nil, fmt.Errorf("multi update SQL with limit condition is not support yet") } if updateStmt.Order != nil { return "", nil, fmt.Errorf("multi update SQL with orderBy condition is not support yet") } if undo.UndoConfig.OnlyCareUpdateColumns { //select update columns for _, column := range updateStmt.List { if _, exist := fieldsExits[column.Column.String()]; exist { continue } fieldsExits[column.Column.String()] = struct{}{} fields = append(fields, &ast.SelectField{Expr: &ast.ColumnNameExpr{Name: column.Column}}) } for _, columnName := range meta.GetPrimaryKeyOnlyName() { if _, exist := fieldsExits[columnName]; exist { continue } //select index columns fieldsExits[columnName] = struct{}{} fields = append(fields, &ast.SelectField{ Expr: &ast.ColumnNameExpr{Name: &ast.ColumnName{Name: model.CIStr{O: columnName, L: columnName}}}, }) } } else { fields = make([]*ast.SelectField, 0, len(meta.ColumnNames)) for _, column := range meta.ColumnNames { fields = append(fields, &ast.SelectField{ Expr: &ast.ColumnNameExpr{Name: &ast.ColumnName{Name: model.CIStr{O: column}}}}) } } tmpSelectStmt := ast.SelectStmt{ SelectStmtOpts: &ast.SelectStmtOpts{}, From: updateStmt.TableRefs, Where: updateStmt.Where, Fields: &ast.FieldList{Fields: fields}, OrderBy: updateStmt.Order, Limit: updateStmt.Limit, TableHints: updateStmt.TableHints, LockInfo: &ast.SelectLockInfo{ LockType: ast.SelectLockForUpdate, }, } newArgs = append(newArgs, u.buildSelectArgs(&tmpSelectStmt, args)...) in := bytes.NewByteBuffer([]byte{}) _ = updateStmt.Where.Restore(format.NewRestoreCtx(format.RestoreKeyWordUppercase, in)) if whereCondition.Len() > 0 { whereCondition.Write([]byte(" OR ")) } whereCondition.Write(in.Bytes()) } // only just get the where condition fakeSql := "select * from t where " + whereCondition.String() fakeStmt, err := parser.New().ParseOneStmt(fakeSql, "", "") if err != nil { log.Errorf("multi update parse fake sql error") return "", nil, err } fakeNode, ok := fakeStmt.Accept(&updateVisitor{}) if !ok { log.Errorf("multi update accept update visitor error") return "", nil, err } fakeSelectStmt, ok := fakeNode.(*ast.SelectStmt) if !ok { log.Errorf("multi update fake node is not select stmt") return "", nil, err } selStmt := ast.SelectStmt{ SelectStmtOpts: &ast.SelectStmtOpts{}, From: multiStmts[0].UpdateStmt.TableRefs, Where: fakeSelectStmt.Where, Fields: &ast.FieldList{Fields: fields}, TableHints: multiStmts[0].UpdateStmt.TableHints, LockInfo: &ast.SelectLockInfo{ LockType: ast.SelectLockForUpdate, }, } b := bytes.NewByteBuffer([]byte{}) selStmt.Restore(format.NewRestoreCtx(format.RestoreKeyWordUppercase, b)) log.Infof("build select sql by update sourceQuery, sql {}", string(b.Bytes())) return string(b.Bytes()), newArgs, nil } func (u *multiUpdateExecutor) isAstStmtValid() bool { return u.parserCtx != nil && u.parserCtx.MultiStmt != nil && len(u.parserCtx.MultiStmt) > 0 } type updateVisitor struct { stmt *ast.UpdateStmt } func (m *updateVisitor) Enter(n ast.Node) (ast.Node, bool) { return n, true } func (m *updateVisitor) Leave(n ast.Node) (ast.Node, bool) { node := n return node, true }