vitess.io/vitess@v0.16.2/go/mysql/endtoend/client_test.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package endtoend
    18  
    19  import (
    20  	"context"
    21  	"fmt"
    22  	"strings"
    23  	"testing"
    24  	"time"
    25  
    26  	"github.com/stretchr/testify/assert"
    27  
    28  	"github.com/stretchr/testify/require"
    29  
    30  	"vitess.io/vitess/go/mysql"
    31  )
    32  
    33  // TestKill opens a connection, issues a command that
    34  // will sleep for a few seconds, waits a bit for MySQL to start
    35  // executing it, then kills the connection (using another
    36  // connection). We make sure we get the right error code.
    37  func TestKill(t *testing.T) {
    38  	ctx := context.Background()
    39  	conn, err := mysql.Connect(ctx, &connParams)
    40  	if err != nil {
    41  		t.Fatal(err)
    42  	}
    43  
    44  	// Create the kill connection first. It sometimes takes longer
    45  	// than 10s
    46  	killConn, err := mysql.Connect(ctx, &connParams)
    47  	if err != nil {
    48  		t.Fatal(err)
    49  	}
    50  	defer killConn.Close()
    51  
    52  	errChan := make(chan error)
    53  	go func() {
    54  		_, err = conn.ExecuteFetch("select sleep(10) from dual", 1000, false)
    55  		errChan <- err
    56  		close(errChan)
    57  	}()
    58  
    59  	// Give extra time for the query to start executing.
    60  	time.Sleep(2 * time.Second)
    61  	if _, err := killConn.ExecuteFetch(fmt.Sprintf("kill %v", conn.ConnectionID), 1000, false); err != nil {
    62  		t.Fatalf("Kill(%v) failed: %v", conn.ConnectionID, err)
    63  	}
    64  
    65  	// The error text will depend on what ExecuteFetch in the go
    66  	// routine managed to do. Two cases:
    67  	// 1. the connection was closed before the go routine's ExecuteFetch
    68  	//   entered the ReadFull call. Then we get an error with
    69  	//   'connection reset by peer' in it.
    70  	// 2. the connection was closed while the go routine's ExecuteFetch
    71  	//   was stuck on the read. Then we get io.EOF.
    72  	// The code and sqlState needs to be right in any case, the text
    73  	// will differ.
    74  	err = <-errChan
    75  	if strings.Contains(err.Error(), "EOF") {
    76  		assertSQLError(t, err, mysql.CRServerLost, mysql.SSUnknownSQLState, "EOF", "select sleep(10) from dual")
    77  	} else {
    78  		assertSQLError(t, err, mysql.CRServerLost, mysql.SSUnknownSQLState, "", "connection reset by peer")
    79  	}
    80  }
    81  
    82  // TestKill2006 opens a connection, kills the
    83  // connection from the server side, then waits a bit, and tries to
    84  // execute a command. We make sure we get the right error code.
    85  func TestKill2006(t *testing.T) {
    86  	ctx := context.Background()
    87  	conn, err := mysql.Connect(ctx, &connParams)
    88  	if err != nil {
    89  		t.Fatal(err)
    90  	}
    91  
    92  	// Kill the connection from the server side.
    93  	killConn, err := mysql.Connect(ctx, &connParams)
    94  	if err != nil {
    95  		t.Fatal(err)
    96  	}
    97  	defer killConn.Close()
    98  
    99  	if _, err := killConn.ExecuteFetch(fmt.Sprintf("kill %v", conn.ConnectionID), 1000, false); err != nil {
   100  		t.Fatalf("Kill(%v) failed: %v", conn.ConnectionID, err)
   101  	}
   102  
   103  	// Now we should get a CRServerGone.  Since we are using a
   104  	// unix socket, we will get a broken pipe when the server
   105  	// closes the connection and we are trying to write the command.
   106  	_, err = conn.ExecuteFetch("select sleep(10) from dual", 1000, false)
   107  	assertSQLError(t, err, mysql.CRServerGone, mysql.SSUnknownSQLState, "broken pipe", "select sleep(10) from dual")
   108  }
   109  
   110  // TestDupEntry tests a duplicate key is properly raised.
   111  func TestDupEntry(t *testing.T) {
   112  	ctx := context.Background()
   113  	conn, err := mysql.Connect(ctx, &connParams)
   114  	if err != nil {
   115  		t.Fatal(err)
   116  	}
   117  	defer conn.Close()
   118  
   119  	if _, err := conn.ExecuteFetch("create table dup_entry(id int, name int, primary key(id), unique index(name))", 0, false); err != nil {
   120  		t.Fatalf("create table failed: %v", err)
   121  	}
   122  	if _, err := conn.ExecuteFetch("insert into dup_entry(id, name) values(1, 10)", 0, false); err != nil {
   123  		t.Fatalf("first insert failed: %v", err)
   124  	}
   125  	_, err = conn.ExecuteFetch("insert into dup_entry(id, name) values(2, 10)", 0, false)
   126  	assertSQLError(t, err, mysql.ERDupEntry, mysql.SSConstraintViolation, "Duplicate entry", "insert into dup_entry(id, name) values(2, 10)")
   127  }
   128  
   129  // TestClientFoundRows tests if the CLIENT_FOUND_ROWS flag works.
   130  func TestClientFoundRows(t *testing.T) {
   131  	params := connParams
   132  	params.EnableClientFoundRows()
   133  
   134  	ctx := context.Background()
   135  	conn, err := mysql.Connect(ctx, &params)
   136  	if err != nil {
   137  		t.Fatal(err)
   138  	}
   139  	defer conn.Close()
   140  
   141  	if _, err := conn.ExecuteFetch("create table found_rows(id int, val int, primary key(id))", 0, false); err != nil {
   142  		t.Fatalf("create table failed: %v", err)
   143  	}
   144  	if _, err := conn.ExecuteFetch("insert into found_rows(id, val) values(1, 10)", 0, false); err != nil {
   145  		t.Fatalf("insert failed: %v", err)
   146  	}
   147  	qr, err := conn.ExecuteFetch("update found_rows set val=11 where id=1", 0, false)
   148  	require.NoError(t, err)
   149  	assert.EqualValues(t, 1, qr.RowsAffected, "RowsAffected")
   150  
   151  	qr, err = conn.ExecuteFetch("update found_rows set val=11 where id=1", 0, false)
   152  	require.NoError(t, err)
   153  	assert.EqualValues(t, 1, qr.RowsAffected, "RowsAffected")
   154  }
   155  
   156  func doTestMultiResult(t *testing.T, disableClientDeprecateEOF bool) {
   157  	ctx := context.Background()
   158  	connParams.DisableClientDeprecateEOF = disableClientDeprecateEOF
   159  
   160  	conn, err := mysql.Connect(ctx, &connParams)
   161  	expectNoError(t, err)
   162  	defer conn.Close()
   163  
   164  	qr, more, err := conn.ExecuteFetchMulti("select 1 from dual; set autocommit=1; select 1 from dual", 10, true)
   165  	expectNoError(t, err)
   166  	expectFlag(t, "ExecuteMultiFetch(multi result)", more, true)
   167  	assert.EqualValues(t, 1, len(qr.Rows))
   168  
   169  	qr, more, _, err = conn.ReadQueryResult(10, true)
   170  	expectNoError(t, err)
   171  	expectFlag(t, "ReadQueryResult(1)", more, true)
   172  	assert.EqualValues(t, 0, len(qr.Rows))
   173  
   174  	qr, more, _, err = conn.ReadQueryResult(10, true)
   175  	expectNoError(t, err)
   176  	expectFlag(t, "ReadQueryResult(2)", more, false)
   177  	assert.EqualValues(t, 1, len(qr.Rows))
   178  
   179  	qr, more, err = conn.ExecuteFetchMulti("select 1 from dual", 10, true)
   180  	expectNoError(t, err)
   181  	expectFlag(t, "ExecuteMultiFetch(single result)", more, false)
   182  	assert.EqualValues(t, 1, len(qr.Rows))
   183  
   184  	qr, more, err = conn.ExecuteFetchMulti("set autocommit=1", 10, true)
   185  	expectNoError(t, err)
   186  	expectFlag(t, "ExecuteMultiFetch(no result)", more, false)
   187  	assert.EqualValues(t, 0, len(qr.Rows))
   188  
   189  	// The ClientDeprecateEOF protocol change has a subtle twist in which an EOF or OK
   190  	// packet happens to have the status flags in the same position if the affected_rows
   191  	// and last_insert_id are both one byte long:
   192  	//
   193  	// https://dev.mysql.com/doc/internals/en/packet-EOF_Packet.html
   194  	// https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html
   195  	//
   196  	// It turns out that there are no actual cases in which clients end up needing to make
   197  	// this distinction. If either affected_rows or last_insert_id are non-zero, the protocol
   198  	// sends an OK packet unilaterally which is properly parsed. If not, then regardless of the
   199  	// negotiated version, it can properly send the status flags.
   200  	//
   201  	result, err := conn.ExecuteFetch("create table a(id int, name varchar(128), primary key(id))", 0, false)
   202  	require.NoError(t, err)
   203  	assert.Zero(t, result.RowsAffected, "create table RowsAffected ")
   204  
   205  	for i := 0; i < 255; i++ {
   206  		result, err := conn.ExecuteFetch(fmt.Sprintf("insert into a(id, name) values(%v, 'nice name %v')", 1000+i, i), 1000, true)
   207  		require.NoError(t, err)
   208  		assert.EqualValues(t, 1, result.RowsAffected, "insert into returned RowsAffected")
   209  	}
   210  
   211  	qr, more, err = conn.ExecuteFetchMulti("update a set name = concat(name, ' updated'); select * from a; select count(*) from a", 300, true)
   212  	expectNoError(t, err)
   213  	expectFlag(t, "ExecuteMultiFetch(multi result)", more, true)
   214  	assert.EqualValues(t, 255, qr.RowsAffected)
   215  
   216  	qr, more, _, err = conn.ReadQueryResult(300, true)
   217  	expectNoError(t, err)
   218  	expectFlag(t, "ReadQueryResult(1)", more, true)
   219  	assert.EqualValues(t, 255, len(qr.Rows), "ReadQueryResult(1)")
   220  
   221  	qr, more, _, err = conn.ReadQueryResult(300, true)
   222  	expectNoError(t, err)
   223  	expectFlag(t, "ReadQueryResult(2)", more, false)
   224  	assert.EqualValues(t, 1, len(qr.Rows), "ReadQueryResult(1)")
   225  
   226  	_, err = conn.ExecuteFetch("drop table a", 10, true)
   227  	require.NoError(t, err)
   228  }
   229  
   230  func TestMultiResultDeprecateEOF(t *testing.T) {
   231  	doTestMultiResult(t, false)
   232  }
   233  func TestMultiResultNoDeprecateEOF(t *testing.T) {
   234  	doTestMultiResult(t, true)
   235  }
   236  
   237  func expectNoError(t *testing.T, err error) {
   238  	t.Helper()
   239  	if err != nil {
   240  		t.Fatal(err)
   241  	}
   242  }
   243  
   244  func expectFlag(t *testing.T, msg string, flag, want bool) {
   245  	t.Helper()
   246  	require.Equal(t, want, flag, "%s: %v, want: %v", msg, flag, want)
   247  
   248  }
   249  
   250  // TestTLS tests our client can connect via SSL.
   251  func TestTLS(t *testing.T) {
   252  	params := connParams
   253  	params.EnableSSL()
   254  
   255  	// First make sure the official 'mysql' client can connect.
   256  	output, ok := runMysql(t, &params, "status")
   257  	require.True(t, ok, "'mysql -e status' failed: %v", output)
   258  	require.True(t, strings.Contains(output, "Cipher in use is"), "cannot connect via SSL: %v", output)
   259  
   260  	// Now connect with our client.
   261  	ctx := context.Background()
   262  	conn, err := mysql.Connect(ctx, &params)
   263  	if err != nil {
   264  		t.Fatal(err)
   265  	}
   266  	defer conn.Close()
   267  
   268  	result, err := conn.ExecuteFetch("SHOW STATUS LIKE 'Ssl_cipher'", 10, true)
   269  	require.NoError(t, err, "SHOW STATUS LIKE 'Ssl_cipher' failed: %v", err)
   270  
   271  	if len(result.Rows) != 1 || result.Rows[0][0].ToString() != "Ssl_cipher" ||
   272  		result.Rows[0][1].ToString() == "" {
   273  		t.Fatalf("SHOW STATUS LIKE 'Ssl_cipher' returned unexpected result: %v", result)
   274  	}
   275  }
   276  
   277  func TestReplicationStatus(t *testing.T) {
   278  	params := connParams
   279  	ctx := context.Background()
   280  	conn, err := mysql.Connect(ctx, &params)
   281  	if err != nil {
   282  		t.Fatal(err)
   283  	}
   284  	defer conn.Close()
   285  
   286  	status, err := conn.ShowReplicationStatus()
   287  	assert.Equal(t, mysql.ErrNotReplica, err, "Got unexpected result for ShowReplicationStatus: %v %v", status, err)
   288  
   289  }
   290  
   291  func TestSessionTrackGTIDs(t *testing.T) {
   292  	ctx := context.Background()
   293  	params := connParams
   294  	params.Flags |= mysql.CapabilityClientSessionTrack
   295  	conn, err := mysql.Connect(ctx, &params)
   296  	require.NoError(t, err)
   297  
   298  	qr, err := conn.ExecuteFetch(`set session session_track_gtids='own_gtid'`, 1000, false)
   299  	require.NoError(t, err)
   300  	require.Empty(t, qr.SessionStateChanges)
   301  
   302  	qr, err = conn.ExecuteFetch(`create table vttest.t1(id bigint primary key)`, 1000, false)
   303  	require.NoError(t, err)
   304  	require.NotEmpty(t, qr.SessionStateChanges)
   305  
   306  	qr, err = conn.ExecuteFetch(`insert into vttest.t1 values (1)`, 1000, false)
   307  	require.NoError(t, err)
   308  	require.NotEmpty(t, qr.SessionStateChanges)
   309  }
   310  
   311  func TestCachingSha2Password(t *testing.T) {
   312  	ctx := context.Background()
   313  
   314  	// connect as an existing user to create a user account with caching_sha2_password
   315  	params := connParams
   316  	conn, err := mysql.Connect(ctx, &params)
   317  	expectNoError(t, err)
   318  	defer conn.Close()
   319  
   320  	qr, err := conn.ExecuteFetch(`select true from information_schema.PLUGINS where PLUGIN_NAME='caching_sha2_password' and PLUGIN_STATUS='ACTIVE'`, 1, false)
   321  	assert.NoError(t, err, "select true from information_schema.PLUGINS failed: %v", err)
   322  
   323  	if len(qr.Rows) != 1 {
   324  		t.Skip("Server does not support caching_sha2_password plugin")
   325  	}
   326  
   327  	// create a user using caching_sha2_password password
   328  	if _, err = conn.ExecuteFetch(`create user 'sha2user'@'localhost' identified with caching_sha2_password by 'password';`, 0, false); err != nil {
   329  		t.Fatalf("Create user with caching_sha2_password failed: %v", err)
   330  	}
   331  	conn.Close()
   332  
   333  	// connect as sha2user
   334  	params.Uname = "sha2user"
   335  	params.Pass = "password"
   336  	params.DbName = "information_schema"
   337  	conn, err = mysql.Connect(ctx, &params)
   338  	expectNoError(t, err)
   339  	defer conn.Close()
   340  
   341  	if qr, err = conn.ExecuteFetch(`select user()`, 1, true); err != nil {
   342  		t.Fatalf("select user() failed: %v", err)
   343  	}
   344  
   345  	if len(qr.Rows) != 1 || qr.Rows[0][0].ToString() != "sha2user@localhost" {
   346  		t.Errorf("Logged in user is not sha2user")
   347  	}
   348  }
   349  
   350  func TestClientInfo(t *testing.T) {
   351  	const infoPrepared = "Statement prepared"
   352  
   353  	ctx := context.Background()
   354  	params := connParams
   355  	params.EnableQueryInfo = true
   356  	conn, err := mysql.Connect(ctx, &params)
   357  	require.NoError(t, err)
   358  
   359  	defer conn.Close()
   360  
   361  	// This is the simplest query that would return some textual data in the 'info' field
   362  	result, err := conn.ExecuteFetch(`PREPARE stmt1 FROM 'SELECT 1 = 1'`, -1, true)
   363  	require.NoError(t, err, "select failed: %v", err)
   364  	require.Equal(t, infoPrepared, result.Info, "expected result.Info=%q, got=%q", infoPrepared, result.Info)
   365  }
   366  
   367  func TestBaseShowTables(t *testing.T) {
   368  	params := connParams
   369  	ctx := context.Background()
   370  	conn, err := mysql.Connect(ctx, &params)
   371  	require.NoError(t, err)
   372  	defer conn.Close()
   373  
   374  	sql := conn.BaseShowTables()
   375  	// An improved test would make assertions about the results. This test just
   376  	// makes sure there aren't any errors.
   377  	_, err = conn.ExecuteFetch(sql, -1, true)
   378  	require.NoError(t, err)
   379  }
   380  
   381  func TestBaseShowTablesFilePos(t *testing.T) {
   382  	params := connParams
   383  	params.Flavor = "FilePos"
   384  	ctx := context.Background()
   385  	conn, err := mysql.Connect(ctx, &params)
   386  	require.NoError(t, err)
   387  	defer conn.Close()
   388  
   389  	sql := conn.BaseShowTables()
   390  	// An improved test would make assertions about the results. This test just
   391  	// makes sure there aren't any errors.
   392  	_, err = conn.ExecuteFetch(sql, -1, true)
   393  	require.NoError(t, err)
   394  }