vitess.io/vitess@v0.16.2/go/vt/concurrency/error_group.go (about)

     1  /*
     2  Copyright 2021 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8  	http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package concurrency
    18  
    19  import "context"
    20  
    21  // ErrorGroup provides a function for waiting for N goroutines to complete with
    22  // at least Z successes which we wanted to wait for and
    23  // at least X overall successes and no more than Y failures, and cancelling the rest.
    24  //
    25  // It should be used as follows:
    26  //
    27  //		errCh := make(chan concurrency.Error)
    28  //		errgroupCtx, errgroupCancel := context.WithCancel(ctx)
    29  //
    30  //		for _, arg := range args {
    31  //			arg := arg
    32  //
    33  //			go func() {
    34  //				err := doWork(errGroupCtx, arg)
    35  //				errCh <- concurrency.Error{
    36  //					Err: err,
    37  //					MustWaitFor: <boolean>,
    38  //				}
    39  //			}()
    40  //		}
    41  //
    42  //		errgroup := concurrency.ErrorGroup{
    43  //			NumGoroutines: len(args),
    44  //			NumRequiredSuccess: 5, // need at least 5 to respond with nil error before cancelling the rest
    45  //			NumAllowedErrors: 1, // if more than 1 responds with non-nil error, cancel the rest
    46  //			NumErrorsToWaitFor: 1, // if there is 1 response that we must wait for, before cancelling the rest
    47  //		}
    48  //		errRec := errgroup.Wait(errgroupCancel, errCh)
    49  //
    50  //		if errRec.HasErrors() {
    51  //			// ...
    52  //		}
    53  //
    54  //	The NumErrorsToWaitFor should be equal to the number of
    55  //	Errors that are received on the channel which have MustWaitFor set
    56  type ErrorGroup struct {
    57  	NumGoroutines        int
    58  	NumRequiredSuccesses int
    59  	NumAllowedErrors     int
    60  	NumErrorsToWaitFor   int
    61  }
    62  
    63  // Error is used in ErrGroup.Wait function
    64  // It contains the error that was received along
    65  // with the information of whether the received error
    66  // originated from a tablet that we must wait for
    67  type Error struct {
    68  	Err         error
    69  	MustWaitFor bool
    70  }
    71  
    72  // Wait waits for a group of goroutines that are sending errors to the given
    73  // Error channel, and are cancellable by the given cancel function.
    74  //
    75  // Wait will cancel any outstanding goroutines when the following condition is met:
    76  //
    77  //   - At least NumErrorsToWaitFor results with MustWaitFor set have been consumed
    78  //     on the error channel AND one of the following two -
    79  //     (1) More than NumAllowedErrors non-nil results have been consumed on the
    80  //     error channel.
    81  //
    82  //     (2) At least NumRequiredSuccesses nil results have been consumed on the error
    83  //     channel.
    84  //
    85  // After the cancellation condition is triggered, Wait will continue to consume
    86  // results off the Error channel so as to not permanently block any of those
    87  // cancelled goroutines.
    88  //
    89  // When finished consuming results from all goroutines, cancelled or otherwise,
    90  // Wait returns an AllErrorRecorder that contains all errors returned by any of
    91  // those goroutines. It does not close the Error channel.
    92  func (eg ErrorGroup) Wait(cancel context.CancelFunc, errors chan Error) *AllErrorRecorder {
    93  	errCounter := 0
    94  	successCounter := 0
    95  	responseCounter := 0
    96  	mustWaitForCounter := 0
    97  	rec := &AllErrorRecorder{}
    98  
    99  	if eg.NumGoroutines < 1 {
   100  		return rec
   101  	}
   102  
   103  	for err := range errors {
   104  		responseCounter++
   105  		if err.MustWaitFor {
   106  			mustWaitForCounter++
   107  		}
   108  
   109  		switch err.Err {
   110  		case nil:
   111  			successCounter++
   112  		default:
   113  			errCounter++
   114  			rec.RecordError(err.Err)
   115  		}
   116  
   117  		// Even though we cancel in the next conditional, we need to keep
   118  		// consuming off the channel, or those goroutines will get stuck
   119  		// forever.
   120  		if responseCounter == eg.NumGoroutines {
   121  			break
   122  		}
   123  
   124  		if mustWaitForCounter >= eg.NumErrorsToWaitFor && (errCounter > eg.NumAllowedErrors || successCounter >= eg.NumRequiredSuccesses) {
   125  			cancel()
   126  		}
   127  	}
   128  
   129  	return rec
   130  }