github.com/nya3jp/tast@v0.0.0-20230601000426-85c8e4d83a9b/src/go.chromium.org/tast/core/internal/xcontext/xcontext.go (about)

     1  // Copyright 2021 The ChromiumOS Authors
     2  // Use of this source code is governed by a BSD-style license that can be
     3  // found in the LICENSE file.
     4  
     5  // Package xcontext provides Context with custom errors.
     6  package xcontext
     7  
     8  import (
     9  	"context"
    10  	"errors"
    11  	"sync/atomic"
    12  	"time"
    13  
    14  	"code.cloudfoundry.org/clock"
    15  )
    16  
    17  // clk is replaced in unit tests to use fake clocks.
    18  var clk = clock.NewClock()
    19  
    20  type keyType string
    21  
    22  const (
    23  	contextTimeoutKey keyType = "context_timeout_duration"
    24  )
    25  
    26  // CancelFunc is a function to cancel an associated context with a specified
    27  // error. If a context is already canceled, calling this function has no effect.
    28  // It panics if err is nil.
    29  // Upon returning from this function, an associated context is guaranteed to be
    30  // in a canceled state (i.e. Done channel is closed, Err returns non-nil).
    31  type CancelFunc func(err error)
    32  
    33  // contextImpl implements context.Context with custom errors.
    34  type contextImpl struct {
    35  	// parent is a parent context.
    36  	parent context.Context
    37  
    38  	// hasDeadline indicates whether this context has a deadline.
    39  	hasDeadline bool
    40  
    41  	// deadline is a deadline of this context. It is valid only when
    42  	// hasDeadline is true.
    43  	deadline time.Time
    44  
    45  	// done is a channel returned by Done.
    46  	done chan struct{}
    47  
    48  	// req is a channel over which cancellation errors are sent. The channel
    49  	// has capacity=1 so that sending a first error over it does not block.
    50  	req chan error
    51  
    52  	// errValue holds an error value returned by Err.
    53  	errValue atomic.Value
    54  }
    55  
    56  // newContext returns a new context. It also starts a background goroutine to
    57  // handle cancellation signals if needed.
    58  //
    59  // If deadlineErr is nil, a new context has the same deadline as its parent, and
    60  // reqDeadline is ignored. If deadlineErr is non-nil, the deadline of a new
    61  // context is set to reqDeadline or that of the parent context, whichever comes
    62  // earlier.
    63  func newContext(parent context.Context, deadlineErr error, reqDeadline time.Time) (context.Context, CancelFunc) {
    64  	newDeadline := false
    65  	deadline, hasDeadline := parent.Deadline()
    66  	if deadlineErr != nil && (!hasDeadline || reqDeadline.Before(deadline)) {
    67  		deadline = reqDeadline
    68  		hasDeadline = true
    69  		newDeadline = true
    70  	}
    71  
    72  	ctx := &contextImpl{
    73  		parent:      parent,
    74  		hasDeadline: hasDeadline,
    75  		deadline:    deadline,
    76  		done:        make(chan struct{}),
    77  		req:         make(chan error, 1),
    78  	}
    79  
    80  	// Handle the cases where the new context is immediately canceled.
    81  	if err := func() error {
    82  		if err := parent.Err(); err != nil {
    83  			return err
    84  		}
    85  		if newDeadline && !deadline.After(clk.Now()) {
    86  			return deadlineErr
    87  		}
    88  		return nil
    89  	}(); err != nil {
    90  		ctx.errValue.Store(err)
    91  		close(ctx.done)
    92  		return ctx, ctx.cancel
    93  	}
    94  
    95  	// Start a background goroutine that handles cancellation signals.
    96  	go func() {
    97  		err := func() error {
    98  			var dl <-chan time.Time
    99  			if newDeadline {
   100  				tm := clk.NewTimer(deadline.Sub(clk.Now()))
   101  				defer tm.Stop()
   102  				dl = tm.C()
   103  			}
   104  
   105  			select {
   106  			case <-parent.Done():
   107  				return parent.Err()
   108  			case <-dl:
   109  				return deadlineErr
   110  			case err := <-ctx.req:
   111  				return err
   112  			}
   113  		}()
   114  		ctx.errValue.Store(err)
   115  		close(ctx.done)
   116  	}()
   117  
   118  	return ctx, ctx.cancel
   119  }
   120  
   121  // Deadline returns the deadline of the context.
   122  func (c *contextImpl) Deadline() (deadline time.Time, ok bool) {
   123  	return c.deadline, c.hasDeadline
   124  }
   125  
   126  // Done returns a channel that is closed on cancellation of the context.
   127  func (c *contextImpl) Done() <-chan struct{} {
   128  	return c.done
   129  }
   130  
   131  // Err returns a non-nil error if the context has been canceled.
   132  // This method does not strictly follow the contract of the context.Context
   133  // interface; it may return an error different from context.Canceled or
   134  // context.DeadlineExceeded.
   135  func (c *contextImpl) Err() error {
   136  	if val := c.errValue.Load(); val != nil {
   137  		return val.(error)
   138  	}
   139  	return nil
   140  }
   141  
   142  // Value returns a value associated with the context.
   143  func (c *contextImpl) Value(key interface{}) interface{} {
   144  	return c.parent.Value(key)
   145  }
   146  
   147  // cancel requests to cancel the context.
   148  func (c *contextImpl) cancel(err error) {
   149  	if err == nil {
   150  		panic("xcontext: Cancel called with nil")
   151  	}
   152  
   153  	// Attempt to send an error to the background goroutine.
   154  	// req has capacity=1, so at least the first send should succeed.
   155  	select {
   156  	case c.req <- err:
   157  	default:
   158  	}
   159  
   160  	// Wait until the context is canceled.
   161  	<-c.done
   162  }
   163  
   164  // WithCancel returns a context that can be canceled with arbitrary errors.
   165  func WithCancel(parent context.Context) (context.Context, CancelFunc) {
   166  	return newContext(parent, nil, time.Time{})
   167  }
   168  
   169  // WithDeadline returns a context that can be canceled with arbitrary errors on
   170  // reaching a specified deadline. It panics if err is nil.
   171  func WithDeadline(parent context.Context, t time.Time, err error) (context.Context, CancelFunc) {
   172  	if err == nil {
   173  		panic("xcontext: WithDeadline called with nil err")
   174  	}
   175  	return newContext(parent, err, t)
   176  }
   177  
   178  // WithTimeout returns a context that can be canceled with arbitrary errors on
   179  // reaching a specified timeout. It panics if err is nil.
   180  func WithTimeout(parent context.Context, d time.Duration, err error) (context.Context, CancelFunc) {
   181  	if err == nil {
   182  		panic("xcontext: WithTimeout called with nil err")
   183  	}
   184  	ctx, cFunc := WithDeadline(parent, clk.Now().Add(d), err)
   185  
   186  	// Adding timeout to the context for access when reporting the timeout has been reached
   187  	ctx = context.WithValue(ctx, contextTimeoutKey, d)
   188  
   189  	return ctx, cFunc
   190  }
   191  
   192  // GetContextTimeout returns the duration set on the Context for its Timeout
   193  // If the Timeout has not been set, then -1 is returned
   194  func GetContextTimeout(ctx context.Context) (time.Duration, error) {
   195  	t, ok := ctx.Value(contextTimeoutKey).(time.Duration)
   196  	if ok {
   197  		return t, nil
   198  	}
   199  	return 0, errors.New("timeout not set on context")
   200  }