vitess.io/vitess@v0.16.2/go/vt/concurrency/error_group_test.go (about)

     1  /*
     2  Copyright 2021 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8  	http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package concurrency
    18  
    19  import (
    20  	"context"
    21  	"fmt"
    22  	"testing"
    23  	"time"
    24  
    25  	"github.com/stretchr/testify/require"
    26  )
    27  
    28  func TestErrorGroup(t *testing.T) {
    29  	testcases := []struct {
    30  		name                string
    31  		errorGroup          ErrorGroup
    32  		numSuccesses        int
    33  		numErrors           int
    34  		numWaitFor          int
    35  		numDelayedSuccesses int
    36  		expectedError       string
    37  	}{
    38  		{
    39  			name: "Wait() returns immediately when NumGoroutines = 0",
    40  			errorGroup: ErrorGroup{
    41  				NumGoroutines: 0,
    42  			},
    43  		}, {
    44  			name: "Require all successes - pass",
    45  			errorGroup: ErrorGroup{
    46  				NumGoroutines:        4,
    47  				NumRequiredSuccesses: 4,
    48  			},
    49  			numSuccesses: 4,
    50  		}, {
    51  			name: "Require all successes - failure",
    52  			errorGroup: ErrorGroup{
    53  				NumGoroutines:        4,
    54  				NumRequiredSuccesses: 4,
    55  			},
    56  			numSuccesses:  3,
    57  			numErrors:     1,
    58  			expectedError: "a general error",
    59  		}, {
    60  			name: "1 allowed failure",
    61  			errorGroup: ErrorGroup{
    62  				NumGoroutines:        4,
    63  				NumRequiredSuccesses: 3,
    64  				NumAllowedErrors:     1,
    65  			},
    66  			numSuccesses:  3,
    67  			numErrors:     1,
    68  			expectedError: "a general error",
    69  		}, {
    70  			name: "less than allowed failures",
    71  			errorGroup: ErrorGroup{
    72  				NumGoroutines:        4,
    73  				NumRequiredSuccesses: 2,
    74  				NumAllowedErrors:     2,
    75  			},
    76  			numSuccesses:  3,
    77  			numErrors:     1,
    78  			expectedError: "a general error",
    79  		}, {
    80  			name: "1 must wait for routine",
    81  			errorGroup: ErrorGroup{
    82  				NumGoroutines:        4,
    83  				NumRequiredSuccesses: 2,
    84  				NumAllowedErrors:     2,
    85  				NumErrorsToWaitFor:   1,
    86  			},
    87  			numSuccesses: 3,
    88  			numWaitFor:   1,
    89  		}, {
    90  			name: "delayed success should be cancelled",
    91  			errorGroup: ErrorGroup{
    92  				NumGoroutines:        4,
    93  				NumRequiredSuccesses: 2,
    94  				NumAllowedErrors:     2,
    95  			},
    96  			numSuccesses:        3,
    97  			numDelayedSuccesses: 1,
    98  			expectedError:       "context cancelled",
    99  		},
   100  	}
   101  	for _, testcase := range testcases {
   102  		t.Run(testcase.name, func(t *testing.T) {
   103  			groupContext, groupCancel := context.WithTimeout(context.Background(), 10*time.Second)
   104  			defer groupCancel()
   105  			errCh := make(chan Error)
   106  			defer close(errCh)
   107  
   108  			spawnGoRoutines(errCh, false, testcase.numSuccesses)
   109  			spawnGoRoutines(errCh, true, testcase.numErrors)
   110  			spawnDelayedGoRoutine(groupContext, errCh, true, testcase.numWaitFor)
   111  			spawnDelayedGoRoutine(groupContext, errCh, false, testcase.numDelayedSuccesses)
   112  
   113  			err := testcase.errorGroup.Wait(groupCancel, errCh)
   114  			if testcase.expectedError == "" {
   115  				require.False(t, err.HasErrors())
   116  				require.NoError(t, err.Error())
   117  			} else {
   118  				require.True(t, err.HasErrors())
   119  				require.EqualError(t, err.Error(), testcase.expectedError)
   120  			}
   121  		})
   122  	}
   123  }
   124  
   125  func spawnGoRoutines(errCh chan Error, shouldError bool, count int) {
   126  	for i := 0; i < count; i++ {
   127  		go func() {
   128  			time.Sleep(100 * time.Millisecond)
   129  			var err Error
   130  			if shouldError {
   131  				err.Err = fmt.Errorf("a general error")
   132  			}
   133  			errCh <- err
   134  		}()
   135  	}
   136  }
   137  
   138  func spawnDelayedGoRoutine(groupContext context.Context, errCh chan Error, mustWaitFor bool, count int) {
   139  	for i := 0; i < count; i++ {
   140  		go func() {
   141  			select {
   142  			case <-groupContext.Done():
   143  				err := Error{
   144  					Err: fmt.Errorf("context cancelled"),
   145  				}
   146  				errCh <- err
   147  			case <-time.After(300 * time.Millisecond):
   148  				err := Error{
   149  					MustWaitFor: mustWaitFor,
   150  				}
   151  				errCh <- err
   152  			}
   153  		}()
   154  	}
   155  }