github.com/wfusion/gofusion@v1.1.14/db/plugins/table_sharding.go (about) 1 package plugins 2 3 import ( 4 "bytes" 5 "context" 6 "encoding/binary" 7 "fmt" 8 "hash/crc32" 9 "math" 10 "math/big" 11 "reflect" 12 "strconv" 13 "strings" 14 "sync" 15 "unsafe" 16 17 "github.com/PaesslerAG/gval" 18 "github.com/google/uuid" 19 "github.com/pkg/errors" 20 "github.com/spf13/cast" 21 "gorm.io/gorm" 22 "gorm.io/gorm/clause" 23 "gorm.io/gorm/schema" 24 25 "github.com/wfusion/gofusion/common/constant" 26 "github.com/wfusion/gofusion/common/infra/drivers/orm/idgen" 27 "github.com/wfusion/gofusion/common/utils" 28 "github.com/wfusion/gofusion/common/utils/clone" 29 "github.com/wfusion/gofusion/common/utils/inspect" 30 "github.com/wfusion/gofusion/common/utils/sqlparser" 31 "github.com/wfusion/gofusion/db/callbacks" 32 ) 33 34 const ( 35 shardingIgnoreStoreKey = "sharding_ignore" 36 ) 37 38 var ( 39 ErrInvalidID = errors.New("invalid id format") 40 ErrIDGeneratorNotFound = errors.New("id generator not found") 41 ErrShardingModelNotFound = errors.New("sharding table model not found when migrating") 42 ErrDiffSuffixDML = errors.New("can not query different suffix table in one sql") 43 ErrMissingShardingKey = errors.New("sharding key required and use operator =") 44 ErrColumnAndExprMisMatch = errors.New("column names and expressions mismatch") 45 46 gormSchemaEmbeddedNamer = inspect.TypeOf("gorm.io/gorm/schema.embeddedNamer") 47 ) 48 49 type TableShardingConfig struct { 50 // Database name 51 Database string 52 53 // Table name 54 Table string 55 56 // ShardingKeys required, specifies the table columns you want to use for sharding the table rows. 57 // For example, for a product order table, you may want to split the rows by `user_id`. 58 ShardingKeys []string 59 60 // ShardingKeyExpr optional, specifies how to calculate sharding key by columns, e.g. tenant_id << 16 | user_id 61 ShardingKeyExpr gval.Evaluable 62 63 // ShardingKeyByRawValue optional, specifies sharding key with snake values, e.g. xxx_region1_az1, xxx_region1_az2 64 ShardingKeyByRawValue bool 65 66 // ShardingKeysForMigrating optional, specifies all sharding keys 67 ShardingKeysForMigrating []string 68 69 // NumberOfShards required, specifies how many tables you want to sharding. 70 NumberOfShards uint 71 72 // CustomSuffix optional, specifies shard table a custom suffix, e.g. user_%02d means <main_table_name>_user_01 73 CustomSuffix string 74 75 // PrimaryKeyGenerator optional, generates id if id is a sharding key and is zero 76 PrimaryKeyGenerator idgen.Generator 77 } 78 79 // sharding plugin inspired by gorm.io/sharding@v0.5.3 80 type tableSharding struct { 81 *gorm.DB 82 83 config TableShardingConfig 84 85 shardingFunc func(ctx context.Context, values ...any) (suffix string, err error) 86 isShardingPrimaryKey bool 87 shardingPrimaryKey string 88 shardingTableModel any 89 shardingTableCreatedMutex sync.RWMutex 90 shardingTableCreated map[string]struct{} 91 92 suffixFormat string 93 } 94 95 func DefaultTableSharding(config TableShardingConfig) TableSharding { 96 if utils.IsStrBlank(config.Table) { 97 panic(errors.New("missing sharding table name")) 98 } 99 if len(config.ShardingKeys) == 0 { 100 panic(errors.New("missing sharding keys")) 101 } 102 if !config.ShardingKeyByRawValue && (config.NumberOfShards <= 0 || config.NumberOfShards >= 100000) { 103 panic(errors.New("invalid number of shards")) 104 } 105 106 shardingKeySet := utils.NewSet(config.ShardingKeys...) 107 shardingPrimaryKey := "" 108 isShardingPrimaryKey := false 109 if shardingKeySet.Contains("id") || shardingKeySet.Contains("ID") || 110 shardingKeySet.Contains("iD") || shardingKeySet.Contains("Id") { 111 if config.PrimaryKeyGenerator == nil { 112 panic(errors.New("sharding by primary key but primary key generator not found")) 113 } 114 115 isShardingPrimaryKey = true 116 for _, key := range config.ShardingKeys { 117 if key == "id" || key == "ID" || key == "Id" || key == "iD" { 118 shardingPrimaryKey = key 119 break 120 } 121 } 122 } 123 124 return &tableSharding{ 125 config: config, 126 isShardingPrimaryKey: isShardingPrimaryKey, 127 shardingPrimaryKey: shardingPrimaryKey, 128 shardingTableCreated: make(map[string]struct{}, config.NumberOfShards), 129 } 130 } 131 132 func (t *tableSharding) Name() string { 133 return fmt.Sprintf("gorm:sharding:%s:%s", t.config.Database, t.config.Table) 134 } 135 136 func (t *tableSharding) Initialize(db *gorm.DB) (err error) { 137 db.Dialector = newShardingDialector(db.Dialector, t) 138 139 t.DB = db 140 t.shardingFunc = t.defaultShardingFunc() 141 t.registerCallbacks(db) 142 return 143 } 144 145 func (t *tableSharding) ShardingByModelList(ctx context.Context, src ...any) (dst map[string][]any, err error) { 146 dst = make(map[string][]any, len(t.config.ShardingKeys)) 147 for _, m := range src { 148 val := reflect.Indirect(reflect.ValueOf(m)) 149 shardingValues := make([]any, 0, len(t.config.ShardingKeys)) 150 for _, key := range t.config.ShardingKeys { 151 field := val.FieldByNameFunc(func(v string) bool { return strings.EqualFold(v, key) }) 152 if !field.IsValid() { 153 field, _ = utils.GetGormColumnValue(val, key) 154 } 155 if !field.IsValid() { 156 return dst, ErrMissingShardingKey 157 } 158 if key == t.shardingPrimaryKey && field.IsZero() { 159 return dst, ErrInvalidID 160 } 161 shardingValues = append(shardingValues, field.Interface()) 162 } 163 suffix, err := t.shardingFunc(ctx, shardingValues...) 164 if err != nil { 165 return dst, err 166 } 167 dst[suffix] = append(dst[suffix], m) 168 } 169 return 170 } 171 172 func (t *tableSharding) ShardingByValues(ctx context.Context, src []map[string]any) ( 173 dst map[string][]map[string]any, err error) { 174 dst = make(map[string][]map[string]any, len(t.config.ShardingKeys)) 175 for _, col := range src { 176 values := make([]any, 0, len(col)) 177 for _, k := range t.config.ShardingKeys { 178 value, ok := col[k] 179 if !ok { 180 return dst, errors.Errorf("sharding key not found [column[%s]]", k) 181 } 182 if k == t.shardingPrimaryKey && utils.IsBlank(value) { 183 return dst, ErrInvalidID 184 } 185 values = append(values, value) 186 } 187 suffix, err := t.shardingFunc(ctx, values...) 188 if err != nil { 189 return dst, err 190 } 191 dst[suffix] = append(dst[suffix], col) 192 } 193 return 194 } 195 196 func (t *tableSharding) ShardingIDGen(ctx context.Context) (id uint64, err error) { 197 if t.config.PrimaryKeyGenerator == nil { 198 return 0, ErrIDGeneratorNotFound 199 } 200 return t.config.PrimaryKeyGenerator.Next() 201 } 202 203 func (t *tableSharding) registerCallbacks(db *gorm.DB) { 204 utils.MustSuccess(db.Callback(). 205 Create(). 206 After("gorm:before_create"). 207 Before("gorm:save_before_associations"). 208 Register(t.Name(), t.createCallback)) 209 210 utils.MustSuccess(db.Callback(). 211 Query(). 212 Before("gorm:query"). 213 Register(t.Name(), t.queryCallback)) 214 215 utils.MustSuccess(db.Callback(). 216 Update(). 217 After("gorm:before_update"). 218 Before("gorm:save_before_associations"). 219 Register(t.Name(), t.updateCallback)) 220 221 utils.MustSuccess(db.Callback(). 222 Delete(). 223 After("gorm:before_delete"). 224 Before("gorm:delete_before_associations"). 225 Register(t.Name(), t.deleteCallback)) 226 227 utils.MustSuccess(db.Callback(). 228 Row(). 229 Before("gorm:row"). 230 Register(t.Name(), t.queryCallback)) 231 232 utils.MustSuccess(db.Callback(). 233 Raw(). 234 Before("gorm:raw"). 235 Register(t.Name(), t.rawCallback)) 236 } 237 func (t *tableSharding) createCallback(db *gorm.DB) { 238 utils.IfAny( 239 t.isIgnored(db), 240 func() bool { ok1, ok2 := t.dispatchTableByModel(db, tableShardingIsInsert()); return ok1 || ok2 }, 241 func() bool { 242 callbacks.BuildCreateSQL(db) 243 t.wrapDispatchTableBySQL(db, tableShardingIsInsert()) 244 return true 245 }, 246 ) 247 } 248 func (t *tableSharding) queryCallback(db *gorm.DB) { 249 utils.IfAny( 250 t.isIgnored(db), 251 func() bool { ok1, ok2 := t.dispatchTableByModel(db); return ok1 || ok2 }, 252 func() bool { 253 callbacks.BuildQuerySQL(db) 254 t.wrapDispatchTableBySQL(db) 255 return true 256 }, 257 ) 258 } 259 func (t *tableSharding) updateCallback(db *gorm.DB) { 260 utils.IfAny( 261 t.isIgnored(db), 262 func() bool { ok1, ok2 := t.dispatchTableByModel(db); return ok1 || ok2 }, 263 func() bool { 264 callbacks.BuildUpdateSQL(db) 265 t.wrapDispatchTableBySQL(db) 266 return true 267 }, 268 ) 269 } 270 func (t *tableSharding) deleteCallback(db *gorm.DB) { 271 utils.IfAny( 272 t.isIgnored(db), 273 func() bool { ok1, ok2 := t.dispatchTableByModel(db); return ok1 || ok2 }, 274 func() bool { 275 callbacks.BuildDeleteSQL(db) 276 t.wrapDispatchTableBySQL(db) 277 return true 278 }, 279 ) 280 } 281 func (t *tableSharding) rawCallback(db *gorm.DB) { 282 utils.IfAny( 283 t.isIgnored(db), 284 func() bool { ok1, ok2 := t.dispatchTableByModel(db); return ok1 || ok2 }, 285 func() bool { t.wrapDispatchTableBySQL(db); return true }, 286 ) 287 } 288 289 type tableShardingDispatchOption struct { 290 isInsert bool 291 } 292 293 func tableShardingIsInsert() utils.OptionFunc[tableShardingDispatchOption] { 294 return func(t *tableShardingDispatchOption) { 295 t.isInsert = true 296 } 297 } 298 299 func (t *tableSharding) dispatchTableByModel(db *gorm.DB, opts ...utils.OptionExtender) (otherTable, ok bool) { 300 if db.Statement.Model == nil || utils.IsBlank(db.Statement.ReflectValue.Interface()) { 301 return 302 } 303 if db.Statement.Table != t.config.Table { 304 otherTable = true 305 return 306 } 307 if t.shardingTableModel == nil { 308 if _, ok := db.Statement.Model.(schema.Tabler); ok { 309 cloneModel := clone.Clone(db.Statement.Model) 310 t.shardingTableModel = cloneModel 311 } 312 } 313 314 opt := utils.ApplyOptions[tableShardingDispatchOption](opts...) 315 if t.isShardingPrimaryKey { 316 if err := t.setPrimaryKeyByModel(db, opt); err != nil { 317 _ = db.AddError(err) 318 return 319 } 320 } 321 322 reflectVal, ok := t.getModelReflectValue(db) 323 if !ok { 324 return 325 } 326 if err := t.checkDiffSuffixesByModel(db); err != nil { 327 return 328 } 329 330 values := make([]any, 0, len(t.config.ShardingKeys)) 331 for _, key := range t.config.ShardingKeys { 332 val := reflectVal.FieldByNameFunc(func(v string) bool { return strings.EqualFold(v, key) }) 333 if !val.IsValid() { 334 val, _ = utils.GetGormColumnValue(reflectVal, key) 335 } 336 if !val.IsValid() { 337 _ = db.AddError(ErrMissingShardingKey) 338 return 339 } 340 values = append(values, val.Interface()) 341 } 342 343 suffix, err := t.shardingFunc(db.Statement.Context, values...) 344 if err != nil { 345 _ = db.AddError(err) 346 return 347 } 348 // cannot parse suffix from model 349 if utils.IsStrBlank(suffix) || suffix == constant.Underline { 350 return false, false 351 } 352 if err = t.createTableIfNotExists(db, db.Statement.Table, suffix); err != nil { 353 _ = db.AddError(err) 354 return 355 } 356 357 db.Statement.Table = db.Statement.Table + suffix 358 t.replaceStatementClauseAndSchema(db, opt) 359 ok = true 360 return 361 } 362 363 //nolint: revive // sql parser issue 364 func (t *tableSharding) dispatchTableBySQL(db *gorm.DB, opts ...utils.OptionExtender) (ok bool, err error) { 365 expr, err := sqlparser.NewParser(strings.NewReader(db.Statement.SQL.String())).ParseStatement() 366 if err != nil { 367 // maybe not a dml, so we ignore this error 368 return 369 } 370 371 getSuffix := func(condition sqlparser.Node, tableName string, vars ...any) (suffix string, err error) { 372 values := make([]any, 0, len(t.config.ShardingKeys)) 373 for _, key := range t.config.ShardingKeys { 374 val, err := t.nonInsertValue(condition, key, tableName, vars...) 375 if err != nil { 376 return "", db.AddError(err) 377 } 378 values = append(values, val) 379 } 380 381 suffix, err = t.shardingFunc(db.Statement.Context, values...) 382 if err != nil { 383 return "", db.AddError(err) 384 } 385 return 386 } 387 388 newSQL := "" 389 switch stmt := expr.(type) { 390 case *sqlparser.InsertStatement: 391 if stmt.TableName.TableName() != t.config.Table { 392 return 393 } 394 395 suffix := "" 396 for _, insertExpression := range stmt.Expressions { 397 values, id, e := t.insertValue(t.config.ShardingKeys, stmt.ColumnNames, 398 insertExpression.Exprs, db.Statement.Vars...) 399 if e != nil { 400 _ = db.AddError(e) 401 return 402 } 403 if t.isShardingPrimaryKey && id == 0 { 404 if t.config.PrimaryKeyGenerator == nil { 405 _ = db.AddError(ErrIDGeneratorNotFound) 406 return 407 } 408 if id, e = t.config.PrimaryKeyGenerator.Next(idgen.GormTx(db)); e != nil { 409 _ = db.AddError(e) 410 return 411 } 412 stmt.ColumnNames = append(stmt.ColumnNames, &sqlparser.Ident{Name: "id"}) 413 insertExpression.Exprs = append(insertExpression.Exprs, &sqlparser.NumberLit{Value: cast.ToString(id)}) 414 values, _, _ = t.insertValue(t.config.ShardingKeys, stmt.ColumnNames, 415 insertExpression.Exprs, db.Statement.Vars...) 416 } 417 418 subSuffix, e := t.shardingFunc(db.Statement.Context, values...) 419 if e != nil { 420 _ = db.AddError(e) 421 return 422 } 423 424 if suffix != "" && suffix != subSuffix { 425 _ = db.AddError(ErrDiffSuffixDML) 426 return 427 } 428 suffix = subSuffix 429 } 430 // FIXME: could not find the table schema to migrate 431 if e := t.createTableIfNotExists(db, db.Statement.Table, suffix); e != nil { 432 _ = db.AddError(e) 433 return 434 } 435 stmt.TableName = &sqlparser.TableName{Name: &sqlparser.Ident{Name: stmt.TableName.TableName() + suffix}} 436 newSQL = stmt.String() 437 case *sqlparser.SelectStatement: 438 parseSelectStatementFunc := func(stmt *sqlparser.SelectStatement) (ok bool, err error) { 439 if stmt.Hint != nil && stmt.Hint.Value == "nosharding" { 440 return false, nil 441 } 442 443 switch tbl := stmt.FromItems.(type) { 444 case *sqlparser.TableName: 445 if tbl.TableName() != t.config.Table { 446 return false, nil 447 } 448 suffix, e := getSuffix(stmt.Condition, t.config.Table, db.Statement.Vars...) 449 if e != nil { 450 _ = db.AddError(e) 451 return false, nil 452 } 453 oldTableName := tbl.TableName() 454 newTableName := oldTableName + suffix 455 stmt.FromItems = &sqlparser.TableName{Name: &sqlparser.Ident{Name: newTableName}} 456 stmt.OrderBy = t.replaceOrderByTableName(stmt.OrderBy, oldTableName, newTableName) 457 if e := t.replaceCondition(stmt.Condition, oldTableName, newTableName); err != nil { 458 _ = db.AddError(e) 459 return false, nil 460 } 461 case *sqlparser.JoinClause: 462 tblx, _ := tbl.X.(*sqlparser.TableName) 463 tbly, _ := tbl.Y.(*sqlparser.TableName) 464 isXSharding := tblx != nil && tblx.TableName() == t.config.Table 465 isYSharding := tbly != nil && tbly.TableName() == t.config.Table 466 oldTableName := "" 467 switch { 468 case isXSharding: 469 oldTableName = tblx.TableName() 470 case isYSharding: 471 oldTableName = tbly.TableName() 472 default: 473 return false, nil 474 } 475 suffix, e := getSuffix(stmt.Condition, oldTableName, db.Statement.Vars...) 476 if e != nil { 477 _ = db.AddError(e) 478 return false, nil 479 } 480 newTableName := oldTableName + suffix 481 stmt.OrderBy = t.replaceOrderByTableName(stmt.OrderBy, oldTableName, newTableName) 482 if e := t.replaceCondition(stmt.Condition, oldTableName, newTableName); err != nil { 483 _ = db.AddError(e) 484 return false, nil 485 } 486 if e := t.replaceConstraint(tbl.Constraint, oldTableName, newTableName); err != nil { 487 _ = db.AddError(e) 488 return false, nil 489 } 490 if isXSharding { 491 tblx.Name.Name = newTableName 492 } else { 493 tbly.Name.Name = newTableName 494 } 495 if stmt.Columns != nil { 496 for _, column := range *stmt.Columns { 497 columnTbl, ok := column.Expr.(*sqlparser.QualifiedRef) 498 if !ok || columnTbl.Table.Name != oldTableName { 499 continue 500 } 501 columnTbl.Table.Name = newTableName 502 } 503 } 504 } 505 return true, nil 506 } 507 for compound := stmt; compound != nil; compound = compound.Compound { 508 if ok, err = parseSelectStatementFunc(compound); !ok || err != nil { 509 return 510 } 511 } 512 513 newSQL = stmt.String() 514 515 case *sqlparser.UpdateStatement: 516 if stmt.TableName.TableName() != t.config.Table { 517 return 518 } 519 520 suffix, e := getSuffix(stmt.Condition, t.config.Table, db.Statement.Vars...) 521 if e != nil { 522 _ = db.AddError(e) 523 return 524 } 525 526 oldTableName := stmt.TableName.TableName() 527 newTableName := oldTableName + suffix 528 stmt.TableName = &sqlparser.TableName{Name: &sqlparser.Ident{Name: newTableName}} 529 if e := t.replaceCondition(stmt.Condition, oldTableName, newTableName); err != nil { 530 _ = db.AddError(e) 531 return false, nil 532 } 533 newSQL = stmt.String() 534 case *sqlparser.DeleteStatement: 535 if stmt.TableName.TableName() != t.config.Table { 536 return 537 } 538 539 suffix, e := getSuffix(stmt.Condition, t.config.Table, db.Statement.Vars...) 540 if e != nil { 541 _ = db.AddError(e) 542 return 543 } 544 545 oldTableName := stmt.TableName.TableName() 546 newTableName := oldTableName + suffix 547 stmt.TableName = &sqlparser.TableName{Name: &sqlparser.Ident{Name: newTableName}} 548 if e := t.replaceCondition(stmt.Condition, oldTableName, newTableName); err != nil { 549 _ = db.AddError(e) 550 return false, nil 551 } 552 newSQL = stmt.String() 553 default: 554 _ = db.AddError(sqlparser.ErrNotImplemented) 555 return 556 } 557 558 sb := strings.Builder{} 559 sb.Grow(len(newSQL)) 560 sb.WriteString(newSQL) 561 db.Statement.SQL = sb 562 563 return true, nil 564 } 565 func (t *tableSharding) wrapDispatchTableBySQL(db *gorm.DB, opts ...utils.OptionExtender) { 566 if ok, err := t.dispatchTableBySQL(db, opts...); err != nil || !ok { 567 // not a dml 568 if err != nil { 569 return 570 } 571 // not a sharding table 572 if !ok { 573 // FIXME: reset sql parse result will get duplicated sql statement 574 // db.Statement.SQL = strings.Builder{} 575 // db.Statement.Vars = nil 576 } 577 } 578 } 579 func (t *tableSharding) replaceStatementClauseAndSchema(db *gorm.DB, opt *tableShardingDispatchOption) { 580 changeExprFunc := func(src []clause.Expression) (dst []clause.Expression) { 581 changeTableFunc := func(src any) (dst any, ok bool) { 582 switch col := src.(type) { 583 case clause.Column: 584 if col.Table == t.config.Table { 585 col.Table = db.Statement.Table 586 return col, true 587 } 588 case clause.Table: 589 if col.Name == t.config.Table { 590 col.Name = db.Statement.Table 591 return col, true 592 } 593 } 594 return 595 } 596 dst = make([]clause.Expression, 0, len(src)) 597 for _, srcExpr := range src { 598 switch expr := srcExpr.(type) { 599 case clause.IN: 600 if col, ok := changeTableFunc(expr.Column); ok { 601 expr.Column = col 602 } 603 dst = append(dst, expr) 604 case clause.Eq: 605 if col, ok := changeTableFunc(expr.Column); ok { 606 expr.Column = col 607 } 608 dst = append(dst, expr) 609 case clause.Neq: 610 if col, ok := changeTableFunc(expr.Column); ok { 611 expr.Column = col 612 } 613 dst = append(dst, expr) 614 case clause.Gt: 615 if col, ok := changeTableFunc(expr.Column); ok { 616 expr.Column = col 617 } 618 dst = append(dst, expr) 619 case clause.Gte: 620 if col, ok := changeTableFunc(expr.Column); ok { 621 expr.Column = col 622 } 623 dst = append(dst, expr) 624 case clause.Lt: 625 if col, ok := changeTableFunc(expr.Column); ok { 626 expr.Column = col 627 } 628 dst = append(dst, expr) 629 case clause.Lte: 630 if col, ok := changeTableFunc(expr.Column); ok { 631 expr.Column = col 632 } 633 dst = append(dst, expr) 634 case clause.Like: 635 if col, ok := changeTableFunc(expr.Column); ok { 636 expr.Column = col 637 } 638 dst = append(dst, expr) 639 default: 640 dst = append(dst, expr) 641 } 642 } 643 return 644 } 645 changeClausesMapping := map[string]func(cls clause.Clause){ 646 "WHERE": func(cls clause.Clause) { 647 whereClause, ok := cls.Expression.(clause.Where) 648 if !ok { 649 return 650 } 651 whereClause.Exprs = changeExprFunc(whereClause.Exprs) 652 cls.Expression = whereClause 653 db.Statement.Clauses["WHERE"] = cls 654 }, 655 "FROM": func(cls clause.Clause) { 656 fromClause, ok := cls.Expression.(clause.From) 657 if !ok { 658 return 659 } 660 tables := make([]clause.Table, 0, len(fromClause.Tables)) 661 for _, table := range fromClause.Tables { 662 if table.Name == t.config.Table { 663 table.Name = db.Statement.Table 664 tables = append(tables, table) 665 } else { 666 tables = append(tables, table) 667 } 668 } 669 fromClause.Tables = tables 670 cls.Expression = fromClause 671 db.Statement.Clauses["FROM"] = cls 672 }, 673 // TODO: check if order by contains table name 674 "ORDER BY": func(cls clause.Clause) { 675 _, ok := cls.Expression.(clause.OrderBy) 676 if !ok { 677 return 678 } 679 }, 680 } 681 682 for name, cls := range db.Statement.Clauses { 683 if mappingFunc, ok := changeClausesMapping[name]; ok { 684 mappingFunc(cls) 685 } 686 } 687 688 if opt.isInsert { 689 db.Clauses(clause.Insert{Table: clause.Table{Name: db.Statement.Table}}) 690 } else { 691 db.Clauses(clause.From{Tables: []clause.Table{{Name: db.Statement.Table}}}) 692 } 693 } 694 695 func (t *tableSharding) replaceCondition(conditions sqlparser.Expr, oldTableName, newTableName string) (err error) { 696 err = sqlparser.Walk( 697 sqlparser.VisitFunc(func(node sqlparser.Node) (err error) { 698 n, ok := node.(*sqlparser.BinaryExpr) 699 if !ok { 700 return 701 } 702 703 x, ok := n.X.(*sqlparser.QualifiedRef) 704 if !ok || x.Table == nil || x.Table.Name != oldTableName { 705 return 706 } 707 708 x.Table.Name = newTableName 709 return 710 }), 711 conditions, 712 ) 713 return 714 } 715 716 func (t *tableSharding) replaceConstraint(constraints sqlparser.Node, oldTableName, newTableName string) (err error) { 717 return sqlparser.Walk( 718 sqlparser.VisitFunc(func(node sqlparser.Node) (err error) { 719 n, ok := node.(*sqlparser.QualifiedRef) 720 if !ok || n.Table == nil || n.Table.Name != oldTableName { 721 return 722 } 723 724 n.Table.Name = newTableName 725 return 726 }), 727 constraints, 728 ) 729 } 730 731 func (t *tableSharding) insertValue(keys []string, names []*sqlparser.Ident, exprs []sqlparser.Expr, args ...any) ( 732 values []any, id uint64, err error) { 733 if len(names) != len(exprs) { 734 return nil, 0, ErrColumnAndExprMisMatch 735 } 736 737 for _, key := range keys { 738 found := false 739 isPrimaryKey := key == t.shardingPrimaryKey 740 for i, name := range names { 741 if name.Name != key { 742 continue 743 } 744 745 switch expr := exprs[i].(type) { 746 case *sqlparser.BindExpr: 747 if !isPrimaryKey { 748 values = append(values, args[expr.Pos]) 749 } else { 750 switch v := args[expr.Pos].(type) { 751 case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, string: 752 if id, err = cast.ToUint64E(v); err != nil { 753 return nil, 0, errors.Wrapf(err, "parse id as uint64 failed [%v]", v) 754 } 755 default: 756 return nil, 0, ErrInvalidID 757 } 758 if id != 0 { 759 values = append(values, args[expr.Pos]) 760 } 761 } 762 case *sqlparser.StringLit: 763 if !isPrimaryKey { 764 values = append(values, expr.Value) 765 } else { 766 if id, err = cast.ToUint64E(expr.Value); err != nil { 767 return nil, 0, errors.Wrapf(err, "parse id as uint64 failed [%s]", expr.Value) 768 } 769 if id != 0 { 770 values = append(values, expr.Value) 771 } 772 } 773 case *sqlparser.NumberLit: 774 if !isPrimaryKey { 775 values = append(values, expr.Value) 776 } else { 777 if id, err = strconv.ParseUint(expr.Value, 10, 64); err != nil { 778 return nil, 0, errors.Wrapf(err, 779 "parse id as uint64 failed [%s]", expr.Value) 780 } 781 if id != 0 { 782 values = append(values, expr.Value) 783 } 784 } 785 default: 786 return nil, 0, sqlparser.ErrNotImplemented 787 } 788 789 found = true 790 break 791 } 792 if !found && !isPrimaryKey { 793 return nil, 0, ErrMissingShardingKey 794 } 795 } 796 797 return 798 } 799 800 func (t *tableSharding) nonInsertValue(condition sqlparser.Node, key, tableName string, args ...any) ( 801 value any, err error) { 802 found := false 803 err = sqlparser.Walk( 804 sqlparser.VisitFunc(func(node sqlparser.Node) (err error) { 805 n, ok := node.(*sqlparser.BinaryExpr) 806 if !ok { 807 return 808 } 809 if n.Op != sqlparser.EQ { 810 return 811 } 812 813 switch x := n.X.(type) { 814 case *sqlparser.Ident: 815 if x.Name != key { 816 return 817 } 818 case *sqlparser.QualifiedRef: 819 if !ok || x.Table.Name != tableName || x.Column.Name != key { 820 return 821 } 822 } 823 824 found = true 825 switch expr := n.Y.(type) { 826 case *sqlparser.BindExpr: 827 value = args[expr.Pos] 828 case *sqlparser.StringLit: 829 value = expr.Value 830 case *sqlparser.NumberLit: 831 value = expr.Value 832 default: 833 return sqlparser.ErrNotImplemented 834 } 835 836 return 837 }), 838 condition, 839 ) 840 if err != nil { 841 return 842 } 843 if !found { 844 return nil, ErrMissingShardingKey 845 } 846 return 847 } 848 849 func (t *tableSharding) setPrimaryKeyByModel(db *gorm.DB, opt *tableShardingDispatchOption) (err error) { 850 if !opt.isInsert || db.Statement.Model == nil || 851 db.Statement.Schema == nil || db.Statement.Schema.PrioritizedPrimaryField == nil { 852 return 853 } 854 setPrimaryKeyFunc := func(rv reflect.Value) (err error) { 855 _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv) 856 if !isZero { 857 return 858 } 859 if t.config.PrimaryKeyGenerator == nil { 860 return ErrIDGeneratorNotFound 861 } 862 id, err := t.config.PrimaryKeyGenerator.Next(idgen.GormTx(db)) 863 if err != nil { 864 return 865 } 866 return db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, id) 867 } 868 869 switch db.Statement.ReflectValue.Kind() { 870 case reflect.Slice, reflect.Array: 871 for i := 0; i < db.Statement.ReflectValue.Len(); i++ { 872 rv := db.Statement.ReflectValue.Index(i) 873 if reflect.Indirect(rv).Kind() != reflect.Struct { 874 break 875 } 876 877 if err = setPrimaryKeyFunc(rv); err != nil { 878 return 879 } 880 } 881 case reflect.Struct: 882 if err = setPrimaryKeyFunc(db.Statement.ReflectValue); err != nil { 883 return 884 } 885 } 886 887 return 888 } 889 890 func (t *tableSharding) getModelReflectValue(db *gorm.DB) (reflectVal reflect.Value, ok bool) { 891 reflectVal = utils.IndirectValue(db.Statement.ReflectValue) 892 if reflectVal.Kind() == reflect.Array || reflectVal.Kind() == reflect.Slice { 893 if reflectVal.Len() == 0 { 894 return 895 } 896 reflectVal = utils.IndirectValue(reflectVal.Index(0)) 897 } 898 899 if reflectVal.Kind() != reflect.Struct { 900 return 901 } 902 903 return reflectVal, !utils.IsBlank(reflectVal.Interface()) 904 } 905 906 func (t *tableSharding) checkDiffSuffixesByModel(db *gorm.DB) (err error) { 907 reflectVal := utils.IndirectValue(db.Statement.ReflectValue) 908 if reflectVal.Kind() != reflect.Array && reflectVal.Kind() != reflect.Slice { 909 return 910 } 911 912 suffix := "" 913 for i := 0; i < reflectVal.Len(); i++ { 914 reflectItemVal := reflect.Indirect(reflectVal.Index(i)) 915 values := make([]any, 0, len(t.config.ShardingKeys)) 916 for _, key := range t.config.ShardingKeys { 917 val := reflectItemVal.FieldByNameFunc(func(v string) bool { return strings.EqualFold(v, key) }) 918 if !val.IsValid() { 919 val, _ = utils.GetGormColumnValue(reflectItemVal, key) 920 } 921 if !val.IsValid() { 922 return db.AddError(ErrMissingShardingKey) 923 } 924 values = append(values, val.Interface()) 925 } 926 subSuffix, err := t.shardingFunc(db.Statement.Context, values...) 927 if err != nil { 928 return db.AddError(err) 929 } 930 if suffix != "" && suffix != subSuffix { 931 return db.AddError(ErrDiffSuffixDML) 932 } 933 suffix = subSuffix 934 } 935 return 936 } 937 938 func (t *tableSharding) replaceOrderByTableName( 939 orderBy []*sqlparser.OrderingTerm, oldName, newName string) []*sqlparser.OrderingTerm { 940 for i, term := range orderBy { 941 if x, ok := term.X.(*sqlparser.QualifiedRef); ok { 942 if x.Table.Name == oldName { 943 x.Table.Name = newName 944 orderBy[i].X = x 945 } 946 } 947 } 948 return orderBy 949 } 950 951 func (t *tableSharding) createTableIfNotExists(db *gorm.DB, tableName, suffix string) (err error) { 952 shardingTableName := tableName + suffix 953 t.shardingTableCreatedMutex.RLock() 954 if _, ok := t.shardingTableCreated[shardingTableName]; ok { 955 t.shardingTableCreatedMutex.RUnlock() 956 return 957 } 958 t.shardingTableCreatedMutex.RUnlock() 959 t.shardingTableCreatedMutex.Lock() 960 defer t.shardingTableCreatedMutex.Unlock() 961 962 defer t.ignore(t.DB)() //nolint: revive // partial calling issue 963 if t.DB.Migrator().HasTable(shardingTableName) { 964 t.shardingTableCreated[shardingTableName] = struct{}{} 965 return 966 } 967 968 model := db.Statement.Model 969 if model == nil { 970 model = t.shardingTableModel 971 } 972 if model == nil { 973 return ErrShardingModelNotFound 974 } 975 tx := t.DB.Session(&gorm.Session{}).Table(shardingTableName) 976 if err = db.Dialector.Migrator(tx).AutoMigrate(db.Statement.Model); err != nil { 977 return err 978 } 979 t.shardingTableCreated[shardingTableName] = struct{}{} 980 return 981 } 982 983 func (t *tableSharding) suffixes() (suffixes []string, err error) { 984 switch { 985 case t.config.ShardingKeyByRawValue: 986 if len(t.config.ShardingKeysForMigrating) == 0 { 987 return nil, errors.New("sharding key by raw value but do not configure keys for migrating") 988 } 989 990 for _, shardingKey := range t.config.ShardingKeysForMigrating { 991 suffixes = append(suffixes, fmt.Sprintf(t.suffixFormat, shardingKey)) 992 } 993 default: 994 for i := 0; i < int(t.config.NumberOfShards); i++ { 995 suffixes = append(suffixes, fmt.Sprintf(t.suffixFormat, i)) 996 } 997 } 998 return 999 } 1000 1001 func (t *tableSharding) ignore(db *gorm.DB) func() { 1002 if _, ok := db.Statement.Settings.Load(shardingIgnoreStoreKey); ok { 1003 return func() {} 1004 } 1005 db.Statement.Settings.Store(shardingIgnoreStoreKey, nil) 1006 return func() { 1007 db.Statement.Settings.Delete(shardingIgnoreStoreKey) 1008 } 1009 } 1010 func (t *tableSharding) isIgnored(db *gorm.DB) func() bool { 1011 return func() bool { 1012 _, ok := db.Statement.Settings.Load(shardingIgnoreStoreKey) 1013 return ok 1014 } 1015 } 1016 1017 func (t *tableSharding) defaultShardingFunc() func(ctx context.Context, values ...any) (suffix string, err error) { 1018 if !t.config.ShardingKeyByRawValue && t.config.NumberOfShards == 0 { 1019 panic(errors.New("missing number_of_shards config")) 1020 } 1021 t.suffixFormat = constant.Underline 1022 1023 switch { 1024 case utils.IsStrNotBlank(t.config.CustomSuffix): 1025 t.suffixFormat += t.config.CustomSuffix 1026 case t.config.ShardingKeyByRawValue: 1027 t.suffixFormat += "%s" 1028 default: 1029 t.suffixFormat += strings.Join(t.config.ShardingKeys, constant.Underline) 1030 } 1031 1032 numberOfShards := t.config.NumberOfShards 1033 if !strings.Contains(t.suffixFormat, "%") { 1034 if t.config.ShardingKeyByRawValue { 1035 t.suffixFormat += "_%s" 1036 } else if numberOfShards < 10 { 1037 t.suffixFormat += "_%01d" 1038 } else if numberOfShards < 100 { 1039 t.suffixFormat += "_%02d" 1040 } else if numberOfShards < 1000 { 1041 t.suffixFormat += "_%03d" 1042 } else if numberOfShards < 10000 { 1043 t.suffixFormat += "_%04d" 1044 } 1045 } 1046 1047 switch { 1048 case t.config.ShardingKeyByRawValue: 1049 return func(ctx context.Context, values ...any) (suffix string, err error) { 1050 data := make([]string, 0, len(values)) 1051 for _, value := range values { 1052 v, err := cast.ToStringE(value) 1053 if err != nil { 1054 return "", err 1055 } 1056 data = append(data, v) 1057 } 1058 shardingKey := strings.Join(data, constant.Underline) 1059 return fmt.Sprintf("_%s", shardingKey), nil 1060 } 1061 case t.config.ShardingKeyExpr != nil: 1062 numberOfShardsFloat64 := float64(numberOfShards) 1063 return func(ctx context.Context, values ...any) (suffix string, err error) { 1064 params := make(map[string]any, len(t.config.ShardingKeys)) 1065 for idx, column := range t.config.ShardingKeys { 1066 params[column] = values[idx] 1067 } 1068 1069 result, err := t.config.ShardingKeyExpr(ctx, params) 1070 if err != nil { 1071 return 1072 } 1073 shardingKey := int64(math.Mod(cast.ToFloat64(result), numberOfShardsFloat64)) 1074 return fmt.Sprintf(t.suffixFormat, shardingKey), nil 1075 } 1076 default: 1077 stringToByteSliceFunc := func(v string) (data []byte) { 1078 utils.IfAny( 1079 // number 1080 func() (ok bool) { 1081 num := new(big.Float) 1082 if _, ok = num.SetString(v); !ok { 1083 return 1084 } 1085 gobEncoded, err := num.GobEncode() 1086 if err != nil { 1087 return false 1088 } 1089 data = gobEncoded 1090 return 1091 }, 1092 // uuid 1093 func() bool { 1094 uid, err := uuid.Parse(v) 1095 if err != nil { 1096 return false 1097 } 1098 data = uid[:] 1099 return true 1100 }, 1101 // bytes 1102 func() bool { data = []byte(v); return true }, 1103 ) 1104 return 1105 } 1106 return func(ctx context.Context, values ...any) (suffix string, err error) { 1107 size := 0 1108 for _, value := range values { 1109 s := binary.Size(value) 1110 if s <= 0 { 1111 s = int(unsafe.Sizeof(value)) 1112 } 1113 size += s 1114 } 1115 w := new(bytes.Buffer) 1116 w.Grow(size) 1117 1118 for _, value := range values { 1119 var data any 1120 switch v := value.(type) { 1121 case int, *int: 1122 data = utils.IntNarrow(cast.ToInt(v)) 1123 case uint, *uint: 1124 data = utils.UintNarrow(cast.ToUint(v)) 1125 case []int: 1126 data = make([]any, len(v)) 1127 for i := 0; i < len(v); i++ { 1128 data.([]any)[i] = utils.IntNarrow(cast.ToInt(v)) 1129 } 1130 case []uint: 1131 data = make([]any, len(v)) 1132 for i := 0; i < len(v); i++ { 1133 data.([]any)[i] = utils.UintNarrow(cast.ToUint(v)) 1134 } 1135 case string: 1136 data = stringToByteSliceFunc(v) 1137 case []byte: 1138 data = stringToByteSliceFunc(utils.UnsafeBytesToString(v)) 1139 case uuid.UUID: 1140 data = v[:] 1141 default: 1142 data = v 1143 } 1144 if err = binary.Write(w, binary.BigEndian, data); err != nil { 1145 return 1146 } 1147 } 1148 1149 // checksum mod shards 1150 checksum := crc32.ChecksumIEEE(w.Bytes()) 1151 shardingKey := uint64(checksum) % uint64(numberOfShards) 1152 suffix = fmt.Sprintf(t.suffixFormat, shardingKey) 1153 return 1154 } 1155 } 1156 } 1157 1158 type shardingDialector struct { 1159 gorm.Dialector 1160 shardingMap map[string]*tableSharding 1161 } 1162 1163 func newShardingDialector(d gorm.Dialector, s *tableSharding) shardingDialector { 1164 if sd, ok := d.(shardingDialector); ok { 1165 sd.shardingMap[s.config.Table] = s 1166 return sd 1167 } 1168 1169 return shardingDialector{ 1170 Dialector: d, 1171 shardingMap: map[string]*tableSharding{s.config.Table: s}, 1172 } 1173 } 1174 1175 func (s shardingDialector) Migrator(db *gorm.DB) gorm.Migrator { 1176 m := s.Dialector.Migrator(db) 1177 if (*tableSharding)(nil).isIgnored(db)() { 1178 return m 1179 } 1180 return &shardingMigrator{ 1181 Migrator: m, 1182 db: db, 1183 shardingMap: s.shardingMap, 1184 dialector: s.Dialector, 1185 } 1186 } 1187 func (s shardingDialector) SavePoint(tx *gorm.DB, name string) error { 1188 if savePointer, ok := s.Dialector.(gorm.SavePointerDialectorInterface); ok { 1189 return savePointer.SavePoint(tx, name) 1190 } else { 1191 return gorm.ErrUnsupportedDriver 1192 } 1193 } 1194 func (s shardingDialector) RollbackTo(tx *gorm.DB, name string) error { 1195 if savePointer, ok := s.Dialector.(gorm.SavePointerDialectorInterface); ok { 1196 return savePointer.RollbackTo(tx, name) 1197 } else { 1198 return gorm.ErrUnsupportedDriver 1199 } 1200 } 1201 1202 type shardingMigrator struct { 1203 gorm.Migrator 1204 db *gorm.DB 1205 dialector gorm.Dialector 1206 shardingMap map[string]*tableSharding 1207 } 1208 1209 func (s *shardingMigrator) AutoMigrate(dst ...any) (err error) { 1210 sharding, ok := s.shardingMap[s.tableName(s.db, dst[0])] 1211 if !ok { 1212 defer (*tableSharding)(nil).ignore(s.db)() //nolint: revive // partial calling issue 1213 return s.Migrator.AutoMigrate(dst...) 1214 } 1215 1216 stmt := &gorm.Statement{DB: sharding.DB} 1217 if sharding.isIgnored(sharding.DB)() { 1218 return s.dialector.Migrator(stmt.DB.Session(&gorm.Session{})).AutoMigrate(dst...) 1219 } 1220 1221 shardingDst, err := s.getShardingDst(sharding, dst...) 1222 if err != nil { 1223 return err 1224 } 1225 1226 defer sharding.ignore(sharding.DB)() //nolint: revive // partial calling issue 1227 for _, sd := range shardingDst { 1228 tx := stmt.DB.Session(&gorm.Session{}).Table(sd.table) 1229 if err = s.dialector.Migrator(tx).AutoMigrate(sd.dst); err != nil { 1230 return err 1231 } 1232 } 1233 1234 return 1235 } 1236 func (s *shardingMigrator) DropTable(dst ...any) (err error) { 1237 sharding, ok := s.shardingMap[s.tableName(s.db, dst[0])] 1238 if !ok { 1239 defer (*tableSharding)(nil).ignore(s.db)() //nolint: revive // partial calling issue 1240 return s.Migrator.DropTable(dst...) 1241 } 1242 1243 stmt := &gorm.Statement{DB: sharding.DB} 1244 if sharding.isIgnored(sharding.DB)() { 1245 return s.dialector.Migrator(stmt.DB.Session(&gorm.Session{})).DropTable(dst...) 1246 } 1247 shardingDst, err := s.getShardingDst(sharding, dst...) 1248 if err != nil { 1249 return err 1250 } 1251 1252 defer sharding.ignore(sharding.DB)() //nolint: revive // partial calling issue 1253 for _, sd := range shardingDst { 1254 tx := stmt.DB.Session(&gorm.Session{}).Table(sd.table) 1255 if err = s.dialector.Migrator(tx).DropTable(sd.table); err != nil { 1256 return err 1257 } 1258 } 1259 1260 return 1261 } 1262 1263 type shardingDst struct { 1264 table string 1265 dst any 1266 } 1267 1268 func (s *shardingMigrator) getShardingDst(sharding *tableSharding, src ...any) (dst []shardingDst, err error) { 1269 for _, model := range src { 1270 stmt := &gorm.Statement{DB: sharding.DB} 1271 if err = stmt.Parse(model); err != nil { 1272 return 1273 } 1274 1275 // support sharding table 1276 suffixes, err := sharding.suffixes() 1277 if err != nil { 1278 return nil, err 1279 } 1280 if len(suffixes) == 0 { 1281 return nil, fmt.Errorf("sharding table:%s suffixes are empty", stmt.Table) 1282 } 1283 for _, suffix := range suffixes { 1284 dst = append(dst, shardingDst{ 1285 table: stmt.Table + suffix, 1286 dst: model, 1287 }) 1288 } 1289 } 1290 return 1291 } 1292 func (s *shardingMigrator) tableName(db *gorm.DB, m any) (name string) { 1293 if tabler, ok := m.(schema.Tabler); ok { 1294 name = tabler.TableName() 1295 } 1296 if tabler, ok := m.(schema.TablerWithNamer); ok { 1297 name = tabler.TableName(db.NamingStrategy) 1298 } 1299 namingStrategy := reflect.ValueOf(db.NamingStrategy) 1300 if namingStrategy.CanConvert(gormSchemaEmbeddedNamer) { 1301 name = reflect.Indirect(namingStrategy.Convert(gormSchemaEmbeddedNamer)).FieldByName("Table").String() 1302 } 1303 return 1304 }