func()

in pkg/datasource/sql/exec/at/multi_update_excutor.go [233:343]


func (u *multiUpdateExecutor) buildBeforeImageSQL(args []driver.NamedValue, meta *types.TableMeta) (string, []driver.NamedValue, error) {
	if !u.isAstStmtValid() {
		log.Errorf("invalid multi update stmt")
		return "", nil, errors.New("invalid muliti update stmt")
	}

	var (
		whereCondition strings.Builder
		multiStmts     = u.parserCtx.MultiStmt
		newArgs        = make([]driver.NamedValue, 0, len(u.parserCtx.MultiStmt))
		fields         = make([]*ast.SelectField, 0, len(meta.ColumnNames))
		fieldsExits    = make(map[string]struct{}, len(meta.ColumnNames))
	)

	for _, multiStmt := range u.parserCtx.MultiStmt {
		updateStmt := multiStmt.UpdateStmt
		if updateStmt.Limit != nil {
			return "", nil, fmt.Errorf("multi update SQL with limit condition is not support yet")
		}
		if updateStmt.Order != nil {
			return "", nil, fmt.Errorf("multi update SQL with orderBy condition is not support yet")
		}

		if undo.UndoConfig.OnlyCareUpdateColumns {
			//select update columns
			for _, column := range updateStmt.List {
				if _, exist := fieldsExits[column.Column.String()]; exist {
					continue
				}

				fieldsExits[column.Column.String()] = struct{}{}
				fields = append(fields, &ast.SelectField{Expr: &ast.ColumnNameExpr{Name: column.Column}})
			}

			for _, columnName := range meta.GetPrimaryKeyOnlyName() {
				if _, exist := fieldsExits[columnName]; exist {
					continue
				}

				//select index columns
				fieldsExits[columnName] = struct{}{}
				fields = append(fields, &ast.SelectField{
					Expr: &ast.ColumnNameExpr{Name: &ast.ColumnName{Name: model.CIStr{O: columnName, L: columnName}}},
				})
			}
		} else {
			fields = make([]*ast.SelectField, 0, len(meta.ColumnNames))
			for _, column := range meta.ColumnNames {
				fields = append(fields, &ast.SelectField{
					Expr: &ast.ColumnNameExpr{Name: &ast.ColumnName{Name: model.CIStr{O: column}}}})
			}
		}

		tmpSelectStmt := ast.SelectStmt{
			SelectStmtOpts: &ast.SelectStmtOpts{},
			From:           updateStmt.TableRefs,
			Where:          updateStmt.Where,
			Fields:         &ast.FieldList{Fields: fields},
			OrderBy:        updateStmt.Order,
			Limit:          updateStmt.Limit,
			TableHints:     updateStmt.TableHints,
			LockInfo: &ast.SelectLockInfo{
				LockType: ast.SelectLockForUpdate,
			},
		}
		newArgs = append(newArgs, u.buildSelectArgs(&tmpSelectStmt, args)...)

		in := bytes.NewByteBuffer([]byte{})
		_ = updateStmt.Where.Restore(format.NewRestoreCtx(format.RestoreKeyWordUppercase, in))

		if whereCondition.Len() > 0 {
			whereCondition.Write([]byte(" OR "))
		}
		whereCondition.Write(in.Bytes())
	}

	// only just get the where condition
	fakeSql := "select * from t where " + whereCondition.String()
	fakeStmt, err := parser.New().ParseOneStmt(fakeSql, "", "")
	if err != nil {
		log.Errorf("multi update parse fake sql error")
		return "", nil, err
	}
	fakeNode, ok := fakeStmt.Accept(&updateVisitor{})
	if !ok {
		log.Errorf("multi update accept update visitor error")
		return "", nil, err
	}
	fakeSelectStmt, ok := fakeNode.(*ast.SelectStmt)
	if !ok {
		log.Errorf("multi update fake node is not select stmt")
		return "", nil, err
	}

	selStmt := ast.SelectStmt{
		SelectStmtOpts: &ast.SelectStmtOpts{},
		From:           multiStmts[0].UpdateStmt.TableRefs,
		Where:          fakeSelectStmt.Where,
		Fields:         &ast.FieldList{Fields: fields},
		TableHints:     multiStmts[0].UpdateStmt.TableHints,
		LockInfo: &ast.SelectLockInfo{
			LockType: ast.SelectLockForUpdate,
		},
	}

	b := bytes.NewByteBuffer([]byte{})
	selStmt.Restore(format.NewRestoreCtx(format.RestoreKeyWordUppercase, b))
	log.Infof("build select sql by update sourceQuery, sql {}", string(b.Bytes()))

	return string(b.Bytes()), newArgs, nil
}