github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/tests/integration_tests/cdc/dailytest/db.go (about) 1 // Copyright 2020 PingCAP, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package dailytest 15 16 import ( 17 "bytes" 18 "database/sql" 19 "fmt" 20 "math" 21 "strconv" 22 "time" 23 24 "github.com/pingcap/errors" 25 "github.com/pingcap/log" 26 "github.com/pingcap/tidb/pkg/parser/mysql" 27 "github.com/pingcap/tiflow/tests/integration_tests/util" 28 "go.uber.org/zap/zapcore" 29 ) 30 31 func intRangeValue(column *column, min int64, max int64) (int64, int64) { 32 var err error 33 if len(column.min) > 0 { 34 min, err = strconv.ParseInt(column.min, 10, 64) 35 if err != nil { 36 log.S().Fatal(err) 37 } 38 39 if len(column.max) > 0 { 40 max, err = strconv.ParseInt(column.max, 10, 64) 41 if err != nil { 42 log.S().Fatal(err) 43 } 44 } 45 } 46 47 return min, max 48 } 49 50 func randInt64Value(column *column, min int64, max int64) int64 { 51 if len(column.set) > 0 { 52 idx := randInt(0, len(column.set)-1) 53 data, _ := strconv.ParseInt(column.set[idx], 10, 64) 54 return data 55 } 56 57 min, max = intRangeValue(column, min, max) 58 return randInt64(min, max) 59 } 60 61 func uniqInt64Value(column *column, max int64) int64 { 62 min, max := intRangeValue(column, 0, max) 63 column.data.setInitInt64Value(column.step, min, max) 64 return column.data.uniqInt64() 65 } 66 67 func queryCount(table *table, db *sql.DB) (int, error) { 68 rows, err := db.Query(fmt.Sprintf("SELECT COUNT(*) as count FROM %s", table.name)) 69 if err != nil { 70 return 0, errors.Trace(err) 71 } 72 73 var nums int 74 for rows.Next() { 75 err = rows.Scan(&nums) 76 if err != nil { 77 return 0, errors.Trace(err) 78 } 79 } 80 81 return nums, nil 82 } 83 84 func genDeleteSqls(table *table, db *sql.DB, count int) ([]string, [][]interface{}, error) { 85 nums, err := queryCount(table, db) 86 if err != nil { 87 return nil, nil, errors.Trace(err) 88 } 89 90 var sqls []string 91 var args [][]interface{} 92 93 if nums == 0 || nums-count < 1 { 94 return sqls, args, nil 95 } 96 97 start := randInt(1, nums-count) 98 length := len(table.columns) 99 where := genWhere(table.columns) 100 101 rows, err := db.Query(fmt.Sprintf("SELECT * FROM %s limit %d, %d", table.name, start, count)) 102 if err != nil { 103 return nil, nil, errors.Trace(err) 104 } 105 106 for rows.Next() { 107 data := make([]interface{}, length) 108 dbArgs := make([]interface{}, length) 109 110 for i := 0; i < length; i++ { 111 dbArgs[i] = &data[i] 112 } 113 114 err = rows.Scan(dbArgs...) 115 if err != nil { 116 return nil, nil, errors.Trace(err) 117 } 118 119 sqls = append(sqls, fmt.Sprintf("delete from %s where %s", table.name, where)) 120 args = append(args, data) 121 } 122 123 return sqls, args, nil 124 } 125 126 func genUpdateSqls(table *table, db *sql.DB, count int) ([]string, [][]interface{}, error) { 127 nums, err := queryCount(table, db) 128 if err != nil { 129 return nil, nil, errors.Trace(err) 130 } 131 132 var sqls []string 133 var args [][]interface{} 134 135 if nums == 0 || nums-count < 1 { 136 return sqls, args, nil 137 } 138 139 start := randInt(1, nums-count) 140 length := len(table.columns) 141 where := genWhere(table.columns) 142 143 rows, err := db.Query(fmt.Sprintf("SELECT * FROM %s limit %d, %d", table.name, start, count)) 144 if err != nil { 145 return nil, nil, errors.Trace(err) 146 } 147 148 for rows.Next() { 149 data := make([]interface{}, length) 150 dbArgs := make([]interface{}, length) 151 152 for i := 0; i < length; i++ { 153 dbArgs[i] = &data[i] 154 } 155 156 err = rows.Scan(dbArgs...) 157 if err != nil { 158 return nil, nil, errors.Trace(err) 159 } 160 161 index := randInt(2, length-1) 162 column := table.columns[index] 163 updateData, err := genColumnData(table, column) 164 if err != nil { 165 return nil, nil, errors.Trace(err) 166 } 167 168 sqls = append(sqls, fmt.Sprintf("update %s set `%s` = %s where %s", table.name, column.name, updateData, where)) 169 args = append(args, data) 170 } 171 172 return sqls, args, nil 173 } 174 175 func genInsertSqls(table *table, count int) ([]string, [][]interface{}, error) { 176 datas := make([]string, 0, count) 177 args := make([][]interface{}, 0, count) 178 for i := 0; i < count; i++ { 179 data, err := genRowData(table) 180 if err != nil { 181 return nil, nil, errors.Trace(err) 182 } 183 datas = append(datas, data) 184 args = append(args, nil) 185 } 186 187 return datas, args, nil 188 } 189 190 func genWhere(columns []*column) string { 191 var kvs bytes.Buffer 192 for i := range columns { 193 if i == len(columns)-1 { 194 fmt.Fprintf(&kvs, "`%s` = ?", columns[i].name) 195 } else { 196 fmt.Fprintf(&kvs, "`%s` = ? and ", columns[i].name) 197 } 198 } 199 200 return kvs.String() 201 } 202 203 func genRowData(table *table) (string, error) { 204 var values []byte 205 for _, column := range table.columns { 206 data, err := genColumnData(table, column) 207 if err != nil { 208 return "", errors.Trace(err) 209 } 210 values = append(values, []byte(data)...) 211 values = append(values, ',') 212 } 213 214 values = values[:len(values)-1] 215 sql := fmt.Sprintf("insert into %s values (%s);", table.name, string(values)) 216 return sql, nil 217 } 218 219 func genColumnData(table *table, column *column) (string, error) { 220 tp := column.tp 221 _, isUnique := table.uniqIndices[column.name] 222 isUnsigned := mysql.HasUnsignedFlag(tp.GetFlag()) 223 224 switch tp.GetType() { 225 case mysql.TypeTiny: 226 var data int64 227 if isUnique { 228 data = uniqInt64Value(column, math.MaxUint8) 229 } else { 230 if isUnsigned { 231 data = randInt64Value(column, 0, math.MaxUint8) 232 } else { 233 data = randInt64Value(column, math.MinInt8, math.MaxInt8) 234 } 235 } 236 return strconv.FormatInt(data, 10), nil 237 case mysql.TypeShort: 238 var data int64 239 if isUnique { 240 data = uniqInt64Value(column, math.MaxUint16) 241 } else { 242 if isUnsigned { 243 data = randInt64Value(column, 0, math.MaxUint16) 244 } else { 245 data = randInt64Value(column, math.MinInt16, math.MaxInt16) 246 } 247 } 248 return strconv.FormatInt(data, 10), nil 249 case mysql.TypeLong: 250 var data int64 251 if isUnique { 252 data = uniqInt64Value(column, math.MaxUint32) 253 } else { 254 if isUnsigned { 255 data = randInt64Value(column, 0, math.MaxUint32) 256 } else { 257 data = randInt64Value(column, math.MinInt32, math.MaxInt32) 258 } 259 } 260 return strconv.FormatInt(data, 10), nil 261 case mysql.TypeLonglong: 262 var data int64 263 if isUnique { 264 data = uniqInt64Value(column, math.MaxInt64) 265 } else { 266 if isUnsigned { 267 data = randInt64Value(column, 0, math.MaxInt64) 268 } else { 269 data = randInt64Value(column, math.MinInt32, math.MaxInt32) 270 } 271 } 272 return strconv.FormatInt(data, 10), nil 273 case mysql.TypeVarchar, mysql.TypeString, mysql.TypeTinyBlob, mysql.TypeBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: 274 data := []byte{'\''} 275 if isUnique { 276 data = append(data, []byte(column.data.uniqString(tp.GetFlen()))...) 277 } else { 278 data = append(data, []byte(randString(randInt(1, tp.GetFlen())))...) 279 } 280 281 data = append(data, '\'') 282 return string(data), nil 283 case mysql.TypeFloat, mysql.TypeDouble: 284 var data float64 285 if isUnique { 286 data = float64(uniqInt64Value(column, math.MaxInt64)) 287 } else { 288 if isUnsigned { 289 data = float64(randInt64Value(column, 0, math.MaxInt64)) 290 } else { 291 data = float64(randInt64Value(column, math.MinInt32, math.MaxInt32)) 292 } 293 } 294 return strconv.FormatFloat(data, 'f', -1, 64), nil 295 case mysql.TypeDate: 296 data := []byte{'\''} 297 if isUnique { 298 data = append(data, []byte(column.data.uniqDate())...) 299 } else { 300 data = append(data, []byte(randDate(column.min, column.max))...) 301 } 302 303 data = append(data, '\'') 304 return string(data), nil 305 case mysql.TypeDatetime, mysql.TypeTimestamp: 306 data := []byte{'\''} 307 if isUnique { 308 data = append(data, []byte(column.data.uniqTimestamp())...) 309 } else { 310 data = append(data, []byte(randTimestamp(column.min, column.max))...) 311 } 312 313 data = append(data, '\'') 314 return string(data), nil 315 case mysql.TypeDuration: 316 data := []byte{'\''} 317 if isUnique { 318 data = append(data, []byte(column.data.uniqTime())...) 319 } else { 320 data = append(data, []byte(randTime(column.min, column.max))...) 321 } 322 323 data = append(data, '\'') 324 return string(data), nil 325 case mysql.TypeYear: 326 data := []byte{'\''} 327 if isUnique { 328 data = append(data, []byte(column.data.uniqYear())...) 329 } else { 330 data = append(data, []byte(randYear(column.min, column.max))...) 331 } 332 333 data = append(data, '\'') 334 return string(data), nil 335 default: 336 return "", errors.Errorf("unsupported column type - %v", column) 337 } 338 } 339 340 func execSQLs(db *sql.DB, sqls []string) error { 341 for _, sql := range sqls { 342 err := execSQL(db, sql) 343 if err != nil { 344 return errors.Trace(err) 345 } 346 } 347 return nil 348 } 349 350 func execSQL(db *sql.DB, sql string) error { 351 if len(sql) == 0 { 352 return nil 353 } 354 355 _, err := db.Exec(sql) 356 if err != nil { 357 return errors.Trace(err) 358 } 359 360 return nil 361 } 362 363 // RunTest will call writeSrc and check if src is contisitent with dst 364 func RunTest(src *sql.DB, dst *sql.DB, schema string, writeSrc func(src *sql.DB)) { 365 writeSrc(src) 366 367 tick := time.NewTicker(time.Second * 5) 368 defer tick.Stop() 369 timeout := time.After(time.Second * 240) 370 371 oldLevel := log.GetLevel() 372 defer log.SetLevel(oldLevel) 373 374 for { 375 select { 376 case <-tick.C: 377 log.SetLevel(zapcore.WarnLevel) 378 if util.CheckSyncState(src, dst, schema) { 379 return 380 } 381 case <-timeout: 382 // check last time 383 log.SetLevel(zapcore.InfoLevel) 384 if !util.CheckSyncState(src, dst, schema) { 385 log.S().Fatal("sourceDB don't equal targetDB") 386 } 387 388 return 389 } 390 } 391 }