github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/tests/mq_protocol_tests/framework/sql_helper.go (about) 1 // Copyright 2020 PingCAP, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package framework 15 16 import ( 17 "context" 18 "database/sql" 19 "fmt" 20 "strings" 21 "sync/atomic" 22 "time" 23 24 "github.com/jmoiron/sqlx" 25 "github.com/pingcap/errors" 26 "github.com/pingcap/log" 27 "github.com/pingcap/tiflow/pkg/quotes" 28 "go.uber.org/zap" 29 "go.uber.org/zap/zapcore" 30 "upper.io/db.v3/lib/sqlbuilder" 31 _ "upper.io/db.v3/mysql" // imported for side effects 32 ) 33 34 // SQLHelper provides basic utilities for manipulating data 35 type SQLHelper struct { 36 upstream *sql.DB 37 downstream *sql.DB 38 ctx context.Context 39 } 40 41 // Table represents the handle of a table in the upstream 42 type Table struct { 43 err error 44 tableName string 45 uniqueIndex []string 46 helper *SQLHelper 47 } 48 49 // GetTable returns the handle of the given table 50 func (h *SQLHelper) GetTable(tableName string) *Table { 51 db, err := sqlbuilder.New("mysql", h.upstream) 52 if err != nil { 53 return &Table{err: errors.AddStack(err)} 54 } 55 56 idxCol, err := getUniqueIndexColumn(h.ctx, db, "testdb", tableName) 57 if err != nil { 58 return &Table{err: errors.AddStack(err)} 59 } 60 61 return &Table{tableName: tableName, uniqueIndex: idxCol, helper: h} 62 } 63 64 func (t *Table) makeSQLRequest(requestType sqlRequestType, rowData map[string]interface{}) (*sqlRequest, error) { 65 if t.err != nil { 66 return nil, t.err 67 } 68 69 return &sqlRequest{ 70 tableName: t.tableName, 71 data: rowData, 72 result: nil, 73 uniqueIndex: t.uniqueIndex, 74 helper: t.helper, 75 requestType: requestType, 76 }, nil 77 } 78 79 // Insert returns a Sendable object that represents an Insert clause 80 func (t *Table) Insert(rowData map[string]interface{}) Sendable { 81 basicReq, err := t.makeSQLRequest(sqlRequestTypeInsert, rowData) 82 if err != nil { 83 return &errorSender{err: err} 84 } 85 86 return &syncSQLRequest{*basicReq} 87 } 88 89 // Upsert returns a Sendable object that represents a Replace Into clause 90 func (t *Table) Upsert(rowData map[string]interface{}) Sendable { 91 basicReq, err := t.makeSQLRequest(sqlRequestTypeUpsert, rowData) 92 if err != nil { 93 return &errorSender{err: err} 94 } 95 96 return &syncSQLRequest{*basicReq} 97 } 98 99 // Delete returns a Sendable object that represents a Delete from clause 100 func (t *Table) Delete(rowData map[string]interface{}) Sendable { 101 basicReq, err := t.makeSQLRequest(sqlRequestTypeDelete, rowData) 102 if err != nil { 103 return &errorSender{err: err} 104 } 105 106 return &syncSQLRequest{*basicReq} 107 } 108 109 type sqlRowContainer interface { 110 getData() map[string]interface{} 111 getComparableKey() string 112 getTable() *Table 113 } 114 115 type awaitableSQLRowContainer struct { 116 Awaitable 117 sqlRowContainer 118 } 119 120 type sqlRequestType int32 121 122 const ( 123 sqlRequestTypeInsert sqlRequestType = iota 124 sqlRequestTypeUpsert 125 sqlRequestTypeDelete 126 ) 127 128 type sqlRequest struct { 129 tableName string 130 data map[string]interface{} 131 result map[string]interface{} 132 uniqueIndex []string 133 helper *SQLHelper 134 requestType sqlRequestType 135 hasReadBack uint32 136 } 137 138 // MarshalLogObjects helps printing the sqlRequest 139 func (s *sqlRequest) MarshalLogObject(encoder zapcore.ObjectEncoder) error { 140 encoder.AddString("upstream", fmt.Sprintf("%#v", s.data)) 141 encoder.AddString("downstream", fmt.Sprintf("%#v", s.result)) 142 return nil 143 } 144 145 func (s *sqlRequest) getPrimaryKeyTuple() string { 146 return makeColumnTuple(s.uniqueIndex) 147 } 148 149 func (s *sqlRequest) getWhereCondition() []interface{} { 150 builder := strings.Builder{} 151 args := make([]interface{}, 1, len(s.uniqueIndex)+1) 152 builder.WriteString(s.getPrimaryKeyTuple() + " = (") 153 for i, col := range s.uniqueIndex { 154 builder.WriteString("?") 155 if i != len(s.uniqueIndex)-1 { 156 builder.WriteString(",") 157 } 158 159 args = append(args, s.data[col]) 160 } 161 builder.WriteString(")") 162 args[0] = builder.String() 163 return args 164 } 165 166 func (s *sqlRequest) getComparableKey() string { 167 if len(s.uniqueIndex) == 1 { 168 return s.uniqueIndex[0] 169 } 170 171 ret := make(map[string]interface{}) 172 for k, v := range s.data { 173 for _, col := range s.uniqueIndex { 174 if k == col { 175 ret[k] = v 176 } 177 } 178 } 179 return fmt.Sprintf("%v", ret) 180 } 181 182 func (s *sqlRequest) getData() map[string]interface{} { 183 return s.data 184 } 185 186 func (s *sqlRequest) getTable() *Table { 187 return &Table{ 188 err: nil, 189 tableName: s.tableName, 190 uniqueIndex: s.uniqueIndex, 191 helper: s.helper, 192 } 193 } 194 195 func (s *sqlRequest) getAwaitableSQLRowContainer() *awaitableSQLRowContainer { 196 return &awaitableSQLRowContainer{ 197 Awaitable: &basicAwaitable{ 198 pollableAndCheckable: s, 199 timeout: 30 * time.Second, 200 }, 201 sqlRowContainer: s, 202 } 203 } 204 205 // Sendable is a sendable request to the upstream 206 type Sendable interface { 207 Send() Awaitable 208 } 209 210 type errorSender struct { 211 err error 212 } 213 214 // Send implements sender 215 func (s *errorSender) Send() Awaitable { 216 return &errorCheckableAndAwaitable{s.err} 217 } 218 219 type syncSQLRequest struct { 220 sqlRequest 221 } 222 223 func (r *syncSQLRequest) Send() Awaitable { 224 atomic.StoreUint32(&r.hasReadBack, 0) 225 var err error 226 switch r.requestType { 227 case sqlRequestTypeInsert: 228 err = r.insert(r.helper.ctx) 229 case sqlRequestTypeUpsert: 230 err = r.upsert(r.helper.ctx) 231 case sqlRequestTypeDelete: 232 err = r.delete(r.helper.ctx) 233 } 234 235 go func() { 236 db, err := sqlbuilder.New("mysql", r.helper.upstream) 237 if err != nil { 238 log.Warn("ReadBack:", zap.Error(err)) 239 return 240 } 241 242 cond := r.getWhereCondition() 243 244 rows, err := db.SelectFrom(r.tableName).Where(cond).QueryContext(r.helper.ctx) 245 if err != nil { 246 log.Warn("ReadBack:", zap.Error(err)) 247 return 248 } 249 defer rows.Close() 250 251 if !rows.Next() { 252 // Upstream does not have the row 253 if r.requestType != sqlRequestTypeDelete { 254 log.Warn("ReadBack: no row, likely to be bug") 255 } 256 } else { 257 r.data, err = rowsToMap(rows) 258 if err != nil { 259 log.Warn("ReadBack", zap.Error(err)) 260 return 261 } 262 } 263 264 atomic.StoreUint32(&r.hasReadBack, 1) 265 }() 266 267 if err != nil { 268 return &errorCheckableAndAwaitable{errors.AddStack(err)} 269 } 270 return r.getAwaitableSQLRowContainer() 271 } 272 273 /* 274 type asyncSQLRequest struct { 275 sqlRequest 276 } 277 */ 278 279 func (s *sqlRequest) insert(ctx context.Context) error { 280 db, err := sqlbuilder.New("mysql", s.helper.upstream) 281 if err != nil { 282 return errors.AddStack(err) 283 } 284 285 keys := make([]string, len(s.data)) 286 values := make([]interface{}, len(s.data)) 287 i := 0 288 for k, v := range s.data { 289 keys[i] = k 290 values[i] = v 291 i++ 292 } 293 294 _, err = db.InsertInto(s.tableName).Columns(keys...).Values(values...).ExecContext(ctx) 295 if err != nil { 296 return errors.AddStack(err) 297 } 298 299 s.requestType = sqlRequestTypeInsert 300 return nil 301 } 302 303 func (s *sqlRequest) upsert(ctx context.Context) error { 304 db := sqlx.NewDb(s.helper.upstream, "mysql") 305 306 keys := make([]string, len(s.data)) 307 values := make([]interface{}, len(s.data)) 308 i := 0 309 for k, v := range s.data { 310 keys[i] = k 311 values[i] = v 312 i++ 313 } 314 315 query, args, err := sqlx.In("replace into `"+s.tableName+"` "+makeColumnTuple(keys)+" values (?)", values) 316 if err != nil { 317 return errors.AddStack(err) 318 } 319 320 query = db.Rebind(query) 321 _, err = s.helper.upstream.ExecContext(ctx, query, args...) 322 if err != nil { 323 return errors.AddStack(err) 324 } 325 326 s.requestType = sqlRequestTypeUpsert 327 return nil 328 } 329 330 func (s *sqlRequest) delete(ctx context.Context) error { 331 db, err := sqlbuilder.New("mysql", s.helper.upstream) 332 if err != nil { 333 return errors.AddStack(err) 334 } 335 336 _, err = db.DeleteFrom(s.tableName).Where(s.getWhereCondition()).ExecContext(ctx) 337 if err != nil { 338 return errors.AddStack(err) 339 } 340 341 s.requestType = sqlRequestTypeDelete 342 return nil 343 } 344 345 func (s *sqlRequest) read(ctx context.Context) (map[string]interface{}, error) { 346 db, err := sqlbuilder.New("mysql", s.helper.downstream) 347 if err != nil { 348 return nil, errors.AddStack(err) 349 } 350 351 rows, err := db.SelectFrom(s.tableName).Where(s.getWhereCondition()).QueryContext(ctx) 352 if err != nil { 353 return nil, errors.AddStack(err) 354 } 355 defer rows.Close() 356 357 if !rows.Next() { 358 return nil, nil 359 } 360 return rowsToMap(rows) 361 } 362 363 //nolint:unused 364 func (s *sqlRequest) getBasicAwaitable() basicAwaitable { 365 return basicAwaitable{ 366 pollableAndCheckable: s, 367 timeout: 0, 368 } 369 } 370 371 func (s *sqlRequest) poll(ctx context.Context) (bool, error) { 372 if atomic.LoadUint32(&s.hasReadBack) == 0 { 373 return false, nil 374 } 375 res, err := s.read(ctx) 376 if err != nil { 377 if strings.Contains(err.Error(), "Error 1146") { 378 return false, nil 379 } 380 return false, errors.AddStack(err) 381 } 382 s.result = res 383 384 switch s.requestType { 385 case sqlRequestTypeInsert: 386 if res == nil { 387 return false, nil 388 } 389 return true, nil 390 case sqlRequestTypeUpsert: 391 if res == nil { 392 return false, nil 393 } 394 if compareMaps(s.data, res) { 395 return true, nil 396 } 397 log.Debug("Upserted row does not match the expected") 398 return false, nil 399 case sqlRequestTypeDelete: 400 if res == nil { 401 return true, nil 402 } 403 log.Debug("Delete not successful yet", zap.Reflect("where", s.getWhereCondition())) 404 return false, nil 405 } 406 return true, nil 407 } 408 409 func (s *sqlRequest) Check() error { 410 if s.requestType == sqlRequestTypeUpsert || s.requestType == sqlRequestTypeDelete { 411 return nil 412 } 413 // TODO better comparator 414 if s.result == nil { 415 return errors.New("Check: nil result") 416 } 417 if compareMaps(s.data, s.result) { 418 return nil 419 } 420 log.Warn("Check failed", zap.Object("request", s)) 421 return errors.New("Check failed") 422 } 423 424 func rowsToMap(rows *sql.Rows) (map[string]interface{}, error) { 425 colNames, err := rows.Columns() 426 if err != nil { 427 return nil, errors.AddStack(err) 428 } 429 430 colData := make([]interface{}, len(colNames)) 431 colDataPtrs := make([]interface{}, len(colNames)) 432 for i := range colData { 433 colDataPtrs[i] = &colData[i] 434 } 435 436 err = rows.Scan(colDataPtrs...) 437 if err != nil { 438 return nil, errors.AddStack(err) 439 } 440 441 ret := make(map[string]interface{}, len(colNames)) 442 for i := 0; i < len(colNames); i++ { 443 ret[colNames[i]] = colData[i] 444 } 445 return ret, nil 446 } 447 448 func getUniqueIndexColumn(ctx context.Context, db sqlbuilder.Database, dbName string, tableName string) ([]string, error) { 449 row, err := db.QueryRowContext(ctx, ` 450 SELECT GROUP_CONCAT(COLUMN_NAME SEPARATOR ' ') FROM INFORMATION_SCHEMA.STATISTICS 451 WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? 452 GROUP BY INDEX_NAME 453 ORDER BY FIELD(INDEX_NAME,'PRIMARY') DESC 454 `, dbName, tableName) 455 if err != nil { 456 return nil, errors.AddStack(err) 457 } 458 459 colName := "" 460 err = row.Scan(&colName) 461 if err != nil { 462 return nil, errors.AddStack(err) 463 } 464 465 return strings.Split(colName, " "), nil 466 } 467 468 func compareMaps(m1 map[string]interface{}, m2 map[string]interface{}) bool { 469 // TODO better comparator 470 if m2 == nil { 471 return false 472 } 473 str1 := fmt.Sprintf("%v", m1) 474 str2 := fmt.Sprintf("%v", m2) 475 return str1 == str2 476 } 477 478 func makeColumnTuple(colNames []string) string { 479 colNamesQuoted := make([]string, len(colNames)) 480 for i := range colNames { 481 colNamesQuoted[i] = quotes.QuoteName(colNames[i]) 482 } 483 return "(" + strings.Join(colNamesQuoted, ",") + ")" 484 }