pulsar/transaction_impl.go (197 lines of code) (raw):
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package pulsar
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
pb "github.com/apache/pulsar-client-go/pulsar/internal/pulsar_proto"
"github.com/apache/pulsar-client-go/pulsar/log"
)
type subscription struct {
topic string
subscription string
}
type transaction struct {
mu sync.Mutex
txnID TxnID
state atomic.Int32
tcClient *transactionCoordinatorClient
registerPartitions map[string]bool
registerAckSubscriptions map[subscription]bool
// opsFlow It has two effects:
// 1. Wait all the operations of sending and acking messages with the transaction complete
// by reading msg from the chan.
// 2. Prevent sending or acking messages with a committed or aborted transaction.
// opsCount is record the number of the uncompleted operations.
// opsFlow
// Write:
// 1. When the transaction is created, a bool will be written to opsFlow chan.
// 2. When the opsCount decrement from 1 to 0, a new bool will be written to opsFlow chan.
// 3. When get a retryable error after committing or aborting the transaction,
// a bool will be written to opsFlow chan.
// Read:
// 1. When the transaction is committed or aborted, a bool will be read from opsFlow chan.
// 2. When the opsCount increment from 0 to 1, a bool will be read from opsFlow chan.
opsFlow chan bool
opsCount atomic.Int32
opTimeout time.Duration
log log.Logger
}
func newTransaction(id TxnID, tcClient *transactionCoordinatorClient, timeout time.Duration) *transaction {
transaction := &transaction{
txnID: id,
registerPartitions: make(map[string]bool),
registerAckSubscriptions: make(map[subscription]bool),
opsFlow: make(chan bool, 1),
opTimeout: tcClient.client.operationTimeout,
tcClient: tcClient,
}
transaction.state.Store(int32(TxnOpen))
// This means there are not pending requests with this transaction. The transaction can be committed or aborted.
transaction.opsFlow <- true
go func() {
// Set the state of the transaction to timeout after timeout
<-time.After(timeout)
transaction.state.CompareAndSwap(int32(TxnOpen), int32(TxnTimeout))
}()
transaction.log = tcClient.log.SubLogger(log.Fields{})
return transaction
}
func (txn *transaction) GetState() TxnState {
return TxnState(txn.state.Load())
}
func (txn *transaction) Commit(ctx context.Context) error {
if !(txn.state.CompareAndSwap(int32(TxnOpen), int32(TxnCommitting))) {
txnState := txn.state.Load()
return newError(InvalidStatus, txnStateErrorMessage(TxnOpen, TxnState(txnState)))
}
// Wait for all operations to complete
select {
case <-txn.opsFlow:
case <-ctx.Done():
txn.state.Store(int32(TxnOpen))
return ctx.Err()
case <-time.After(txn.opTimeout):
txn.state.Store(int32(TxnTimeout))
return newError(TimeoutError, "There are some operations that are not completed after the timeout.")
}
// Send commit transaction command to transaction coordinator
err := txn.tcClient.endTxn(&txn.txnID, pb.TxnAction_COMMIT)
if err == nil {
txn.state.Store(int32(TxnCommitted))
} else {
var e *Error
if errors.As(err, &e) && (e.Result() == TransactionNoFoundError || e.Result() == InvalidStatus) {
txn.state.Store(int32(TxnError))
return err
}
txn.opsFlow <- true
}
return err
}
func (txn *transaction) Abort(ctx context.Context) error {
if !(txn.state.CompareAndSwap(int32(TxnOpen), int32(TxnAborting))) {
txnState := txn.state.Load()
return newError(InvalidStatus, txnStateErrorMessage(TxnOpen, TxnState(txnState)))
}
// Wait for all operations to complete
select {
case <-txn.opsFlow:
case <-ctx.Done():
txn.state.Store(int32(TxnOpen))
return ctx.Err()
case <-time.After(txn.opTimeout):
txn.state.Store(int32(TxnTimeout))
return newError(TimeoutError, "There are some operations that are not completed after the timeout.")
}
// Send abort transaction command to transaction coordinator
err := txn.tcClient.endTxn(&txn.txnID, pb.TxnAction_ABORT)
if err == nil {
txn.state.Store(int32(TxnAborted))
} else {
var e *Error
if errors.As(err, &e) && (e.Result() == TransactionNoFoundError || e.Result() == InvalidStatus) {
txn.state.Store(int32(TxnError))
return err
}
txn.opsFlow <- true
}
return err
}
func (txn *transaction) registerSendOrAckOp() error {
if txn.opsCount.Add(1) == 1 {
// There are new operations that were not completed
select {
case <-txn.opsFlow:
return nil
case <-time.After(txn.opTimeout):
if err := txn.verifyOpen(); err != nil {
return err
}
return newError(TimeoutError, "Failed to get the semaphore to register the send/ack operation")
}
}
return nil
}
func (txn *transaction) endSendOrAckOp(err error) {
if err != nil {
txn.state.Store(int32(TxnError))
}
if txn.opsCount.Add(-1) == 0 {
// This means there are no pending send/ack requests
txn.opsFlow <- true
}
}
func (txn *transaction) registerProducerTopic(topic string) error {
if err := txn.verifyOpen(); err != nil {
return err
}
_, ok := txn.registerPartitions[topic]
if !ok {
txn.mu.Lock()
defer txn.mu.Unlock()
if _, ok = txn.registerPartitions[topic]; !ok {
err := txn.tcClient.addPublishPartitionToTxn(&txn.txnID, []string{topic})
if err != nil {
return err
}
txn.registerPartitions[topic] = true
}
}
return nil
}
func (txn *transaction) registerAckTopic(topic string, subName string) error {
if err := txn.verifyOpen(); err != nil {
return err
}
sub := subscription{
topic: topic,
subscription: subName,
}
_, ok := txn.registerAckSubscriptions[sub]
if !ok {
txn.mu.Lock()
defer txn.mu.Unlock()
if _, ok = txn.registerAckSubscriptions[sub]; !ok {
err := txn.tcClient.addSubscriptionToTxn(&txn.txnID, topic, subName)
if err != nil {
return err
}
txn.registerAckSubscriptions[sub] = true
}
}
return nil
}
func (txn *transaction) GetTxnID() TxnID {
return txn.txnID
}
func (txn *transaction) verifyOpen() error {
txnState := txn.state.Load()
if txnState != int32(TxnOpen) {
return newError(InvalidStatus, txnStateErrorMessage(TxnOpen, TxnState(txnState)))
}
return nil
}
func (state TxnState) String() string {
switch state {
case TxnOpen:
return "TxnOpen"
case TxnCommitting:
return "TxnCommitting"
case TxnAborting:
return "TxnAborting"
case TxnCommitted:
return "TxnCommitted"
case TxnAborted:
return "TxnAborted"
case TxnTimeout:
return "TxnTimeout"
case TxnError:
return "TxnError"
default:
return "Unknown"
}
}
//nolint:unparam
func txnStateErrorMessage(expected, actual TxnState) string {
return fmt.Sprintf("Expected transaction state: %s, actual: %s", expected, actual)
}