unison/safewaitgroup.go (51 lines of code) (raw):

// Licensed to Elasticsearch B.V. under one or more contributor // license agreements. See the NOTICE file distributed with // this work for additional information regarding copyright // ownership. Elasticsearch B.V. licenses this file to you under // the Apache License, Version 2.0 (the "License"); you may // not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. package unison import ( "context" "errors" "sync" "github.com/elastic/go-concert/ctxtool" ) // SafeWaitGroup provides a safe alternative to WaitGroup, that instead of // panicing returns an error when Wait has been called. type SafeWaitGroup struct { mu sync.RWMutex wg sync.WaitGroup cancel context.CancelFunc closed bool } // ErrGroupClosed indicates that the WaitGroup is currently closed, and no more // routines can be started. var ErrGroupClosed = errors.New("group closed") // SafeWaitGroupWithCancel creates a SafeWaitGroup that will be closed when // the given canceler signals shutdown. // // Associated resources are cleaned when the parent context is cancelled, or Stop is called. func SafeWaitGroupWithCancel(parent Canceler) *SafeWaitGroup { grp := &SafeWaitGroup{} _, cancel := ctxtool.WithFunc(parent, grp.Close) grp.cancel = cancel return grp } // Add adds the delta to the WaitGroup counter. // If the counter becomes 0, all goroutines are blocked on Wait will continue. // // Add returns an error if 'Wait' has already been called, indicating that no more // go-routines should be started. func (s *SafeWaitGroup) Add(n int) error { if n < 0 { s.wg.Add(n) return nil } s.mu.RLock() defer s.mu.RUnlock() if s.closed { return ErrGroupClosed } s.wg.Add(n) return nil } // Done decrements the WaitGroup counter. func (s *SafeWaitGroup) Done() { s.wg.Done() } // Close marks the wait group as closed. All calls to Add will fail with ErrGroupClosed after // close has been called. Close does not wait until the WaitGroup counter has // reached zero, but will return immediately. Use Wait to wait for the counter to become 0. func (s *SafeWaitGroup) Close() { // When the context is cancelled, either by the parent context or by calling // 'cancel' directly, Close will be called. // The `cancel` function must always be called in order to clean up the context resources. // Due to `cancel` calling `Close`, we better be sure to have the mutex // released before calling cancel. // Although `cancel` is likely to be run in another go-routine, we don't want // to make any assumptions about implementation details of the context and cancel function. var wasClosed bool func() { s.mu.Lock() defer s.mu.Unlock() wasClosed, s.closed = s.closed, true }() if !wasClosed && s.cancel != nil { s.cancel() } } // Wait closes the WaitGroup and blocks until the WaitGroup counter is zero. // Add will return errors the moment 'Wait' has been called. func (s *SafeWaitGroup) Wait() { s.Close() s.wg.Wait() }