internal/database/lock.go (177 lines of code) (raw):

// Copyright 2025 Microsoft Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package database import ( "context" "encoding/json" "fmt" "net/http" "os" "strconv" "time" "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" ) // Copied from azcore/internal/shared/shared.go func Delay(ctx context.Context, delay time.Duration) error { select { case <-time.After(delay): return nil case <-ctx.Done(): return ctx.Err() } } type LockClient struct { name string containerClient *azcosmos.ContainerClient defaultTimeToLive int32 } // lockDocument implements a global distributed lock. // Its contents should be opaque outside of LockClient. type lockDocument struct { baseDocument Owner string `json:"owner,omitempty"` TTL int32 `json:"ttl,omitempty"` } // NewLockClient creates a LockClient around a ContainerClient. It attempts to // read container properties to extract a default TTL. If this fails or if the // container does not define a default TTL, the function returns an error. func NewLockClient(ctx context.Context, containerClient *azcosmos.ContainerClient) (*LockClient, error) { hostname, err := os.Hostname() if err != nil { return nil, err } c := &LockClient{ name: hostname, containerClient: containerClient, } response, err := containerClient.Read(ctx, nil) if err != nil { return nil, err } if response.ContainerProperties != nil && response.ContainerProperties.DefaultTimeToLive != nil { c.defaultTimeToLive = *response.ContainerProperties.DefaultTimeToLive } else { return nil, fmt.Errorf("container '%s' does not have a default TTL", containerClient.ID()) } return c, nil } // SetName overrides how a lock item identifies the owner. This is for // informational purposes only. LockClient uses the hostname by default. func (c *LockClient) SetName(name string) { c.name = name } // GetDefaultTimeToLive returns the default time-to-live value of the // container as a time.Duration. func (c *LockClient) GetDefaultTimeToLive() time.Duration { return time.Duration(c.defaultTimeToLive) * time.Second } // SetRetryAfterHeader sets a "Retry-After" header to the default TTL value. func (c *LockClient) SetRetryAfterHeader(header http.Header) { header.Set("Retry-After", strconv.Itoa(int(c.defaultTimeToLive))) } // AcquireLock persistently tries to acquire a lock for the given ID. If a // timeout is provided, the function will cease after the timeout duration // and return a context.DeadlineExceeded error. func (c *LockClient) AcquireLock(ctx context.Context, id string, timeout *time.Duration) (*azcosmos.ItemResponse, error) { var lock *azcosmos.ItemResponse if timeout != nil { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, *timeout) defer cancel() } for lock == nil { var err error lock, err = c.TryAcquireLock(ctx, id) if err != nil { return nil, err } if lock == nil { // TTL values are in whole seconds, // so wait one second before retrying. err = Delay(ctx, time.Second) if err != nil { return nil, err } } } return lock, nil } // TryAcquireLock tries once to acquire a lock for the given ID. If the lock // is already taken, it returns a nil azcosmos.ItemResponse and no error. func (c *LockClient) TryAcquireLock(ctx context.Context, id string) (*azcosmos.ItemResponse, error) { doc := &lockDocument{ baseDocument: baseDocument{ID: id}, Owner: c.name, TTL: c.defaultTimeToLive, } data, err := json.Marshal(doc) if err != nil { return nil, err } pk := azcosmos.NewPartitionKeyString(doc.ID) options := &azcosmos.ItemOptions{ EnableContentResponseOnWrite: true, } response, err := c.containerClient.CreateItem(ctx, pk, data, options) if isResponseError(err, http.StatusConflict) { return nil, nil // lock already acquired by someone else } else if err != nil { return nil, err } return &response, nil } type StopHoldLock func() *azcosmos.ItemResponse // HoldLock tries to hold an acquired lock by renewing it periodically from a // goroutine until the returned stop function is called. The function also returns // a new context which is cancelled if the lock is lost or some other error occurs. // The stop function terminates the goroutine and returns the current lock, or nil // if the lock was lost. func (c *LockClient) HoldLock(ctx context.Context, item *azcosmos.ItemResponse) (cancelCtx context.Context, stop StopHoldLock) { cancelCtx, cancelCause := context.WithCancelCause(ctx) done := make(chan struct{}) stop = func() *azcosmos.ItemResponse { cancelCause(nil) <-done // wait for goroutine to finish return item } go func() { defer close(done) for { var doc *lockDocument err := json.Unmarshal(item.Value, &doc) if err != nil { cancelCause(fmt.Errorf("failed to unmarshal lock: %w", err)) return } // Aim to renew one second before TTL expires. timeToRenew := time.Unix(int64(doc.CosmosTimestamp), 0) if doc.TTL > 0 { timeToRenew = timeToRenew.Add(time.Duration(doc.TTL-1) * time.Second) } select { case <-time.After(time.Until(timeToRenew)): item, err = c.RenewLock(cancelCtx, item) if err != nil { cancelCause(fmt.Errorf("failed to renew lock: %w", err)) return } if item == nil { // We lost the lock, cancel the context. cancelCause(nil) return } case <-cancelCtx.Done(): return } } }() return } // RenewLock attempts to renew an acquired lock. If successful it returns a new lock. // If the lock was somehow lost, it returns a nil azcosmos.ItemResponse and no error. func (c *LockClient) RenewLock(ctx context.Context, item *azcosmos.ItemResponse) (*azcosmos.ItemResponse, error) { var doc *lockDocument err := json.Unmarshal(item.Value, &doc) if err != nil { return nil, err } pk := azcosmos.NewPartitionKeyString(doc.ID) options := &azcosmos.ItemOptions{ EnableContentResponseOnWrite: true, IfMatchEtag: &item.ETag, } response, err := c.containerClient.UpsertItem(ctx, pk, item.Value, options) if isResponseError(err, http.StatusPreconditionFailed) { return nil, nil // lock already acquired by someone else } else if err != nil { return nil, err } return &response, nil } // ReleaseLock attempts to release an acquired lock. Errors should be logged but not // treated as fatal, since the container item's TTL value guarantees that it will be // released eventually. func (c *LockClient) ReleaseLock(ctx context.Context, item *azcosmos.ItemResponse) error { var doc *lockDocument err := json.Unmarshal(item.Value, &doc) if err != nil { return err } pk := azcosmos.NewPartitionKeyString(doc.ID) options := &azcosmos.ItemOptions{ IfMatchEtag: &item.ETag, } _, err = c.containerClient.DeleteItem(ctx, pk, doc.ID, options) if isResponseError(err, http.StatusPreconditionFailed) { return nil // lock already acquired by someone else } return err }