pkg/xcontext/event_handler.go (271 lines of code) (raw):
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
package xcontext
import (
"context"
"errors"
"runtime"
"sync"
"time"
)
// WithTimeout is analog of context.WithTimeout, but with support of the
// extended Context.
func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) {
return WithDeadline(parent, time.Now().Add(timeout))
}
// WithDeadline is analog of context.WithDeadline, but with support of the
// extended Context..
func WithDeadline(parent Context, t time.Time) (Context, CancelFunc) {
if parent == nil {
parent = background
}
ctx := parent.Clone()
h := ctx.addEventHandler()
h.deadline = &t
time.AfterFunc(time.Until(t), func() {
h.cancel(ErrDeadlineExceeded)
})
return ctx, func() {
h.cancel(ErrCanceled)
}
}
// WithCancel is analog of context.WithCancel, but with support of the
// extended Context.
//
// If no errs are passed, then cancel is used.
func WithCancel(parent Context, errs ...error) (Context, CancelFunc) {
if parent == nil {
parent = background
}
ctx := parent.Clone()
h := ctx.addEventHandler()
return ctx, func() {
h.cancel(errs...)
}
}
// WithResetSignalers resets all signalers (cancelers and notifiers).
func WithResetSignalers(parent Context) Context {
if parent == nil {
parent = background
}
ctx := parent.Clone()
ctx.resetEventHandler()
return ctx
}
// WithNotify is analog WithCancel, but does a notification signal instead
// (which does not close the context).
//
// Panics if no errs are passed.
func WithNotify(parent Context, errs ...error) (Context, CancelFunc) {
if len(errs) == 0 {
panic("len(errs) == 0")
}
if parent == nil {
parent = background
}
ctx := parent.Clone()
h := ctx.addEventHandler()
return ctx, func() {
h.notify(errs...)
}
}
// WithStdContext adds events and values of a standard context.
func WithStdContext(parent Context, stdCtx context.Context) Context {
if parent == nil {
parent = background
}
return parent.cloneWithStdContext(stdCtx)
}
func (ctx *ctxValue) cloneWithStdContext(stdCtx context.Context) Context {
child := ctx.clone()
if child.valuesHandler != nil {
child.valuesHandler = valuesMerger{
outer: stdCtx,
inner: child.valuesHandler,
}
} else {
child.valuesHandler = stdCtx
}
h := child.addEventHandler()
child = child.clone()
stopListening := make(chan struct{})
runtime.SetFinalizer(child, func(ctx *ctxValue) {
close(stopListening)
})
go func() {
select {
case <-stdCtx.Done():
h.cancel(stdCtx.Err())
case <-stopListening:
}
}()
return child
}
type eventHandler struct {
// cancelSignal is closed when a new cancel signal is arrived
cancelSignal chan struct{}
// children is all eventHandlers derived from this one. And for example
// if this one will receive a close signal, then all children will
// also receive a close signal.
children map[*eventHandler]struct{}
// locker is used exclude concurrent access to any data below (in this
// structure).
locker sync.Mutex
// firstCancel is the first cancel signal ever received for the whole.
// context tree.
//
// It is used to implement method Err compatible with the co-named method
// of the original context.Context.
firstCancel error
// receivedCancels is only the cancel signals received by this node.
receivedCancels []error
// receivedNotifications same as receivedCancels, but for notification signals.
receivedNotifications []error
// notifySignal same as cancelSignal, but for notification signals.
notifySignal map[error]chan struct{}
// deadline is when the context will be closed by exceeding a timeout.
//
// If nil then never.
deadline *time.Time
}
func (ctx *ctxValue) resetEventHandler() {
ctx.eventHandler = nil
}
func (ctx *ctxValue) addEventHandler() *eventHandler {
parent := ctx.eventHandler
h := &eventHandler{
notifySignal: make(map[error]chan struct{}),
children: map[*eventHandler]struct{}{},
}
ctx.eventHandler = h
if parent == nil {
return h
}
parent.locker.Lock()
h.firstCancel = parent.firstCancel
h.receivedNotifications = make([]error, len(parent.receivedNotifications))
copy(h.receivedNotifications, parent.receivedNotifications)
h.receivedCancels = make([]error, len(parent.receivedCancels))
copy(h.receivedCancels, parent.receivedCancels)
parent.children[h] = struct{}{}
parent.locker.Unlock()
runtime.SetFinalizer(ctx, func(ctx *ctxValue) {
parent.locker.Lock()
delete(parent.children, h)
parent.locker.Unlock()
})
return h
}
// IsSignaledWith returns true if the context received a cancel
// or a notification signal equals to any of passed ones.
//
// If errs is empty, then returns true if the context received any
// cancel or notification signal.
func (h *eventHandler) IsSignaledWith(errs ...error) bool {
if h == nil {
return false
}
h.locker.Lock()
defer h.locker.Unlock()
return h.isSignaledWith(errs...)
}
func (h *eventHandler) isSignaledWith(errs ...error) bool {
if len(errs) == 0 && (len(h.receivedCancels) != 0 || len(h.receivedNotifications) != 0) {
return true
}
for _, receivedErr := range h.receivedCancels {
for _, err := range errs {
if errors.Is(receivedErr, err) {
return true
}
}
}
for _, receivedErr := range h.receivedNotifications {
for _, err := range errs {
if errors.Is(receivedErr, err) {
return true
}
}
}
return false
}
// Err implements context.Context.Err
func (h *eventHandler) Err() error {
if h == nil {
return nil
}
h.locker.Lock()
defer h.locker.Unlock()
if h.firstCancel != nil {
return h.firstCancel
}
return nil
}
func (h *eventHandler) cancel(errs ...error) {
h.locker.Lock()
defer h.locker.Unlock()
if len(errs) == 0 {
h.receivedCancels = append(h.receivedCancels, ErrCanceled)
} else {
h.receivedCancels = append(h.receivedCancels, errs...)
}
if h.firstCancel == nil {
h.firstCancel = h.receivedCancels[0]
}
if h.cancelSignal != nil {
cancelSignal := h.cancelSignal
h.cancelSignal = nil
close(cancelSignal)
}
for _, err := range errs {
if h.notifySignal[err] == nil {
continue
}
close(h.notifySignal[err])
h.notifySignal[err] = nil
}
if h.notifySignal[nil] != nil {
close(h.notifySignal[nil])
h.notifySignal[nil] = nil
}
for child := range h.children {
child.cancel(errs...)
}
}
func (h *eventHandler) notify(errs ...error) {
h.locker.Lock()
defer h.locker.Unlock()
h.receivedNotifications = append(h.receivedNotifications, errs...)
for _, err := range errs {
if h.notifySignal[err] == nil {
continue
}
close(h.notifySignal[err])
h.notifySignal[err] = nil
}
if h.notifySignal[nil] != nil {
close(h.notifySignal[nil])
h.notifySignal[nil] = nil
}
for child := range h.children {
child.notify(errs...)
}
}
// Notifications returns all the received notifications (including events
// received by parents).
//
// This is a read-only value, do not modify it.
func (h *eventHandler) Notifications() []error {
if h == nil {
return nil
}
h.locker.Lock()
defer h.locker.Unlock()
return h.receivedNotifications
}
var (
openChan = make(chan struct{})
closedChan = make(chan struct{})
)
func init() {
close(closedChan)
}
// Done implements context.Context.Done
func (h *eventHandler) Done() <-chan struct{} {
if h == nil {
return nil
}
h.locker.Lock()
defer h.locker.Unlock()
if h.firstCancel != nil {
return closedChan
}
if h.cancelSignal == nil {
h.cancelSignal = make(chan struct{})
}
return h.cancelSignal
}
// Until works similar to Done(), but it is possible to specify specific
// signal to wait for.
//
// If err is nil, then waits for any event.
func (h *eventHandler) Until(err error) <-chan struct{} {
if h == nil {
return openChan
}
return h.newWaiter(err)
}
// Deadline implements context.Context.Deadline
func (h *eventHandler) Deadline() (deadline time.Time, ok bool) {
if h == nil {
return
}
h.locker.Lock()
defer h.locker.Unlock()
if h.deadline == nil {
return
}
return *h.deadline, true
}
func (h *eventHandler) newWaiter(err error) <-chan struct{} {
h.locker.Lock() // is unlocked in the go func() below
defer h.locker.Unlock()
if (err == nil && h.isSignaledWith()) || (err != nil && h.isSignaledWith(err)) {
return closedChan
}
if h.notifySignal[err] == nil {
h.notifySignal[err] = make(chan struct{})
}
return h.notifySignal[err]
}