github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/dm/pkg/conn/db.go (about) 1 // Copyright 2019 PingCAP, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package conn 15 16 import ( 17 "context" 18 "database/sql" 19 "fmt" 20 "math" 21 "math/rand" 22 "strconv" 23 "strings" 24 "time" 25 26 "github.com/coreos/go-semver/semver" 27 gmysql "github.com/go-mysql-org/go-mysql/mysql" 28 "github.com/go-sql-driver/mysql" 29 "github.com/pingcap/errors" 30 "github.com/pingcap/failpoint" 31 "github.com/pingcap/tidb/dumpling/export" 32 "github.com/pingcap/tidb/pkg/parser" 33 tmysql "github.com/pingcap/tidb/pkg/parser/mysql" 34 "github.com/pingcap/tidb/pkg/util/dbutil" 35 "github.com/pingcap/tidb/pkg/util/filter" 36 regexprrouter "github.com/pingcap/tidb/pkg/util/regexpr-router" 37 tcontext "github.com/pingcap/tiflow/dm/pkg/context" 38 "github.com/pingcap/tiflow/dm/pkg/log" 39 "github.com/pingcap/tiflow/dm/pkg/terror" 40 "go.uber.org/zap" 41 ) 42 43 const ( 44 // DefaultDBTimeout represents a DB operation timeout for common usages. 45 DefaultDBTimeout = 30 * time.Second 46 47 // for MariaDB, UUID set as `gtid_domain_id` + domainServerIDSeparator + `server_id`. 48 domainServerIDSeparator = "-" 49 50 // the default base(min) server id generated by random. 51 defaultBaseServerID = math.MaxUint32 / 10 52 ) 53 54 // GetFlavor gets flavor from DB. 55 func GetFlavor(ctx context.Context, db *BaseDB) (string, error) { 56 value, err := dbutil.ShowVersion(ctx, db.DB) 57 if err != nil { 58 return "", terror.DBErrorAdapt(err, db.Scope, terror.ErrDBDriverError) 59 } 60 if IsMariaDB(value) { 61 return gmysql.MariaDBFlavor, nil 62 } 63 return gmysql.MySQLFlavor, nil 64 } 65 66 // GetAllServerID gets all slave server id and master server id. 67 func GetAllServerID(ctx *tcontext.Context, db *BaseDB) (map[uint32]struct{}, error) { 68 serverIDs, err := GetSlaveServerID(ctx, db) 69 if err != nil { 70 return nil, err 71 } 72 73 masterServerID, err := GetServerID(ctx, db) 74 if err != nil { 75 return nil, err 76 } 77 78 serverIDs[masterServerID] = struct{}{} 79 return serverIDs, nil 80 } 81 82 // GetRandomServerID gets a random server ID which is not used. 83 func GetRandomServerID(ctx *tcontext.Context, db *BaseDB) (uint32, error) { 84 rand.Seed(time.Now().UnixNano()) 85 86 serverIDs, err := GetAllServerID(ctx, db) 87 if err != nil { 88 return 0, err 89 } 90 91 for i := 0; i < 99999; i++ { 92 randomValue := uint32(rand.Intn(100000)) 93 randomServerID := uint32(defaultBaseServerID) + randomValue 94 if _, ok := serverIDs[randomServerID]; ok { 95 continue 96 } 97 98 return randomServerID, nil 99 } 100 101 // should never happened unless the master has too many slave. 102 return 0, terror.ErrInvalidServerID.Generatef("can't find a random available server ID") 103 } 104 105 // GetSlaveServerID gets all slave server id. 106 func GetSlaveServerID(ctx *tcontext.Context, db *BaseDB) (map[uint32]struct{}, error) { 107 // need REPLICATION SLAVE privilege 108 rows, err := db.QueryContext(ctx, `SHOW SLAVE HOSTS`) 109 if err != nil { 110 return nil, terror.DBErrorAdapt(err, db.Scope, terror.ErrDBDriverError) 111 } 112 defer func() { 113 _ = rows.Close() 114 _ = rows.Err() 115 }() 116 117 /* 118 in MySQL: 119 mysql> SHOW SLAVE HOSTS; 120 +------------+-----------+------+-----------+--------------------------------------+ 121 | Server_id | Host | Port | Master_id | Slave_UUID | 122 +------------+-----------+------+-----------+--------------------------------------+ 123 | 192168010 | iconnect2 | 3306 | 192168011 | 14cb6624-7f93-11e0-b2c0-c80aa9429562 | 124 | 1921680101 | athena | 3306 | 192168011 | 07af4990-f41f-11df-a566-7ac56fdaf645 | 125 +------------+-----------+------+-----------+--------------------------------------+ 126 127 in MariaDB: 128 mysql> SHOW SLAVE HOSTS; 129 +------------+-----------+------+-----------+ 130 | Server_id | Host | Port | Master_id | 131 +------------+-----------+------+-----------+ 132 | 192168010 | iconnect2 | 3306 | 192168011 | 133 | 1921680101 | athena | 3306 | 192168011 | 134 +------------+-----------+------+-----------+ 135 */ 136 137 serverIDs := make(map[uint32]struct{}) 138 var rowsResult []string 139 rowsResult, err = export.GetSpecifiedColumnValueAndClose(rows, "Server_id") 140 if err != nil { 141 return nil, terror.DBErrorAdapt(err, db.Scope, terror.ErrDBDriverError) 142 } 143 for _, serverID := range rowsResult { 144 // serverID will not be null 145 serverIDUInt, err := strconv.ParseUint(serverID, 10, 32) 146 if err != nil { 147 return nil, terror.DBErrorAdapt(err, db.Scope, terror.ErrDBDriverError) 148 } 149 serverIDs[uint32(serverIDUInt)] = struct{}{} 150 } 151 return serverIDs, nil 152 } 153 154 // GetSessionVariable gets connection's session variable. 155 func GetSessionVariable(ctx *tcontext.Context, conn *BaseConn, variable string) (value string, err error) { 156 failpoint.Inject("GetSessionVariableFailed", func(val failpoint.Value) { 157 items := strings.Split(val.(string), ",") 158 if len(items) != 2 { 159 log.L().Fatal("failpoint GetSessionVariableFailed's value is invalid", zap.String("val", val.(string))) 160 } 161 variableName := items[0] 162 errCode, err1 := strconv.ParseUint(items[1], 10, 16) 163 if err1 != nil { 164 log.L().Fatal("failpoint GetSessionVariableFailed's value is invalid", zap.String("val", val.(string))) 165 } 166 if variable == variableName { 167 err = tmysql.NewErr(uint16(errCode)) 168 log.L().Warn("GetSessionVariable failed", zap.String("variable", variable), zap.String("failpoint", "GetSessionVariableFailed"), zap.Error(err)) 169 failpoint.Return("", terror.DBErrorAdapt(err, conn.Scope, terror.ErrDBDriverError)) 170 } 171 }) 172 return getVariable(ctx, conn, variable, false) 173 } 174 175 // GetServerID gets server's `server_id`. 176 func GetServerID(ctx *tcontext.Context, db *BaseDB) (uint32, error) { 177 serverIDStr, err := GetGlobalVariable(ctx, db, "server_id") 178 if err != nil { 179 return 0, err 180 } 181 182 serverID, err := strconv.ParseUint(serverIDStr, 10, 32) 183 return uint32(serverID), terror.ErrInvalidServerID.Delegate(err, serverIDStr) 184 } 185 186 // GetMariaDBGtidDomainID gets MariaDB server's `gtid_domain_id`. 187 func GetMariaDBGtidDomainID(ctx *tcontext.Context, db *BaseDB) (uint32, error) { 188 domainIDStr, err := GetGlobalVariable(ctx, db, "gtid_domain_id") 189 if err != nil { 190 return 0, err 191 } 192 193 domainID, err := strconv.ParseUint(domainIDStr, 10, 32) 194 return uint32(domainID), terror.ErrMariaDBDomainID.Delegate(err, domainIDStr) 195 } 196 197 // GetServerUUID gets server's `server_uuid`. 198 func GetServerUUID(ctx *tcontext.Context, db *BaseDB, flavor string) (string, error) { 199 if flavor == gmysql.MariaDBFlavor { 200 return GetMariaDBUUID(ctx, db) 201 } 202 serverUUID, err := GetGlobalVariable(ctx, db, "server_uuid") 203 return serverUUID, err 204 } 205 206 // GetServerUnixTS gets server's `UNIX_TIMESTAMP()`. 207 func GetServerUnixTS(ctx context.Context, db *BaseDB) (int64, error) { 208 var ts int64 209 row := db.DB.QueryRowContext(ctx, "SELECT UNIX_TIMESTAMP()") 210 err := row.Scan(&ts) 211 if err != nil { 212 log.L().Error("can't SELECT UNIX_TIMESTAMP()", zap.Error(err)) 213 return ts, terror.DBErrorAdapt(err, db.Scope, terror.ErrDBDriverError) 214 } 215 return ts, err 216 } 217 218 // GetMariaDBUUID gets equivalent `server_uuid` for MariaDB 219 // `gtid_domain_id` joined `server_id` with domainServerIDSeparator. 220 func GetMariaDBUUID(ctx *tcontext.Context, db *BaseDB) (string, error) { 221 domainID, err := GetMariaDBGtidDomainID(ctx, db) 222 if err != nil { 223 return "", err 224 } 225 serverID, err := GetServerID(ctx, db) 226 if err != nil { 227 return "", err 228 } 229 return fmt.Sprintf("%d%s%d", domainID, domainServerIDSeparator, serverID), nil 230 } 231 232 // GetParser gets a parser for sql.DB which is suitable for session variable sql_mode. 233 func GetParser(ctx *tcontext.Context, db *BaseDB) (*parser.Parser, error) { 234 c, err := db.GetBaseConn(ctx.Ctx) 235 if err != nil { 236 return nil, err 237 } 238 defer db.CloseConnWithoutErr(c) 239 return GetParserForConn(ctx, c) 240 } 241 242 // GetParserForConn gets a parser for BaseConn which is suitable for session variable sql_mode. 243 func GetParserForConn(ctx *tcontext.Context, conn *BaseConn) (*parser.Parser, error) { 244 sqlMode, err := GetSessionVariable(ctx, conn, "sql_mode") 245 if err != nil { 246 return nil, err 247 } 248 return GetParserFromSQLModeStr(sqlMode) 249 } 250 251 // GetParserFromSQLModeStr gets a parser and applies given sqlMode. 252 func GetParserFromSQLModeStr(sqlMode string) (*parser.Parser, error) { 253 mode, err := tmysql.GetSQLMode(sqlMode) 254 if err != nil { 255 return nil, err 256 } 257 258 parser2 := parser.New() 259 parser2.SetSQLMode(mode) 260 return parser2, nil 261 } 262 263 // KillConn kills the DB connection (thread in mysqld). 264 func KillConn(ctx *tcontext.Context, db *BaseDB, connID uint32) error { 265 _, err := db.ExecContext(ctx, fmt.Sprintf("KILL %d", connID)) 266 return terror.DBErrorAdapt(err, db.Scope, terror.ErrDBDriverError) 267 } 268 269 // IsMySQLError checks whether err is MySQLError error. 270 func IsMySQLError(err error, code uint16) bool { 271 err = errors.Cause(err) 272 e, ok := err.(*mysql.MySQLError) 273 return ok && e.Number == code 274 } 275 276 // IsErrDuplicateEntry checks whether err is DuplicateEntry error. 277 func IsErrDuplicateEntry(err error) bool { 278 return IsMySQLError(err, tmysql.ErrDupEntry) 279 } 280 281 // IsErrBinlogPurged checks whether err is BinlogPurged error. 282 func IsErrBinlogPurged(err error) bool { 283 return IsMySQLError(err, tmysql.ErrMasterFatalErrorReadingBinlog) 284 } 285 286 // IsNoSuchThreadError checks whether err is NoSuchThreadError. 287 func IsNoSuchThreadError(err error) bool { 288 return IsMySQLError(err, tmysql.ErrNoSuchThread) 289 } 290 291 // GetGTIDMode return GTID_MODE. 292 func GetGTIDMode(ctx *tcontext.Context, db *BaseDB) (string, error) { 293 val, err := GetGlobalVariable(ctx, db, "GTID_MODE") 294 return val, err 295 } 296 297 // ExtractTiDBVersion extract tidb's version 298 // version format: "5.7.25-TiDB-v3.0.0-beta-211-g09beefbe0-dirty" 299 // - ^.......... 300 func ExtractTiDBVersion(version string) (*semver.Version, error) { 301 versions := strings.Split(strings.TrimSuffix(version, "-dirty"), "-") 302 end := len(versions) 303 switch end { 304 case 3, 4: 305 case 5, 6: 306 end -= 2 307 default: 308 return nil, errors.Errorf("not a valid TiDB version: %s", version) 309 } 310 rawVersion := strings.Join(versions[2:end], "-") 311 rawVersion = strings.TrimPrefix(rawVersion, "v") 312 return semver.NewVersion(rawVersion) 313 } 314 315 // AddGSetWithPurged is used to handle this case: https://github.com/pingcap/dm/issues/1418 316 // we might get a gtid set from Previous_gtids event in binlog, but that gtid set can't be used to start a gtid sync 317 // because it doesn't cover all gtid_purged. The error of using it will be 318 // ERROR 1236 (HY000): The slave is connecting using CHANGE MASTER TO MASTER_AUTO_POSITION = 1, but the master has purged binary logs containing GTIDs that the slave requires. 319 // so we add gtid_purged to it. 320 func AddGSetWithPurged(ctx context.Context, gset gmysql.GTIDSet, conn *BaseConn) (gmysql.GTIDSet, error) { 321 if _, ok := gset.(*gmysql.MariadbGTIDSet); ok { 322 return gset, nil 323 } 324 325 var ( 326 gtidStr string 327 row *sql.Row 328 err error 329 ) 330 331 failpoint.Inject("GetGTIDPurged", func(val failpoint.Value) { 332 str := val.(string) 333 gtidStr = str 334 failpoint.Goto("bypass") 335 }) 336 row = conn.DBConn.QueryRowContext(ctx, "select @@GLOBAL.gtid_purged") 337 err = row.Scan(>idStr) 338 if err != nil { 339 log.L().Error("can't get @@GLOBAL.gtid_purged when try to add it to gtid set", zap.Error(err)) 340 return gset, terror.DBErrorAdapt(err, conn.Scope, terror.ErrDBDriverError) 341 } 342 failpoint.Label("bypass") 343 if gtidStr == "" { 344 return gset, nil 345 } 346 347 cloned := gset.Clone() 348 err = cloned.Update(gtidStr) 349 if err != nil { 350 return nil, err 351 } 352 return cloned, nil 353 } 354 355 // AdjustSQLModeCompatible adjust downstream sql mode to compatible. 356 // TODO: When upstream's datatime is 2020-00-00, 2020-00-01, 2020-06-00 357 // and so on, downstream will be 2019-11-30, 2019-12-01, 2020-05-31, 358 // as if set the 'NO_ZERO_IN_DATE', 'NO_ZERO_DATE'. 359 // This is because the implementation of go-mysql, that you can see 360 // https://github.com/go-mysql-org/go-mysql/blob/master/replication/row_event.go#L1063-L1087 361 func AdjustSQLModeCompatible(sqlModes string) (string, error) { 362 needDisable := []string{ 363 "NO_ZERO_IN_DATE", 364 "NO_ZERO_DATE", 365 "ERROR_FOR_DIVISION_BY_ZERO", 366 "NO_AUTO_CREATE_USER", 367 "STRICT_TRANS_TABLES", 368 "STRICT_ALL_TABLES", 369 } 370 needEnable := []string{ 371 "IGNORE_SPACE", 372 "NO_AUTO_VALUE_ON_ZERO", 373 "ALLOW_INVALID_DATES", 374 } 375 disable := strings.Join(needDisable, ",") 376 enable := strings.Join(needEnable, ",") 377 378 mode, err := tmysql.GetSQLMode(sqlModes) 379 if err != nil { 380 return sqlModes, err 381 } 382 disableMode, err2 := tmysql.GetSQLMode(disable) 383 if err2 != nil { 384 return sqlModes, err2 385 } 386 enableMode, err3 := tmysql.GetSQLMode(enable) 387 if err3 != nil { 388 return sqlModes, err3 389 } 390 // About this bit manipulation, details can be seen 391 // https://github.com/pingcap/dm/pull/1869#discussion_r669771966 392 mode = (mode &^ disableMode) | enableMode 393 394 return GetSQLModeStrBySQLMode(mode), nil 395 } 396 397 // GetSQLModeStrBySQLMode get string represent of sql_mode by sql_mode. 398 func GetSQLModeStrBySQLMode(sqlMode tmysql.SQLMode) string { 399 var sqlModeStr []string 400 for str, SQLMode := range tmysql.Str2SQLMode { 401 if sqlMode&SQLMode != 0 { 402 sqlModeStr = append(sqlModeStr, str) 403 } 404 } 405 return strings.Join(sqlModeStr, ",") 406 } 407 408 // GetMaxConnections gets max_connections for sql.DB which is suitable for session variable max_connections. 409 func GetMaxConnections(ctx *tcontext.Context, db *BaseDB) (int, error) { 410 c, err := db.GetBaseConn(ctx.Ctx) 411 if err != nil { 412 return 0, err 413 } 414 defer db.CloseConnWithoutErr(c) 415 return GetMaxConnectionsForConn(ctx, c) 416 } 417 418 // GetMaxConnectionsForConn gets max_connections for BaseConn which is suitable for session variable max_connections. 419 func GetMaxConnectionsForConn(ctx *tcontext.Context, conn *BaseConn) (int, error) { 420 maxConnectionsStr, err := GetSessionVariable(ctx, conn, "max_connections") 421 if err != nil { 422 return 0, err 423 } 424 maxConnections, err := strconv.ParseUint(maxConnectionsStr, 10, 32) 425 return int(maxConnections), err 426 } 427 428 // IsMariaDB tells whether the version is mariadb. 429 func IsMariaDB(version string) bool { 430 return strings.Contains(strings.ToUpper(version), "MARIADB") 431 } 432 433 // CreateTableSQLToOneRow formats the result of SHOW CREATE TABLE to one row. 434 func CreateTableSQLToOneRow(sql string) string { 435 sql = strings.ReplaceAll(sql, "\n", "") 436 sql = strings.ReplaceAll(sql, " ", " ") 437 return sql 438 } 439 440 // FetchAllDoTables returns all need to do tables after filtered (fetches from upstream MySQL). 441 func FetchAllDoTables(ctx context.Context, db *BaseDB, bw *filter.Filter) (map[string][]string, error) { 442 schemas, err := dbutil.GetSchemas(ctx, db.DB) 443 444 failpoint.Inject("FetchAllDoTablesFailed", func(val failpoint.Value) { 445 err = tmysql.NewErr(uint16(val.(int))) 446 log.L().Warn("FetchAllDoTables failed", zap.String("failpoint", "FetchAllDoTablesFailed"), zap.Error(err)) 447 }) 448 449 if err != nil { 450 return nil, terror.WithScope(err, db.Scope) 451 } 452 453 ftSchemas := make([]*filter.Table, 0, len(schemas)) 454 for _, schema := range schemas { 455 if filter.IsSystemSchema(schema) { 456 continue 457 } 458 ftSchemas = append(ftSchemas, &filter.Table{ 459 Schema: schema, 460 Name: "", // schema level 461 }) 462 } 463 ftSchemas = bw.Apply(ftSchemas) 464 if len(ftSchemas) == 0 { 465 log.L().Warn("no schema need to sync") 466 return nil, nil 467 } 468 469 schemaToTables := make(map[string][]string) 470 for _, ftSchema := range ftSchemas { 471 schema := ftSchema.Schema 472 // use `GetTables` from tidb-tools, no view included 473 tables, err := dbutil.GetTables(ctx, db.DB, schema) 474 if err != nil { 475 return nil, terror.DBErrorAdapt(err, db.Scope, terror.ErrDBDriverError) 476 } 477 ftTables := make([]*filter.Table, 0, len(tables)) 478 for _, table := range tables { 479 ftTables = append(ftTables, &filter.Table{ 480 Schema: schema, 481 Name: table, 482 }) 483 } 484 ftTables = bw.Apply(ftTables) 485 if len(ftTables) == 0 { 486 log.L().Info("no tables need to sync", zap.String("schema", schema)) 487 continue // NOTE: should we still keep it as an empty elem? 488 } 489 tables = tables[:0] 490 for _, ftTable := range ftTables { 491 tables = append(tables, ftTable.Name) 492 } 493 schemaToTables[schema] = tables 494 } 495 496 return schemaToTables, nil 497 } 498 499 // FetchTargetDoTables returns all need to do tables after filtered and routed (fetches from upstream MySQL). 500 func FetchTargetDoTables( 501 ctx context.Context, 502 source string, 503 db *BaseDB, 504 bw *filter.Filter, 505 router *regexprrouter.RouteTable, 506 ) (map[filter.Table][]filter.Table, map[filter.Table][]string, error) { 507 // fetch tables from source and filter them 508 sourceTables, err := FetchAllDoTables(ctx, db, bw) 509 510 failpoint.Inject("FetchTargetDoTablesFailed", func(val failpoint.Value) { 511 err = tmysql.NewErr(uint16(val.(int))) 512 log.L().Warn("FetchTargetDoTables failed", zap.String("failpoint", "FetchTargetDoTablesFailed"), zap.Error(err)) 513 }) 514 515 if err != nil { 516 return nil, nil, err 517 } 518 519 tableMapper := make(map[filter.Table][]filter.Table) 520 extendedColumnPerTable := make(map[filter.Table][]string) 521 for schema, tables := range sourceTables { 522 for _, table := range tables { 523 targetSchema, targetTable, err := router.Route(schema, table) 524 if err != nil { 525 return nil, nil, terror.ErrGenTableRouter.Delegate(err) 526 } 527 528 target := filter.Table{ 529 Schema: targetSchema, 530 Name: targetTable, 531 } 532 tableMapper[target] = append(tableMapper[target], filter.Table{ 533 Schema: schema, 534 Name: table, 535 }) 536 col, _ := router.FetchExtendColumn(schema, table, source) 537 if len(col) > 0 { 538 extendedColumnPerTable[target] = col 539 } 540 } 541 } 542 543 return tableMapper, extendedColumnPerTable, nil 544 }