github.com/mdaxf/iac@v0.0.0-20240519030858-58a061660378/databases/orm/db_tables.go (about) 1 // The original package is migrated from beego and modified, you can find orignal from following link: 2 // "github.com/beego/beego/" 3 // 4 // Copyright 2023 IAC. All Rights Reserved. 5 // 6 // Licensed under the Apache License, Version 2.0 (the "License"); 7 // you may not use this file except in compliance with the License. 8 // You may obtain a copy of the License at 9 // 10 // http://www.apache.org/licenses/LICENSE-2.0 11 // 12 // Unless required by applicable law or agreed to in writing, software 13 // distributed under the License is distributed on an "AS IS" BASIS, 14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 // See the License for the specific language governing permissions and 16 // limitations under the License. 17 18 package orm 19 20 import ( 21 "fmt" 22 "strings" 23 "time" 24 25 "github.com/mdaxf/iac/databases/orm/clauses" 26 ) 27 28 // table info struct. 29 type dbTable struct { 30 id int 31 index string 32 name string 33 names []string 34 sel bool 35 inner bool 36 mi *modelInfo 37 fi *fieldInfo 38 jtl *dbTable 39 } 40 41 // tables collection struct, contains some tables. 42 type dbTables struct { 43 tablesM map[string]*dbTable 44 tables []*dbTable 45 mi *modelInfo 46 base dbBaser 47 skipEnd bool 48 } 49 50 // set table info to collection. 51 // if not exist, create new. 52 func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable { 53 name := strings.Join(names, ExprSep) 54 if j, ok := t.tablesM[name]; ok { 55 j.name = name 56 j.mi = mi 57 j.fi = fi 58 j.inner = inner 59 } else { 60 i := len(t.tables) + 1 61 jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil} 62 t.tablesM[name] = jt 63 t.tables = append(t.tables, jt) 64 } 65 return t.tablesM[name] 66 } 67 68 // add table info to collection. 69 func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) { 70 name := strings.Join(names, ExprSep) 71 if _, ok := t.tablesM[name]; !ok { 72 i := len(t.tables) + 1 73 jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil} 74 t.tablesM[name] = jt 75 t.tables = append(t.tables, jt) 76 return jt, true 77 } 78 return t.tablesM[name], false 79 } 80 81 // get table info in collection. 82 func (t *dbTables) get(name string) (*dbTable, bool) { 83 j, ok := t.tablesM[name] 84 return j, ok 85 } 86 87 // get related fields info in recursive depth loop. 88 // loop once, depth decreases one. 89 func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string { 90 if depth < 0 || fi.fieldType == RelManyToMany { 91 return related 92 } 93 94 if prefix == "" { 95 prefix = fi.name 96 } else { 97 prefix = prefix + ExprSep + fi.name 98 } 99 related = append(related, prefix) 100 101 depth-- 102 for _, fi := range fi.relModelInfo.fields.fieldsRel { 103 related = t.loopDepth(depth, prefix, fi, related) 104 } 105 106 return related 107 } 108 109 // parse related fields. 110 func (t *dbTables) parseRelated(rels []string, depth int) { 111 relsNum := len(rels) 112 related := make([]string, relsNum) 113 copy(related, rels) 114 115 relDepth := depth 116 117 if relsNum != 0 { 118 relDepth = 0 119 } 120 121 relDepth-- 122 for _, fi := range t.mi.fields.fieldsRel { 123 related = t.loopDepth(relDepth, "", fi, related) 124 } 125 126 for i, s := range related { 127 var ( 128 exs = strings.Split(s, ExprSep) 129 names = make([]string, 0, len(exs)) 130 mmi = t.mi 131 cancel = true 132 jtl *dbTable 133 ) 134 135 inner := true 136 137 for _, ex := range exs { 138 if fi, ok := mmi.fields.GetByAny(ex); ok && fi.rel && fi.fieldType != RelManyToMany { 139 names = append(names, fi.name) 140 mmi = fi.relModelInfo 141 142 if fi.null || t.skipEnd { 143 inner = false 144 } 145 146 jt := t.set(names, mmi, fi, inner) 147 jt.jtl = jtl 148 149 if fi.reverse { 150 cancel = false 151 } 152 153 if cancel { 154 jt.sel = depth > 0 155 156 if i < relsNum { 157 jt.sel = true 158 } 159 } 160 161 jtl = jt 162 163 } else { 164 panic(fmt.Errorf("unknown model/table name `%s`", ex)) 165 } 166 } 167 } 168 } 169 170 // generate join string. 171 func (t *dbTables) getJoinSQL() (join string) { 172 Q := t.base.TableQuote() 173 174 for _, jt := range t.tables { 175 if jt.inner { 176 join += "INNER JOIN " 177 } else { 178 join += "LEFT OUTER JOIN " 179 } 180 var ( 181 table string 182 t1, t2 string 183 c1, c2 string 184 ) 185 t1 = "T0" 186 if jt.jtl != nil { 187 t1 = jt.jtl.index 188 } 189 t2 = jt.index 190 table = jt.mi.table 191 192 switch { 193 case jt.fi.fieldType == RelManyToMany || jt.fi.fieldType == RelReverseMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany: 194 c1 = jt.fi.mi.fields.pk.column 195 for _, ffi := range jt.mi.fields.fieldsRel { 196 if jt.fi.mi == ffi.relModelInfo { 197 c2 = ffi.column 198 break 199 } 200 } 201 default: 202 c1 = jt.fi.column 203 c2 = jt.fi.relModelInfo.fields.pk.column 204 205 if jt.fi.reverse { 206 c1 = jt.mi.fields.pk.column 207 c2 = jt.fi.reverseFieldInfo.column 208 } 209 } 210 211 join += fmt.Sprintf("%s%s%s %s ON %s.%s%s%s = %s.%s%s%s ", Q, table, Q, t2, 212 t2, Q, c2, Q, t1, Q, c1, Q) 213 } 214 return 215 } 216 217 // parse orm model struct field tag expression. 218 func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) { 219 var ( 220 jtl *dbTable 221 fi *fieldInfo 222 fiN *fieldInfo 223 mmi = mi 224 ) 225 226 num := len(exprs) - 1 227 var names []string 228 229 inner := true 230 231 loopFor: 232 for i, ex := range exprs { 233 234 var ok, okN bool 235 236 if fiN != nil { 237 fi = fiN 238 ok = true 239 fiN = nil 240 } 241 242 if i == 0 { 243 fi, ok = mmi.fields.GetByAny(ex) 244 } 245 246 _ = okN 247 248 if ok { 249 250 isRel := fi.rel || fi.reverse 251 252 names = append(names, fi.name) 253 254 switch { 255 case fi.rel: 256 mmi = fi.relModelInfo 257 if fi.fieldType == RelManyToMany { 258 mmi = fi.relThroughModelInfo 259 } 260 case fi.reverse: 261 mmi = fi.reverseFieldInfo.mi 262 } 263 264 if i < num { 265 fiN, okN = mmi.fields.GetByAny(exprs[i+1]) 266 } 267 268 if isRel && (!fi.mi.isThrough || num != i) { 269 if fi.null || t.skipEnd { 270 inner = false 271 } 272 273 if t.skipEnd && okN || !t.skipEnd { 274 if t.skipEnd && okN && fiN.pk { 275 goto loopEnd 276 } 277 278 jt, _ := t.add(names, mmi, fi, inner) 279 jt.jtl = jtl 280 jtl = jt 281 } 282 283 } 284 285 if num != i { 286 continue 287 } 288 289 loopEnd: 290 291 if i == 0 || jtl == nil { 292 index = "T0" 293 } else { 294 index = jtl.index 295 } 296 297 info = fi 298 299 if jtl == nil { 300 name = fi.name 301 } else { 302 name = jtl.name + ExprSep + fi.name 303 } 304 305 switch { 306 case fi.rel: 307 308 case fi.reverse: 309 switch fi.reverseFieldInfo.fieldType { 310 case RelOneToOne, RelForeignKey: 311 index = jtl.index 312 info = fi.reverseFieldInfo.mi.fields.pk 313 name = info.name 314 } 315 } 316 317 break loopFor 318 319 } else { 320 index = "" 321 name = "" 322 info = nil 323 success = false 324 return 325 } 326 } 327 328 success = index != "" && info != nil 329 return 330 } 331 332 // generate condition sql. 333 func (t *dbTables) getCondSQL(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) { 334 if cond == nil || cond.IsEmpty() { 335 return 336 } 337 338 Q := t.base.TableQuote() 339 340 mi := t.mi 341 342 for i, p := range cond.params { 343 if i > 0 { 344 if p.isOr { 345 where += "OR " 346 } else { 347 where += "AND " 348 } 349 } 350 if p.isNot { 351 where += "NOT " 352 } 353 if p.isCond { 354 w, ps := t.getCondSQL(p.cond, true, tz) 355 if w != "" { 356 w = fmt.Sprintf("( %s) ", w) 357 } 358 where += w 359 params = append(params, ps...) 360 } else { 361 exprs := p.exprs 362 363 num := len(exprs) - 1 364 operator := "" 365 if operators[exprs[num]] { 366 operator = exprs[num] 367 exprs = exprs[:num] 368 } 369 370 index, _, fi, suc := t.parseExprs(mi, exprs) 371 if !suc { 372 panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep))) 373 } 374 375 if operator == "" { 376 operator = "exact" 377 } 378 379 var operSQL string 380 var args []interface{} 381 if p.isRaw { 382 operSQL = p.sql 383 } else { 384 operSQL, args = t.base.GenerateOperatorSQL(mi, fi, operator, p.args, tz) 385 } 386 387 leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q) 388 t.base.GenerateOperatorLeftCol(fi, operator, &leftCol) 389 390 where += fmt.Sprintf("%s %s ", leftCol, operSQL) 391 params = append(params, args...) 392 393 } 394 } 395 396 if !sub && where != "" { 397 where = "WHERE " + where 398 } 399 400 return 401 } 402 403 // generate group sql. 404 func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) { 405 if len(groups) == 0 { 406 return 407 } 408 409 Q := t.base.TableQuote() 410 411 groupSqls := make([]string, 0, len(groups)) 412 for _, group := range groups { 413 exprs := strings.Split(group, ExprSep) 414 415 index, _, fi, suc := t.parseExprs(t.mi, exprs) 416 if !suc { 417 panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep))) 418 } 419 420 groupSqls = append(groupSqls, fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q)) 421 } 422 423 groupSQL = fmt.Sprintf("GROUP BY %s ", strings.Join(groupSqls, ", ")) 424 return 425 } 426 427 // generate order sql. 428 func (t *dbTables) getOrderSQL(orders []*clauses.Order) (orderSQL string) { 429 if len(orders) == 0 { 430 return 431 } 432 433 Q := t.base.TableQuote() 434 435 orderSqls := make([]string, 0, len(orders)) 436 for _, order := range orders { 437 column := order.GetColumn() 438 clause := strings.Split(column, clauses.ExprDot) 439 440 if order.IsRaw() { 441 if len(clause) == 2 { 442 orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", clause[0], Q, clause[1], Q, order.SortString())) 443 } else if len(clause) == 1 { 444 orderSqls = append(orderSqls, fmt.Sprintf("%s%s%s %s", Q, clause[0], Q, order.SortString())) 445 } else { 446 panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(clause, ExprSep))) 447 } 448 } else { 449 index, _, fi, suc := t.parseExprs(t.mi, clause) 450 if !suc { 451 panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(clause, ExprSep))) 452 } 453 454 orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, order.SortString())) 455 } 456 } 457 458 orderSQL = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", ")) 459 return 460 } 461 462 // generate limit sql. 463 func (t *dbTables) getLimitSQL(mi *modelInfo, offset int64, limit int64) (limits string) { 464 if limit == 0 { 465 limit = int64(DefaultRowsLimit) 466 } 467 if limit < 0 { 468 // no limit 469 if offset > 0 { 470 maxLimit := t.base.MaxLimit() 471 if maxLimit == 0 { 472 limits = fmt.Sprintf("OFFSET %d", offset) 473 } else { 474 limits = fmt.Sprintf("LIMIT %d OFFSET %d", maxLimit, offset) 475 } 476 } 477 } else if offset <= 0 { 478 limits = fmt.Sprintf("LIMIT %d", limit) 479 } else { 480 limits = fmt.Sprintf("LIMIT %d OFFSET %d", limit, offset) 481 } 482 return 483 } 484 485 // getIndexSql generate index sql. 486 func (t *dbTables) getIndexSql(tableName string, useIndex int, indexes []string) (clause string) { 487 if len(indexes) == 0 { 488 return 489 } 490 491 return t.base.GenerateSpecifyIndex(tableName, useIndex, indexes) 492 } 493 494 // crete new tables collection. 495 func newDbTables(mi *modelInfo, base dbBaser) *dbTables { 496 tables := &dbTables{} 497 tables.tablesM = make(map[string]*dbTable) 498 tables.mi = mi 499 tables.base = base 500 return tables 501 }