github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/util/ctxgroup/ctxgroup.go (about)

     1  // Copyright 2018 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  /*
    12  Package ctxgroup wraps golang.org/x/sync/errgroup with a context func.
    13  
    14  This package extends and modifies the errgroup API slightly to
    15  make context variables more explicit. WithContext no longer returns
    16  a context. Instead, the GoCtx method explicitly passes one to the
    17  invoked func. The goal is to make misuse of context vars with errgroups
    18  more difficult. Example usage:
    19  
    20  	ctx := context.Background()
    21  	g := ctxgroup.WithContext(ctx)
    22  	ch := make(chan bool)
    23  	g.GoCtx(func(ctx context.Context) error {
    24  		defer close(ch)
    25  		for _, val := range []bool{true, false} {
    26  			select {
    27  			case ch <- val:
    28  			case <-ctx.Done():
    29  				return ctx.Err()
    30  			}
    31  		}
    32  		return nil
    33  	})
    34  	g.GoCtx(func(ctx context.Context) error {
    35  		for val := range ch {
    36  			if err := api.Call(ctx, val); err != nil {
    37  				return err
    38  			}
    39  		}
    40  		return nil
    41  	})
    42  	if err := g.Wait(); err != nil {
    43  		return err
    44  	}
    45  	api.Call(ctx, "done")
    46  
    47  Problems with errgroup
    48  
    49  The bugs this package attempts to prevent are: misuse of shadowed
    50  ctx variables after errgroup closure and confusion in the face of
    51  multiple ctx variables when trying to prevent shadowing. The following
    52  are all example bugs that Cockroach has had during its use of errgroup:
    53  
    54  	ctx := context.Background()
    55  	g, ctx := errgroup.WithContext(ctx)
    56  	ch := make(chan bool)
    57  	g.Go(func() error {
    58  		defer close(ch)
    59  		for _, val := range []bool{true, false} {
    60  			select {
    61  			case ch <- val:
    62  			case <-ctx.Done():
    63  				return ctx.Err()
    64  			}
    65  		}
    66  		return nil
    67  	})
    68  	g.Go(func() error {
    69  		for val := range ch {
    70  			if err := api.Call(ctx, val); err != nil {
    71  				return err
    72  			}
    73  		}
    74  		return nil
    75  	})
    76  	if err := g.Wait(); err != nil {
    77  		return err
    78  	}
    79  	api.Call(ctx, "done")
    80  
    81  The ctx used by the final api.Call is already closed because the
    82  errgroup has returned. This happened because of the desire to not
    83  create another ctx variable, and so we shadowed the original ctx var,
    84  but then incorrectly continued to use it after the errgroup had closed
    85  its context. So we make a modification and create new gCtx variable
    86  that doesn't shadow the original ctx:
    87  
    88  	ctx := context.Background()
    89  	g, gCtx := errgroup.WithContext(ctx)
    90  	ch := make(chan bool)
    91  	g.Go(func() error {
    92  		defer close(ch)
    93  		for _, val := range []bool{true, false} {
    94  			select {
    95  			case ch <- val:
    96  			case <-ctx.Done():
    97  				return ctx.Err()
    98  			}
    99  		}
   100  		return nil
   101  	})
   102  	g.Go(func() error {
   103  		for val := range ch {
   104  			if err := api.Call(ctx, val); err != nil {
   105  				return err
   106  			}
   107  		}
   108  		return nil
   109  	})
   110  	if err := g.Wait(); err != nil {
   111  		return err
   112  	}
   113  	api.Call(ctx, "done")
   114  
   115  Now the final api.Call is correct. But the other api.Call is incorrect
   116  and the ctx.Done receive is incorrect because they are using the wrong
   117  context and thus won't correctly exit early if the errgroup needs to
   118  exit early.
   119  
   120  */
   121  package ctxgroup
   122  
   123  import (
   124  	"context"
   125  
   126  	"golang.org/x/sync/errgroup"
   127  )
   128  
   129  // Group wraps errgroup.
   130  type Group struct {
   131  	wrapped *errgroup.Group
   132  	ctx     context.Context
   133  }
   134  
   135  // Wait blocks until all function calls from the Go method have returned, then
   136  // returns the first non-nil error (if any) from them. If Wait() is invoked
   137  // after the context (originally supplied to WithContext) is canceled, Wait
   138  // returns an error, even if no Go invocation did. In particular, calling
   139  // Wait() after Done has been closed is guaranteed to return an error.
   140  func (g Group) Wait() error {
   141  	ctxErr := g.ctx.Err()
   142  	err := g.wrapped.Wait()
   143  	if err != nil {
   144  		return err
   145  	}
   146  	return ctxErr
   147  }
   148  
   149  // WithContext returns a new Group and an associated Context derived from ctx.
   150  func WithContext(ctx context.Context) Group {
   151  	grp, ctx := errgroup.WithContext(ctx)
   152  	return Group{
   153  		wrapped: grp,
   154  		ctx:     ctx,
   155  	}
   156  }
   157  
   158  // Go calls the given function in a new goroutine.
   159  func (g Group) Go(f func() error) {
   160  	g.wrapped.Go(f)
   161  }
   162  
   163  // GoCtx calls the given function in a new goroutine.
   164  func (g Group) GoCtx(f func(ctx context.Context) error) {
   165  	g.wrapped.Go(func() error {
   166  		return f(g.ctx)
   167  	})
   168  }
   169  
   170  // GroupWorkers runs num worker go routines in an errgroup.
   171  func GroupWorkers(ctx context.Context, num int, f func(context.Context, int) error) error {
   172  	group := WithContext(ctx)
   173  	for i := 0; i < num; i++ {
   174  		workerID := i
   175  		group.GoCtx(func(ctx context.Context) error { return f(ctx, workerID) })
   176  	}
   177  	return group.Wait()
   178  }