pkg/rm/two_phase.go (273 lines of code) (raw):
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package rm
import (
"context"
"fmt"
"reflect"
"seata.apache.org/seata-go/pkg/tm"
)
const (
TwoPhaseActionTag = "seataTwoPhaseAction"
TwoPhaseActionNameTag = "seataTwoPhaseServiceName"
TwoPhaseActionPrepareTagVal = "prepare"
TwoPhaseActionCommitTagVal = "commit"
TwoPhaseActionRollbackTagVal = "rollback"
)
var (
typError = reflect.Zero(reflect.TypeOf((*error)(nil)).Elem()).Type()
typContext = reflect.Zero(reflect.TypeOf((*context.Context)(nil)).Elem()).Type()
typBool = reflect.Zero(reflect.TypeOf((*bool)(nil)).Elem()).Type()
TypBusinessContextInterface = reflect.Zero(reflect.TypeOf((*tm.BusinessActionContext)(nil))).Type()
)
type TwoPhaseInterface interface {
Prepare(ctx context.Context, params interface{}) (bool, error)
Commit(ctx context.Context, businessActionContext *tm.BusinessActionContext) (bool, error)
Rollback(ctx context.Context, businessActionContext *tm.BusinessActionContext) (bool, error)
GetActionName() string
}
type TwoPhaseAction struct {
twoPhaseService interface{}
actionName string
prepareMethodName string
prepareMethod *reflect.Value
commitMethodName string
commitMethod *reflect.Value
rollbackMethodName string
rollbackMethod *reflect.Value
}
func (t *TwoPhaseAction) GetTwoPhaseService() interface{} {
return t.twoPhaseService
}
func (t *TwoPhaseAction) GetPrepareMethodName() string {
return t.prepareMethodName
}
func (t *TwoPhaseAction) GetCommitMethodName() string {
return t.commitMethodName
}
func (t *TwoPhaseAction) GetRollbackMethodName() string {
return t.rollbackMethodName
}
func (t *TwoPhaseAction) Prepare(ctx context.Context, params interface{}) (bool, error) {
values := []reflect.Value{reflect.ValueOf(ctx), reflect.ValueOf(params)}
res := t.prepareMethod.Call(values)
var (
r0 = res[0].Interface()
r1 = res[1].Interface()
res0 bool
res1 error
)
if r0 != nil {
res0 = r0.(bool)
}
if r1 != nil {
res1 = r1.(error)
}
return res0, res1
}
func (t *TwoPhaseAction) Commit(ctx context.Context, businessActionContext *tm.BusinessActionContext) (bool, error) {
res := t.commitMethod.Call([]reflect.Value{reflect.ValueOf(ctx), reflect.ValueOf(businessActionContext)})
var (
r0 = res[0].Interface()
r1 = res[1].Interface()
res0 bool
res1 error
)
if r0 != nil {
res0 = r0.(bool)
}
if r1 != nil {
res1 = r1.(error)
}
return res0, res1
}
func (t *TwoPhaseAction) Rollback(ctx context.Context, businessActionContext *tm.BusinessActionContext) (bool, error) {
res := t.rollbackMethod.Call([]reflect.Value{reflect.ValueOf(ctx), reflect.ValueOf(businessActionContext)})
var (
r0 = res[0].Interface()
r1 = res[1].Interface()
res0 bool
res1 error
)
if r0 != nil {
res0 = r0.(bool)
}
if r1 != nil {
res1 = r1.(error)
}
return res0, res1
}
func (t *TwoPhaseAction) GetActionName() string {
return t.actionName
}
func IsTwoPhaseAction(v interface{}) bool {
m, err := ParseTwoPhaseAction(v)
return m != nil && err == nil
}
func ParseTwoPhaseAction(v interface{}) (*TwoPhaseAction, error) {
if m, ok := v.(TwoPhaseInterface); ok {
return parseTwoPhaseActionByTwoPhaseInterface(m), nil
}
return ParseTwoPhaseActionByInterface(v)
}
func parseTwoPhaseActionByTwoPhaseInterface(v TwoPhaseInterface) *TwoPhaseAction {
value := reflect.ValueOf(v)
mp := value.MethodByName("Prepare")
mc := value.MethodByName("Commit")
mr := value.MethodByName("Rollback")
return &TwoPhaseAction{
twoPhaseService: v,
actionName: v.GetActionName(),
prepareMethodName: "Prepare",
prepareMethod: &mp,
commitMethodName: "Commit",
commitMethod: &mc,
rollbackMethodName: "Rollback",
rollbackMethod: &mr,
}
}
func ParseTwoPhaseActionByInterface(v interface{}) (*TwoPhaseAction, error) {
valueOfElem := reflect.ValueOf(v).Elem()
typeOf := valueOfElem.Type()
k := typeOf.Kind()
if k != reflect.Struct {
return nil, fmt.Errorf("param should be a struct, instead of a pointer")
}
numField := typeOf.NumField()
var (
hasPrepareMethodName bool
hasCommitMethodName bool
hasRollbackMethod bool
twoPhaseName string
result = TwoPhaseAction{
twoPhaseService: v,
}
)
for i := 0; i < numField; i++ {
t := typeOf.Field(i)
f := valueOfElem.Field(i)
if ms, m, ok := getPrepareAction(t, f); ok {
hasPrepareMethodName = true
result.prepareMethod = m
result.prepareMethodName = ms
} else if ms, m, ok = getCommitMethod(t, f); ok {
hasCommitMethodName = true
result.commitMethod = m
result.commitMethodName = ms
} else if ms, m, ok = getRollbackMethod(t, f); ok {
hasRollbackMethod = true
result.rollbackMethod = m
result.rollbackMethodName = ms
}
}
if !hasPrepareMethodName {
return nil, fmt.Errorf("missing prepare method")
}
if !hasCommitMethodName {
return nil, fmt.Errorf("missing commit method")
}
if !hasRollbackMethod {
return nil, fmt.Errorf("missing rollback method")
}
twoPhaseName = getActionName(v)
if twoPhaseName == "" {
return nil, fmt.Errorf("missing two phase name")
}
result.actionName = twoPhaseName
return &result, nil
}
func getPrepareAction(t reflect.StructField, f reflect.Value) (string, *reflect.Value, bool) {
if t.Tag.Get(TwoPhaseActionTag) != TwoPhaseActionPrepareTagVal {
return "", nil, false
}
if f.Kind() != reflect.Func || !f.IsValid() {
return "", nil, false
}
// prepare has 2 return error value
if outNum := t.Type.NumOut(); outNum != 2 {
return "", nil, false
}
if returnType := t.Type.Out(0); returnType != typBool {
return "", nil, false
}
if returnType := t.Type.Out(1); returnType != typError {
return "", nil, false
}
// prepared method has at least 1 params, context.Context, and other params
if inNum := t.Type.NumIn(); inNum == 0 {
return "", nil, false
}
if inType := t.Type.In(0); inType != typContext {
return "", nil, false
}
return t.Name, &f, true
}
func getCommitMethod(t reflect.StructField, f reflect.Value) (string, *reflect.Value, bool) {
if t.Tag.Get(TwoPhaseActionTag) != TwoPhaseActionCommitTagVal {
return "", nil, false
}
if f.Kind() != reflect.Func || !f.IsValid() {
return "", nil, false
}
// commit method has 2 return error value
if outNum := t.Type.NumOut(); outNum != 2 {
return "", nil, false
}
if returnType := t.Type.Out(0); returnType != typBool {
return "", nil, false
}
if returnType := t.Type.Out(1); returnType != typError {
return "", nil, false
}
// commit method has at least 1 params, context.Context, and other params
if inNum := t.Type.NumIn(); inNum != 2 {
return "", nil, false
}
if inType := t.Type.In(0); inType != typContext {
return "", nil, false
}
if inType := t.Type.In(1); inType != TypBusinessContextInterface {
return "", nil, false
}
return t.Name, &f, true
}
func getRollbackMethod(t reflect.StructField, f reflect.Value) (string, *reflect.Value, bool) {
if t.Tag.Get(TwoPhaseActionTag) != TwoPhaseActionRollbackTagVal {
return "", nil, false
}
if f.Kind() != reflect.Func || !f.IsValid() {
return "", nil, false
}
// rollback method has 2 return value
if outNum := t.Type.NumOut(); outNum != 2 {
return "", nil, false
}
if returnType := t.Type.Out(0); returnType != typBool {
return "", nil, false
}
if returnType := t.Type.Out(1); returnType != typError {
return "", nil, false
}
// rollback method has at least 1 params, context.Context, and other params
if inNum := t.Type.NumIn(); inNum != 2 {
return "", nil, false
}
if inType := t.Type.In(0); inType != typContext {
return "", nil, false
}
if inType := t.Type.In(1); inType != TypBusinessContextInterface {
return "", nil, false
}
return t.Name, &f, true
}
func getActionName(v interface{}) string {
var (
actionName string
valueOf = reflect.ValueOf(v)
valueOfElem = valueOf.Elem()
typeOf = valueOfElem.Type()
)
if typeOf.Kind() != reflect.Struct {
return ""
}
numField := valueOfElem.NumField()
for i := 0; i < numField; i++ {
t := typeOf.Field(i)
if actionName = t.Tag.Get(TwoPhaseActionNameTag); actionName != "" {
break
}
}
return actionName
}