github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/tests/mq_protocol_tests/framework/sql_helper.go (about)

     1  // Copyright 2020 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package framework
    15  
    16  import (
    17  	"context"
    18  	"database/sql"
    19  	"fmt"
    20  	"strings"
    21  	"sync/atomic"
    22  	"time"
    23  
    24  	"github.com/jmoiron/sqlx"
    25  	"github.com/pingcap/errors"
    26  	"github.com/pingcap/log"
    27  	"github.com/pingcap/tiflow/pkg/quotes"
    28  	"go.uber.org/zap"
    29  	"go.uber.org/zap/zapcore"
    30  	"upper.io/db.v3/lib/sqlbuilder"
    31  	_ "upper.io/db.v3/mysql" // imported for side effects
    32  )
    33  
    34  // SQLHelper provides basic utilities for manipulating data
    35  type SQLHelper struct {
    36  	upstream   *sql.DB
    37  	downstream *sql.DB
    38  	ctx        context.Context
    39  }
    40  
    41  // Table represents the handle of a table in the upstream
    42  type Table struct {
    43  	err         error
    44  	tableName   string
    45  	uniqueIndex []string
    46  	helper      *SQLHelper
    47  }
    48  
    49  // GetTable returns the handle of the given table
    50  func (h *SQLHelper) GetTable(tableName string) *Table {
    51  	db, err := sqlbuilder.New("mysql", h.upstream)
    52  	if err != nil {
    53  		return &Table{err: errors.AddStack(err)}
    54  	}
    55  
    56  	idxCol, err := getUniqueIndexColumn(h.ctx, db, "testdb", tableName)
    57  	if err != nil {
    58  		return &Table{err: errors.AddStack(err)}
    59  	}
    60  
    61  	return &Table{tableName: tableName, uniqueIndex: idxCol, helper: h}
    62  }
    63  
    64  func (t *Table) makeSQLRequest(requestType sqlRequestType, rowData map[string]interface{}) (*sqlRequest, error) {
    65  	if t.err != nil {
    66  		return nil, t.err
    67  	}
    68  
    69  	return &sqlRequest{
    70  		tableName:   t.tableName,
    71  		data:        rowData,
    72  		result:      nil,
    73  		uniqueIndex: t.uniqueIndex,
    74  		helper:      t.helper,
    75  		requestType: requestType,
    76  	}, nil
    77  }
    78  
    79  // Insert returns a Sendable object that represents an Insert clause
    80  func (t *Table) Insert(rowData map[string]interface{}) Sendable {
    81  	basicReq, err := t.makeSQLRequest(sqlRequestTypeInsert, rowData)
    82  	if err != nil {
    83  		return &errorSender{err: err}
    84  	}
    85  
    86  	return &syncSQLRequest{*basicReq}
    87  }
    88  
    89  // Upsert returns a Sendable object that represents a Replace Into clause
    90  func (t *Table) Upsert(rowData map[string]interface{}) Sendable {
    91  	basicReq, err := t.makeSQLRequest(sqlRequestTypeUpsert, rowData)
    92  	if err != nil {
    93  		return &errorSender{err: err}
    94  	}
    95  
    96  	return &syncSQLRequest{*basicReq}
    97  }
    98  
    99  // Delete returns a Sendable object that represents a Delete from clause
   100  func (t *Table) Delete(rowData map[string]interface{}) Sendable {
   101  	basicReq, err := t.makeSQLRequest(sqlRequestTypeDelete, rowData)
   102  	if err != nil {
   103  		return &errorSender{err: err}
   104  	}
   105  
   106  	return &syncSQLRequest{*basicReq}
   107  }
   108  
   109  type sqlRowContainer interface {
   110  	getData() map[string]interface{}
   111  	getComparableKey() string
   112  	getTable() *Table
   113  }
   114  
   115  type awaitableSQLRowContainer struct {
   116  	Awaitable
   117  	sqlRowContainer
   118  }
   119  
   120  type sqlRequestType int32
   121  
   122  const (
   123  	sqlRequestTypeInsert sqlRequestType = iota
   124  	sqlRequestTypeUpsert
   125  	sqlRequestTypeDelete
   126  )
   127  
   128  type sqlRequest struct {
   129  	tableName   string
   130  	data        map[string]interface{}
   131  	result      map[string]interface{}
   132  	uniqueIndex []string
   133  	helper      *SQLHelper
   134  	requestType sqlRequestType
   135  	hasReadBack uint32
   136  }
   137  
   138  // MarshalLogObjects helps printing the sqlRequest
   139  func (s *sqlRequest) MarshalLogObject(encoder zapcore.ObjectEncoder) error {
   140  	encoder.AddString("upstream", fmt.Sprintf("%#v", s.data))
   141  	encoder.AddString("downstream", fmt.Sprintf("%#v", s.result))
   142  	return nil
   143  }
   144  
   145  func (s *sqlRequest) getPrimaryKeyTuple() string {
   146  	return makeColumnTuple(s.uniqueIndex)
   147  }
   148  
   149  func (s *sqlRequest) getWhereCondition() []interface{} {
   150  	builder := strings.Builder{}
   151  	args := make([]interface{}, 1, len(s.uniqueIndex)+1)
   152  	builder.WriteString(s.getPrimaryKeyTuple() + " = (")
   153  	for i, col := range s.uniqueIndex {
   154  		builder.WriteString("?")
   155  		if i != len(s.uniqueIndex)-1 {
   156  			builder.WriteString(",")
   157  		}
   158  
   159  		args = append(args, s.data[col])
   160  	}
   161  	builder.WriteString(")")
   162  	args[0] = builder.String()
   163  	return args
   164  }
   165  
   166  func (s *sqlRequest) getComparableKey() string {
   167  	if len(s.uniqueIndex) == 1 {
   168  		return s.uniqueIndex[0]
   169  	}
   170  
   171  	ret := make(map[string]interface{})
   172  	for k, v := range s.data {
   173  		for _, col := range s.uniqueIndex {
   174  			if k == col {
   175  				ret[k] = v
   176  			}
   177  		}
   178  	}
   179  	return fmt.Sprintf("%v", ret)
   180  }
   181  
   182  func (s *sqlRequest) getData() map[string]interface{} {
   183  	return s.data
   184  }
   185  
   186  func (s *sqlRequest) getTable() *Table {
   187  	return &Table{
   188  		err:         nil,
   189  		tableName:   s.tableName,
   190  		uniqueIndex: s.uniqueIndex,
   191  		helper:      s.helper,
   192  	}
   193  }
   194  
   195  func (s *sqlRequest) getAwaitableSQLRowContainer() *awaitableSQLRowContainer {
   196  	return &awaitableSQLRowContainer{
   197  		Awaitable: &basicAwaitable{
   198  			pollableAndCheckable: s,
   199  			timeout:              30 * time.Second,
   200  		},
   201  		sqlRowContainer: s,
   202  	}
   203  }
   204  
   205  // Sendable is a sendable request to the upstream
   206  type Sendable interface {
   207  	Send() Awaitable
   208  }
   209  
   210  type errorSender struct {
   211  	err error
   212  }
   213  
   214  // Send implements sender
   215  func (s *errorSender) Send() Awaitable {
   216  	return &errorCheckableAndAwaitable{s.err}
   217  }
   218  
   219  type syncSQLRequest struct {
   220  	sqlRequest
   221  }
   222  
   223  func (r *syncSQLRequest) Send() Awaitable {
   224  	atomic.StoreUint32(&r.hasReadBack, 0)
   225  	var err error
   226  	switch r.requestType {
   227  	case sqlRequestTypeInsert:
   228  		err = r.insert(r.helper.ctx)
   229  	case sqlRequestTypeUpsert:
   230  		err = r.upsert(r.helper.ctx)
   231  	case sqlRequestTypeDelete:
   232  		err = r.delete(r.helper.ctx)
   233  	}
   234  
   235  	go func() {
   236  		db, err := sqlbuilder.New("mysql", r.helper.upstream)
   237  		if err != nil {
   238  			log.Warn("ReadBack:", zap.Error(err))
   239  			return
   240  		}
   241  
   242  		cond := r.getWhereCondition()
   243  
   244  		rows, err := db.SelectFrom(r.tableName).Where(cond).QueryContext(r.helper.ctx)
   245  		if err != nil {
   246  			log.Warn("ReadBack:", zap.Error(err))
   247  			return
   248  		}
   249  		defer rows.Close()
   250  
   251  		if !rows.Next() {
   252  			// Upstream does not have the row
   253  			if r.requestType != sqlRequestTypeDelete {
   254  				log.Warn("ReadBack: no row, likely to be bug")
   255  			}
   256  		} else {
   257  			r.data, err = rowsToMap(rows)
   258  			if err != nil {
   259  				log.Warn("ReadBack", zap.Error(err))
   260  				return
   261  			}
   262  		}
   263  
   264  		atomic.StoreUint32(&r.hasReadBack, 1)
   265  	}()
   266  
   267  	if err != nil {
   268  		return &errorCheckableAndAwaitable{errors.AddStack(err)}
   269  	}
   270  	return r.getAwaitableSQLRowContainer()
   271  }
   272  
   273  /*
   274  type asyncSQLRequest struct {
   275  	sqlRequest
   276  }
   277  */
   278  
   279  func (s *sqlRequest) insert(ctx context.Context) error {
   280  	db, err := sqlbuilder.New("mysql", s.helper.upstream)
   281  	if err != nil {
   282  		return errors.AddStack(err)
   283  	}
   284  
   285  	keys := make([]string, len(s.data))
   286  	values := make([]interface{}, len(s.data))
   287  	i := 0
   288  	for k, v := range s.data {
   289  		keys[i] = k
   290  		values[i] = v
   291  		i++
   292  	}
   293  
   294  	_, err = db.InsertInto(s.tableName).Columns(keys...).Values(values...).ExecContext(ctx)
   295  	if err != nil {
   296  		return errors.AddStack(err)
   297  	}
   298  
   299  	s.requestType = sqlRequestTypeInsert
   300  	return nil
   301  }
   302  
   303  func (s *sqlRequest) upsert(ctx context.Context) error {
   304  	db := sqlx.NewDb(s.helper.upstream, "mysql")
   305  
   306  	keys := make([]string, len(s.data))
   307  	values := make([]interface{}, len(s.data))
   308  	i := 0
   309  	for k, v := range s.data {
   310  		keys[i] = k
   311  		values[i] = v
   312  		i++
   313  	}
   314  
   315  	query, args, err := sqlx.In("replace into `"+s.tableName+"` "+makeColumnTuple(keys)+" values (?)", values)
   316  	if err != nil {
   317  		return errors.AddStack(err)
   318  	}
   319  
   320  	query = db.Rebind(query)
   321  	_, err = s.helper.upstream.ExecContext(ctx, query, args...)
   322  	if err != nil {
   323  		return errors.AddStack(err)
   324  	}
   325  
   326  	s.requestType = sqlRequestTypeUpsert
   327  	return nil
   328  }
   329  
   330  func (s *sqlRequest) delete(ctx context.Context) error {
   331  	db, err := sqlbuilder.New("mysql", s.helper.upstream)
   332  	if err != nil {
   333  		return errors.AddStack(err)
   334  	}
   335  
   336  	_, err = db.DeleteFrom(s.tableName).Where(s.getWhereCondition()).ExecContext(ctx)
   337  	if err != nil {
   338  		return errors.AddStack(err)
   339  	}
   340  
   341  	s.requestType = sqlRequestTypeDelete
   342  	return nil
   343  }
   344  
   345  func (s *sqlRequest) read(ctx context.Context) (map[string]interface{}, error) {
   346  	db, err := sqlbuilder.New("mysql", s.helper.downstream)
   347  	if err != nil {
   348  		return nil, errors.AddStack(err)
   349  	}
   350  
   351  	rows, err := db.SelectFrom(s.tableName).Where(s.getWhereCondition()).QueryContext(ctx)
   352  	if err != nil {
   353  		return nil, errors.AddStack(err)
   354  	}
   355  	defer rows.Close()
   356  
   357  	if !rows.Next() {
   358  		return nil, nil
   359  	}
   360  	return rowsToMap(rows)
   361  }
   362  
   363  //nolint:unused
   364  func (s *sqlRequest) getBasicAwaitable() basicAwaitable {
   365  	return basicAwaitable{
   366  		pollableAndCheckable: s,
   367  		timeout:              0,
   368  	}
   369  }
   370  
   371  func (s *sqlRequest) poll(ctx context.Context) (bool, error) {
   372  	if atomic.LoadUint32(&s.hasReadBack) == 0 {
   373  		return false, nil
   374  	}
   375  	res, err := s.read(ctx)
   376  	if err != nil {
   377  		if strings.Contains(err.Error(), "Error 1146") {
   378  			return false, nil
   379  		}
   380  		return false, errors.AddStack(err)
   381  	}
   382  	s.result = res
   383  
   384  	switch s.requestType {
   385  	case sqlRequestTypeInsert:
   386  		if res == nil {
   387  			return false, nil
   388  		}
   389  		return true, nil
   390  	case sqlRequestTypeUpsert:
   391  		if res == nil {
   392  			return false, nil
   393  		}
   394  		if compareMaps(s.data, res) {
   395  			return true, nil
   396  		}
   397  		log.Debug("Upserted row does not match the expected")
   398  		return false, nil
   399  	case sqlRequestTypeDelete:
   400  		if res == nil {
   401  			return true, nil
   402  		}
   403  		log.Debug("Delete not successful yet", zap.Reflect("where", s.getWhereCondition()))
   404  		return false, nil
   405  	}
   406  	return true, nil
   407  }
   408  
   409  func (s *sqlRequest) Check() error {
   410  	if s.requestType == sqlRequestTypeUpsert || s.requestType == sqlRequestTypeDelete {
   411  		return nil
   412  	}
   413  	// TODO better comparator
   414  	if s.result == nil {
   415  		return errors.New("Check: nil result")
   416  	}
   417  	if compareMaps(s.data, s.result) {
   418  		return nil
   419  	}
   420  	log.Warn("Check failed", zap.Object("request", s))
   421  	return errors.New("Check failed")
   422  }
   423  
   424  func rowsToMap(rows *sql.Rows) (map[string]interface{}, error) {
   425  	colNames, err := rows.Columns()
   426  	if err != nil {
   427  		return nil, errors.AddStack(err)
   428  	}
   429  
   430  	colData := make([]interface{}, len(colNames))
   431  	colDataPtrs := make([]interface{}, len(colNames))
   432  	for i := range colData {
   433  		colDataPtrs[i] = &colData[i]
   434  	}
   435  
   436  	err = rows.Scan(colDataPtrs...)
   437  	if err != nil {
   438  		return nil, errors.AddStack(err)
   439  	}
   440  
   441  	ret := make(map[string]interface{}, len(colNames))
   442  	for i := 0; i < len(colNames); i++ {
   443  		ret[colNames[i]] = colData[i]
   444  	}
   445  	return ret, nil
   446  }
   447  
   448  func getUniqueIndexColumn(ctx context.Context, db sqlbuilder.Database, dbName string, tableName string) ([]string, error) {
   449  	row, err := db.QueryRowContext(ctx, `
   450  		SELECT GROUP_CONCAT(COLUMN_NAME SEPARATOR ' ') FROM INFORMATION_SCHEMA.STATISTICS
   451  		WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?
   452  		GROUP BY INDEX_NAME
   453  		ORDER BY FIELD(INDEX_NAME,'PRIMARY') DESC
   454  	`, dbName, tableName)
   455  	if err != nil {
   456  		return nil, errors.AddStack(err)
   457  	}
   458  
   459  	colName := ""
   460  	err = row.Scan(&colName)
   461  	if err != nil {
   462  		return nil, errors.AddStack(err)
   463  	}
   464  
   465  	return strings.Split(colName, " "), nil
   466  }
   467  
   468  func compareMaps(m1 map[string]interface{}, m2 map[string]interface{}) bool {
   469  	// TODO better comparator
   470  	if m2 == nil {
   471  		return false
   472  	}
   473  	str1 := fmt.Sprintf("%v", m1)
   474  	str2 := fmt.Sprintf("%v", m2)
   475  	return str1 == str2
   476  }
   477  
   478  func makeColumnTuple(colNames []string) string {
   479  	colNamesQuoted := make([]string, len(colNames))
   480  	for i := range colNames {
   481  		colNamesQuoted[i] = quotes.QuoteName(colNames[i])
   482  	}
   483  	return "(" + strings.Join(colNamesQuoted, ",") + ")"
   484  }