github.com/mdaxf/iac@v0.0.0-20240519030858-58a061660378/databases/orm/orm.go (about) 1 // The original package is migrated from beego and modified, you can find orignal from following link: 2 // "github.com/beego/beego/" 3 // 4 // Copyright 2023 IAC. All Rights Reserved. 5 // 6 // Licensed under the Apache License, Version 2.0 (the "License"); 7 // you may not use this file except in compliance with the License. 8 // You may obtain a copy of the License at 9 // 10 // http://www.apache.org/licenses/LICENSE-2.0 11 // 12 // Unless required by applicable law or agreed to in writing, software 13 // distributed under the License is distributed on an "AS IS" BASIS, 14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 // See the License for the specific language governing permissions and 16 // limitations under the License. 17 18 //go:build go1.8 19 // +build go1.8 20 21 // Package orm provide ORM for MySQL/PostgreSQL/sqlite 22 // Simple Usage 23 // 24 // package main 25 // 26 // import ( 27 // "fmt" 28 // "github.com/mdaxf/iac/databases/orm" 29 // _ "github.com/go-sql-driver/mysql" // import your used driver 30 // ) 31 // 32 // // Model Struct 33 // type User struct { 34 // Id int `orm:"auto"` 35 // Name string `orm:"size(100)"` 36 // } 37 // 38 // func init() { 39 // orm.RegisterDataBase("default", "mysql", "root:root@/my_db?charset=utf8", 30) 40 // } 41 // 42 // func main() { 43 // o := orm.NewOrm() 44 // user := User{Name: "slene"} 45 // // insert 46 // id, err := o.Insert(&user) 47 // // update 48 // user.Name = "astaxie" 49 // num, err := o.Update(&user) 50 // // read one 51 // u := User{Id: user.Id} 52 // err = o.Read(&u) 53 // // delete 54 // num, err = o.Delete(&u) 55 // } 56 package orm 57 58 import ( 59 "context" 60 "database/sql" 61 "errors" 62 "fmt" 63 "os" 64 "reflect" 65 "time" 66 67 "github.com/mdaxf/iac/databases/orm/clauses" 68 "github.com/mdaxf/iac/databases/orm/hints" 69 logs "github.com/mdaxf/iac/framework/logs" 70 "github.com/mdaxf/iac/framework/utils" 71 ) 72 73 // DebugQueries define the debug 74 const ( 75 DebugQueries = iota 76 ) 77 78 // Define common vars 79 var ( 80 Debug = false 81 DebugLog = NewLog(os.Stdout) 82 DefaultRowsLimit = -1 83 DefaultRelsDepth = 2 84 DefaultTimeLoc = time.Local 85 ErrTxDone = errors.New("<TxOrmer.Commit/Rollback> transaction already done") 86 ErrMultiRows = errors.New("<QuerySeter> return multi rows") 87 ErrNoRows = errors.New("<QuerySeter> no row found") 88 ErrStmtClosed = errors.New("<QuerySeter> stmt already closed") 89 ErrArgs = errors.New("<Ormer> args error may be empty") 90 ErrNotImplement = errors.New("have not implement") 91 92 ErrLastInsertIdUnavailable = errors.New("<Ormer> last insert id is unavailable") 93 ) 94 95 // Params stores the Params 96 type Params map[string]interface{} 97 98 // ParamsList stores paramslist 99 type ParamsList []interface{} 100 101 type ormBase struct { 102 alias *alias 103 db dbQuerier 104 } 105 106 var ( 107 _ DQL = new(ormBase) 108 _ DML = new(ormBase) 109 _ DriverGetter = new(ormBase) 110 ) 111 112 // get model info and model reflect value 113 func (*ormBase) getMi(md interface{}) (mi *modelInfo) { 114 val := reflect.ValueOf(md) 115 ind := reflect.Indirect(val) 116 typ := ind.Type() 117 mi = getTypeMi(typ) 118 return 119 } 120 121 // get need ptr model info and model reflect value 122 func (*ormBase) getPtrMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) { 123 val := reflect.ValueOf(md) 124 ind = reflect.Indirect(val) 125 typ := ind.Type() 126 if val.Kind() != reflect.Ptr { 127 panic(fmt.Errorf("<Ormer> cannot use non-ptr model struct `%s`", getFullName(typ))) 128 } 129 mi = getTypeMi(typ) 130 return 131 } 132 133 func getTypeMi(mdTyp reflect.Type) *modelInfo { 134 name := getFullName(mdTyp) 135 if mi, ok := defaultModelCache.getByFullName(name); ok { 136 return mi 137 } 138 panic(fmt.Errorf("<Ormer> table: `%s` not found, make sure it was registered with `RegisterModel()`", name)) 139 } 140 141 // get field info from model info by given field name 142 func (*ormBase) getFieldInfo(mi *modelInfo, name string) *fieldInfo { 143 fi, ok := mi.fields.GetByAny(name) 144 if !ok { 145 panic(fmt.Errorf("<Ormer> cannot find field `%s` for model `%s`", name, mi.fullName)) 146 } 147 return fi 148 } 149 150 // read data to model 151 func (o *ormBase) Read(md interface{}, cols ...string) error { 152 return o.ReadWithCtx(context.Background(), md, cols...) 153 } 154 155 func (o *ormBase) ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error { 156 mi, ind := o.getPtrMiInd(md) 157 return o.alias.DbBaser.Read(ctx, o.db, mi, ind, o.alias.TZ, cols, false) 158 } 159 160 // read data to model, like Read(), but use "SELECT FOR UPDATE" form 161 func (o *ormBase) ReadForUpdate(md interface{}, cols ...string) error { 162 return o.ReadForUpdateWithCtx(context.Background(), md, cols...) 163 } 164 165 func (o *ormBase) ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error { 166 mi, ind := o.getPtrMiInd(md) 167 return o.alias.DbBaser.Read(ctx, o.db, mi, ind, o.alias.TZ, cols, true) 168 } 169 170 // Try to read a row from the database, or insert one if it doesn't exist 171 func (o *ormBase) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) { 172 return o.ReadOrCreateWithCtx(context.Background(), md, col1, cols...) 173 } 174 175 func (o *ormBase) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) { 176 cols = append([]string{col1}, cols...) 177 mi, ind := o.getPtrMiInd(md) 178 err := o.alias.DbBaser.Read(ctx, o.db, mi, ind, o.alias.TZ, cols, false) 179 if err == ErrNoRows { 180 // Create 181 id, err := o.InsertWithCtx(ctx, md) 182 return err == nil, id, err 183 } 184 185 id, vid := int64(0), ind.FieldByIndex(mi.fields.pk.fieldIndex) 186 if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { 187 id = int64(vid.Uint()) 188 } else if mi.fields.pk.rel { 189 return o.ReadOrCreateWithCtx(ctx, vid.Interface(), mi.fields.pk.relModelInfo.fields.pk.name) 190 } else { 191 id = vid.Int() 192 } 193 194 return false, id, err 195 } 196 197 // insert model data to database 198 func (o *ormBase) Insert(md interface{}) (int64, error) { 199 return o.InsertWithCtx(context.Background(), md) 200 } 201 202 func (o *ormBase) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) { 203 mi, ind := o.getPtrMiInd(md) 204 id, err := o.alias.DbBaser.Insert(ctx, o.db, mi, ind, o.alias.TZ) 205 if err != nil { 206 return id, err 207 } 208 209 o.setPk(mi, ind, id) 210 211 return id, nil 212 } 213 214 // set auto pk field 215 func (*ormBase) setPk(mi *modelInfo, ind reflect.Value, id int64) { 216 if mi.fields.pk.auto { 217 if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { 218 ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(uint64(id)) 219 } else { 220 ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(id) 221 } 222 } 223 } 224 225 // insert some models to database 226 func (o *ormBase) InsertMulti(bulk int, mds interface{}) (int64, error) { 227 return o.InsertMultiWithCtx(context.Background(), bulk, mds) 228 } 229 230 func (o *ormBase) InsertMultiWithCtx(ctx context.Context, bulk int, mds interface{}) (int64, error) { 231 var cnt int64 232 233 sind := reflect.Indirect(reflect.ValueOf(mds)) 234 235 switch sind.Kind() { 236 case reflect.Array, reflect.Slice: 237 if sind.Len() == 0 { 238 return cnt, ErrArgs 239 } 240 default: 241 return cnt, ErrArgs 242 } 243 244 if bulk <= 1 { 245 for i := 0; i < sind.Len(); i++ { 246 ind := reflect.Indirect(sind.Index(i)) 247 mi := o.getMi(ind.Interface()) 248 id, err := o.alias.DbBaser.Insert(ctx, o.db, mi, ind, o.alias.TZ) 249 if err != nil { 250 return cnt, err 251 } 252 253 o.setPk(mi, ind, id) 254 255 cnt++ 256 } 257 } else { 258 mi := o.getMi(sind.Index(0).Interface()) 259 return o.alias.DbBaser.InsertMulti(ctx, o.db, mi, sind, bulk, o.alias.TZ) 260 } 261 return cnt, nil 262 } 263 264 // InsertOrUpdate data to database 265 func (o *ormBase) InsertOrUpdate(md interface{}, colConflictAndArgs ...string) (int64, error) { 266 return o.InsertOrUpdateWithCtx(context.Background(), md, colConflictAndArgs...) 267 } 268 269 func (o *ormBase) InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error) { 270 mi, ind := o.getPtrMiInd(md) 271 id, err := o.alias.DbBaser.InsertOrUpdate(ctx, o.db, mi, ind, o.alias, colConflitAndArgs...) 272 if err != nil { 273 return id, err 274 } 275 276 o.setPk(mi, ind, id) 277 278 return id, nil 279 } 280 281 // update model to database. 282 // cols set the columns those want to update. 283 func (o *ormBase) Update(md interface{}, cols ...string) (int64, error) { 284 return o.UpdateWithCtx(context.Background(), md, cols...) 285 } 286 287 func (o *ormBase) UpdateWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { 288 mi, ind := o.getPtrMiInd(md) 289 return o.alias.DbBaser.Update(ctx, o.db, mi, ind, o.alias.TZ, cols) 290 } 291 292 // delete model in database 293 // cols shows the delete conditions values read from. default is pk 294 func (o *ormBase) Delete(md interface{}, cols ...string) (int64, error) { 295 return o.DeleteWithCtx(context.Background(), md, cols...) 296 } 297 298 func (o *ormBase) DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { 299 mi, ind := o.getPtrMiInd(md) 300 num, err := o.alias.DbBaser.Delete(ctx, o.db, mi, ind, o.alias.TZ, cols) 301 return num, err 302 } 303 304 // create a models to models queryer 305 func (o *ormBase) QueryM2M(md interface{}, name string) QueryM2Mer { 306 mi, ind := o.getPtrMiInd(md) 307 fi := o.getFieldInfo(mi, name) 308 309 switch { 310 case fi.fieldType == RelManyToMany: 311 case fi.fieldType == RelReverseMany && fi.reverseFieldInfo.mi.isThrough: 312 default: 313 panic(fmt.Errorf("<Ormer.QueryM2M> model `%s` . name `%s` is not a m2m field", fi.name, mi.fullName)) 314 } 315 316 return newQueryM2M(md, o, mi, fi, ind) 317 } 318 319 // NOTE: this method is deprecated, context parameter will not take effect. 320 func (o *ormBase) QueryM2MWithCtx(_ context.Context, md interface{}, name string) QueryM2Mer { 321 logs.Warn("QueryM2MWithCtx is DEPRECATED. Use methods with `WithCtx` suffix on QueryM2M as replacement please.") 322 return o.QueryM2M(md, name) 323 } 324 325 // load related models to md model. 326 // args are limit, offset int and order string. 327 // 328 // example: 329 // 330 // orm.LoadRelated(post,"Tags") 331 // for _,tag := range post.Tags{...} 332 // 333 // make sure the relation is defined in model struct tags. 334 func (o *ormBase) LoadRelated(md interface{}, name string, args ...utils.KV) (int64, error) { 335 return o.LoadRelatedWithCtx(context.Background(), md, name, args...) 336 } 337 338 func (o *ormBase) LoadRelatedWithCtx(_ context.Context, md interface{}, name string, args ...utils.KV) (int64, error) { 339 _, fi, ind, qs := o.queryRelated(md, name) 340 341 var relDepth int 342 var limit, offset int64 343 var order string 344 345 kvs := utils.NewKVs(args...) 346 kvs.IfContains(hints.KeyRelDepth, func(value interface{}) { 347 if v, ok := value.(bool); ok { 348 if v { 349 relDepth = DefaultRelsDepth 350 } 351 } else if v, ok := value.(int); ok { 352 relDepth = v 353 } 354 }).IfContains(hints.KeyLimit, func(value interface{}) { 355 if v, ok := value.(int64); ok { 356 limit = v 357 } 358 }).IfContains(hints.KeyOffset, func(value interface{}) { 359 if v, ok := value.(int64); ok { 360 offset = v 361 } 362 }).IfContains(hints.KeyOrderBy, func(value interface{}) { 363 if v, ok := value.(string); ok { 364 order = v 365 } 366 }) 367 368 switch fi.fieldType { 369 case RelOneToOne, RelForeignKey, RelReverseOne: 370 limit = 1 371 offset = 0 372 } 373 374 qs.limit = limit 375 qs.offset = offset 376 qs.relDepth = relDepth 377 378 if len(order) > 0 { 379 qs.orders = clauses.ParseOrder(order) 380 } 381 382 find := ind.FieldByIndex(fi.fieldIndex) 383 384 var nums int64 385 var err error 386 switch fi.fieldType { 387 case RelOneToOne, RelForeignKey, RelReverseOne: 388 val := reflect.New(find.Type().Elem()) 389 container := val.Interface() 390 err = qs.One(container) 391 if err == nil { 392 find.Set(val) 393 nums = 1 394 } 395 default: 396 nums, err = qs.All(find.Addr().Interface()) 397 } 398 399 return nums, err 400 } 401 402 // get QuerySeter for related models to md model 403 func (o *ormBase) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, *querySet) { 404 mi, ind := o.getPtrMiInd(md) 405 fi := o.getFieldInfo(mi, name) 406 407 _, _, exist := getExistPk(mi, ind) 408 if !exist { 409 panic(ErrMissPK) 410 } 411 412 var qs *querySet 413 414 switch fi.fieldType { 415 case RelOneToOne, RelForeignKey, RelManyToMany: 416 if !fi.inModel { 417 break 418 } 419 qs = o.getRelQs(md, mi, fi) 420 case RelReverseOne, RelReverseMany: 421 if !fi.inModel { 422 break 423 } 424 qs = o.getReverseQs(md, mi, fi) 425 } 426 427 if qs == nil { 428 panic(fmt.Errorf("<Ormer> name `%s` for model `%s` is not an available rel/reverse field", md, name)) 429 } 430 431 return mi, fi, ind, qs 432 } 433 434 // get reverse relation QuerySeter 435 func (o *ormBase) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { 436 switch fi.fieldType { 437 case RelReverseOne, RelReverseMany: 438 default: 439 panic(fmt.Errorf("<Ormer> name `%s` for model `%s` is not an available reverse field", fi.name, mi.fullName)) 440 } 441 442 var q *querySet 443 444 if fi.fieldType == RelReverseMany && fi.reverseFieldInfo.mi.isThrough { 445 q = newQuerySet(o, fi.relModelInfo).(*querySet) 446 q.cond = NewCondition().And(fi.reverseFieldInfoM2M.column+ExprSep+fi.reverseFieldInfo.column, md) 447 } else { 448 q = newQuerySet(o, fi.reverseFieldInfo.mi).(*querySet) 449 q.cond = NewCondition().And(fi.reverseFieldInfo.column, md) 450 } 451 452 return q 453 } 454 455 // get relation QuerySeter 456 func (o *ormBase) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { 457 switch fi.fieldType { 458 case RelOneToOne, RelForeignKey, RelManyToMany: 459 default: 460 panic(fmt.Errorf("<Ormer> name `%s` for model `%s` is not an available rel field", fi.name, mi.fullName)) 461 } 462 463 q := newQuerySet(o, fi.relModelInfo).(*querySet) 464 q.cond = NewCondition() 465 466 if fi.fieldType == RelManyToMany { 467 q.cond = q.cond.And(fi.reverseFieldInfoM2M.column+ExprSep+fi.reverseFieldInfo.column, md) 468 } else { 469 q.cond = q.cond.And(fi.reverseFieldInfo.column, md) 470 } 471 472 return q 473 } 474 475 // return a QuerySeter for table operations. 476 // table name can be string or struct. 477 // e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), 478 func (o *ormBase) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { 479 var name string 480 if table, ok := ptrStructOrTableName.(string); ok { 481 name = nameStrategyMap[defaultNameStrategy](table) 482 if mi, ok := defaultModelCache.get(name); ok { 483 qs = newQuerySet(o, mi) 484 } 485 } else { 486 name = getFullName(indirectType(reflect.TypeOf(ptrStructOrTableName))) 487 if mi, ok := defaultModelCache.getByFullName(name); ok { 488 qs = newQuerySet(o, mi) 489 } 490 } 491 if qs == nil { 492 panic(fmt.Errorf("<Ormer.QueryTable> table name: `%s` not exists", name)) 493 } 494 return qs 495 } 496 497 // NOTE: this method is deprecated, context parameter will not take effect. 498 func (o *ormBase) QueryTableWithCtx(_ context.Context, ptrStructOrTableName interface{}) (qs QuerySeter) { 499 logs.Warn("QueryTableWithCtx is DEPRECATED. Use methods with `WithCtx` suffix on QuerySeter as replacement please.") 500 return o.QueryTable(ptrStructOrTableName) 501 } 502 503 // return a raw query seter for raw sql string. 504 func (o *ormBase) Raw(query string, args ...interface{}) RawSeter { 505 return o.RawWithCtx(context.Background(), query, args...) 506 } 507 508 func (o *ormBase) RawWithCtx(_ context.Context, query string, args ...interface{}) RawSeter { 509 return newRawSet(o, query, args) 510 } 511 512 // return current using database Driver 513 func (o *ormBase) Driver() Driver { 514 return driver(o.alias.Name) 515 } 516 517 // return sql.DBStats for current database 518 func (o *ormBase) DBStats() *sql.DBStats { 519 if o.alias != nil && o.alias.DB != nil { 520 stats := o.alias.DB.DB.Stats() 521 return &stats 522 } 523 return nil 524 } 525 526 type orm struct { 527 ormBase 528 } 529 530 var _ Ormer = new(orm) 531 532 func (o *orm) Begin() (TxOrmer, error) { 533 return o.BeginWithCtx(context.Background()) 534 } 535 536 func (o *orm) BeginWithCtx(ctx context.Context) (TxOrmer, error) { 537 return o.BeginWithCtxAndOpts(ctx, nil) 538 } 539 540 func (o *orm) BeginWithOpts(opts *sql.TxOptions) (TxOrmer, error) { 541 return o.BeginWithCtxAndOpts(context.Background(), opts) 542 } 543 544 func (o *orm) BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions) (TxOrmer, error) { 545 tx, err := o.db.(txer).BeginTx(ctx, opts) 546 if err != nil { 547 return nil, err 548 } 549 550 _txOrm := &txOrm{ 551 ormBase: ormBase{ 552 alias: o.alias, 553 db: &TxDB{tx: tx}, 554 }, 555 } 556 557 if Debug { 558 _txOrm.db = newDbQueryLog(o.alias, _txOrm.db) 559 } 560 561 var taskTxOrm TxOrmer = _txOrm 562 return taskTxOrm, nil 563 } 564 565 func (o *orm) DoTx(task func(ctx context.Context, txOrm TxOrmer) error) error { 566 return o.DoTxWithCtx(context.Background(), task) 567 } 568 569 func (o *orm) DoTxWithCtx(ctx context.Context, task func(ctx context.Context, txOrm TxOrmer) error) error { 570 return o.DoTxWithCtxAndOpts(ctx, nil, task) 571 } 572 573 func (o *orm) DoTxWithOpts(opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error { 574 return o.DoTxWithCtxAndOpts(context.Background(), opts, task) 575 } 576 577 func (o *orm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error { 578 return doTxTemplate(ctx, o, opts, task) 579 } 580 581 func doTxTemplate(ctx context.Context, o TxBeginner, opts *sql.TxOptions, 582 task func(ctx context.Context, txOrm TxOrmer) error) error { 583 _txOrm, err := o.BeginWithCtxAndOpts(ctx, opts) 584 if err != nil { 585 return err 586 } 587 panicked := true 588 defer func() { 589 if panicked || err != nil { 590 e := _txOrm.Rollback() 591 if e != nil { 592 logs.Error("rollback transaction failed: %v,%v", e, panicked) 593 } 594 } else { 595 e := _txOrm.Commit() 596 if e != nil { 597 logs.Error("commit transaction failed: %v,%v", e, panicked) 598 } 599 } 600 }() 601 taskTxOrm := _txOrm 602 err = task(ctx, taskTxOrm) 603 panicked = false 604 return err 605 } 606 607 type txOrm struct { 608 ormBase 609 } 610 611 var _ TxOrmer = new(txOrm) 612 613 func (t *txOrm) Commit() error { 614 return t.db.(txEnder).Commit() 615 } 616 617 func (t *txOrm) Rollback() error { 618 return t.db.(txEnder).Rollback() 619 } 620 621 func (t *txOrm) RollbackUnlessCommit() error { 622 return t.db.(txEnder).RollbackUnlessCommit() 623 } 624 625 // NewOrm create new orm 626 func NewOrm() Ormer { 627 BootStrap() // execute only once 628 return NewOrmUsingDB(`default`) 629 } 630 631 // NewOrmUsingDB create new orm with the name 632 func NewOrmUsingDB(aliasName string) Ormer { 633 if al, ok := dataBaseCache.get(aliasName); ok { 634 return newDBWithAlias(al) 635 } 636 panic(fmt.Errorf("<Ormer.Using> unknown db alias name `%s`", aliasName)) 637 } 638 639 // NewOrmWithDB create a new ormer object with specify *sql.DB for query 640 func NewOrmWithDB(driverName, aliasName string, db *sql.DB, params ...DBOption) (Ormer, error) { 641 al, err := newAliasWithDb(aliasName, driverName, db, params...) 642 if err != nil { 643 return nil, err 644 } 645 646 return newDBWithAlias(al), nil 647 } 648 649 func newDBWithAlias(al *alias) Ormer { 650 o := new(orm) 651 o.alias = al 652 653 if Debug { 654 o.db = newDbQueryLog(al, al.DB) 655 } else { 656 o.db = al.DB 657 } 658 659 if len(globalFilterChains) > 0 { 660 return NewFilterOrmDecorator(o, globalFilterChains...) 661 } 662 return o 663 }