ctxtool/merge.go (131 lines of code) (raw):

// Licensed to Elasticsearch B.V. under one or more contributor // license agreements. See the NOTICE file distributed with // this work for additional information regarding copyright // ownership. Elasticsearch B.V. licenses this file to you under // the Apache License, Version 2.0 (the "License"); you may // not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. package ctxtool import ( "context" "sync" "time" ) type cancelledContext struct { context.Context err error } type mergeCancelCtx struct { context.Context cancel canceller ch <-chan struct{} mu sync.Mutex err error } // cancelOverwriteContext uses the canceller for Done and error calls, and the // original context for Deadline and Value calls. type cancelOverwriteContext struct { ctx context.Context cancel canceller } type mergedDeadlineCtx struct { context.Context deadline time.Time } type mergeValueCtx struct { context.Context overwrites valuer } // MergeContexts merges cancellation and values of 2 contexts. // The resulting context is canceled by the first context that got canceled. // The ctx2 overwrites values in ctx1 during value lookup. func MergeContexts(ctx1, ctx2 context.Context) (context.Context, context.CancelFunc) { return MergeCancellation(MergeValues(MergeDeadline(ctx1, ctx2), ctx2), ctx2) } // MergeCancellation creates a new context that will be cancelled if one of the // two input contexts gets canceled. The `Values` and `Deadline` are taken from the first context. func MergeCancellation(parent, other canceller) (context.Context, context.CancelFunc) { ctx := FromCanceller(parent) err := ctx.Err() if err == nil { err = other.Err() } if err != nil { // at least one context is already cancelled return &cancelledContext{Context: ctx, err: err}, func() {} } if ctx.Done() == nil { if other.Done() == nil { // context is never cancelled. return ctx, func() {} } return &cancelOverwriteContext{ctx: ctx, cancel: other}, func() {} } chDone := make(chan struct{}) merged := &mergeCancelCtx{ Context: ctx, cancel: other, ch: chDone, } go merged.waitCancel(chDone) canceller := func() { merged.mu.Lock() defer merged.mu.Unlock() if merged.err == nil { merged.err = context.Canceled close(chDone) } } return merged, canceller } func (c *cancelledContext) Done() <-chan struct{} { return closedChan } func (c *cancelledContext) Err() error { return c.err } func (c *mergeCancelCtx) waitCancel(chDone chan struct{}) { var err error defer func() { c.mu.Lock() defer c.mu.Unlock() if c.err == nil { c.err = err close(chDone) } }() select { case <-chDone: // CancelFunc triggered cleanup case <-c.Context.Done(): err = c.Context.Err() case <-c.cancel.Done(): err = c.cancel.Err() } } func (c *mergeCancelCtx) Done() <-chan struct{} { return c.ch } func (c *mergeCancelCtx) Err() error { c.mu.Lock() defer c.mu.Unlock() return c.err } func (c *cancelOverwriteContext) Deadline() (deadline time.Time, ok bool) { return c.ctx.Deadline() } func (c *cancelOverwriteContext) Done() <-chan struct{} { return c.cancel.Done() } func (c *cancelOverwriteContext) Err() error { return c.cancel.Err() } func (c *cancelOverwriteContext) Value(key interface{}) interface{} { return c.ctx.Value(key) } // MergeValues merges the values from ctx and overwrites. Value lookup will occur on `overwrites` first. // Deadline and cancellation are still driven by the first context. In order to merge cancellation use // MergeCancellation. func MergeValues(ctx context.Context, overwrites valuer) context.Context { return &mergeValueCtx{ctx, overwrites} } func (c *mergeValueCtx) Value(key interface{}) interface{} { if val := c.overwrites.Value(key); val != nil { return val } return c.Context.Value(key) } // MergeDeadline merges the deadline of two contexts. The resulting context // deadline will be the lesser deadline between the two context. If neither // context configures a deadline, the original context is returned. func MergeDeadline(ctx context.Context, deadliner deadliner) context.Context { deadline, ok := deadliner.Deadline() if !ok { return ctx } ctxDeadline, ok := ctx.Deadline() if ok && ctxDeadline.Before(deadline) { return ctx } return &mergedDeadlineCtx{ctx, deadline} } func (ctx mergedDeadlineCtx) Deadline() (time.Time, bool) { return ctx.deadline, true }