github.com/pingcap/br@v5.3.0-alpha.0.20220125034240-ec59c7b6ce30+incompatible/pkg/lightning/common/util.go (about)

     1  // Copyright 2019 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package common
    15  
    16  import (
    17  	"context"
    18  	"database/sql"
    19  	"encoding/json"
    20  	stderrors "errors"
    21  	"fmt"
    22  	"io"
    23  	"net"
    24  	"net/http"
    25  	"net/url"
    26  	"os"
    27  	"reflect"
    28  	"regexp"
    29  	"strings"
    30  	"syscall"
    31  	"time"
    32  
    33  	"github.com/go-sql-driver/mysql"
    34  	"github.com/pingcap/errors"
    35  	"github.com/pingcap/parser/model"
    36  	tmysql "github.com/pingcap/tidb/errno"
    37  	"go.uber.org/zap"
    38  	"google.golang.org/grpc/codes"
    39  	"google.golang.org/grpc/status"
    40  
    41  	"github.com/pingcap/br/pkg/lightning/log"
    42  )
    43  
    44  const (
    45  	retryTimeout = 3 * time.Second
    46  
    47  	defaultMaxRetry = 3
    48  )
    49  
    50  // MySQLConnectParam records the parameters needed to connect to a MySQL database.
    51  type MySQLConnectParam struct {
    52  	Host             string
    53  	Port             int
    54  	User             string
    55  	Password         string
    56  	SQLMode          string
    57  	MaxAllowedPacket uint64
    58  	TLS              string
    59  	Vars             map[string]string
    60  }
    61  
    62  func (param *MySQLConnectParam) ToDSN() string {
    63  	dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/?charset=utf8mb4&sql_mode='%s'&maxAllowedPacket=%d&tls=%s",
    64  		param.User, param.Password, param.Host, param.Port,
    65  		param.SQLMode, param.MaxAllowedPacket, param.TLS)
    66  
    67  	for k, v := range param.Vars {
    68  		dsn += fmt.Sprintf("&%s=%s", k, url.QueryEscape(v))
    69  	}
    70  
    71  	return dsn
    72  }
    73  
    74  func (param *MySQLConnectParam) Connect() (*sql.DB, error) {
    75  	db, err := sql.Open("mysql", param.ToDSN())
    76  	if err != nil {
    77  		return nil, errors.Trace(err)
    78  	}
    79  
    80  	return db, errors.Trace(db.Ping())
    81  }
    82  
    83  // IsDirExists checks if dir exists.
    84  func IsDirExists(name string) bool {
    85  	f, err := os.Stat(name)
    86  	if err != nil {
    87  		return false
    88  	}
    89  	return f != nil && f.IsDir()
    90  }
    91  
    92  // IsEmptyDir checks if dir is empty.
    93  func IsEmptyDir(name string) bool {
    94  	entries, err := os.ReadDir(name)
    95  	if err != nil {
    96  		return false
    97  	}
    98  	return len(entries) == 0
    99  }
   100  
   101  type QueryExecutor interface {
   102  	QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
   103  	QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
   104  }
   105  
   106  type DBExecutor interface {
   107  	QueryExecutor
   108  	BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
   109  	ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
   110  }
   111  
   112  // SQLWithRetry constructs a retryable transaction.
   113  type SQLWithRetry struct {
   114  	// either *sql.DB or *sql.Conn
   115  	DB           DBExecutor
   116  	Logger       log.Logger
   117  	HideQueryLog bool
   118  }
   119  
   120  func (t SQLWithRetry) perform(_ context.Context, parentLogger log.Logger, purpose string, action func() error) error {
   121  	return Retry(purpose, parentLogger, action)
   122  }
   123  
   124  // Retry is shared by SQLWithRetry.perform, implementation of GlueCheckpointsDB and TiDB's glue implementation
   125  func Retry(purpose string, parentLogger log.Logger, action func() error) error {
   126  	var err error
   127  outside:
   128  	for i := 0; i < defaultMaxRetry; i++ {
   129  		logger := parentLogger.With(zap.Int("retryCnt", i))
   130  
   131  		if i > 0 {
   132  			logger.Warn(purpose + " retry start")
   133  			time.Sleep(retryTimeout)
   134  		}
   135  
   136  		err = action()
   137  		switch {
   138  		case err == nil:
   139  			return nil
   140  		// do not retry NotFound error
   141  		case errors.IsNotFound(err):
   142  			break outside
   143  		case IsRetryableError(err):
   144  			logger.Warn(purpose+" failed but going to try again", log.ShortError(err))
   145  			continue
   146  		default:
   147  			logger.Warn(purpose+" failed with no retry", log.ShortError(err))
   148  			break outside
   149  		}
   150  	}
   151  
   152  	return errors.Annotatef(err, "%s failed", purpose)
   153  }
   154  
   155  func (t SQLWithRetry) QueryRow(ctx context.Context, purpose string, query string, dest ...interface{}) error {
   156  	logger := t.Logger
   157  	if !t.HideQueryLog {
   158  		logger = logger.With(zap.String("query", query))
   159  	}
   160  	return t.perform(ctx, logger, purpose, func() error {
   161  		return t.DB.QueryRowContext(ctx, query).Scan(dest...)
   162  	})
   163  }
   164  
   165  // Transact executes an action in a transaction, and retry if the
   166  // action failed with a retryable error.
   167  func (t SQLWithRetry) Transact(ctx context.Context, purpose string, action func(context.Context, *sql.Tx) error) error {
   168  	return t.perform(ctx, t.Logger, purpose, func() error {
   169  		txn, err := t.DB.BeginTx(ctx, nil)
   170  		if err != nil {
   171  			return errors.Annotate(err, "begin transaction failed")
   172  		}
   173  
   174  		err = action(ctx, txn)
   175  		if err != nil {
   176  			rerr := txn.Rollback()
   177  			if rerr != nil {
   178  				t.Logger.Error(purpose+" rollback transaction failed", log.ShortError(rerr))
   179  			}
   180  			// we should return the exec err, instead of the rollback rerr.
   181  			// no need to errors.Trace() it, as the error comes from user code anyway.
   182  			return err
   183  		}
   184  
   185  		err = txn.Commit()
   186  		if err != nil {
   187  			return errors.Annotate(err, "commit transaction failed")
   188  		}
   189  
   190  		return nil
   191  	})
   192  }
   193  
   194  // Exec executes a single SQL with optional retry.
   195  func (t SQLWithRetry) Exec(ctx context.Context, purpose string, query string, args ...interface{}) error {
   196  	logger := t.Logger
   197  	if !t.HideQueryLog {
   198  		logger = logger.With(zap.String("query", query), zap.Reflect("args", args))
   199  	}
   200  	return t.perform(ctx, logger, purpose, func() error {
   201  		_, err := t.DB.ExecContext(ctx, query, args...)
   202  		return errors.Trace(err)
   203  	})
   204  }
   205  
   206  // sqlmock uses fmt.Errorf to produce expectation failures, which will cause
   207  // unnecessary retry if not specially handled >:(
   208  var stdFatalErrorsRegexp = regexp.MustCompile(
   209  	`^call to (?s:.*) was not expected|arguments do not match:|could not match actual sql`,
   210  )
   211  var stdErrorType = reflect.TypeOf(stderrors.New(""))
   212  
   213  // IsRetryableError returns whether the error is transient (e.g. network
   214  // connection dropped) or irrecoverable (e.g. user pressing Ctrl+C). This
   215  // function returns `false` (irrecoverable) if `err == nil`.
   216  //
   217  // If the error is a multierr, returns true only if all suberrors are retryable.
   218  func IsRetryableError(err error) bool {
   219  	for _, singleError := range errors.Errors(err) {
   220  		if !isSingleRetryableError(singleError) {
   221  			return false
   222  		}
   223  	}
   224  	return true
   225  }
   226  
   227  func isSingleRetryableError(err error) bool {
   228  	err = errors.Cause(err)
   229  
   230  	switch err {
   231  	case nil, context.Canceled, context.DeadlineExceeded, io.EOF, sql.ErrNoRows:
   232  		return false
   233  	}
   234  
   235  	switch nerr := err.(type) {
   236  	case net.Error:
   237  		return nerr.Timeout()
   238  	case *mysql.MySQLError:
   239  		switch nerr.Number {
   240  		// ErrLockDeadlock can retry to commit while meet deadlock
   241  		case tmysql.ErrUnknown, tmysql.ErrLockDeadlock, tmysql.ErrWriteConflictInTiDB, tmysql.ErrPDServerTimeout, tmysql.ErrTiKVServerTimeout, tmysql.ErrTiKVServerBusy, tmysql.ErrResolveLockTimeout, tmysql.ErrRegionUnavailable:
   242  			return true
   243  		default:
   244  			return false
   245  		}
   246  	default:
   247  		switch status.Code(err) {
   248  		case codes.DeadlineExceeded, codes.NotFound, codes.AlreadyExists, codes.PermissionDenied, codes.ResourceExhausted, codes.Aborted, codes.OutOfRange, codes.Unavailable, codes.DataLoss:
   249  			return true
   250  		case codes.Unknown:
   251  			if reflect.TypeOf(err) == stdErrorType {
   252  				return !stdFatalErrorsRegexp.MatchString(err.Error())
   253  			}
   254  			return true
   255  		default:
   256  			return false
   257  		}
   258  	}
   259  }
   260  
   261  // IsContextCanceledError returns whether the error is caused by context
   262  // cancellation. This function should only be used when the code logic is
   263  // affected by whether the error is canceling or not.
   264  //
   265  // This function returns `false` (not a context-canceled error) if `err == nil`.
   266  func IsContextCanceledError(err error) bool {
   267  	return log.IsContextCanceledError(err)
   268  }
   269  
   270  // UniqueTable returns an unique table name.
   271  func UniqueTable(schema string, table string) string {
   272  	var builder strings.Builder
   273  	WriteMySQLIdentifier(&builder, schema)
   274  	builder.WriteByte('.')
   275  	WriteMySQLIdentifier(&builder, table)
   276  	return builder.String()
   277  }
   278  
   279  // EscapeIdentifier quote and escape an sql identifier
   280  func EscapeIdentifier(identifier string) string {
   281  	var builder strings.Builder
   282  	WriteMySQLIdentifier(&builder, identifier)
   283  	return builder.String()
   284  }
   285  
   286  // Writes a MySQL identifier into the string builder.
   287  // The identifier is always escaped into the form "`foo`".
   288  func WriteMySQLIdentifier(builder *strings.Builder, identifier string) {
   289  	builder.Grow(len(identifier) + 2)
   290  	builder.WriteByte('`')
   291  
   292  	// use a C-style loop instead of range loop to avoid UTF-8 decoding
   293  	for i := 0; i < len(identifier); i++ {
   294  		b := identifier[i]
   295  		if b == '`' {
   296  			builder.WriteString("``")
   297  		} else {
   298  			builder.WriteByte(b)
   299  		}
   300  	}
   301  
   302  	builder.WriteByte('`')
   303  }
   304  
   305  func InterpolateMySQLString(s string) string {
   306  	var builder strings.Builder
   307  	builder.Grow(len(s) + 2)
   308  	builder.WriteByte('\'')
   309  	for i := 0; i < len(s); i++ {
   310  		b := s[i]
   311  		if b == '\'' {
   312  			builder.WriteString("''")
   313  		} else {
   314  			builder.WriteByte(b)
   315  		}
   316  	}
   317  	builder.WriteByte('\'')
   318  	return builder.String()
   319  }
   320  
   321  // GetJSON fetches a page and parses it as JSON. The parsed result will be
   322  // stored into the `v`. The variable `v` must be a pointer to a type that can be
   323  // unmarshalled from JSON.
   324  //
   325  // Example:
   326  //
   327  //	client := &http.Client{}
   328  //	var resp struct { IP string }
   329  //	if err := util.GetJSON(client, "http://api.ipify.org/?format=json", &resp); err != nil {
   330  //		return errors.Trace(err)
   331  //	}
   332  //	fmt.Println(resp.IP)
   333  func GetJSON(ctx context.Context, client *http.Client, url string, v interface{}) error {
   334  	req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
   335  	if err != nil {
   336  		return errors.Trace(err)
   337  	}
   338  
   339  	resp, err := client.Do(req)
   340  	if err != nil {
   341  		return errors.Trace(err)
   342  	}
   343  
   344  	defer resp.Body.Close()
   345  
   346  	if resp.StatusCode != http.StatusOK {
   347  		body, err := io.ReadAll(resp.Body)
   348  		if err != nil {
   349  			return errors.Trace(err)
   350  		}
   351  		return errors.Errorf("get %s http status code != 200, message %s", url, string(body))
   352  	}
   353  
   354  	return errors.Trace(json.NewDecoder(resp.Body).Decode(v))
   355  }
   356  
   357  // KillMySelf sends sigint to current process, used in integration test only
   358  //
   359  // Only works on Unix. Signaling on Windows is not supported.
   360  func KillMySelf() error {
   361  	proc, err := os.FindProcess(os.Getpid())
   362  	if err == nil {
   363  		err = proc.Signal(syscall.SIGINT)
   364  	}
   365  	return errors.Trace(err)
   366  }
   367  
   368  // KvPair is a pair of key and value.
   369  type KvPair struct {
   370  	// Key is the key of the KV pair
   371  	Key []byte
   372  	// Val is the value of the KV pair
   373  	Val []byte
   374  	// RowID is the row id of the KV pair.
   375  	RowID int64
   376  	// Offset is the row's offset in file.
   377  	Offset int64
   378  }
   379  
   380  // TableHasAutoRowID return whether table has auto generated row id
   381  func TableHasAutoRowID(info *model.TableInfo) bool {
   382  	return !info.PKIsHandle && !info.IsCommonHandle
   383  }
   384  
   385  // StringSliceEqual checks if two string slices are equal.
   386  func StringSliceEqual(a, b []string) bool {
   387  	if len(a) != len(b) {
   388  		return false
   389  	}
   390  	for i, v := range a {
   391  		if v != b[i] {
   392  			return false
   393  		}
   394  	}
   395  	return true
   396  }