github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/utils/svcs/controller_test.go (about)

     1  // Copyright 2023 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package svcs
    16  
    17  import (
    18  	"context"
    19  	"errors"
    20  	"sync"
    21  	"testing"
    22  
    23  	"github.com/stretchr/testify/require"
    24  )
    25  
    26  func TestController(t *testing.T) {
    27  	t.Run("NewController", func(t *testing.T) {
    28  		c := NewController()
    29  		require.NotNil(t, c)
    30  	})
    31  	t.Run("Stop", func(t *testing.T) {
    32  		t.Run("CalledBeforeStart", func(t *testing.T) {
    33  			c := NewController()
    34  			c.Stop()
    35  			require.Error(t, c.Start(context.Background()))
    36  			require.NoError(t, c.WaitForStart())
    37  			require.NoError(t, c.WaitForStop())
    38  		})
    39  		t.Run("ReturnsFirstError", func(t *testing.T) {
    40  			c := NewController()
    41  			ctx := context.Background()
    42  			err := errors.New("first")
    43  			require.NoError(t, c.Register(&AnonService{
    44  				InitF: func(context.Context) error { return nil },
    45  				RunF:  func(context.Context) {},
    46  				StopF: func() error { return errors.New("second") },
    47  			}))
    48  			require.NoError(t, c.Register(&AnonService{
    49  				InitF: func(context.Context) error { return nil },
    50  				RunF:  func(context.Context) {},
    51  				StopF: func() error { return err },
    52  			}))
    53  			var wg sync.WaitGroup
    54  			wg.Add(1)
    55  			go func() {
    56  				defer wg.Done()
    57  				require.NoError(t, c.WaitForStart())
    58  				c.Stop()
    59  			}()
    60  			require.ErrorIs(t, c.Start(ctx), err)
    61  			require.ErrorIs(t, c.WaitForStop(), err)
    62  			wg.Wait()
    63  		})
    64  	})
    65  	t.Run("EmptyServices", func(t *testing.T) {
    66  		c := NewController()
    67  		ctx := context.Background()
    68  		var wg sync.WaitGroup
    69  		wg.Add(1)
    70  		go func() {
    71  			defer wg.Done()
    72  			require.NoError(t, c.WaitForStart())
    73  			c.Stop()
    74  		}()
    75  		require.NoError(t, c.Start(ctx))
    76  		require.NoError(t, c.WaitForStop())
    77  		wg.Wait()
    78  	})
    79  	t.Run("Register", func(t *testing.T) {
    80  		t.Run("AfterStartCalled", func(t *testing.T) {
    81  			c := NewController()
    82  			ctx := context.Background()
    83  			var wg sync.WaitGroup
    84  			wg.Add(1)
    85  			go func() {
    86  				defer wg.Done()
    87  				require.NoError(t, c.WaitForStart())
    88  				require.Error(t, c.Register(&AnonService{
    89  					InitF: func(context.Context) error { return nil },
    90  					RunF:  func(context.Context) {},
    91  					StopF: func() error { return nil },
    92  				}))
    93  				c.Stop()
    94  			}()
    95  			require.NoError(t, c.Start(ctx))
    96  			require.NoError(t, c.WaitForStop())
    97  			wg.Wait()
    98  		})
    99  	})
   100  	t.Run("Start", func(t *testing.T) {
   101  		t.Run("CallsInitInOrder", func(t *testing.T) {
   102  			c := NewController()
   103  			var inited []int
   104  			require.NoError(t, c.Register(&AnonService{
   105  				InitF: func(context.Context) error {
   106  					inited = append(inited, 0)
   107  					return nil
   108  				},
   109  				RunF:  func(context.Context) {},
   110  				StopF: func() error { return nil },
   111  			}))
   112  			require.NoError(t, c.Register(&AnonService{
   113  				InitF: func(context.Context) error {
   114  					inited = append(inited, 1)
   115  					return nil
   116  				},
   117  				RunF:  func(context.Context) {},
   118  				StopF: func() error { return nil },
   119  			}))
   120  			require.NoError(t, c.Register(&AnonService{
   121  				InitF: func(context.Context) error {
   122  					inited = append(inited, 2)
   123  					return nil
   124  				},
   125  				RunF:  func(context.Context) {},
   126  				StopF: func() error { return nil },
   127  			}))
   128  			ctx := context.Background()
   129  			var wg sync.WaitGroup
   130  			wg.Add(1)
   131  			go func() {
   132  				defer wg.Done()
   133  				require.NoError(t, c.WaitForStart())
   134  				c.Stop()
   135  			}()
   136  			require.NoError(t, c.Start(ctx))
   137  			require.NoError(t, c.WaitForStop())
   138  			require.Equal(t, inited, []int{0, 1, 2})
   139  			wg.Wait()
   140  		})
   141  		t.Run("StopsCallingInitOnFirstError", func(t *testing.T) {
   142  			err := errors.New("first error")
   143  			c := NewController()
   144  			var inited []int
   145  			require.NoError(t, c.Register(&AnonService{
   146  				InitF: func(context.Context) error {
   147  					inited = append(inited, 0)
   148  					return nil
   149  				},
   150  				RunF:  func(context.Context) {},
   151  				StopF: func() error { return nil },
   152  			}))
   153  			require.NoError(t, c.Register(&AnonService{
   154  				InitF: func(context.Context) error {
   155  					inited = append(inited, 1)
   156  					return nil
   157  				},
   158  				RunF:  func(context.Context) {},
   159  				StopF: func() error { return nil },
   160  			}))
   161  			require.NoError(t, c.Register(&AnonService{
   162  				InitF: func(context.Context) error {
   163  					return err
   164  				},
   165  				RunF:  func(context.Context) {},
   166  				StopF: func() error { return nil },
   167  			}))
   168  			require.NoError(t, c.Register(&AnonService{
   169  				InitF: func(context.Context) error {
   170  					inited = append(inited, 2)
   171  					return nil
   172  				},
   173  				RunF:  func(context.Context) {},
   174  				StopF: func() error { return nil },
   175  			}))
   176  			ctx := context.Background()
   177  			var wg sync.WaitGroup
   178  			wg.Add(1)
   179  			go func() {
   180  				defer wg.Done()
   181  				require.ErrorIs(t, c.WaitForStart(), err)
   182  				c.Stop()
   183  			}()
   184  			require.ErrorIs(t, c.Start(ctx), err)
   185  			require.ErrorIs(t, c.WaitForStop(), err)
   186  			require.Equal(t, inited, []int{0, 1})
   187  			wg.Wait()
   188  		})
   189  		t.Run("CallsStopWhenInitErrors", func(t *testing.T) {
   190  			err := errors.New("first error")
   191  			c := NewController()
   192  			var stopped []int
   193  			require.NoError(t, c.Register(&AnonService{
   194  				InitF: func(context.Context) error {
   195  					return nil
   196  				},
   197  				RunF: func(context.Context) {},
   198  				StopF: func() error {
   199  					stopped = append(stopped, 0)
   200  					return nil
   201  				},
   202  			}))
   203  			require.NoError(t, c.Register(&AnonService{
   204  				InitF: func(context.Context) error {
   205  					return nil
   206  				},
   207  				RunF: func(context.Context) {},
   208  				StopF: func() error {
   209  					stopped = append(stopped, 1)
   210  					return nil
   211  				},
   212  			}))
   213  			require.NoError(t, c.Register(&AnonService{
   214  				InitF: func(context.Context) error {
   215  					return err
   216  				},
   217  				RunF: func(context.Context) {},
   218  				StopF: func() error {
   219  					stopped = append(stopped, 2)
   220  					return nil
   221  				},
   222  			}))
   223  			require.NoError(t, c.Register(&AnonService{
   224  				InitF: func(context.Context) error {
   225  					return nil
   226  				},
   227  				RunF: func(context.Context) {},
   228  				StopF: func() error {
   229  					stopped = append(stopped, 3)
   230  					return nil
   231  				},
   232  			}))
   233  			ctx := context.Background()
   234  			var wg sync.WaitGroup
   235  			wg.Add(1)
   236  			go func() {
   237  				defer wg.Done()
   238  				require.ErrorIs(t, c.WaitForStart(), err)
   239  				c.Stop()
   240  			}()
   241  			require.ErrorIs(t, c.Start(ctx), err)
   242  			require.ErrorIs(t, c.WaitForStop(), err)
   243  			require.Equal(t, stopped, []int{1, 0})
   244  			wg.Wait()
   245  		})
   246  		t.Run("RunsServices", func(t *testing.T) {
   247  			c := NewController()
   248  			var wg sync.WaitGroup
   249  			wg.Add(2)
   250  			require.NoError(t, c.Register(&AnonService{
   251  				InitF: func(context.Context) error { return nil },
   252  				RunF:  func(context.Context) { wg.Done() },
   253  				StopF: func() error { return nil },
   254  			}))
   255  			require.NoError(t, c.Register(&AnonService{
   256  				InitF: func(context.Context) error { return nil },
   257  				RunF:  func(context.Context) { wg.Done() },
   258  				StopF: func() error { return nil },
   259  			}))
   260  			ctx := context.Background()
   261  			var cwg sync.WaitGroup
   262  			cwg.Add(1)
   263  			go func() {
   264  				defer cwg.Done()
   265  				require.NoError(t, c.WaitForStart())
   266  				c.Stop()
   267  			}()
   268  			require.NoError(t, c.Start(ctx))
   269  			require.NoError(t, c.WaitForStop())
   270  			wg.Wait()
   271  			cwg.Wait()
   272  		})
   273  		t.Run("StopsAllServices", func(t *testing.T) {
   274  			c := NewController()
   275  			var wg sync.WaitGroup
   276  			err := errors.New("first error")
   277  			wg.Add(2)
   278  			require.NoError(t, c.Register(&AnonService{
   279  				InitF: func(context.Context) error { return nil },
   280  				RunF:  func(context.Context) {},
   281  				StopF: func() error {
   282  					wg.Done()
   283  					return errors.New("second error")
   284  				},
   285  			}))
   286  			require.NoError(t, c.Register(&AnonService{
   287  				InitF: func(context.Context) error { return nil },
   288  				RunF:  func(context.Context) {},
   289  				StopF: func() error {
   290  					wg.Done()
   291  					return err
   292  				},
   293  			}))
   294  			ctx := context.Background()
   295  			var cwg sync.WaitGroup
   296  			cwg.Add(1)
   297  			go func() {
   298  				defer cwg.Done()
   299  				require.NoError(t, c.WaitForStart())
   300  				c.Stop()
   301  			}()
   302  			require.ErrorIs(t, c.Start(ctx), err)
   303  			require.ErrorIs(t, c.WaitForStop(), err)
   304  			wg.Wait()
   305  			cwg.Wait()
   306  		})
   307  	})
   308  	t.Run("RunStopsControllerExample", func(t *testing.T) {
   309  		// |Run| has no way to return an error, but it *can* use the
   310  		// controller itself to coordinate a shutdown of all the
   311  		// services and to ensure that an error is returned from its
   312  		// Service's |Close| method.
   313  		c := NewController()
   314  		ctx := context.Background()
   315  
   316  		expectedErr := errors.New("error set from run")
   317  		errCh := make(chan error)
   318  		var runErr error
   319  		var runWg sync.WaitGroup
   320  		runWg.Add(1)
   321  		failingService := &AnonService{
   322  			RunF: func(context.Context) {
   323  				runErr = <-errCh
   324  				// Do this in the background, since it will block on StopF down below being completed.
   325  				go c.Stop()
   326  				runWg.Done()
   327  			},
   328  			StopF: func() error {
   329  				runWg.Wait()
   330  				return runErr
   331  			},
   332  		}
   333  		c.Register(failingService)
   334  
   335  		// See how we do not call |Stop| on the controller here. The
   336  		// "failing" Run method of the failingService will call it.
   337  		var cwg sync.WaitGroup
   338  		cwg.Add(1)
   339  		go func() {
   340  			defer cwg.Done()
   341  			require.ErrorIs(t, c.Start(ctx), expectedErr)
   342  		}()
   343  		require.NoError(t, c.WaitForStart())
   344  		errCh <- expectedErr
   345  		require.ErrorIs(t, c.WaitForStop(), expectedErr)
   346  		cwg.Wait()
   347  	})
   348  }