github.com/pingcap/tidb-lightning@v5.0.0-rc.0.20210428090220-84b649866577+incompatible/lightning/common/util_test.go (about)

     1  // Copyright 2019 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package common_test
    15  
    16  import (
    17  	"context"
    18  	"encoding/json"
    19  	"fmt"
    20  	"io"
    21  	"net"
    22  	"net/http"
    23  	"net/http/httptest"
    24  	"time"
    25  
    26  	sqlmock "github.com/DATA-DOG/go-sqlmock"
    27  	"github.com/go-sql-driver/mysql"
    28  	. "github.com/pingcap/check"
    29  	"github.com/pingcap/errors"
    30  	tmysql "github.com/pingcap/tidb/errno"
    31  	"go.uber.org/multierr"
    32  	"google.golang.org/grpc/codes"
    33  	"google.golang.org/grpc/status"
    34  
    35  	"github.com/pingcap/tidb-lightning/lightning/common"
    36  	"github.com/pingcap/tidb-lightning/lightning/log"
    37  )
    38  
    39  type utilSuite struct{}
    40  
    41  var _ = Suite(&utilSuite{})
    42  
    43  func (s *utilSuite) TestDirNotExist(c *C) {
    44  	c.Assert(common.IsDirExists("."), IsTrue)
    45  	c.Assert(common.IsDirExists("not-exists"), IsFalse)
    46  }
    47  
    48  func (s *utilSuite) TestGetJSON(c *C) {
    49  	type TestPayload struct {
    50  		Username string `json:"username"`
    51  		Password string `json:"password"`
    52  	}
    53  	var request = TestPayload{
    54  		Username: "lightning",
    55  		Password: "lightning-ctl",
    56  	}
    57  
    58  	ctx := context.Background()
    59  	// Mock success response
    60  	handle := func(res http.ResponseWriter, req *http.Request) {
    61  		res.WriteHeader(http.StatusOK)
    62  		err := json.NewEncoder(res).Encode(request)
    63  		c.Assert(err, IsNil)
    64  	}
    65  	testServer := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
    66  		handle(res, req)
    67  	}))
    68  	defer testServer.Close()
    69  
    70  	client := &http.Client{Timeout: time.Second}
    71  
    72  	response := TestPayload{}
    73  	err := common.GetJSON(ctx, client, "http://not-exists", &response)
    74  	c.Assert(err, NotNil)
    75  	err = common.GetJSON(ctx, client, testServer.URL, &response)
    76  	c.Assert(err, IsNil)
    77  	c.Assert(request, DeepEquals, response)
    78  
    79  	// Mock `StatusNoContent` response
    80  	handle = func(res http.ResponseWriter, req *http.Request) {
    81  		res.WriteHeader(http.StatusNoContent)
    82  	}
    83  	err = common.GetJSON(ctx, client, testServer.URL, &response)
    84  	c.Assert(err, NotNil)
    85  	c.Assert(err, ErrorMatches, ".*http status code != 200.*")
    86  }
    87  
    88  func (s *utilSuite) TestIsRetryableError(c *C) {
    89  	c.Assert(common.IsRetryableError(context.Canceled), IsFalse)
    90  	c.Assert(common.IsRetryableError(context.DeadlineExceeded), IsFalse)
    91  	c.Assert(common.IsRetryableError(io.EOF), IsFalse)
    92  	c.Assert(common.IsRetryableError(&net.AddrError{}), IsFalse)
    93  	c.Assert(common.IsRetryableError(&net.DNSError{}), IsFalse)
    94  	c.Assert(common.IsRetryableError(&net.DNSError{IsTimeout: true}), IsTrue)
    95  
    96  	// MySQL Errors
    97  	c.Assert(common.IsRetryableError(&mysql.MySQLError{}), IsFalse)
    98  	c.Assert(common.IsRetryableError(&mysql.MySQLError{Number: tmysql.ErrUnknown}), IsTrue)
    99  	c.Assert(common.IsRetryableError(&mysql.MySQLError{Number: tmysql.ErrLockDeadlock}), IsTrue)
   100  	c.Assert(common.IsRetryableError(&mysql.MySQLError{Number: tmysql.ErrPDServerTimeout}), IsTrue)
   101  	c.Assert(common.IsRetryableError(&mysql.MySQLError{Number: tmysql.ErrTiKVServerTimeout}), IsTrue)
   102  	c.Assert(common.IsRetryableError(&mysql.MySQLError{Number: tmysql.ErrTiKVServerBusy}), IsTrue)
   103  	c.Assert(common.IsRetryableError(&mysql.MySQLError{Number: tmysql.ErrResolveLockTimeout}), IsTrue)
   104  	c.Assert(common.IsRetryableError(&mysql.MySQLError{Number: tmysql.ErrRegionUnavailable}), IsTrue)
   105  	c.Assert(common.IsRetryableError(&mysql.MySQLError{Number: tmysql.ErrWriteConflictInTiDB}), IsTrue)
   106  
   107  	// gRPC Errors
   108  	c.Assert(common.IsRetryableError(status.Error(codes.Canceled, "")), IsFalse)
   109  	c.Assert(common.IsRetryableError(status.Error(codes.Unknown, "")), IsTrue)
   110  	c.Assert(common.IsRetryableError(status.Error(codes.DeadlineExceeded, "")), IsTrue)
   111  	c.Assert(common.IsRetryableError(status.Error(codes.NotFound, "")), IsTrue)
   112  	c.Assert(common.IsRetryableError(status.Error(codes.AlreadyExists, "")), IsTrue)
   113  	c.Assert(common.IsRetryableError(status.Error(codes.PermissionDenied, "")), IsTrue)
   114  	c.Assert(common.IsRetryableError(status.Error(codes.ResourceExhausted, "")), IsTrue)
   115  	c.Assert(common.IsRetryableError(status.Error(codes.Aborted, "")), IsTrue)
   116  	c.Assert(common.IsRetryableError(status.Error(codes.OutOfRange, "")), IsTrue)
   117  	c.Assert(common.IsRetryableError(status.Error(codes.Unavailable, "")), IsTrue)
   118  	c.Assert(common.IsRetryableError(status.Error(codes.DataLoss, "")), IsTrue)
   119  
   120  	// sqlmock errors
   121  	c.Assert(common.IsRetryableError(fmt.Errorf("call to database Close was not expected")), IsFalse)
   122  	c.Assert(common.IsRetryableError(errors.New("call to database Close was not expected")), IsTrue)
   123  
   124  	// multierr
   125  	c.Assert(common.IsRetryableError(multierr.Combine(context.Canceled, context.Canceled)), IsFalse)
   126  	c.Assert(common.IsRetryableError(multierr.Combine(&net.DNSError{IsTimeout: true}, &net.DNSError{IsTimeout: true})), IsTrue)
   127  	c.Assert(common.IsRetryableError(multierr.Combine(context.Canceled, &net.DNSError{IsTimeout: true})), IsFalse)
   128  }
   129  
   130  func (s *utilSuite) TestToDSN(c *C) {
   131  	param := common.MySQLConnectParam{
   132  		Host:             "127.0.0.1",
   133  		Port:             4000,
   134  		User:             "root",
   135  		Password:         "123456",
   136  		SQLMode:          "strict",
   137  		MaxAllowedPacket: 1234,
   138  		TLS:              "cluster",
   139  		Vars: map[string]string{
   140  			"tidb_distsql_scan_concurrency": "1",
   141  		},
   142  	}
   143  	c.Assert(param.ToDSN(), Equals, "root:123456@tcp(127.0.0.1:4000)/?charset=utf8mb4&sql_mode='strict'&maxAllowedPacket=1234&tls=cluster&tidb_distsql_scan_concurrency=1")
   144  }
   145  
   146  func (s *utilSuite) TestIsContextCanceledError(c *C) {
   147  	c.Assert(common.IsContextCanceledError(context.Canceled), IsTrue)
   148  	c.Assert(common.IsContextCanceledError(io.EOF), IsFalse)
   149  }
   150  
   151  func (s *utilSuite) TestUniqueTable(c *C) {
   152  	tableName := common.UniqueTable("test", "t1")
   153  	c.Assert(tableName, Equals, "`test`.`t1`")
   154  
   155  	tableName = common.UniqueTable("test", "t`1")
   156  	c.Assert(tableName, Equals, "`test`.`t``1`")
   157  }
   158  
   159  func (s *utilSuite) TestSQLWithRetry(c *C) {
   160  	db, mock, err := sqlmock.New()
   161  	c.Assert(err, IsNil)
   162  
   163  	sqlWithRetry := &common.SQLWithRetry{
   164  		DB:     db,
   165  		Logger: log.L(),
   166  	}
   167  	aValue := new(int)
   168  
   169  	// retry defaultMaxRetry times and still failed
   170  	for i := 0; i < 3; i++ {
   171  		mock.ExpectQuery("select a from test.t1").WillReturnError(errors.New("mock error"))
   172  	}
   173  	err = sqlWithRetry.QueryRow(context.Background(), "", "select a from test.t1", aValue)
   174  	c.Assert(err, ErrorMatches, ".*mock error")
   175  
   176  	// meet unretryable error and will return directly
   177  	mock.ExpectQuery("select a from test.t1").WillReturnError(context.Canceled)
   178  	err = sqlWithRetry.QueryRow(context.Background(), "", "select a from test.t1", aValue)
   179  	c.Assert(err, ErrorMatches, ".*context canceled")
   180  
   181  	// query success
   182  	rows := sqlmock.NewRows([]string{"a"}).AddRow("1")
   183  	mock.ExpectQuery("select a from test.t1").WillReturnRows(rows)
   184  
   185  	err = sqlWithRetry.QueryRow(context.Background(), "", "select a from test.t1", aValue)
   186  	c.Assert(err, IsNil)
   187  	c.Assert(*aValue, Equals, 1)
   188  
   189  	// test Exec
   190  	mock.ExpectExec("delete from").WillReturnError(context.Canceled)
   191  	err = sqlWithRetry.Exec(context.Background(), "", "delete from test.t1 where id = ?", 2)
   192  	c.Assert(err, ErrorMatches, ".*context canceled")
   193  
   194  	mock.ExpectExec("delete from").WillReturnResult(sqlmock.NewResult(0, 1))
   195  	err = sqlWithRetry.Exec(context.Background(), "", "delete from test.t1 where id = ?", 2)
   196  	c.Assert(err, IsNil)
   197  
   198  	c.Assert(mock.ExpectationsWereMet(), IsNil)
   199  }
   200  
   201  func (s *utilSuite) TestStringSliceEqual(c *C) {
   202  	c.Assert(common.StringSliceEqual(nil, nil), IsTrue)
   203  	c.Assert(common.StringSliceEqual(nil, []string{}), IsTrue)
   204  	c.Assert(common.StringSliceEqual(nil, []string{"a"}), IsFalse)
   205  	c.Assert(common.StringSliceEqual([]string{"a"}, nil), IsFalse)
   206  	c.Assert(common.StringSliceEqual([]string{"a"}, []string{"a"}), IsTrue)
   207  	c.Assert(common.StringSliceEqual([]string{"a"}, []string{"b"}), IsFalse)
   208  	c.Assert(common.StringSliceEqual([]string{"a", "b", "c"}, []string{"a", "b", "c"}), IsTrue)
   209  	c.Assert(common.StringSliceEqual([]string{"a"}, []string{"a", "b", "c"}), IsFalse)
   210  	c.Assert(common.StringSliceEqual([]string{"a", "b", "c"}, []string{"a", "b"}), IsFalse)
   211  	c.Assert(common.StringSliceEqual([]string{"a", "x", "y"}, []string{"a", "y", "x"}), IsFalse)
   212  }
   213  
   214  func (s *utilSuite) TestInterpolateMySQLString(c *C) {
   215  	c.Assert(common.InterpolateMySQLString("123"), Equals, "'123'")
   216  	c.Assert(common.InterpolateMySQLString("1'23"), Equals, "'1''23'")
   217  	c.Assert(common.InterpolateMySQLString("1'2''3"), Equals, "'1''2''''3'")
   218  }