github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/util/contextutil/context.go (about)

     1  // Copyright 2017 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package contextutil
    12  
    13  import (
    14  	"context"
    15  	"fmt"
    16  	"net"
    17  	"runtime/debug"
    18  	"sync/atomic"
    19  	"time"
    20  
    21  	"github.com/cockroachdb/cockroach/pkg/util/log"
    22  	"github.com/cockroachdb/errors"
    23  )
    24  
    25  // WithCancel adds an info log to context.WithCancel's CancelFunc. Prefer using
    26  // WithCancelReason when possible.
    27  func WithCancel(parent context.Context) (context.Context, context.CancelFunc) {
    28  	return wrap(context.WithCancel(parent))
    29  }
    30  
    31  // reasonKey is a marker struct that's used to save the reason a context was
    32  // canceled.
    33  type reasonKey struct{}
    34  
    35  // CancelWithReasonFunc is a context.CancelFunc that also passes along an error
    36  // that is the reason for cancellation.
    37  type CancelWithReasonFunc func(reason error)
    38  
    39  // WithCancelReason adds a CancelFunc to this context, returning a new
    40  // cancellable context and a CancelWithReasonFunc, which is like
    41  // context.CancelFunc, except it also takes a "reason" error. The context that
    42  // is canceled with this CancelWithReasonFunc will immediately be updated to
    43  // contain this "reason". The reason can be retrieved with GetCancelReason.
    44  // This function doesn't change the deadline of a context if it already exists.
    45  func WithCancelReason(ctx context.Context) (context.Context, CancelWithReasonFunc) {
    46  	val := new(atomic.Value)
    47  	ctx = context.WithValue(ctx, reasonKey{}, val)
    48  	ctx, cancel := wrap(context.WithCancel(ctx))
    49  	return ctx, func(reason error) {
    50  		val.Store(reason)
    51  		cancel()
    52  	}
    53  }
    54  
    55  // GetCancelReason retrieves the cancel reason for a context that has been
    56  // created via WithCancelReason. The reason will be nil if the context was not
    57  // created with WithCancelReason, or if the context has not been canceled yet.
    58  // Otherwise, the reason will be the error that the context's
    59  // CancelWithReasonFunc was invoked with.
    60  func GetCancelReason(ctx context.Context) error {
    61  	i := ctx.Value(reasonKey{})
    62  	switch t := i.(type) {
    63  	case *atomic.Value:
    64  		return t.Load().(error)
    65  	}
    66  	return nil
    67  }
    68  
    69  func wrap(ctx context.Context, cancel context.CancelFunc) (context.Context, context.CancelFunc) {
    70  	if !log.V(1) {
    71  		return ctx, cancel
    72  	}
    73  	return ctx, func() {
    74  		if log.V(2) {
    75  			log.InfofDepth(ctx, 1, "canceling context:\n%s", debug.Stack())
    76  		} else if log.V(1) {
    77  			log.InfofDepth(ctx, 1, "canceling context")
    78  		}
    79  		cancel()
    80  	}
    81  }
    82  
    83  // TimeoutError is a wrapped ContextDeadlineExceeded error. It indicates that
    84  // an operation didn't complete within its designated timeout.
    85  type TimeoutError struct {
    86  	operation string
    87  	duration  time.Duration
    88  	cause     error
    89  }
    90  
    91  var _ error = (*TimeoutError)(nil)
    92  var _ fmt.Formatter = (*TimeoutError)(nil)
    93  var _ errors.Formatter = (*TimeoutError)(nil)
    94  
    95  // We implement net.Error the same way that context.DeadlineExceeded does, so
    96  // that people looking for net.Error attributes will still find them.
    97  var _ net.Error = (*TimeoutError)(nil)
    98  
    99  func (t *TimeoutError) Error() string { return fmt.Sprintf("%v", t) }
   100  
   101  // Format implements fmt.Formatter.
   102  func (t *TimeoutError) Format(s fmt.State, verb rune) { errors.FormatError(t, s, verb) }
   103  
   104  // FormatError implements errors.Formatter.
   105  func (t *TimeoutError) FormatError(p errors.Printer) error {
   106  	p.Printf("operation %q timed out after %s", t.operation, t.duration)
   107  	if errors.UnwrapOnce(t.cause) != nil {
   108  		// If there were details (wrappers, stack trace etc.) ensure
   109  		// they get printed.
   110  		return t.cause
   111  	}
   112  	// We omit the "context deadline exceeded" suffix in the common case.
   113  	return nil
   114  }
   115  
   116  // Timeout implements net.Error.
   117  func (*TimeoutError) Timeout() bool { return true }
   118  
   119  // Temporary implements net.Error.
   120  func (*TimeoutError) Temporary() bool { return true }
   121  
   122  // Cause implements Causer.
   123  func (t *TimeoutError) Cause() error {
   124  	return t.cause
   125  }
   126  
   127  // RunWithTimeout runs a function with a timeout, the same way you'd do with
   128  // context.WithTimeout. It improves the opaque error messages returned by
   129  // WithTimeout by augmenting them with the op string that is passed in.
   130  func RunWithTimeout(
   131  	ctx context.Context, op string, timeout time.Duration, fn func(ctx context.Context) error,
   132  ) error {
   133  	ctx, cancel := context.WithTimeout(ctx, timeout)
   134  	defer cancel()
   135  	err := fn(ctx)
   136  	if err != nil && errors.Is(ctx.Err(), context.DeadlineExceeded) {
   137  		err = &TimeoutError{
   138  			operation: op,
   139  			duration:  timeout,
   140  			cause:     err,
   141  		}
   142  	}
   143  	return err
   144  }