vitess.io/vitess@v0.16.2/go/vt/vttablet/tabletserver/connpool/dbconn_test.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package connpool
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"fmt"
    23  	"strings"
    24  	"testing"
    25  	"time"
    26  
    27  	"github.com/stretchr/testify/assert"
    28  	"github.com/stretchr/testify/require"
    29  
    30  	"vitess.io/vitess/go/mysql"
    31  	"vitess.io/vitess/go/mysql/fakesqldb"
    32  	"vitess.io/vitess/go/pools"
    33  	"vitess.io/vitess/go/sqltypes"
    34  	querypb "vitess.io/vitess/go/vt/proto/query"
    35  )
    36  
    37  func compareTimingCounts(t *testing.T, op string, delta int64, before, after map[string]int64) {
    38  	t.Helper()
    39  	countBefore := before[op]
    40  	countAfter := after[op]
    41  	if countAfter-countBefore != delta {
    42  		t.Errorf("Expected %s to increase by %d, got %d (%d => %d)", op, delta, countAfter-countBefore, countBefore, countAfter)
    43  	}
    44  }
    45  
    46  func TestDBConnExec(t *testing.T) {
    47  	db := fakesqldb.New(t)
    48  	defer db.Close()
    49  
    50  	sql := "select * from test_table limit 1000"
    51  	expectedResult := &sqltypes.Result{
    52  		Fields: []*querypb.Field{
    53  			{Type: sqltypes.VarChar},
    54  		},
    55  		RowsAffected: 0,
    56  		Rows: [][]sqltypes.Value{
    57  			{sqltypes.NewVarChar("123")},
    58  		},
    59  	}
    60  	db.AddQuery(sql, expectedResult)
    61  	connPool := newPool()
    62  	mysqlTimings := connPool.env.Stats().MySQLTimings
    63  	startCounts := mysqlTimings.Counts()
    64  	connPool.Open(db.ConnParams(), db.ConnParams(), db.ConnParams())
    65  	defer connPool.Close()
    66  	ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(10*time.Second))
    67  	defer cancel()
    68  	dbConn, err := NewDBConn(context.Background(), connPool, db.ConnParams())
    69  	if dbConn != nil {
    70  		defer dbConn.Close()
    71  	}
    72  	if err != nil {
    73  		t.Fatalf("should not get an error, err: %v", err)
    74  	}
    75  	// Exec succeed, not asking for fields.
    76  	result, err := dbConn.Exec(ctx, sql, 1, false)
    77  	if err != nil {
    78  		t.Fatalf("should not get an error, err: %v", err)
    79  	}
    80  	expectedResult.Fields = nil
    81  	if !expectedResult.Equal(result) {
    82  		t.Errorf("Exec: %v, want %v", expectedResult, result)
    83  	}
    84  
    85  	compareTimingCounts(t, "PoolTest.Exec", 1, startCounts, mysqlTimings.Counts())
    86  
    87  	startCounts = mysqlTimings.Counts()
    88  
    89  	// Exec fail due to client side error
    90  	db.AddRejectedQuery(sql, &mysql.SQLError{
    91  		Num:     2012,
    92  		Message: "connection fail",
    93  		Query:   "",
    94  	})
    95  	_, err = dbConn.Exec(ctx, sql, 1, false)
    96  	want := "connection fail"
    97  	if err == nil || !strings.Contains(err.Error(), want) {
    98  		t.Errorf("Exec: %v, want %s", err, want)
    99  	}
   100  
   101  	// The client side error triggers a retry in exec.
   102  	compareTimingCounts(t, "PoolTest.Exec", 2, startCounts, mysqlTimings.Counts())
   103  
   104  	startCounts = mysqlTimings.Counts()
   105  
   106  	// Set the connection fail flag and try again.
   107  	// This time the initial query fails as does the reconnect attempt.
   108  	db.EnableConnFail()
   109  	_, err = dbConn.Exec(ctx, sql, 1, false)
   110  	want = "packet read failed"
   111  	if err == nil || !strings.Contains(err.Error(), want) {
   112  		t.Errorf("Exec: %v, want %s", err, want)
   113  	}
   114  	db.DisableConnFail()
   115  
   116  	compareTimingCounts(t, "PoolTest.Exec", 1, startCounts, mysqlTimings.Counts())
   117  }
   118  
   119  func TestDBConnExecLost(t *testing.T) {
   120  	db := fakesqldb.New(t)
   121  	defer db.Close()
   122  
   123  	sql := "select * from test_table limit 1000"
   124  	expectedResult := &sqltypes.Result{
   125  		Fields: []*querypb.Field{
   126  			{Type: sqltypes.VarChar},
   127  		},
   128  		RowsAffected: 0,
   129  		Rows: [][]sqltypes.Value{
   130  			{sqltypes.NewVarChar("123")},
   131  		},
   132  	}
   133  	db.AddQuery(sql, expectedResult)
   134  	connPool := newPool()
   135  	mysqlTimings := connPool.env.Stats().MySQLTimings
   136  	startCounts := mysqlTimings.Counts()
   137  	connPool.Open(db.ConnParams(), db.ConnParams(), db.ConnParams())
   138  	defer connPool.Close()
   139  	ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(10*time.Second))
   140  	defer cancel()
   141  	dbConn, err := NewDBConn(context.Background(), connPool, db.ConnParams())
   142  	if dbConn != nil {
   143  		defer dbConn.Close()
   144  	}
   145  	if err != nil {
   146  		t.Fatalf("should not get an error, err: %v", err)
   147  	}
   148  	// Exec succeed, not asking for fields.
   149  	result, err := dbConn.Exec(ctx, sql, 1, false)
   150  	if err != nil {
   151  		t.Fatalf("should not get an error, err: %v", err)
   152  	}
   153  	expectedResult.Fields = nil
   154  	if !expectedResult.Equal(result) {
   155  		t.Errorf("Exec: %v, want %v", expectedResult, result)
   156  	}
   157  
   158  	compareTimingCounts(t, "PoolTest.Exec", 1, startCounts, mysqlTimings.Counts())
   159  
   160  	// Exec fail due to server side error (e.g. query kill)
   161  	startCounts = mysqlTimings.Counts()
   162  	db.AddRejectedQuery(sql, &mysql.SQLError{
   163  		Num:     2013,
   164  		Message: "Lost connection to MySQL server during query",
   165  		Query:   "",
   166  	})
   167  	_, err = dbConn.Exec(ctx, sql, 1, false)
   168  	want := "Lost connection to MySQL server during query"
   169  	if err == nil || !strings.Contains(err.Error(), want) {
   170  		t.Errorf("Exec: %v, want %s", err, want)
   171  	}
   172  
   173  	// Should *not* see a retry, so only increment by 1
   174  	compareTimingCounts(t, "PoolTest.Exec", 1, startCounts, mysqlTimings.Counts())
   175  }
   176  
   177  func TestDBConnDeadline(t *testing.T) {
   178  	db := fakesqldb.New(t)
   179  	defer db.Close()
   180  	sql := "select * from test_table limit 1000"
   181  	expectedResult := &sqltypes.Result{
   182  		Fields: []*querypb.Field{
   183  			{Type: sqltypes.VarChar},
   184  		},
   185  		RowsAffected: 0,
   186  		Rows: [][]sqltypes.Value{
   187  			{sqltypes.NewVarChar("123")},
   188  		},
   189  	}
   190  	db.AddQuery(sql, expectedResult)
   191  
   192  	connPool := newPool()
   193  	mysqlTimings := connPool.env.Stats().MySQLTimings
   194  	startCounts := mysqlTimings.Counts()
   195  	connPool.Open(db.ConnParams(), db.ConnParams(), db.ConnParams())
   196  	defer connPool.Close()
   197  
   198  	db.SetConnDelay(100 * time.Millisecond)
   199  	ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(50*time.Millisecond))
   200  	defer cancel()
   201  
   202  	dbConn, err := NewDBConn(context.Background(), connPool, db.ConnParams())
   203  	if dbConn != nil {
   204  		defer dbConn.Close()
   205  	}
   206  	if err != nil {
   207  		t.Fatalf("should not get an error, err: %v", err)
   208  	}
   209  
   210  	_, err = dbConn.Exec(ctx, sql, 1, false)
   211  	want := "context deadline exceeded before execution started"
   212  	if err == nil || !strings.Contains(err.Error(), want) {
   213  		t.Errorf("Exec: %v, want %s", err, want)
   214  	}
   215  
   216  	compareTimingCounts(t, "PoolTest.Exec", 0, startCounts, mysqlTimings.Counts())
   217  
   218  	startCounts = mysqlTimings.Counts()
   219  
   220  	ctx, cancel = context.WithDeadline(context.Background(), time.Now().Add(10*time.Second))
   221  	defer cancel()
   222  
   223  	result, err := dbConn.Exec(ctx, sql, 1, false)
   224  	if err != nil {
   225  		t.Fatalf("should not get an error, err: %v", err)
   226  	}
   227  	expectedResult.Fields = nil
   228  	if !expectedResult.Equal(result) {
   229  		t.Errorf("Exec: %v, want %v", expectedResult, result)
   230  	}
   231  
   232  	compareTimingCounts(t, "PoolTest.Exec", 1, startCounts, mysqlTimings.Counts())
   233  
   234  	startCounts = mysqlTimings.Counts()
   235  
   236  	// Test with just the Background context (with no deadline)
   237  	result, err = dbConn.Exec(context.Background(), sql, 1, false)
   238  	if err != nil {
   239  		t.Fatalf("should not get an error, err: %v", err)
   240  	}
   241  	expectedResult.Fields = nil
   242  	if !expectedResult.Equal(result) {
   243  		t.Errorf("Exec: %v, want %v", expectedResult, result)
   244  	}
   245  
   246  	compareTimingCounts(t, "PoolTest.Exec", 1, startCounts, mysqlTimings.Counts())
   247  }
   248  
   249  func TestDBConnKill(t *testing.T) {
   250  	db := fakesqldb.New(t)
   251  	defer db.Close()
   252  	connPool := newPool()
   253  	connPool.Open(db.ConnParams(), db.ConnParams(), db.ConnParams())
   254  	defer connPool.Close()
   255  	dbConn, err := NewDBConn(context.Background(), connPool, db.ConnParams())
   256  	if dbConn != nil {
   257  		defer dbConn.Close()
   258  	}
   259  	if err != nil {
   260  		t.Fatalf("should not get an error, err: %v", err)
   261  	}
   262  	query := fmt.Sprintf("kill %d", dbConn.ID())
   263  	db.AddQuery(query, &sqltypes.Result{})
   264  	// Kill failed because we are not able to connect to the database
   265  	db.EnableConnFail()
   266  	err = dbConn.Kill("test kill", 0)
   267  	want := "errno 2013"
   268  	if err == nil || !strings.Contains(err.Error(), want) {
   269  		t.Errorf("Exec: %v, want %s", err, want)
   270  	}
   271  	db.DisableConnFail()
   272  
   273  	// Kill succeed
   274  	err = dbConn.Kill("test kill", 0)
   275  	if err != nil {
   276  		t.Fatalf("kill should succeed, but got error: %v", err)
   277  	}
   278  
   279  	err = dbConn.reconnect(context.Background())
   280  	if err != nil {
   281  		t.Fatalf("reconnect should succeed, but got error: %v", err)
   282  	}
   283  	newKillQuery := fmt.Sprintf("kill %d", dbConn.ID())
   284  	// Kill failed because "kill query_id" failed
   285  	db.AddRejectedQuery(newKillQuery, errors.New("rejected"))
   286  	err = dbConn.Kill("test kill", 0)
   287  	want = "rejected"
   288  	if err == nil || !strings.Contains(err.Error(), want) {
   289  		t.Errorf("Exec: %v, want %s", err, want)
   290  	}
   291  }
   292  
   293  // TestDBConnClose tests that an Exec returns immediately if a connection
   294  // is asynchronously killed (and closed) in the middle of an execution.
   295  func TestDBConnClose(t *testing.T) {
   296  	db := fakesqldb.New(t)
   297  	defer db.Close()
   298  	connPool := newPool()
   299  	connPool.Open(db.ConnParams(), db.ConnParams(), db.ConnParams())
   300  	defer connPool.Close()
   301  	dbConn, err := NewDBConn(context.Background(), connPool, db.ConnParams())
   302  	require.NoError(t, err)
   303  	defer dbConn.Close()
   304  
   305  	query := "sleep"
   306  	db.AddQuery(query, &sqltypes.Result{})
   307  	db.SetBeforeFunc(query, func() {
   308  		time.Sleep(100 * time.Millisecond)
   309  	})
   310  
   311  	start := time.Now()
   312  	go func() {
   313  		time.Sleep(10 * time.Millisecond)
   314  		dbConn.Kill("test kill", 0)
   315  	}()
   316  	_, err = dbConn.Exec(context.Background(), query, 1, false)
   317  	assert.Contains(t, err.Error(), "(errno 2013) due to")
   318  	assert.True(t, time.Since(start) < 100*time.Millisecond, "%v %v", time.Since(start), 100*time.Millisecond)
   319  }
   320  
   321  func TestDBNoPoolConnKill(t *testing.T) {
   322  	db := fakesqldb.New(t)
   323  	connPool := newPool()
   324  	connPool.Open(db.ConnParams(), db.ConnParams(), db.ConnParams())
   325  	defer connPool.Close()
   326  	dbConn, err := NewDBConnNoPool(context.Background(), db.ConnParams(), connPool.dbaPool, nil)
   327  	if dbConn != nil {
   328  		defer dbConn.Close()
   329  	}
   330  	if err != nil {
   331  		t.Fatalf("should not get an error, err: %v", err)
   332  	}
   333  	query := fmt.Sprintf("kill %d", dbConn.ID())
   334  	db.AddQuery(query, &sqltypes.Result{})
   335  	// Kill failed because we are not able to connect to the database
   336  	db.EnableConnFail()
   337  	err = dbConn.Kill("test kill", 0)
   338  	want := "errno 2013"
   339  	if err == nil || !strings.Contains(err.Error(), want) {
   340  		t.Errorf("Exec: %v, want %s", err, want)
   341  	}
   342  	db.DisableConnFail()
   343  
   344  	// Kill succeed
   345  	err = dbConn.Kill("test kill", 0)
   346  	if err != nil {
   347  		t.Fatalf("kill should succeed, but got error: %v", err)
   348  	}
   349  
   350  	err = dbConn.reconnect(context.Background())
   351  	if err != nil {
   352  		t.Fatalf("reconnect should succeed, but got error: %v", err)
   353  	}
   354  	newKillQuery := fmt.Sprintf("kill %d", dbConn.ID())
   355  	// Kill failed because "kill query_id" failed
   356  	db.AddRejectedQuery(newKillQuery, errors.New("rejected"))
   357  	err = dbConn.Kill("test kill", 0)
   358  	want = "rejected"
   359  	if err == nil || !strings.Contains(err.Error(), want) {
   360  		t.Errorf("Exec: %v, want %s", err, want)
   361  	}
   362  }
   363  
   364  func TestDBConnStream(t *testing.T) {
   365  	db := fakesqldb.New(t)
   366  	defer db.Close()
   367  	sql := "select * from test_table limit 1000"
   368  	expectedResult := &sqltypes.Result{
   369  		Fields: []*querypb.Field{
   370  			{Type: sqltypes.VarChar},
   371  		},
   372  		Rows: [][]sqltypes.Value{
   373  			{sqltypes.NewVarChar("123")},
   374  		},
   375  	}
   376  	db.AddQuery(sql, expectedResult)
   377  	connPool := newPool()
   378  	connPool.Open(db.ConnParams(), db.ConnParams(), db.ConnParams())
   379  	defer connPool.Close()
   380  	ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(10*time.Second))
   381  	defer cancel()
   382  	dbConn, err := NewDBConn(context.Background(), connPool, db.ConnParams())
   383  	if dbConn != nil {
   384  		defer dbConn.Close()
   385  	}
   386  	if err != nil {
   387  		t.Fatalf("should not get an error, err: %v", err)
   388  	}
   389  	var result sqltypes.Result
   390  	err = dbConn.Stream(
   391  		ctx, sql, func(r *sqltypes.Result) error {
   392  			// Aggregate Fields and Rows
   393  			if r.Fields != nil {
   394  				result.Fields = r.Fields
   395  			}
   396  			if r.Rows != nil {
   397  				result.Rows = append(result.Rows, r.Rows...)
   398  			}
   399  			return nil
   400  		}, func() *sqltypes.Result {
   401  			return &sqltypes.Result{}
   402  		},
   403  		10, querypb.ExecuteOptions_ALL)
   404  	if err != nil {
   405  		t.Fatalf("should not get an error, err: %v", err)
   406  	}
   407  	if !expectedResult.Equal(&result) {
   408  		t.Errorf("Exec: %v, want %v", expectedResult, &result)
   409  	}
   410  	// Stream fail
   411  	db.Close()
   412  	dbConn.Close()
   413  	err = dbConn.Stream(
   414  		ctx, sql, func(r *sqltypes.Result) error {
   415  			return nil
   416  		}, func() *sqltypes.Result {
   417  			return &sqltypes.Result{}
   418  		},
   419  		10, querypb.ExecuteOptions_ALL)
   420  	db.DisableConnFail()
   421  	want := "no such file or directory (errno 2002)"
   422  	if err == nil || !strings.Contains(err.Error(), want) {
   423  		t.Errorf("Error: '%v', must contain '%s'", err, want)
   424  	}
   425  }
   426  
   427  func TestDBConnStreamKill(t *testing.T) {
   428  	db := fakesqldb.New(t)
   429  	defer db.Close()
   430  	sql := "select * from test_table limit 1000"
   431  	expectedResult := &sqltypes.Result{
   432  		Fields: []*querypb.Field{
   433  			{Type: sqltypes.VarChar},
   434  		},
   435  	}
   436  	db.AddQuery(sql, expectedResult)
   437  	connPool := newPool()
   438  	connPool.Open(db.ConnParams(), db.ConnParams(), db.ConnParams())
   439  	defer connPool.Close()
   440  	dbConn, err := NewDBConn(context.Background(), connPool, db.ConnParams())
   441  	require.NoError(t, err)
   442  	defer dbConn.Close()
   443  
   444  	go func() {
   445  		time.Sleep(10 * time.Millisecond)
   446  		dbConn.Kill("test kill", 0)
   447  	}()
   448  
   449  	err = dbConn.Stream(context.Background(), sql,
   450  		func(r *sqltypes.Result) error {
   451  			time.Sleep(100 * time.Millisecond)
   452  			return nil
   453  		},
   454  		func() *sqltypes.Result {
   455  			return &sqltypes.Result{}
   456  		},
   457  		10, querypb.ExecuteOptions_ALL)
   458  
   459  	assert.Contains(t, err.Error(), "(errno 2013) due to")
   460  }
   461  
   462  func TestDBConnReconnect(t *testing.T) {
   463  	db := fakesqldb.New(t)
   464  	defer db.Close()
   465  
   466  	connPool := newPool()
   467  	connPool.Open(db.ConnParams(), db.ConnParams(), db.ConnParams())
   468  	defer connPool.Close()
   469  
   470  	dbConn, err := NewDBConn(context.Background(), connPool, db.ConnParams())
   471  	require.NoError(t, err)
   472  	defer dbConn.Close()
   473  
   474  	oldConnID := dbConn.conn.ID()
   475  	// close the connection and let the dbconn reconnect to start a new connection when required.
   476  	dbConn.conn.Close()
   477  
   478  	query := "select 1"
   479  	db.AddQuery(query, &sqltypes.Result{})
   480  
   481  	_, err = dbConn.Exec(context.Background(), query, 1, false)
   482  	require.NoError(t, err)
   483  	require.NotEqual(t, oldConnID, dbConn.conn.ID())
   484  }
   485  
   486  func TestDBConnReApplySetting(t *testing.T) {
   487  	db := fakesqldb.New(t)
   488  	defer db.Close()
   489  	db.OrderMatters()
   490  
   491  	connPool := newPool()
   492  	connPool.Open(db.ConnParams(), db.ConnParams(), db.ConnParams())
   493  	defer connPool.Close()
   494  
   495  	ctx := context.Background()
   496  	dbConn, err := NewDBConn(ctx, connPool, db.ConnParams())
   497  	require.NoError(t, err)
   498  	defer dbConn.Close()
   499  
   500  	// apply system settings.
   501  	setQ := "set @@sql_mode='ANSI_QUOTES'"
   502  	db.AddExpectedQuery(setQ, nil)
   503  	err = dbConn.ApplySetting(ctx, pools.NewSetting(setQ, "set @@sql_mode = default"))
   504  	require.NoError(t, err)
   505  
   506  	// close the connection and let the dbconn reconnect to start a new connection when required.
   507  	oldConnID := dbConn.conn.ID()
   508  	dbConn.conn.Close()
   509  
   510  	// new conn should also have the same settings.
   511  	// set query will be executed first on the new connection and then the query.
   512  	db.AddExpectedQuery(setQ, nil)
   513  	query := "select 1"
   514  	db.AddExpectedQuery(query, nil)
   515  	_, err = dbConn.Exec(ctx, query, 1, false)
   516  	require.NoError(t, err)
   517  	require.NotEqual(t, oldConnID, dbConn.conn.ID())
   518  
   519  	db.VerifyAllExecutedOrFail()
   520  }