github.com/pingcap/tidb-lightning@v5.0.0-rc.0.20210428090220-84b649866577+incompatible/lightning/common/util.go (about) 1 // Copyright 2019 PingCAP, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package common 15 16 import ( 17 "context" 18 "database/sql" 19 "encoding/json" 20 stderrors "errors" 21 "fmt" 22 "io" 23 "io/ioutil" 24 "net" 25 "net/http" 26 "net/url" 27 "os" 28 "reflect" 29 "regexp" 30 "strings" 31 "syscall" 32 "time" 33 34 "github.com/go-sql-driver/mysql" 35 "github.com/pingcap/errors" 36 "github.com/pingcap/parser/model" 37 tmysql "github.com/pingcap/tidb/errno" 38 "go.uber.org/zap" 39 "google.golang.org/grpc/codes" 40 "google.golang.org/grpc/status" 41 42 "github.com/pingcap/tidb-lightning/lightning/log" 43 ) 44 45 const ( 46 retryTimeout = 3 * time.Second 47 48 defaultMaxRetry = 3 49 ) 50 51 // MySQLConnectParam records the parameters needed to connect to a MySQL database. 52 type MySQLConnectParam struct { 53 Host string 54 Port int 55 User string 56 Password string 57 SQLMode string 58 MaxAllowedPacket uint64 59 TLS string 60 Vars map[string]string 61 } 62 63 func (param *MySQLConnectParam) ToDSN() string { 64 dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/?charset=utf8mb4&sql_mode='%s'&maxAllowedPacket=%d&tls=%s", 65 param.User, param.Password, param.Host, param.Port, 66 param.SQLMode, param.MaxAllowedPacket, param.TLS) 67 68 for k, v := range param.Vars { 69 dsn += fmt.Sprintf("&%s=%s", k, url.QueryEscape(v)) 70 } 71 72 return dsn 73 } 74 75 func (param *MySQLConnectParam) Connect() (*sql.DB, error) { 76 db, err := sql.Open("mysql", param.ToDSN()) 77 if err != nil { 78 return nil, errors.Trace(err) 79 } 80 81 return db, errors.Trace(db.Ping()) 82 } 83 84 // IsDirExists checks if dir exists. 85 func IsDirExists(name string) bool { 86 f, err := os.Stat(name) 87 if err != nil { 88 return false 89 } 90 return f != nil && f.IsDir() 91 } 92 93 // IsEmptyDir checks if dir is empty. 94 func IsEmptyDir(name string) bool { 95 entries, err := ioutil.ReadDir(name) 96 if err != nil { 97 return false 98 } 99 return len(entries) == 0 100 } 101 102 // SQLWithRetry constructs a retryable transaction. 103 type SQLWithRetry struct { 104 DB *sql.DB 105 Logger log.Logger 106 HideQueryLog bool 107 } 108 109 func (t SQLWithRetry) perform(ctx context.Context, parentLogger log.Logger, purpose string, action func() error) error { 110 return Retry(purpose, parentLogger, action) 111 } 112 113 // Retry is shared by SQLWithRetry.perform, implementation of GlueCheckpointsDB and TiDB's glue implementation 114 func Retry(purpose string, parentLogger log.Logger, action func() error) error { 115 var err error 116 outside: 117 for i := 0; i < defaultMaxRetry; i++ { 118 logger := parentLogger.With(zap.Int("retryCnt", i)) 119 120 if i > 0 { 121 logger.Warn(purpose + " retry start") 122 time.Sleep(retryTimeout) 123 } 124 125 err = action() 126 switch { 127 case err == nil: 128 return nil 129 case IsRetryableError(err): 130 logger.Warn(purpose+" failed but going to try again", log.ShortError(err)) 131 continue 132 default: 133 break outside 134 } 135 } 136 137 return errors.Annotatef(err, "%s failed", purpose) 138 } 139 140 func (t SQLWithRetry) QueryRow(ctx context.Context, purpose string, query string, dest ...interface{}) error { 141 logger := t.Logger 142 if !t.HideQueryLog { 143 logger = logger.With(zap.String("query", query)) 144 } 145 return t.perform(ctx, logger, purpose, func() error { 146 return t.DB.QueryRowContext(ctx, query).Scan(dest...) 147 }) 148 } 149 150 // Transact executes an action in a transaction, and retry if the 151 // action failed with a retryable error. 152 func (t SQLWithRetry) Transact(ctx context.Context, purpose string, action func(context.Context, *sql.Tx) error) error { 153 return t.perform(ctx, t.Logger, purpose, func() error { 154 txn, err := t.DB.BeginTx(ctx, nil) 155 if err != nil { 156 return errors.Annotate(err, "begin transaction failed") 157 } 158 159 err = action(ctx, txn) 160 if err != nil { 161 rerr := txn.Rollback() 162 if rerr != nil { 163 t.Logger.Error(purpose+" rollback transaction failed", log.ShortError(rerr)) 164 } 165 // we should return the exec err, instead of the rollback rerr. 166 // no need to errors.Trace() it, as the error comes from user code anyway. 167 return err 168 } 169 170 err = txn.Commit() 171 if err != nil { 172 return errors.Annotate(err, "commit transaction failed") 173 } 174 175 return nil 176 }) 177 } 178 179 // Exec executes a single SQL with optional retry. 180 func (t SQLWithRetry) Exec(ctx context.Context, purpose string, query string, args ...interface{}) error { 181 logger := t.Logger 182 if !t.HideQueryLog { 183 logger = logger.With(zap.String("query", query), zap.Reflect("args", args)) 184 } 185 return t.perform(ctx, logger, purpose, func() error { 186 _, err := t.DB.ExecContext(ctx, query, args...) 187 return errors.Trace(err) 188 }) 189 } 190 191 // sqlmock uses fmt.Errorf to produce expectation failures, which will cause 192 // unnecessary retry if not specially handled >:( 193 var stdFatalErrorsRegexp = regexp.MustCompile( 194 `^call to (?s:.*) was not expected|arguments do not match:|could not match actual sql`, 195 ) 196 var stdErrorType = reflect.TypeOf(stderrors.New("")) 197 198 // IsRetryableError returns whether the error is transient (e.g. network 199 // connection dropped) or irrecoverable (e.g. user pressing Ctrl+C). This 200 // function returns `false` (irrecoverable) if `err == nil`. 201 // 202 // If the error is a multierr, returns true only if all suberrors are retryable. 203 func IsRetryableError(err error) bool { 204 for _, singleError := range errors.Errors(err) { 205 if !isSingleRetryableError(singleError) { 206 return false 207 } 208 } 209 return true 210 } 211 212 func isSingleRetryableError(err error) bool { 213 err = errors.Cause(err) 214 215 switch err { 216 case nil, context.Canceled, context.DeadlineExceeded, io.EOF: 217 return false 218 } 219 220 switch nerr := err.(type) { 221 case net.Error: 222 return nerr.Timeout() 223 case *mysql.MySQLError: 224 switch nerr.Number { 225 // ErrLockDeadlock can retry to commit while meet deadlock 226 case tmysql.ErrUnknown, tmysql.ErrLockDeadlock, tmysql.ErrWriteConflictInTiDB, tmysql.ErrPDServerTimeout, tmysql.ErrTiKVServerTimeout, tmysql.ErrTiKVServerBusy, tmysql.ErrResolveLockTimeout, tmysql.ErrRegionUnavailable: 227 return true 228 default: 229 return false 230 } 231 default: 232 switch status.Code(err) { 233 case codes.DeadlineExceeded, codes.NotFound, codes.AlreadyExists, codes.PermissionDenied, codes.ResourceExhausted, codes.Aborted, codes.OutOfRange, codes.Unavailable, codes.DataLoss: 234 return true 235 case codes.Unknown: 236 if reflect.TypeOf(err) == stdErrorType { 237 return !stdFatalErrorsRegexp.MatchString(err.Error()) 238 } 239 return true 240 default: 241 return false 242 } 243 } 244 } 245 246 // IsContextCanceledError returns whether the error is caused by context 247 // cancellation. This function should only be used when the code logic is 248 // affected by whether the error is canceling or not. 249 // 250 // This function returns `false` (not a context-canceled error) if `err == nil`. 251 func IsContextCanceledError(err error) bool { 252 return log.IsContextCanceledError(err) 253 } 254 255 // UniqueTable returns an unique table name. 256 func UniqueTable(schema string, table string) string { 257 var builder strings.Builder 258 WriteMySQLIdentifier(&builder, schema) 259 builder.WriteByte('.') 260 WriteMySQLIdentifier(&builder, table) 261 return builder.String() 262 } 263 264 // Writes a MySQL identifier into the string builder. 265 // The identifier is always escaped into the form "`foo`". 266 func WriteMySQLIdentifier(builder *strings.Builder, identifier string) { 267 builder.Grow(len(identifier) + 2) 268 builder.WriteByte('`') 269 270 // use a C-style loop instead of range loop to avoid UTF-8 decoding 271 for i := 0; i < len(identifier); i++ { 272 b := identifier[i] 273 if b == '`' { 274 builder.WriteString("``") 275 } else { 276 builder.WriteByte(b) 277 } 278 } 279 280 builder.WriteByte('`') 281 } 282 283 func InterpolateMySQLString(s string) string { 284 var builder strings.Builder 285 builder.Grow(len(s) + 2) 286 builder.WriteByte('\'') 287 for i := 0; i < len(s); i++ { 288 b := s[i] 289 if b == '\'' { 290 builder.WriteString("''") 291 } else { 292 builder.WriteByte(b) 293 } 294 } 295 builder.WriteByte('\'') 296 return builder.String() 297 } 298 299 // GetJSON fetches a page and parses it as JSON. The parsed result will be 300 // stored into the `v`. The variable `v` must be a pointer to a type that can be 301 // unmarshalled from JSON. 302 // 303 // Example: 304 // 305 // client := &http.Client{} 306 // var resp struct { IP string } 307 // if err := util.GetJSON(client, "http://api.ipify.org/?format=json", &resp); err != nil { 308 // return errors.Trace(err) 309 // } 310 // fmt.Println(resp.IP) 311 func GetJSON(ctx context.Context, client *http.Client, url string, v interface{}) error { 312 req, err := http.NewRequestWithContext(ctx, "GET", url, nil) 313 if err != nil { 314 return errors.Trace(err) 315 } 316 317 resp, err := client.Do(req) 318 if err != nil { 319 return errors.Trace(err) 320 } 321 322 defer resp.Body.Close() 323 324 if resp.StatusCode != http.StatusOK { 325 body, err := ioutil.ReadAll(resp.Body) 326 if err != nil { 327 return errors.Trace(err) 328 } 329 return errors.Errorf("get %s http status code != 200, message %s", url, string(body)) 330 } 331 332 return errors.Trace(json.NewDecoder(resp.Body).Decode(v)) 333 } 334 335 // KillMySelf sends sigint to current process, used in integration test only 336 // 337 // Only works on Unix. Signaling on Windows is not supported. 338 func KillMySelf() error { 339 proc, err := os.FindProcess(os.Getpid()) 340 if err == nil { 341 err = proc.Signal(syscall.SIGINT) 342 } 343 return errors.Trace(err) 344 } 345 346 // KvPair is a pair of key and value. 347 type KvPair struct { 348 // Key is the key of the KV pair 349 Key []byte 350 // Val is the value of the KV pair 351 Val []byte 352 } 353 354 // TableHasAutoRowID return whether table has auto generated row id 355 func TableHasAutoRowID(info *model.TableInfo) bool { 356 return !info.PKIsHandle && !info.IsCommonHandle 357 } 358 359 // StringSliceEqual checks if two string slices are equal. 360 func StringSliceEqual(a, b []string) bool { 361 if len(a) != len(b) { 362 return false 363 } 364 for i, v := range a { 365 if v != b[i] { 366 return false 367 } 368 } 369 return true 370 }