func()

in pkg/datasource/sql/undo/builder/mysql_multi_update_undo_log_builder.go [149:236]


func (u *MySQLMultiUpdateUndoLogBuilder) buildBeforeImageSQL(updateStmts []*ast.UpdateStmt, args []driver.Value) (string, []driver.Value, error) {
	if len(updateStmts) == 0 {
		log.Errorf("invalid multi update stmt")
		return "", nil, fmt.Errorf("invalid muliti update stmt")
	}

	var newArgs []driver.Value
	var fields []*ast.SelectField
	fieldsExits := make(map[string]struct{})
	var whereCondition strings.Builder
	for _, updateStmt := range updateStmts {
		if updateStmt.Limit != nil {
			return "", nil, fmt.Errorf("multi update SQL with limit condition is not support yet")
		}
		if updateStmt.Order != nil {
			return "", nil, fmt.Errorf("multi update SQL with orderBy condition is not support yet")
		}

		// todo use ONLY_CARE_UPDATE_COLUMNS to judge select all columns or not
		for _, column := range updateStmt.List {
			if _, exist := fieldsExits[column.Column.String()]; exist {
				continue
			}
			fieldsExits[column.Column.String()] = struct{}{}
			fields = append(fields, &ast.SelectField{
				Expr: &ast.ColumnNameExpr{
					Name: column.Column,
				},
			})
		}

		tmpSelectStmt := ast.SelectStmt{
			SelectStmtOpts: &ast.SelectStmtOpts{},
			From:           updateStmt.TableRefs,
			Where:          updateStmt.Where,
			Fields:         &ast.FieldList{Fields: fields},
			OrderBy:        updateStmt.Order,
			Limit:          updateStmt.Limit,
			TableHints:     updateStmt.TableHints,
			LockInfo: &ast.SelectLockInfo{
				LockType: ast.SelectLockForUpdate,
			},
		}
		newArgs = append(newArgs, u.buildSelectArgs(&tmpSelectStmt, args)...)

		in := bytes.NewByteBuffer([]byte{})
		updateStmt.Where.Restore(format.NewRestoreCtx(format.RestoreKeyWordUppercase, in))
		whereConditionStr := string(in.Bytes())

		if whereCondition.Len() > 0 {
			whereCondition.Write([]byte(" OR "))
		}
		whereCondition.Write([]byte(whereConditionStr))
	}

	// only just get the where condition
	fakeSql := "select * from t where " + whereCondition.String()
	fakeStmt, err := parser.New().ParseOneStmt(fakeSql, "", "")
	if err != nil {
		return "", nil, errors.Wrap(err, "multi update parse fake sql error")
	}
	fakeNode, ok := fakeStmt.Accept(&updateVisitor{})
	if !ok {
		return "", nil, errors.Wrap(err, "multi update accept update visitor error")
	}
	fakeSelectStmt, ok := fakeNode.(*ast.SelectStmt)
	if !ok {
		return "", nil, fmt.Errorf("multi update fake node is not select stmt")
	}

	selStmt := ast.SelectStmt{
		SelectStmtOpts: &ast.SelectStmtOpts{},
		From:           updateStmts[0].TableRefs,
		Where:          fakeSelectStmt.Where,
		Fields:         &ast.FieldList{Fields: fields},
		TableHints:     updateStmts[0].TableHints,
		LockInfo: &ast.SelectLockInfo{
			LockType: ast.SelectLockForUpdate,
		},
	}

	b := bytes.NewByteBuffer([]byte{})
	selStmt.Restore(format.NewRestoreCtx(format.RestoreKeyWordUppercase, b))
	sql := string(b.Bytes())
	log.Infof("build select sql by update sourceQuery, sql {}", sql)

	return sql, newArgs, nil
}