vitess.io/vitess@v0.16.2/go/vt/vttablet/tabletconntest/tabletconntest.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  // Package tabletconntest provides the test methods to make sure a
    18  // tabletconn/queryservice pair over RPC works correctly.
    19  package tabletconntest
    20  
    21  import (
    22  	"context"
    23  	"io"
    24  	"os"
    25  	"strings"
    26  	"testing"
    27  
    28  	"github.com/spf13/pflag"
    29  	"github.com/stretchr/testify/assert"
    30  	"github.com/stretchr/testify/require"
    31  	"google.golang.org/protobuf/proto"
    32  
    33  	"vitess.io/vitess/go/sqltypes"
    34  	"vitess.io/vitess/go/vt/callerid"
    35  	"vitess.io/vitess/go/vt/grpcclient"
    36  	"vitess.io/vitess/go/vt/log"
    37  	"vitess.io/vitess/go/vt/servenv"
    38  	"vitess.io/vitess/go/vt/vterrors"
    39  	"vitess.io/vitess/go/vt/vttablet/queryservice"
    40  	"vitess.io/vitess/go/vt/vttablet/tabletconn"
    41  
    42  	querypb "vitess.io/vitess/go/vt/proto/query"
    43  	topodatapb "vitess.io/vitess/go/vt/proto/topodata"
    44  	vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
    45  )
    46  
    47  // testErrorHelper will check one instance of each error type,
    48  // to make sure we propagate the errors properly.
    49  func testErrorHelper(t *testing.T, f *FakeQueryService, name string, ef func(context.Context) error) {
    50  	errors := []error{
    51  		// A few generic errors
    52  		vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "generic error"),
    53  		vterrors.Errorf(vtrpcpb.Code_UNKNOWN, "uncaught panic"),
    54  		vterrors.Errorf(vtrpcpb.Code_UNAUTHENTICATED, "missing caller id"),
    55  		vterrors.Errorf(vtrpcpb.Code_PERMISSION_DENIED, "table acl error: nil acl"),
    56  
    57  		// Client will retry on this specific error
    58  		vterrors.Errorf(vtrpcpb.Code_FAILED_PRECONDITION, "query disallowed due to rule: %v", "cool rule"),
    59  
    60  		// Client may retry on another server on this specific error
    61  		vterrors.Errorf(vtrpcpb.Code_INTERNAL, "could not verify strict mode"),
    62  
    63  		// This is usually transaction pool full
    64  		vterrors.Errorf(vtrpcpb.Code_RESOURCE_EXHAUSTED, "transaction pool connection limit exceeded"),
    65  
    66  		// Transaction expired or was unknown
    67  		vterrors.Errorf(vtrpcpb.Code_ABORTED, "transaction 12"),
    68  	}
    69  	for _, e := range errors {
    70  		f.TabletError = e
    71  		ctx := context.Background()
    72  		err := ef(ctx)
    73  		if err == nil {
    74  			t.Errorf("error wasn't returned for %v?", name)
    75  			continue
    76  		}
    77  
    78  		// First we check the recoverable vtrpc code is right.
    79  		code := vterrors.Code(err)
    80  		wantcode := vterrors.Code(e)
    81  		if code != wantcode {
    82  			t.Errorf("unexpected server code from %v: got %v, wanted %v", name, code, wantcode)
    83  		}
    84  
    85  		if !strings.Contains(err.Error(), e.Error()) {
    86  			t.Errorf("client error message '%v' for %v doesn't contain expected server text message '%v'", err.Error(), name, e)
    87  		}
    88  	}
    89  	f.TabletError = nil
    90  }
    91  
    92  func testPanicHelper(t *testing.T, f *FakeQueryService, name string, pf func(context.Context) error) {
    93  	f.Panics = true
    94  	ctx := context.Background()
    95  	if err := pf(ctx); err == nil || !strings.Contains(err.Error(), "caught test panic") {
    96  		t.Fatalf("unexpected panic error for %v: %v", name, err)
    97  	}
    98  	f.Panics = false
    99  }
   100  
   101  func testBegin(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   102  	t.Log("testBegin")
   103  	ctx := context.Background()
   104  	ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID)
   105  	state, err := conn.Begin(ctx, TestTarget, TestExecuteOptions)
   106  	if err != nil {
   107  		t.Fatalf("Begin failed: %v", err)
   108  	}
   109  	if state.TransactionID != beginTransactionID {
   110  		t.Errorf("Unexpected result from Begin: got %v wanted %v", state.TransactionID, beginTransactionID)
   111  	}
   112  	assert.Equal(t, TestAlias, state.TabletAlias, "Unexpected tablet alias from Begin")
   113  }
   114  
   115  func testBeginError(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   116  	t.Log("testBeginError")
   117  	f.HasBeginError = true
   118  	testErrorHelper(t, f, "Begin", func(ctx context.Context) error {
   119  		_, err := conn.Begin(ctx, TestTarget, nil)
   120  		return err
   121  	})
   122  	f.HasBeginError = false
   123  }
   124  
   125  func testBeginPanics(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   126  	t.Log("testBeginPanics")
   127  	testPanicHelper(t, f, "Begin", func(ctx context.Context) error {
   128  		_, err := conn.Begin(ctx, TestTarget, nil)
   129  		return err
   130  	})
   131  }
   132  
   133  func testCommit(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   134  	t.Log("testCommit")
   135  	ctx := context.Background()
   136  	ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID)
   137  	_, err := conn.Commit(ctx, TestTarget, commitTransactionID)
   138  	if err != nil {
   139  		t.Fatalf("Commit failed: %v", err)
   140  	}
   141  }
   142  
   143  func testCommitError(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   144  	t.Log("testCommitError")
   145  	f.HasError = true
   146  	testErrorHelper(t, f, "Commit", func(ctx context.Context) error {
   147  		_, err := conn.Commit(ctx, TestTarget, commitTransactionID)
   148  		return err
   149  	})
   150  	f.HasError = false
   151  }
   152  
   153  func testCommitPanics(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   154  	t.Log("testCommitPanics")
   155  	testPanicHelper(t, f, "Commit", func(ctx context.Context) error {
   156  		_, err := conn.Commit(ctx, TestTarget, commitTransactionID)
   157  		return err
   158  	})
   159  }
   160  
   161  func testRollback(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   162  	t.Log("testRollback")
   163  	ctx := context.Background()
   164  	ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID)
   165  	_, err := conn.Rollback(ctx, TestTarget, rollbackTransactionID)
   166  	if err != nil {
   167  		t.Fatalf("Rollback failed: %v", err)
   168  	}
   169  }
   170  
   171  func testRollbackError(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   172  	t.Log("testRollbackError")
   173  	f.HasError = true
   174  	testErrorHelper(t, f, "Rollback", func(ctx context.Context) error {
   175  		_, err := conn.Rollback(ctx, TestTarget, commitTransactionID)
   176  		return err
   177  	})
   178  	f.HasError = false
   179  }
   180  
   181  func testRollbackPanics(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   182  	t.Log("testRollbackPanics")
   183  	testPanicHelper(t, f, "Rollback", func(ctx context.Context) error {
   184  		_, err := conn.Rollback(ctx, TestTarget, rollbackTransactionID)
   185  		return err
   186  	})
   187  }
   188  
   189  func testPrepare(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   190  	t.Log("testPrepare")
   191  	ctx := context.Background()
   192  	ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID)
   193  	err := conn.Prepare(ctx, TestTarget, commitTransactionID, Dtid)
   194  	if err != nil {
   195  		t.Fatalf("Prepare failed: %v", err)
   196  	}
   197  }
   198  
   199  func testPrepareError(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   200  	t.Log("testPrepareError")
   201  	f.HasError = true
   202  	testErrorHelper(t, f, "Prepare", func(ctx context.Context) error {
   203  		return conn.Prepare(ctx, TestTarget, commitTransactionID, Dtid)
   204  	})
   205  	f.HasError = false
   206  }
   207  
   208  func testPreparePanics(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   209  	t.Log("testPreparePanics")
   210  	testPanicHelper(t, f, "Prepare", func(ctx context.Context) error {
   211  		return conn.Prepare(ctx, TestTarget, commitTransactionID, Dtid)
   212  	})
   213  }
   214  
   215  func testCommitPrepared(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   216  	t.Log("testCommitPrepared")
   217  	ctx := context.Background()
   218  	ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID)
   219  	err := conn.CommitPrepared(ctx, TestTarget, Dtid)
   220  	if err != nil {
   221  		t.Fatalf("CommitPrepared failed: %v", err)
   222  	}
   223  }
   224  
   225  func testCommitPreparedError(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   226  	t.Log("testCommitPreparedError")
   227  	f.HasError = true
   228  	testErrorHelper(t, f, "CommitPrepared", func(ctx context.Context) error {
   229  		return conn.CommitPrepared(ctx, TestTarget, Dtid)
   230  	})
   231  	f.HasError = false
   232  }
   233  
   234  func testCommitPreparedPanics(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   235  	t.Log("testCommitPreparedPanics")
   236  	testPanicHelper(t, f, "CommitPrepared", func(ctx context.Context) error {
   237  		return conn.CommitPrepared(ctx, TestTarget, Dtid)
   238  	})
   239  }
   240  
   241  func testRollbackPrepared(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   242  	t.Log("testRollbackPrepared")
   243  	ctx := context.Background()
   244  	ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID)
   245  	err := conn.RollbackPrepared(ctx, TestTarget, Dtid, rollbackTransactionID)
   246  	if err != nil {
   247  		t.Fatalf("RollbackPrepared failed: %v", err)
   248  	}
   249  }
   250  
   251  func testRollbackPreparedError(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   252  	t.Log("testRollbackPreparedError")
   253  	f.HasError = true
   254  	testErrorHelper(t, f, "RollbackPrepared", func(ctx context.Context) error {
   255  		return conn.RollbackPrepared(ctx, TestTarget, Dtid, rollbackTransactionID)
   256  	})
   257  	f.HasError = false
   258  }
   259  
   260  func testRollbackPreparedPanics(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   261  	t.Log("testRollbackPreparedPanics")
   262  	testPanicHelper(t, f, "RollbackPrepared", func(ctx context.Context) error {
   263  		return conn.RollbackPrepared(ctx, TestTarget, Dtid, rollbackTransactionID)
   264  	})
   265  }
   266  
   267  func testCreateTransaction(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   268  	t.Log("testCreateTransaction")
   269  	ctx := context.Background()
   270  	ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID)
   271  	err := conn.CreateTransaction(ctx, TestTarget, Dtid, Participants)
   272  	if err != nil {
   273  		t.Fatalf("CreateTransaction failed: %v", err)
   274  	}
   275  }
   276  
   277  func testCreateTransactionError(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   278  	t.Log("testCreateTransactionError")
   279  	f.HasError = true
   280  	testErrorHelper(t, f, "CreateTransaction", func(ctx context.Context) error {
   281  		return conn.CreateTransaction(ctx, TestTarget, Dtid, Participants)
   282  	})
   283  	f.HasError = false
   284  }
   285  
   286  func testCreateTransactionPanics(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   287  	t.Log("testCreateTransactionPanics")
   288  	testPanicHelper(t, f, "CreateTransaction", func(ctx context.Context) error {
   289  		return conn.CreateTransaction(ctx, TestTarget, Dtid, Participants)
   290  	})
   291  }
   292  
   293  func testStartCommit(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   294  	t.Log("testStartCommit")
   295  	ctx := context.Background()
   296  	ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID)
   297  	err := conn.StartCommit(ctx, TestTarget, commitTransactionID, Dtid)
   298  	if err != nil {
   299  		t.Fatalf("StartCommit failed: %v", err)
   300  	}
   301  }
   302  
   303  func testStartCommitError(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   304  	t.Log("testStartCommitError")
   305  	f.HasError = true
   306  	testErrorHelper(t, f, "StartCommit", func(ctx context.Context) error {
   307  		return conn.StartCommit(ctx, TestTarget, commitTransactionID, Dtid)
   308  	})
   309  	f.HasError = false
   310  }
   311  
   312  func testStartCommitPanics(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   313  	t.Log("testStartCommitPanics")
   314  	testPanicHelper(t, f, "StartCommit", func(ctx context.Context) error {
   315  		return conn.StartCommit(ctx, TestTarget, commitTransactionID, Dtid)
   316  	})
   317  }
   318  
   319  func testSetRollback(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   320  	t.Log("testSetRollback")
   321  	ctx := context.Background()
   322  	ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID)
   323  	err := conn.SetRollback(ctx, TestTarget, Dtid, rollbackTransactionID)
   324  	if err != nil {
   325  		t.Fatalf("SetRollback failed: %v", err)
   326  	}
   327  }
   328  
   329  func testSetRollbackError(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   330  	t.Log("testSetRollbackError")
   331  	f.HasError = true
   332  	testErrorHelper(t, f, "SetRollback", func(ctx context.Context) error {
   333  		return conn.SetRollback(ctx, TestTarget, Dtid, rollbackTransactionID)
   334  	})
   335  	f.HasError = false
   336  }
   337  
   338  func testSetRollbackPanics(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   339  	t.Log("testSetRollbackPanics")
   340  	testPanicHelper(t, f, "SetRollback", func(ctx context.Context) error {
   341  		return conn.SetRollback(ctx, TestTarget, Dtid, rollbackTransactionID)
   342  	})
   343  }
   344  
   345  func testConcludeTransaction(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   346  	t.Log("testConcludeTransaction")
   347  	ctx := context.Background()
   348  	ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID)
   349  	err := conn.ConcludeTransaction(ctx, TestTarget, Dtid)
   350  	if err != nil {
   351  		t.Fatalf("ConcludeTransaction failed: %v", err)
   352  	}
   353  }
   354  
   355  func testConcludeTransactionError(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   356  	t.Log("testConcludeTransactionError")
   357  	f.HasError = true
   358  	testErrorHelper(t, f, "ConcludeTransaction", func(ctx context.Context) error {
   359  		return conn.ConcludeTransaction(ctx, TestTarget, Dtid)
   360  	})
   361  	f.HasError = false
   362  }
   363  
   364  func testConcludeTransactionPanics(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   365  	t.Log("testConcludeTransactionPanics")
   366  	testPanicHelper(t, f, "ConcludeTransaction", func(ctx context.Context) error {
   367  		return conn.ConcludeTransaction(ctx, TestTarget, Dtid)
   368  	})
   369  }
   370  
   371  func testReadTransaction(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   372  	t.Log("testReadTransaction")
   373  	ctx := context.Background()
   374  	ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID)
   375  	metadata, err := conn.ReadTransaction(ctx, TestTarget, Dtid)
   376  	if err != nil {
   377  		t.Fatalf("ReadTransaction failed: %v", err)
   378  	}
   379  	if !proto.Equal(metadata, Metadata) {
   380  		t.Errorf("Unexpected result from Execute: got %v wanted %v", metadata, Metadata)
   381  	}
   382  }
   383  
   384  func testReadTransactionError(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   385  	t.Log("testReadTransactionError")
   386  	f.HasError = true
   387  	testErrorHelper(t, f, "ReadTransaction", func(ctx context.Context) error {
   388  		_, err := conn.ReadTransaction(ctx, TestTarget, Dtid)
   389  		return err
   390  	})
   391  	f.HasError = false
   392  }
   393  
   394  func testReadTransactionPanics(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   395  	t.Log("testReadTransactionPanics")
   396  	testPanicHelper(t, f, "ReadTransaction", func(ctx context.Context) error {
   397  		_, err := conn.ReadTransaction(ctx, TestTarget, Dtid)
   398  		return err
   399  	})
   400  }
   401  
   402  func testExecute(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   403  	t.Log("testExecute")
   404  	f.ExpectedTransactionID = ExecuteTransactionID
   405  	ctx := context.Background()
   406  	ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID)
   407  	qr, err := conn.Execute(ctx, TestTarget, ExecuteQuery, ExecuteBindVars, ExecuteTransactionID, ReserveConnectionID, TestExecuteOptions)
   408  	if err != nil {
   409  		t.Fatalf("Execute failed: %v", err)
   410  	}
   411  	if !qr.Equal(&ExecuteQueryResult) {
   412  		t.Errorf("Unexpected result from Execute: got %v wanted %v", qr, ExecuteQueryResult)
   413  	}
   414  }
   415  
   416  func testExecuteError(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   417  	t.Log("testExecuteError")
   418  	f.HasError = true
   419  	testErrorHelper(t, f, "Execute", func(ctx context.Context) error {
   420  		_, err := conn.Execute(ctx, TestTarget, ExecuteQuery, ExecuteBindVars, ExecuteTransactionID, ReserveConnectionID, TestExecuteOptions)
   421  		return err
   422  	})
   423  	f.HasError = false
   424  }
   425  
   426  func testExecutePanics(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   427  	t.Log("testExecutePanics")
   428  	testPanicHelper(t, f, "Execute", func(ctx context.Context) error {
   429  		_, err := conn.Execute(ctx, TestTarget, ExecuteQuery, ExecuteBindVars, ExecuteTransactionID, ReserveConnectionID, TestExecuteOptions)
   430  		return err
   431  	})
   432  }
   433  
   434  func testBeginExecute(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   435  	t.Log("testBeginExecute")
   436  	f.ExpectedTransactionID = beginTransactionID
   437  	ctx := context.Background()
   438  	ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID)
   439  	state, qr, err := conn.BeginExecute(ctx, TestTarget, nil, ExecuteQuery, ExecuteBindVars, ReserveConnectionID, TestExecuteOptions)
   440  	if err != nil {
   441  		t.Fatalf("BeginExecute failed: %v", err)
   442  	}
   443  	if state.TransactionID != beginTransactionID {
   444  		t.Errorf("Unexpected result from BeginExecute: got %v wanted %v", state.TransactionID, beginTransactionID)
   445  	}
   446  	if !qr.Equal(&ExecuteQueryResult) {
   447  		t.Errorf("Unexpected result from BeginExecute: got %v wanted %v", qr, ExecuteQueryResult)
   448  	}
   449  	assert.Equal(t, TestAlias, state.TabletAlias, "Unexpected tablet alias from Begin")
   450  }
   451  
   452  func testBeginExecuteErrorInBegin(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   453  	t.Log("testBeginExecuteErrorInBegin")
   454  	f.HasBeginError = true
   455  	testErrorHelper(t, f, "BeginExecute.Begin", func(ctx context.Context) error {
   456  		state, _, err := conn.BeginExecute(ctx, TestTarget, nil, ExecuteQuery, ExecuteBindVars, ReserveConnectionID, TestExecuteOptions)
   457  		if state.TransactionID != 0 {
   458  			t.Errorf("Unexpected transactionID from BeginExecute: got %v wanted 0", state.TransactionID)
   459  		}
   460  		return err
   461  	})
   462  	f.HasBeginError = false
   463  }
   464  
   465  func testBeginExecuteErrorInExecute(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   466  	t.Log("testBeginExecuteErrorInExecute")
   467  	f.HasError = true
   468  	testErrorHelper(t, f, "BeginExecute.Execute", func(ctx context.Context) error {
   469  		ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID)
   470  		state, _, err := conn.BeginExecute(ctx, TestTarget, nil, ExecuteQuery, ExecuteBindVars, ReserveConnectionID, TestExecuteOptions)
   471  		if state.TransactionID != beginTransactionID {
   472  			t.Errorf("Unexpected transactionID from BeginExecute: got %v wanted %v", state.TransactionID, beginTransactionID)
   473  		}
   474  		return err
   475  	})
   476  	f.HasError = false
   477  }
   478  
   479  func testBeginExecutePanics(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   480  	t.Log("testBeginExecutePanics")
   481  	testPanicHelper(t, f, "BeginExecute", func(ctx context.Context) error {
   482  		_, _, err := conn.BeginExecute(ctx, TestTarget, nil, ExecuteQuery, ExecuteBindVars, ReserveConnectionID, TestExecuteOptions)
   483  		return err
   484  	})
   485  }
   486  
   487  func testStreamExecute(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   488  	t.Log("testStreamExecute")
   489  	ctx := context.Background()
   490  	ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID)
   491  	i := 0
   492  	err := conn.StreamExecute(ctx, TestTarget, StreamExecuteQuery, StreamExecuteBindVars, 0, 0, TestExecuteOptions, func(qr *sqltypes.Result) error {
   493  		switch i {
   494  		case 0:
   495  			if len(qr.Rows) == 0 {
   496  				qr.Rows = nil
   497  			}
   498  			if !qr.Equal(&StreamExecuteQueryResult1) {
   499  				t.Errorf("Unexpected result1 from StreamExecute: got %v wanted %v", qr, StreamExecuteQueryResult1)
   500  			}
   501  		case 1:
   502  			if len(qr.Fields) == 0 {
   503  				qr.Fields = nil
   504  			}
   505  			if !qr.Equal(&StreamExecuteQueryResult2) {
   506  				t.Errorf("Unexpected result2 from StreamExecute: got %v wanted %v", qr, StreamExecuteQueryResult2)
   507  			}
   508  		default:
   509  			t.Fatal("callback should not be called any more")
   510  		}
   511  		i++
   512  		if i >= 2 {
   513  			return io.EOF
   514  		}
   515  		return nil
   516  	})
   517  	if err != nil {
   518  		t.Fatal(err)
   519  	}
   520  }
   521  
   522  func testStreamExecuteError(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   523  	t.Log("testStreamExecuteError")
   524  	f.HasError = true
   525  	testErrorHelper(t, f, "StreamExecute", func(ctx context.Context) error {
   526  		f.ErrorWait = make(chan struct{})
   527  		ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID)
   528  		return conn.StreamExecute(ctx, TestTarget, StreamExecuteQuery, StreamExecuteBindVars, 0, 0, TestExecuteOptions, func(qr *sqltypes.Result) error {
   529  			// For some errors, the call can be retried.
   530  			select {
   531  			case <-f.ErrorWait:
   532  				return nil
   533  			default:
   534  			}
   535  			if len(qr.Rows) == 0 {
   536  				qr.Rows = nil
   537  			}
   538  			if !qr.Equal(&StreamExecuteQueryResult1) {
   539  				t.Errorf("Unexpected result1 from StreamExecute: got %v wanted %v", qr, StreamExecuteQueryResult1)
   540  			}
   541  			// signal to the server that the first result has been received
   542  			close(f.ErrorWait)
   543  			return nil
   544  		})
   545  	})
   546  	f.HasError = false
   547  }
   548  
   549  func testStreamExecutePanics(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   550  	t.Log("testStreamExecutePanics")
   551  	// early panic is before sending the Fields, that is returned
   552  	// by the StreamExecute call itself, or as the first error
   553  	// by ErrFunc
   554  	f.StreamExecutePanicsEarly = true
   555  	testPanicHelper(t, f, "StreamExecute.Early", func(ctx context.Context) error {
   556  		ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID)
   557  		return conn.StreamExecute(ctx, TestTarget, StreamExecuteQuery, StreamExecuteBindVars, 0, 0, TestExecuteOptions, func(qr *sqltypes.Result) error {
   558  			return nil
   559  		})
   560  	})
   561  
   562  	// late panic is after sending Fields
   563  	f.StreamExecutePanicsEarly = false
   564  	testPanicHelper(t, f, "StreamExecute.Late", func(ctx context.Context) error {
   565  		f.PanicWait = make(chan struct{})
   566  		ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID)
   567  		return conn.StreamExecute(ctx, TestTarget, StreamExecuteQuery, StreamExecuteBindVars, 0, 0, TestExecuteOptions, func(qr *sqltypes.Result) error {
   568  			// For some errors, the call can be retried.
   569  			select {
   570  			case <-f.PanicWait:
   571  				return nil
   572  			default:
   573  			}
   574  			if len(qr.Rows) == 0 {
   575  				qr.Rows = nil
   576  			}
   577  			if !qr.Equal(&StreamExecuteQueryResult1) {
   578  				t.Errorf("Unexpected result1 from StreamExecute: got %v wanted %v", qr, StreamExecuteQueryResult1)
   579  			}
   580  			// signal to the server that the first result has been received
   581  			close(f.PanicWait)
   582  			return nil
   583  		})
   584  	})
   585  }
   586  
   587  func testBeginStreamExecute(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   588  	t.Log("testBeginStreamExecute")
   589  	ctx := context.Background()
   590  	ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID)
   591  	i := 0
   592  	_, err := conn.BeginStreamExecute(ctx, TestTarget, nil, StreamExecuteQuery, StreamExecuteBindVars, 0, TestExecuteOptions, func(qr *sqltypes.Result) error {
   593  		switch i {
   594  		case 0:
   595  			if len(qr.Rows) == 0 {
   596  				qr.Rows = nil
   597  			}
   598  			if !qr.Equal(&StreamExecuteQueryResult1) {
   599  				t.Errorf("Unexpected result1 from StreamExecute: got %v wanted %v", qr, StreamExecuteQueryResult1)
   600  			}
   601  		case 1:
   602  			if len(qr.Fields) == 0 {
   603  				qr.Fields = nil
   604  			}
   605  			if !qr.Equal(&StreamExecuteQueryResult2) {
   606  				t.Errorf("Unexpected result2 from StreamExecute: got %v wanted %v", qr, StreamExecuteQueryResult2)
   607  			}
   608  		default:
   609  			t.Fatal("callback should not be called any more")
   610  		}
   611  		i++
   612  		if i >= 2 {
   613  			return io.EOF
   614  		}
   615  		return nil
   616  	})
   617  	if err != nil {
   618  		t.Fatal(err)
   619  	}
   620  }
   621  
   622  func testReserveStreamExecute(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   623  	t.Log("testReserveStreamExecute")
   624  	ctx := context.Background()
   625  	ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID)
   626  	i := 0
   627  	_, err := conn.ReserveStreamExecute(ctx, TestTarget, nil, StreamExecuteQuery, StreamExecuteBindVars, 0, TestExecuteOptions, func(qr *sqltypes.Result) error {
   628  		switch i {
   629  		case 0:
   630  			if len(qr.Rows) == 0 {
   631  				qr.Rows = nil
   632  			}
   633  			if !qr.Equal(&StreamExecuteQueryResult1) {
   634  				t.Errorf("Unexpected result1 from StreamExecute: got %v wanted %v", qr, StreamExecuteQueryResult1)
   635  			}
   636  		case 1:
   637  			if len(qr.Fields) == 0 {
   638  				qr.Fields = nil
   639  			}
   640  			if !qr.Equal(&StreamExecuteQueryResult2) {
   641  				t.Errorf("Unexpected result2 from StreamExecute: got %v wanted %v", qr, StreamExecuteQueryResult2)
   642  			}
   643  		default:
   644  			t.Fatal("callback should not be called any more")
   645  		}
   646  		i++
   647  		if i >= 2 {
   648  			return io.EOF
   649  		}
   650  		return nil
   651  	})
   652  	if err != nil {
   653  		t.Fatal(err)
   654  	}
   655  }
   656  
   657  func testBeginStreamExecuteErrorInBegin(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   658  	t.Log("testBeginExecuteErrorInBegin")
   659  	f.HasBeginError = true
   660  	testErrorHelper(t, f, "StreamExecute", func(ctx context.Context) error {
   661  		f.ErrorWait = make(chan struct{})
   662  		ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID)
   663  		_, err := conn.BeginStreamExecute(ctx, TestTarget, nil, StreamExecuteQuery, StreamExecuteBindVars, 0, TestExecuteOptions, func(qr *sqltypes.Result) error {
   664  			// For some errors, the call can be retried.
   665  			select {
   666  			case <-f.ErrorWait:
   667  				return nil
   668  			default:
   669  			}
   670  			if len(qr.Rows) == 0 {
   671  				qr.Rows = nil
   672  			}
   673  			if !qr.Equal(&StreamExecuteQueryResult1) {
   674  				t.Errorf("Unexpected result1 from StreamExecute: got %v wanted %v", qr, StreamExecuteQueryResult1)
   675  			}
   676  			// signal to the server that the first result has been received
   677  			close(f.ErrorWait)
   678  			return nil
   679  		})
   680  		return err
   681  	})
   682  	f.HasBeginError = false
   683  }
   684  
   685  func testBeginStreamExecuteErrorInExecute(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   686  	t.Log("testBeginStreamExecuteErrorInExecute")
   687  	f.HasError = true
   688  	testErrorHelper(t, f, "StreamExecute", func(ctx context.Context) error {
   689  		f.ErrorWait = make(chan struct{})
   690  		ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID)
   691  		state, err := conn.BeginStreamExecute(ctx, TestTarget, nil, StreamExecuteQuery, StreamExecuteBindVars, 0, TestExecuteOptions, func(qr *sqltypes.Result) error {
   692  			// For some errors, the call can be retried.
   693  			select {
   694  			case <-f.ErrorWait:
   695  				return nil
   696  			default:
   697  			}
   698  			if len(qr.Rows) == 0 {
   699  				qr.Rows = nil
   700  			}
   701  			if !qr.Equal(&StreamExecuteQueryResult1) {
   702  				t.Errorf("Unexpected result1 from StreamExecute: got %v wanted %v", qr, StreamExecuteQueryResult1)
   703  			}
   704  			// signal to the server that the first result has been received
   705  			close(f.ErrorWait)
   706  			return nil
   707  		})
   708  		require.NotZero(t, state.TransactionID)
   709  		return err
   710  	})
   711  	f.HasError = false
   712  }
   713  
   714  func testReserveStreamExecuteErrorInReserve(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   715  	t.Log("testReserveExecuteErrorInReserve")
   716  	f.HasReserveError = true
   717  	testErrorHelper(t, f, "ReserveStreamExecute", func(ctx context.Context) error {
   718  		f.ErrorWait = make(chan struct{})
   719  		ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID)
   720  		_, err := conn.ReserveStreamExecute(ctx, TestTarget, nil, StreamExecuteQuery, StreamExecuteBindVars, 0, TestExecuteOptions, func(qr *sqltypes.Result) error {
   721  			// For some errors, the call can be retried.
   722  			select {
   723  			case <-f.ErrorWait:
   724  				return nil
   725  			default:
   726  			}
   727  			if len(qr.Rows) == 0 {
   728  				qr.Rows = nil
   729  			}
   730  			if !qr.Equal(&StreamExecuteQueryResult1) {
   731  				t.Errorf("Unexpected result1 from StreamExecute: got %v wanted %v", qr, StreamExecuteQueryResult1)
   732  			}
   733  			// signal to the server that the first result has been received
   734  			close(f.ErrorWait)
   735  			return nil
   736  		})
   737  		return err
   738  	})
   739  	f.HasReserveError = false
   740  }
   741  
   742  func testReserveStreamExecuteErrorInExecute(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   743  	t.Log("testReserveStreamExecuteErrorInExecute")
   744  	f.HasError = true
   745  	testErrorHelper(t, f, "ReserveStreamExecute", func(ctx context.Context) error {
   746  		f.ErrorWait = make(chan struct{})
   747  		ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID)
   748  		state, err := conn.ReserveStreamExecute(ctx, TestTarget, nil, StreamExecuteQuery, StreamExecuteBindVars, 0, TestExecuteOptions, func(qr *sqltypes.Result) error {
   749  			// For some errors, the call can be retried.
   750  			select {
   751  			case <-f.ErrorWait:
   752  				return nil
   753  			default:
   754  			}
   755  			if len(qr.Rows) == 0 {
   756  				qr.Rows = nil
   757  			}
   758  			if !qr.Equal(&StreamExecuteQueryResult1) {
   759  				t.Errorf("Unexpected result1 from StreamExecute: got %v wanted %v", qr, StreamExecuteQueryResult1)
   760  			}
   761  			// signal to the server that the first result has been received
   762  			close(f.ErrorWait)
   763  			return nil
   764  		})
   765  		require.NotZero(t, state.ReservedID)
   766  		return err
   767  	})
   768  	f.HasError = false
   769  }
   770  
   771  func testBeginStreamExecutePanics(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   772  	t.Log("testStreamExecutePanics")
   773  	// early panic is before sending the Fields, that is returned
   774  	// by the StreamExecute call itself, or as the first error
   775  	// by ErrFunc
   776  	f.StreamExecutePanicsEarly = true
   777  	testPanicHelper(t, f, "StreamExecute.Early", func(ctx context.Context) error {
   778  		ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID)
   779  		return conn.StreamExecute(ctx, TestTarget, StreamExecuteQuery, StreamExecuteBindVars, 0, 0, TestExecuteOptions, func(qr *sqltypes.Result) error {
   780  			return nil
   781  		})
   782  	})
   783  
   784  	// late panic is after sending Fields
   785  	f.StreamExecutePanicsEarly = false
   786  	testPanicHelper(t, f, "StreamExecute.Late", func(ctx context.Context) error {
   787  		f.PanicWait = make(chan struct{})
   788  		ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID)
   789  		_, err := conn.BeginStreamExecute(ctx, TestTarget, nil, StreamExecuteQuery, StreamExecuteBindVars, 0, TestExecuteOptions, func(qr *sqltypes.Result) error {
   790  			// For some errors, the call can be retried.
   791  			select {
   792  			case <-f.PanicWait:
   793  				return nil
   794  			default:
   795  			}
   796  			if len(qr.Rows) == 0 {
   797  				qr.Rows = nil
   798  			}
   799  			if !qr.Equal(&StreamExecuteQueryResult1) {
   800  				t.Errorf("Unexpected result1 from StreamExecute: got %v wanted %v", qr, StreamExecuteQueryResult1)
   801  			}
   802  			// signal to the server that the first result has been received
   803  			close(f.PanicWait)
   804  			return nil
   805  		})
   806  		return err
   807  	})
   808  }
   809  
   810  func testMessageStream(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   811  	t.Log("testMessageStream")
   812  	ctx := context.Background()
   813  	ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID)
   814  	var got *sqltypes.Result
   815  	err := conn.MessageStream(ctx, TestTarget, MessageName, func(qr *sqltypes.Result) error {
   816  		got = qr
   817  		return nil
   818  	})
   819  	if err != nil {
   820  		t.Fatalf("MessageStream failed: %v", err)
   821  	}
   822  	if !got.Equal(MessageStreamResult) {
   823  		t.Errorf("Unexpected result from MessageStream: got %v wanted %v", got, MessageStreamResult)
   824  	}
   825  }
   826  
   827  func testMessageStreamError(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   828  	t.Log("testMessageStreamError")
   829  	f.HasError = true
   830  	testErrorHelper(t, f, "MessageStream", func(ctx context.Context) error {
   831  		ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID)
   832  		return conn.MessageStream(ctx, TestTarget, MessageName, func(qr *sqltypes.Result) error { return nil })
   833  	})
   834  	f.HasError = false
   835  }
   836  
   837  func testMessageStreamPanics(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   838  	t.Log("testMessageStreamPanics")
   839  	testPanicHelper(t, f, "MessageStream", func(ctx context.Context) error {
   840  		err := conn.MessageStream(ctx, TestTarget, MessageName, func(qr *sqltypes.Result) error { return nil })
   841  		return err
   842  	})
   843  }
   844  
   845  func testMessageAck(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   846  	t.Log("testMessageAck")
   847  	ctx := context.Background()
   848  	ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID)
   849  	count, err := conn.MessageAck(ctx, TestTarget, MessageName, MessageIDs)
   850  	if err != nil {
   851  		t.Fatalf("MessageAck failed: %v", err)
   852  	}
   853  	if count != 1 {
   854  		t.Errorf("Unexpected result from MessageAck: got %v wanted 1", count)
   855  	}
   856  }
   857  
   858  func testMessageAckError(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   859  	t.Log("testMessageAckError")
   860  	f.HasError = true
   861  	testErrorHelper(t, f, "MessageAck", func(ctx context.Context) error {
   862  		ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID)
   863  		_, err := conn.MessageAck(ctx, TestTarget, MessageName, MessageIDs)
   864  		return err
   865  	})
   866  	f.HasError = false
   867  }
   868  
   869  func testMessageAckPanics(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   870  	t.Log("testMessageAckPanics")
   871  	testPanicHelper(t, f, "MessageAck", func(ctx context.Context) error {
   872  		_, err := conn.MessageAck(ctx, TestTarget, MessageName, MessageIDs)
   873  		return err
   874  	})
   875  }
   876  
   877  // this test is a bit of a hack: we write something on the channel
   878  // upon registration, and we also return an error, so the streaming query
   879  // ends right there. Otherwise we have no real way to trigger a real
   880  // communication error, that ends the streaming.
   881  func testStreamHealth(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   882  	t.Log("testStreamHealth")
   883  	ctx := context.Background()
   884  
   885  	var health *querypb.StreamHealthResponse
   886  	err := conn.StreamHealth(ctx, func(shr *querypb.StreamHealthResponse) error {
   887  		health = shr
   888  		return io.EOF
   889  	})
   890  	if err != nil {
   891  		t.Fatalf("StreamHealth failed: %v", err)
   892  	}
   893  	if !proto.Equal(health, TestStreamHealthStreamHealthResponse) {
   894  		t.Errorf("invalid StreamHealthResponse: got %v expected %v", health, TestStreamHealthStreamHealthResponse)
   895  	}
   896  }
   897  
   898  func testStreamHealthError(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   899  	t.Log("testStreamHealthError")
   900  	f.HasError = true
   901  	ctx := context.Background()
   902  	err := conn.StreamHealth(ctx, func(shr *querypb.StreamHealthResponse) error {
   903  		t.Fatalf("Unexpected call to callback")
   904  		return nil
   905  	})
   906  	if err == nil || !strings.Contains(err.Error(), TestStreamHealthErrorMsg) {
   907  		t.Fatalf("StreamHealth failed with the wrong error: %v", err)
   908  	}
   909  	f.HasError = false
   910  }
   911  
   912  func testStreamHealthPanics(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) {
   913  	t.Log("testStreamHealthPanics")
   914  	testPanicHelper(t, f, "StreamHealth", func(ctx context.Context) error {
   915  		return conn.StreamHealth(ctx, func(shr *querypb.StreamHealthResponse) error {
   916  			t.Fatalf("Unexpected call to callback")
   917  			return nil
   918  		})
   919  	})
   920  }
   921  
   922  // TestSuite runs all the tests.
   923  // If fake.TestingGateway is set, we only test the calls that can go through
   924  // a gateway.
   925  func TestSuite(t *testing.T, protocol string, tablet *topodatapb.Tablet, fake *FakeQueryService, clientCreds *os.File) {
   926  	tests := []func(*testing.T, queryservice.QueryService, *FakeQueryService){
   927  		// positive test cases
   928  		testBegin,
   929  		testCommit,
   930  		testRollback,
   931  		testPrepare,
   932  		testCommitPrepared,
   933  		testRollbackPrepared,
   934  		testCreateTransaction,
   935  		testStartCommit,
   936  		testSetRollback,
   937  		testConcludeTransaction,
   938  		testReadTransaction,
   939  		testExecute,
   940  		testBeginExecute,
   941  		testStreamExecute,
   942  		testBeginStreamExecute,
   943  		testMessageStream,
   944  		testMessageAck,
   945  		testReserveStreamExecute,
   946  
   947  		// error test cases
   948  		testBeginError,
   949  		testCommitError,
   950  		testRollbackError,
   951  		testPrepareError,
   952  		testCommitPreparedError,
   953  		testRollbackPreparedError,
   954  		testCreateTransactionError,
   955  		testStartCommitError,
   956  		testSetRollbackError,
   957  		testConcludeTransactionError,
   958  		testReadTransactionError,
   959  		testExecuteError,
   960  		testBeginExecuteErrorInBegin,
   961  		testBeginExecuteErrorInExecute,
   962  		testStreamExecuteError,
   963  		testBeginStreamExecuteErrorInBegin,
   964  		testBeginStreamExecuteErrorInExecute,
   965  		testReserveStreamExecuteErrorInReserve,
   966  		testReserveStreamExecuteErrorInExecute,
   967  		testMessageStreamError,
   968  		testMessageAckError,
   969  
   970  		// panic test cases
   971  		testBeginPanics,
   972  		testCommitPanics,
   973  		testRollbackPanics,
   974  		testPreparePanics,
   975  		testCommitPreparedPanics,
   976  		testRollbackPreparedPanics,
   977  		testCreateTransactionPanics,
   978  		testStartCommitPanics,
   979  		testSetRollbackPanics,
   980  		testConcludeTransactionPanics,
   981  		testReadTransactionPanics,
   982  		testExecutePanics,
   983  		testBeginExecutePanics,
   984  		testStreamExecutePanics,
   985  		testBeginStreamExecutePanics,
   986  		testMessageStreamPanics,
   987  		testMessageAckPanics,
   988  	}
   989  
   990  	if !fake.TestingGateway {
   991  		tests = append(tests, []func(*testing.T, queryservice.QueryService, *FakeQueryService){
   992  			// positive test cases
   993  			testStreamHealth,
   994  
   995  			// error test cases
   996  			testStreamHealthError,
   997  
   998  			// panic test cases
   999  			testStreamHealthPanics,
  1000  		}...)
  1001  	}
  1002  
  1003  	// make sure we use the right client
  1004  	SetProtocol(t.Name(), protocol)
  1005  
  1006  	// create a connection
  1007  	if clientCreds != nil {
  1008  		fs := pflag.NewFlagSet("", pflag.ContinueOnError)
  1009  		grpcclient.RegisterFlags(fs)
  1010  
  1011  		err := fs.Parse([]string{
  1012  			"--grpc_auth_static_client_creds",
  1013  			clientCreds.Name(),
  1014  		})
  1015  		require.NoError(t, err, "failed to set `--grpc_auth_static_client_creds=%s`", clientCreds.Name())
  1016  	}
  1017  
  1018  	conn, err := tabletconn.GetDialer()(tablet, grpcclient.FailFast(false))
  1019  	if err != nil {
  1020  		t.Fatalf("dial failed: %v", err)
  1021  	}
  1022  
  1023  	// run the tests
  1024  	for _, c := range tests {
  1025  		c(t, conn, fake)
  1026  	}
  1027  
  1028  	// and we're done
  1029  	conn.Close(context.Background())
  1030  }
  1031  
  1032  const tabletProtocolFlagName = "tablet_protocol"
  1033  
  1034  // SetProtocol is a helper function to set the tabletconn --tablet_protocol flag
  1035  // value for tests.
  1036  //
  1037  // Note that because this variable is bound to a flag, the effects of this
  1038  // function are global, not scoped to the calling test-case. Therefore it should
  1039  // not be used in conjunction with t.Parallel.
  1040  func SetProtocol(name string, protocol string) {
  1041  	var tmp []string
  1042  	tmp, os.Args = os.Args[:], []string{name}
  1043  	defer func() { os.Args = tmp }()
  1044  
  1045  	servenv.OnParseFor(name, func(fs *pflag.FlagSet) {
  1046  		if fs.Lookup(tabletProtocolFlagName) != nil {
  1047  			return
  1048  		}
  1049  
  1050  		tabletconn.RegisterFlags(fs)
  1051  	})
  1052  	servenv.ParseFlags(name)
  1053  
  1054  	if err := pflag.Set(tabletProtocolFlagName, protocol); err != nil {
  1055  		msg := "failed to set flag %q to %q: %v"
  1056  		log.Errorf(msg, tabletProtocolFlagName, protocol, err)
  1057  	}
  1058  }