vitess.io/vitess@v0.16.2/go/vt/srvtopo/resilient_server_test.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package srvtopo
    18  
    19  import (
    20  	"bytes"
    21  	"context"
    22  	"fmt"
    23  	"html/template"
    24  	"reflect"
    25  	"sync"
    26  	"testing"
    27  	"time"
    28  
    29  	"vitess.io/vitess/go/vt/key"
    30  
    31  	"vitess.io/vitess/go/sync2"
    32  
    33  	"github.com/stretchr/testify/assert"
    34  	"github.com/stretchr/testify/require"
    35  	"google.golang.org/protobuf/proto"
    36  
    37  	"vitess.io/vitess/go/vt/status"
    38  	"vitess.io/vitess/go/vt/topo"
    39  	"vitess.io/vitess/go/vt/topo/memorytopo"
    40  
    41  	topodatapb "vitess.io/vitess/go/vt/proto/topodata"
    42  	vschemapb "vitess.io/vitess/go/vt/proto/vschema"
    43  )
    44  
    45  // TestGetSrvKeyspace will test we properly return updated SrvKeyspace.
    46  func TestGetSrvKeyspace(t *testing.T) {
    47  	ts, factory := memorytopo.NewServerAndFactory("test_cell")
    48  	srvTopoCacheTTL = 200 * time.Millisecond
    49  	srvTopoCacheRefresh = 80 * time.Millisecond
    50  	defer func() {
    51  		srvTopoCacheTTL = 1 * time.Second
    52  		srvTopoCacheRefresh = 1 * time.Second
    53  	}()
    54  
    55  	rs := NewResilientServer(ts, "TestGetSrvKeyspace")
    56  
    57  	// Ask for a not-yet-created keyspace
    58  	_, err := rs.GetSrvKeyspace(context.Background(), "test_cell", "test_ks")
    59  	if !topo.IsErrType(err, topo.NoNode) {
    60  		t.Fatalf("GetSrvKeyspace(not created) got unexpected error: %v", err)
    61  	}
    62  
    63  	// Wait until the cached error expires.
    64  	time.Sleep(srvTopoCacheRefresh + 10*time.Millisecond)
    65  
    66  	// Set SrvKeyspace with value
    67  	want := &topodatapb.SrvKeyspace{}
    68  	err = ts.UpdateSrvKeyspace(context.Background(), "test_cell", "test_ks", want)
    69  	require.NoError(t, err, "UpdateSrvKeyspace(test_cell, test_ks, %s) failed", want)
    70  
    71  	// wait until we get the right value
    72  	var got *topodatapb.SrvKeyspace
    73  	expiry := time.Now().Add(srvTopoCacheRefresh - 20*time.Millisecond)
    74  	for {
    75  		ctx, cancel := context.WithCancel(context.Background())
    76  		got, err = rs.GetSrvKeyspace(ctx, "test_cell", "test_ks")
    77  		cancel()
    78  
    79  		if err != nil {
    80  			t.Fatalf("GetSrvKeyspace got unexpected error: %v", err)
    81  		}
    82  		if proto.Equal(want, got) {
    83  			break
    84  		}
    85  		if time.Now().After(expiry) {
    86  			t.Fatalf("GetSrvKeyspace() timeout = %+v, want %+v", got, want)
    87  		}
    88  		time.Sleep(2 * time.Millisecond)
    89  	}
    90  
    91  	// Update the value and check it again to verify that the watcher
    92  	// is still up and running
    93  	want = &topodatapb.SrvKeyspace{Partitions: []*topodatapb.SrvKeyspace_KeyspacePartition{{ServedType: topodatapb.TabletType_REPLICA}}}
    94  	err = ts.UpdateSrvKeyspace(context.Background(), "test_cell", "test_ks", want)
    95  	require.NoError(t, err, "UpdateSrvKeyspace(test_cell, test_ks, %s) failed", want)
    96  
    97  	// Wait a bit to give the watcher enough time to update the value.
    98  	time.Sleep(10 * time.Millisecond)
    99  	got, err = rs.GetSrvKeyspace(context.Background(), "test_cell", "test_ks")
   100  
   101  	if err != nil {
   102  		t.Fatalf("GetSrvKeyspace got unexpected error: %v", err)
   103  	}
   104  	if !proto.Equal(want, got) {
   105  		t.Fatalf("GetSrvKeyspace() = %+v, want %+v", got, want)
   106  	}
   107  
   108  	// make sure the HTML template works
   109  	funcs := map[string]any{}
   110  	for k, v := range status.StatusFuncs {
   111  		funcs[k] = v
   112  	}
   113  	for k, v := range StatusFuncs {
   114  		funcs[k] = v
   115  	}
   116  	templ := template.New("").Funcs(funcs)
   117  	templ, err = templ.Parse(TopoTemplate)
   118  	if err != nil {
   119  		t.Fatalf("error parsing template: %v", err)
   120  	}
   121  	wr := &bytes.Buffer{}
   122  	if err := templ.Execute(wr, rs.CacheStatus()); err != nil {
   123  		t.Fatalf("error executing template: %v", err)
   124  	}
   125  
   126  	// Now delete the SrvKeyspace, wait until we get the error.
   127  	if err := ts.DeleteSrvKeyspace(context.Background(), "test_cell", "test_ks"); err != nil {
   128  		t.Fatalf("DeleteSrvKeyspace() failed: %v", err)
   129  	}
   130  	expiry = time.Now().Add(5 * time.Second)
   131  	for {
   132  		_, err = rs.GetSrvKeyspace(context.Background(), "test_cell", "test_ks")
   133  		if topo.IsErrType(err, topo.NoNode) {
   134  			break
   135  		}
   136  		if time.Now().After(expiry) {
   137  			t.Fatalf("timeout waiting for no keyspace error")
   138  		}
   139  		time.Sleep(time.Millisecond)
   140  	}
   141  
   142  	// Now send an updated real value, see it come through.
   143  	keyRange, err := key.ParseShardingSpec("-")
   144  	if err != nil || len(keyRange) != 1 {
   145  		t.Fatalf("ParseShardingSpec failed. Expected non error and only one element. Got err: %v, len(%v)", err, len(keyRange))
   146  	}
   147  
   148  	want = &topodatapb.SrvKeyspace{
   149  		Partitions: []*topodatapb.SrvKeyspace_KeyspacePartition{
   150  			{
   151  				ServedType: topodatapb.TabletType_PRIMARY,
   152  				ShardReferences: []*topodatapb.ShardReference{
   153  					{
   154  						Name:     "-",
   155  						KeyRange: keyRange[0],
   156  					},
   157  				},
   158  			},
   159  		},
   160  	}
   161  
   162  	err = ts.UpdateSrvKeyspace(context.Background(), "test_cell", "test_ks", want)
   163  	require.NoError(t, err, "UpdateSrvKeyspace(test_cell, test_ks, %s) failed", want)
   164  	expiry = time.Now().Add(5 * time.Second)
   165  	updateTime := time.Now()
   166  	for {
   167  		got, err = rs.GetSrvKeyspace(context.Background(), "test_cell", "test_ks")
   168  		if err == nil && proto.Equal(want, got) {
   169  			break
   170  		}
   171  		if time.Now().After(expiry) {
   172  			t.Fatalf("timeout waiting for new keyspace value")
   173  		}
   174  		time.Sleep(time.Millisecond)
   175  	}
   176  
   177  	// Now simulate a topo service error and see that the last value is
   178  	// cached for at least half of the expected ttl.
   179  	errorTestStart := time.Now()
   180  	errorReqsBefore := rs.counts.Counts()[errorCategory]
   181  	forceErr := topo.NewError(topo.Timeout, "test topo error")
   182  	factory.SetError(forceErr)
   183  
   184  	expiry = time.Now().Add(srvTopoCacheTTL / 2)
   185  	for {
   186  		got, err = rs.GetSrvKeyspace(context.Background(), "test_cell", "test_ks")
   187  		if err != nil || !proto.Equal(want, got) {
   188  			// On a slow test machine it is possible that we never end up
   189  			// verifying the value is cached because it could take too long to
   190  			// even get into this loop... so log this as an informative message
   191  			// but don't fail the test
   192  			if time.Now().After(expiry) {
   193  				t.Logf("test execution was too slow -- caching was not verified")
   194  				break
   195  			}
   196  
   197  			t.Errorf("expected keyspace to be cached for at least %s seconds, got error %v", time.Since(updateTime), err)
   198  		}
   199  
   200  		if time.Now().After(expiry) {
   201  			break
   202  		}
   203  
   204  		time.Sleep(time.Millisecond)
   205  	}
   206  
   207  	// Now wait for the TTL to expire and we should get the expected error
   208  	expiry = time.Now().Add(1 * time.Second)
   209  	for {
   210  		_, err = rs.GetSrvKeyspace(context.Background(), "test_cell", "test_ks")
   211  		if err != nil || err == forceErr {
   212  			break
   213  		}
   214  
   215  		if time.Now().After(expiry) {
   216  			t.Fatalf("timed out waiting for error to be returned")
   217  		}
   218  		time.Sleep(time.Millisecond)
   219  	}
   220  
   221  	// Clear the error away and check that the cached error is still returned
   222  	// until the refresh interval elapses
   223  	factory.SetError(nil)
   224  	_, err = rs.GetSrvKeyspace(context.Background(), "test_cell", "test_ks")
   225  	if err == nil || err != forceErr {
   226  		t.Errorf("expected error to be cached")
   227  	}
   228  
   229  	// Now sleep for the rest of the interval and we should get the value again
   230  	time.Sleep(srvTopoCacheRefresh)
   231  	got, err = rs.GetSrvKeyspace(context.Background(), "test_cell", "test_ks")
   232  	if err != nil || !proto.Equal(want, got) {
   233  		t.Errorf("expected value to be restored, got %v", err)
   234  	}
   235  
   236  	// Now sleep for the full TTL before setting the error again to test
   237  	// that even when there is no activity on the key, it is still cached
   238  	// for the full configured TTL.
   239  	time.Sleep(srvTopoCacheTTL)
   240  	forceErr = topo.NewError(topo.Interrupted, "another test topo error")
   241  	factory.SetError(forceErr)
   242  
   243  	expiry = time.Now().Add(srvTopoCacheTTL / 2)
   244  	for {
   245  		_, err = rs.GetSrvKeyspace(context.Background(), "test_cell", "test_ks")
   246  		if err != nil {
   247  			t.Fatalf("value should have been cached for the full ttl, error %v", err)
   248  		}
   249  		if time.Now().After(expiry) {
   250  			break
   251  		}
   252  		time.Sleep(time.Millisecond)
   253  	}
   254  
   255  	// Wait again until the TTL expires and we get the error
   256  	expiry = time.Now().Add(time.Second)
   257  	for {
   258  		_, err = rs.GetSrvKeyspace(context.Background(), "test_cell", "test_ks")
   259  		if err != nil {
   260  			if err == forceErr {
   261  				break
   262  			}
   263  			t.Fatalf("expected %v got %v", forceErr, err)
   264  		}
   265  
   266  		if time.Now().After(expiry) {
   267  			t.Fatalf("timed out waiting for error")
   268  		}
   269  		time.Sleep(time.Millisecond)
   270  	}
   271  
   272  	factory.SetError(nil)
   273  
   274  	// Check that the expected number of errors were counted during the
   275  	// interval
   276  	errorReqs := rs.counts.Counts()[errorCategory]
   277  	expectedErrors := int64(time.Since(errorTestStart) / srvTopoCacheRefresh)
   278  	if errorReqs-errorReqsBefore > expectedErrors {
   279  		t.Errorf("expected <= %v error requests got %d", expectedErrors, errorReqs-errorReqsBefore)
   280  	}
   281  
   282  	// Check that the watch now works to update the value
   283  	want = &topodatapb.SrvKeyspace{}
   284  	err = ts.UpdateSrvKeyspace(context.Background(), "test_cell", "test_ks", want)
   285  	require.NoError(t, err, "UpdateSrvKeyspace(test_cell, test_ks, %s) failed", want)
   286  	expiry = time.Now().Add(5 * time.Second)
   287  	for {
   288  		got, err = rs.GetSrvKeyspace(context.Background(), "test_cell", "test_ks")
   289  		if err == nil && proto.Equal(want, got) {
   290  			break
   291  		}
   292  		if time.Now().After(expiry) {
   293  			t.Fatalf("timeout waiting for new keyspace value")
   294  		}
   295  		time.Sleep(time.Millisecond)
   296  	}
   297  
   298  	// Now test with a new error in which the topo service is locked during
   299  	// the test which prevents all queries from proceeding.
   300  	forceErr = fmt.Errorf("test topo error with factory locked")
   301  	factory.SetError(forceErr)
   302  	factory.Lock()
   303  	go func() {
   304  		time.Sleep(srvTopoCacheRefresh * 2)
   305  		factory.Unlock()
   306  	}()
   307  
   308  	expiry = time.Now().Add(srvTopoCacheTTL / 2)
   309  	for {
   310  		got, err = rs.GetSrvKeyspace(context.Background(), "test_cell", "test_ks")
   311  		if err != nil || !proto.Equal(want, got) {
   312  			// On a slow test machine it is possible that we never end up
   313  			// verifying the value is cached because it could take too long to
   314  			// even get into this loop... so log this as an informative message
   315  			// but don't fail the test
   316  			if time.Now().After(expiry) {
   317  				t.Logf("test execution was too slow -- caching was not verified")
   318  				break
   319  			}
   320  
   321  			t.Errorf("expected keyspace to be cached for at least %s seconds, got error %v", time.Since(updateTime), err)
   322  		}
   323  
   324  		if time.Now().After(expiry) {
   325  			break
   326  		}
   327  
   328  		time.Sleep(time.Millisecond)
   329  	}
   330  
   331  	// Clear the error, wait for things to proceed again
   332  	factory.SetError(nil)
   333  	time.Sleep(srvTopoCacheTTL)
   334  
   335  	got, err = rs.GetSrvKeyspace(context.Background(), "test_cell", "test_ks")
   336  	if err != nil || !proto.Equal(want, got) {
   337  		t.Errorf("expected error to clear, got %v", err)
   338  	}
   339  
   340  	// Force another error and lock the topo. Then wait for the TTL to
   341  	// expire and verify that the context timeout unblocks the request.
   342  
   343  	// TODO(deepthi): Commenting out this test until we fix https://github.com/vitessio/vitess/issues/6134
   344  
   345  	/*
   346  		forceErr = fmt.Errorf("force long test error")
   347  		factory.SetError(forceErr)
   348  		factory.Lock()
   349  
   350  		time.Sleep(*srvTopoCacheTTL)
   351  
   352  		timeoutCtx, cancel := context.WithTimeout(context.Background(), *srvTopoCacheRefresh*2) //nolint
   353  		defer cancel()
   354  		_, err = rs.GetSrvKeyspace(timeoutCtx, "test_cell", "test_ks")
   355  		wantErr := "timed out waiting for keyspace"
   356  		if err == nil || err.Error() != wantErr {
   357  			t.Errorf("expected error '%v', got '%v'", wantErr, err)
   358  		}
   359  		factory.Unlock()
   360  	*/
   361  }
   362  
   363  // TestSrvKeyspaceCachedError will test we properly re-try to query
   364  // the topo server upon failure.
   365  func TestSrvKeyspaceCachedError(t *testing.T) {
   366  	ts := memorytopo.NewServer("test_cell")
   367  	srvTopoCacheTTL = 100 * time.Millisecond
   368  	srvTopoCacheRefresh = 40 * time.Millisecond
   369  	defer func() {
   370  		srvTopoCacheTTL = 1 * time.Second
   371  		srvTopoCacheRefresh = 1 * time.Second
   372  	}()
   373  	rs := NewResilientServer(ts, "TestSrvKeyspaceCachedErrors")
   374  
   375  	// Ask for an unknown keyspace, should get an error.
   376  	ctx := context.Background()
   377  	_, err := rs.GetSrvKeyspace(ctx, "test_cell", "unknown_ks")
   378  	if err == nil {
   379  		t.Fatalf("First GetSrvKeyspace didn't return an error")
   380  	}
   381  	entry := rs.SrvKeyspaceWatcher.rw.getEntry(&srvKeyspaceKey{"test_cell", "unknown_ks"})
   382  	if err != entry.lastError {
   383  		t.Errorf("Error wasn't saved properly")
   384  	}
   385  
   386  	time.Sleep(srvTopoCacheTTL + 10*time.Millisecond)
   387  	// Ask again with a different context, should get an error and
   388  	// save that context.
   389  	ctx, cancel := context.WithCancel(ctx)
   390  	defer cancel()
   391  	_, err2 := rs.GetSrvKeyspace(ctx, "test_cell", "unknown_ks")
   392  	if err2 == nil {
   393  		t.Fatalf("Second GetSrvKeyspace didn't return an error")
   394  	}
   395  	if err2 != entry.lastError {
   396  		t.Errorf("Error wasn't saved properly")
   397  	}
   398  }
   399  
   400  // TestGetSrvKeyspaceCreated will test we properly get the initial
   401  // value if the SrvKeyspace already exists.
   402  func TestGetSrvKeyspaceCreated(t *testing.T) {
   403  	ts := memorytopo.NewServer("test_cell")
   404  	rs := NewResilientServer(ts, "TestGetSrvKeyspaceCreated")
   405  
   406  	// Set SrvKeyspace with value.
   407  	want := &topodatapb.SrvKeyspace{}
   408  	err := ts.UpdateSrvKeyspace(context.Background(), "test_cell", "test_ks", want)
   409  	require.NoError(t, err, "UpdateSrvKeyspace(test_cell, test_ks, %s) failed", want)
   410  
   411  	// Wait until we get the right value.
   412  	expiry := time.Now().Add(5 * time.Second)
   413  	for {
   414  		got, err := rs.GetSrvKeyspace(context.Background(), "test_cell", "test_ks")
   415  		switch {
   416  		case topo.IsErrType(err, topo.NoNode):
   417  			// keep trying
   418  		case err == nil:
   419  			// we got a value, see if it's good
   420  			if proto.Equal(want, got) {
   421  				return
   422  			}
   423  		default:
   424  			t.Fatalf("GetSrvKeyspace got unexpected error: %v", err)
   425  		}
   426  		if time.Now().After(expiry) {
   427  			t.Fatalf("GetSrvKeyspace() timeout = %+v, want %+v", got, want)
   428  		}
   429  		time.Sleep(10 * time.Millisecond)
   430  	}
   431  }
   432  
   433  func TestWatchSrvVSchema(t *testing.T) {
   434  	srvTopoCacheRefresh = 10 * time.Millisecond
   435  	ctx := context.Background()
   436  	ts := memorytopo.NewServer("test_cell")
   437  	rs := NewResilientServer(ts, "TestWatchSrvVSchema")
   438  
   439  	// mu protects watchValue and watchErr.
   440  	mu := sync.Mutex{}
   441  	var watchValue *vschemapb.SrvVSchema
   442  	var watchErr error
   443  	rs.WatchSrvVSchema(ctx, "test_cell", func(v *vschemapb.SrvVSchema, e error) bool {
   444  		mu.Lock()
   445  		defer mu.Unlock()
   446  		watchValue = v
   447  		watchErr = e
   448  		return true
   449  	})
   450  	get := func() (*vschemapb.SrvVSchema, error) {
   451  		mu.Lock()
   452  		defer mu.Unlock()
   453  		return watchValue, watchErr
   454  	}
   455  
   456  	// WatchSrvVSchema won't return until it gets the initial value,
   457  	// which is not there, so we should get watchErr=topo.ErrNoNode.
   458  	if _, err := get(); !topo.IsErrType(err, topo.NoNode) {
   459  		t.Fatalf("WatchSrvVSchema didn't return topo.ErrNoNode at first, but got: %v", err)
   460  	}
   461  
   462  	// Save a value, wait for it.
   463  	newValue := &vschemapb.SrvVSchema{
   464  		Keyspaces: map[string]*vschemapb.Keyspace{
   465  			"ks1": {},
   466  		},
   467  	}
   468  	if err := ts.UpdateSrvVSchema(ctx, "test_cell", newValue); err != nil {
   469  		t.Fatalf("UpdateSrvVSchema failed: %v", err)
   470  	}
   471  	start := time.Now()
   472  	for {
   473  		if v, err := get(); err == nil && proto.Equal(newValue, v) {
   474  			break
   475  		}
   476  		if time.Since(start) > 5*time.Second {
   477  			t.Fatalf("timed out waiting for new SrvVschema")
   478  		}
   479  		time.Sleep(10 * time.Millisecond)
   480  	}
   481  
   482  	// Update value, wait for it.
   483  	updatedValue := &vschemapb.SrvVSchema{
   484  		Keyspaces: map[string]*vschemapb.Keyspace{
   485  			"ks2": {},
   486  		},
   487  	}
   488  	if err := ts.UpdateSrvVSchema(ctx, "test_cell", updatedValue); err != nil {
   489  		t.Fatalf("UpdateSrvVSchema failed: %v", err)
   490  	}
   491  	start = time.Now()
   492  	for {
   493  		if v, err := get(); err == nil && proto.Equal(updatedValue, v) {
   494  			break
   495  		}
   496  		if time.Since(start) > 5*time.Second {
   497  			t.Fatalf("timed out waiting for updated SrvVschema")
   498  		}
   499  		time.Sleep(10 * time.Millisecond)
   500  	}
   501  
   502  	// Delete the value, wait for topo.ErrNoNode
   503  	if err := ts.DeleteSrvVSchema(ctx, "test_cell"); err != nil {
   504  		t.Fatalf("DeleteSrvVSchema failed: %v", err)
   505  	}
   506  	start = time.Now()
   507  	for {
   508  		if _, err := get(); topo.IsErrType(err, topo.NoNode) {
   509  			break
   510  		}
   511  		if time.Since(start) > 5*time.Second {
   512  			t.Fatalf("timed out waiting for deleted SrvVschema")
   513  		}
   514  		time.Sleep(10 * time.Millisecond)
   515  	}
   516  }
   517  
   518  func TestGetSrvKeyspaceNames(t *testing.T) {
   519  	ts, factory := memorytopo.NewServerAndFactory("test_cell")
   520  	srvTopoCacheTTL = 100 * time.Millisecond
   521  	srvTopoCacheRefresh = 40 * time.Millisecond
   522  	defer func() {
   523  		srvTopoCacheTTL = 1 * time.Second
   524  		srvTopoCacheRefresh = 1 * time.Second
   525  	}()
   526  	rs := NewResilientServer(ts, "TestGetSrvKeyspaceNames")
   527  
   528  	// Set SrvKeyspace with value
   529  	want := &topodatapb.SrvKeyspace{}
   530  	err := ts.UpdateSrvKeyspace(context.Background(), "test_cell", "test_ks", want)
   531  	require.NoError(t, err, "UpdateSrvKeyspace(test_cell, test_ks, %s) failed", want)
   532  
   533  	err = ts.UpdateSrvKeyspace(context.Background(), "test_cell", "test_ks2", want)
   534  	require.NoError(t, err, "UpdateSrvKeyspace(test_cell, test_ks2, %s) failed", want)
   535  
   536  	ctx := context.Background()
   537  	names, err := rs.GetSrvKeyspaceNames(ctx, "test_cell", false)
   538  	if err != nil {
   539  		t.Errorf("GetSrvKeyspaceNames unexpected error %v", err)
   540  	}
   541  	wantNames := []string{"test_ks", "test_ks2"}
   542  
   543  	if !reflect.DeepEqual(names, wantNames) {
   544  		t.Errorf("GetSrvKeyspaceNames got %v want %v", names, wantNames)
   545  	}
   546  
   547  	forceErr := fmt.Errorf("force test error")
   548  	factory.SetError(forceErr)
   549  
   550  	// Lock the topo for half the duration of the cache TTL to ensure our
   551  	// requests aren't blocked
   552  	factory.Lock()
   553  	go func() {
   554  		time.Sleep(srvTopoCacheTTL / 2)
   555  		factory.Unlock()
   556  	}()
   557  
   558  	// Check that we get the cached value until at least the refresh interval
   559  	// elapses but before the TTL expires
   560  	start := time.Now()
   561  	for {
   562  		names, err = rs.GetSrvKeyspaceNames(ctx, "test_cell", false)
   563  		if err != nil {
   564  			t.Errorf("GetSrvKeyspaceNames unexpected error %v", err)
   565  		}
   566  
   567  		if !reflect.DeepEqual(names, wantNames) {
   568  			t.Errorf("GetSrvKeyspaceNames got %v want %v", names, wantNames)
   569  		}
   570  
   571  		if time.Since(start) >= srvTopoCacheRefresh+10*time.Millisecond {
   572  			break
   573  		}
   574  
   575  		time.Sleep(time.Millisecond)
   576  	}
   577  
   578  	// Now wait for it to expire from cache
   579  	for {
   580  		_, err = rs.GetSrvKeyspaceNames(ctx, "test_cell", false)
   581  		if err != nil {
   582  			break
   583  		}
   584  
   585  		time.Sleep(2 * time.Millisecond)
   586  
   587  		if time.Since(start) > 2*time.Second {
   588  			t.Fatalf("expected error after TTL expires")
   589  		}
   590  	}
   591  
   592  	if err != forceErr {
   593  		t.Errorf("got error %v want %v", err, forceErr)
   594  	}
   595  
   596  	// Now, since the TTL has expired, check that when we ask for stale
   597  	// info, we'll get it.
   598  	_, err = rs.GetSrvKeyspaceNames(ctx, "test_cell", true)
   599  	if err != nil {
   600  		t.Fatalf("expected no error if asking for stale cache data")
   601  	}
   602  
   603  	// Now, wait long enough that with a stale ask, we'll get an error
   604  	time.Sleep(srvTopoCacheRefresh*2 + 2*time.Millisecond)
   605  	_, err = rs.GetSrvKeyspaceNames(ctx, "test_cell", true)
   606  	if err != forceErr {
   607  		t.Fatalf("expected an error if asking for really stale cache data")
   608  	}
   609  
   610  	// Check that we only checked the topo service 1 or 2 times during the
   611  	// period where we got the cached error.
   612  	cachedReqs, ok := rs.counts.Counts()[cachedCategory]
   613  	if !ok || cachedReqs > 2 {
   614  		t.Errorf("expected <= 2 cached requests got %v", cachedReqs)
   615  	}
   616  
   617  	// Clear the error and wait until the cached error state expires
   618  	factory.SetError(nil)
   619  
   620  	start = time.Now()
   621  	for {
   622  		names, err = rs.GetSrvKeyspaceNames(ctx, "test_cell", false)
   623  		if err == nil {
   624  			break
   625  		}
   626  
   627  		time.Sleep(2 * time.Millisecond)
   628  
   629  		if time.Since(start) > 2*time.Second {
   630  			t.Fatalf("expected error after TTL expires")
   631  		}
   632  	}
   633  
   634  	if !reflect.DeepEqual(names, wantNames) {
   635  		t.Errorf("GetSrvKeyspaceNames got %v want %v", names, wantNames)
   636  	}
   637  
   638  	errorReqs, ok := rs.counts.Counts()[errorCategory]
   639  	if !ok || errorReqs == 0 {
   640  		t.Errorf("expected non-zero error requests got %v", errorReqs)
   641  	}
   642  
   643  	// Force another error and lock the topo. Then wait for the TTL to
   644  	// expire and verify that the context timeout unblocks the request.
   645  	forceErr = fmt.Errorf("force long test error")
   646  	factory.SetError(forceErr)
   647  	factory.Lock()
   648  
   649  	time.Sleep(srvTopoCacheTTL)
   650  
   651  	timeoutCtx, cancel := context.WithTimeout(context.Background(), srvTopoCacheRefresh*2) //nolint
   652  	defer cancel()
   653  	_, err = rs.GetSrvKeyspaceNames(timeoutCtx, "test_cell", false)
   654  	if err != context.DeadlineExceeded {
   655  		t.Errorf("expected error '%v', got '%v'", context.DeadlineExceeded, err.Error())
   656  	}
   657  	factory.Unlock()
   658  }
   659  
   660  type watched struct {
   661  	keyspace *topodatapb.SrvKeyspace
   662  	err      error
   663  }
   664  
   665  func (w *watched) equals(other *watched) bool {
   666  	if w.keyspace != nil {
   667  		return other.keyspace != nil && proto.Equal(w.keyspace, other.keyspace)
   668  	}
   669  	return w.err == other.err
   670  }
   671  
   672  func TestSrvKeyspaceWatcher(t *testing.T) {
   673  	ts, factory := memorytopo.NewServerAndFactory("test_cell")
   674  	srvTopoCacheTTL = 100 * time.Millisecond
   675  	srvTopoCacheRefresh = 40 * time.Millisecond
   676  	defer func() {
   677  		srvTopoCacheTTL = 1 * time.Second
   678  		srvTopoCacheRefresh = 1 * time.Second
   679  	}()
   680  
   681  	rs := NewResilientServer(ts, "TestGetSrvKeyspaceWatcher")
   682  
   683  	var wmu sync.Mutex
   684  	var wseen []watched
   685  
   686  	allSeen := func() []watched {
   687  		wmu.Lock()
   688  		defer wmu.Unlock()
   689  
   690  		var result []watched
   691  		for _, w := range wseen {
   692  			if len(result) == 0 || !result[len(result)-1].equals(&w) {
   693  				result = append(result, w)
   694  			}
   695  		}
   696  		return result
   697  	}
   698  
   699  	waitForEntries := func(entryCount int) []watched {
   700  		var current []watched
   701  		var expire = time.Now().Add(5 * time.Second)
   702  
   703  		for time.Now().Before(expire) {
   704  			current = allSeen()
   705  			if len(current) >= entryCount {
   706  				return current
   707  			}
   708  			time.Sleep(2 * time.Millisecond)
   709  		}
   710  		t.Fatalf("Failed to receive %d entries after 5s (got %d entries so far)", entryCount, len(current))
   711  		return nil
   712  	}
   713  
   714  	rs.WatchSrvKeyspace(context.Background(), "test_cell", "test_ks", func(keyspace *topodatapb.SrvKeyspace, err error) bool {
   715  		wmu.Lock()
   716  		defer wmu.Unlock()
   717  		wseen = append(wseen, watched{keyspace: keyspace, err: err})
   718  		return true
   719  	})
   720  
   721  	seen1 := allSeen()
   722  	assert.Len(t, seen1, 1)
   723  	assert.Nil(t, seen1[0].keyspace)
   724  	assert.True(t, topo.IsErrType(seen1[0].err, topo.NoNode))
   725  
   726  	// Set SrvKeyspace with no values
   727  	want := &topodatapb.SrvKeyspace{}
   728  	err := ts.UpdateSrvKeyspace(context.Background(), "test_cell", "test_ks", want)
   729  	require.NoError(t, err)
   730  
   731  	seen2 := waitForEntries(2)
   732  	assert.Len(t, seen2, 2)
   733  	assert.NotNil(t, seen2[1].keyspace)
   734  	assert.Nil(t, seen2[1].err)
   735  	assert.True(t, proto.Equal(want, seen2[1].keyspace))
   736  
   737  	// Now delete the SrvKeyspace, wait until we get the error.
   738  	err = ts.DeleteSrvKeyspace(context.Background(), "test_cell", "test_ks")
   739  	require.NoError(t, err)
   740  
   741  	seen3 := waitForEntries(3)
   742  	assert.Len(t, seen3, 3)
   743  	assert.Nil(t, seen3[2].keyspace)
   744  	assert.True(t, topo.IsErrType(seen3[2].err, topo.NoNode))
   745  
   746  	keyRange, err := key.ParseShardingSpec("-")
   747  	if err != nil || len(keyRange) != 1 {
   748  		t.Fatalf("ParseShardingSpec failed. Expected non error and only one element. Got err: %v, len(%v)", err, len(keyRange))
   749  	}
   750  
   751  	for i := 0; i < 5; i++ {
   752  		want = &topodatapb.SrvKeyspace{
   753  			Partitions: []*topodatapb.SrvKeyspace_KeyspacePartition{
   754  				{
   755  					ServedType: topodatapb.TabletType_PRIMARY,
   756  					ShardReferences: []*topodatapb.ShardReference{
   757  						{
   758  							// This may not be a valid shard spec, but is fine for unit test purposes
   759  							Name:     fmt.Sprintf("%d", i),
   760  							KeyRange: keyRange[0],
   761  						},
   762  					},
   763  				},
   764  			},
   765  		}
   766  		err = ts.UpdateSrvKeyspace(context.Background(), "test_cell", "test_ks", want)
   767  		require.NoError(t, err)
   768  		time.Sleep(100 * time.Millisecond)
   769  	}
   770  
   771  	seen4 := waitForEntries(8)
   772  	assert.Len(t, seen4, 8)
   773  
   774  	for i := 0; i < 5; i++ {
   775  		w := seen4[3+i]
   776  		assert.Nil(t, w.err)
   777  	}
   778  
   779  	// Now simulate a topo service error
   780  	forceErr := topo.NewError(topo.Timeout, "test topo error")
   781  	factory.SetError(forceErr)
   782  
   783  	seen5 := waitForEntries(9)
   784  	assert.Len(t, seen5, 9)
   785  	assert.Nil(t, seen5[8].keyspace)
   786  	assert.True(t, topo.IsErrType(seen5[8].err, topo.Timeout))
   787  
   788  	factory.SetError(nil)
   789  
   790  	seen6 := waitForEntries(10)
   791  	assert.Len(t, seen6, 10)
   792  	assert.Nil(t, seen6[9].err)
   793  	assert.NotNil(t, seen6[9].keyspace)
   794  }
   795  
   796  func TestSrvKeyspaceListener(t *testing.T) {
   797  	ts, _ := memorytopo.NewServerAndFactory("test_cell")
   798  	srvTopoCacheTTL = 100 * time.Millisecond
   799  	srvTopoCacheRefresh = 40 * time.Millisecond
   800  	defer func() {
   801  		srvTopoCacheTTL = 1 * time.Second
   802  		srvTopoCacheRefresh = 1 * time.Second
   803  	}()
   804  
   805  	rs := NewResilientServer(ts, "TestGetSrvKeyspaceWatcher")
   806  
   807  	ctx, cancel := context.WithCancel(context.Background())
   808  	var callbackCount sync2.AtomicInt32
   809  
   810  	// adding listener will perform callback.
   811  	rs.WatchSrvKeyspace(context.Background(), "test_cell", "test_ks", func(srvKs *topodatapb.SrvKeyspace, err error) bool {
   812  		callbackCount.Add(1)
   813  		select {
   814  		case <-ctx.Done():
   815  			return false
   816  		default:
   817  			return true
   818  		}
   819  	})
   820  
   821  	// First update (callback - 2)
   822  	want := &topodatapb.SrvKeyspace{}
   823  	err := ts.UpdateSrvKeyspace(context.Background(), "test_cell", "test_ks", want)
   824  	require.NoError(t, err)
   825  
   826  	// Next callback to remove from listener
   827  	cancel()
   828  
   829  	// multi updates thereafter
   830  	for i := 0; i < 5; i++ {
   831  		want = &topodatapb.SrvKeyspace{}
   832  		err = ts.UpdateSrvKeyspace(context.Background(), "test_cell", "test_ks", want)
   833  		require.NoError(t, err)
   834  		time.Sleep(100 * time.Millisecond)
   835  	}
   836  
   837  	// only 3 times the callback called for the listener
   838  	assert.EqualValues(t, 3, callbackCount.Get())
   839  }