github.com/wfusion/gofusion@v1.1.14/db/dal.go (about) 1 package db 2 3 import ( 4 "context" 5 "reflect" 6 7 "github.com/pkg/errors" 8 "gorm.io/gorm" 9 "gorm.io/gorm/clause" 10 "gorm.io/gorm/schema" 11 12 "github.com/wfusion/gofusion/common/utils" 13 "github.com/wfusion/gofusion/common/utils/inspect" 14 "github.com/wfusion/gofusion/db/plugins" 15 16 ormDrv "github.com/wfusion/gofusion/common/infra/drivers/orm" 17 fusCtx "github.com/wfusion/gofusion/context" 18 ) 19 20 // DalInterface 21 //nolint: revive // interface issue 22 type DalInterface[T any, TS ~[]*T] interface { 23 Query(ctx context.Context, query any, args ...any) (TS, error) 24 QueryFirst(ctx context.Context, query any, args ...any) (*T, error) 25 QueryLast(ctx context.Context, query any, args ...any) (*T, error) 26 QueryInBatches(ctx context.Context, batchSize int, fc func(tx *DB, batch int, found TS) error, query any, args ...any) error 27 Count(ctx context.Context, query any, args ...any) (int64, error) 28 Pluck(ctx context.Context, column string, dest any, query any, args ...any) error 29 Take(ctx context.Context, dest any, conds ...any) error 30 InsertOne(ctx context.Context, mod *T, opts ...utils.OptionExtender) error 31 InsertInBatches(ctx context.Context, modList TS, batchSize int, opts ...utils.OptionExtender) error 32 Save(ctx context.Context, mod any, opts ...utils.OptionExtender) error 33 Update(ctx context.Context, column string, value any, query any, args ...any) (int64, error) 34 Updates(ctx context.Context, columns map[string]any, query any, args ...any) (int64, error) 35 Delete(ctx context.Context, query any, args ...any) (int64, error) 36 FirstOrCreate(ctx context.Context, mod *T, conds ...any) (int64, error) 37 Transaction(ctx context.Context, fc func(tx context.Context) error, opts ...utils.OptionExtender) error 38 ReadDB(ctx context.Context) *gorm.DB 39 WriteDB(ctx context.Context) *gorm.DB 40 SetCtxReadDB(src context.Context) (dst context.Context) 41 SetCtxWriteDB(src context.Context) (dst context.Context) 42 Model() *T 43 ModelSlice() TS 44 IgnoreErr(err error) error 45 CanIgnore(err error) bool 46 ShardingByValues(ctx context.Context, src []map[string]any) (dst map[string][]map[string]any, err error) 47 ShardingIDGen(ctx context.Context) (id uint64, err error) 48 ShardingIDListGen(ctx context.Context, amount int) (idList []uint64, err error) 49 ShardingByModelList(ctx context.Context, src TS) (dst map[string]TS, err error) 50 } 51 52 type dal[T any, TS ~[]*T] struct { 53 appName string 54 readDBName string 55 writeDBName string 56 } 57 58 func NewDAL[T any, TS ~[]*T](readDBName, writeDBName string, opts ...utils.OptionExtender) DalInterface[T, TS] { 59 instance := new(T) 60 if _, ok := any(instance).(schema.Tabler); !ok { 61 panic(errors.Errorf("model unimplement schema.Tabler [model[%T] read_db[%s] write_db[%s]]", 62 instance, readDBName, writeDBName)) 63 } 64 opt := utils.ApplyOptions[useOption](opts...) 65 return &dal[T, TS]{ 66 appName: opt.appName, 67 readDBName: readDBName, 68 writeDBName: writeDBName, 69 } 70 } 71 72 func (d *dal[T, TS]) Query(ctx context.Context, query any, args ...any) (TS, error) { 73 o, args := d.parseOptionFromArgs(args...) 74 ctx = context.WithValue(ctx, fusCtx.KeyDALOption, o) 75 76 found := d.ModelSlice() 77 result := d.ReadDB(ctx).Clauses(o.clauses...).Where(query, args...).Find(&found) 78 if d.CanIgnore(result.Error) { 79 return nil, nil 80 } 81 return found, d.IgnoreErr(result.Error) 82 } 83 84 func (d *dal[T, TS]) QueryLast(ctx context.Context, query any, args ...any) (*T, error) { 85 o, args := d.parseOptionFromArgs(args...) 86 ctx = context.WithValue(ctx, fusCtx.KeyDALOption, o) 87 88 found := d.Model() 89 result := d.ReadDB(ctx).Clauses(o.clauses...).Where(query, args...).Last(found) 90 if d.CanIgnore(result.Error) { 91 return nil, nil 92 } 93 return found, d.IgnoreErr(result.Error) 94 } 95 96 func (d *dal[T, TS]) QueryFirst(ctx context.Context, query any, args ...any) (*T, error) { 97 o, args := d.parseOptionFromArgs(args...) 98 ctx = context.WithValue(ctx, fusCtx.KeyDALOption, o) 99 100 found := d.Model() 101 result := d.ReadDB(ctx).Clauses(o.clauses...).Where(query, args...).First(found) 102 if d.CanIgnore(result.Error) { 103 return nil, nil 104 } 105 return found, d.IgnoreErr(result.Error) 106 } 107 108 func (d *dal[T, TS]) QueryInBatches(ctx context.Context, batchSize int, 109 fc func(tx *DB, batch int, found TS) error, query any, args ...any) (err error) { 110 o, args := d.parseOptionFromArgs(args...) 111 ctx = context.WithValue(ctx, fusCtx.KeyDALOption, o) 112 113 orm := Use(ctx, d.readDBName, AppName(d.appName)) 114 found := make(TS, 0, batchSize) 115 result := d.ReadDB(ctx).Clauses(o.clauses...).Where(query, args...).FindInBatches(&found, batchSize, 116 func(tx *gorm.DB, batch int) error { 117 wrapper := &DB{ 118 DB: &ormDrv.DB{DB: tx}, 119 Name: orm.Name, 120 tableShardingPlugins: orm.tableShardingPlugins, 121 } 122 return fc(wrapper, batch, found) 123 }, 124 ) 125 if d.CanIgnore(result.Error) { 126 return 127 } 128 return d.IgnoreErr(result.Error) 129 } 130 131 func (d *dal[T, TS]) Count(ctx context.Context, query any, args ...any) (int64, error) { 132 var count int64 133 134 o, args := d.parseOptionFromArgs(args...) 135 ctx = context.WithValue(ctx, fusCtx.KeyDALOption, o) 136 137 result := d.ReadDB(ctx).Clauses(o.clauses...).Where(query, args...).Count(&count) 138 if d.CanIgnore(result.Error) { 139 return 0, nil 140 } 141 return count, d.IgnoreErr(result.Error) 142 } 143 144 func (d *dal[T, TS]) Pluck(ctx context.Context, column string, dest any, 145 query any, args ...any) error { 146 o, args := d.parseOptionFromArgs(args...) 147 ctx = context.WithValue(ctx, fusCtx.KeyDALOption, o) 148 149 result := d.ReadDB(ctx).Clauses(o.clauses...).Where(query, args...).Pluck(column, dest) 150 return d.IgnoreErr(result.Error) 151 } 152 153 func (d *dal[T, TS]) Take(ctx context.Context, dest any, conds ...any) error { 154 o, args := d.parseOptionFromArgs(conds...) 155 ctx = context.WithValue(ctx, fusCtx.KeyDALOption, o) 156 157 result := d.ReadDB(ctx).Clauses(o.clauses...).Take(dest, args...) 158 return d.IgnoreErr(result.Error) 159 } 160 161 func (d *dal[T, TS]) InsertOne(ctx context.Context, mod *T, opts ...utils.OptionExtender) error { 162 o := utils.ApplyOptions[mysqlDALOption](opts...) 163 ctx = context.WithValue(ctx, fusCtx.KeyDALOption, o) 164 return d.WriteDB(ctx).Clauses(o.clauses...).Create(mod).Error 165 } 166 167 func (d *dal[T, TS]) InsertInBatches(ctx context.Context, 168 modList TS, batchSize int, opts ...utils.OptionExtender) error { 169 o := utils.ApplyOptions[mysqlDALOption](opts...) 170 ctx = context.WithValue(ctx, fusCtx.KeyDALOption, o) 171 sharded, err := d.writeWithTableSharding(ctx, modList) 172 if err != nil { 173 return err 174 } 175 for _, mList := range sharded { 176 if err = d.WriteDB(ctx).Clauses(o.clauses...).CreateInBatches(mList, batchSize).Error; err != nil { 177 return err 178 } 179 } 180 181 return nil 182 } 183 184 func (d *dal[T, TS]) FirstOrCreate(ctx context.Context, mod *T, conds ...any) (int64, error) { 185 o, conds := d.parseOptionFromArgs(conds...) 186 ctx = context.WithValue(ctx, fusCtx.KeyDALOption, o) 187 result := d.WriteDB(ctx).Clauses(o.clauses...).FirstOrCreate(mod, conds...) 188 return result.RowsAffected, result.Error 189 } 190 191 // Save create or update model 192 // Only support for passing in *mod, []*mod, [...]*mod, it's recommended to only use *mod to call this method. 193 // If using mod, []mod, since it's value passing, the upper layer will not be able to 194 // obtain the auto-incremented id from create or other fields filled in by the lower layer. 195 // If using [...]mod, it will trigger panic: using unaddressable error. 196 // In official usage, both mod and [...]mod will trigger panic: using unaddressable error. 197 func (d *dal[T, TS]) Save(ctx context.Context, mod any, opts ...utils.OptionExtender) error { 198 // Translate the struct to slice to follow the insert into with ON DUPLICATE KEY UPDATE 199 mList, ok := d.convertAnyToTS(mod) 200 if !ok { 201 mList = utils.SliceConvert(mod, reflect.TypeOf(TS{})).(TS) 202 } 203 if len(mList) == 0 { 204 return nil 205 } 206 o := utils.ApplyOptions[mysqlDALOption](opts...) 207 ctx = context.WithValue(ctx, fusCtx.KeyDALOption, o) 208 sharded, err := d.writeWithTableSharding(ctx, mList) 209 if err != nil { 210 return err 211 } 212 for _, mList := range sharded { 213 if err = d.WriteDB(ctx).Clauses(o.clauses...).Save(mList).Error; err != nil { 214 return err 215 } 216 } 217 218 return nil 219 } 220 221 func (d *dal[T, TS]) Update(ctx context.Context, column string, value any, 222 query any, args ...any) (int64, error) { 223 o, args := d.parseOptionFromArgs(args...) 224 ctx = context.WithValue(ctx, fusCtx.KeyDALOption, o) 225 u := d.WriteDB(ctx).Clauses(o.clauses...).Where(query, args...).Update(column, value) 226 return u.RowsAffected, u.Error 227 } 228 229 func (d *dal[T, TS]) Updates(ctx context.Context, columns map[string]any, 230 query any, args ...any) (int64, error) { 231 o, args := d.parseOptionFromArgs(args...) 232 ctx = context.WithValue(ctx, fusCtx.KeyDALOption, o) 233 u := d.WriteDB(ctx).Clauses(o.clauses...).Where(query, args...).Updates(columns) 234 return u.RowsAffected, u.Error 235 } 236 237 func (d *dal[T, TS]) Delete(ctx context.Context, query any, args ...any) (int64, error) { 238 o, args := d.parseOptionFromArgs(args...) 239 ctx = context.WithValue(ctx, fusCtx.KeyDALOption, o) 240 mList, ok := d.convertAnyToTS(query) 241 if !ok || len(mList) == 0 { 242 deleted := d.WriteDB(ctx).Clauses(o.clauses...).Where(query, args...).Delete(d.Model()) 243 return deleted.RowsAffected, deleted.Error 244 } else { 245 sharded, err := d.writeWithTableSharding(ctx, mList) 246 if err != nil { 247 return 0, err 248 } 249 var rowAffected int64 250 for _, mList := range sharded { 251 deleted := d.WriteDB(ctx).Clauses(o.clauses...).Delete(mList, args...) 252 if deleted.Error != nil { 253 return rowAffected, deleted.Error 254 } 255 rowAffected += deleted.RowsAffected 256 } 257 return rowAffected, nil 258 } 259 } 260 261 func (d *dal[T, TS]) Transaction(ctx context.Context, fc func(context.Context) error, 262 opts ...utils.OptionExtender) error { 263 orm := GetCtxGormDBByNameList(ctx, []string{d.writeDBName, d.readDBName}) 264 o := utils.ApplyOptions[mysqlDALOption](opts...) 265 if orm == nil { 266 if o.useWriteDB { 267 orm = Use(ctx, d.writeDBName, AppName(d.appName)) 268 } else { 269 orm = Use(ctx, d.readDBName, AppName(d.appName)) 270 } 271 } 272 273 return d.unscopedGormDB(orm.GetProxy().WithContext(ctx), o).Transaction(func(tx *gorm.DB) error { 274 return fc(SetCtxGormDB(ctx, &DB{ 275 DB: &ormDrv.DB{DB: tx}, 276 Name: orm.Name, 277 tableShardingPlugins: orm.tableShardingPlugins, 278 })) 279 }) 280 } 281 282 func (d *dal[T, TS]) ReadDB(ctx context.Context) *gorm.DB { 283 o, _ := ctx.Value(fusCtx.KeyDALOption).(*mysqlDALOption) 284 dbName := d.readDBName 285 if o != nil && o.useWriteDB { 286 dbName = d.writeDBName 287 } 288 if orm := GetCtxGormDBByName(ctx, dbName); orm != nil { 289 return d.unscopedGormDB(orm.Model(d.Model()), o).WithContext(ctx) 290 } 291 return d.unscopedGormDB(Use(ctx, dbName, AppName(d.appName)).WithContext(ctx).Model(d.Model()), o) 292 } 293 func (d *dal[T, TS]) WriteDB(ctx context.Context) *gorm.DB { 294 o, _ := ctx.Value(fusCtx.KeyDALOption).(*mysqlDALOption) 295 if orm := GetCtxGormDBByName(ctx, d.writeDBName); orm != nil { 296 return d.unscopedGormDB(orm.Model(d.Model()), o).WithContext(ctx) 297 } 298 299 return d.unscopedGormDB(Use(ctx, d.writeDBName, AppName(d.appName)).WithContext(ctx).Model(d.Model()), o) 300 } 301 func (d *dal[T, TS]) SetCtxReadDB(src context.Context) (dst context.Context) { 302 if orm := GetCtxGormDBByName(src, d.readDBName); orm != nil { 303 return src 304 } 305 306 return SetCtxGormDB(src, Use(src, d.readDBName, AppName(d.appName))) 307 } 308 func (d *dal[T, TS]) SetCtxWriteDB(src context.Context) (dst context.Context) { 309 if orm := GetCtxGormDBByName(src, d.writeDBName); orm != nil { 310 return src 311 } 312 return SetCtxGormDB(src, Use(src, d.writeDBName, AppName(d.appName))) 313 } 314 315 func (d *dal[T, TS]) Model() *T { return new(T) } 316 func (d *dal[T, TS]) ModelSlice() TS { return make(TS, 0) } 317 func (d *dal[T, TS]) IgnoreErr(err error) error { 318 if errors.Is(err, gorm.ErrRecordNotFound) { 319 return nil 320 } 321 return err 322 } 323 func (d *dal[T, TS]) CanIgnore(err error) bool { return errors.Is(err, gorm.ErrRecordNotFound) } 324 325 func (d *dal[T, TS]) ShardingByValues(ctx context.Context, src []map[string]any) ( 326 dst map[string][]map[string]any, err error) { 327 writeDB := d.writeDB(ctx) 328 tableName := d.tableName(writeDB, new(T)) 329 tableShardingPlugin, ok := writeDB.tableShardingPlugins[tableName] 330 if !ok { 331 return map[string][]map[string]any{tableName: src}, nil 332 } 333 return tableShardingPlugin.ShardingByValues(ctx, src) 334 } 335 func (d *dal[T, TS]) ShardingIDGen(ctx context.Context) (id uint64, err error) { 336 writeDB := d.writeDB(ctx) 337 tableName := d.tableName(writeDB, new(T)) 338 tableShardingPlugin, ok := writeDB.tableShardingPlugins[tableName] 339 if !ok { 340 return 0, plugins.ErrIDGeneratorNotFound 341 } 342 return tableShardingPlugin.ShardingIDGen(ctx) 343 } 344 func (d *dal[T, TS]) ShardingIDListGen(ctx context.Context, amount int) (idList []uint64, err error) { 345 writeDB := d.writeDB(ctx) 346 tableName := d.tableName(writeDB, new(T)) 347 tableShardingPlugin, ok := writeDB.tableShardingPlugins[tableName] 348 if !ok { 349 return nil, plugins.ErrIDGeneratorNotFound 350 } 351 idList = make([]uint64, 0, amount) 352 for i := 0; i < amount; i++ { 353 id, err := tableShardingPlugin.ShardingIDGen(ctx) 354 if err != nil { 355 return nil, err 356 } 357 idList = append(idList, id) 358 } 359 return 360 } 361 func (d *dal[T, TS]) ShardingByModelList(ctx context.Context, src TS) (dst map[string]TS, err error) { 362 if len(src) == 0 { 363 return make(map[string]TS), nil 364 } 365 writeDB := d.writeDB(ctx) 366 tableName := d.tableName(writeDB, src[0]) 367 shardingPlugin, ok := writeDB.tableShardingPlugins[tableName] 368 if !ok { 369 return map[string]TS{tableName: src}, nil 370 } 371 sharded, err := shardingPlugin.ShardingByModelList(ctx, utils.SliceMapping(src, func(t *T) any { return t })...) 372 if err != nil { 373 return 374 } 375 dst = make(map[string]TS, len(sharded)) 376 for suffix, item := range sharded { 377 shardingTableName := tableName + suffix 378 dst[shardingTableName] = TS(utils.SliceMapping(item, func(t any) *T { return t.(*T) })) 379 } 380 return 381 } 382 383 func (d *dal[T, TS]) writeDB(ctx context.Context) *DB { 384 if orm := GetCtxGormDBByName(ctx, d.writeDBName); orm != nil { 385 return orm 386 } 387 388 return Use(ctx, d.writeDBName, AppName(d.appName)) 389 } 390 func (d *dal[T, TS]) writeWithTableSharding(ctx context.Context, src TS) (dst []TS, err error) { 391 if len(src) == 0 { 392 return 393 } 394 writeDB := d.writeDB(ctx) 395 shardingPlugin, ok := writeDB.tableShardingPlugins[d.tableName(writeDB, src[0])] 396 if !ok { 397 return []TS{src}, nil 398 } 399 400 sharded, err := shardingPlugin.ShardingByModelList(ctx, utils.SliceMapping(src, func(t *T) any { return t })...) 401 if err != nil { 402 return 403 } 404 for _, item := range sharded { 405 dst = append(dst, utils.SliceMapping(item, func(t any) *T { return t.(*T) })) 406 } 407 return 408 } 409 func (d *dal[T, TS]) tableName(db *DB, mod *T) (name string) { 410 if tabler, ok := any(mod).(schema.Tabler); ok { 411 name = tabler.TableName() 412 } 413 if tabler, ok := any(mod).(schema.TablerWithNamer); ok { 414 name = tabler.TableName(db.NamingStrategy) 415 } 416 // TODO: check if embeddedNamer valid 417 embeddedNamer := inspect.TypeOf("gorm.io/gorm/schema.embeddedNamer") 418 namingStrategy := reflect.ValueOf(db.NamingStrategy) 419 if namingStrategy.CanConvert(embeddedNamer) { 420 name = namingStrategy.Convert(embeddedNamer).FieldByName("Table").String() 421 } 422 return 423 } 424 func (d *dal[T, TS]) convertAnyToTS(query any) (mList TS, ok bool) { 425 switch q := query.(type) { 426 case TS: 427 ok = true 428 mList = q 429 case []*T: 430 ok = true 431 mList = TS(q) 432 case []T: 433 ok = true 434 mList = TS(utils.SliceMapping(q, func(t T) *T { return &t })) 435 case T: 436 ok = true 437 mList = TS{&q} 438 case *T: 439 ok = true 440 mList = TS{q} 441 } 442 return 443 } 444 func (d *dal[T, TS]) unscopedGormDB(src *gorm.DB, o *mysqlDALOption) (dst *gorm.DB) { 445 if o != nil && o.unscoped { 446 return src.Unscoped() 447 } 448 return src 449 } 450 451 type mysqlDALOption struct { 452 unscoped bool 453 useWriteDB bool 454 clauses []clause.Expression 455 } 456 457 func Unscoped() utils.OptionFunc[mysqlDALOption] { 458 return func(m *mysqlDALOption) { 459 m.unscoped = true 460 } 461 } 462 463 func Clauses(clauses ...clause.Expression) utils.OptionFunc[mysqlDALOption] { 464 return func(m *mysqlDALOption) { 465 m.clauses = append(m.clauses, clauses...) 466 } 467 } 468 469 func WriteDB() utils.OptionFunc[mysqlDALOption] { 470 return func(m *mysqlDALOption) { 471 m.useWriteDB = true 472 } 473 } 474 475 func (d *dal[T, TS]) parseOptionFromArgs(args ...any) (o *mysqlDALOption, r []any) { 476 o = new(mysqlDALOption) 477 r = make([]any, 0, len(args)) 478 for _, arg := range args { 479 if reflect.TypeOf(arg).Implements(gormClauseExpressionType) { 480 o.clauses = append(o.clauses, arg.(clause.Expression)) 481 continue 482 } 483 484 switch v := arg.(type) { 485 case utils.OptionFunc[mysqlDALOption]: 486 v(o) 487 default: 488 r = append(r, arg) 489 } 490 } 491 return 492 }