github.com/pingcap/tidb-lightning@v5.0.0-rc.0.20210428090220-84b649866577+incompatible/lightning/common/util.go (about)

     1  // Copyright 2019 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package common
    15  
    16  import (
    17  	"context"
    18  	"database/sql"
    19  	"encoding/json"
    20  	stderrors "errors"
    21  	"fmt"
    22  	"io"
    23  	"io/ioutil"
    24  	"net"
    25  	"net/http"
    26  	"net/url"
    27  	"os"
    28  	"reflect"
    29  	"regexp"
    30  	"strings"
    31  	"syscall"
    32  	"time"
    33  
    34  	"github.com/go-sql-driver/mysql"
    35  	"github.com/pingcap/errors"
    36  	"github.com/pingcap/parser/model"
    37  	tmysql "github.com/pingcap/tidb/errno"
    38  	"go.uber.org/zap"
    39  	"google.golang.org/grpc/codes"
    40  	"google.golang.org/grpc/status"
    41  
    42  	"github.com/pingcap/tidb-lightning/lightning/log"
    43  )
    44  
    45  const (
    46  	retryTimeout = 3 * time.Second
    47  
    48  	defaultMaxRetry = 3
    49  )
    50  
    51  // MySQLConnectParam records the parameters needed to connect to a MySQL database.
    52  type MySQLConnectParam struct {
    53  	Host             string
    54  	Port             int
    55  	User             string
    56  	Password         string
    57  	SQLMode          string
    58  	MaxAllowedPacket uint64
    59  	TLS              string
    60  	Vars             map[string]string
    61  }
    62  
    63  func (param *MySQLConnectParam) ToDSN() string {
    64  	dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/?charset=utf8mb4&sql_mode='%s'&maxAllowedPacket=%d&tls=%s",
    65  		param.User, param.Password, param.Host, param.Port,
    66  		param.SQLMode, param.MaxAllowedPacket, param.TLS)
    67  
    68  	for k, v := range param.Vars {
    69  		dsn += fmt.Sprintf("&%s=%s", k, url.QueryEscape(v))
    70  	}
    71  
    72  	return dsn
    73  }
    74  
    75  func (param *MySQLConnectParam) Connect() (*sql.DB, error) {
    76  	db, err := sql.Open("mysql", param.ToDSN())
    77  	if err != nil {
    78  		return nil, errors.Trace(err)
    79  	}
    80  
    81  	return db, errors.Trace(db.Ping())
    82  }
    83  
    84  // IsDirExists checks if dir exists.
    85  func IsDirExists(name string) bool {
    86  	f, err := os.Stat(name)
    87  	if err != nil {
    88  		return false
    89  	}
    90  	return f != nil && f.IsDir()
    91  }
    92  
    93  // IsEmptyDir checks if dir is empty.
    94  func IsEmptyDir(name string) bool {
    95  	entries, err := ioutil.ReadDir(name)
    96  	if err != nil {
    97  		return false
    98  	}
    99  	return len(entries) == 0
   100  }
   101  
   102  // SQLWithRetry constructs a retryable transaction.
   103  type SQLWithRetry struct {
   104  	DB           *sql.DB
   105  	Logger       log.Logger
   106  	HideQueryLog bool
   107  }
   108  
   109  func (t SQLWithRetry) perform(ctx context.Context, parentLogger log.Logger, purpose string, action func() error) error {
   110  	return Retry(purpose, parentLogger, action)
   111  }
   112  
   113  // Retry is shared by SQLWithRetry.perform, implementation of GlueCheckpointsDB and TiDB's glue implementation
   114  func Retry(purpose string, parentLogger log.Logger, action func() error) error {
   115  	var err error
   116  outside:
   117  	for i := 0; i < defaultMaxRetry; i++ {
   118  		logger := parentLogger.With(zap.Int("retryCnt", i))
   119  
   120  		if i > 0 {
   121  			logger.Warn(purpose + " retry start")
   122  			time.Sleep(retryTimeout)
   123  		}
   124  
   125  		err = action()
   126  		switch {
   127  		case err == nil:
   128  			return nil
   129  		case IsRetryableError(err):
   130  			logger.Warn(purpose+" failed but going to try again", log.ShortError(err))
   131  			continue
   132  		default:
   133  			break outside
   134  		}
   135  	}
   136  
   137  	return errors.Annotatef(err, "%s failed", purpose)
   138  }
   139  
   140  func (t SQLWithRetry) QueryRow(ctx context.Context, purpose string, query string, dest ...interface{}) error {
   141  	logger := t.Logger
   142  	if !t.HideQueryLog {
   143  		logger = logger.With(zap.String("query", query))
   144  	}
   145  	return t.perform(ctx, logger, purpose, func() error {
   146  		return t.DB.QueryRowContext(ctx, query).Scan(dest...)
   147  	})
   148  }
   149  
   150  // Transact executes an action in a transaction, and retry if the
   151  // action failed with a retryable error.
   152  func (t SQLWithRetry) Transact(ctx context.Context, purpose string, action func(context.Context, *sql.Tx) error) error {
   153  	return t.perform(ctx, t.Logger, purpose, func() error {
   154  		txn, err := t.DB.BeginTx(ctx, nil)
   155  		if err != nil {
   156  			return errors.Annotate(err, "begin transaction failed")
   157  		}
   158  
   159  		err = action(ctx, txn)
   160  		if err != nil {
   161  			rerr := txn.Rollback()
   162  			if rerr != nil {
   163  				t.Logger.Error(purpose+" rollback transaction failed", log.ShortError(rerr))
   164  			}
   165  			// we should return the exec err, instead of the rollback rerr.
   166  			// no need to errors.Trace() it, as the error comes from user code anyway.
   167  			return err
   168  		}
   169  
   170  		err = txn.Commit()
   171  		if err != nil {
   172  			return errors.Annotate(err, "commit transaction failed")
   173  		}
   174  
   175  		return nil
   176  	})
   177  }
   178  
   179  // Exec executes a single SQL with optional retry.
   180  func (t SQLWithRetry) Exec(ctx context.Context, purpose string, query string, args ...interface{}) error {
   181  	logger := t.Logger
   182  	if !t.HideQueryLog {
   183  		logger = logger.With(zap.String("query", query), zap.Reflect("args", args))
   184  	}
   185  	return t.perform(ctx, logger, purpose, func() error {
   186  		_, err := t.DB.ExecContext(ctx, query, args...)
   187  		return errors.Trace(err)
   188  	})
   189  }
   190  
   191  // sqlmock uses fmt.Errorf to produce expectation failures, which will cause
   192  // unnecessary retry if not specially handled >:(
   193  var stdFatalErrorsRegexp = regexp.MustCompile(
   194  	`^call to (?s:.*) was not expected|arguments do not match:|could not match actual sql`,
   195  )
   196  var stdErrorType = reflect.TypeOf(stderrors.New(""))
   197  
   198  // IsRetryableError returns whether the error is transient (e.g. network
   199  // connection dropped) or irrecoverable (e.g. user pressing Ctrl+C). This
   200  // function returns `false` (irrecoverable) if `err == nil`.
   201  //
   202  // If the error is a multierr, returns true only if all suberrors are retryable.
   203  func IsRetryableError(err error) bool {
   204  	for _, singleError := range errors.Errors(err) {
   205  		if !isSingleRetryableError(singleError) {
   206  			return false
   207  		}
   208  	}
   209  	return true
   210  }
   211  
   212  func isSingleRetryableError(err error) bool {
   213  	err = errors.Cause(err)
   214  
   215  	switch err {
   216  	case nil, context.Canceled, context.DeadlineExceeded, io.EOF:
   217  		return false
   218  	}
   219  
   220  	switch nerr := err.(type) {
   221  	case net.Error:
   222  		return nerr.Timeout()
   223  	case *mysql.MySQLError:
   224  		switch nerr.Number {
   225  		// ErrLockDeadlock can retry to commit while meet deadlock
   226  		case tmysql.ErrUnknown, tmysql.ErrLockDeadlock, tmysql.ErrWriteConflictInTiDB, tmysql.ErrPDServerTimeout, tmysql.ErrTiKVServerTimeout, tmysql.ErrTiKVServerBusy, tmysql.ErrResolveLockTimeout, tmysql.ErrRegionUnavailable:
   227  			return true
   228  		default:
   229  			return false
   230  		}
   231  	default:
   232  		switch status.Code(err) {
   233  		case codes.DeadlineExceeded, codes.NotFound, codes.AlreadyExists, codes.PermissionDenied, codes.ResourceExhausted, codes.Aborted, codes.OutOfRange, codes.Unavailable, codes.DataLoss:
   234  			return true
   235  		case codes.Unknown:
   236  			if reflect.TypeOf(err) == stdErrorType {
   237  				return !stdFatalErrorsRegexp.MatchString(err.Error())
   238  			}
   239  			return true
   240  		default:
   241  			return false
   242  		}
   243  	}
   244  }
   245  
   246  // IsContextCanceledError returns whether the error is caused by context
   247  // cancellation. This function should only be used when the code logic is
   248  // affected by whether the error is canceling or not.
   249  //
   250  // This function returns `false` (not a context-canceled error) if `err == nil`.
   251  func IsContextCanceledError(err error) bool {
   252  	return log.IsContextCanceledError(err)
   253  }
   254  
   255  // UniqueTable returns an unique table name.
   256  func UniqueTable(schema string, table string) string {
   257  	var builder strings.Builder
   258  	WriteMySQLIdentifier(&builder, schema)
   259  	builder.WriteByte('.')
   260  	WriteMySQLIdentifier(&builder, table)
   261  	return builder.String()
   262  }
   263  
   264  // Writes a MySQL identifier into the string builder.
   265  // The identifier is always escaped into the form "`foo`".
   266  func WriteMySQLIdentifier(builder *strings.Builder, identifier string) {
   267  	builder.Grow(len(identifier) + 2)
   268  	builder.WriteByte('`')
   269  
   270  	// use a C-style loop instead of range loop to avoid UTF-8 decoding
   271  	for i := 0; i < len(identifier); i++ {
   272  		b := identifier[i]
   273  		if b == '`' {
   274  			builder.WriteString("``")
   275  		} else {
   276  			builder.WriteByte(b)
   277  		}
   278  	}
   279  
   280  	builder.WriteByte('`')
   281  }
   282  
   283  func InterpolateMySQLString(s string) string {
   284  	var builder strings.Builder
   285  	builder.Grow(len(s) + 2)
   286  	builder.WriteByte('\'')
   287  	for i := 0; i < len(s); i++ {
   288  		b := s[i]
   289  		if b == '\'' {
   290  			builder.WriteString("''")
   291  		} else {
   292  			builder.WriteByte(b)
   293  		}
   294  	}
   295  	builder.WriteByte('\'')
   296  	return builder.String()
   297  }
   298  
   299  // GetJSON fetches a page and parses it as JSON. The parsed result will be
   300  // stored into the `v`. The variable `v` must be a pointer to a type that can be
   301  // unmarshalled from JSON.
   302  //
   303  // Example:
   304  //
   305  //	client := &http.Client{}
   306  //	var resp struct { IP string }
   307  //	if err := util.GetJSON(client, "http://api.ipify.org/?format=json", &resp); err != nil {
   308  //		return errors.Trace(err)
   309  //	}
   310  //	fmt.Println(resp.IP)
   311  func GetJSON(ctx context.Context, client *http.Client, url string, v interface{}) error {
   312  	req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
   313  	if err != nil {
   314  		return errors.Trace(err)
   315  	}
   316  
   317  	resp, err := client.Do(req)
   318  	if err != nil {
   319  		return errors.Trace(err)
   320  	}
   321  
   322  	defer resp.Body.Close()
   323  
   324  	if resp.StatusCode != http.StatusOK {
   325  		body, err := ioutil.ReadAll(resp.Body)
   326  		if err != nil {
   327  			return errors.Trace(err)
   328  		}
   329  		return errors.Errorf("get %s http status code != 200, message %s", url, string(body))
   330  	}
   331  
   332  	return errors.Trace(json.NewDecoder(resp.Body).Decode(v))
   333  }
   334  
   335  // KillMySelf sends sigint to current process, used in integration test only
   336  //
   337  // Only works on Unix. Signaling on Windows is not supported.
   338  func KillMySelf() error {
   339  	proc, err := os.FindProcess(os.Getpid())
   340  	if err == nil {
   341  		err = proc.Signal(syscall.SIGINT)
   342  	}
   343  	return errors.Trace(err)
   344  }
   345  
   346  // KvPair is a pair of key and value.
   347  type KvPair struct {
   348  	// Key is the key of the KV pair
   349  	Key []byte
   350  	// Val is the value of the KV pair
   351  	Val []byte
   352  }
   353  
   354  // TableHasAutoRowID return whether table has auto generated row id
   355  func TableHasAutoRowID(info *model.TableInfo) bool {
   356  	return !info.PKIsHandle && !info.IsCommonHandle
   357  }
   358  
   359  // StringSliceEqual checks if two string slices are equal.
   360  func StringSliceEqual(a, b []string) bool {
   361  	if len(a) != len(b) {
   362  		return false
   363  	}
   364  	for i, v := range a {
   365  		if v != b[i] {
   366  			return false
   367  		}
   368  	}
   369  	return true
   370  }