github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/dm/pkg/checker/utils.go (about)

     1  // Copyright 2021 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package checker
    15  
    16  import (
    17  	"context"
    18  	"fmt"
    19  	"math"
    20  	"strconv"
    21  	"strings"
    22  
    23  	"github.com/go-sql-driver/mysql"
    24  	"github.com/pingcap/errors"
    25  	"github.com/pingcap/tidb-tools/pkg/utils"
    26  	"github.com/pingcap/tidb/pkg/parser"
    27  	"github.com/pingcap/tidb/pkg/parser/ast"
    28  	"github.com/pingcap/tidb/pkg/parser/format"
    29  	"github.com/pingcap/tiflow/dm/pkg/log"
    30  	"github.com/pingcap/tiflow/dm/pkg/terror"
    31  	"go.uber.org/zap"
    32  )
    33  
    34  // MySQLVersion represents MySQL version number.
    35  type MySQLVersion [3]uint
    36  
    37  // MinVersion define a mininum version.
    38  var MinVersion = MySQLVersion{0, 0, 0}
    39  
    40  // MaxVersion define a maximum version.
    41  var MaxVersion = MySQLVersion{math.MaxUint8, math.MaxUint8, math.MaxUint8}
    42  
    43  // version format:
    44  // mysql        5.7.18-log
    45  // mariadb      5.5.50-MariaDB-1~wheezy
    46  // percona      5.7.19-17-log
    47  // aliyun rds   5.7.18-log
    48  // aws rds      5.7.16-log
    49  // ref: https://dev.mysql.com/doc/refman/5.7/en/which-version.html
    50  
    51  // v is mysql version in string format.
    52  func toMySQLVersion(v string) (MySQLVersion, error) {
    53  	version := MySQLVersion{0, 0, 0}
    54  	tmp := strings.Split(v, "-")
    55  	if len(tmp) == 0 {
    56  		return version, errors.NotValidf("MySQL version %s", v)
    57  	}
    58  
    59  	tmp = strings.Split(tmp[0], ".")
    60  	if len(tmp) != 3 {
    61  		return version, errors.NotValidf("MySQL version %s", v)
    62  	}
    63  
    64  	for i := range tmp {
    65  		val, err := strconv.ParseUint(tmp[i], 10, 64)
    66  		if err != nil {
    67  			return version, errors.NotValidf("MySQL version %s", v)
    68  		}
    69  		version[i] = uint(val)
    70  	}
    71  	return version, nil
    72  }
    73  
    74  // Ge means v >= min.
    75  func (v MySQLVersion) Ge(min MySQLVersion) bool {
    76  	for i := range v {
    77  		if v[i] > min[i] {
    78  			return true
    79  		} else if v[i] < min[i] {
    80  			return false
    81  		}
    82  	}
    83  	return true
    84  }
    85  
    86  // Gt means v > min.
    87  func (v MySQLVersion) Gt(min MySQLVersion) bool {
    88  	for i := range v {
    89  		if v[i] > min[i] {
    90  			return true
    91  		} else if v[i] < min[i] {
    92  			return false
    93  		}
    94  	}
    95  	return false
    96  }
    97  
    98  // Lt means v < min.
    99  func (v MySQLVersion) Lt(max MySQLVersion) bool {
   100  	for i := range v {
   101  		if v[i] < max[i] {
   102  			return true
   103  		} else if v[i] > max[i] {
   104  			return false
   105  		}
   106  	}
   107  	return false
   108  }
   109  
   110  // Le means v <= min.
   111  func (v MySQLVersion) Le(max MySQLVersion) bool {
   112  	for i := range v {
   113  		if v[i] < max[i] {
   114  			return true
   115  		} else if v[i] > max[i] {
   116  			return false
   117  		}
   118  	}
   119  	return true
   120  }
   121  
   122  // String implements the Stringer interface.
   123  func (v MySQLVersion) String() string {
   124  	return fmt.Sprintf("%d.%d.%d", v[0], v[1], v[2])
   125  }
   126  
   127  // IsTiDBFromVersion tells whether the version is tidb.
   128  func IsTiDBFromVersion(version string) bool {
   129  	return strings.Contains(strings.ToUpper(version), "TIDB")
   130  }
   131  
   132  func markCheckError(result *Result, err error) {
   133  	if err != nil {
   134  		state := StateFailure
   135  		if utils.OriginError(err) == context.Canceled {
   136  			state = StateWarning
   137  		}
   138  		// `StateWarning` can't cover `StateFailure`.
   139  		if result.State != StateFailure {
   140  			result.State = state
   141  		}
   142  		if err2, ok := err.(*terror.Error); ok {
   143  			result.Errors = append(result.Errors, &Error{Severity: state, ShortErr: err2.ErrorWithoutWorkaround()})
   144  			result.Instruction = err2.Workaround()
   145  		} else {
   146  			result.Errors = append(result.Errors, &Error{Severity: state, ShortErr: err.Error()})
   147  		}
   148  	}
   149  }
   150  
   151  func markCheckErrorFromParser(result *Result, err error) {
   152  	if err != nil {
   153  		state := StateWarning
   154  		// `StateWarning` can't cover `StateFailure`.
   155  		if result.State != StateFailure {
   156  			result.State = state
   157  		}
   158  		result.Errors = append(result.Errors, &Error{Severity: state, ShortErr: err.Error()})
   159  	}
   160  }
   161  
   162  //nolint:unparam
   163  func isMySQLError(err error, code uint16) bool {
   164  	err = errors.Cause(err)
   165  	e, ok := err.(*mysql.MySQLError)
   166  	return ok && e.Number == code
   167  }
   168  
   169  func getCreateTableStmt(p *parser.Parser, statement string) (*ast.CreateTableStmt, error) {
   170  	stmt, err := p.ParseOneStmt(statement, "", "")
   171  	if err != nil {
   172  		return nil, errors.Annotatef(err, "statement %s", statement)
   173  	}
   174  
   175  	ctStmt, ok := stmt.(*ast.CreateTableStmt)
   176  	if !ok {
   177  		return nil, errors.Errorf("Expect CreateTableStmt but got %T", stmt)
   178  	}
   179  	return ctStmt, nil
   180  }
   181  
   182  func getCharset(stmt *ast.CreateTableStmt) string {
   183  	if stmt.Options != nil {
   184  		for _, option := range stmt.Options {
   185  			if option.Tp == ast.TableOptionCharset {
   186  				return option.StrValue
   187  			}
   188  		}
   189  	}
   190  	return ""
   191  }
   192  
   193  func getCollation(stmt *ast.CreateTableStmt) string {
   194  	if stmt.Options != nil {
   195  		for _, option := range stmt.Options {
   196  			if option.Tp == ast.TableOptionCollate {
   197  				return option.StrValue
   198  			}
   199  		}
   200  	}
   201  	return ""
   202  }
   203  
   204  // getPKAndUK returns a map of INDEX_NAME -> set of COLUMN_NAMEs.
   205  func getPKAndUK(stmt *ast.CreateTableStmt) map[string]map[string]struct{} {
   206  	ret := make(map[string]map[string]struct{})
   207  	var sb strings.Builder
   208  	restoreCtx := format.NewRestoreCtx(format.DefaultRestoreFlags, &sb)
   209  
   210  	for _, constraint := range stmt.Constraints {
   211  		switch constraint.Tp {
   212  		case ast.ConstraintPrimaryKey:
   213  			ret["PRIMARY"] = make(map[string]struct{})
   214  			for _, key := range constraint.Keys {
   215  				ret["PRIMARY"][key.Column.Name.L] = struct{}{}
   216  			}
   217  		case ast.ConstraintUniq, ast.ConstraintUniqKey, ast.ConstraintUniqIndex:
   218  			ret[constraint.Name] = make(map[string]struct{})
   219  			for _, key := range constraint.Keys {
   220  				if key.Column != nil {
   221  					ret[constraint.Name][key.Column.Name.L] = struct{}{}
   222  				} else {
   223  					sb.Reset()
   224  					err := key.Expr.Restore(restoreCtx)
   225  					if err != nil {
   226  						log.L().Warn("failed to restore expression", zap.Error(err))
   227  						continue
   228  					}
   229  					ret[constraint.Name][sb.String()] = struct{}{}
   230  				}
   231  			}
   232  		}
   233  	}
   234  	return ret
   235  }
   236  
   237  func stringSetEqual(a, b map[string]struct{}) bool {
   238  	if len(a) != len(b) {
   239  		return false
   240  	}
   241  	for k := range a {
   242  		if _, ok := b[k]; !ok {
   243  			return false
   244  		}
   245  	}
   246  	return true
   247  }
   248  
   249  // getColumnsAndIgnorable return a map of COLUMN_NAME -> if this columns can be
   250  // ignored when inserting data, which means it has default value or can be null.
   251  func getColumnsAndIgnorable(stmt *ast.CreateTableStmt) map[string]bool {
   252  	ret := make(map[string]bool)
   253  	for _, col := range stmt.Cols {
   254  		notNull := false
   255  		hasDefaultValue := false
   256  		for _, opt := range col.Options {
   257  			switch opt.Tp {
   258  			case ast.ColumnOptionNotNull:
   259  				notNull = true
   260  			case ast.ColumnOptionDefaultValue,
   261  				ast.ColumnOptionAutoIncrement,
   262  				ast.ColumnOptionAutoRandom,
   263  				ast.ColumnOptionGenerated:
   264  				// if the generated column has NOT NULL, its referring columns
   265  				// must not be NULL. But even if we mark the referring columns
   266  				// as not ignorable, the data may still be NULL so replication
   267  				// is still failed. For simplicity, we just ignore this case.
   268  				hasDefaultValue = true
   269  			}
   270  		}
   271  		ret[col.Name.Name.L] = !notNull || hasDefaultValue
   272  	}
   273  	return ret
   274  }