pkg/datasource/sql/exec/at/update_executor.go (223 lines of code) (raw):

/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package at import ( "context" "database/sql/driver" "fmt" "strings" "github.com/arana-db/parser/ast" "github.com/arana-db/parser/format" "github.com/arana-db/parser/model" "seata.apache.org/seata-go/pkg/datasource/sql/datasource" "seata.apache.org/seata-go/pkg/datasource/sql/exec" "seata.apache.org/seata-go/pkg/datasource/sql/types" "seata.apache.org/seata-go/pkg/datasource/sql/undo" "seata.apache.org/seata-go/pkg/datasource/sql/util" "seata.apache.org/seata-go/pkg/util/bytes" "seata.apache.org/seata-go/pkg/util/log" ) var ( maxInSize = 1000 ) // updateExecutor execute update SQL type updateExecutor struct { baseExecutor parserCtx *types.ParseContext execContext *types.ExecContext } // NewUpdateExecutor get update executor func NewUpdateExecutor(parserCtx *types.ParseContext, execContent *types.ExecContext, hooks []exec.SQLHook) executor { return &updateExecutor{parserCtx: parserCtx, execContext: execContent, baseExecutor: baseExecutor{hooks: hooks}} } // ExecContext exec SQL, and generate before image and after image func (u *updateExecutor) ExecContext(ctx context.Context, f exec.CallbackWithNamedValue) (types.ExecResult, error) { u.beforeHooks(ctx, u.execContext) defer func() { u.afterHooks(ctx, u.execContext) }() beforeImage, err := u.beforeImage(ctx) if err != nil { return nil, err } res, err := f(ctx, u.execContext.Query, u.execContext.NamedValues) if err != nil { return nil, err } afterImage, err := u.afterImage(ctx, *beforeImage) if err != nil { return nil, err } if len(beforeImage.Rows) != len(afterImage.Rows) { return nil, fmt.Errorf("Before image size is not equaled to after image size, probably because you updated the primary keys.") } u.execContext.TxCtx.RoundImages.AppendBeofreImage(beforeImage) u.execContext.TxCtx.RoundImages.AppendAfterImage(afterImage) return res, nil } // beforeImage build before image func (u *updateExecutor) beforeImage(ctx context.Context) (*types.RecordImage, error) { if !u.isAstStmtValid() { return nil, nil } selectSQL, selectArgs, err := u.buildBeforeImageSQL(ctx, u.execContext.NamedValues) if err != nil { return nil, err } tableName, _ := u.parserCtx.GetTableName() metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, tableName) if err != nil { return nil, err } var rowsi driver.Rows queryerCtx, ok := u.execContext.Conn.(driver.QueryerContext) var queryer driver.Queryer if !ok { queryer, ok = u.execContext.Conn.(driver.Queryer) } if ok { rowsi, err = util.CtxDriverQuery(ctx, queryerCtx, queryer, selectSQL, selectArgs) defer func() { if rowsi != nil { rowsi.Close() } }() if err != nil { log.Errorf("ctx driver query: %+v", err) return nil, err } } else { log.Errorf("target conn should been driver.QueryerContext or driver.Queryer") return nil, fmt.Errorf("invalid conn") } image, err := u.buildRecordImages(rowsi, metaData, types.SQLTypeUpdate) if err != nil { return nil, err } lockKey := u.buildLockKey(image, *metaData) u.execContext.TxCtx.LockKeys[lockKey] = struct{}{} image.SQLType = u.parserCtx.SQLType return image, nil } // afterImage build after image func (u *updateExecutor) afterImage(ctx context.Context, beforeImage types.RecordImage) (*types.RecordImage, error) { if !u.isAstStmtValid() { return nil, nil } if len(beforeImage.Rows) == 0 { return &types.RecordImage{}, nil } tableName, _ := u.parserCtx.GetTableName() metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, tableName) if err != nil { return nil, err } selectSQL, selectArgs := u.buildAfterImageSQL(beforeImage, metaData) var rowsi driver.Rows queryerCtx, ok := u.execContext.Conn.(driver.QueryerContext) var queryer driver.Queryer if !ok { queryer, ok = u.execContext.Conn.(driver.Queryer) } if ok { rowsi, err = util.CtxDriverQuery(ctx, queryerCtx, queryer, selectSQL, selectArgs) defer func() { if rowsi != nil { rowsi.Close() } }() if err != nil { log.Errorf("ctx driver query: %+v", err) return nil, err } } else { log.Errorf("target conn should been driver.QueryerContext or driver.Queryer") return nil, fmt.Errorf("invalid conn") } afterImage, err := u.buildRecordImages(rowsi, metaData, types.SQLTypeUpdate) if err != nil { return nil, err } afterImage.SQLType = u.parserCtx.SQLType return afterImage, nil } func (u *updateExecutor) isAstStmtValid() bool { return u.parserCtx != nil && u.parserCtx.UpdateStmt != nil } // buildAfterImageSQL build the SQL to query after image data func (u *updateExecutor) buildAfterImageSQL(beforeImage types.RecordImage, meta *types.TableMeta) (string, []driver.NamedValue) { if len(beforeImage.Rows) == 0 { return "", nil } sb := strings.Builder{} // todo: OnlyCareUpdateColumns should load from config first var selectFields string var separator = "," if undo.UndoConfig.OnlyCareUpdateColumns { for _, row := range beforeImage.Rows { for _, column := range row.Columns { selectFields += column.ColumnName + separator } } selectFields = strings.TrimSuffix(selectFields, separator) } else { selectFields = "*" } sb.WriteString("SELECT " + selectFields + " FROM " + meta.TableName + " WHERE ") whereSQL := u.buildWhereConditionByPKs(meta.GetPrimaryKeyOnlyName(), len(beforeImage.Rows), "mysql", maxInSize) sb.WriteString(" " + whereSQL + " ") return sb.String(), u.buildPKParams(beforeImage.Rows, meta.GetPrimaryKeyOnlyName()) } // buildAfterImageSQL build the SQL to query before image data func (u *updateExecutor) buildBeforeImageSQL(ctx context.Context, args []driver.NamedValue) (string, []driver.NamedValue, error) { if !u.isAstStmtValid() { log.Errorf("invalid update stmt") return "", nil, fmt.Errorf("invalid update stmt") } updateStmt := u.parserCtx.UpdateStmt fields := make([]*ast.SelectField, 0, len(updateStmt.List)) if undo.UndoConfig.OnlyCareUpdateColumns { for _, column := range updateStmt.List { fields = append(fields, &ast.SelectField{ Expr: &ast.ColumnNameExpr{ Name: column.Column, }, }) } // select indexes columns tableName, _ := u.parserCtx.GetTableName() metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, tableName) if err != nil { return "", nil, err } for _, columnName := range metaData.GetPrimaryKeyOnlyName() { fields = append(fields, &ast.SelectField{ Expr: &ast.ColumnNameExpr{ Name: &ast.ColumnName{ Name: model.CIStr{ O: columnName, L: columnName, }, }, }, }) } } else { fields = append(fields, &ast.SelectField{ Expr: &ast.ColumnNameExpr{ Name: &ast.ColumnName{ Name: model.CIStr{ O: "*", L: "*", }, }, }, }) } selStmt := ast.SelectStmt{ SelectStmtOpts: &ast.SelectStmtOpts{}, From: updateStmt.TableRefs, Where: updateStmt.Where, Fields: &ast.FieldList{Fields: fields}, OrderBy: updateStmt.Order, Limit: updateStmt.Limit, TableHints: updateStmt.TableHints, LockInfo: &ast.SelectLockInfo{ LockType: ast.SelectLockForUpdate, }, } b := bytes.NewByteBuffer([]byte{}) _ = selStmt.Restore(format.NewRestoreCtx(format.RestoreKeyWordUppercase, b)) sql := string(b.Bytes()) log.Infof("build select sql by update sourceQuery, sql {%s}", sql) return sql, u.buildSelectArgs(&selStmt, args), nil }