pkg/datasource/sql/undo/executor/executor.go (141 lines of code) (raw):

/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package executor import ( "context" "database/sql" "database/sql/driver" "fmt" "strings" "github.com/goccy/go-json" "seata.apache.org/seata-go/pkg/datasource/sql/datasource" "seata.apache.org/seata-go/pkg/datasource/sql/types" "seata.apache.org/seata-go/pkg/datasource/sql/undo" "seata.apache.org/seata-go/pkg/util/log" ) var _ undo.UndoExecutor = (*BaseExecutor)(nil) const ( checkSQLTemplate = "SELECT * FROM %s WHERE %s FOR UPDATE" maxInSize = 1000 ) type BaseExecutor struct { sqlUndoLog undo.SQLUndoLog undoImage *types.RecordImage } // ExecuteOn func (b *BaseExecutor) ExecuteOn(ctx context.Context, dbType types.DBType, conn *sql.Conn) error { // check data if valid return nil } // UndoPrepare func (b *BaseExecutor) UndoPrepare(undoPST *sql.Stmt, undoValues []types.ColumnImage, pkValueList []types.ColumnImage) { } func (b *BaseExecutor) dataValidationAndGoOn(ctx context.Context, conn *sql.Conn) (bool, error) { if !undo.UndoConfig.DataValidation { return true, nil } beforeImage := b.sqlUndoLog.BeforeImage afterImage := b.sqlUndoLog.AfterImage equals, err := IsRecordsEquals(beforeImage, afterImage) if err != nil { return false, err } if equals { log.Infof("Stop rollback because there is no data change between the before data snapshot and the after data snapshot.") return false, nil } // Validate if data is dirty. currentImage, err := b.queryCurrentRecords(ctx, conn) if err != nil { return false, err } // compare with current data and after image. equals, err = IsRecordsEquals(afterImage, currentImage) if err != nil { return false, err } if !equals { // If current data is not equivalent to the after data, then compare the current data with the before // data, too. No need continue to undo if current data is equivalent to the before data snapshot equals, err = IsRecordsEquals(beforeImage, currentImage) if err != nil { return false, err } if equals { log.Infof("Stop rollback because there is no data change between the before data snapshot and the current data snapshot.") // no need continue undo. return false, nil } else { oldRowJson, _ := json.Marshal(afterImage.Rows) newRowJson, _ := json.Marshal(currentImage.Rows) log.Infof("check dirty data failed, old and new data are not equal, "+ "tableName:[%s], oldRows:[%s],newRows:[%s].", afterImage.TableName, oldRowJson, newRowJson) return false, fmt.Errorf("Has dirty records when undo.") } } return true, nil } func (b *BaseExecutor) queryCurrentRecords(ctx context.Context, conn *sql.Conn) (*types.RecordImage, error) { if b.undoImage == nil { return nil, fmt.Errorf("undo image is nil") } tableMeta := b.undoImage.TableMeta pkNameList := tableMeta.GetPrimaryKeyOnlyName() pkValues := b.parsePkValues(b.undoImage.Rows, pkNameList) if len(pkValues) == 0 { return nil, nil } where := buildWhereConditionByPKs(pkNameList, len(b.undoImage.Rows), maxInSize) checkSQL := fmt.Sprintf(checkSQLTemplate, b.undoImage.TableName, where) params := buildPKParams(b.undoImage.Rows, pkNameList) rows, err := conn.QueryContext(ctx, checkSQL, params...) if err != nil { return nil, err } defer rows.Close() image := types.RecordImage{ TableName: b.undoImage.TableName, TableMeta: tableMeta, SQLType: types.SQLTypeSelect, } rowImages := make([]types.RowImage, 0) for rows.Next() { columnTypes, err := rows.ColumnTypes() if err != nil { return nil, err } slice := datasource.GetScanSlice(columnTypes) if err = rows.Scan(slice...); err != nil { return nil, err } colNames, err := rows.Columns() if err != nil { return nil, err } columns := make([]types.ColumnImage, 0) for i, val := range slice { actualVal := val if v, ok := val.(driver.Valuer); ok { actualVal, _ = v.Value() } columns = append(columns, types.ColumnImage{ ColumnName: colNames[i], Value: actualVal, }) } rowImages = append(rowImages, types.RowImage{Columns: columns}) } if err := rows.Err(); err != nil { return nil, err } image.Rows = rowImages return &image, nil } func (b *BaseExecutor) parsePkValues(rows []types.RowImage, pkNameList []string) map[string][]types.ColumnImage { pkValues := make(map[string][]types.ColumnImage) // todo optimize 3 fors for _, row := range rows { for _, column := range row.Columns { for _, pk := range pkNameList { if strings.EqualFold(pk, column.ColumnName) { values := pkValues[strings.ToUpper(pk)] if values == nil { values = make([]types.ColumnImage, 0) } values = append(values, column) pkValues[pk] = values } } } } return pkValues }