github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/util/contextutil/context_test.go (about)

     1  // Copyright 2017 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package contextutil
    12  
    13  import (
    14  	"context"
    15  	"net"
    16  	"testing"
    17  	"time"
    18  
    19  	"github.com/cockroachdb/errors"
    20  	"github.com/stretchr/testify/assert"
    21  )
    22  
    23  func TestRunWithTimeout(t *testing.T) {
    24  	ctx := context.Background()
    25  	err := RunWithTimeout(ctx, "foo", 1, func(ctx context.Context) error {
    26  		time.Sleep(10 * time.Millisecond)
    27  		return nil
    28  	})
    29  	if err != nil {
    30  		t.Fatal("RunWithTimeout shouldn't return a timeout error if nobody touched the context.")
    31  	}
    32  
    33  	err = RunWithTimeout(ctx, "foo", 1, func(ctx context.Context) error {
    34  		time.Sleep(10 * time.Millisecond)
    35  		return ctx.Err()
    36  	})
    37  	expectedMsg := "operation \"foo\" timed out after 1ns"
    38  	if err.Error() != expectedMsg {
    39  		t.Fatalf("expected %s, actual %s", expectedMsg, err.Error())
    40  	}
    41  	var netError net.Error
    42  	if !errors.As(err, &netError) {
    43  		t.Fatal("RunWithTimeout should return a net.Error")
    44  	}
    45  	if !netError.Timeout() || !netError.Temporary() {
    46  		t.Fatal("RunWithTimeout should return a timeout and temporary error")
    47  	}
    48  	if !errors.Is(err, context.DeadlineExceeded) {
    49  		t.Fatalf("RunWithTimeout should return an error with a DeadlineExceeded cause")
    50  	}
    51  
    52  	err = RunWithTimeout(ctx, "foo", 1, func(ctx context.Context) error {
    53  		time.Sleep(10 * time.Millisecond)
    54  		return errors.Wrap(ctx.Err(), "custom error")
    55  	})
    56  	expExtended := expectedMsg + ": custom error: context deadline exceeded"
    57  	if err.Error() != expExtended {
    58  		t.Fatalf("expected %q, actual %q", expExtended, err.Error())
    59  	}
    60  	if !errors.As(err, &netError) {
    61  		t.Fatal("RunWithTimeout should return a net.Error")
    62  	}
    63  	if !netError.Timeout() || !netError.Temporary() {
    64  		t.Fatal("RunWithTimeout should return a timeout and temporary error")
    65  	}
    66  	if !errors.Is(err, context.DeadlineExceeded) {
    67  		t.Fatalf("RunWithTimeout should return an error with a DeadlineExceeded cause")
    68  	}
    69  }
    70  
    71  // TestRunWithTimeoutWithoutDeadlineExceeded ensures that when a timeout on the
    72  // context occurs but the underlying error does not have
    73  // context.DeadlineExceeded as its Cause (perhaps due to serialization) the
    74  // returned error is still a TimeoutError. In this case however the underlying
    75  // cause should be the returned error and not context.DeadlineExceeded.
    76  func TestRunWithTimeoutWithoutDeadlineExceeded(t *testing.T) {
    77  	ctx := context.Background()
    78  	notContextDeadlineExceeded := errors.Handled(context.DeadlineExceeded)
    79  	err := RunWithTimeout(ctx, "foo", 1, func(ctx context.Context) error {
    80  		<-ctx.Done()
    81  		return notContextDeadlineExceeded
    82  	})
    83  	var netError net.Error
    84  	if !errors.As(err, &netError) {
    85  		t.Fatal("RunWithTimeout should return a net.Error")
    86  	}
    87  	if !netError.Timeout() || !netError.Temporary() {
    88  		t.Fatal("RunWithTimeout should return a timeout and temporary error")
    89  	}
    90  	if !errors.Is(err, notContextDeadlineExceeded) {
    91  		t.Fatalf("RunWithTimeout should return an error caused by the underlying " +
    92  			"returned error")
    93  	}
    94  }
    95  
    96  func TestCancelWithReason(t *testing.T) {
    97  	ctx := context.Background()
    98  
    99  	var cancel CancelWithReasonFunc
   100  	ctx, cancel = WithCancelReason(ctx)
   101  
   102  	e := errors.New("hodor")
   103  	go func() {
   104  		cancel(e)
   105  	}()
   106  
   107  	<-ctx.Done()
   108  
   109  	expected := "context canceled"
   110  	found := ctx.Err().Error()
   111  	assert.Equal(t, expected, found)
   112  	assert.Equal(t, e, GetCancelReason(ctx))
   113  }