github.com/twelsh-aw/go/src@v0.0.0-20230516233729-a56fe86a7c81/context/afterfunc_test.go (about)

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package context_test
     6  
     7  import (
     8  	"context"
     9  	"sync"
    10  	"testing"
    11  	"time"
    12  )
    13  
    14  // afterFuncContext is a context that's not one of the types
    15  // defined in context.go, that supports registering AfterFuncs.
    16  type afterFuncContext struct {
    17  	mu         sync.Mutex
    18  	afterFuncs map[*struct{}]func()
    19  	done       chan struct{}
    20  	err        error
    21  }
    22  
    23  func newAfterFuncContext() context.Context {
    24  	return &afterFuncContext{}
    25  }
    26  
    27  func (c *afterFuncContext) Deadline() (time.Time, bool) {
    28  	return time.Time{}, false
    29  }
    30  
    31  func (c *afterFuncContext) Done() <-chan struct{} {
    32  	c.mu.Lock()
    33  	defer c.mu.Unlock()
    34  	if c.done == nil {
    35  		c.done = make(chan struct{})
    36  	}
    37  	return c.done
    38  }
    39  
    40  func (c *afterFuncContext) Err() error {
    41  	c.mu.Lock()
    42  	defer c.mu.Unlock()
    43  	return c.err
    44  }
    45  
    46  func (c *afterFuncContext) Value(key any) any {
    47  	return nil
    48  }
    49  
    50  func (c *afterFuncContext) AfterFunc(f func()) func() bool {
    51  	c.mu.Lock()
    52  	defer c.mu.Unlock()
    53  	k := &struct{}{}
    54  	if c.afterFuncs == nil {
    55  		c.afterFuncs = make(map[*struct{}]func())
    56  	}
    57  	c.afterFuncs[k] = f
    58  	return func() bool {
    59  		c.mu.Lock()
    60  		defer c.mu.Unlock()
    61  		_, ok := c.afterFuncs[k]
    62  		delete(c.afterFuncs, k)
    63  		return ok
    64  	}
    65  }
    66  
    67  func (c *afterFuncContext) cancel(err error) {
    68  	c.mu.Lock()
    69  	defer c.mu.Unlock()
    70  	if c.err != nil {
    71  		return
    72  	}
    73  	c.err = err
    74  	for _, f := range c.afterFuncs {
    75  		go f()
    76  	}
    77  	c.afterFuncs = nil
    78  }
    79  
    80  func TestCustomContextAfterFuncCancel(t *testing.T) {
    81  	ctx0 := &afterFuncContext{}
    82  	ctx1, cancel := context.WithCancel(ctx0)
    83  	defer cancel()
    84  	ctx0.cancel(context.Canceled)
    85  	<-ctx1.Done()
    86  }
    87  
    88  func TestCustomContextAfterFuncTimeout(t *testing.T) {
    89  	ctx0 := &afterFuncContext{}
    90  	ctx1, cancel := context.WithTimeout(ctx0, veryLongDuration)
    91  	defer cancel()
    92  	ctx0.cancel(context.Canceled)
    93  	<-ctx1.Done()
    94  }
    95  
    96  func TestCustomContextAfterFuncAfterFunc(t *testing.T) {
    97  	ctx0 := &afterFuncContext{}
    98  	donec := make(chan struct{})
    99  	stop := context.AfterFunc(ctx0, func() {
   100  		close(donec)
   101  	})
   102  	defer stop()
   103  	ctx0.cancel(context.Canceled)
   104  	<-donec
   105  }
   106  
   107  func TestCustomContextAfterFuncUnregisterCancel(t *testing.T) {
   108  	ctx0 := &afterFuncContext{}
   109  	_, cancel := context.WithCancel(ctx0)
   110  	if got, want := len(ctx0.afterFuncs), 1; got != want {
   111  		t.Errorf("after WithCancel(ctx0): ctx0 has %v afterFuncs, want %v", got, want)
   112  	}
   113  	cancel()
   114  	if got, want := len(ctx0.afterFuncs), 0; got != want {
   115  		t.Errorf("after canceling WithCancel(ctx0): ctx0 has %v afterFuncs, want %v", got, want)
   116  	}
   117  }
   118  
   119  func TestCustomContextAfterFuncUnregisterTimeout(t *testing.T) {
   120  	ctx0 := &afterFuncContext{}
   121  	_, cancel := context.WithTimeout(ctx0, veryLongDuration)
   122  	if got, want := len(ctx0.afterFuncs), 1; got != want {
   123  		t.Errorf("after WithTimeout(ctx0, d): ctx0 has %v afterFuncs, want %v", got, want)
   124  	}
   125  	cancel()
   126  	if got, want := len(ctx0.afterFuncs), 0; got != want {
   127  		t.Errorf("after canceling WithTimeout(ctx0, d): ctx0 has %v afterFuncs, want %v", got, want)
   128  	}
   129  }
   130  
   131  func TestCustomContextAfterFuncUnregisterAfterFunc(t *testing.T) {
   132  	ctx0 := &afterFuncContext{}
   133  	stop := context.AfterFunc(ctx0, func() {})
   134  	if got, want := len(ctx0.afterFuncs), 1; got != want {
   135  		t.Errorf("after AfterFunc(ctx0, f): ctx0 has %v afterFuncs, want %v", got, want)
   136  	}
   137  	stop()
   138  	if got, want := len(ctx0.afterFuncs), 0; got != want {
   139  		t.Errorf("after stopping AfterFunc(ctx0, f): ctx0 has %v afterFuncs, want %v", got, want)
   140  	}
   141  }