github.com/mdaxf/iac@v0.0.0-20240519030858-58a061660378/databases/orm/models.go (about) 1 // The original package is migrated from beego and modified, you can find orignal from following link: 2 // "github.com/beego/beego/" 3 // 4 // Copyright 2023 IAC. All Rights Reserved. 5 // 6 // Licensed under the Apache License, Version 2.0 (the "License"); 7 // you may not use this file except in compliance with the License. 8 // You may obtain a copy of the License at 9 // 10 // http://www.apache.org/licenses/LICENSE-2.0 11 // 12 // Unless required by applicable law or agreed to in writing, software 13 // distributed under the License is distributed on an "AS IS" BASIS, 14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 // See the License for the specific language governing permissions and 16 // limitations under the License. 17 18 package orm 19 20 import ( 21 "errors" 22 "fmt" 23 "reflect" 24 "runtime/debug" 25 "strings" 26 "sync" 27 ) 28 29 const ( 30 odCascade = "cascade" 31 odSetNULL = "set_null" 32 odSetDefault = "set_default" 33 odDoNothing = "do_nothing" 34 defaultStructTagName = "orm" 35 defaultStructTagDelim = ";" 36 ) 37 38 var defaultModelCache = NewModelCacheHandler() 39 40 // model info collection 41 type modelCache struct { 42 sync.RWMutex // only used outsite for bootStrap 43 orders []string 44 cache map[string]*modelInfo 45 cacheByFullName map[string]*modelInfo 46 done bool 47 } 48 49 // NewModelCacheHandler generator of modelCache 50 func NewModelCacheHandler() *modelCache { 51 return &modelCache{ 52 cache: make(map[string]*modelInfo), 53 cacheByFullName: make(map[string]*modelInfo), 54 } 55 } 56 57 // get all model info 58 func (mc *modelCache) all() map[string]*modelInfo { 59 m := make(map[string]*modelInfo, len(mc.cache)) 60 for k, v := range mc.cache { 61 m[k] = v 62 } 63 return m 64 } 65 66 // get ordered model info 67 func (mc *modelCache) allOrdered() []*modelInfo { 68 m := make([]*modelInfo, 0, len(mc.orders)) 69 for _, table := range mc.orders { 70 m = append(m, mc.cache[table]) 71 } 72 return m 73 } 74 75 // get model info by table name 76 func (mc *modelCache) get(table string) (mi *modelInfo, ok bool) { 77 mi, ok = mc.cache[table] 78 return 79 } 80 81 // get model info by full name 82 func (mc *modelCache) getByFullName(name string) (mi *modelInfo, ok bool) { 83 mi, ok = mc.cacheByFullName[name] 84 return 85 } 86 87 func (mc *modelCache) getByMd(md interface{}) (*modelInfo, bool) { 88 val := reflect.ValueOf(md) 89 ind := reflect.Indirect(val) 90 typ := ind.Type() 91 name := getFullName(typ) 92 return mc.getByFullName(name) 93 } 94 95 // set model info to collection 96 func (mc *modelCache) set(table string, mi *modelInfo) *modelInfo { 97 mii := mc.cache[table] 98 mc.cache[table] = mi 99 mc.cacheByFullName[mi.fullName] = mi 100 if mii == nil { 101 mc.orders = append(mc.orders, table) 102 } 103 return mii 104 } 105 106 // clean all model info. 107 func (mc *modelCache) clean() { 108 mc.Lock() 109 defer mc.Unlock() 110 111 mc.orders = make([]string, 0) 112 mc.cache = make(map[string]*modelInfo) 113 mc.cacheByFullName = make(map[string]*modelInfo) 114 mc.done = false 115 } 116 117 // bootstrap bootstrap for models 118 func (mc *modelCache) bootstrap() { 119 mc.Lock() 120 defer mc.Unlock() 121 if mc.done { 122 return 123 } 124 var ( 125 err error 126 models map[string]*modelInfo 127 ) 128 if dataBaseCache.getDefault() == nil { 129 err = fmt.Errorf("must have one register DataBase alias named `default`") 130 goto end 131 } 132 133 // set rel and reverse model 134 // RelManyToMany set the relTable 135 models = mc.all() 136 for _, mi := range models { 137 for _, fi := range mi.fields.columns { 138 if fi.rel || fi.reverse { 139 elm := fi.addrValue.Type().Elem() 140 if fi.fieldType == RelReverseMany || fi.fieldType == RelManyToMany { 141 elm = elm.Elem() 142 } 143 // check the rel or reverse model already register 144 name := getFullName(elm) 145 mii, ok := mc.getByFullName(name) 146 if !ok || mii.pkg != elm.PkgPath() { 147 err = fmt.Errorf("can not find rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String()) 148 goto end 149 } 150 fi.relModelInfo = mii 151 152 switch fi.fieldType { 153 case RelManyToMany: 154 if fi.relThrough != "" { 155 if i := strings.LastIndex(fi.relThrough, "."); i != -1 && len(fi.relThrough) > (i+1) { 156 pn := fi.relThrough[:i] 157 rmi, ok := mc.getByFullName(fi.relThrough) 158 if !ok || pn != rmi.pkg { 159 err = fmt.Errorf("field `%s` wrong rel_through value `%s` cannot find table", fi.fullName, fi.relThrough) 160 goto end 161 } 162 fi.relThroughModelInfo = rmi 163 fi.relTable = rmi.table 164 } else { 165 err = fmt.Errorf("field `%s` wrong rel_through value `%s`", fi.fullName, fi.relThrough) 166 goto end 167 } 168 } else { 169 i := newM2MModelInfo(mi, mii) 170 if fi.relTable != "" { 171 i.table = fi.relTable 172 } 173 if v := mc.set(i.table, i); v != nil { 174 err = fmt.Errorf("the rel table name `%s` already registered, cannot be use, please change one", fi.relTable) 175 goto end 176 } 177 fi.relTable = i.table 178 fi.relThroughModelInfo = i 179 } 180 181 fi.relThroughModelInfo.isThrough = true 182 } 183 } 184 } 185 } 186 187 // check the rel filed while the relModelInfo also has filed point to current model 188 // if not exist, add a new field to the relModelInfo 189 models = mc.all() 190 for _, mi := range models { 191 for _, fi := range mi.fields.fieldsRel { 192 switch fi.fieldType { 193 case RelForeignKey, RelOneToOne, RelManyToMany: 194 inModel := false 195 for _, ffi := range fi.relModelInfo.fields.fieldsReverse { 196 if ffi.relModelInfo == mi { 197 inModel = true 198 break 199 } 200 } 201 if !inModel { 202 rmi := fi.relModelInfo 203 ffi := new(fieldInfo) 204 ffi.name = mi.name 205 ffi.column = ffi.name 206 ffi.fullName = rmi.fullName + "." + ffi.name 207 ffi.reverse = true 208 ffi.relModelInfo = mi 209 ffi.mi = rmi 210 if fi.fieldType == RelOneToOne { 211 ffi.fieldType = RelReverseOne 212 } else { 213 ffi.fieldType = RelReverseMany 214 } 215 if !rmi.fields.Add(ffi) { 216 added := false 217 for cnt := 0; cnt < 5; cnt++ { 218 ffi.name = fmt.Sprintf("%s%d", mi.name, cnt) 219 ffi.column = ffi.name 220 ffi.fullName = rmi.fullName + "." + ffi.name 221 if added = rmi.fields.Add(ffi); added { 222 break 223 } 224 } 225 if !added { 226 panic(fmt.Errorf("cannot generate auto reverse field info `%s` to `%s`", fi.fullName, ffi.fullName)) 227 } 228 } 229 } 230 } 231 } 232 } 233 234 models = mc.all() 235 for _, mi := range models { 236 for _, fi := range mi.fields.fieldsRel { 237 switch fi.fieldType { 238 case RelManyToMany: 239 for _, ffi := range fi.relThroughModelInfo.fields.fieldsRel { 240 switch ffi.fieldType { 241 case RelOneToOne, RelForeignKey: 242 if ffi.relModelInfo == fi.relModelInfo { 243 fi.reverseFieldInfoTwo = ffi 244 } 245 if ffi.relModelInfo == mi { 246 fi.reverseField = ffi.name 247 fi.reverseFieldInfo = ffi 248 } 249 } 250 } 251 if fi.reverseFieldInfoTwo == nil { 252 err = fmt.Errorf("can not find m2m field for m2m model `%s`, ensure your m2m model defined correct", 253 fi.relThroughModelInfo.fullName) 254 goto end 255 } 256 } 257 } 258 } 259 260 models = mc.all() 261 for _, mi := range models { 262 for _, fi := range mi.fields.fieldsReverse { 263 switch fi.fieldType { 264 case RelReverseOne: 265 found := false 266 mForA: 267 for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelOneToOne] { 268 if ffi.relModelInfo == mi { 269 found = true 270 fi.reverseField = ffi.name 271 fi.reverseFieldInfo = ffi 272 273 ffi.reverseField = fi.name 274 ffi.reverseFieldInfo = fi 275 break mForA 276 } 277 } 278 if !found { 279 err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName) 280 goto end 281 } 282 case RelReverseMany: 283 found := false 284 mForB: 285 for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelForeignKey] { 286 if ffi.relModelInfo == mi { 287 found = true 288 fi.reverseField = ffi.name 289 fi.reverseFieldInfo = ffi 290 291 ffi.reverseField = fi.name 292 ffi.reverseFieldInfo = fi 293 294 break mForB 295 } 296 } 297 if !found { 298 mForC: 299 for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelManyToMany] { 300 conditions := fi.relThrough != "" && fi.relThrough == ffi.relThrough || 301 fi.relTable != "" && fi.relTable == ffi.relTable || 302 fi.relThrough == "" && fi.relTable == "" 303 if ffi.relModelInfo == mi && conditions { 304 found = true 305 306 fi.reverseField = ffi.reverseFieldInfoTwo.name 307 fi.reverseFieldInfo = ffi.reverseFieldInfoTwo 308 fi.relThroughModelInfo = ffi.relThroughModelInfo 309 fi.reverseFieldInfoTwo = ffi.reverseFieldInfo 310 fi.reverseFieldInfoM2M = ffi 311 ffi.reverseFieldInfoM2M = fi 312 313 break mForC 314 } 315 } 316 } 317 if !found { 318 err = fmt.Errorf("reverse field for `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName) 319 goto end 320 } 321 } 322 } 323 } 324 325 end: 326 if err != nil { 327 fmt.Println(err) 328 debug.PrintStack() 329 } 330 mc.done = true 331 } 332 333 // register register models to model cache 334 func (mc *modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, models ...interface{}) (err error) { 335 for _, model := range models { 336 val := reflect.ValueOf(model) 337 typ := reflect.Indirect(val).Type() 338 339 if val.Kind() != reflect.Ptr { 340 err = fmt.Errorf("<orm.RegisterModel> cannot use non-ptr model struct `%s`", getFullName(typ)) 341 return 342 } 343 // For this case: 344 // u := &User{} 345 // registerModel(&u) 346 if typ.Kind() == reflect.Ptr { 347 err = fmt.Errorf("<orm.RegisterModel> only allow ptr model struct, it looks you use two reference to the struct `%s`", typ) 348 return 349 } 350 if val.Elem().Kind() == reflect.Slice { 351 val = reflect.New(val.Elem().Type().Elem()) 352 } 353 table := getTableName(val) 354 355 if prefixOrSuffixStr != "" { 356 if prefixOrSuffix { 357 table = prefixOrSuffixStr + table 358 } else { 359 table = table + prefixOrSuffixStr 360 } 361 } 362 363 // models's fullname is pkgpath + struct name 364 name := getFullName(typ) 365 if _, ok := mc.getByFullName(name); ok { 366 err = fmt.Errorf("<orm.RegisterModel> model `%s` repeat register, must be unique\n", name) 367 return 368 } 369 370 if _, ok := mc.get(table); ok { 371 return nil 372 } 373 374 mi := newModelInfo(val) 375 if mi.fields.pk == nil { 376 outFor: 377 for _, fi := range mi.fields.fieldsDB { 378 if strings.ToLower(fi.name) == "id" { 379 switch fi.addrValue.Elem().Kind() { 380 case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64: 381 fi.auto = true 382 fi.pk = true 383 mi.fields.pk = fi 384 break outFor 385 } 386 } 387 } 388 } 389 390 mi.table = table 391 mi.pkg = typ.PkgPath() 392 mi.model = model 393 mi.manual = true 394 395 mc.set(table, mi) 396 } 397 return 398 } 399 400 // getDbDropSQL get database scheme drop sql queries 401 func (mc *modelCache) getDbDropSQL(al *alias) (queries []string, err error) { 402 if len(mc.cache) == 0 { 403 err = errors.New("no Model found, need register your model") 404 return 405 } 406 407 Q := al.DbBaser.TableQuote() 408 409 for _, mi := range mc.allOrdered() { 410 queries = append(queries, fmt.Sprintf(`DROP TABLE IF EXISTS %s%s%s`, Q, mi.table, Q)) 411 } 412 return queries, nil 413 } 414 415 // getDbCreateSQL get database scheme creation sql queries 416 func (mc *modelCache) getDbCreateSQL(al *alias) (queries []string, tableIndexes map[string][]dbIndex, err error) { 417 if len(mc.cache) == 0 { 418 err = errors.New("no Model found, need register your model") 419 return 420 } 421 422 Q := al.DbBaser.TableQuote() 423 T := al.DbBaser.DbTypes() 424 sep := fmt.Sprintf("%s, %s", Q, Q) 425 426 tableIndexes = make(map[string][]dbIndex) 427 428 for _, mi := range mc.allOrdered() { 429 sql := fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) 430 sql += fmt.Sprintf("-- Table Structure for `%s`\n", mi.fullName) 431 sql += fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) 432 433 sql += fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s%s%s (\n", Q, mi.table, Q) 434 435 columns := make([]string, 0, len(mi.fields.fieldsDB)) 436 437 sqlIndexes := [][]string{} 438 var commentIndexes []int // store comment indexes for postgres 439 440 for i, fi := range mi.fields.fieldsDB { 441 column := fmt.Sprintf(" %s%s%s ", Q, fi.column, Q) 442 col := getColumnTyp(al, fi) 443 444 if fi.auto { 445 switch al.Driver { 446 case DRSqlite, DRPostgres: 447 column += T["auto"] 448 default: 449 column += col + " " + T["auto"] 450 } 451 } else if fi.pk { 452 column += col + " " + T["pk"] 453 } else { 454 column += col 455 456 if !fi.null { 457 column += " " + "NOT NULL" 458 } 459 460 // if fi.initial.String() != "" { 461 // column += " DEFAULT " + fi.initial.String() 462 // } 463 464 // Append attribute DEFAULT 465 column += getColumnDefault(fi) 466 467 if fi.unique { 468 column += " " + "UNIQUE" 469 } 470 471 if fi.index { 472 sqlIndexes = append(sqlIndexes, []string{fi.column}) 473 } 474 } 475 476 if strings.Contains(column, "%COL%") { 477 column = strings.Replace(column, "%COL%", fi.column, -1) 478 } 479 480 if fi.description != "" && al.Driver != DRSqlite { 481 if al.Driver == DRPostgres { 482 commentIndexes = append(commentIndexes, i) 483 } else { 484 column += " " + fmt.Sprintf("COMMENT '%s'", fi.description) 485 } 486 } 487 488 columns = append(columns, column) 489 } 490 491 if mi.model != nil { 492 allnames := getTableUnique(mi.addrField) 493 if !mi.manual && len(mi.uniques) > 0 { 494 allnames = append(allnames, mi.uniques) 495 } 496 for _, names := range allnames { 497 cols := make([]string, 0, len(names)) 498 for _, name := range names { 499 if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol { 500 cols = append(cols, fi.column) 501 } else { 502 panic(fmt.Errorf("cannot found column `%s` when parse UNIQUE in `%s.TableUnique`", name, mi.fullName)) 503 } 504 } 505 column := fmt.Sprintf(" UNIQUE (%s%s%s)", Q, strings.Join(cols, sep), Q) 506 columns = append(columns, column) 507 } 508 } 509 510 sql += strings.Join(columns, ",\n") 511 sql += "\n)" 512 513 if al.Driver == DRMySQL { 514 var engine string 515 if mi.model != nil { 516 engine = getTableEngine(mi.addrField) 517 } 518 if engine == "" { 519 engine = al.Engine 520 } 521 sql += " ENGINE=" + engine 522 } 523 524 sql += ";" 525 if al.Driver == DRPostgres && len(commentIndexes) > 0 { 526 // append comments for postgres only 527 for _, index := range commentIndexes { 528 sql += fmt.Sprintf("\nCOMMENT ON COLUMN %s%s%s.%s%s%s is '%s';", 529 Q, 530 mi.table, 531 Q, 532 Q, 533 mi.fields.fieldsDB[index].column, 534 Q, 535 mi.fields.fieldsDB[index].description) 536 } 537 } 538 queries = append(queries, sql) 539 540 if mi.model != nil { 541 for _, names := range getTableIndex(mi.addrField) { 542 cols := make([]string, 0, len(names)) 543 for _, name := range names { 544 if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol { 545 cols = append(cols, fi.column) 546 } else { 547 panic(fmt.Errorf("cannot found column `%s` when parse INDEX in `%s.TableIndex`", name, mi.fullName)) 548 } 549 } 550 sqlIndexes = append(sqlIndexes, cols) 551 } 552 } 553 554 for _, names := range sqlIndexes { 555 name := mi.table + "_" + strings.Join(names, "_") 556 cols := strings.Join(names, sep) 557 sql := fmt.Sprintf("CREATE INDEX %s%s%s ON %s%s%s (%s%s%s);", Q, name, Q, Q, mi.table, Q, Q, cols, Q) 558 559 index := dbIndex{} 560 index.Table = mi.table 561 index.Name = name 562 index.SQL = sql 563 564 tableIndexes[mi.table] = append(tableIndexes[mi.table], index) 565 } 566 567 } 568 569 return 570 } 571 572 // ResetModelCache Clean model cache. Then you can re-RegisterModel. 573 // Common use this api for test case. 574 func ResetModelCache() { 575 defaultModelCache.clean() 576 }