github.com/pingcap/tidb-lightning@v5.0.0-rc.0.20210428090220-84b649866577+incompatible/lightning/backend/tidb.go (about)

     1  // Copyright 2019 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package backend
    15  
    16  import (
    17  	"context"
    18  	"database/sql"
    19  	"encoding/hex"
    20  	"fmt"
    21  	"strconv"
    22  	"strings"
    23  	"time"
    24  
    25  	"github.com/google/uuid"
    26  	"github.com/pingcap/errors"
    27  	"github.com/pingcap/failpoint"
    28  	"github.com/pingcap/parser/model"
    29  	"github.com/pingcap/parser/mysql"
    30  	"github.com/pingcap/tidb/table"
    31  	"github.com/pingcap/tidb/types"
    32  	"go.uber.org/zap"
    33  	"go.uber.org/zap/zapcore"
    34  
    35  	"github.com/pingcap/tidb-lightning/lightning/common"
    36  	"github.com/pingcap/tidb-lightning/lightning/config"
    37  	"github.com/pingcap/tidb-lightning/lightning/log"
    38  	"github.com/pingcap/tidb-lightning/lightning/verification"
    39  )
    40  
    41  var (
    42  	extraHandleTableColumn = &table.Column{
    43  		ColumnInfo:    extraHandleColumnInfo,
    44  		GeneratedExpr: nil,
    45  		DefaultExpr:   nil,
    46  	}
    47  )
    48  
    49  type tidbRow string
    50  
    51  type tidbRows []tidbRow
    52  
    53  // MarshalLogArray implements the zapcore.ArrayMarshaler interface
    54  func (row tidbRows) MarshalLogArray(encoder zapcore.ArrayEncoder) error {
    55  	for _, r := range row {
    56  		encoder.AppendString(string(r))
    57  	}
    58  	return nil
    59  }
    60  
    61  type tidbEncoder struct {
    62  	mode mysql.SQLMode
    63  	tbl  table.Table
    64  	se   *session
    65  	// the index of table columns for each data field.
    66  	// index == len(table.columns) means this field is `_tidb_rowid`
    67  	columnIdx []int
    68  	columnCnt int
    69  }
    70  
    71  type tidbBackend struct {
    72  	db          *sql.DB
    73  	onDuplicate string
    74  }
    75  
    76  // NewTiDBBackend creates a new TiDB backend using the given database.
    77  //
    78  // The backend does not take ownership of `db`. Caller should close `db`
    79  // manually after the backend expired.
    80  func NewTiDBBackend(db *sql.DB, onDuplicate string) Backend {
    81  	switch onDuplicate {
    82  	case config.ReplaceOnDup, config.IgnoreOnDup, config.ErrorOnDup:
    83  	default:
    84  		log.L().Warn("unsupported action on duplicate, overwrite with `replace`")
    85  		onDuplicate = config.ReplaceOnDup
    86  	}
    87  	return MakeBackend(&tidbBackend{db: db, onDuplicate: onDuplicate})
    88  }
    89  
    90  func (row tidbRow) ClassifyAndAppend(data *Rows, checksum *verification.KVChecksum, _ *Rows, _ *verification.KVChecksum) {
    91  	rows := (*data).(tidbRows)
    92  	*data = tidbRows(append(rows, row))
    93  	cs := verification.MakeKVChecksum(uint64(len(row)), 1, 0)
    94  	checksum.Add(&cs)
    95  }
    96  
    97  func (rows tidbRows) SplitIntoChunks(splitSize int) []Rows {
    98  	if len(rows) == 0 {
    99  		return nil
   100  	}
   101  
   102  	res := make([]Rows, 0, 1)
   103  	i := 0
   104  	cumSize := 0
   105  
   106  	for j, row := range rows {
   107  		if i < j && cumSize+len(row) > splitSize {
   108  			res = append(res, rows[i:j])
   109  			i = j
   110  			cumSize = 0
   111  		}
   112  		cumSize += len(row)
   113  	}
   114  
   115  	return append(res, rows[i:])
   116  }
   117  
   118  func (rows tidbRows) Clear() Rows {
   119  	return rows[:0]
   120  }
   121  
   122  func (enc *tidbEncoder) appendSQLBytes(sb *strings.Builder, value []byte) {
   123  	sb.Grow(2 + len(value))
   124  	sb.WriteByte('\'')
   125  	if enc.mode.HasNoBackslashEscapesMode() {
   126  		for _, b := range value {
   127  			if b == '\'' {
   128  				sb.WriteString(`''`)
   129  			} else {
   130  				sb.WriteByte(b)
   131  			}
   132  		}
   133  	} else {
   134  		for _, b := range value {
   135  			switch b {
   136  			case 0:
   137  				sb.WriteString(`\0`)
   138  			case '\b':
   139  				sb.WriteString(`\b`)
   140  			case '\n':
   141  				sb.WriteString(`\n`)
   142  			case '\r':
   143  				sb.WriteString(`\r`)
   144  			case '\t':
   145  				sb.WriteString(`\t`)
   146  			case 26:
   147  				sb.WriteString(`\Z`)
   148  			case '\'':
   149  				sb.WriteString(`''`)
   150  			case '\\':
   151  				sb.WriteString(`\\`)
   152  			default:
   153  				sb.WriteByte(b)
   154  			}
   155  		}
   156  	}
   157  	sb.WriteByte('\'')
   158  }
   159  
   160  // appendSQL appends the SQL representation of the Datum into the string builder.
   161  // Note that we cannot use Datum.ToString since it doesn't perform SQL escaping.
   162  func (enc *tidbEncoder) appendSQL(sb *strings.Builder, datum *types.Datum, col *table.Column) error {
   163  	switch datum.Kind() {
   164  	case types.KindNull:
   165  		sb.WriteString("NULL")
   166  
   167  	case types.KindMinNotNull:
   168  		sb.WriteString("MINVALUE")
   169  
   170  	case types.KindMaxValue:
   171  		sb.WriteString("MAXVALUE")
   172  
   173  	case types.KindInt64:
   174  		// longest int64 = -9223372036854775808 which has 20 characters
   175  		var buffer [20]byte
   176  		value := strconv.AppendInt(buffer[:0], datum.GetInt64(), 10)
   177  		sb.Write(value)
   178  
   179  	case types.KindUint64, types.KindMysqlEnum, types.KindMysqlSet:
   180  		// longest uint64 = 18446744073709551615 which has 20 characters
   181  		var buffer [20]byte
   182  		value := strconv.AppendUint(buffer[:0], datum.GetUint64(), 10)
   183  		sb.Write(value)
   184  
   185  	case types.KindFloat32, types.KindFloat64:
   186  		// float64 has 16 digits of precision, so a buffer size of 32 is more than enough...
   187  		var buffer [32]byte
   188  		value := strconv.AppendFloat(buffer[:0], datum.GetFloat64(), 'g', -1, 64)
   189  		sb.Write(value)
   190  	case types.KindString:
   191  		// See: https://github.com/pingcap/tidb-lightning/issues/550
   192  		//if enc.mode.HasStrictMode() {
   193  		//	d, err := table.CastValue(enc.se, *datum, col.ToInfo(), false, false)
   194  		//	if err != nil {
   195  		//		return errors.Trace(err)
   196  		//	}
   197  		//	datum = &d
   198  		//}
   199  
   200  		enc.appendSQLBytes(sb, datum.GetBytes())
   201  	case types.KindBytes:
   202  		enc.appendSQLBytes(sb, datum.GetBytes())
   203  
   204  	case types.KindMysqlJSON:
   205  		value, err := datum.GetMysqlJSON().MarshalJSON()
   206  		if err != nil {
   207  			return err
   208  		}
   209  		enc.appendSQLBytes(sb, value)
   210  
   211  	case types.KindBinaryLiteral:
   212  		value := datum.GetBinaryLiteral()
   213  		sb.Grow(3 + 2*len(value))
   214  		sb.WriteString("x'")
   215  		hex.NewEncoder(sb).Write(value)
   216  		sb.WriteByte('\'')
   217  
   218  	case types.KindMysqlBit:
   219  		var buffer [20]byte
   220  		intValue, err := datum.GetBinaryLiteral().ToInt(nil)
   221  		if err != nil {
   222  			return err
   223  		}
   224  		value := strconv.AppendUint(buffer[:0], intValue, 10)
   225  		sb.Write(value)
   226  
   227  		// time, duration, decimal
   228  	default:
   229  		value, err := datum.ToString()
   230  		if err != nil {
   231  			return err
   232  		}
   233  		sb.WriteByte('\'')
   234  		sb.WriteString(value)
   235  		sb.WriteByte('\'')
   236  	}
   237  
   238  	return nil
   239  }
   240  
   241  func (*tidbEncoder) Close() {}
   242  
   243  func getColumnByIndex(cols []*table.Column, index int) *table.Column {
   244  	if index == len(cols) {
   245  		return extraHandleTableColumn
   246  	}
   247  	return cols[index]
   248  }
   249  
   250  func (enc *tidbEncoder) Encode(logger log.Logger, row []types.Datum, _ int64, columnPermutation []int) (Row, error) {
   251  	cols := enc.tbl.Cols()
   252  
   253  	if len(enc.columnIdx) == 0 {
   254  		columnCount := 0
   255  		columnIdx := make([]int, len(columnPermutation))
   256  		for i, idx := range columnPermutation {
   257  			if idx >= 0 {
   258  				columnIdx[idx] = i
   259  				columnCount++
   260  			}
   261  		}
   262  		enc.columnIdx = columnIdx
   263  		enc.columnCnt = columnCount
   264  	}
   265  
   266  	// TODO: since the column count doesn't exactly reflect the real column names, we only check the upper bound currently.
   267  	// See: tests/generated_columns/data/gencol.various_types.0.sql this sql has no columns, so encodeLoop will fill the
   268  	// column permutation with default, thus enc.columnCnt > len(row).
   269  	if len(row) > enc.columnCnt {
   270  		logger.Error("column count mismatch", zap.Ints("column_permutation", columnPermutation),
   271  			zap.Array("data", rowArrayMarshaler(row)))
   272  		return nil, errors.Errorf("column count mismatch, expected %d, got %d", enc.columnCnt, len(row))
   273  	}
   274  
   275  	var encoded strings.Builder
   276  	encoded.Grow(8 * len(row))
   277  	encoded.WriteByte('(')
   278  	for i, field := range row {
   279  		if i != 0 {
   280  			encoded.WriteByte(',')
   281  		}
   282  		if err := enc.appendSQL(&encoded, &field, getColumnByIndex(cols, enc.columnIdx[i])); err != nil {
   283  			logger.Error("tidb encode failed",
   284  				zap.Array("original", rowArrayMarshaler(row)),
   285  				zap.Int("originalCol", i),
   286  				log.ShortError(err),
   287  			)
   288  			return nil, err
   289  		}
   290  	}
   291  	encoded.WriteByte(')')
   292  	return tidbRow(encoded.String()), nil
   293  }
   294  
   295  func (be *tidbBackend) Close() {
   296  	// *Not* going to close `be.db`. The db object is normally borrowed from a
   297  	// TidbManager, so we let the manager to close it.
   298  }
   299  
   300  func (be *tidbBackend) MakeEmptyRows() Rows {
   301  	return tidbRows(nil)
   302  }
   303  
   304  func (be *tidbBackend) RetryImportDelay() time.Duration {
   305  	return 0
   306  }
   307  
   308  func (be *tidbBackend) MaxChunkSize() int {
   309  	failpoint.Inject("FailIfImportedSomeRows", func() {
   310  		failpoint.Return(1)
   311  	})
   312  	return 1048576
   313  }
   314  
   315  func (be *tidbBackend) ShouldPostProcess() bool {
   316  	return false
   317  }
   318  
   319  func (be *tidbBackend) CheckRequirements(ctx context.Context) error {
   320  	log.L().Info("skipping check requirements for tidb backend")
   321  	return nil
   322  }
   323  
   324  func (be *tidbBackend) NewEncoder(tbl table.Table, options *SessionOptions) (Encoder, error) {
   325  	se := newSession(options)
   326  	if options.SQLMode.HasStrictMode() {
   327  		se.vars.SkipUTF8Check = false
   328  		se.vars.SkipASCIICheck = false
   329  	}
   330  
   331  	return &tidbEncoder{mode: options.SQLMode, tbl: tbl, se: se}, nil
   332  }
   333  
   334  func (be *tidbBackend) OpenEngine(context.Context, uuid.UUID) error {
   335  	return nil
   336  }
   337  
   338  func (be *tidbBackend) CloseEngine(context.Context, uuid.UUID) error {
   339  	return nil
   340  }
   341  
   342  func (be *tidbBackend) CleanupEngine(context.Context, uuid.UUID) error {
   343  	return nil
   344  }
   345  
   346  func (be *tidbBackend) ImportEngine(context.Context, uuid.UUID) error {
   347  	return nil
   348  }
   349  
   350  func (be *tidbBackend) WriteRows(ctx context.Context, _ uuid.UUID, tableName string, columnNames []string, _ uint64, rows Rows) error {
   351  	var err error
   352  outside:
   353  	for _, r := range rows.SplitIntoChunks(be.MaxChunkSize()) {
   354  		for i := 0; i < maxRetryTimes; i++ {
   355  			err = be.WriteRowsToDB(ctx, tableName, columnNames, r)
   356  			switch {
   357  			case err == nil:
   358  				continue outside
   359  			case common.IsRetryableError(err):
   360  				// retry next loop
   361  			default:
   362  				return err
   363  			}
   364  		}
   365  		return errors.Annotatef(err, "[%s] write rows reach max retry %d and still failed", tableName, maxRetryTimes)
   366  	}
   367  	return nil
   368  }
   369  
   370  func (be *tidbBackend) WriteRowsToDB(ctx context.Context, tableName string, columnNames []string, r Rows) error {
   371  	rows := r.(tidbRows)
   372  	if len(rows) == 0 {
   373  		return nil
   374  	}
   375  
   376  	var insertStmt strings.Builder
   377  	switch be.onDuplicate {
   378  	case config.ReplaceOnDup:
   379  		insertStmt.WriteString("REPLACE INTO ")
   380  	case config.IgnoreOnDup:
   381  		insertStmt.WriteString("INSERT IGNORE INTO ")
   382  	case config.ErrorOnDup:
   383  		insertStmt.WriteString("INSERT INTO ")
   384  	}
   385  
   386  	insertStmt.WriteString(tableName)
   387  	if len(columnNames) > 0 {
   388  		insertStmt.WriteByte('(')
   389  		for i, colName := range columnNames {
   390  			if i != 0 {
   391  				insertStmt.WriteByte(',')
   392  			}
   393  			common.WriteMySQLIdentifier(&insertStmt, colName)
   394  		}
   395  		insertStmt.WriteByte(')')
   396  	}
   397  	insertStmt.WriteString(" VALUES")
   398  
   399  	// Note: we are not going to do interpolation (prepared statements) to avoid
   400  	// complication arise from data length overflow of BIT and BINARY columns
   401  
   402  	for i, row := range rows {
   403  		if i != 0 {
   404  			insertStmt.WriteByte(',')
   405  		}
   406  		insertStmt.WriteString(string(row))
   407  	}
   408  
   409  	// Retry will be done externally, so we're not going to retry here.
   410  	_, err := be.db.ExecContext(ctx, insertStmt.String())
   411  	if err != nil {
   412  		log.L().Error("execute statement failed", log.ZapRedactString("stmt", insertStmt.String()),
   413  			log.ZapRedactArray("rows", rows), zap.Error(err))
   414  	}
   415  	failpoint.Inject("FailIfImportedSomeRows", func() {
   416  		panic("forcing failure due to FailIfImportedSomeRows, before saving checkpoint")
   417  	})
   418  	return errors.Trace(err)
   419  }
   420  
   421  func (be *tidbBackend) FetchRemoteTableModels(ctx context.Context, schemaName string) (tables []*model.TableInfo, err error) {
   422  	s := common.SQLWithRetry{
   423  		DB:     be.db,
   424  		Logger: log.L(),
   425  	}
   426  
   427  	err = s.Transact(ctx, "fetch table columns", func(c context.Context, tx *sql.Tx) error {
   428  		var versionStr string
   429  		if err = tx.QueryRowContext(ctx, "SELECT version()").Scan(&versionStr); err != nil {
   430  			return err
   431  		}
   432  		tidbVersion, err := common.ExtractTiDBVersion(versionStr)
   433  		if err != nil {
   434  			return err
   435  		}
   436  
   437  		rows, e := tx.Query(`
   438  			SELECT table_name, column_name, column_type, extra
   439  			FROM information_schema.columns
   440  			WHERE table_schema = ?
   441  			ORDER BY table_name, ordinal_position;
   442  		`, schemaName)
   443  		if e != nil {
   444  			return e
   445  		}
   446  		defer rows.Close()
   447  
   448  		var (
   449  			curTableName string
   450  			curColOffset int
   451  			curTable     *model.TableInfo
   452  		)
   453  		for rows.Next() {
   454  			var tableName, columnName, columnType, columnExtra string
   455  			if e := rows.Scan(&tableName, &columnName, &columnType, &columnExtra); e != nil {
   456  				return e
   457  			}
   458  			if tableName != curTableName {
   459  				curTable = &model.TableInfo{
   460  					Name:       model.NewCIStr(tableName),
   461  					State:      model.StatePublic,
   462  					PKIsHandle: true,
   463  				}
   464  				tables = append(tables, curTable)
   465  				curTableName = tableName
   466  				curColOffset = 0
   467  			}
   468  
   469  			// see: https://github.com/pingcap/parser/blob/3b2fb4b41d73710bc6c4e1f4e8679d8be6a4863e/types/field_type.go#L185-L191
   470  			var flag uint
   471  			if strings.HasSuffix(columnType, "unsigned") {
   472  				flag |= mysql.UnsignedFlag
   473  			}
   474  			if strings.Contains(columnExtra, "auto_increment") {
   475  				flag |= mysql.AutoIncrementFlag
   476  			}
   477  			curTable.Columns = append(curTable.Columns, &model.ColumnInfo{
   478  				Name:   model.NewCIStr(columnName),
   479  				Offset: curColOffset,
   480  				State:  model.StatePublic,
   481  				FieldType: types.FieldType{
   482  					Flag: flag,
   483  				},
   484  			})
   485  			curColOffset++
   486  		}
   487  		if rows.Err() != nil {
   488  			return rows.Err()
   489  		}
   490  		// for version < v4.0.0 we can use `show table next_row_id` to fetch auto id info, so about should be enough
   491  		if tidbVersion.Major < 4 {
   492  			return nil
   493  		}
   494  		// init auto id column for each table
   495  		for _, tbl := range tables {
   496  			tblName := common.UniqueTable(schemaName, tbl.Name.O)
   497  			rows, e = tx.Query(fmt.Sprintf("SHOW TABLE %s NEXT_ROW_ID", tblName))
   498  			if e != nil {
   499  				return e
   500  			}
   501  			for rows.Next() {
   502  				var (
   503  					dbName, tblName, columnName, idType string
   504  					nextID                              int64
   505  				)
   506  				columns, err := rows.Columns()
   507  				if err != nil {
   508  					return err
   509  				}
   510  
   511  				//+--------------+------------+-------------+--------------------+----------------+
   512  				//| DB_NAME      | TABLE_NAME | COLUMN_NAME | NEXT_GLOBAL_ROW_ID | ID_TYPE        |
   513  				//+--------------+------------+-------------+--------------------+----------------+
   514  				//| testsysbench | t          | _tidb_rowid |                  1 | AUTO_INCREMENT |
   515  				//+--------------+------------+-------------+--------------------+----------------+
   516  
   517  				// if columns length is 4, it doesn't contains the last column `ID_TYPE`, and it will always be 'AUTO_INCREMENT'
   518  				// for v4.0.0~v4.0.2 show table t next_row_id only returns 4 columns.
   519  				if len(columns) == 4 {
   520  					err = rows.Scan(&dbName, &tblName, &columnName, &nextID)
   521  					idType = "AUTO_INCREMENT"
   522  				} else {
   523  					err = rows.Scan(&dbName, &tblName, &columnName, &nextID, &idType)
   524  				}
   525  				if err != nil {
   526  					return err
   527  				}
   528  
   529  				for _, col := range tbl.Columns {
   530  					if col.Name.O == columnName {
   531  						switch idType {
   532  						case "AUTO_INCREMENT":
   533  							col.Flag |= mysql.AutoIncrementFlag
   534  						case "AUTO_RANDOM":
   535  							col.Flag |= mysql.PriKeyFlag
   536  							tbl.PKIsHandle = true
   537  							// set a stub here, since we don't really need the real value
   538  							tbl.AutoRandomBits = 1
   539  						}
   540  					}
   541  				}
   542  			}
   543  			rows.Close()
   544  			if rows.Err() != nil {
   545  				return rows.Err()
   546  			}
   547  		}
   548  		return nil
   549  	})
   550  	return
   551  }
   552  
   553  func (be *tidbBackend) EngineFileSizes() []EngineFileSize {
   554  	return nil
   555  }
   556  
   557  func (be *tidbBackend) FlushEngine(context.Context, uuid.UUID) error {
   558  	return nil
   559  }
   560  
   561  func (be *tidbBackend) FlushAllEngines(context.Context) error {
   562  	return nil
   563  }
   564  
   565  func (be *tidbBackend) ResetEngine(context.Context, uuid.UUID) error {
   566  	return errors.New("cannot reset an engine in TiDB backend")
   567  }
   568  
   569  func (be *tidbBackend) LocalWriter(ctx context.Context, engineUUID uuid.UUID, maxCacheSize int64) (EngineWriter, error) {
   570  	return &TiDBWriter{be: be, engineUUID: engineUUID}, nil
   571  }
   572  
   573  type TiDBWriter struct {
   574  	be         *tidbBackend
   575  	engineUUID uuid.UUID
   576  }
   577  
   578  func (w *TiDBWriter) Close() error {
   579  	return nil
   580  }
   581  
   582  func (w *TiDBWriter) AppendRows(ctx context.Context, tableName string, columnNames []string, arg1 uint64, rows Rows) error {
   583  	return w.be.WriteRows(ctx, w.engineUUID, tableName, columnNames, arg1, rows)
   584  }