vitess.io/vitess@v0.16.2/go/vt/vtadmin/vtctldclient/proxy_test.go (about)

     1  /*
     2  Copyright 2021 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package vtctldclient
    18  
    19  import (
    20  	"context"
    21  	"net"
    22  	"sync"
    23  	"testing"
    24  	"time"
    25  
    26  	"github.com/stretchr/testify/assert"
    27  	"github.com/stretchr/testify/require"
    28  	"google.golang.org/grpc"
    29  
    30  	"vitess.io/vitess/go/vt/vtadmin/cluster/discovery/fakediscovery"
    31  	"vitess.io/vitess/go/vt/vtadmin/cluster/resolver"
    32  
    33  	vtadminpb "vitess.io/vitess/go/vt/proto/vtadmin"
    34  	vtctldatapb "vitess.io/vitess/go/vt/proto/vtctldata"
    35  	vtctlservicepb "vitess.io/vitess/go/vt/proto/vtctlservice"
    36  )
    37  
    38  type fakeVtctld struct {
    39  	vtctlservicepb.VtctldServer
    40  	addr string
    41  }
    42  
    43  // GetKeyspace is used for tests to detect what addr the VtctldServer is
    44  // listening on. The addr will always be stored as resp.Keyspace.Name, and the
    45  // actual request is ignored.
    46  func (fake *fakeVtctld) GetKeyspace(ctx context.Context, req *vtctldatapb.GetKeyspaceRequest) (*vtctldatapb.GetKeyspaceResponse, error) {
    47  	return &vtctldatapb.GetKeyspaceResponse{
    48  		Keyspace: &vtctldatapb.Keyspace{
    49  			Name: fake.addr,
    50  		},
    51  	}, nil
    52  }
    53  
    54  func initVtctldServer() (net.Listener, *grpc.Server, error) {
    55  	listener, err := net.Listen("tcp", "127.0.0.1:0")
    56  	if err != nil {
    57  		return nil, nil, err
    58  	}
    59  
    60  	vtctld := &fakeVtctld{
    61  		addr: listener.Addr().String(),
    62  	}
    63  	server := grpc.NewServer()
    64  	vtctlservicepb.RegisterVtctldServer(server, vtctld)
    65  
    66  	return listener, server, err
    67  }
    68  
    69  func TestDial(t *testing.T) {
    70  	listener, server, err := initVtctldServer()
    71  	require.NoError(t, err)
    72  
    73  	defer listener.Close()
    74  
    75  	go server.Serve(listener)
    76  	defer server.Stop()
    77  
    78  	disco := fakediscovery.New()
    79  	disco.AddTaggedVtctlds(nil, &vtadminpb.Vtctld{
    80  		Hostname: listener.Addr().String(),
    81  	})
    82  
    83  	proxy, err := New(context.Background(), &Config{
    84  		Cluster: &vtadminpb.Cluster{
    85  			Id:   "test",
    86  			Name: "testcluster",
    87  		},
    88  		ResolverOptions: &resolver.Options{
    89  			Discovery:        disco,
    90  			DiscoveryTimeout: 50 * time.Millisecond,
    91  		},
    92  	})
    93  	require.NoError(t, err)
    94  
    95  	defer proxy.Close() // prevents grpc-core from logging a bunch of "connection errors" after deferred listener.Close() above.
    96  
    97  	resp, err := proxy.GetKeyspace(context.Background(), &vtctldatapb.GetKeyspaceRequest{})
    98  	require.NoError(t, err)
    99  	assert.Equal(t, listener.Addr().String(), resp.Keyspace.Name)
   100  }
   101  
   102  type testdisco struct {
   103  	*fakediscovery.Fake
   104  	notify chan struct{}
   105  	fired  chan struct{}
   106  	m      sync.Mutex
   107  }
   108  
   109  func (d *testdisco) DiscoverVtctldAddrs(ctx context.Context, tags []string) ([]string, error) {
   110  	d.m.Lock()
   111  	defer d.m.Unlock()
   112  
   113  	select {
   114  	case <-d.notify:
   115  		defer func() {
   116  			go func() { d.fired <- struct{}{} }()
   117  		}()
   118  	default:
   119  	}
   120  	return d.Fake.DiscoverVtctldAddrs(ctx, tags)
   121  }
   122  
   123  // TestRedial tests that vtadmin-api is able to recover from a lost connection to
   124  // a vtctld by rediscovering and redialing a new one.
   125  func TestRedial(t *testing.T) {
   126  	// Initialize vtctld #1
   127  	listener1, server1, err := initVtctldServer()
   128  	require.NoError(t, err)
   129  
   130  	defer listener1.Close()
   131  
   132  	go server1.Serve(listener1)
   133  	defer server1.Stop()
   134  
   135  	// Initialize vtctld #2
   136  	listener2, server2, err := initVtctldServer()
   137  	require.NoError(t, err)
   138  
   139  	defer listener2.Close()
   140  
   141  	go server2.Serve(listener2)
   142  	defer server2.Stop()
   143  
   144  	reResolveFired := make(chan struct{}, 1)
   145  
   146  	// Register both vtctlds with VTAdmin
   147  	disco := &testdisco{
   148  		Fake:   fakediscovery.New(),
   149  		notify: make(chan struct{}),
   150  		fired:  reResolveFired,
   151  	}
   152  	disco.AddTaggedVtctlds(nil, &vtadminpb.Vtctld{
   153  		Hostname: listener1.Addr().String(),
   154  	}, &vtadminpb.Vtctld{
   155  		Hostname: listener2.Addr().String(),
   156  	})
   157  
   158  	proxy, err := New(context.Background(), &Config{
   159  		Cluster: &vtadminpb.Cluster{
   160  			Id:   "test",
   161  			Name: "testcluster",
   162  		},
   163  		ResolverOptions: &resolver.Options{
   164  			Discovery:            disco,
   165  			DiscoveryTimeout:     50 * time.Millisecond,
   166  			MinDiscoveryInterval: 0,
   167  			BackoffStrategy:      "none",
   168  		},
   169  	})
   170  	require.NoError(t, err)
   171  
   172  	// vtadmin's fakediscovery package discovers vtctlds in random order. Rather
   173  	// than force some cumbersome sequential logic, we can just do a switcheroo
   174  	// here in the test to determine our "current" and (expected) "next" vtctlds.
   175  	var currentVtctld *grpc.Server
   176  	var nextAddr string
   177  
   178  	// Check for a successful connection to whichever vtctld we discover first.
   179  	resp, err := proxy.GetKeyspace(context.Background(), &vtctldatapb.GetKeyspaceRequest{})
   180  	require.NoError(t, err)
   181  
   182  	proxyHost := resp.Keyspace.Name
   183  	switch proxyHost {
   184  	case listener1.Addr().String():
   185  		currentVtctld = server1
   186  		nextAddr = listener2.Addr().String()
   187  
   188  	case listener2.Addr().String():
   189  		currentVtctld = server2
   190  		nextAddr = listener1.Addr().String()
   191  	default:
   192  		t.Fatalf("invalid proxy host %s", proxyHost)
   193  	}
   194  
   195  	// Shut down the vtctld we're connected to, then await re-resolution.
   196  
   197  	// 1. First, block calls to DiscoverVtctldAddrs so we don't race with the
   198  	// background resolver watcher.
   199  	disco.m.Lock()
   200  
   201  	// 2. Force an ungraceful shutdown of the gRPC server to which we're
   202  	// connected.
   203  	currentVtctld.Stop()
   204  
   205  	// 3. Remove the shut down vtctld from VTAdmin's service discovery
   206  	// (clumsily). Otherwise, when redialing, we may redial the vtctld that we
   207  	// just shut down.
   208  	disco.Clear()
   209  	disco.AddTaggedVtctlds(nil, &vtadminpb.Vtctld{
   210  		Hostname: nextAddr,
   211  	})
   212  
   213  	// 4. Notify our wrapped DiscoverVtctldAddrs function to start signaling on
   214  	// its `fired` channel when called.
   215  	close(disco.notify)
   216  	// 5. Unblock calls to DiscoverVtctldAddrs, and move on to our assertions.
   217  	disco.m.Unlock()
   218  
   219  	maxWait := time.Second
   220  	select {
   221  	case <-reResolveFired:
   222  	case <-time.After(maxWait):
   223  		require.FailNowf(t, "forced shutdown of vtctld should trigger grpc re-resolution", "did not receive re-resolve signal within %s", maxWait)
   224  	}
   225  
   226  	// Finally, check that we discover + establish a new connection to the remaining vtctld.
   227  	resp, err = proxy.GetKeyspace(context.Background(), &vtctldatapb.GetKeyspaceRequest{})
   228  	require.NoError(t, err)
   229  	assert.Equal(t, nextAddr, resp.Keyspace.Name)
   230  }