github.com/mh-cbon/go@v0.0.0-20160603070303-9e112a3fe4c0/src/context/context_test.go (about)

     1  // Copyright 2014 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package context
     6  
     7  import (
     8  	"fmt"
     9  	"math/rand"
    10  	"runtime"
    11  	"strings"
    12  	"sync"
    13  	"testing"
    14  	"time"
    15  )
    16  
    17  // otherContext is a Context that's not one of the types defined in context.go.
    18  // This lets us test code paths that differ based on the underlying type of the
    19  // Context.
    20  type otherContext struct {
    21  	Context
    22  }
    23  
    24  func TestBackground(t *testing.T) {
    25  	c := Background()
    26  	if c == nil {
    27  		t.Fatalf("Background returned nil")
    28  	}
    29  	select {
    30  	case x := <-c.Done():
    31  		t.Errorf("<-c.Done() == %v want nothing (it should block)", x)
    32  	default:
    33  	}
    34  	if got, want := fmt.Sprint(c), "context.Background"; got != want {
    35  		t.Errorf("Background().String() = %q want %q", got, want)
    36  	}
    37  }
    38  
    39  func TestTODO(t *testing.T) {
    40  	c := TODO()
    41  	if c == nil {
    42  		t.Fatalf("TODO returned nil")
    43  	}
    44  	select {
    45  	case x := <-c.Done():
    46  		t.Errorf("<-c.Done() == %v want nothing (it should block)", x)
    47  	default:
    48  	}
    49  	if got, want := fmt.Sprint(c), "context.TODO"; got != want {
    50  		t.Errorf("TODO().String() = %q want %q", got, want)
    51  	}
    52  }
    53  
    54  func TestWithCancel(t *testing.T) {
    55  	c1, cancel := WithCancel(Background())
    56  
    57  	if got, want := fmt.Sprint(c1), "context.Background.WithCancel"; got != want {
    58  		t.Errorf("c1.String() = %q want %q", got, want)
    59  	}
    60  
    61  	o := otherContext{c1}
    62  	c2, _ := WithCancel(o)
    63  	contexts := []Context{c1, o, c2}
    64  
    65  	for i, c := range contexts {
    66  		if d := c.Done(); d == nil {
    67  			t.Errorf("c[%d].Done() == %v want non-nil", i, d)
    68  		}
    69  		if e := c.Err(); e != nil {
    70  			t.Errorf("c[%d].Err() == %v want nil", i, e)
    71  		}
    72  
    73  		select {
    74  		case x := <-c.Done():
    75  			t.Errorf("<-c.Done() == %v want nothing (it should block)", x)
    76  		default:
    77  		}
    78  	}
    79  
    80  	cancel()
    81  	time.Sleep(100 * time.Millisecond) // let cancelation propagate
    82  
    83  	for i, c := range contexts {
    84  		select {
    85  		case <-c.Done():
    86  		default:
    87  			t.Errorf("<-c[%d].Done() blocked, but shouldn't have", i)
    88  		}
    89  		if e := c.Err(); e != Canceled {
    90  			t.Errorf("c[%d].Err() == %v want %v", i, e, Canceled)
    91  		}
    92  	}
    93  }
    94  
    95  func TestParentFinishesChild(t *testing.T) {
    96  	// Context tree:
    97  	// parent -> cancelChild
    98  	// parent -> valueChild -> timerChild
    99  	parent, cancel := WithCancel(Background())
   100  	cancelChild, stop := WithCancel(parent)
   101  	defer stop()
   102  	valueChild := WithValue(parent, "key", "value")
   103  	timerChild, stop := WithTimeout(valueChild, 10000*time.Hour)
   104  	defer stop()
   105  
   106  	select {
   107  	case x := <-parent.Done():
   108  		t.Errorf("<-parent.Done() == %v want nothing (it should block)", x)
   109  	case x := <-cancelChild.Done():
   110  		t.Errorf("<-cancelChild.Done() == %v want nothing (it should block)", x)
   111  	case x := <-timerChild.Done():
   112  		t.Errorf("<-timerChild.Done() == %v want nothing (it should block)", x)
   113  	case x := <-valueChild.Done():
   114  		t.Errorf("<-valueChild.Done() == %v want nothing (it should block)", x)
   115  	default:
   116  	}
   117  
   118  	// The parent's children should contain the two cancelable children.
   119  	pc := parent.(*cancelCtx)
   120  	cc := cancelChild.(*cancelCtx)
   121  	tc := timerChild.(*timerCtx)
   122  	pc.mu.Lock()
   123  	if len(pc.children) != 2 || !pc.children[cc] || !pc.children[tc] {
   124  		t.Errorf("bad linkage: pc.children = %v, want %v and %v",
   125  			pc.children, cc, tc)
   126  	}
   127  	pc.mu.Unlock()
   128  
   129  	if p, ok := parentCancelCtx(cc.Context); !ok || p != pc {
   130  		t.Errorf("bad linkage: parentCancelCtx(cancelChild.Context) = %v, %v want %v, true", p, ok, pc)
   131  	}
   132  	if p, ok := parentCancelCtx(tc.Context); !ok || p != pc {
   133  		t.Errorf("bad linkage: parentCancelCtx(timerChild.Context) = %v, %v want %v, true", p, ok, pc)
   134  	}
   135  
   136  	cancel()
   137  
   138  	pc.mu.Lock()
   139  	if len(pc.children) != 0 {
   140  		t.Errorf("pc.cancel didn't clear pc.children = %v", pc.children)
   141  	}
   142  	pc.mu.Unlock()
   143  
   144  	// parent and children should all be finished.
   145  	check := func(ctx Context, name string) {
   146  		select {
   147  		case <-ctx.Done():
   148  		default:
   149  			t.Errorf("<-%s.Done() blocked, but shouldn't have", name)
   150  		}
   151  		if e := ctx.Err(); e != Canceled {
   152  			t.Errorf("%s.Err() == %v want %v", name, e, Canceled)
   153  		}
   154  	}
   155  	check(parent, "parent")
   156  	check(cancelChild, "cancelChild")
   157  	check(valueChild, "valueChild")
   158  	check(timerChild, "timerChild")
   159  
   160  	// WithCancel should return a canceled context on a canceled parent.
   161  	precanceledChild := WithValue(parent, "key", "value")
   162  	select {
   163  	case <-precanceledChild.Done():
   164  	default:
   165  		t.Errorf("<-precanceledChild.Done() blocked, but shouldn't have")
   166  	}
   167  	if e := precanceledChild.Err(); e != Canceled {
   168  		t.Errorf("precanceledChild.Err() == %v want %v", e, Canceled)
   169  	}
   170  }
   171  
   172  func TestChildFinishesFirst(t *testing.T) {
   173  	cancelable, stop := WithCancel(Background())
   174  	defer stop()
   175  	for _, parent := range []Context{Background(), cancelable} {
   176  		child, cancel := WithCancel(parent)
   177  
   178  		select {
   179  		case x := <-parent.Done():
   180  			t.Errorf("<-parent.Done() == %v want nothing (it should block)", x)
   181  		case x := <-child.Done():
   182  			t.Errorf("<-child.Done() == %v want nothing (it should block)", x)
   183  		default:
   184  		}
   185  
   186  		cc := child.(*cancelCtx)
   187  		pc, pcok := parent.(*cancelCtx) // pcok == false when parent == Background()
   188  		if p, ok := parentCancelCtx(cc.Context); ok != pcok || (ok && pc != p) {
   189  			t.Errorf("bad linkage: parentCancelCtx(cc.Context) = %v, %v want %v, %v", p, ok, pc, pcok)
   190  		}
   191  
   192  		if pcok {
   193  			pc.mu.Lock()
   194  			if len(pc.children) != 1 || !pc.children[cc] {
   195  				t.Errorf("bad linkage: pc.children = %v, cc = %v", pc.children, cc)
   196  			}
   197  			pc.mu.Unlock()
   198  		}
   199  
   200  		cancel()
   201  
   202  		if pcok {
   203  			pc.mu.Lock()
   204  			if len(pc.children) != 0 {
   205  				t.Errorf("child's cancel didn't remove self from pc.children = %v", pc.children)
   206  			}
   207  			pc.mu.Unlock()
   208  		}
   209  
   210  		// child should be finished.
   211  		select {
   212  		case <-child.Done():
   213  		default:
   214  			t.Errorf("<-child.Done() blocked, but shouldn't have")
   215  		}
   216  		if e := child.Err(); e != Canceled {
   217  			t.Errorf("child.Err() == %v want %v", e, Canceled)
   218  		}
   219  
   220  		// parent should not be finished.
   221  		select {
   222  		case x := <-parent.Done():
   223  			t.Errorf("<-parent.Done() == %v want nothing (it should block)", x)
   224  		default:
   225  		}
   226  		if e := parent.Err(); e != nil {
   227  			t.Errorf("parent.Err() == %v want nil", e)
   228  		}
   229  	}
   230  }
   231  
   232  func testDeadline(c Context, name string, failAfter time.Duration, t *testing.T) {
   233  	select {
   234  	case <-time.After(failAfter):
   235  		t.Fatalf("%s: context should have timed out", name)
   236  	case <-c.Done():
   237  	}
   238  	if e := c.Err(); e != DeadlineExceeded {
   239  		t.Errorf("%s: c.Err() == %v; want %v", name, e, DeadlineExceeded)
   240  	}
   241  }
   242  
   243  func TestDeadline(t *testing.T) {
   244  	c, _ := WithDeadline(Background(), time.Now().Add(50*time.Millisecond))
   245  	if got, prefix := fmt.Sprint(c), "context.Background.WithDeadline("; !strings.HasPrefix(got, prefix) {
   246  		t.Errorf("c.String() = %q want prefix %q", got, prefix)
   247  	}
   248  	testDeadline(c, "WithDeadline", time.Second, t)
   249  
   250  	c, _ = WithDeadline(Background(), time.Now().Add(50*time.Millisecond))
   251  	o := otherContext{c}
   252  	testDeadline(o, "WithDeadline+otherContext", time.Second, t)
   253  
   254  	c, _ = WithDeadline(Background(), time.Now().Add(50*time.Millisecond))
   255  	o = otherContext{c}
   256  	c, _ = WithDeadline(o, time.Now().Add(4*time.Second))
   257  	testDeadline(c, "WithDeadline+otherContext+WithDeadline", 2*time.Second, t)
   258  }
   259  
   260  func TestTimeout(t *testing.T) {
   261  	c, _ := WithTimeout(Background(), 50*time.Millisecond)
   262  	if got, prefix := fmt.Sprint(c), "context.Background.WithDeadline("; !strings.HasPrefix(got, prefix) {
   263  		t.Errorf("c.String() = %q want prefix %q", got, prefix)
   264  	}
   265  	testDeadline(c, "WithTimeout", time.Second, t)
   266  
   267  	c, _ = WithTimeout(Background(), 50*time.Millisecond)
   268  	o := otherContext{c}
   269  	testDeadline(o, "WithTimeout+otherContext", time.Second, t)
   270  
   271  	c, _ = WithTimeout(Background(), 50*time.Millisecond)
   272  	o = otherContext{c}
   273  	c, _ = WithTimeout(o, 3*time.Second)
   274  	testDeadline(c, "WithTimeout+otherContext+WithTimeout", 2*time.Second, t)
   275  }
   276  
   277  func TestCanceledTimeout(t *testing.T) {
   278  	c, _ := WithTimeout(Background(), time.Second)
   279  	o := otherContext{c}
   280  	c, cancel := WithTimeout(o, 2*time.Second)
   281  	cancel()
   282  	time.Sleep(100 * time.Millisecond) // let cancelation propagate
   283  	select {
   284  	case <-c.Done():
   285  	default:
   286  		t.Errorf("<-c.Done() blocked, but shouldn't have")
   287  	}
   288  	if e := c.Err(); e != Canceled {
   289  		t.Errorf("c.Err() == %v want %v", e, Canceled)
   290  	}
   291  }
   292  
   293  type key1 int
   294  type key2 int
   295  
   296  var k1 = key1(1)
   297  var k2 = key2(1) // same int as k1, different type
   298  var k3 = key2(3) // same type as k2, different int
   299  
   300  func TestValues(t *testing.T) {
   301  	check := func(c Context, nm, v1, v2, v3 string) {
   302  		if v, ok := c.Value(k1).(string); ok == (len(v1) == 0) || v != v1 {
   303  			t.Errorf(`%s.Value(k1).(string) = %q, %t want %q, %t`, nm, v, ok, v1, len(v1) != 0)
   304  		}
   305  		if v, ok := c.Value(k2).(string); ok == (len(v2) == 0) || v != v2 {
   306  			t.Errorf(`%s.Value(k2).(string) = %q, %t want %q, %t`, nm, v, ok, v2, len(v2) != 0)
   307  		}
   308  		if v, ok := c.Value(k3).(string); ok == (len(v3) == 0) || v != v3 {
   309  			t.Errorf(`%s.Value(k3).(string) = %q, %t want %q, %t`, nm, v, ok, v3, len(v3) != 0)
   310  		}
   311  	}
   312  
   313  	c0 := Background()
   314  	check(c0, "c0", "", "", "")
   315  
   316  	c1 := WithValue(Background(), k1, "c1k1")
   317  	check(c1, "c1", "c1k1", "", "")
   318  
   319  	if got, want := fmt.Sprint(c1), `context.Background.WithValue(1, "c1k1")`; got != want {
   320  		t.Errorf("c.String() = %q want %q", got, want)
   321  	}
   322  
   323  	c2 := WithValue(c1, k2, "c2k2")
   324  	check(c2, "c2", "c1k1", "c2k2", "")
   325  
   326  	c3 := WithValue(c2, k3, "c3k3")
   327  	check(c3, "c2", "c1k1", "c2k2", "c3k3")
   328  
   329  	c4 := WithValue(c3, k1, nil)
   330  	check(c4, "c4", "", "c2k2", "c3k3")
   331  
   332  	o0 := otherContext{Background()}
   333  	check(o0, "o0", "", "", "")
   334  
   335  	o1 := otherContext{WithValue(Background(), k1, "c1k1")}
   336  	check(o1, "o1", "c1k1", "", "")
   337  
   338  	o2 := WithValue(o1, k2, "o2k2")
   339  	check(o2, "o2", "c1k1", "o2k2", "")
   340  
   341  	o3 := otherContext{c4}
   342  	check(o3, "o3", "", "c2k2", "c3k3")
   343  
   344  	o4 := WithValue(o3, k3, nil)
   345  	check(o4, "o4", "", "c2k2", "")
   346  }
   347  
   348  func TestAllocs(t *testing.T) {
   349  	bg := Background()
   350  	for _, test := range []struct {
   351  		desc       string
   352  		f          func()
   353  		limit      float64
   354  		gccgoLimit float64
   355  	}{
   356  		{
   357  			desc:       "Background()",
   358  			f:          func() { Background() },
   359  			limit:      0,
   360  			gccgoLimit: 0,
   361  		},
   362  		{
   363  			desc: fmt.Sprintf("WithValue(bg, %v, nil)", k1),
   364  			f: func() {
   365  				c := WithValue(bg, k1, nil)
   366  				c.Value(k1)
   367  			},
   368  			limit:      3,
   369  			gccgoLimit: 3,
   370  		},
   371  		{
   372  			desc: "WithTimeout(bg, 15*time.Millisecond)",
   373  			f: func() {
   374  				c, _ := WithTimeout(bg, 15*time.Millisecond)
   375  				<-c.Done()
   376  			},
   377  			limit:      8,
   378  			gccgoLimit: 15,
   379  		},
   380  		{
   381  			desc: "WithCancel(bg)",
   382  			f: func() {
   383  				c, cancel := WithCancel(bg)
   384  				cancel()
   385  				<-c.Done()
   386  			},
   387  			limit:      5,
   388  			gccgoLimit: 8,
   389  		},
   390  		{
   391  			desc: "WithTimeout(bg, 5*time.Millisecond)",
   392  			f: func() {
   393  				c, cancel := WithTimeout(bg, 5*time.Millisecond)
   394  				cancel()
   395  				<-c.Done()
   396  			},
   397  			limit:      8,
   398  			gccgoLimit: 25,
   399  		},
   400  	} {
   401  		limit := test.limit
   402  		if runtime.Compiler == "gccgo" {
   403  			// gccgo does not yet do escape analysis.
   404  			// TOOD(iant): Remove this when gccgo does do escape analysis.
   405  			limit = test.gccgoLimit
   406  		}
   407  		numRuns := 100
   408  		if testing.Short() {
   409  			numRuns = 10
   410  		}
   411  		if n := testing.AllocsPerRun(numRuns, test.f); n > limit {
   412  			t.Errorf("%s allocs = %f want %d", test.desc, n, int(limit))
   413  		}
   414  	}
   415  }
   416  
   417  func TestSimultaneousCancels(t *testing.T) {
   418  	root, cancel := WithCancel(Background())
   419  	m := map[Context]CancelFunc{root: cancel}
   420  	q := []Context{root}
   421  	// Create a tree of contexts.
   422  	for len(q) != 0 && len(m) < 100 {
   423  		parent := q[0]
   424  		q = q[1:]
   425  		for i := 0; i < 4; i++ {
   426  			ctx, cancel := WithCancel(parent)
   427  			m[ctx] = cancel
   428  			q = append(q, ctx)
   429  		}
   430  	}
   431  	// Start all the cancels in a random order.
   432  	var wg sync.WaitGroup
   433  	wg.Add(len(m))
   434  	for _, cancel := range m {
   435  		go func(cancel CancelFunc) {
   436  			cancel()
   437  			wg.Done()
   438  		}(cancel)
   439  	}
   440  	// Wait on all the contexts in a random order.
   441  	for ctx := range m {
   442  		select {
   443  		case <-ctx.Done():
   444  		case <-time.After(1 * time.Second):
   445  			buf := make([]byte, 10<<10)
   446  			n := runtime.Stack(buf, true)
   447  			t.Fatalf("timed out waiting for <-ctx.Done(); stacks:\n%s", buf[:n])
   448  		}
   449  	}
   450  	// Wait for all the cancel functions to return.
   451  	done := make(chan struct{})
   452  	go func() {
   453  		wg.Wait()
   454  		close(done)
   455  	}()
   456  	select {
   457  	case <-done:
   458  	case <-time.After(1 * time.Second):
   459  		buf := make([]byte, 10<<10)
   460  		n := runtime.Stack(buf, true)
   461  		t.Fatalf("timed out waiting for cancel functions; stacks:\n%s", buf[:n])
   462  	}
   463  }
   464  
   465  func TestInterlockedCancels(t *testing.T) {
   466  	parent, cancelParent := WithCancel(Background())
   467  	child, cancelChild := WithCancel(parent)
   468  	go func() {
   469  		parent.Done()
   470  		cancelChild()
   471  	}()
   472  	cancelParent()
   473  	select {
   474  	case <-child.Done():
   475  	case <-time.After(1 * time.Second):
   476  		buf := make([]byte, 10<<10)
   477  		n := runtime.Stack(buf, true)
   478  		t.Fatalf("timed out waiting for child.Done(); stacks:\n%s", buf[:n])
   479  	}
   480  }
   481  
   482  func TestLayersCancel(t *testing.T) {
   483  	testLayers(t, time.Now().UnixNano(), false)
   484  }
   485  
   486  func TestLayersTimeout(t *testing.T) {
   487  	testLayers(t, time.Now().UnixNano(), true)
   488  }
   489  
   490  func testLayers(t *testing.T, seed int64, testTimeout bool) {
   491  	rand.Seed(seed)
   492  	errorf := func(format string, a ...interface{}) {
   493  		t.Errorf(fmt.Sprintf("seed=%d: %s", seed, format), a...)
   494  	}
   495  	const (
   496  		timeout   = 200 * time.Millisecond
   497  		minLayers = 30
   498  	)
   499  	type value int
   500  	var (
   501  		vals      []*value
   502  		cancels   []CancelFunc
   503  		numTimers int
   504  		ctx       = Background()
   505  	)
   506  	for i := 0; i < minLayers || numTimers == 0 || len(cancels) == 0 || len(vals) == 0; i++ {
   507  		switch rand.Intn(3) {
   508  		case 0:
   509  			v := new(value)
   510  			ctx = WithValue(ctx, v, v)
   511  			vals = append(vals, v)
   512  		case 1:
   513  			var cancel CancelFunc
   514  			ctx, cancel = WithCancel(ctx)
   515  			cancels = append(cancels, cancel)
   516  		case 2:
   517  			var cancel CancelFunc
   518  			ctx, cancel = WithTimeout(ctx, timeout)
   519  			cancels = append(cancels, cancel)
   520  			numTimers++
   521  		}
   522  	}
   523  	checkValues := func(when string) {
   524  		for _, key := range vals {
   525  			if val := ctx.Value(key).(*value); key != val {
   526  				errorf("%s: ctx.Value(%p) = %p want %p", when, key, val, key)
   527  			}
   528  		}
   529  	}
   530  	select {
   531  	case <-ctx.Done():
   532  		errorf("ctx should not be canceled yet")
   533  	default:
   534  	}
   535  	if s, prefix := fmt.Sprint(ctx), "context.Background."; !strings.HasPrefix(s, prefix) {
   536  		t.Errorf("ctx.String() = %q want prefix %q", s, prefix)
   537  	}
   538  	t.Log(ctx)
   539  	checkValues("before cancel")
   540  	if testTimeout {
   541  		select {
   542  		case <-ctx.Done():
   543  		case <-time.After(timeout + time.Second):
   544  			errorf("ctx should have timed out")
   545  		}
   546  		checkValues("after timeout")
   547  	} else {
   548  		cancel := cancels[rand.Intn(len(cancels))]
   549  		cancel()
   550  		select {
   551  		case <-ctx.Done():
   552  		default:
   553  			errorf("ctx should be canceled")
   554  		}
   555  		checkValues("after cancel")
   556  	}
   557  }
   558  
   559  func TestCancelRemoves(t *testing.T) {
   560  	checkChildren := func(when string, ctx Context, want int) {
   561  		if got := len(ctx.(*cancelCtx).children); got != want {
   562  			t.Errorf("%s: context has %d children, want %d", when, got, want)
   563  		}
   564  	}
   565  
   566  	ctx, _ := WithCancel(Background())
   567  	checkChildren("after creation", ctx, 0)
   568  	_, cancel := WithCancel(ctx)
   569  	checkChildren("with WithCancel child ", ctx, 1)
   570  	cancel()
   571  	checkChildren("after cancelling WithCancel child", ctx, 0)
   572  
   573  	ctx, _ = WithCancel(Background())
   574  	checkChildren("after creation", ctx, 0)
   575  	_, cancel = WithTimeout(ctx, 60*time.Minute)
   576  	checkChildren("with WithTimeout child ", ctx, 1)
   577  	cancel()
   578  	checkChildren("after cancelling WithTimeout child", ctx, 0)
   579  }
   580  
   581  func TestWithValueChecksKey(t *testing.T) {
   582  	panicVal := recoveredValue(func() { WithValue(Background(), []byte("foo"), "bar") })
   583  	if panicVal == nil {
   584  		t.Error("expected panic")
   585  	}
   586  	panicVal = recoveredValue(func() { WithValue(Background(), nil, "bar") })
   587  	if got, want := fmt.Sprint(panicVal), "nil key"; got != want {
   588  		t.Errorf("panic = %q; want %q", got, want)
   589  	}
   590  }
   591  
   592  func recoveredValue(fn func()) (v interface{}) {
   593  	defer func() { v = recover() }()
   594  	fn()
   595  	return
   596  }
   597  
   598  func TestDeadlineExceededSupportsTimeout(t *testing.T) {
   599  	i, ok := DeadlineExceeded.(interface {
   600  		Timeout() bool
   601  	})
   602  	if !ok {
   603  		t.Fatal("DeadlineExceeded does not support Timeout interface")
   604  	}
   605  	if !i.Timeout() {
   606  		t.Fatal("wrong value for timeout")
   607  	}
   608  }