pkg/datasource/sql/tx.go (163 lines of code) (raw):
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import (
"context"
"database/sql/driver"
"fmt"
"sync"
"time"
"seata.apache.org/seata-go/pkg/datasource/sql/datasource"
"seata.apache.org/seata-go/pkg/datasource/sql/types"
"seata.apache.org/seata-go/pkg/protocol/branch"
"seata.apache.org/seata-go/pkg/rm"
"seata.apache.org/seata-go/pkg/util/backoff"
"seata.apache.org/seata-go/pkg/util/log"
)
var (
hl sync.RWMutex
txHooks []txHook
)
func RegisterTxHook(h txHook) {
hl.Lock()
defer hl.Unlock()
txHooks = append(txHooks, h)
}
func CleanTxHooks() {
hl.Lock()
defer hl.Unlock()
txHooks = make([]txHook, 0, 4)
}
type (
txOption func(tx *Tx)
txHook interface {
BeforeCommit(tx *Tx)
BeforeRollback(tx *Tx)
}
)
func newTx(opts ...txOption) (driver.Tx, error) {
tx := new(Tx)
for i := range opts {
opts[i](tx)
}
if err := tx.init(); err != nil {
return nil, err
}
return tx, nil
}
// withDriverConn
func withDriverConn(conn *Conn) txOption {
return func(t *Tx) {
t.conn = conn
}
}
// withOriginTx
func withOriginTx(tx driver.Tx) txOption {
return func(t *Tx) {
t.target = tx
}
}
// withTxCtx
func withTxCtx(ctx *types.TransactionContext) txOption {
return func(t *Tx) {
t.tranCtx = ctx
}
}
// Tx
type Tx struct {
conn *Conn
tranCtx *types.TransactionContext
target driver.Tx
}
// Commit do commit action
func (tx *Tx) Commit() error {
tx.beforeCommit()
return tx.commitOnLocal()
}
func (tx *Tx) beforeCommit() {
if len(txHooks) != 0 {
hl.RLock()
defer hl.RUnlock()
for i := range txHooks {
txHooks[i].BeforeCommit(tx)
}
}
}
func (tx *Tx) Rollback() error {
if len(txHooks) != 0 {
hl.RLock()
defer hl.RUnlock()
for i := range txHooks {
txHooks[i].BeforeRollback(tx)
}
}
return tx.target.Rollback()
}
// init
func (tx *Tx) init() error {
return nil
}
// commitOnLocal
func (tx *Tx) commitOnLocal() error {
return tx.target.Commit()
}
// register
func (tx *Tx) register(ctx *types.TransactionContext) error {
if ctx.TransactionMode.BranchType() == branch.BranchTypeUnknow {
return nil
}
if ctx.TransactionMode.BranchType() == branch.BranchTypeAT {
if !ctx.HasUndoLog() || !ctx.HasLockKey() {
return nil
}
}
request := rm.BranchRegisterParam{
Xid: ctx.XID,
BranchType: ctx.TransactionMode.BranchType(),
ResourceId: ctx.ResourceID,
}
var lockKey string
if ctx.TransactionMode == types.ATMode {
if !ctx.HasUndoLog() || !ctx.HasLockKey() {
return nil
}
for k := range ctx.LockKeys {
lockKey += k + ";"
}
request.LockKeys = lockKey
}
dataSourceManager := datasource.GetDataSourceManager(ctx.TransactionMode.BranchType())
branchId, err := dataSourceManager.BranchRegister(context.Background(), request)
if err != nil {
log.Errorf("Failed to register branch: %s", err.Error())
return err
}
ctx.BranchID = uint64(branchId)
return nil
}
// report
func (tx *Tx) report(success bool) error {
if tx.tranCtx.BranchID == 0 {
return nil
}
status := getStatus(success)
request := rm.BranchReportParam{
Xid: tx.tranCtx.XID,
BranchId: int64(tx.tranCtx.BranchID),
Status: status,
}
dataSourceManager := datasource.GetDataSourceManager(tx.tranCtx.TransactionMode.BranchType())
if dataSourceManager == nil {
return fmt.Errorf("get dataSourceManager failed")
}
retry := backoff.New(context.Background(), backoff.Config{
MinBackoff: 100 * time.Millisecond,
MaxBackoff: 200 * time.Millisecond,
MaxRetries: 5,
})
var err error
for retry.Ongoing() {
if err = dataSourceManager.BranchReport(context.Background(), request); err == nil {
break
}
log.Infof("Failed to report [%s / %s] commit done [%s] Retry Countdown: %s", tx.tranCtx.BranchID, tx.tranCtx.XID, success, retry)
retry.Wait()
}
return err
}
func getStatus(success bool) branch.BranchStatus {
if success {
return branch.BranchStatusPhaseoneDone
} else {
return branch.BranchStatusPhaseoneFailed
}
}