github.com/wfusion/gofusion@v1.1.14/db/dal.go (about)

     1  package db
     2  
     3  import (
     4  	"context"
     5  	"reflect"
     6  
     7  	"github.com/pkg/errors"
     8  	"gorm.io/gorm"
     9  	"gorm.io/gorm/clause"
    10  	"gorm.io/gorm/schema"
    11  
    12  	"github.com/wfusion/gofusion/common/utils"
    13  	"github.com/wfusion/gofusion/common/utils/inspect"
    14  	"github.com/wfusion/gofusion/db/plugins"
    15  
    16  	ormDrv "github.com/wfusion/gofusion/common/infra/drivers/orm"
    17  	fusCtx "github.com/wfusion/gofusion/context"
    18  )
    19  
    20  // DalInterface
    21  //nolint: revive // interface issue
    22  type DalInterface[T any, TS ~[]*T] interface {
    23  	Query(ctx context.Context, query any, args ...any) (TS, error)
    24  	QueryFirst(ctx context.Context, query any, args ...any) (*T, error)
    25  	QueryLast(ctx context.Context, query any, args ...any) (*T, error)
    26  	QueryInBatches(ctx context.Context, batchSize int, fc func(tx *DB, batch int, found TS) error, query any, args ...any) error
    27  	Count(ctx context.Context, query any, args ...any) (int64, error)
    28  	Pluck(ctx context.Context, column string, dest any, query any, args ...any) error
    29  	Take(ctx context.Context, dest any, conds ...any) error
    30  	InsertOne(ctx context.Context, mod *T, opts ...utils.OptionExtender) error
    31  	InsertInBatches(ctx context.Context, modList TS, batchSize int, opts ...utils.OptionExtender) error
    32  	Save(ctx context.Context, mod any, opts ...utils.OptionExtender) error
    33  	Update(ctx context.Context, column string, value any, query any, args ...any) (int64, error)
    34  	Updates(ctx context.Context, columns map[string]any, query any, args ...any) (int64, error)
    35  	Delete(ctx context.Context, query any, args ...any) (int64, error)
    36  	FirstOrCreate(ctx context.Context, mod *T, conds ...any) (int64, error)
    37  	Transaction(ctx context.Context, fc func(tx context.Context) error, opts ...utils.OptionExtender) error
    38  	ReadDB(ctx context.Context) *gorm.DB
    39  	WriteDB(ctx context.Context) *gorm.DB
    40  	SetCtxReadDB(src context.Context) (dst context.Context)
    41  	SetCtxWriteDB(src context.Context) (dst context.Context)
    42  	Model() *T
    43  	ModelSlice() TS
    44  	IgnoreErr(err error) error
    45  	CanIgnore(err error) bool
    46  	ShardingByValues(ctx context.Context, src []map[string]any) (dst map[string][]map[string]any, err error)
    47  	ShardingIDGen(ctx context.Context) (id uint64, err error)
    48  	ShardingIDListGen(ctx context.Context, amount int) (idList []uint64, err error)
    49  	ShardingByModelList(ctx context.Context, src TS) (dst map[string]TS, err error)
    50  }
    51  
    52  type dal[T any, TS ~[]*T] struct {
    53  	appName     string
    54  	readDBName  string
    55  	writeDBName string
    56  }
    57  
    58  func NewDAL[T any, TS ~[]*T](readDBName, writeDBName string, opts ...utils.OptionExtender) DalInterface[T, TS] {
    59  	instance := new(T)
    60  	if _, ok := any(instance).(schema.Tabler); !ok {
    61  		panic(errors.Errorf("model unimplement schema.Tabler [model[%T] read_db[%s] write_db[%s]]",
    62  			instance, readDBName, writeDBName))
    63  	}
    64  	opt := utils.ApplyOptions[useOption](opts...)
    65  	return &dal[T, TS]{
    66  		appName:     opt.appName,
    67  		readDBName:  readDBName,
    68  		writeDBName: writeDBName,
    69  	}
    70  }
    71  
    72  func (d *dal[T, TS]) Query(ctx context.Context, query any, args ...any) (TS, error) {
    73  	o, args := d.parseOptionFromArgs(args...)
    74  	ctx = context.WithValue(ctx, fusCtx.KeyDALOption, o)
    75  
    76  	found := d.ModelSlice()
    77  	result := d.ReadDB(ctx).Clauses(o.clauses...).Where(query, args...).Find(&found)
    78  	if d.CanIgnore(result.Error) {
    79  		return nil, nil
    80  	}
    81  	return found, d.IgnoreErr(result.Error)
    82  }
    83  
    84  func (d *dal[T, TS]) QueryLast(ctx context.Context, query any, args ...any) (*T, error) {
    85  	o, args := d.parseOptionFromArgs(args...)
    86  	ctx = context.WithValue(ctx, fusCtx.KeyDALOption, o)
    87  
    88  	found := d.Model()
    89  	result := d.ReadDB(ctx).Clauses(o.clauses...).Where(query, args...).Last(found)
    90  	if d.CanIgnore(result.Error) {
    91  		return nil, nil
    92  	}
    93  	return found, d.IgnoreErr(result.Error)
    94  }
    95  
    96  func (d *dal[T, TS]) QueryFirst(ctx context.Context, query any, args ...any) (*T, error) {
    97  	o, args := d.parseOptionFromArgs(args...)
    98  	ctx = context.WithValue(ctx, fusCtx.KeyDALOption, o)
    99  
   100  	found := d.Model()
   101  	result := d.ReadDB(ctx).Clauses(o.clauses...).Where(query, args...).First(found)
   102  	if d.CanIgnore(result.Error) {
   103  		return nil, nil
   104  	}
   105  	return found, d.IgnoreErr(result.Error)
   106  }
   107  
   108  func (d *dal[T, TS]) QueryInBatches(ctx context.Context, batchSize int,
   109  	fc func(tx *DB, batch int, found TS) error, query any, args ...any) (err error) {
   110  	o, args := d.parseOptionFromArgs(args...)
   111  	ctx = context.WithValue(ctx, fusCtx.KeyDALOption, o)
   112  
   113  	orm := Use(ctx, d.readDBName, AppName(d.appName))
   114  	found := make(TS, 0, batchSize)
   115  	result := d.ReadDB(ctx).Clauses(o.clauses...).Where(query, args...).FindInBatches(&found, batchSize,
   116  		func(tx *gorm.DB, batch int) error {
   117  			wrapper := &DB{
   118  				DB:                   &ormDrv.DB{DB: tx},
   119  				Name:                 orm.Name,
   120  				tableShardingPlugins: orm.tableShardingPlugins,
   121  			}
   122  			return fc(wrapper, batch, found)
   123  		},
   124  	)
   125  	if d.CanIgnore(result.Error) {
   126  		return
   127  	}
   128  	return d.IgnoreErr(result.Error)
   129  }
   130  
   131  func (d *dal[T, TS]) Count(ctx context.Context, query any, args ...any) (int64, error) {
   132  	var count int64
   133  
   134  	o, args := d.parseOptionFromArgs(args...)
   135  	ctx = context.WithValue(ctx, fusCtx.KeyDALOption, o)
   136  
   137  	result := d.ReadDB(ctx).Clauses(o.clauses...).Where(query, args...).Count(&count)
   138  	if d.CanIgnore(result.Error) {
   139  		return 0, nil
   140  	}
   141  	return count, d.IgnoreErr(result.Error)
   142  }
   143  
   144  func (d *dal[T, TS]) Pluck(ctx context.Context, column string, dest any,
   145  	query any, args ...any) error {
   146  	o, args := d.parseOptionFromArgs(args...)
   147  	ctx = context.WithValue(ctx, fusCtx.KeyDALOption, o)
   148  
   149  	result := d.ReadDB(ctx).Clauses(o.clauses...).Where(query, args...).Pluck(column, dest)
   150  	return d.IgnoreErr(result.Error)
   151  }
   152  
   153  func (d *dal[T, TS]) Take(ctx context.Context, dest any, conds ...any) error {
   154  	o, args := d.parseOptionFromArgs(conds...)
   155  	ctx = context.WithValue(ctx, fusCtx.KeyDALOption, o)
   156  
   157  	result := d.ReadDB(ctx).Clauses(o.clauses...).Take(dest, args...)
   158  	return d.IgnoreErr(result.Error)
   159  }
   160  
   161  func (d *dal[T, TS]) InsertOne(ctx context.Context, mod *T, opts ...utils.OptionExtender) error {
   162  	o := utils.ApplyOptions[mysqlDALOption](opts...)
   163  	ctx = context.WithValue(ctx, fusCtx.KeyDALOption, o)
   164  	return d.WriteDB(ctx).Clauses(o.clauses...).Create(mod).Error
   165  }
   166  
   167  func (d *dal[T, TS]) InsertInBatches(ctx context.Context,
   168  	modList TS, batchSize int, opts ...utils.OptionExtender) error {
   169  	o := utils.ApplyOptions[mysqlDALOption](opts...)
   170  	ctx = context.WithValue(ctx, fusCtx.KeyDALOption, o)
   171  	sharded, err := d.writeWithTableSharding(ctx, modList)
   172  	if err != nil {
   173  		return err
   174  	}
   175  	for _, mList := range sharded {
   176  		if err = d.WriteDB(ctx).Clauses(o.clauses...).CreateInBatches(mList, batchSize).Error; err != nil {
   177  			return err
   178  		}
   179  	}
   180  
   181  	return nil
   182  }
   183  
   184  func (d *dal[T, TS]) FirstOrCreate(ctx context.Context, mod *T, conds ...any) (int64, error) {
   185  	o, conds := d.parseOptionFromArgs(conds...)
   186  	ctx = context.WithValue(ctx, fusCtx.KeyDALOption, o)
   187  	result := d.WriteDB(ctx).Clauses(o.clauses...).FirstOrCreate(mod, conds...)
   188  	return result.RowsAffected, result.Error
   189  }
   190  
   191  // Save create or update model
   192  // Only support for passing in *mod, []*mod, [...]*mod, it's recommended to only use *mod to call this method.
   193  // If using mod, []mod, since it's value passing, the upper layer will not be able to
   194  // obtain the auto-incremented id from create or other fields filled in by the lower layer.
   195  // If using [...]mod, it will trigger panic: using unaddressable error.
   196  // In official usage, both mod and [...]mod will trigger panic: using unaddressable error.
   197  func (d *dal[T, TS]) Save(ctx context.Context, mod any, opts ...utils.OptionExtender) error {
   198  	// Translate the struct to slice to follow the insert into with ON DUPLICATE KEY UPDATE
   199  	mList, ok := d.convertAnyToTS(mod)
   200  	if !ok {
   201  		mList = utils.SliceConvert(mod, reflect.TypeOf(TS{})).(TS)
   202  	}
   203  	if len(mList) == 0 {
   204  		return nil
   205  	}
   206  	o := utils.ApplyOptions[mysqlDALOption](opts...)
   207  	ctx = context.WithValue(ctx, fusCtx.KeyDALOption, o)
   208  	sharded, err := d.writeWithTableSharding(ctx, mList)
   209  	if err != nil {
   210  		return err
   211  	}
   212  	for _, mList := range sharded {
   213  		if err = d.WriteDB(ctx).Clauses(o.clauses...).Save(mList).Error; err != nil {
   214  			return err
   215  		}
   216  	}
   217  
   218  	return nil
   219  }
   220  
   221  func (d *dal[T, TS]) Update(ctx context.Context, column string, value any,
   222  	query any, args ...any) (int64, error) {
   223  	o, args := d.parseOptionFromArgs(args...)
   224  	ctx = context.WithValue(ctx, fusCtx.KeyDALOption, o)
   225  	u := d.WriteDB(ctx).Clauses(o.clauses...).Where(query, args...).Update(column, value)
   226  	return u.RowsAffected, u.Error
   227  }
   228  
   229  func (d *dal[T, TS]) Updates(ctx context.Context, columns map[string]any,
   230  	query any, args ...any) (int64, error) {
   231  	o, args := d.parseOptionFromArgs(args...)
   232  	ctx = context.WithValue(ctx, fusCtx.KeyDALOption, o)
   233  	u := d.WriteDB(ctx).Clauses(o.clauses...).Where(query, args...).Updates(columns)
   234  	return u.RowsAffected, u.Error
   235  }
   236  
   237  func (d *dal[T, TS]) Delete(ctx context.Context, query any, args ...any) (int64, error) {
   238  	o, args := d.parseOptionFromArgs(args...)
   239  	ctx = context.WithValue(ctx, fusCtx.KeyDALOption, o)
   240  	mList, ok := d.convertAnyToTS(query)
   241  	if !ok || len(mList) == 0 {
   242  		deleted := d.WriteDB(ctx).Clauses(o.clauses...).Where(query, args...).Delete(d.Model())
   243  		return deleted.RowsAffected, deleted.Error
   244  	} else {
   245  		sharded, err := d.writeWithTableSharding(ctx, mList)
   246  		if err != nil {
   247  			return 0, err
   248  		}
   249  		var rowAffected int64
   250  		for _, mList := range sharded {
   251  			deleted := d.WriteDB(ctx).Clauses(o.clauses...).Delete(mList, args...)
   252  			if deleted.Error != nil {
   253  				return rowAffected, deleted.Error
   254  			}
   255  			rowAffected += deleted.RowsAffected
   256  		}
   257  		return rowAffected, nil
   258  	}
   259  }
   260  
   261  func (d *dal[T, TS]) Transaction(ctx context.Context, fc func(context.Context) error,
   262  	opts ...utils.OptionExtender) error {
   263  	orm := GetCtxGormDBByNameList(ctx, []string{d.writeDBName, d.readDBName})
   264  	o := utils.ApplyOptions[mysqlDALOption](opts...)
   265  	if orm == nil {
   266  		if o.useWriteDB {
   267  			orm = Use(ctx, d.writeDBName, AppName(d.appName))
   268  		} else {
   269  			orm = Use(ctx, d.readDBName, AppName(d.appName))
   270  		}
   271  	}
   272  
   273  	return d.unscopedGormDB(orm.GetProxy().WithContext(ctx), o).Transaction(func(tx *gorm.DB) error {
   274  		return fc(SetCtxGormDB(ctx, &DB{
   275  			DB:                   &ormDrv.DB{DB: tx},
   276  			Name:                 orm.Name,
   277  			tableShardingPlugins: orm.tableShardingPlugins,
   278  		}))
   279  	})
   280  }
   281  
   282  func (d *dal[T, TS]) ReadDB(ctx context.Context) *gorm.DB {
   283  	o, _ := ctx.Value(fusCtx.KeyDALOption).(*mysqlDALOption)
   284  	dbName := d.readDBName
   285  	if o != nil && o.useWriteDB {
   286  		dbName = d.writeDBName
   287  	}
   288  	if orm := GetCtxGormDBByName(ctx, dbName); orm != nil {
   289  		return d.unscopedGormDB(orm.Model(d.Model()), o).WithContext(ctx)
   290  	}
   291  	return d.unscopedGormDB(Use(ctx, dbName, AppName(d.appName)).WithContext(ctx).Model(d.Model()), o)
   292  }
   293  func (d *dal[T, TS]) WriteDB(ctx context.Context) *gorm.DB {
   294  	o, _ := ctx.Value(fusCtx.KeyDALOption).(*mysqlDALOption)
   295  	if orm := GetCtxGormDBByName(ctx, d.writeDBName); orm != nil {
   296  		return d.unscopedGormDB(orm.Model(d.Model()), o).WithContext(ctx)
   297  	}
   298  
   299  	return d.unscopedGormDB(Use(ctx, d.writeDBName, AppName(d.appName)).WithContext(ctx).Model(d.Model()), o)
   300  }
   301  func (d *dal[T, TS]) SetCtxReadDB(src context.Context) (dst context.Context) {
   302  	if orm := GetCtxGormDBByName(src, d.readDBName); orm != nil {
   303  		return src
   304  	}
   305  
   306  	return SetCtxGormDB(src, Use(src, d.readDBName, AppName(d.appName)))
   307  }
   308  func (d *dal[T, TS]) SetCtxWriteDB(src context.Context) (dst context.Context) {
   309  	if orm := GetCtxGormDBByName(src, d.writeDBName); orm != nil {
   310  		return src
   311  	}
   312  	return SetCtxGormDB(src, Use(src, d.writeDBName, AppName(d.appName)))
   313  }
   314  
   315  func (d *dal[T, TS]) Model() *T      { return new(T) }
   316  func (d *dal[T, TS]) ModelSlice() TS { return make(TS, 0) }
   317  func (d *dal[T, TS]) IgnoreErr(err error) error {
   318  	if errors.Is(err, gorm.ErrRecordNotFound) {
   319  		return nil
   320  	}
   321  	return err
   322  }
   323  func (d *dal[T, TS]) CanIgnore(err error) bool { return errors.Is(err, gorm.ErrRecordNotFound) }
   324  
   325  func (d *dal[T, TS]) ShardingByValues(ctx context.Context, src []map[string]any) (
   326  	dst map[string][]map[string]any, err error) {
   327  	writeDB := d.writeDB(ctx)
   328  	tableName := d.tableName(writeDB, new(T))
   329  	tableShardingPlugin, ok := writeDB.tableShardingPlugins[tableName]
   330  	if !ok {
   331  		return map[string][]map[string]any{tableName: src}, nil
   332  	}
   333  	return tableShardingPlugin.ShardingByValues(ctx, src)
   334  }
   335  func (d *dal[T, TS]) ShardingIDGen(ctx context.Context) (id uint64, err error) {
   336  	writeDB := d.writeDB(ctx)
   337  	tableName := d.tableName(writeDB, new(T))
   338  	tableShardingPlugin, ok := writeDB.tableShardingPlugins[tableName]
   339  	if !ok {
   340  		return 0, plugins.ErrIDGeneratorNotFound
   341  	}
   342  	return tableShardingPlugin.ShardingIDGen(ctx)
   343  }
   344  func (d *dal[T, TS]) ShardingIDListGen(ctx context.Context, amount int) (idList []uint64, err error) {
   345  	writeDB := d.writeDB(ctx)
   346  	tableName := d.tableName(writeDB, new(T))
   347  	tableShardingPlugin, ok := writeDB.tableShardingPlugins[tableName]
   348  	if !ok {
   349  		return nil, plugins.ErrIDGeneratorNotFound
   350  	}
   351  	idList = make([]uint64, 0, amount)
   352  	for i := 0; i < amount; i++ {
   353  		id, err := tableShardingPlugin.ShardingIDGen(ctx)
   354  		if err != nil {
   355  			return nil, err
   356  		}
   357  		idList = append(idList, id)
   358  	}
   359  	return
   360  }
   361  func (d *dal[T, TS]) ShardingByModelList(ctx context.Context, src TS) (dst map[string]TS, err error) {
   362  	if len(src) == 0 {
   363  		return make(map[string]TS), nil
   364  	}
   365  	writeDB := d.writeDB(ctx)
   366  	tableName := d.tableName(writeDB, src[0])
   367  	shardingPlugin, ok := writeDB.tableShardingPlugins[tableName]
   368  	if !ok {
   369  		return map[string]TS{tableName: src}, nil
   370  	}
   371  	sharded, err := shardingPlugin.ShardingByModelList(ctx, utils.SliceMapping(src, func(t *T) any { return t })...)
   372  	if err != nil {
   373  		return
   374  	}
   375  	dst = make(map[string]TS, len(sharded))
   376  	for suffix, item := range sharded {
   377  		shardingTableName := tableName + suffix
   378  		dst[shardingTableName] = TS(utils.SliceMapping(item, func(t any) *T { return t.(*T) }))
   379  	}
   380  	return
   381  }
   382  
   383  func (d *dal[T, TS]) writeDB(ctx context.Context) *DB {
   384  	if orm := GetCtxGormDBByName(ctx, d.writeDBName); orm != nil {
   385  		return orm
   386  	}
   387  
   388  	return Use(ctx, d.writeDBName, AppName(d.appName))
   389  }
   390  func (d *dal[T, TS]) writeWithTableSharding(ctx context.Context, src TS) (dst []TS, err error) {
   391  	if len(src) == 0 {
   392  		return
   393  	}
   394  	writeDB := d.writeDB(ctx)
   395  	shardingPlugin, ok := writeDB.tableShardingPlugins[d.tableName(writeDB, src[0])]
   396  	if !ok {
   397  		return []TS{src}, nil
   398  	}
   399  
   400  	sharded, err := shardingPlugin.ShardingByModelList(ctx, utils.SliceMapping(src, func(t *T) any { return t })...)
   401  	if err != nil {
   402  		return
   403  	}
   404  	for _, item := range sharded {
   405  		dst = append(dst, utils.SliceMapping(item, func(t any) *T { return t.(*T) }))
   406  	}
   407  	return
   408  }
   409  func (d *dal[T, TS]) tableName(db *DB, mod *T) (name string) {
   410  	if tabler, ok := any(mod).(schema.Tabler); ok {
   411  		name = tabler.TableName()
   412  	}
   413  	if tabler, ok := any(mod).(schema.TablerWithNamer); ok {
   414  		name = tabler.TableName(db.NamingStrategy)
   415  	}
   416  	// TODO: check if embeddedNamer valid
   417  	embeddedNamer := inspect.TypeOf("gorm.io/gorm/schema.embeddedNamer")
   418  	namingStrategy := reflect.ValueOf(db.NamingStrategy)
   419  	if namingStrategy.CanConvert(embeddedNamer) {
   420  		name = namingStrategy.Convert(embeddedNamer).FieldByName("Table").String()
   421  	}
   422  	return
   423  }
   424  func (d *dal[T, TS]) convertAnyToTS(query any) (mList TS, ok bool) {
   425  	switch q := query.(type) {
   426  	case TS:
   427  		ok = true
   428  		mList = q
   429  	case []*T:
   430  		ok = true
   431  		mList = TS(q)
   432  	case []T:
   433  		ok = true
   434  		mList = TS(utils.SliceMapping(q, func(t T) *T { return &t }))
   435  	case T:
   436  		ok = true
   437  		mList = TS{&q}
   438  	case *T:
   439  		ok = true
   440  		mList = TS{q}
   441  	}
   442  	return
   443  }
   444  func (d *dal[T, TS]) unscopedGormDB(src *gorm.DB, o *mysqlDALOption) (dst *gorm.DB) {
   445  	if o != nil && o.unscoped {
   446  		return src.Unscoped()
   447  	}
   448  	return src
   449  }
   450  
   451  type mysqlDALOption struct {
   452  	unscoped   bool
   453  	useWriteDB bool
   454  	clauses    []clause.Expression
   455  }
   456  
   457  func Unscoped() utils.OptionFunc[mysqlDALOption] {
   458  	return func(m *mysqlDALOption) {
   459  		m.unscoped = true
   460  	}
   461  }
   462  
   463  func Clauses(clauses ...clause.Expression) utils.OptionFunc[mysqlDALOption] {
   464  	return func(m *mysqlDALOption) {
   465  		m.clauses = append(m.clauses, clauses...)
   466  	}
   467  }
   468  
   469  func WriteDB() utils.OptionFunc[mysqlDALOption] {
   470  	return func(m *mysqlDALOption) {
   471  		m.useWriteDB = true
   472  	}
   473  }
   474  
   475  func (d *dal[T, TS]) parseOptionFromArgs(args ...any) (o *mysqlDALOption, r []any) {
   476  	o = new(mysqlDALOption)
   477  	r = make([]any, 0, len(args))
   478  	for _, arg := range args {
   479  		if reflect.TypeOf(arg).Implements(gormClauseExpressionType) {
   480  			o.clauses = append(o.clauses, arg.(clause.Expression))
   481  			continue
   482  		}
   483  
   484  		switch v := arg.(type) {
   485  		case utils.OptionFunc[mysqlDALOption]:
   486  			v(o)
   487  		default:
   488  			r = append(r, arg)
   489  		}
   490  	}
   491  	return
   492  }