vitess.io/vitess@v0.16.2/go/mysql/server_flaky_test.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package mysql
    18  
    19  import (
    20  	"context"
    21  	"crypto/tls"
    22  	"fmt"
    23  	"net"
    24  	"os"
    25  	"os/exec"
    26  	"path"
    27  	"strings"
    28  	"sync"
    29  	"testing"
    30  	"time"
    31  
    32  	"github.com/stretchr/testify/assert"
    33  	"github.com/stretchr/testify/require"
    34  
    35  	"vitess.io/vitess/go/sqltypes"
    36  	"vitess.io/vitess/go/test/utils"
    37  	vtenv "vitess.io/vitess/go/vt/env"
    38  	"vitess.io/vitess/go/vt/tlstest"
    39  	"vitess.io/vitess/go/vt/vterrors"
    40  	"vitess.io/vitess/go/vt/vttls"
    41  
    42  	querypb "vitess.io/vitess/go/vt/proto/query"
    43  	vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
    44  )
    45  
    46  var selectRowsResult = &sqltypes.Result{
    47  	Fields: []*querypb.Field{
    48  		{
    49  			Name: "id",
    50  			Type: querypb.Type_INT32,
    51  		},
    52  		{
    53  			Name: "name",
    54  			Type: querypb.Type_VARCHAR,
    55  		},
    56  	},
    57  	Rows: [][]sqltypes.Value{
    58  		{
    59  			sqltypes.MakeTrusted(querypb.Type_INT32, []byte("10")),
    60  			sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte("nice name")),
    61  		},
    62  		{
    63  			sqltypes.MakeTrusted(querypb.Type_INT32, []byte("20")),
    64  			sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte("nicer name")),
    65  		},
    66  	},
    67  }
    68  
    69  type testHandler struct {
    70  	UnimplementedHandler
    71  	mu       sync.Mutex
    72  	lastConn *Conn
    73  	result   *sqltypes.Result
    74  	err      error
    75  	warnings uint16
    76  }
    77  
    78  func (th *testHandler) LastConn() *Conn {
    79  	th.mu.Lock()
    80  	defer th.mu.Unlock()
    81  	return th.lastConn
    82  }
    83  
    84  func (th *testHandler) Result() *sqltypes.Result {
    85  	th.mu.Lock()
    86  	defer th.mu.Unlock()
    87  	return th.result
    88  }
    89  
    90  func (th *testHandler) SetErr(err error) {
    91  	th.mu.Lock()
    92  	defer th.mu.Unlock()
    93  	th.err = err
    94  }
    95  
    96  func (th *testHandler) Err() error {
    97  	th.mu.Lock()
    98  	defer th.mu.Unlock()
    99  	return th.err
   100  }
   101  
   102  func (th *testHandler) SetWarnings(count uint16) {
   103  	th.mu.Lock()
   104  	defer th.mu.Unlock()
   105  	th.warnings = count
   106  }
   107  
   108  func (th *testHandler) NewConnection(c *Conn) {
   109  	th.mu.Lock()
   110  	defer th.mu.Unlock()
   111  	th.lastConn = c
   112  }
   113  
   114  func (th *testHandler) ComQuery(c *Conn, query string, callback func(*sqltypes.Result) error) error {
   115  	if result := th.Result(); result != nil {
   116  		callback(result)
   117  		return nil
   118  	}
   119  
   120  	switch query {
   121  	case "error":
   122  		return th.Err()
   123  	case "panic":
   124  		panic("test panic attack!")
   125  	case "select rows":
   126  		callback(selectRowsResult)
   127  	case "error after send":
   128  		callback(selectRowsResult)
   129  		return th.Err()
   130  	case "insert":
   131  		callback(&sqltypes.Result{
   132  			RowsAffected: 123,
   133  			InsertID:     123456789,
   134  		})
   135  	case "schema echo":
   136  		callback(&sqltypes.Result{
   137  			Fields: []*querypb.Field{
   138  				{
   139  					Name: "schema_name",
   140  					Type: querypb.Type_VARCHAR,
   141  				},
   142  			},
   143  			Rows: [][]sqltypes.Value{
   144  				{
   145  					sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte(c.schemaName)),
   146  				},
   147  			},
   148  		})
   149  	case "ssl echo":
   150  		value := "OFF"
   151  		if c.Capabilities&CapabilityClientSSL > 0 {
   152  			value = "ON"
   153  		}
   154  		callback(&sqltypes.Result{
   155  			Fields: []*querypb.Field{
   156  				{
   157  					Name: "ssl_flag",
   158  					Type: querypb.Type_VARCHAR,
   159  				},
   160  			},
   161  			Rows: [][]sqltypes.Value{
   162  				{
   163  					sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte(value)),
   164  				},
   165  			},
   166  		})
   167  	case "userData echo":
   168  		callback(&sqltypes.Result{
   169  			Fields: []*querypb.Field{
   170  				{
   171  					Name: "user",
   172  					Type: querypb.Type_VARCHAR,
   173  				},
   174  				{
   175  					Name: "user_data",
   176  					Type: querypb.Type_VARCHAR,
   177  				},
   178  			},
   179  			Rows: [][]sqltypes.Value{
   180  				{
   181  					sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte(c.User)),
   182  					sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte(c.UserData.Get().Username)),
   183  				},
   184  			},
   185  		})
   186  	case "50ms delay":
   187  		callback(&sqltypes.Result{
   188  			Fields: []*querypb.Field{{
   189  				Name: "result",
   190  				Type: querypb.Type_VARCHAR,
   191  			}},
   192  		})
   193  		time.Sleep(50 * time.Millisecond)
   194  		callback(&sqltypes.Result{
   195  			Rows: [][]sqltypes.Value{{
   196  				sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte("delayed")),
   197  			}},
   198  		})
   199  	default:
   200  		if strings.HasPrefix(query, benchmarkQueryPrefix) {
   201  			callback(&sqltypes.Result{
   202  				Fields: []*querypb.Field{
   203  					{
   204  						Name: "result",
   205  						Type: querypb.Type_VARCHAR,
   206  					},
   207  				},
   208  				Rows: [][]sqltypes.Value{
   209  					{
   210  						sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte(query)),
   211  					},
   212  				},
   213  			})
   214  		}
   215  
   216  		callback(&sqltypes.Result{})
   217  	}
   218  	return nil
   219  }
   220  
   221  func (th *testHandler) ComPrepare(c *Conn, query string, bindVars map[string]*querypb.BindVariable) ([]*querypb.Field, error) {
   222  	return nil, nil
   223  }
   224  
   225  func (th *testHandler) ComStmtExecute(c *Conn, prepare *PrepareData, callback func(*sqltypes.Result) error) error {
   226  	return nil
   227  }
   228  
   229  func (th *testHandler) ComRegisterReplica(c *Conn, replicaHost string, replicaPort uint16, replicaUser string, replicaPassword string) error {
   230  	return nil
   231  }
   232  func (th *testHandler) ComBinlogDump(c *Conn, logFile string, binlogPos uint32) error {
   233  	return nil
   234  }
   235  func (th *testHandler) ComBinlogDumpGTID(c *Conn, logFile string, logPos uint64, gtidSet GTIDSet) error {
   236  	return nil
   237  }
   238  
   239  func (th *testHandler) WarningCount(c *Conn) uint16 {
   240  	th.mu.Lock()
   241  	defer th.mu.Unlock()
   242  	return th.warnings
   243  }
   244  
   245  func getHostPort(t *testing.T, a net.Addr) (string, int) {
   246  	host := a.(*net.TCPAddr).IP.String()
   247  	port := a.(*net.TCPAddr).Port
   248  	t.Logf("listening on address '%v' port %v", host, port)
   249  	return host, port
   250  }
   251  
   252  func TestConnectionFromListener(t *testing.T) {
   253  	th := &testHandler{}
   254  
   255  	authServer := NewAuthServerStatic("", "", 0)
   256  	authServer.entries["user1"] = []*AuthServerStaticEntry{{
   257  		Password: "password1",
   258  		UserData: "userData1",
   259  	}}
   260  	defer authServer.close()
   261  	// Make sure we can create our own net.Listener for use with the mysql
   262  	// listener
   263  	listener, err := net.Listen("tcp", "127.0.0.1:")
   264  	require.NoError(t, err, "net.Listener failed")
   265  
   266  	l, err := NewFromListener(listener, authServer, th, 0, 0, false)
   267  	require.NoError(t, err, "NewListener failed")
   268  	defer l.Close()
   269  	go l.Accept()
   270  
   271  	host, port := getHostPort(t, l.Addr())
   272  	fmt.Printf("host: %s, port: %d\n", host, port)
   273  	// Setup the right parameters.
   274  	params := &ConnParams{
   275  		Host:  host,
   276  		Port:  port,
   277  		Uname: "user1",
   278  		Pass:  "password1",
   279  	}
   280  
   281  	c, err := Connect(context.Background(), params)
   282  	require.NoError(t, err, "Should be able to connect to server")
   283  	c.Close()
   284  }
   285  
   286  func TestConnectionWithoutSourceHost(t *testing.T) {
   287  	th := &testHandler{}
   288  
   289  	authServer := NewAuthServerStatic("", "", 0)
   290  	authServer.entries["user1"] = []*AuthServerStaticEntry{{
   291  		Password: "password1",
   292  		UserData: "userData1",
   293  	}}
   294  	defer authServer.close()
   295  	l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false)
   296  	require.NoError(t, err, "NewListener failed")
   297  	defer l.Close()
   298  	go l.Accept()
   299  
   300  	host, port := getHostPort(t, l.Addr())
   301  
   302  	// Setup the right parameters.
   303  	params := &ConnParams{
   304  		Host:  host,
   305  		Port:  port,
   306  		Uname: "user1",
   307  		Pass:  "password1",
   308  	}
   309  
   310  	c, err := Connect(context.Background(), params)
   311  	require.NoError(t, err, "Should be able to connect to server")
   312  	c.Close()
   313  }
   314  
   315  func TestConnectionWithSourceHost(t *testing.T) {
   316  	th := &testHandler{}
   317  
   318  	authServer := NewAuthServerStatic("", "", 0)
   319  	authServer.entries["user1"] = []*AuthServerStaticEntry{
   320  		{
   321  			Password:   "password1",
   322  			UserData:   "userData1",
   323  			SourceHost: "localhost",
   324  		},
   325  	}
   326  	defer authServer.close()
   327  
   328  	l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false)
   329  	require.NoError(t, err, "NewListener failed")
   330  	defer l.Close()
   331  	go l.Accept()
   332  
   333  	host, port := getHostPort(t, l.Addr())
   334  
   335  	// Setup the right parameters.
   336  	params := &ConnParams{
   337  		Host:  host,
   338  		Port:  port,
   339  		Uname: "user1",
   340  		Pass:  "password1",
   341  	}
   342  
   343  	_, err = Connect(context.Background(), params)
   344  	// target is localhost, should not work from tcp connection
   345  	require.EqualError(t, err, "Access denied for user 'user1' (errno 1045) (sqlstate 28000)", "Should not be able to connect to server")
   346  }
   347  
   348  func TestConnectionUseMysqlNativePasswordWithSourceHost(t *testing.T) {
   349  	th := &testHandler{}
   350  
   351  	authServer := NewAuthServerStatic("", "", 0)
   352  	authServer.entries["user1"] = []*AuthServerStaticEntry{
   353  		{
   354  			MysqlNativePassword: "*9E128DA0C64A6FCCCDCFBDD0FC0A2C967C6DB36F",
   355  			UserData:            "userData1",
   356  			SourceHost:          "localhost",
   357  		},
   358  	}
   359  	defer authServer.close()
   360  
   361  	l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false)
   362  	require.NoError(t, err, "NewListener failed")
   363  	defer l.Close()
   364  	go l.Accept()
   365  
   366  	host, port := getHostPort(t, l.Addr())
   367  
   368  	// Setup the right parameters.
   369  	params := &ConnParams{
   370  		Host:  host,
   371  		Port:  port,
   372  		Uname: "user1",
   373  		Pass:  "mysql_password",
   374  	}
   375  
   376  	_, err = Connect(context.Background(), params)
   377  	// target is localhost, should not work from tcp connection
   378  	require.EqualError(t, err, "Access denied for user 'user1' (errno 1045) (sqlstate 28000)", "Should not be able to connect to server")
   379  }
   380  
   381  func TestConnectionUnixSocket(t *testing.T) {
   382  	th := &testHandler{}
   383  
   384  	authServer := NewAuthServerStatic("", "", 0)
   385  	authServer.entries["user1"] = []*AuthServerStaticEntry{
   386  		{
   387  			Password:   "password1",
   388  			UserData:   "userData1",
   389  			SourceHost: "localhost",
   390  		},
   391  	}
   392  	defer authServer.close()
   393  
   394  	unixSocket, err := os.CreateTemp("", "mysql_vitess_test.sock")
   395  	require.NoError(t, err, "Failed to create temp file")
   396  
   397  	os.Remove(unixSocket.Name())
   398  
   399  	l, err := NewListener("unix", unixSocket.Name(), authServer, th, 0, 0, false, false)
   400  	require.NoError(t, err, "NewListener failed")
   401  	defer l.Close()
   402  	go l.Accept()
   403  
   404  	// Setup the right parameters.
   405  	params := &ConnParams{
   406  		UnixSocket: unixSocket.Name(),
   407  		Uname:      "user1",
   408  		Pass:       "password1",
   409  	}
   410  
   411  	c, err := Connect(context.Background(), params)
   412  	require.NoError(t, err, "Should be able to connect to server")
   413  	c.Close()
   414  }
   415  
   416  func TestClientFoundRows(t *testing.T) {
   417  	th := &testHandler{}
   418  
   419  	authServer := NewAuthServerStatic("", "", 0)
   420  	authServer.entries["user1"] = []*AuthServerStaticEntry{{
   421  		Password: "password1",
   422  		UserData: "userData1",
   423  	}}
   424  	defer authServer.close()
   425  	l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false)
   426  	require.NoError(t, err, "NewListener failed")
   427  	defer l.Close()
   428  	go l.Accept()
   429  
   430  	host, port := getHostPort(t, l.Addr())
   431  
   432  	// Setup the right parameters.
   433  	params := &ConnParams{
   434  		Host:  host,
   435  		Port:  port,
   436  		Uname: "user1",
   437  		Pass:  "password1",
   438  	}
   439  
   440  	// Test without flag.
   441  	c, err := Connect(context.Background(), params)
   442  	require.NoError(t, err, "Connect failed")
   443  	foundRows := th.LastConn().Capabilities & CapabilityClientFoundRows
   444  	assert.Equal(t, uint32(0), foundRows, "FoundRows flag: %x, second bit must be 0", th.LastConn().Capabilities)
   445  	c.Close()
   446  	assert.True(t, c.IsClosed(), "IsClosed should be true on Close-d connection.")
   447  
   448  	// Test with flag.
   449  	params.Flags |= CapabilityClientFoundRows
   450  	c, err = Connect(context.Background(), params)
   451  	require.NoError(t, err, "Connect failed")
   452  	foundRows = th.LastConn().Capabilities & CapabilityClientFoundRows
   453  	assert.NotZero(t, foundRows, "FoundRows flag: %x, second bit must be set", th.LastConn().Capabilities)
   454  	c.Close()
   455  }
   456  
   457  func TestConnCounts(t *testing.T) {
   458  	th := &testHandler{}
   459  
   460  	initialNumUsers := len(connCountPerUser.Counts())
   461  
   462  	// FIXME: we should be able to ResetAll counters instead of computing a delta, but it doesn't work for some reason
   463  	// connCountPerUser.ResetAll()
   464  
   465  	user := "anotherNotYetConnectedUser1"
   466  	passwd := "password1"
   467  
   468  	authServer := NewAuthServerStatic("", "", 0)
   469  	authServer.entries[user] = []*AuthServerStaticEntry{{
   470  		Password: passwd,
   471  		UserData: "userData1",
   472  	}}
   473  	defer authServer.close()
   474  	l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false)
   475  	require.NoError(t, err, "NewListener failed")
   476  	defer l.Close()
   477  	go l.Accept()
   478  
   479  	host, port := getHostPort(t, l.Addr())
   480  
   481  	// Test with one new connection.
   482  	params := &ConnParams{
   483  		Host:  host,
   484  		Port:  port,
   485  		Uname: user,
   486  		Pass:  passwd,
   487  	}
   488  
   489  	c, err := Connect(context.Background(), params)
   490  	require.NoError(t, err, "Connect failed")
   491  
   492  	connCounts := connCountPerUser.Counts()
   493  	assert.Equal(t, 1, len(connCounts)-initialNumUsers)
   494  	checkCountsForUser(t, user, 1)
   495  
   496  	// Test with a second new connection.
   497  	c2, err := Connect(context.Background(), params)
   498  	require.NoError(t, err)
   499  	connCounts = connCountPerUser.Counts()
   500  	// There is still only one new user.
   501  	assert.Equal(t, 1, len(connCounts)-initialNumUsers)
   502  	checkCountsForUser(t, user, 2)
   503  
   504  	// Test after closing connections. time.Sleep lets it work, but seems flakey.
   505  	c.Close()
   506  	//time.Sleep(10 * time.Millisecond)
   507  	//checkCountsForUser(t, user, 1)
   508  
   509  	c2.Close()
   510  	//time.Sleep(10 * time.Millisecond)
   511  	//checkCountsForUser(t, user, 0)
   512  }
   513  
   514  func checkCountsForUser(t *testing.T, user string, expected int64) {
   515  	connCounts := connCountPerUser.Counts()
   516  
   517  	userCount, ok := connCounts[user]
   518  	assert.True(t, ok, "No count found for user %s", user)
   519  	assert.Equal(t, expected, userCount)
   520  }
   521  
   522  func TestServer(t *testing.T) {
   523  	th := &testHandler{}
   524  
   525  	authServer := NewAuthServerStatic("", "", 0)
   526  	authServer.entries["user1"] = []*AuthServerStaticEntry{{
   527  		Password: "password1",
   528  		UserData: "userData1",
   529  	}}
   530  	defer authServer.close()
   531  	l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false)
   532  	require.NoError(t, err)
   533  	l.SlowConnectWarnThreshold.Set(time.Nanosecond * 1)
   534  	defer l.Close()
   535  	go l.Accept()
   536  
   537  	host, port := getHostPort(t, l.Addr())
   538  
   539  	// Setup the right parameters.
   540  	params := &ConnParams{
   541  		Host:  host,
   542  		Port:  port,
   543  		Uname: "user1",
   544  		Pass:  "password1",
   545  	}
   546  
   547  	// Run a 'select rows' command with results.
   548  	output, err := runMysqlWithErr(t, params, "select rows")
   549  	require.NoError(t, err)
   550  
   551  	assert.Contains(t, output, "nice name", "Unexpected output for 'select rows'")
   552  	assert.Contains(t, output, "nicer name", "Unexpected output for 'select rows'")
   553  	assert.Contains(t, output, "2 rows in set", "Unexpected output for 'select rows'")
   554  	assert.NotContains(t, output, "warnings")
   555  
   556  	// Run a 'select rows' command with warnings
   557  	th.SetWarnings(13)
   558  	output, err = runMysqlWithErr(t, params, "select rows")
   559  	require.NoError(t, err)
   560  	assert.Contains(t, output, "nice name", "Unexpected output for 'select rows'")
   561  	assert.Contains(t, output, "nicer name", "Unexpected output for 'select rows'")
   562  	assert.Contains(t, output, "2 rows in set", "Unexpected output for 'select rows'")
   563  	assert.Contains(t, output, "13 warnings", "Unexpected output for 'select rows'")
   564  	th.SetWarnings(0)
   565  
   566  	// If there's an error after streaming has started,
   567  	// we should get a 2013
   568  	th.SetErr(NewSQLError(ERUnknownComError, SSNetError, "forced error after send"))
   569  	output, err = runMysqlWithErr(t, params, "error after send")
   570  	require.Error(t, err)
   571  	assert.Contains(t, output, "ERROR 2013 (HY000)", "Unexpected output for 'panic'")
   572  	// MariaDB might not print the MySQL bit here
   573  	assert.Regexp(t, `Lost connection to( MySQL)? server during query`, output, "Unexpected output for 'panic': %v", output)
   574  
   575  	// Run an 'insert' command, no rows, but rows affected.
   576  	output, err = runMysqlWithErr(t, params, "insert")
   577  	require.NoError(t, err)
   578  	assert.Contains(t, output, "Query OK, 123 rows affected", "Unexpected output for 'insert'")
   579  
   580  	// Run a 'schema echo' command, to make sure db name is right.
   581  	params.DbName = "XXXfancyXXX"
   582  	output, err = runMysqlWithErr(t, params, "schema echo")
   583  	require.NoError(t, err)
   584  	assert.Contains(t, output, params.DbName, "Unexpected output for 'schema echo'")
   585  
   586  	// Sanity check: make sure this didn't go through SSL
   587  	output, err = runMysqlWithErr(t, params, "ssl echo")
   588  	require.NoError(t, err)
   589  	assert.Contains(t, output, "ssl_flag")
   590  	assert.Contains(t, output, "OFF")
   591  	assert.Contains(t, output, "1 row in set", "Unexpected output for 'ssl echo': %v", output)
   592  
   593  	// UserData check: checks the server user data is correct.
   594  	output, err = runMysqlWithErr(t, params, "userData echo")
   595  	require.NoError(t, err)
   596  	assert.Contains(t, output, "user1")
   597  	assert.Contains(t, output, "user_data")
   598  	assert.Contains(t, output, "userData1", "Unexpected output for 'userData echo': %v", output)
   599  
   600  	// Permissions check: check a bad password is rejected.
   601  	params.Pass = "bad"
   602  	output, err = runMysqlWithErr(t, params, "select rows")
   603  	require.Error(t, err)
   604  	assert.Contains(t, output, "1045")
   605  	assert.Contains(t, output, "28000")
   606  	assert.Contains(t, output, "Access denied", "Unexpected output for invalid password: %v", output)
   607  
   608  	// Permissions check: check an unknown user is rejected.
   609  	params.Pass = "password1"
   610  	params.Uname = "user2"
   611  	output, err = runMysqlWithErr(t, params, "select rows")
   612  	require.Error(t, err)
   613  	assert.Contains(t, output, "1045")
   614  	assert.Contains(t, output, "28000")
   615  	assert.Contains(t, output, "Access denied", "Unexpected output for invalid password: %v", output)
   616  
   617  	// Uncomment to leave setup up for a while, to run tests manually.
   618  	//	fmt.Printf("Listening to server on host '%v' port '%v'.\n", host, port)
   619  	//	time.Sleep(60 * time.Minute)
   620  }
   621  
   622  func TestServerStats(t *testing.T) {
   623  	th := &testHandler{}
   624  
   625  	authServer := NewAuthServerStatic("", "", 0)
   626  	authServer.entries["user1"] = []*AuthServerStaticEntry{{
   627  		Password: "password1",
   628  		UserData: "userData1",
   629  	}}
   630  	defer authServer.close()
   631  	l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false)
   632  	require.NoError(t, err)
   633  	l.SlowConnectWarnThreshold.Set(time.Nanosecond * 1)
   634  	defer l.Close()
   635  	go l.Accept()
   636  
   637  	host, port := getHostPort(t, l.Addr())
   638  
   639  	// Setup the right parameters.
   640  	params := &ConnParams{
   641  		Host:  host,
   642  		Port:  port,
   643  		Uname: "user1",
   644  		Pass:  "password1",
   645  	}
   646  
   647  	timings.Reset()
   648  	connAccept.Reset()
   649  	connCount.Reset()
   650  	connSlow.Reset()
   651  	connRefuse.Reset()
   652  
   653  	// Run an 'error' command.
   654  	th.SetErr(NewSQLError(ERUnknownComError, SSNetError, "forced query error"))
   655  	output, ok := runMysql(t, params, "error")
   656  	require.False(t, ok, "mysql should have failed: %v", output)
   657  
   658  	assert.Contains(t, output, "ERROR 1047 (08S01)")
   659  	assert.Contains(t, output, "forced query error", "Unexpected output for 'error': %v", output)
   660  
   661  	assert.EqualValues(t, 0, connCount.Get(), "connCount")
   662  	assert.EqualValues(t, 1, connAccept.Get(), "connAccept")
   663  	assert.EqualValues(t, 1, connSlow.Get(), "connSlow")
   664  	assert.EqualValues(t, 0, connRefuse.Get(), "connRefuse")
   665  
   666  	expectedTimingDeltas := map[string]int64{
   667  		"All":            2,
   668  		connectTimingKey: 1,
   669  		queryTimingKey:   1,
   670  	}
   671  	gotTimingCounts := timings.Counts()
   672  	for key, got := range gotTimingCounts {
   673  		expected := expectedTimingDeltas[key]
   674  		assert.GreaterOrEqual(t, got, expected, "Expected Timing count delta %s should be >= %d, got %d", key, expected, got)
   675  	}
   676  
   677  	// Set the slow connect threshold to something high that we don't expect to trigger
   678  	l.SlowConnectWarnThreshold.Set(time.Second * 1)
   679  
   680  	// Run a 'panic' command, other side should panic, recover and
   681  	// close the connection.
   682  	output, err = runMysqlWithErr(t, params, "panic")
   683  	require.Error(t, err)
   684  	assert.Contains(t, output, "ERROR 2013 (HY000)")
   685  	// MariaDB might not print the MySQL bit here
   686  	assert.Regexp(t, `Lost connection to( MySQL)? server during query`, output, "Unexpected output for 'panic': %v", output)
   687  
   688  	assert.EqualValues(t, 0, connCount.Get(), "connCount")
   689  	assert.EqualValues(t, 2, connAccept.Get(), "connAccept")
   690  	assert.EqualValues(t, 1, connSlow.Get(), "connSlow")
   691  	assert.EqualValues(t, 0, connRefuse.Get(), "connRefuse")
   692  }
   693  
   694  // TestClearTextServer creates a Server that needs clear text
   695  // passwords from the client.
   696  func TestClearTextServer(t *testing.T) {
   697  	th := &testHandler{}
   698  
   699  	authServer := NewAuthServerStaticWithAuthMethodDescription("", "", 0, MysqlClearPassword)
   700  	authServer.entries["user1"] = []*AuthServerStaticEntry{{
   701  		Password: "password1",
   702  		UserData: "userData1",
   703  	}}
   704  	defer authServer.close()
   705  	l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false)
   706  	require.NoError(t, err)
   707  	defer l.Close()
   708  	go l.Accept()
   709  
   710  	host, port := getHostPort(t, l.Addr())
   711  
   712  	version, _ := runMysql(t, nil, "--version")
   713  	isMariaDB := strings.Contains(version, "MariaDB")
   714  
   715  	// Setup the right parameters.
   716  	params := &ConnParams{
   717  		Host:  host,
   718  		Port:  port,
   719  		Uname: "user1",
   720  		Pass:  "password1",
   721  	}
   722  
   723  	// Run a 'select rows' command with results.  This should fail
   724  	// as clear text is not enabled by default on the client
   725  	// (except MariaDB).
   726  	l.AllowClearTextWithoutTLS.Set(true)
   727  	sql := "select rows"
   728  	output, ok := runMysql(t, params, sql)
   729  	if ok {
   730  		if isMariaDB {
   731  			t.Logf("mysql should have failed but returned: %v\nbut letting it go on MariaDB", output)
   732  		} else {
   733  			require.Fail(t, "mysql should have failed but returned: %v", output)
   734  		}
   735  	} else {
   736  		if strings.Contains(output, "No such file or directory") {
   737  			t.Logf("skipping mysql clear text tests, as the clear text plugin cannot be loaded: %v", err)
   738  			return
   739  		}
   740  		assert.Contains(t, output, "plugin not enabled", "Unexpected output for 'select rows': %v", output)
   741  	}
   742  
   743  	// Now enable clear text plugin in client, but server requires SSL.
   744  	l.AllowClearTextWithoutTLS.Set(false)
   745  	if !isMariaDB {
   746  		sql = enableCleartextPluginPrefix + sql
   747  	}
   748  	output, ok = runMysql(t, params, sql)
   749  	assert.False(t, ok, "mysql should have failed but returned: %v", output)
   750  	assert.Contains(t, output, "Cannot use clear text authentication over non-SSL connections", "Unexpected output for 'select rows': %v", output)
   751  
   752  	// Now enable clear text plugin, it should now work.
   753  	l.AllowClearTextWithoutTLS.Set(true)
   754  	output, ok = runMysql(t, params, sql)
   755  	require.True(t, ok, "mysql failed: %v", output)
   756  
   757  	assert.Contains(t, output, "nice name", "Unexpected output for 'select rows'")
   758  	assert.Contains(t, output, "nicer name", "Unexpected output for 'select rows'")
   759  	assert.Contains(t, output, "2 rows in set", "Unexpected output for 'select rows'")
   760  
   761  	// Change password, make sure server rejects us.
   762  	params.Pass = "bad"
   763  	output, ok = runMysql(t, params, sql)
   764  	assert.False(t, ok, "mysql should have failed but returned: %v", output)
   765  	assert.Contains(t, output, "Access denied for user 'user1'", "Unexpected output for 'select rows': %v", output)
   766  }
   767  
   768  // TestDialogServer creates a Server that uses the dialog plugin on the client.
   769  func TestDialogServer(t *testing.T) {
   770  	th := &testHandler{}
   771  
   772  	authServer := NewAuthServerStaticWithAuthMethodDescription("", "", 0, MysqlDialog)
   773  	authServer.entries["user1"] = []*AuthServerStaticEntry{{
   774  		Password: "password1",
   775  		UserData: "userData1",
   776  	}}
   777  	defer authServer.close()
   778  	l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false)
   779  	require.NoError(t, err)
   780  	l.AllowClearTextWithoutTLS.Set(true)
   781  	defer l.Close()
   782  	go l.Accept()
   783  
   784  	host, port := getHostPort(t, l.Addr())
   785  
   786  	// Setup the right parameters.
   787  	params := &ConnParams{
   788  		Host:    host,
   789  		Port:    port,
   790  		Uname:   "user1",
   791  		Pass:    "password1",
   792  		SslMode: vttls.Disabled,
   793  	}
   794  	sql := "select rows"
   795  	output, ok := runMysql(t, params, sql)
   796  	if strings.Contains(output, "No such file or directory") || strings.Contains(output, "Authentication plugin 'dialog' cannot be loaded") {
   797  		t.Logf("skipping dialog plugin tests, as the dialog plugin cannot be loaded: %v", err)
   798  		return
   799  	}
   800  	require.True(t, ok, "mysql failed: %v", output)
   801  	assert.Contains(t, output, "nice name", "Unexpected output for 'select rows': %v", output)
   802  	assert.Contains(t, output, "nicer name", "Unexpected output for 'select rows': %v", output)
   803  	assert.Contains(t, output, "2 rows in set", "Unexpected output for 'select rows': %v", output)
   804  }
   805  
   806  // TestTLSServer creates a Server with TLS support, then uses mysql
   807  // client to connect to it.
   808  func TestTLSServer(t *testing.T) {
   809  	th := &testHandler{}
   810  
   811  	authServer := NewAuthServerStatic("", "", 0)
   812  	authServer.entries["user1"] = []*AuthServerStaticEntry{{
   813  		Password: "password1",
   814  	}}
   815  	defer authServer.close()
   816  
   817  	// Create the listener, so we can get its host.
   818  	// Below, we are enabling --ssl-verify-server-cert, which adds
   819  	// a check that the common name of the certificate matches the
   820  	// server host name we connect to.
   821  	l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false)
   822  	require.NoError(t, err)
   823  	defer l.Close()
   824  
   825  	host := l.Addr().(*net.TCPAddr).IP.String()
   826  	port := l.Addr().(*net.TCPAddr).Port
   827  
   828  	// Create the certs.
   829  	root := t.TempDir()
   830  	tlstest.CreateCA(root)
   831  	tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com")
   832  	tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert")
   833  
   834  	// Create the server with TLS config.
   835  	serverConfig, err := vttls.ServerConfig(
   836  		path.Join(root, "server-cert.pem"),
   837  		path.Join(root, "server-key.pem"),
   838  		path.Join(root, "ca-cert.pem"),
   839  		"",
   840  		"",
   841  		tls.VersionTLS12)
   842  	require.NoError(t, err)
   843  	l.TLSConfig.Store(serverConfig)
   844  
   845  	var wg sync.WaitGroup
   846  	wg.Add(1)
   847  	go func(l *Listener) {
   848  		wg.Done()
   849  		l.Accept()
   850  	}(l)
   851  	// This is ensure the listener is called
   852  	wg.Wait()
   853  	// Sleep so that the Accept function is called as well.'
   854  	time.Sleep(3 * time.Second)
   855  
   856  	connCountByTLSVer.ResetAll()
   857  	// Setup the right parameters.
   858  	params := &ConnParams{
   859  		Host:  host,
   860  		Port:  port,
   861  		Uname: "user1",
   862  		Pass:  "password1",
   863  		// SSL flags.
   864  		SslMode:    vttls.VerifyIdentity,
   865  		SslCa:      path.Join(root, "ca-cert.pem"),
   866  		SslCert:    path.Join(root, "client-cert.pem"),
   867  		SslKey:     path.Join(root, "client-key.pem"),
   868  		ServerName: "server.example.com",
   869  	}
   870  
   871  	// Run a 'select rows' command with results.
   872  	conn, err := Connect(context.Background(), params)
   873  	//output, ok := runMysql(t, params, "select rows")
   874  	require.NoError(t, err)
   875  	results, err := conn.ExecuteFetch("select rows", 1000, true)
   876  	require.NoError(t, err)
   877  	output := ""
   878  	for _, row := range results.Rows {
   879  		r := make([]string, 0)
   880  		for _, col := range row {
   881  			r = append(r, col.String())
   882  		}
   883  		output = output + strings.Join(r, ",") + "\n"
   884  	}
   885  
   886  	assert.Equal(t, "nice name", results.Rows[0][1].ToString())
   887  	assert.Equal(t, "nicer name", results.Rows[1][1].ToString())
   888  	assert.Equal(t, 2, len(results.Rows))
   889  
   890  	// make sure this went through SSL
   891  	results, err = conn.ExecuteFetch("ssl echo", 1000, true)
   892  	require.NoError(t, err)
   893  	assert.Equal(t, "ON", results.Rows[0][0].ToString())
   894  
   895  	// Find out which TLS version the connection actually used,
   896  	// so we can check that the corresponding counter was incremented.
   897  	tlsVersion := conn.conn.(*tls.Conn).ConnectionState().Version
   898  
   899  	checkCountForTLSVer(t, tlsVersionToString(tlsVersion), 1)
   900  	conn.Close()
   901  
   902  }
   903  
   904  // TestTLSRequired creates a Server with TLS required, then tests that an insecure mysql
   905  // client is rejected
   906  func TestTLSRequired(t *testing.T) {
   907  	th := &testHandler{}
   908  
   909  	authServer := NewAuthServerStatic("", "", 0)
   910  	authServer.entries["user1"] = []*AuthServerStaticEntry{{
   911  		Password: "password1",
   912  	}}
   913  	defer authServer.close()
   914  
   915  	// Create the listener, so we can get its host.
   916  	// Below, we are enabling --ssl-verify-server-cert, which adds
   917  	// a check that the common name of the certificate matches the
   918  	// server host name we connect to.
   919  	l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false)
   920  	require.NoError(t, err)
   921  	defer l.Close()
   922  
   923  	host := l.Addr().(*net.TCPAddr).IP.String()
   924  	port := l.Addr().(*net.TCPAddr).Port
   925  
   926  	// Create the certs.
   927  	root := t.TempDir()
   928  	tlstest.CreateCA(root)
   929  	tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com")
   930  	tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert")
   931  	tlstest.CreateSignedCert(root, tlstest.CA, "03", "revoked-client", "Revoked Client Cert")
   932  	tlstest.RevokeCertAndRegenerateCRL(root, tlstest.CA, "revoked-client")
   933  
   934  	// Create the server with TLS config.
   935  	serverConfig, err := vttls.ServerConfig(
   936  		path.Join(root, "server-cert.pem"),
   937  		path.Join(root, "server-key.pem"),
   938  		path.Join(root, "ca-cert.pem"),
   939  		path.Join(root, "ca-crl.pem"),
   940  		"",
   941  		tls.VersionTLS12)
   942  	require.NoError(t, err)
   943  	l.TLSConfig.Store(serverConfig)
   944  	l.RequireSecureTransport = true
   945  
   946  	var wg sync.WaitGroup
   947  	wg.Add(1)
   948  	go func(l *Listener) {
   949  		wg.Done()
   950  		l.Accept()
   951  	}(l)
   952  	// This is ensure the listener is called
   953  	wg.Wait()
   954  	// Sleep so that the Accept function is called as well.'
   955  	time.Sleep(3 * time.Second)
   956  
   957  	// Setup conn params without SSL.
   958  	params := &ConnParams{
   959  		Host:       host,
   960  		Port:       port,
   961  		Uname:      "user1",
   962  		Pass:       "password1",
   963  		SslMode:    vttls.Disabled,
   964  		ServerName: "server.example.com",
   965  	}
   966  	conn, err := Connect(context.Background(), params)
   967  	require.NotNil(t, err)
   968  	require.Contains(t, err.Error(), "Code: UNAVAILABLE")
   969  	require.Contains(t, err.Error(), "server does not allow insecure connections, client must use SSL/TLS")
   970  	require.Contains(t, err.Error(), "(errno 1105) (sqlstate HY000)")
   971  	if conn != nil {
   972  		conn.Close()
   973  	}
   974  
   975  	// setup conn params with TLS
   976  	params.SslMode = vttls.VerifyIdentity
   977  	params.SslCa = path.Join(root, "ca-cert.pem")
   978  	params.SslCert = path.Join(root, "client-cert.pem")
   979  	params.SslKey = path.Join(root, "client-key.pem")
   980  
   981  	conn, err = Connect(context.Background(), params)
   982  	require.NoError(t, err)
   983  	if conn != nil {
   984  		conn.Close()
   985  	}
   986  
   987  	// setup conn params with TLS, but with a revoked client certificate
   988  	params.SslCert = path.Join(root, "revoked-client-cert.pem")
   989  	params.SslKey = path.Join(root, "revoked-client-key.pem")
   990  	conn, err = Connect(context.Background(), params)
   991  	require.NotNil(t, err)
   992  	require.Contains(t, err.Error(), "remote error: tls: bad certificate")
   993  	if conn != nil {
   994  		conn.Close()
   995  	}
   996  }
   997  
   998  func TestCachingSha2PasswordAuthWithTLS(t *testing.T) {
   999  	th := &testHandler{}
  1000  
  1001  	authServer := NewAuthServerStaticWithAuthMethodDescription("", "", 0, CachingSha2Password)
  1002  	authServer.entries["user1"] = []*AuthServerStaticEntry{
  1003  		{Password: "password1"},
  1004  	}
  1005  	defer authServer.close()
  1006  
  1007  	// Create the listener, so we can get its host.
  1008  	l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false)
  1009  	require.NoError(t, err, "NewListener failed: %v", err)
  1010  	defer l.Close()
  1011  	host := l.Addr().(*net.TCPAddr).IP.String()
  1012  	port := l.Addr().(*net.TCPAddr).Port
  1013  
  1014  	// Create the certs.
  1015  	root := t.TempDir()
  1016  	tlstest.CreateCA(root)
  1017  	tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com")
  1018  	tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert")
  1019  
  1020  	// Create the server with TLS config.
  1021  	serverConfig, err := vttls.ServerConfig(
  1022  		path.Join(root, "server-cert.pem"),
  1023  		path.Join(root, "server-key.pem"),
  1024  		path.Join(root, "ca-cert.pem"),
  1025  		"",
  1026  		"",
  1027  		tls.VersionTLS12)
  1028  	require.NoError(t, err, "TLSServerConfig failed: %v", err)
  1029  
  1030  	l.TLSConfig.Store(serverConfig)
  1031  	go func() {
  1032  		l.Accept()
  1033  	}()
  1034  
  1035  	// Setup the right parameters.
  1036  	params := &ConnParams{
  1037  		Host:  host,
  1038  		Port:  port,
  1039  		Uname: "user1",
  1040  		Pass:  "password1",
  1041  		// SSL flags.
  1042  		SslMode:    vttls.VerifyIdentity,
  1043  		SslCa:      path.Join(root, "ca-cert.pem"),
  1044  		SslCert:    path.Join(root, "client-cert.pem"),
  1045  		SslKey:     path.Join(root, "client-key.pem"),
  1046  		ServerName: "server.example.com",
  1047  	}
  1048  
  1049  	// Connection should fail, as server requires SSL for caching_sha2_password.
  1050  	ctx := context.Background()
  1051  
  1052  	conn, err := Connect(ctx, params)
  1053  	require.NoError(t, err, "unexpected connection error: %v", err)
  1054  
  1055  	defer conn.Close()
  1056  
  1057  	// Run a 'select rows' command with results.
  1058  	result, err := conn.ExecuteFetch("select rows", 10000, true)
  1059  	require.NoError(t, err, "ExecuteFetch failed: %v", err)
  1060  
  1061  	utils.MustMatch(t, result, selectRowsResult)
  1062  
  1063  	// Send a ComQuit to avoid the error message on the server side.
  1064  	conn.writeComQuit()
  1065  }
  1066  
  1067  type alwaysFallbackAuth struct{}
  1068  
  1069  func (a *alwaysFallbackAuth) UserEntryWithCacheHash(conn *Conn, salt []byte, user string, authResponse []byte, remoteAddr net.Addr) (Getter, CacheState, error) {
  1070  	return &StaticUserData{}, AuthNeedMoreData, nil
  1071  }
  1072  
  1073  // newAuthServerAlwaysFallback returns a new empty AuthServerStatic
  1074  // which will always request more data to trigger fallback auth path
  1075  // for caching sha2.
  1076  func newAuthServerAlwaysFallback(file, jsonConfig string, reloadInterval time.Duration) *AuthServerStatic {
  1077  	a := &AuthServerStatic{
  1078  		file:           file,
  1079  		jsonConfig:     jsonConfig,
  1080  		reloadInterval: reloadInterval,
  1081  		entries:        make(map[string][]*AuthServerStaticEntry),
  1082  	}
  1083  
  1084  	authMethod := NewSha2CachingAuthMethod(&alwaysFallbackAuth{}, a, a)
  1085  	a.methods = []AuthMethod{authMethod}
  1086  
  1087  	a.reload()
  1088  	a.installSignalHandlers()
  1089  	return a
  1090  }
  1091  
  1092  func TestCachingSha2PasswordAuthWithMoreData(t *testing.T) {
  1093  	th := &testHandler{}
  1094  
  1095  	authServer := newAuthServerAlwaysFallback("", "", 0)
  1096  	authServer.entries["user1"] = []*AuthServerStaticEntry{
  1097  		{Password: "password1"},
  1098  	}
  1099  	defer authServer.close()
  1100  
  1101  	// Create the listener, so we can get its host.
  1102  	l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false)
  1103  	require.NoError(t, err, "NewListener failed: %v", err)
  1104  	defer l.Close()
  1105  	host := l.Addr().(*net.TCPAddr).IP.String()
  1106  	port := l.Addr().(*net.TCPAddr).Port
  1107  
  1108  	// Create the certs.
  1109  	root := t.TempDir()
  1110  	tlstest.CreateCA(root)
  1111  	tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com")
  1112  	tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert")
  1113  
  1114  	// Create the server with TLS config.
  1115  	serverConfig, err := vttls.ServerConfig(
  1116  		path.Join(root, "server-cert.pem"),
  1117  		path.Join(root, "server-key.pem"),
  1118  		path.Join(root, "ca-cert.pem"),
  1119  		"",
  1120  		"",
  1121  		tls.VersionTLS12)
  1122  	require.NoError(t, err, "TLSServerConfig failed: %v", err)
  1123  
  1124  	l.TLSConfig.Store(serverConfig)
  1125  	go func() {
  1126  		l.Accept()
  1127  	}()
  1128  
  1129  	// Setup the right parameters.
  1130  	params := &ConnParams{
  1131  		Host:  host,
  1132  		Port:  port,
  1133  		Uname: "user1",
  1134  		Pass:  "password1",
  1135  		// SSL flags.
  1136  		SslMode:    vttls.VerifyIdentity,
  1137  		SslCa:      path.Join(root, "ca-cert.pem"),
  1138  		SslCert:    path.Join(root, "client-cert.pem"),
  1139  		SslKey:     path.Join(root, "client-key.pem"),
  1140  		ServerName: "server.example.com",
  1141  	}
  1142  
  1143  	// Connection should fail, as server requires SSL for caching_sha2_password.
  1144  	ctx := context.Background()
  1145  
  1146  	conn, err := Connect(ctx, params)
  1147  	require.NoError(t, err, "unexpected connection error: %v", err)
  1148  
  1149  	defer conn.Close()
  1150  
  1151  	// Run a 'select rows' command with results.
  1152  	result, err := conn.ExecuteFetch("select rows", 10000, true)
  1153  	require.NoError(t, err, "ExecuteFetch failed: %v", err)
  1154  
  1155  	utils.MustMatch(t, result, selectRowsResult)
  1156  
  1157  	// Send a ComQuit to avoid the error message on the server side.
  1158  	conn.writeComQuit()
  1159  }
  1160  
  1161  func TestCachingSha2PasswordAuthWithoutTLS(t *testing.T) {
  1162  	th := &testHandler{}
  1163  
  1164  	authServer := NewAuthServerStaticWithAuthMethodDescription("", "", 0, CachingSha2Password)
  1165  	authServer.entries["user1"] = []*AuthServerStaticEntry{
  1166  		{Password: "password1"},
  1167  	}
  1168  	defer authServer.close()
  1169  
  1170  	// Create the listener.
  1171  	l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false)
  1172  	require.NoError(t, err, "NewListener failed: %v", err)
  1173  	defer l.Close()
  1174  	host := l.Addr().(*net.TCPAddr).IP.String()
  1175  	port := l.Addr().(*net.TCPAddr).Port
  1176  	go func() {
  1177  		l.Accept()
  1178  	}()
  1179  
  1180  	// Setup the right parameters.
  1181  	params := &ConnParams{
  1182  		Host:    host,
  1183  		Port:    port,
  1184  		Uname:   "user1",
  1185  		Pass:    "password1",
  1186  		SslMode: vttls.Disabled,
  1187  	}
  1188  
  1189  	// Connection should fail, as server requires SSL for caching_sha2_password.
  1190  	ctx := context.Background()
  1191  	_, err = Connect(ctx, params)
  1192  	if err == nil || !strings.Contains(err.Error(), "No authentication methods available for authentication") {
  1193  		t.Fatalf("unexpected connection error: %v", err)
  1194  	}
  1195  }
  1196  
  1197  func checkCountForTLSVer(t *testing.T, version string, expected int64) {
  1198  	connCounts := connCountByTLSVer.Counts()
  1199  	count, ok := connCounts[version]
  1200  	assert.True(t, ok, "No count found for version %s", version)
  1201  	assert.Equal(t, expected, count, "Unexpected connection count for version %s", version)
  1202  }
  1203  
  1204  func TestErrorCodes(t *testing.T) {
  1205  	th := &testHandler{}
  1206  
  1207  	authServer := NewAuthServerStatic("", "", 0)
  1208  	authServer.entries["user1"] = []*AuthServerStaticEntry{{
  1209  		Password: "password1",
  1210  		UserData: "userData1",
  1211  	}}
  1212  	defer authServer.close()
  1213  	l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false)
  1214  	require.NoError(t, err)
  1215  	defer l.Close()
  1216  	go l.Accept()
  1217  
  1218  	host, port := getHostPort(t, l.Addr())
  1219  
  1220  	// Setup the right parameters.
  1221  	params := &ConnParams{
  1222  		Host:  host,
  1223  		Port:  port,
  1224  		Uname: "user1",
  1225  		Pass:  "password1",
  1226  	}
  1227  
  1228  	ctx := context.Background()
  1229  	client, err := Connect(ctx, params)
  1230  	require.NoError(t, err)
  1231  
  1232  	// Test that the right mysql errno/sqlstate are returned for various
  1233  	// internal vitess errors
  1234  	tests := []struct {
  1235  		err      error
  1236  		code     int
  1237  		sqlState string
  1238  		text     string
  1239  	}{
  1240  		{
  1241  			err: vterrors.Errorf(
  1242  				vtrpcpb.Code_INVALID_ARGUMENT,
  1243  				"invalid argument"),
  1244  			code:     ERUnknownError,
  1245  			sqlState: SSUnknownSQLState,
  1246  			text:     "invalid argument",
  1247  		},
  1248  		{
  1249  			err: vterrors.Errorf(
  1250  				vtrpcpb.Code_INVALID_ARGUMENT,
  1251  				"(errno %v) (sqlstate %v) invalid argument with errno", ERDupEntry, SSConstraintViolation),
  1252  			code:     ERDupEntry,
  1253  			sqlState: SSConstraintViolation,
  1254  			text:     "invalid argument with errno",
  1255  		},
  1256  		{
  1257  			err: vterrors.Errorf(
  1258  				vtrpcpb.Code_DEADLINE_EXCEEDED,
  1259  				"connection deadline exceeded"),
  1260  			code:     ERQueryInterrupted,
  1261  			sqlState: SSQueryInterrupted,
  1262  			text:     "deadline exceeded",
  1263  		},
  1264  		{
  1265  			err: vterrors.Errorf(
  1266  				vtrpcpb.Code_RESOURCE_EXHAUSTED,
  1267  				"query pool timeout"),
  1268  			code:     ERTooManyUserConnections,
  1269  			sqlState: SSClientError,
  1270  			text:     "resource exhausted",
  1271  		},
  1272  		{
  1273  			err:      vterrors.Wrap(vterrors.Errorf(vtrpcpb.Code_ABORTED, "Row count exceeded 10000"), "wrapped"),
  1274  			code:     ERQueryInterrupted,
  1275  			sqlState: SSQueryInterrupted,
  1276  			text:     "aborted",
  1277  		},
  1278  	}
  1279  
  1280  	for _, test := range tests {
  1281  		t.Run(test.err.Error(), func(t *testing.T) {
  1282  			th.SetErr(NewSQLErrorFromError(test.err))
  1283  			rs, err := client.ExecuteFetch("error", 100, false)
  1284  			require.Error(t, err, "mysql should have failed but returned: %v", rs)
  1285  			serr, ok := err.(*SQLError)
  1286  			require.True(t, ok, "mysql should have returned a SQLError")
  1287  
  1288  			assert.Equal(t, test.code, serr.Number(), "error in %s: want code %v got %v", test.text, test.code, serr.Number())
  1289  			assert.Equal(t, test.sqlState, serr.SQLState(), "error in %s: want sqlState %v got %v", test.text, test.sqlState, serr.SQLState())
  1290  			assert.Contains(t, serr.Error(), test.err.Error())
  1291  		})
  1292  	}
  1293  }
  1294  
  1295  const enableCleartextPluginPrefix = "enable-cleartext-plugin: "
  1296  
  1297  // runMysql forks a mysql command line process connecting to the provided server.
  1298  func runMysql(t *testing.T, params *ConnParams, command string) (string, bool) {
  1299  	output, err := runMysqlWithErr(t, params, command)
  1300  	if err != nil {
  1301  		return output, false
  1302  	}
  1303  	return output, true
  1304  
  1305  }
  1306  func runMysqlWithErr(t *testing.T, params *ConnParams, command string) (string, error) {
  1307  	dir, err := vtenv.VtMysqlRoot()
  1308  	require.NoError(t, err)
  1309  	name, err := binaryPath(dir, "mysql")
  1310  	require.NoError(t, err)
  1311  	// The args contain '-v' 3 times, to switch to very verbose output.
  1312  	// In particular, it has the message:
  1313  	// Query OK, 1 row affected (0.00 sec)
  1314  	args := []string{
  1315  		"-v", "-v", "-v",
  1316  	}
  1317  	if strings.HasPrefix(command, enableCleartextPluginPrefix) {
  1318  		command = command[len(enableCleartextPluginPrefix):]
  1319  		args = append(args, "--enable-cleartext-plugin")
  1320  	}
  1321  	if command == "--version" {
  1322  		args = append(args, command)
  1323  	} else {
  1324  		args = append(args, "-e", command)
  1325  		if params.UnixSocket != "" {
  1326  			args = append(args, "-S", params.UnixSocket)
  1327  		} else {
  1328  			args = append(args,
  1329  				"-h", params.Host,
  1330  				"-P", fmt.Sprintf("%v", params.Port))
  1331  		}
  1332  		if params.Uname != "" {
  1333  			args = append(args, "-u", params.Uname)
  1334  		}
  1335  		if params.Pass != "" {
  1336  			args = append(args, "-p"+params.Pass)
  1337  		}
  1338  		if params.DbName != "" {
  1339  			args = append(args, "-D", params.DbName)
  1340  		}
  1341  		if params.SslEnabled() {
  1342  			args = append(args,
  1343  				"--ssl",
  1344  				"--ssl-ca", params.SslCa,
  1345  				"--ssl-cert", params.SslCert,
  1346  				"--ssl-key", params.SslKey,
  1347  				"--ssl-verify-server-cert")
  1348  		}
  1349  	}
  1350  	env := []string{
  1351  		"LD_LIBRARY_PATH=" + path.Join(dir, "lib/mysql"),
  1352  	}
  1353  
  1354  	t.Logf("Running mysql command: %v %v", name, args)
  1355  	cmd := exec.Command(name, args...)
  1356  	cmd.Env = env
  1357  	cmd.Dir = dir
  1358  	out, err := cmd.CombinedOutput()
  1359  	output := string(out)
  1360  	if err != nil {
  1361  		return output, err
  1362  	}
  1363  	return output, nil
  1364  }
  1365  
  1366  // binaryPath does a limited path lookup for a command,
  1367  // searching only within sbin and bin in the given root.
  1368  //
  1369  // FIXME(alainjobart) move this to vt/env, and use it from
  1370  // go/vt/mysqlctl too.
  1371  func binaryPath(root, binary string) (string, error) {
  1372  	subdirs := []string{"sbin", "bin"}
  1373  	for _, subdir := range subdirs {
  1374  		binPath := path.Join(root, subdir, binary)
  1375  		if _, err := os.Stat(binPath); err == nil {
  1376  			return binPath, nil
  1377  		}
  1378  	}
  1379  	return "", fmt.Errorf("%s not found in any of %s/{%s}",
  1380  		binary, root, strings.Join(subdirs, ","))
  1381  }
  1382  
  1383  func TestListenerShutdown(t *testing.T) {
  1384  	th := &testHandler{}
  1385  	authServer := NewAuthServerStatic("", "", 0)
  1386  	authServer.entries["user1"] = []*AuthServerStaticEntry{{
  1387  		Password: "password1",
  1388  		UserData: "userData1",
  1389  	}}
  1390  	defer authServer.close()
  1391  	l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false)
  1392  	require.NoError(t, err)
  1393  	defer l.Close()
  1394  	go l.Accept()
  1395  
  1396  	host, port := getHostPort(t, l.Addr())
  1397  
  1398  	// Setup the right parameters.
  1399  	params := &ConnParams{
  1400  		Host:  host,
  1401  		Port:  port,
  1402  		Uname: "user1",
  1403  		Pass:  "password1",
  1404  	}
  1405  	connRefuse.Reset()
  1406  
  1407  	ctx, cancel := context.WithCancel(context.Background())
  1408  	defer cancel()
  1409  
  1410  	conn, err := Connect(ctx, params)
  1411  	require.NoError(t, err)
  1412  
  1413  	err = conn.Ping()
  1414  	require.NoError(t, err)
  1415  
  1416  	l.Shutdown()
  1417  
  1418  	assert.EqualValues(t, 1, connRefuse.Get(), "connRefuse")
  1419  
  1420  	err = conn.Ping()
  1421  	require.EqualError(t, err, "Server shutdown in progress (errno 1053) (sqlstate 08S01)")
  1422  	sqlErr, ok := err.(*SQLError)
  1423  	require.True(t, ok, "Wrong error type: %T", err)
  1424  
  1425  	require.Equal(t, ERServerShutdown, sqlErr.Number())
  1426  	require.Equal(t, SSNetError, sqlErr.SQLState())
  1427  	require.Equal(t, "Server shutdown in progress", sqlErr.Message)
  1428  }
  1429  
  1430  func TestParseConnAttrs(t *testing.T) {
  1431  	expected := map[string]string{
  1432  		"_client_version": "8.0.11",
  1433  		"program_name":    "mysql",
  1434  		"_pid":            "22850",
  1435  		"_platform":       "x86_64",
  1436  		"_os":             "linux-glibc2.12",
  1437  		"_client_name":    "libmysql",
  1438  	}
  1439  
  1440  	data := []byte{0x70, 0x04, 0x5f, 0x70, 0x69, 0x64, 0x05, 0x32, 0x32, 0x38, 0x35, 0x30, 0x09, 0x5f, 0x70, 0x6c,
  1441  		0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x06, 0x78, 0x38, 0x36, 0x5f, 0x36, 0x34, 0x03, 0x5f, 0x6f,
  1442  		0x73, 0x0f, 0x6c, 0x69, 0x6e, 0x75, 0x78, 0x2d, 0x67, 0x6c, 0x69, 0x62, 0x63, 0x32, 0x2e, 0x31,
  1443  		0x32, 0x0c, 0x5f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x08, 0x6c,
  1444  		0x69, 0x62, 0x6d, 0x79, 0x73, 0x71, 0x6c, 0x0f, 0x5f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f,
  1445  		0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x06, 0x38, 0x2e, 0x30, 0x2e, 0x31, 0x31, 0x0c, 0x70,
  1446  		0x72, 0x6f, 0x67, 0x72, 0x61, 0x6d, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x05, 0x6d, 0x79, 0x73, 0x71, 0x6c}
  1447  
  1448  	attrs, pos, err := parseConnAttrs(data, 0)
  1449  	require.NoError(t, err)
  1450  	require.Equal(t, 113, pos)
  1451  	for k, v := range expected {
  1452  		val, ok := attrs[k]
  1453  		require.True(t, ok, "Error reading key %s from connection attributes: attrs: %-v", k, attrs)
  1454  		require.Equal(t, v, val, "Unexpected value found in attrs for key %s", k)
  1455  	}
  1456  }
  1457  
  1458  func TestServerFlush(t *testing.T) {
  1459  	defer func(saved time.Duration) { mysqlServerFlushDelay = saved }(mysqlServerFlushDelay)
  1460  	mysqlServerFlushDelay = 10 * time.Millisecond
  1461  
  1462  	th := &testHandler{}
  1463  
  1464  	l, err := NewListener("tcp", "127.0.0.1:", NewAuthServerNone(), th, 0, 0, false, false)
  1465  	require.NoError(t, err)
  1466  	defer l.Close()
  1467  	go l.Accept()
  1468  
  1469  	host, port := getHostPort(t, l.Addr())
  1470  	params := &ConnParams{
  1471  		Host: host,
  1472  		Port: port,
  1473  	}
  1474  
  1475  	c, err := Connect(context.Background(), params)
  1476  	require.NoError(t, err)
  1477  	defer c.Close()
  1478  
  1479  	start := time.Now()
  1480  	err = c.ExecuteStreamFetch("50ms delay")
  1481  	require.NoError(t, err)
  1482  
  1483  	flds, err := c.Fields()
  1484  	require.NoError(t, err)
  1485  	if duration, want := time.Since(start), 20*time.Millisecond; duration < mysqlServerFlushDelay || duration > want {
  1486  		assert.Fail(t, "duration out of expected range", "duration: %v, want between %v and %v", duration.String(), (mysqlServerFlushDelay).String(), want.String())
  1487  	}
  1488  	want1 := []*querypb.Field{{
  1489  		Name: "result",
  1490  		Type: querypb.Type_VARCHAR,
  1491  	}}
  1492  	assert.Equal(t, want1, flds)
  1493  
  1494  	row, err := c.FetchNext(nil)
  1495  	require.NoError(t, err)
  1496  	if duration, want := time.Since(start), 50*time.Millisecond; duration < want {
  1497  		assert.Fail(t, "duration is too low", "duration: %v, want > %v", duration, want)
  1498  	}
  1499  	want2 := []sqltypes.Value{sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte("delayed"))}
  1500  	assert.Equal(t, want2, row)
  1501  
  1502  	row, err = c.FetchNext(nil)
  1503  	require.NoError(t, err)
  1504  	assert.Nil(t, row)
  1505  }