github.com/wfusion/gofusion@v1.1.14/db/plugins/table_sharding.go (about)

     1  package plugins
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/binary"
     7  	"fmt"
     8  	"hash/crc32"
     9  	"math"
    10  	"math/big"
    11  	"reflect"
    12  	"strconv"
    13  	"strings"
    14  	"sync"
    15  	"unsafe"
    16  
    17  	"github.com/PaesslerAG/gval"
    18  	"github.com/google/uuid"
    19  	"github.com/pkg/errors"
    20  	"github.com/spf13/cast"
    21  	"gorm.io/gorm"
    22  	"gorm.io/gorm/clause"
    23  	"gorm.io/gorm/schema"
    24  
    25  	"github.com/wfusion/gofusion/common/constant"
    26  	"github.com/wfusion/gofusion/common/infra/drivers/orm/idgen"
    27  	"github.com/wfusion/gofusion/common/utils"
    28  	"github.com/wfusion/gofusion/common/utils/clone"
    29  	"github.com/wfusion/gofusion/common/utils/inspect"
    30  	"github.com/wfusion/gofusion/common/utils/sqlparser"
    31  	"github.com/wfusion/gofusion/db/callbacks"
    32  )
    33  
    34  const (
    35  	shardingIgnoreStoreKey = "sharding_ignore"
    36  )
    37  
    38  var (
    39  	ErrInvalidID             = errors.New("invalid id format")
    40  	ErrIDGeneratorNotFound   = errors.New("id generator not found")
    41  	ErrShardingModelNotFound = errors.New("sharding table model not found when migrating")
    42  	ErrDiffSuffixDML         = errors.New("can not query different suffix table in one sql")
    43  	ErrMissingShardingKey    = errors.New("sharding key required and use operator =")
    44  	ErrColumnAndExprMisMatch = errors.New("column names and expressions mismatch")
    45  
    46  	gormSchemaEmbeddedNamer = inspect.TypeOf("gorm.io/gorm/schema.embeddedNamer")
    47  )
    48  
    49  type TableShardingConfig struct {
    50  	// Database name
    51  	Database string
    52  
    53  	// Table name
    54  	Table string
    55  
    56  	// ShardingKeys required, specifies the table columns you want to use for sharding the table rows.
    57  	// For example, for a product order table, you may want to split the rows by `user_id`.
    58  	ShardingKeys []string
    59  
    60  	// ShardingKeyExpr optional, specifies how to calculate sharding key by columns, e.g. tenant_id << 16 | user_id
    61  	ShardingKeyExpr gval.Evaluable
    62  
    63  	// ShardingKeyByRawValue optional, specifies sharding key with snake values, e.g. xxx_region1_az1, xxx_region1_az2
    64  	ShardingKeyByRawValue bool
    65  
    66  	// ShardingKeysForMigrating optional, specifies all sharding keys
    67  	ShardingKeysForMigrating []string
    68  
    69  	// NumberOfShards required, specifies how many tables you want to sharding.
    70  	NumberOfShards uint
    71  
    72  	// CustomSuffix optional, specifies shard table a custom suffix, e.g. user_%02d means <main_table_name>_user_01
    73  	CustomSuffix string
    74  
    75  	// PrimaryKeyGenerator optional, generates id if id is a sharding key and is zero
    76  	PrimaryKeyGenerator idgen.Generator
    77  }
    78  
    79  // sharding plugin inspired by gorm.io/sharding@v0.5.3
    80  type tableSharding struct {
    81  	*gorm.DB
    82  
    83  	config TableShardingConfig
    84  
    85  	shardingFunc              func(ctx context.Context, values ...any) (suffix string, err error)
    86  	isShardingPrimaryKey      bool
    87  	shardingPrimaryKey        string
    88  	shardingTableModel        any
    89  	shardingTableCreatedMutex sync.RWMutex
    90  	shardingTableCreated      map[string]struct{}
    91  
    92  	suffixFormat string
    93  }
    94  
    95  func DefaultTableSharding(config TableShardingConfig) TableSharding {
    96  	if utils.IsStrBlank(config.Table) {
    97  		panic(errors.New("missing sharding table name"))
    98  	}
    99  	if len(config.ShardingKeys) == 0 {
   100  		panic(errors.New("missing sharding keys"))
   101  	}
   102  	if !config.ShardingKeyByRawValue && (config.NumberOfShards <= 0 || config.NumberOfShards >= 100000) {
   103  		panic(errors.New("invalid number of shards"))
   104  	}
   105  
   106  	shardingKeySet := utils.NewSet(config.ShardingKeys...)
   107  	shardingPrimaryKey := ""
   108  	isShardingPrimaryKey := false
   109  	if shardingKeySet.Contains("id") || shardingKeySet.Contains("ID") ||
   110  		shardingKeySet.Contains("iD") || shardingKeySet.Contains("Id") {
   111  		if config.PrimaryKeyGenerator == nil {
   112  			panic(errors.New("sharding by primary key but primary key generator not found"))
   113  		}
   114  
   115  		isShardingPrimaryKey = true
   116  		for _, key := range config.ShardingKeys {
   117  			if key == "id" || key == "ID" || key == "Id" || key == "iD" {
   118  				shardingPrimaryKey = key
   119  				break
   120  			}
   121  		}
   122  	}
   123  
   124  	return &tableSharding{
   125  		config:               config,
   126  		isShardingPrimaryKey: isShardingPrimaryKey,
   127  		shardingPrimaryKey:   shardingPrimaryKey,
   128  		shardingTableCreated: make(map[string]struct{}, config.NumberOfShards),
   129  	}
   130  }
   131  
   132  func (t *tableSharding) Name() string {
   133  	return fmt.Sprintf("gorm:sharding:%s:%s", t.config.Database, t.config.Table)
   134  }
   135  
   136  func (t *tableSharding) Initialize(db *gorm.DB) (err error) {
   137  	db.Dialector = newShardingDialector(db.Dialector, t)
   138  
   139  	t.DB = db
   140  	t.shardingFunc = t.defaultShardingFunc()
   141  	t.registerCallbacks(db)
   142  	return
   143  }
   144  
   145  func (t *tableSharding) ShardingByModelList(ctx context.Context, src ...any) (dst map[string][]any, err error) {
   146  	dst = make(map[string][]any, len(t.config.ShardingKeys))
   147  	for _, m := range src {
   148  		val := reflect.Indirect(reflect.ValueOf(m))
   149  		shardingValues := make([]any, 0, len(t.config.ShardingKeys))
   150  		for _, key := range t.config.ShardingKeys {
   151  			field := val.FieldByNameFunc(func(v string) bool { return strings.EqualFold(v, key) })
   152  			if !field.IsValid() {
   153  				field, _ = utils.GetGormColumnValue(val, key)
   154  			}
   155  			if !field.IsValid() {
   156  				return dst, ErrMissingShardingKey
   157  			}
   158  			if key == t.shardingPrimaryKey && field.IsZero() {
   159  				return dst, ErrInvalidID
   160  			}
   161  			shardingValues = append(shardingValues, field.Interface())
   162  		}
   163  		suffix, err := t.shardingFunc(ctx, shardingValues...)
   164  		if err != nil {
   165  			return dst, err
   166  		}
   167  		dst[suffix] = append(dst[suffix], m)
   168  	}
   169  	return
   170  }
   171  
   172  func (t *tableSharding) ShardingByValues(ctx context.Context, src []map[string]any) (
   173  	dst map[string][]map[string]any, err error) {
   174  	dst = make(map[string][]map[string]any, len(t.config.ShardingKeys))
   175  	for _, col := range src {
   176  		values := make([]any, 0, len(col))
   177  		for _, k := range t.config.ShardingKeys {
   178  			value, ok := col[k]
   179  			if !ok {
   180  				return dst, errors.Errorf("sharding key not found [column[%s]]", k)
   181  			}
   182  			if k == t.shardingPrimaryKey && utils.IsBlank(value) {
   183  				return dst, ErrInvalidID
   184  			}
   185  			values = append(values, value)
   186  		}
   187  		suffix, err := t.shardingFunc(ctx, values...)
   188  		if err != nil {
   189  			return dst, err
   190  		}
   191  		dst[suffix] = append(dst[suffix], col)
   192  	}
   193  	return
   194  }
   195  
   196  func (t *tableSharding) ShardingIDGen(ctx context.Context) (id uint64, err error) {
   197  	if t.config.PrimaryKeyGenerator == nil {
   198  		return 0, ErrIDGeneratorNotFound
   199  	}
   200  	return t.config.PrimaryKeyGenerator.Next()
   201  }
   202  
   203  func (t *tableSharding) registerCallbacks(db *gorm.DB) {
   204  	utils.MustSuccess(db.Callback().
   205  		Create().
   206  		After("gorm:before_create").
   207  		Before("gorm:save_before_associations").
   208  		Register(t.Name(), t.createCallback))
   209  
   210  	utils.MustSuccess(db.Callback().
   211  		Query().
   212  		Before("gorm:query").
   213  		Register(t.Name(), t.queryCallback))
   214  
   215  	utils.MustSuccess(db.Callback().
   216  		Update().
   217  		After("gorm:before_update").
   218  		Before("gorm:save_before_associations").
   219  		Register(t.Name(), t.updateCallback))
   220  
   221  	utils.MustSuccess(db.Callback().
   222  		Delete().
   223  		After("gorm:before_delete").
   224  		Before("gorm:delete_before_associations").
   225  		Register(t.Name(), t.deleteCallback))
   226  
   227  	utils.MustSuccess(db.Callback().
   228  		Row().
   229  		Before("gorm:row").
   230  		Register(t.Name(), t.queryCallback))
   231  
   232  	utils.MustSuccess(db.Callback().
   233  		Raw().
   234  		Before("gorm:raw").
   235  		Register(t.Name(), t.rawCallback))
   236  }
   237  func (t *tableSharding) createCallback(db *gorm.DB) {
   238  	utils.IfAny(
   239  		t.isIgnored(db),
   240  		func() bool { ok1, ok2 := t.dispatchTableByModel(db, tableShardingIsInsert()); return ok1 || ok2 },
   241  		func() bool {
   242  			callbacks.BuildCreateSQL(db)
   243  			t.wrapDispatchTableBySQL(db, tableShardingIsInsert())
   244  			return true
   245  		},
   246  	)
   247  }
   248  func (t *tableSharding) queryCallback(db *gorm.DB) {
   249  	utils.IfAny(
   250  		t.isIgnored(db),
   251  		func() bool { ok1, ok2 := t.dispatchTableByModel(db); return ok1 || ok2 },
   252  		func() bool {
   253  			callbacks.BuildQuerySQL(db)
   254  			t.wrapDispatchTableBySQL(db)
   255  			return true
   256  		},
   257  	)
   258  }
   259  func (t *tableSharding) updateCallback(db *gorm.DB) {
   260  	utils.IfAny(
   261  		t.isIgnored(db),
   262  		func() bool { ok1, ok2 := t.dispatchTableByModel(db); return ok1 || ok2 },
   263  		func() bool {
   264  			callbacks.BuildUpdateSQL(db)
   265  			t.wrapDispatchTableBySQL(db)
   266  			return true
   267  		},
   268  	)
   269  }
   270  func (t *tableSharding) deleteCallback(db *gorm.DB) {
   271  	utils.IfAny(
   272  		t.isIgnored(db),
   273  		func() bool { ok1, ok2 := t.dispatchTableByModel(db); return ok1 || ok2 },
   274  		func() bool {
   275  			callbacks.BuildDeleteSQL(db)
   276  			t.wrapDispatchTableBySQL(db)
   277  			return true
   278  		},
   279  	)
   280  }
   281  func (t *tableSharding) rawCallback(db *gorm.DB) {
   282  	utils.IfAny(
   283  		t.isIgnored(db),
   284  		func() bool { ok1, ok2 := t.dispatchTableByModel(db); return ok1 || ok2 },
   285  		func() bool { t.wrapDispatchTableBySQL(db); return true },
   286  	)
   287  }
   288  
   289  type tableShardingDispatchOption struct {
   290  	isInsert bool
   291  }
   292  
   293  func tableShardingIsInsert() utils.OptionFunc[tableShardingDispatchOption] {
   294  	return func(t *tableShardingDispatchOption) {
   295  		t.isInsert = true
   296  	}
   297  }
   298  
   299  func (t *tableSharding) dispatchTableByModel(db *gorm.DB, opts ...utils.OptionExtender) (otherTable, ok bool) {
   300  	if db.Statement.Model == nil || utils.IsBlank(db.Statement.ReflectValue.Interface()) {
   301  		return
   302  	}
   303  	if db.Statement.Table != t.config.Table {
   304  		otherTable = true
   305  		return
   306  	}
   307  	if t.shardingTableModel == nil {
   308  		if _, ok := db.Statement.Model.(schema.Tabler); ok {
   309  			cloneModel := clone.Clone(db.Statement.Model)
   310  			t.shardingTableModel = cloneModel
   311  		}
   312  	}
   313  
   314  	opt := utils.ApplyOptions[tableShardingDispatchOption](opts...)
   315  	if t.isShardingPrimaryKey {
   316  		if err := t.setPrimaryKeyByModel(db, opt); err != nil {
   317  			_ = db.AddError(err)
   318  			return
   319  		}
   320  	}
   321  
   322  	reflectVal, ok := t.getModelReflectValue(db)
   323  	if !ok {
   324  		return
   325  	}
   326  	if err := t.checkDiffSuffixesByModel(db); err != nil {
   327  		return
   328  	}
   329  
   330  	values := make([]any, 0, len(t.config.ShardingKeys))
   331  	for _, key := range t.config.ShardingKeys {
   332  		val := reflectVal.FieldByNameFunc(func(v string) bool { return strings.EqualFold(v, key) })
   333  		if !val.IsValid() {
   334  			val, _ = utils.GetGormColumnValue(reflectVal, key)
   335  		}
   336  		if !val.IsValid() {
   337  			_ = db.AddError(ErrMissingShardingKey)
   338  			return
   339  		}
   340  		values = append(values, val.Interface())
   341  	}
   342  
   343  	suffix, err := t.shardingFunc(db.Statement.Context, values...)
   344  	if err != nil {
   345  		_ = db.AddError(err)
   346  		return
   347  	}
   348  	// cannot parse suffix from model
   349  	if utils.IsStrBlank(suffix) || suffix == constant.Underline {
   350  		return false, false
   351  	}
   352  	if err = t.createTableIfNotExists(db, db.Statement.Table, suffix); err != nil {
   353  		_ = db.AddError(err)
   354  		return
   355  	}
   356  
   357  	db.Statement.Table = db.Statement.Table + suffix
   358  	t.replaceStatementClauseAndSchema(db, opt)
   359  	ok = true
   360  	return
   361  }
   362  
   363  //nolint: revive // sql parser issue
   364  func (t *tableSharding) dispatchTableBySQL(db *gorm.DB, opts ...utils.OptionExtender) (ok bool, err error) {
   365  	expr, err := sqlparser.NewParser(strings.NewReader(db.Statement.SQL.String())).ParseStatement()
   366  	if err != nil {
   367  		// maybe not a dml, so we ignore this error
   368  		return
   369  	}
   370  
   371  	getSuffix := func(condition sqlparser.Node, tableName string, vars ...any) (suffix string, err error) {
   372  		values := make([]any, 0, len(t.config.ShardingKeys))
   373  		for _, key := range t.config.ShardingKeys {
   374  			val, err := t.nonInsertValue(condition, key, tableName, vars...)
   375  			if err != nil {
   376  				return "", db.AddError(err)
   377  			}
   378  			values = append(values, val)
   379  		}
   380  
   381  		suffix, err = t.shardingFunc(db.Statement.Context, values...)
   382  		if err != nil {
   383  			return "", db.AddError(err)
   384  		}
   385  		return
   386  	}
   387  
   388  	newSQL := ""
   389  	switch stmt := expr.(type) {
   390  	case *sqlparser.InsertStatement:
   391  		if stmt.TableName.TableName() != t.config.Table {
   392  			return
   393  		}
   394  
   395  		suffix := ""
   396  		for _, insertExpression := range stmt.Expressions {
   397  			values, id, e := t.insertValue(t.config.ShardingKeys, stmt.ColumnNames,
   398  				insertExpression.Exprs, db.Statement.Vars...)
   399  			if e != nil {
   400  				_ = db.AddError(e)
   401  				return
   402  			}
   403  			if t.isShardingPrimaryKey && id == 0 {
   404  				if t.config.PrimaryKeyGenerator == nil {
   405  					_ = db.AddError(ErrIDGeneratorNotFound)
   406  					return
   407  				}
   408  				if id, e = t.config.PrimaryKeyGenerator.Next(idgen.GormTx(db)); e != nil {
   409  					_ = db.AddError(e)
   410  					return
   411  				}
   412  				stmt.ColumnNames = append(stmt.ColumnNames, &sqlparser.Ident{Name: "id"})
   413  				insertExpression.Exprs = append(insertExpression.Exprs, &sqlparser.NumberLit{Value: cast.ToString(id)})
   414  				values, _, _ = t.insertValue(t.config.ShardingKeys, stmt.ColumnNames,
   415  					insertExpression.Exprs, db.Statement.Vars...)
   416  			}
   417  
   418  			subSuffix, e := t.shardingFunc(db.Statement.Context, values...)
   419  			if e != nil {
   420  				_ = db.AddError(e)
   421  				return
   422  			}
   423  
   424  			if suffix != "" && suffix != subSuffix {
   425  				_ = db.AddError(ErrDiffSuffixDML)
   426  				return
   427  			}
   428  			suffix = subSuffix
   429  		}
   430  		// FIXME: could not find the table schema to migrate
   431  		if e := t.createTableIfNotExists(db, db.Statement.Table, suffix); e != nil {
   432  			_ = db.AddError(e)
   433  			return
   434  		}
   435  		stmt.TableName = &sqlparser.TableName{Name: &sqlparser.Ident{Name: stmt.TableName.TableName() + suffix}}
   436  		newSQL = stmt.String()
   437  	case *sqlparser.SelectStatement:
   438  		parseSelectStatementFunc := func(stmt *sqlparser.SelectStatement) (ok bool, err error) {
   439  			if stmt.Hint != nil && stmt.Hint.Value == "nosharding" {
   440  				return false, nil
   441  			}
   442  
   443  			switch tbl := stmt.FromItems.(type) {
   444  			case *sqlparser.TableName:
   445  				if tbl.TableName() != t.config.Table {
   446  					return false, nil
   447  				}
   448  				suffix, e := getSuffix(stmt.Condition, t.config.Table, db.Statement.Vars...)
   449  				if e != nil {
   450  					_ = db.AddError(e)
   451  					return false, nil
   452  				}
   453  				oldTableName := tbl.TableName()
   454  				newTableName := oldTableName + suffix
   455  				stmt.FromItems = &sqlparser.TableName{Name: &sqlparser.Ident{Name: newTableName}}
   456  				stmt.OrderBy = t.replaceOrderByTableName(stmt.OrderBy, oldTableName, newTableName)
   457  				if e := t.replaceCondition(stmt.Condition, oldTableName, newTableName); err != nil {
   458  					_ = db.AddError(e)
   459  					return false, nil
   460  				}
   461  			case *sqlparser.JoinClause:
   462  				tblx, _ := tbl.X.(*sqlparser.TableName)
   463  				tbly, _ := tbl.Y.(*sqlparser.TableName)
   464  				isXSharding := tblx != nil && tblx.TableName() == t.config.Table
   465  				isYSharding := tbly != nil && tbly.TableName() == t.config.Table
   466  				oldTableName := ""
   467  				switch {
   468  				case isXSharding:
   469  					oldTableName = tblx.TableName()
   470  				case isYSharding:
   471  					oldTableName = tbly.TableName()
   472  				default:
   473  					return false, nil
   474  				}
   475  				suffix, e := getSuffix(stmt.Condition, oldTableName, db.Statement.Vars...)
   476  				if e != nil {
   477  					_ = db.AddError(e)
   478  					return false, nil
   479  				}
   480  				newTableName := oldTableName + suffix
   481  				stmt.OrderBy = t.replaceOrderByTableName(stmt.OrderBy, oldTableName, newTableName)
   482  				if e := t.replaceCondition(stmt.Condition, oldTableName, newTableName); err != nil {
   483  					_ = db.AddError(e)
   484  					return false, nil
   485  				}
   486  				if e := t.replaceConstraint(tbl.Constraint, oldTableName, newTableName); err != nil {
   487  					_ = db.AddError(e)
   488  					return false, nil
   489  				}
   490  				if isXSharding {
   491  					tblx.Name.Name = newTableName
   492  				} else {
   493  					tbly.Name.Name = newTableName
   494  				}
   495  				if stmt.Columns != nil {
   496  					for _, column := range *stmt.Columns {
   497  						columnTbl, ok := column.Expr.(*sqlparser.QualifiedRef)
   498  						if !ok || columnTbl.Table.Name != oldTableName {
   499  							continue
   500  						}
   501  						columnTbl.Table.Name = newTableName
   502  					}
   503  				}
   504  			}
   505  			return true, nil
   506  		}
   507  		for compound := stmt; compound != nil; compound = compound.Compound {
   508  			if ok, err = parseSelectStatementFunc(compound); !ok || err != nil {
   509  				return
   510  			}
   511  		}
   512  
   513  		newSQL = stmt.String()
   514  
   515  	case *sqlparser.UpdateStatement:
   516  		if stmt.TableName.TableName() != t.config.Table {
   517  			return
   518  		}
   519  
   520  		suffix, e := getSuffix(stmt.Condition, t.config.Table, db.Statement.Vars...)
   521  		if e != nil {
   522  			_ = db.AddError(e)
   523  			return
   524  		}
   525  
   526  		oldTableName := stmt.TableName.TableName()
   527  		newTableName := oldTableName + suffix
   528  		stmt.TableName = &sqlparser.TableName{Name: &sqlparser.Ident{Name: newTableName}}
   529  		if e := t.replaceCondition(stmt.Condition, oldTableName, newTableName); err != nil {
   530  			_ = db.AddError(e)
   531  			return false, nil
   532  		}
   533  		newSQL = stmt.String()
   534  	case *sqlparser.DeleteStatement:
   535  		if stmt.TableName.TableName() != t.config.Table {
   536  			return
   537  		}
   538  
   539  		suffix, e := getSuffix(stmt.Condition, t.config.Table, db.Statement.Vars...)
   540  		if e != nil {
   541  			_ = db.AddError(e)
   542  			return
   543  		}
   544  
   545  		oldTableName := stmt.TableName.TableName()
   546  		newTableName := oldTableName + suffix
   547  		stmt.TableName = &sqlparser.TableName{Name: &sqlparser.Ident{Name: newTableName}}
   548  		if e := t.replaceCondition(stmt.Condition, oldTableName, newTableName); err != nil {
   549  			_ = db.AddError(e)
   550  			return false, nil
   551  		}
   552  		newSQL = stmt.String()
   553  	default:
   554  		_ = db.AddError(sqlparser.ErrNotImplemented)
   555  		return
   556  	}
   557  
   558  	sb := strings.Builder{}
   559  	sb.Grow(len(newSQL))
   560  	sb.WriteString(newSQL)
   561  	db.Statement.SQL = sb
   562  
   563  	return true, nil
   564  }
   565  func (t *tableSharding) wrapDispatchTableBySQL(db *gorm.DB, opts ...utils.OptionExtender) {
   566  	if ok, err := t.dispatchTableBySQL(db, opts...); err != nil || !ok {
   567  		// not a dml
   568  		if err != nil {
   569  			return
   570  		}
   571  		// not a sharding table
   572  		if !ok {
   573  			// FIXME: reset sql parse result will get duplicated sql statement
   574  			// db.Statement.SQL = strings.Builder{}
   575  			// db.Statement.Vars = nil
   576  		}
   577  	}
   578  }
   579  func (t *tableSharding) replaceStatementClauseAndSchema(db *gorm.DB, opt *tableShardingDispatchOption) {
   580  	changeExprFunc := func(src []clause.Expression) (dst []clause.Expression) {
   581  		changeTableFunc := func(src any) (dst any, ok bool) {
   582  			switch col := src.(type) {
   583  			case clause.Column:
   584  				if col.Table == t.config.Table {
   585  					col.Table = db.Statement.Table
   586  					return col, true
   587  				}
   588  			case clause.Table:
   589  				if col.Name == t.config.Table {
   590  					col.Name = db.Statement.Table
   591  					return col, true
   592  				}
   593  			}
   594  			return
   595  		}
   596  		dst = make([]clause.Expression, 0, len(src))
   597  		for _, srcExpr := range src {
   598  			switch expr := srcExpr.(type) {
   599  			case clause.IN:
   600  				if col, ok := changeTableFunc(expr.Column); ok {
   601  					expr.Column = col
   602  				}
   603  				dst = append(dst, expr)
   604  			case clause.Eq:
   605  				if col, ok := changeTableFunc(expr.Column); ok {
   606  					expr.Column = col
   607  				}
   608  				dst = append(dst, expr)
   609  			case clause.Neq:
   610  				if col, ok := changeTableFunc(expr.Column); ok {
   611  					expr.Column = col
   612  				}
   613  				dst = append(dst, expr)
   614  			case clause.Gt:
   615  				if col, ok := changeTableFunc(expr.Column); ok {
   616  					expr.Column = col
   617  				}
   618  				dst = append(dst, expr)
   619  			case clause.Gte:
   620  				if col, ok := changeTableFunc(expr.Column); ok {
   621  					expr.Column = col
   622  				}
   623  				dst = append(dst, expr)
   624  			case clause.Lt:
   625  				if col, ok := changeTableFunc(expr.Column); ok {
   626  					expr.Column = col
   627  				}
   628  				dst = append(dst, expr)
   629  			case clause.Lte:
   630  				if col, ok := changeTableFunc(expr.Column); ok {
   631  					expr.Column = col
   632  				}
   633  				dst = append(dst, expr)
   634  			case clause.Like:
   635  				if col, ok := changeTableFunc(expr.Column); ok {
   636  					expr.Column = col
   637  				}
   638  				dst = append(dst, expr)
   639  			default:
   640  				dst = append(dst, expr)
   641  			}
   642  		}
   643  		return
   644  	}
   645  	changeClausesMapping := map[string]func(cls clause.Clause){
   646  		"WHERE": func(cls clause.Clause) {
   647  			whereClause, ok := cls.Expression.(clause.Where)
   648  			if !ok {
   649  				return
   650  			}
   651  			whereClause.Exprs = changeExprFunc(whereClause.Exprs)
   652  			cls.Expression = whereClause
   653  			db.Statement.Clauses["WHERE"] = cls
   654  		},
   655  		"FROM": func(cls clause.Clause) {
   656  			fromClause, ok := cls.Expression.(clause.From)
   657  			if !ok {
   658  				return
   659  			}
   660  			tables := make([]clause.Table, 0, len(fromClause.Tables))
   661  			for _, table := range fromClause.Tables {
   662  				if table.Name == t.config.Table {
   663  					table.Name = db.Statement.Table
   664  					tables = append(tables, table)
   665  				} else {
   666  					tables = append(tables, table)
   667  				}
   668  			}
   669  			fromClause.Tables = tables
   670  			cls.Expression = fromClause
   671  			db.Statement.Clauses["FROM"] = cls
   672  		},
   673  		// TODO: check if order by contains table name
   674  		"ORDER BY": func(cls clause.Clause) {
   675  			_, ok := cls.Expression.(clause.OrderBy)
   676  			if !ok {
   677  				return
   678  			}
   679  		},
   680  	}
   681  
   682  	for name, cls := range db.Statement.Clauses {
   683  		if mappingFunc, ok := changeClausesMapping[name]; ok {
   684  			mappingFunc(cls)
   685  		}
   686  	}
   687  
   688  	if opt.isInsert {
   689  		db.Clauses(clause.Insert{Table: clause.Table{Name: db.Statement.Table}})
   690  	} else {
   691  		db.Clauses(clause.From{Tables: []clause.Table{{Name: db.Statement.Table}}})
   692  	}
   693  }
   694  
   695  func (t *tableSharding) replaceCondition(conditions sqlparser.Expr, oldTableName, newTableName string) (err error) {
   696  	err = sqlparser.Walk(
   697  		sqlparser.VisitFunc(func(node sqlparser.Node) (err error) {
   698  			n, ok := node.(*sqlparser.BinaryExpr)
   699  			if !ok {
   700  				return
   701  			}
   702  
   703  			x, ok := n.X.(*sqlparser.QualifiedRef)
   704  			if !ok || x.Table == nil || x.Table.Name != oldTableName {
   705  				return
   706  			}
   707  
   708  			x.Table.Name = newTableName
   709  			return
   710  		}),
   711  		conditions,
   712  	)
   713  	return
   714  }
   715  
   716  func (t *tableSharding) replaceConstraint(constraints sqlparser.Node, oldTableName, newTableName string) (err error) {
   717  	return sqlparser.Walk(
   718  		sqlparser.VisitFunc(func(node sqlparser.Node) (err error) {
   719  			n, ok := node.(*sqlparser.QualifiedRef)
   720  			if !ok || n.Table == nil || n.Table.Name != oldTableName {
   721  				return
   722  			}
   723  
   724  			n.Table.Name = newTableName
   725  			return
   726  		}),
   727  		constraints,
   728  	)
   729  }
   730  
   731  func (t *tableSharding) insertValue(keys []string, names []*sqlparser.Ident, exprs []sqlparser.Expr, args ...any) (
   732  	values []any, id uint64, err error) {
   733  	if len(names) != len(exprs) {
   734  		return nil, 0, ErrColumnAndExprMisMatch
   735  	}
   736  
   737  	for _, key := range keys {
   738  		found := false
   739  		isPrimaryKey := key == t.shardingPrimaryKey
   740  		for i, name := range names {
   741  			if name.Name != key {
   742  				continue
   743  			}
   744  
   745  			switch expr := exprs[i].(type) {
   746  			case *sqlparser.BindExpr:
   747  				if !isPrimaryKey {
   748  					values = append(values, args[expr.Pos])
   749  				} else {
   750  					switch v := args[expr.Pos].(type) {
   751  					case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, string:
   752  						if id, err = cast.ToUint64E(v); err != nil {
   753  							return nil, 0, errors.Wrapf(err, "parse id as uint64 failed [%v]", v)
   754  						}
   755  					default:
   756  						return nil, 0, ErrInvalidID
   757  					}
   758  					if id != 0 {
   759  						values = append(values, args[expr.Pos])
   760  					}
   761  				}
   762  			case *sqlparser.StringLit:
   763  				if !isPrimaryKey {
   764  					values = append(values, expr.Value)
   765  				} else {
   766  					if id, err = cast.ToUint64E(expr.Value); err != nil {
   767  						return nil, 0, errors.Wrapf(err, "parse id as uint64 failed [%s]", expr.Value)
   768  					}
   769  					if id != 0 {
   770  						values = append(values, expr.Value)
   771  					}
   772  				}
   773  			case *sqlparser.NumberLit:
   774  				if !isPrimaryKey {
   775  					values = append(values, expr.Value)
   776  				} else {
   777  					if id, err = strconv.ParseUint(expr.Value, 10, 64); err != nil {
   778  						return nil, 0, errors.Wrapf(err,
   779  							"parse id as uint64 failed [%s]", expr.Value)
   780  					}
   781  					if id != 0 {
   782  						values = append(values, expr.Value)
   783  					}
   784  				}
   785  			default:
   786  				return nil, 0, sqlparser.ErrNotImplemented
   787  			}
   788  
   789  			found = true
   790  			break
   791  		}
   792  		if !found && !isPrimaryKey {
   793  			return nil, 0, ErrMissingShardingKey
   794  		}
   795  	}
   796  
   797  	return
   798  }
   799  
   800  func (t *tableSharding) nonInsertValue(condition sqlparser.Node, key, tableName string, args ...any) (
   801  	value any, err error) {
   802  	found := false
   803  	err = sqlparser.Walk(
   804  		sqlparser.VisitFunc(func(node sqlparser.Node) (err error) {
   805  			n, ok := node.(*sqlparser.BinaryExpr)
   806  			if !ok {
   807  				return
   808  			}
   809  			if n.Op != sqlparser.EQ {
   810  				return
   811  			}
   812  
   813  			switch x := n.X.(type) {
   814  			case *sqlparser.Ident:
   815  				if x.Name != key {
   816  					return
   817  				}
   818  			case *sqlparser.QualifiedRef:
   819  				if !ok || x.Table.Name != tableName || x.Column.Name != key {
   820  					return
   821  				}
   822  			}
   823  
   824  			found = true
   825  			switch expr := n.Y.(type) {
   826  			case *sqlparser.BindExpr:
   827  				value = args[expr.Pos]
   828  			case *sqlparser.StringLit:
   829  				value = expr.Value
   830  			case *sqlparser.NumberLit:
   831  				value = expr.Value
   832  			default:
   833  				return sqlparser.ErrNotImplemented
   834  			}
   835  
   836  			return
   837  		}),
   838  		condition,
   839  	)
   840  	if err != nil {
   841  		return
   842  	}
   843  	if !found {
   844  		return nil, ErrMissingShardingKey
   845  	}
   846  	return
   847  }
   848  
   849  func (t *tableSharding) setPrimaryKeyByModel(db *gorm.DB, opt *tableShardingDispatchOption) (err error) {
   850  	if !opt.isInsert || db.Statement.Model == nil ||
   851  		db.Statement.Schema == nil || db.Statement.Schema.PrioritizedPrimaryField == nil {
   852  		return
   853  	}
   854  	setPrimaryKeyFunc := func(rv reflect.Value) (err error) {
   855  		_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv)
   856  		if !isZero {
   857  			return
   858  		}
   859  		if t.config.PrimaryKeyGenerator == nil {
   860  			return ErrIDGeneratorNotFound
   861  		}
   862  		id, err := t.config.PrimaryKeyGenerator.Next(idgen.GormTx(db))
   863  		if err != nil {
   864  			return
   865  		}
   866  		return db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, id)
   867  	}
   868  
   869  	switch db.Statement.ReflectValue.Kind() {
   870  	case reflect.Slice, reflect.Array:
   871  		for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
   872  			rv := db.Statement.ReflectValue.Index(i)
   873  			if reflect.Indirect(rv).Kind() != reflect.Struct {
   874  				break
   875  			}
   876  
   877  			if err = setPrimaryKeyFunc(rv); err != nil {
   878  				return
   879  			}
   880  		}
   881  	case reflect.Struct:
   882  		if err = setPrimaryKeyFunc(db.Statement.ReflectValue); err != nil {
   883  			return
   884  		}
   885  	}
   886  
   887  	return
   888  }
   889  
   890  func (t *tableSharding) getModelReflectValue(db *gorm.DB) (reflectVal reflect.Value, ok bool) {
   891  	reflectVal = utils.IndirectValue(db.Statement.ReflectValue)
   892  	if reflectVal.Kind() == reflect.Array || reflectVal.Kind() == reflect.Slice {
   893  		if reflectVal.Len() == 0 {
   894  			return
   895  		}
   896  		reflectVal = utils.IndirectValue(reflectVal.Index(0))
   897  	}
   898  
   899  	if reflectVal.Kind() != reflect.Struct {
   900  		return
   901  	}
   902  
   903  	return reflectVal, !utils.IsBlank(reflectVal.Interface())
   904  }
   905  
   906  func (t *tableSharding) checkDiffSuffixesByModel(db *gorm.DB) (err error) {
   907  	reflectVal := utils.IndirectValue(db.Statement.ReflectValue)
   908  	if reflectVal.Kind() != reflect.Array && reflectVal.Kind() != reflect.Slice {
   909  		return
   910  	}
   911  
   912  	suffix := ""
   913  	for i := 0; i < reflectVal.Len(); i++ {
   914  		reflectItemVal := reflect.Indirect(reflectVal.Index(i))
   915  		values := make([]any, 0, len(t.config.ShardingKeys))
   916  		for _, key := range t.config.ShardingKeys {
   917  			val := reflectItemVal.FieldByNameFunc(func(v string) bool { return strings.EqualFold(v, key) })
   918  			if !val.IsValid() {
   919  				val, _ = utils.GetGormColumnValue(reflectItemVal, key)
   920  			}
   921  			if !val.IsValid() {
   922  				return db.AddError(ErrMissingShardingKey)
   923  			}
   924  			values = append(values, val.Interface())
   925  		}
   926  		subSuffix, err := t.shardingFunc(db.Statement.Context, values...)
   927  		if err != nil {
   928  			return db.AddError(err)
   929  		}
   930  		if suffix != "" && suffix != subSuffix {
   931  			return db.AddError(ErrDiffSuffixDML)
   932  		}
   933  		suffix = subSuffix
   934  	}
   935  	return
   936  }
   937  
   938  func (t *tableSharding) replaceOrderByTableName(
   939  	orderBy []*sqlparser.OrderingTerm, oldName, newName string) []*sqlparser.OrderingTerm {
   940  	for i, term := range orderBy {
   941  		if x, ok := term.X.(*sqlparser.QualifiedRef); ok {
   942  			if x.Table.Name == oldName {
   943  				x.Table.Name = newName
   944  				orderBy[i].X = x
   945  			}
   946  		}
   947  	}
   948  	return orderBy
   949  }
   950  
   951  func (t *tableSharding) createTableIfNotExists(db *gorm.DB, tableName, suffix string) (err error) {
   952  	shardingTableName := tableName + suffix
   953  	t.shardingTableCreatedMutex.RLock()
   954  	if _, ok := t.shardingTableCreated[shardingTableName]; ok {
   955  		t.shardingTableCreatedMutex.RUnlock()
   956  		return
   957  	}
   958  	t.shardingTableCreatedMutex.RUnlock()
   959  	t.shardingTableCreatedMutex.Lock()
   960  	defer t.shardingTableCreatedMutex.Unlock()
   961  
   962  	defer t.ignore(t.DB)() //nolint: revive // partial calling issue
   963  	if t.DB.Migrator().HasTable(shardingTableName) {
   964  		t.shardingTableCreated[shardingTableName] = struct{}{}
   965  		return
   966  	}
   967  
   968  	model := db.Statement.Model
   969  	if model == nil {
   970  		model = t.shardingTableModel
   971  	}
   972  	if model == nil {
   973  		return ErrShardingModelNotFound
   974  	}
   975  	tx := t.DB.Session(&gorm.Session{}).Table(shardingTableName)
   976  	if err = db.Dialector.Migrator(tx).AutoMigrate(db.Statement.Model); err != nil {
   977  		return err
   978  	}
   979  	t.shardingTableCreated[shardingTableName] = struct{}{}
   980  	return
   981  }
   982  
   983  func (t *tableSharding) suffixes() (suffixes []string, err error) {
   984  	switch {
   985  	case t.config.ShardingKeyByRawValue:
   986  		if len(t.config.ShardingKeysForMigrating) == 0 {
   987  			return nil, errors.New("sharding key by raw value but do not configure keys for migrating")
   988  		}
   989  
   990  		for _, shardingKey := range t.config.ShardingKeysForMigrating {
   991  			suffixes = append(suffixes, fmt.Sprintf(t.suffixFormat, shardingKey))
   992  		}
   993  	default:
   994  		for i := 0; i < int(t.config.NumberOfShards); i++ {
   995  			suffixes = append(suffixes, fmt.Sprintf(t.suffixFormat, i))
   996  		}
   997  	}
   998  	return
   999  }
  1000  
  1001  func (t *tableSharding) ignore(db *gorm.DB) func() {
  1002  	if _, ok := db.Statement.Settings.Load(shardingIgnoreStoreKey); ok {
  1003  		return func() {}
  1004  	}
  1005  	db.Statement.Settings.Store(shardingIgnoreStoreKey, nil)
  1006  	return func() {
  1007  		db.Statement.Settings.Delete(shardingIgnoreStoreKey)
  1008  	}
  1009  }
  1010  func (t *tableSharding) isIgnored(db *gorm.DB) func() bool {
  1011  	return func() bool {
  1012  		_, ok := db.Statement.Settings.Load(shardingIgnoreStoreKey)
  1013  		return ok
  1014  	}
  1015  }
  1016  
  1017  func (t *tableSharding) defaultShardingFunc() func(ctx context.Context, values ...any) (suffix string, err error) {
  1018  	if !t.config.ShardingKeyByRawValue && t.config.NumberOfShards == 0 {
  1019  		panic(errors.New("missing number_of_shards config"))
  1020  	}
  1021  	t.suffixFormat = constant.Underline
  1022  
  1023  	switch {
  1024  	case utils.IsStrNotBlank(t.config.CustomSuffix):
  1025  		t.suffixFormat += t.config.CustomSuffix
  1026  	case t.config.ShardingKeyByRawValue:
  1027  		t.suffixFormat += "%s"
  1028  	default:
  1029  		t.suffixFormat += strings.Join(t.config.ShardingKeys, constant.Underline)
  1030  	}
  1031  
  1032  	numberOfShards := t.config.NumberOfShards
  1033  	if !strings.Contains(t.suffixFormat, "%") {
  1034  		if t.config.ShardingKeyByRawValue {
  1035  			t.suffixFormat += "_%s"
  1036  		} else if numberOfShards < 10 {
  1037  			t.suffixFormat += "_%01d"
  1038  		} else if numberOfShards < 100 {
  1039  			t.suffixFormat += "_%02d"
  1040  		} else if numberOfShards < 1000 {
  1041  			t.suffixFormat += "_%03d"
  1042  		} else if numberOfShards < 10000 {
  1043  			t.suffixFormat += "_%04d"
  1044  		}
  1045  	}
  1046  
  1047  	switch {
  1048  	case t.config.ShardingKeyByRawValue:
  1049  		return func(ctx context.Context, values ...any) (suffix string, err error) {
  1050  			data := make([]string, 0, len(values))
  1051  			for _, value := range values {
  1052  				v, err := cast.ToStringE(value)
  1053  				if err != nil {
  1054  					return "", err
  1055  				}
  1056  				data = append(data, v)
  1057  			}
  1058  			shardingKey := strings.Join(data, constant.Underline)
  1059  			return fmt.Sprintf("_%s", shardingKey), nil
  1060  		}
  1061  	case t.config.ShardingKeyExpr != nil:
  1062  		numberOfShardsFloat64 := float64(numberOfShards)
  1063  		return func(ctx context.Context, values ...any) (suffix string, err error) {
  1064  			params := make(map[string]any, len(t.config.ShardingKeys))
  1065  			for idx, column := range t.config.ShardingKeys {
  1066  				params[column] = values[idx]
  1067  			}
  1068  
  1069  			result, err := t.config.ShardingKeyExpr(ctx, params)
  1070  			if err != nil {
  1071  				return
  1072  			}
  1073  			shardingKey := int64(math.Mod(cast.ToFloat64(result), numberOfShardsFloat64))
  1074  			return fmt.Sprintf(t.suffixFormat, shardingKey), nil
  1075  		}
  1076  	default:
  1077  		stringToByteSliceFunc := func(v string) (data []byte) {
  1078  			utils.IfAny(
  1079  				// number
  1080  				func() (ok bool) {
  1081  					num := new(big.Float)
  1082  					if _, ok = num.SetString(v); !ok {
  1083  						return
  1084  					}
  1085  					gobEncoded, err := num.GobEncode()
  1086  					if err != nil {
  1087  						return false
  1088  					}
  1089  					data = gobEncoded
  1090  					return
  1091  				},
  1092  				// uuid
  1093  				func() bool {
  1094  					uid, err := uuid.Parse(v)
  1095  					if err != nil {
  1096  						return false
  1097  					}
  1098  					data = uid[:]
  1099  					return true
  1100  				},
  1101  				// bytes
  1102  				func() bool { data = []byte(v); return true },
  1103  			)
  1104  			return
  1105  		}
  1106  		return func(ctx context.Context, values ...any) (suffix string, err error) {
  1107  			size := 0
  1108  			for _, value := range values {
  1109  				s := binary.Size(value)
  1110  				if s <= 0 {
  1111  					s = int(unsafe.Sizeof(value))
  1112  				}
  1113  				size += s
  1114  			}
  1115  			w := new(bytes.Buffer)
  1116  			w.Grow(size)
  1117  
  1118  			for _, value := range values {
  1119  				var data any
  1120  				switch v := value.(type) {
  1121  				case int, *int:
  1122  					data = utils.IntNarrow(cast.ToInt(v))
  1123  				case uint, *uint:
  1124  					data = utils.UintNarrow(cast.ToUint(v))
  1125  				case []int:
  1126  					data = make([]any, len(v))
  1127  					for i := 0; i < len(v); i++ {
  1128  						data.([]any)[i] = utils.IntNarrow(cast.ToInt(v))
  1129  					}
  1130  				case []uint:
  1131  					data = make([]any, len(v))
  1132  					for i := 0; i < len(v); i++ {
  1133  						data.([]any)[i] = utils.UintNarrow(cast.ToUint(v))
  1134  					}
  1135  				case string:
  1136  					data = stringToByteSliceFunc(v)
  1137  				case []byte:
  1138  					data = stringToByteSliceFunc(utils.UnsafeBytesToString(v))
  1139  				case uuid.UUID:
  1140  					data = v[:]
  1141  				default:
  1142  					data = v
  1143  				}
  1144  				if err = binary.Write(w, binary.BigEndian, data); err != nil {
  1145  					return
  1146  				}
  1147  			}
  1148  
  1149  			// checksum mod shards
  1150  			checksum := crc32.ChecksumIEEE(w.Bytes())
  1151  			shardingKey := uint64(checksum) % uint64(numberOfShards)
  1152  			suffix = fmt.Sprintf(t.suffixFormat, shardingKey)
  1153  			return
  1154  		}
  1155  	}
  1156  }
  1157  
  1158  type shardingDialector struct {
  1159  	gorm.Dialector
  1160  	shardingMap map[string]*tableSharding
  1161  }
  1162  
  1163  func newShardingDialector(d gorm.Dialector, s *tableSharding) shardingDialector {
  1164  	if sd, ok := d.(shardingDialector); ok {
  1165  		sd.shardingMap[s.config.Table] = s
  1166  		return sd
  1167  	}
  1168  
  1169  	return shardingDialector{
  1170  		Dialector:   d,
  1171  		shardingMap: map[string]*tableSharding{s.config.Table: s},
  1172  	}
  1173  }
  1174  
  1175  func (s shardingDialector) Migrator(db *gorm.DB) gorm.Migrator {
  1176  	m := s.Dialector.Migrator(db)
  1177  	if (*tableSharding)(nil).isIgnored(db)() {
  1178  		return m
  1179  	}
  1180  	return &shardingMigrator{
  1181  		Migrator:    m,
  1182  		db:          db,
  1183  		shardingMap: s.shardingMap,
  1184  		dialector:   s.Dialector,
  1185  	}
  1186  }
  1187  func (s shardingDialector) SavePoint(tx *gorm.DB, name string) error {
  1188  	if savePointer, ok := s.Dialector.(gorm.SavePointerDialectorInterface); ok {
  1189  		return savePointer.SavePoint(tx, name)
  1190  	} else {
  1191  		return gorm.ErrUnsupportedDriver
  1192  	}
  1193  }
  1194  func (s shardingDialector) RollbackTo(tx *gorm.DB, name string) error {
  1195  	if savePointer, ok := s.Dialector.(gorm.SavePointerDialectorInterface); ok {
  1196  		return savePointer.RollbackTo(tx, name)
  1197  	} else {
  1198  		return gorm.ErrUnsupportedDriver
  1199  	}
  1200  }
  1201  
  1202  type shardingMigrator struct {
  1203  	gorm.Migrator
  1204  	db          *gorm.DB
  1205  	dialector   gorm.Dialector
  1206  	shardingMap map[string]*tableSharding
  1207  }
  1208  
  1209  func (s *shardingMigrator) AutoMigrate(dst ...any) (err error) {
  1210  	sharding, ok := s.shardingMap[s.tableName(s.db, dst[0])]
  1211  	if !ok {
  1212  		defer (*tableSharding)(nil).ignore(s.db)() //nolint: revive // partial calling issue
  1213  		return s.Migrator.AutoMigrate(dst...)
  1214  	}
  1215  
  1216  	stmt := &gorm.Statement{DB: sharding.DB}
  1217  	if sharding.isIgnored(sharding.DB)() {
  1218  		return s.dialector.Migrator(stmt.DB.Session(&gorm.Session{})).AutoMigrate(dst...)
  1219  	}
  1220  
  1221  	shardingDst, err := s.getShardingDst(sharding, dst...)
  1222  	if err != nil {
  1223  		return err
  1224  	}
  1225  
  1226  	defer sharding.ignore(sharding.DB)() //nolint: revive // partial calling issue
  1227  	for _, sd := range shardingDst {
  1228  		tx := stmt.DB.Session(&gorm.Session{}).Table(sd.table)
  1229  		if err = s.dialector.Migrator(tx).AutoMigrate(sd.dst); err != nil {
  1230  			return err
  1231  		}
  1232  	}
  1233  
  1234  	return
  1235  }
  1236  func (s *shardingMigrator) DropTable(dst ...any) (err error) {
  1237  	sharding, ok := s.shardingMap[s.tableName(s.db, dst[0])]
  1238  	if !ok {
  1239  		defer (*tableSharding)(nil).ignore(s.db)() //nolint: revive // partial calling issue
  1240  		return s.Migrator.DropTable(dst...)
  1241  	}
  1242  
  1243  	stmt := &gorm.Statement{DB: sharding.DB}
  1244  	if sharding.isIgnored(sharding.DB)() {
  1245  		return s.dialector.Migrator(stmt.DB.Session(&gorm.Session{})).DropTable(dst...)
  1246  	}
  1247  	shardingDst, err := s.getShardingDst(sharding, dst...)
  1248  	if err != nil {
  1249  		return err
  1250  	}
  1251  
  1252  	defer sharding.ignore(sharding.DB)() //nolint: revive // partial calling issue
  1253  	for _, sd := range shardingDst {
  1254  		tx := stmt.DB.Session(&gorm.Session{}).Table(sd.table)
  1255  		if err = s.dialector.Migrator(tx).DropTable(sd.table); err != nil {
  1256  			return err
  1257  		}
  1258  	}
  1259  
  1260  	return
  1261  }
  1262  
  1263  type shardingDst struct {
  1264  	table string
  1265  	dst   any
  1266  }
  1267  
  1268  func (s *shardingMigrator) getShardingDst(sharding *tableSharding, src ...any) (dst []shardingDst, err error) {
  1269  	for _, model := range src {
  1270  		stmt := &gorm.Statement{DB: sharding.DB}
  1271  		if err = stmt.Parse(model); err != nil {
  1272  			return
  1273  		}
  1274  
  1275  		// support sharding table
  1276  		suffixes, err := sharding.suffixes()
  1277  		if err != nil {
  1278  			return nil, err
  1279  		}
  1280  		if len(suffixes) == 0 {
  1281  			return nil, fmt.Errorf("sharding table:%s suffixes are empty", stmt.Table)
  1282  		}
  1283  		for _, suffix := range suffixes {
  1284  			dst = append(dst, shardingDst{
  1285  				table: stmt.Table + suffix,
  1286  				dst:   model,
  1287  			})
  1288  		}
  1289  	}
  1290  	return
  1291  }
  1292  func (s *shardingMigrator) tableName(db *gorm.DB, m any) (name string) {
  1293  	if tabler, ok := m.(schema.Tabler); ok {
  1294  		name = tabler.TableName()
  1295  	}
  1296  	if tabler, ok := m.(schema.TablerWithNamer); ok {
  1297  		name = tabler.TableName(db.NamingStrategy)
  1298  	}
  1299  	namingStrategy := reflect.ValueOf(db.NamingStrategy)
  1300  	if namingStrategy.CanConvert(gormSchemaEmbeddedNamer) {
  1301  		name = reflect.Indirect(namingStrategy.Convert(gormSchemaEmbeddedNamer)).FieldByName("Table").String()
  1302  	}
  1303  	return
  1304  }