github.com/m3db/m3@v1.5.0/src/cluster/kv/util/lock_test.go (about)

     1  // Copyright (c) 2018 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package util
    22  
    23  import (
    24  	"errors"
    25  	"fmt"
    26  	"math/rand"
    27  	"strings"
    28  	"sync"
    29  	"testing"
    30  	"time"
    31  
    32  	"github.com/m3db/m3/src/cluster/generated/proto/commonpb"
    33  	"github.com/m3db/m3/src/cluster/generated/proto/testpb"
    34  	"github.com/m3db/m3/src/cluster/kv"
    35  	"github.com/m3db/m3/src/cluster/kv/mem"
    36  
    37  	"github.com/fortytw2/leaktest"
    38  	"github.com/stretchr/testify/require"
    39  )
    40  
    41  var (
    42  	testNow = time.Now()
    43  )
    44  
    45  func TestWatchAndUpdateBool(t *testing.T) {
    46  	testConfig := struct {
    47  		sync.RWMutex
    48  		v bool
    49  	}{}
    50  
    51  	valueFn := func() bool {
    52  		testConfig.RLock()
    53  		defer testConfig.RUnlock()
    54  
    55  		return testConfig.v
    56  	}
    57  
    58  	var (
    59  		store        = mem.NewStore()
    60  		defaultValue = false
    61  	)
    62  
    63  	watch, err := WatchAndUpdateBool(
    64  		store, "foo", &testConfig.v, &testConfig.RWMutex, defaultValue, nil,
    65  	)
    66  	require.NoError(t, err)
    67  
    68  	// Valid update.
    69  	_, err = store.Set("foo", &commonpb.BoolProto{Value: false})
    70  	require.NoError(t, err)
    71  	for {
    72  		if !valueFn() {
    73  			break
    74  		}
    75  	}
    76  
    77  	// Malformed updates should not be applied.
    78  	_, err = store.Set("foo", &commonpb.Float64Proto{Value: 20})
    79  	require.NoError(t, err)
    80  	time.Sleep(100 * time.Millisecond)
    81  	require.False(t, valueFn())
    82  
    83  	_, err = store.Set("foo", &commonpb.BoolProto{Value: true})
    84  	require.NoError(t, err)
    85  	for {
    86  		if valueFn() {
    87  			break
    88  		}
    89  	}
    90  
    91  	// Nil updates should apply the default value.
    92  	_, err = store.Delete("foo")
    93  	require.NoError(t, err)
    94  	for {
    95  		if !valueFn() {
    96  			break
    97  		}
    98  	}
    99  
   100  	_, err = store.Set("foo", &commonpb.BoolProto{Value: true})
   101  	require.NoError(t, err)
   102  	for {
   103  		if valueFn() {
   104  			break
   105  		}
   106  	}
   107  
   108  	// Updates should not be applied after the watch is closed and there should not
   109  	// be any goroutines still running.
   110  	watch.Close()
   111  	time.Sleep(100 * time.Millisecond)
   112  	_, err = store.Set("foo", &commonpb.BoolProto{Value: false})
   113  	require.NoError(t, err)
   114  	time.Sleep(100 * time.Millisecond)
   115  	require.True(t, valueFn())
   116  
   117  	leaktest.Check(t)()
   118  }
   119  
   120  func TestWatchAndUpdateFloat64(t *testing.T) {
   121  	testConfig := struct {
   122  		sync.RWMutex
   123  		v float64
   124  	}{}
   125  
   126  	valueFn := func() float64 {
   127  		testConfig.RLock()
   128  		defer testConfig.RUnlock()
   129  
   130  		return testConfig.v
   131  	}
   132  
   133  	var (
   134  		store        = mem.NewStore()
   135  		defaultValue = 1.35
   136  	)
   137  
   138  	watch, err := WatchAndUpdateFloat64(
   139  		store, "foo", &testConfig.v, &testConfig.RWMutex, defaultValue, nil,
   140  	)
   141  	require.NoError(t, err)
   142  
   143  	_, err = store.Set("foo", &commonpb.Float64Proto{Value: 3.7})
   144  	require.NoError(t, err)
   145  	for {
   146  		if valueFn() == 3.7 {
   147  			break
   148  		}
   149  	}
   150  
   151  	// Malformed updates should not be applied.
   152  	_, err = store.Set("foo", &commonpb.Int64Proto{Value: 1})
   153  	require.NoError(t, err)
   154  	time.Sleep(100 * time.Millisecond)
   155  	require.Equal(t, 3.7, valueFn())
   156  
   157  	_, err = store.Set("foo", &commonpb.Float64Proto{Value: 1.2})
   158  	require.NoError(t, err)
   159  	for {
   160  		if valueFn() == 1.2 {
   161  			break
   162  		}
   163  	}
   164  
   165  	// Nil updates should apply the default value.
   166  	_, err = store.Delete("foo")
   167  	require.NoError(t, err)
   168  	for {
   169  		if valueFn() == defaultValue {
   170  			break
   171  		}
   172  	}
   173  
   174  	_, err = store.Set("foo", &commonpb.Float64Proto{Value: 6.2})
   175  	require.NoError(t, err)
   176  	for {
   177  		if valueFn() == 6.2 {
   178  			break
   179  		}
   180  	}
   181  
   182  	// Updates should not be applied after the watch is closed and there should not
   183  	// be any goroutines still running.
   184  	watch.Close()
   185  	time.Sleep(100 * time.Millisecond)
   186  	_, err = store.Set("foo", &commonpb.Float64Proto{Value: 7.2})
   187  	require.NoError(t, err)
   188  	time.Sleep(100 * time.Millisecond)
   189  	require.Equal(t, 6.2, valueFn())
   190  
   191  	leaktest.Check(t)
   192  }
   193  
   194  func TestWatchAndUpdateInt64(t *testing.T) {
   195  	testConfig := struct {
   196  		sync.RWMutex
   197  		v int64
   198  	}{}
   199  
   200  	valueFn := func() int64 {
   201  		testConfig.RLock()
   202  		defer testConfig.RUnlock()
   203  
   204  		return testConfig.v
   205  	}
   206  
   207  	var (
   208  		store              = mem.NewStore()
   209  		defaultValue int64 = 3
   210  	)
   211  
   212  	watch, err := WatchAndUpdateInt64(
   213  		store, "foo", &testConfig.v, &testConfig.RWMutex, defaultValue, nil,
   214  	)
   215  	require.NoError(t, err)
   216  
   217  	_, err = store.Set("foo", &commonpb.Int64Proto{Value: 1})
   218  	require.NoError(t, err)
   219  	for {
   220  		if valueFn() == 1 {
   221  			break
   222  		}
   223  	}
   224  
   225  	// Malformed updates should not be applied.
   226  	_, err = store.Set("foo", &commonpb.Float64Proto{Value: 100})
   227  	require.NoError(t, err)
   228  	time.Sleep(100 * time.Millisecond)
   229  	require.Equal(t, int64(1), valueFn())
   230  
   231  	_, err = store.Set("foo", &commonpb.Int64Proto{Value: 7})
   232  	require.NoError(t, err)
   233  	for {
   234  		if valueFn() == 7 {
   235  			break
   236  		}
   237  	}
   238  
   239  	// Nil updates should apply the default value.
   240  	_, err = store.Delete("foo")
   241  	require.NoError(t, err)
   242  	for {
   243  		if valueFn() == defaultValue {
   244  			break
   245  		}
   246  	}
   247  
   248  	_, err = store.Set("foo", &commonpb.Int64Proto{Value: 21})
   249  	require.NoError(t, err)
   250  	for {
   251  		if valueFn() == 21 {
   252  			break
   253  		}
   254  	}
   255  
   256  	// Updates should not be applied after the watch is closed and there should not
   257  	// be any goroutines still running.
   258  	watch.Close()
   259  	time.Sleep(100 * time.Millisecond)
   260  	_, err = store.Set("foo", &commonpb.Int64Proto{Value: 13})
   261  	require.NoError(t, err)
   262  	time.Sleep(100 * time.Millisecond)
   263  	require.Equal(t, int64(21), valueFn())
   264  
   265  	leaktest.Check(t)
   266  }
   267  
   268  func TestWatchAndUpdateString(t *testing.T) {
   269  	testConfig := struct {
   270  		sync.RWMutex
   271  		v string
   272  	}{}
   273  
   274  	valueFn := func() string {
   275  		testConfig.RLock()
   276  		defer testConfig.RUnlock()
   277  
   278  		return testConfig.v
   279  	}
   280  
   281  	var (
   282  		store        = mem.NewStore()
   283  		defaultValue = "abc"
   284  	)
   285  
   286  	watch, err := WatchAndUpdateString(
   287  		store, "foo", &testConfig.v, &testConfig.RWMutex, defaultValue, nil,
   288  	)
   289  	require.NoError(t, err)
   290  
   291  	_, err = store.Set("foo", &commonpb.StringProto{Value: "fizz"})
   292  	require.NoError(t, err)
   293  	for {
   294  		if valueFn() == "fizz" {
   295  			break
   296  		}
   297  	}
   298  
   299  	// Malformed updates should not be applied.
   300  	_, err = store.Set("foo", &commonpb.Float64Proto{Value: 100})
   301  	require.NoError(t, err)
   302  	time.Sleep(100 * time.Millisecond)
   303  	require.Equal(t, "fizz", valueFn())
   304  
   305  	_, err = store.Set("foo", &commonpb.StringProto{Value: "buzz"})
   306  	require.NoError(t, err)
   307  	for {
   308  		if valueFn() == "buzz" {
   309  			break
   310  		}
   311  	}
   312  
   313  	// Nil updates should apply the default value.
   314  	_, err = store.Delete("foo")
   315  	require.NoError(t, err)
   316  	for {
   317  		if valueFn() == defaultValue {
   318  			break
   319  		}
   320  	}
   321  
   322  	_, err = store.Set("foo", &commonpb.StringProto{Value: "lol"})
   323  	require.NoError(t, err)
   324  	for {
   325  		if valueFn() == "lol" {
   326  			break
   327  		}
   328  	}
   329  
   330  	// Updates should not be applied after the watch is closed and there should not
   331  	// be any goroutines still running.
   332  	watch.Close()
   333  	time.Sleep(100 * time.Millisecond)
   334  	_, err = store.Set("foo", &commonpb.StringProto{Value: "abc"})
   335  	require.NoError(t, err)
   336  	time.Sleep(100 * time.Millisecond)
   337  	require.Equal(t, "lol", valueFn())
   338  
   339  	leaktest.Check(t)
   340  }
   341  
   342  func TestWatchAndUpdateStringArray(t *testing.T) {
   343  	testConfig := struct {
   344  		sync.RWMutex
   345  		v []string
   346  	}{}
   347  
   348  	valueFn := func() []string {
   349  		testConfig.RLock()
   350  		defer testConfig.RUnlock()
   351  
   352  		return testConfig.v
   353  	}
   354  
   355  	var (
   356  		store        = mem.NewStore()
   357  		defaultValue = []string{"abc", "def"}
   358  	)
   359  
   360  	watch, err := WatchAndUpdateStringArray(
   361  		store, "foo", &testConfig.v, &testConfig.RWMutex, defaultValue, nil,
   362  	)
   363  	require.NoError(t, err)
   364  
   365  	_, err = store.Set("foo", &commonpb.StringArrayProto{Values: []string{"fizz", "buzz"}})
   366  	require.NoError(t, err)
   367  	for {
   368  		if stringSliceEquals(valueFn(), []string{"fizz", "buzz"}) {
   369  			break
   370  		}
   371  	}
   372  
   373  	// Malformed updates should not be applied.
   374  	_, err = store.Set("foo", &commonpb.Float64Proto{Value: 12.3})
   375  	require.NoError(t, err)
   376  	time.Sleep(100 * time.Millisecond)
   377  	require.Equal(t, []string{"fizz", "buzz"}, valueFn())
   378  
   379  	_, err = store.Set("foo", &commonpb.StringArrayProto{Values: []string{"foo", "bar"}})
   380  	require.NoError(t, err)
   381  	for {
   382  		if stringSliceEquals(valueFn(), []string{"foo", "bar"}) {
   383  			break
   384  		}
   385  	}
   386  
   387  	// Nil updates should apply the default value.
   388  	_, err = store.Delete("foo")
   389  	require.NoError(t, err)
   390  	for {
   391  		if stringSliceEquals(valueFn(), defaultValue) {
   392  			break
   393  		}
   394  	}
   395  
   396  	_, err = store.Set("foo", &commonpb.StringArrayProto{Values: []string{"jim", "jam"}})
   397  	require.NoError(t, err)
   398  	for {
   399  		if stringSliceEquals(valueFn(), []string{"jim", "jam"}) {
   400  			break
   401  		}
   402  	}
   403  
   404  	// Updates should not be applied after the watch is closed and there should not
   405  	// be any goroutines still running.
   406  	watch.Close()
   407  	time.Sleep(100 * time.Millisecond)
   408  	_, err = store.Set("foo", &commonpb.StringArrayProto{Values: []string{"abc", "def"}})
   409  	require.NoError(t, err)
   410  	time.Sleep(100 * time.Millisecond)
   411  	require.Equal(t, []string{"jim", "jam"}, valueFn())
   412  
   413  	leaktest.Check(t)
   414  }
   415  
   416  func TestWatchAndUpdateStringArrayPointer(t *testing.T) {
   417  	testConfig := struct {
   418  		sync.RWMutex
   419  		v *[]string
   420  	}{}
   421  
   422  	valueFn := func() *[]string {
   423  		testConfig.RLock()
   424  		defer testConfig.RUnlock()
   425  
   426  		return testConfig.v
   427  	}
   428  
   429  	var (
   430  		store        = mem.NewStore()
   431  		defaultValue = []string{"abc", "def"}
   432  	)
   433  
   434  	watch, err := WatchAndUpdateStringArrayPointer(
   435  		store, "foo", &testConfig.v, &testConfig.RWMutex, &defaultValue, nil,
   436  	)
   437  	require.NoError(t, err)
   438  
   439  	_, err = store.Set("foo", &commonpb.StringArrayProto{Values: []string{"fizz", "buzz"}})
   440  	require.NoError(t, err)
   441  	for {
   442  		res := valueFn()
   443  		if res != nil && stringSliceEquals(*res, []string{"fizz", "buzz"}) {
   444  			break
   445  		}
   446  	}
   447  
   448  	// Malformed updates should not be applied.
   449  	_, err = store.Set("foo", &commonpb.Float64Proto{Value: 12.3})
   450  	require.NoError(t, err)
   451  	time.Sleep(100 * time.Millisecond)
   452  	require.Equal(t, []string{"fizz", "buzz"}, *valueFn())
   453  
   454  	_, err = store.Set("foo", &commonpb.StringArrayProto{Values: []string{"foo", "bar"}})
   455  	require.NoError(t, err)
   456  	for {
   457  		if stringSliceEquals(*valueFn(), []string{"foo", "bar"}) {
   458  			break
   459  		}
   460  	}
   461  
   462  	// Nil updates should apply the default value.
   463  	_, err = store.Delete("foo")
   464  	require.NoError(t, err)
   465  	for {
   466  		if stringSliceEquals(*valueFn(), defaultValue) {
   467  			break
   468  		}
   469  	}
   470  
   471  	_, err = store.Set("foo", &commonpb.StringArrayProto{Values: []string{"jim", "jam"}})
   472  	require.NoError(t, err)
   473  	for {
   474  		if stringSliceEquals(*valueFn(), []string{"jim", "jam"}) {
   475  			break
   476  		}
   477  	}
   478  
   479  	// Updates should not be applied after the watch is closed and there should not
   480  	// be any goroutines still running.
   481  	watch.Close()
   482  	time.Sleep(100 * time.Millisecond)
   483  	_, err = store.Set("foo", &commonpb.StringArrayProto{Values: []string{"abc", "def"}})
   484  	require.NoError(t, err)
   485  	time.Sleep(100 * time.Millisecond)
   486  	require.Equal(t, []string{"jim", "jam"}, *valueFn())
   487  
   488  	leaktest.Check(t)
   489  }
   490  
   491  func TestWatchAndUpdateStringArrayPointerWithNilDefault(t *testing.T) {
   492  	testConfig := struct {
   493  		sync.RWMutex
   494  		v *[]string
   495  	}{}
   496  
   497  	valueFn := func() *[]string {
   498  		testConfig.RLock()
   499  		defer testConfig.RUnlock()
   500  
   501  		return testConfig.v
   502  	}
   503  
   504  	var store = mem.NewStore()
   505  	watch, err := WatchAndUpdateStringArrayPointer(
   506  		store, "foo", &testConfig.v, &testConfig.RWMutex, nil, nil,
   507  	)
   508  	require.NoError(t, err)
   509  
   510  	_, err = store.Set("foo", &commonpb.StringArrayProto{Values: []string{"fizz", "buzz"}})
   511  	require.NoError(t, err)
   512  	for {
   513  		res := valueFn()
   514  		if res != nil && stringSliceEquals(*res, []string{"fizz", "buzz"}) {
   515  			break
   516  		}
   517  	}
   518  
   519  	// Malformed updates should not be applied.
   520  	_, err = store.Set("foo", &commonpb.Float64Proto{Value: 12.3})
   521  	require.NoError(t, err)
   522  	time.Sleep(100 * time.Millisecond)
   523  	require.Equal(t, []string{"fizz", "buzz"}, *valueFn())
   524  
   525  	_, err = store.Set("foo", &commonpb.StringArrayProto{Values: []string{"foo", "bar"}})
   526  	require.NoError(t, err)
   527  	for {
   528  		if stringSliceEquals(*valueFn(), []string{"foo", "bar"}) {
   529  			break
   530  		}
   531  	}
   532  
   533  	// Nil updates should apply the default value.
   534  	_, err = store.Delete("foo")
   535  	require.NoError(t, err)
   536  	for {
   537  		if valueFn() == nil {
   538  			break
   539  		}
   540  	}
   541  
   542  	_, err = store.Set("foo", &commonpb.StringArrayProto{Values: []string{"jim", "jam"}})
   543  	require.NoError(t, err)
   544  	for {
   545  		res := valueFn()
   546  		if res != nil && stringSliceEquals(*res, []string{"jim", "jam"}) {
   547  			break
   548  		}
   549  	}
   550  
   551  	// Updates should not be applied after the watch is closed and there should not
   552  	// be any goroutines still running.
   553  	watch.Close()
   554  	time.Sleep(100 * time.Millisecond)
   555  	_, err = store.Set("foo", &commonpb.StringArrayProto{Values: []string{"abc", "def"}})
   556  	require.NoError(t, err)
   557  	time.Sleep(100 * time.Millisecond)
   558  	require.Equal(t, []string{"jim", "jam"}, *valueFn())
   559  
   560  	leaktest.Check(t)
   561  }
   562  
   563  func TestWatchAndUpdateTime(t *testing.T) {
   564  	testConfig := struct {
   565  		sync.RWMutex
   566  		v time.Time
   567  	}{}
   568  
   569  	valueFn := func() time.Time {
   570  		testConfig.RLock()
   571  		defer testConfig.RUnlock()
   572  
   573  		return testConfig.v
   574  	}
   575  
   576  	var (
   577  		store        = mem.NewStore()
   578  		defaultValue = time.Now()
   579  	)
   580  
   581  	watch, err := WatchAndUpdateTime(store, "foo", &testConfig.v, &testConfig.RWMutex, defaultValue, nil)
   582  	require.NoError(t, err)
   583  
   584  	newTime := defaultValue.Add(time.Minute)
   585  	_, err = store.Set("foo", &commonpb.Int64Proto{Value: newTime.Unix()})
   586  	require.NoError(t, err)
   587  	for {
   588  		if valueFn().Unix() == newTime.Unix() {
   589  			break
   590  		}
   591  	}
   592  
   593  	// Malformed updates should not be applied.
   594  	_, err = store.Set("foo", &commonpb.Float64Proto{Value: 100})
   595  	require.NoError(t, err)
   596  	time.Sleep(100 * time.Millisecond)
   597  	require.Equal(t, newTime.Unix(), valueFn().Unix())
   598  
   599  	newTime = newTime.Add(time.Minute)
   600  	_, err = store.Set("foo", &commonpb.Int64Proto{Value: newTime.Unix()})
   601  	require.NoError(t, err)
   602  	for {
   603  		if valueFn().Unix() == newTime.Unix() {
   604  			break
   605  		}
   606  	}
   607  
   608  	// Nil updates should apply the default value.
   609  	_, err = store.Delete("foo")
   610  	require.NoError(t, err)
   611  	for {
   612  		if valueFn().Unix() == defaultValue.Unix() {
   613  			break
   614  		}
   615  	}
   616  
   617  	newTime = newTime.Add(time.Minute)
   618  	_, err = store.Set("foo", &commonpb.Int64Proto{Value: newTime.Unix()})
   619  	require.NoError(t, err)
   620  	for {
   621  		if valueFn().Unix() == newTime.Unix() {
   622  			break
   623  		}
   624  	}
   625  
   626  	// Updates should not be applied after the watch is closed and there should not
   627  	// be any goroutines still running.
   628  	watch.Close()
   629  	time.Sleep(100 * time.Millisecond)
   630  	_, err = store.Set("foo", &commonpb.Int64Proto{Value: defaultValue.Unix()})
   631  	require.NoError(t, err)
   632  	time.Sleep(100 * time.Millisecond)
   633  	require.Equal(t, newTime.Unix(), valueFn().Unix())
   634  
   635  	leaktest.Check(t)
   636  }
   637  
   638  func TestWatchAndUpdateDuration(t *testing.T) {
   639  	testConfig := struct {
   640  		sync.RWMutex
   641  		v time.Duration
   642  	}{}
   643  
   644  	valueFn := func() time.Duration {
   645  		testConfig.RLock()
   646  		defer testConfig.RUnlock()
   647  
   648  		return testConfig.v
   649  	}
   650  
   651  	var (
   652  		store        = mem.NewStore()
   653  		defaultValue = time.Duration(rand.Int63())
   654  	)
   655  
   656  	watch, err := WatchAndUpdateDuration(store, "foo", &testConfig.v, &testConfig.RWMutex, defaultValue, nil)
   657  	require.NoError(t, err)
   658  
   659  	newDuration := time.Duration(rand.Int63())
   660  	_, err = store.Set("foo", &commonpb.Int64Proto{Value: int64(newDuration)})
   661  	require.NoError(t, err)
   662  	for {
   663  		if valueFn() == newDuration {
   664  			break
   665  		}
   666  	}
   667  
   668  	// Malformed updates should not be applied.
   669  	_, err = store.Set("foo", &commonpb.Float64Proto{Value: 100})
   670  	require.NoError(t, err)
   671  	time.Sleep(100 * time.Millisecond)
   672  	require.Equal(t, newDuration, valueFn())
   673  
   674  	newDuration = time.Duration(rand.Int63())
   675  	_, err = store.Set("foo", &commonpb.Int64Proto{Value: int64(newDuration)})
   676  	require.NoError(t, err)
   677  	for {
   678  		if valueFn() == newDuration {
   679  			break
   680  		}
   681  	}
   682  
   683  	// Nil updates should apply the default value.
   684  	_, err = store.Delete("foo")
   685  	require.NoError(t, err)
   686  	for {
   687  		if valueFn() == defaultValue {
   688  			break
   689  		}
   690  	}
   691  
   692  	newDuration = time.Duration(rand.Int63())
   693  	_, err = store.Set("foo", &commonpb.Int64Proto{Value: int64(newDuration)})
   694  	require.NoError(t, err)
   695  	for {
   696  		if valueFn() == newDuration {
   697  			break
   698  		}
   699  	}
   700  
   701  	// Updates should not be applied after the watch is closed and there should not
   702  	// be any goroutines still running.
   703  	watch.Close()
   704  	time.Sleep(100 * time.Millisecond)
   705  	_, err = store.Set("foo", &commonpb.Int64Proto{Value: int64(defaultValue)})
   706  	require.NoError(t, err)
   707  	time.Sleep(100 * time.Millisecond)
   708  	require.Equal(t, newDuration, valueFn())
   709  
   710  	leaktest.Check(t)
   711  }
   712  
   713  func TestWatchAndUpdateWithValidationBool(t *testing.T) {
   714  	testConfig := struct {
   715  		sync.RWMutex
   716  		v bool
   717  	}{}
   718  
   719  	valueFn := func() bool {
   720  		testConfig.RLock()
   721  		defer testConfig.RUnlock()
   722  
   723  		return testConfig.v
   724  	}
   725  
   726  	var (
   727  		store = mem.NewStore()
   728  		opts  = NewOptions().SetValidateFn(testValidateBoolFn)
   729  	)
   730  
   731  	_, err := WatchAndUpdateBool(store, "foo", &testConfig.v, &testConfig.RWMutex, false, opts)
   732  	require.NoError(t, err)
   733  
   734  	_, err = store.Set("foo", &commonpb.BoolProto{Value: true})
   735  	require.NoError(t, err)
   736  	for {
   737  		if valueFn() {
   738  			break
   739  		}
   740  	}
   741  
   742  	// Invalid updates should not be applied.
   743  	_, err = store.Set("foo", &commonpb.BoolProto{Value: false})
   744  	require.NoError(t, err)
   745  	for {
   746  		if valueFn() {
   747  			break
   748  		}
   749  	}
   750  }
   751  
   752  func TestWatchAndUpdateWithValidationFloat64(t *testing.T) {
   753  	testConfig := struct {
   754  		sync.RWMutex
   755  		v float64
   756  	}{}
   757  
   758  	valueFn := func() float64 {
   759  		testConfig.RLock()
   760  		defer testConfig.RUnlock()
   761  
   762  		return testConfig.v
   763  	}
   764  
   765  	var (
   766  		store = mem.NewStore()
   767  		opts  = NewOptions().SetValidateFn(testValidateFloat64Fn)
   768  	)
   769  
   770  	_, err := WatchAndUpdateFloat64(store, "foo", &testConfig.v, &testConfig.RWMutex, 1.2, opts)
   771  	require.NoError(t, err)
   772  
   773  	_, err = store.Set("foo", &commonpb.Float64Proto{Value: 17})
   774  	require.NoError(t, err)
   775  	for {
   776  		if valueFn() == 17 {
   777  			break
   778  		}
   779  	}
   780  
   781  	// Invalid updates should not be applied.
   782  	_, err = store.Set("foo", &commonpb.Float64Proto{Value: 22})
   783  	require.NoError(t, err)
   784  	time.Sleep(100 * time.Millisecond)
   785  	require.Equal(t, float64(17), valueFn())
   786  
   787  	_, err = store.Set("foo", &commonpb.Float64Proto{Value: 1})
   788  	require.NoError(t, err)
   789  	for {
   790  		if valueFn() == 1 {
   791  			break
   792  		}
   793  	}
   794  }
   795  
   796  func TestWatchAndUpdateWithValidationInt64(t *testing.T) {
   797  	testConfig := struct {
   798  		sync.RWMutex
   799  		v int64
   800  	}{}
   801  
   802  	valueFn := func() int64 {
   803  		testConfig.RLock()
   804  		defer testConfig.RUnlock()
   805  
   806  		return testConfig.v
   807  	}
   808  
   809  	var (
   810  		store = mem.NewStore()
   811  		opts  = NewOptions().SetValidateFn(testValidateInt64Fn)
   812  	)
   813  
   814  	_, err := WatchAndUpdateInt64(store, "foo", &testConfig.v, &testConfig.RWMutex, 16, opts)
   815  	require.NoError(t, err)
   816  
   817  	_, err = store.Set("foo", &commonpb.Int64Proto{Value: 17})
   818  	require.NoError(t, err)
   819  	for {
   820  		if valueFn() == 17 {
   821  			break
   822  		}
   823  	}
   824  
   825  	// Invalid updates should not be applied.
   826  	_, err = store.Set("foo", &commonpb.Int64Proto{Value: 22})
   827  	require.NoError(t, err)
   828  	time.Sleep(100 * time.Millisecond)
   829  	require.Equal(t, int64(17), valueFn())
   830  
   831  	_, err = store.Set("foo", &commonpb.Int64Proto{Value: 1})
   832  	require.NoError(t, err)
   833  	for {
   834  		if valueFn() == 1 {
   835  			break
   836  		}
   837  	}
   838  }
   839  
   840  func TestWatchAndUpdateWithValidationString(t *testing.T) {
   841  	testConfig := struct {
   842  		sync.RWMutex
   843  		v string
   844  	}{}
   845  
   846  	valueFn := func() string {
   847  		testConfig.RLock()
   848  		defer testConfig.RUnlock()
   849  
   850  		return testConfig.v
   851  	}
   852  
   853  	var (
   854  		store = mem.NewStore()
   855  		opts  = NewOptions().SetValidateFn(testValidateStringFn)
   856  	)
   857  
   858  	_, err := WatchAndUpdateString(store, "foo", &testConfig.v, &testConfig.RWMutex, "bcd", opts)
   859  	require.NoError(t, err)
   860  
   861  	_, err = store.Set("foo", &commonpb.StringProto{Value: "bar"})
   862  	require.NoError(t, err)
   863  	for {
   864  		if valueFn() == "bar" {
   865  			break
   866  		}
   867  	}
   868  
   869  	// Invalid updates should not be applied.
   870  	_, err = store.Set("foo", &commonpb.StringProto{Value: "cat"})
   871  	require.NoError(t, err)
   872  	time.Sleep(100 * time.Millisecond)
   873  	require.Equal(t, "bar", valueFn())
   874  
   875  	_, err = store.Set("foo", &commonpb.StringProto{Value: "baz"})
   876  	require.NoError(t, err)
   877  	for {
   878  		if valueFn() == "bar" {
   879  			break
   880  		}
   881  	}
   882  }
   883  
   884  func TestWatchAndUpdateWithValidationStringArray(t *testing.T) {
   885  	testConfig := struct {
   886  		sync.RWMutex
   887  		v []string
   888  	}{}
   889  
   890  	valueFn := func() []string {
   891  		testConfig.RLock()
   892  		defer testConfig.RUnlock()
   893  
   894  		return testConfig.v
   895  	}
   896  
   897  	var (
   898  		store = mem.NewStore()
   899  		opts  = NewOptions().SetValidateFn(testValidateStringArrayFn)
   900  	)
   901  
   902  	_, err := WatchAndUpdateStringArray(
   903  		store, "foo", &testConfig.v, &testConfig.RWMutex, []string{"a", "b"}, opts,
   904  	)
   905  	require.NoError(t, err)
   906  
   907  	_, err = store.Set("foo", &commonpb.StringArrayProto{Values: []string{"fizz", "buzz"}})
   908  	require.NoError(t, err)
   909  	for {
   910  		if stringSliceEquals([]string{"fizz", "buzz"}, valueFn()) {
   911  			break
   912  		}
   913  	}
   914  
   915  	// Invalid updates should not be applied.
   916  	_, err = store.Set("foo", &commonpb.StringArrayProto{Values: []string{"cat"}})
   917  	require.NoError(t, err)
   918  	time.Sleep(100 * time.Millisecond)
   919  	require.Equal(t, []string{"fizz", "buzz"}, valueFn())
   920  
   921  	_, err = store.Set("foo", &commonpb.StringArrayProto{Values: []string{"jim", "jam"}})
   922  	require.NoError(t, err)
   923  	for {
   924  		if stringSliceEquals([]string{"jim", "jam"}, valueFn()) {
   925  			break
   926  		}
   927  	}
   928  }
   929  
   930  func TestWatchAndUpdateWithValidationTime(t *testing.T) {
   931  	testConfig := struct {
   932  		sync.RWMutex
   933  		v time.Time
   934  	}{}
   935  
   936  	valueFn := func() time.Time {
   937  		testConfig.RLock()
   938  		defer testConfig.RUnlock()
   939  
   940  		return testConfig.v
   941  	}
   942  
   943  	var (
   944  		store = mem.NewStore()
   945  		opts  = NewOptions().SetValidateFn(testValidateTimeFn)
   946  	)
   947  
   948  	_, err := WatchAndUpdateTime(store, "foo", &testConfig.v, &testConfig.RWMutex, testNow, opts)
   949  	require.NoError(t, err)
   950  
   951  	newTime := testNow.Add(30 * time.Second)
   952  	_, err = store.Set("foo", &commonpb.Int64Proto{Value: newTime.Unix()})
   953  	require.NoError(t, err)
   954  	for {
   955  		if valueFn().Unix() == newTime.Unix() {
   956  			break
   957  		}
   958  	}
   959  
   960  	// Invalid updates should not be applied.
   961  	invalidTime := testNow.Add(2 * time.Minute)
   962  	_, err = store.Set("foo", &commonpb.Int64Proto{Value: invalidTime.Unix()})
   963  	require.NoError(t, err)
   964  	time.Sleep(100 * time.Millisecond)
   965  	require.Equal(t, newTime.Unix(), valueFn().Unix())
   966  
   967  	newTime = testNow.Add(45 * time.Second)
   968  	_, err = store.Set("foo", &commonpb.Int64Proto{Value: newTime.Unix()})
   969  	require.NoError(t, err)
   970  	for {
   971  		if valueFn().Unix() == newTime.Unix() {
   972  			break
   973  		}
   974  	}
   975  }
   976  
   977  // NB: uses a map[string]int64 as a standin for a generic type
   978  func TestWatchAndUpdateWithValidationGeneric(t *testing.T) {
   979  	testConfig := struct {
   980  		sync.RWMutex
   981  		v map[string]int64
   982  	}{}
   983  
   984  	valueFn := func() map[string]int64 {
   985  		testConfig.RLock()
   986  		defer testConfig.RUnlock()
   987  
   988  		clonedMap := make(map[string]int64, len(testConfig.v))
   989  		for k, v := range testConfig.v {
   990  			clonedMap[k] = v
   991  		}
   992  
   993  		return clonedMap
   994  	}
   995  
   996  	var (
   997  		store      = mem.NewStore()
   998  		opts       = NewOptions().SetValidateFn(testValidateIntMapFn)
   999  		invalidMap = map[string]int64{"1": 1, "2": 2}
  1000  		defaultMap = map[string]int64{"1": 1, "2": 2, "3": 3}
  1001  		newMap     = map[string]int64{"1": 1, "2": 2, "5": 5}
  1002  		newMap2    = map[string]int64{"1": 1, "2": 2, "100": 100}
  1003  	)
  1004  
  1005  	genericGetFn := func(v kv.Value) (interface{}, error) {
  1006  		var mapProto testpb.MapProto
  1007  		if err := v.Unmarshal(&mapProto); err != nil {
  1008  			return nil, err
  1009  		}
  1010  
  1011  		return mapProto.GetValue(), nil
  1012  	}
  1013  
  1014  	genericUpdateFn := func(i interface{}) {
  1015  		testConfig.v = i.(map[string]int64)
  1016  	}
  1017  
  1018  	waitForExpectedValue := func(m map[string]int64) {
  1019  		start := time.Now()
  1020  		for {
  1021  			value := valueFn()
  1022  			if time.Since(start) >= time.Second*5 {
  1023  				require.FailNow(t, fmt.Sprintf("Exceeded timeout while waiting for "+
  1024  					"generic update result; expected %v, got %v", m, value))
  1025  			}
  1026  
  1027  			if len(value) != len(m) {
  1028  				continue
  1029  			}
  1030  
  1031  			for k, v := range m {
  1032  				if xv, found := value[k]; !found || xv != v {
  1033  					continue
  1034  				}
  1035  			}
  1036  
  1037  			break
  1038  		}
  1039  	}
  1040  
  1041  	_, err := WatchAndUpdateGeneric(store, "foo", genericGetFn, genericUpdateFn,
  1042  		&testConfig.RWMutex, defaultMap, opts)
  1043  	require.NoError(t, err)
  1044  
  1045  	_, err = store.Set("foo", &testpb.MapProto{Value: newMap})
  1046  	require.NoError(t, err)
  1047  
  1048  	waitForExpectedValue(newMap)
  1049  
  1050  	// Invalid updates should not be applied.
  1051  	_, err = store.Set("foo", &testpb.MapProto{Value: invalidMap})
  1052  	require.NoError(t, err)
  1053  	time.Sleep(100 * time.Millisecond)
  1054  
  1055  	// NB: should still be old value.
  1056  	waitForExpectedValue(newMap)
  1057  
  1058  	_, err = store.Set("foo", &testpb.MapProto{Value: newMap2})
  1059  	require.NoError(t, err)
  1060  
  1061  	waitForExpectedValue(newMap2)
  1062  
  1063  	// Check default values with a different kv key.
  1064  	_, err = WatchAndUpdateGeneric(store, "bar", genericGetFn, genericUpdateFn,
  1065  		&testConfig.RWMutex, defaultMap, opts)
  1066  	require.NoError(t, err)
  1067  	waitForExpectedValue(defaultMap)
  1068  }
  1069  
  1070  func testValidateBoolFn(val interface{}) error {
  1071  	v, ok := val.(bool)
  1072  	if !ok {
  1073  		return fmt.Errorf("invalid type for val, expected bool, received %T", val)
  1074  	}
  1075  
  1076  	if !v {
  1077  		return errors.New("value of update is false, must be true")
  1078  	}
  1079  
  1080  	return nil
  1081  }
  1082  
  1083  func testValidateFloat64Fn(val interface{}) error {
  1084  	v, ok := val.(float64)
  1085  	if !ok {
  1086  		return fmt.Errorf("invalid type for val, expected float64, received %T", val)
  1087  	}
  1088  
  1089  	if v > 20 {
  1090  		return fmt.Errorf("val must be < 20, is %v", v)
  1091  	}
  1092  
  1093  	return nil
  1094  }
  1095  
  1096  func testValidateInt64Fn(val interface{}) error {
  1097  	v, ok := val.(int64)
  1098  	if !ok {
  1099  		return fmt.Errorf("invalid type for val, expected int64, received %T", val)
  1100  	}
  1101  
  1102  	if v > 20 {
  1103  		return fmt.Errorf("val must be < 20, is %v", v)
  1104  	}
  1105  
  1106  	return nil
  1107  }
  1108  
  1109  func testValidateStringFn(val interface{}) error {
  1110  	v, ok := val.(string)
  1111  	if !ok {
  1112  		return fmt.Errorf("invalid type for val, expected string, received %T", val)
  1113  	}
  1114  
  1115  	if !strings.HasPrefix(v, "b") {
  1116  		return fmt.Errorf("val must start with 'b', is %v", v)
  1117  	}
  1118  
  1119  	return nil
  1120  }
  1121  
  1122  func testValidateStringArrayFn(val interface{}) error {
  1123  	v, ok := val.([]string)
  1124  	if !ok {
  1125  		return fmt.Errorf("invalid type for val, expected string, received %T", val)
  1126  	}
  1127  
  1128  	if len(v) != 2 {
  1129  		return fmt.Errorf("val must contain 2 entries, is %v", v)
  1130  	}
  1131  
  1132  	return nil
  1133  }
  1134  
  1135  func testValidateIntMapFn(val interface{}) error {
  1136  	v, ok := val.(map[string]int64)
  1137  	if !ok {
  1138  		return fmt.Errorf("invalid type for val, expected int map, received %T", val)
  1139  	}
  1140  
  1141  	// NB: for the purpose of this test, valid maps must have 3 values.
  1142  	if len(v) != 3 {
  1143  		return fmt.Errorf("val must contain 3 entries, has %v", v)
  1144  	}
  1145  
  1146  	return nil
  1147  }
  1148  
  1149  func testValidateTimeFn(val interface{}) error {
  1150  	v, ok := val.(time.Time)
  1151  	if !ok {
  1152  		return fmt.Errorf("invalid type for val, expected time.Time, received %T", val)
  1153  	}
  1154  
  1155  	bound := testNow.Add(time.Minute)
  1156  	if v.After(bound) {
  1157  		return fmt.Errorf("val must be before %v, is %v", bound, v)
  1158  	}
  1159  
  1160  	return nil
  1161  }
  1162  
  1163  func stringSliceEquals(a, b []string) bool {
  1164  	if len(a) != len(b) {
  1165  		return false
  1166  	}
  1167  
  1168  	for i := range a {
  1169  		if a[i] != b[i] {
  1170  			return false
  1171  		}
  1172  	}
  1173  
  1174  	return true
  1175  }