github.com/astaxie/beego@v1.12.3/orm/db_tables.go (about) 1 // Copyright 2014 beego Author. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package orm 16 17 import ( 18 "fmt" 19 "strings" 20 "time" 21 ) 22 23 // table info struct. 24 type dbTable struct { 25 id int 26 index string 27 name string 28 names []string 29 sel bool 30 inner bool 31 mi *modelInfo 32 fi *fieldInfo 33 jtl *dbTable 34 } 35 36 // tables collection struct, contains some tables. 37 type dbTables struct { 38 tablesM map[string]*dbTable 39 tables []*dbTable 40 mi *modelInfo 41 base dbBaser 42 skipEnd bool 43 } 44 45 // set table info to collection. 46 // if not exist, create new. 47 func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable { 48 name := strings.Join(names, ExprSep) 49 if j, ok := t.tablesM[name]; ok { 50 j.name = name 51 j.mi = mi 52 j.fi = fi 53 j.inner = inner 54 } else { 55 i := len(t.tables) + 1 56 jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil} 57 t.tablesM[name] = jt 58 t.tables = append(t.tables, jt) 59 } 60 return t.tablesM[name] 61 } 62 63 // add table info to collection. 64 func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) { 65 name := strings.Join(names, ExprSep) 66 if _, ok := t.tablesM[name]; !ok { 67 i := len(t.tables) + 1 68 jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil} 69 t.tablesM[name] = jt 70 t.tables = append(t.tables, jt) 71 return jt, true 72 } 73 return t.tablesM[name], false 74 } 75 76 // get table info in collection. 77 func (t *dbTables) get(name string) (*dbTable, bool) { 78 j, ok := t.tablesM[name] 79 return j, ok 80 } 81 82 // get related fields info in recursive depth loop. 83 // loop once, depth decreases one. 84 func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string { 85 if depth < 0 || fi.fieldType == RelManyToMany { 86 return related 87 } 88 89 if prefix == "" { 90 prefix = fi.name 91 } else { 92 prefix = prefix + ExprSep + fi.name 93 } 94 related = append(related, prefix) 95 96 depth-- 97 for _, fi := range fi.relModelInfo.fields.fieldsRel { 98 related = t.loopDepth(depth, prefix, fi, related) 99 } 100 101 return related 102 } 103 104 // parse related fields. 105 func (t *dbTables) parseRelated(rels []string, depth int) { 106 107 relsNum := len(rels) 108 related := make([]string, relsNum) 109 copy(related, rels) 110 111 relDepth := depth 112 113 if relsNum != 0 { 114 relDepth = 0 115 } 116 117 relDepth-- 118 for _, fi := range t.mi.fields.fieldsRel { 119 related = t.loopDepth(relDepth, "", fi, related) 120 } 121 122 for i, s := range related { 123 var ( 124 exs = strings.Split(s, ExprSep) 125 names = make([]string, 0, len(exs)) 126 mmi = t.mi 127 cancel = true 128 jtl *dbTable 129 ) 130 131 inner := true 132 133 for _, ex := range exs { 134 if fi, ok := mmi.fields.GetByAny(ex); ok && fi.rel && fi.fieldType != RelManyToMany { 135 names = append(names, fi.name) 136 mmi = fi.relModelInfo 137 138 if fi.null || t.skipEnd { 139 inner = false 140 } 141 142 jt := t.set(names, mmi, fi, inner) 143 jt.jtl = jtl 144 145 if fi.reverse { 146 cancel = false 147 } 148 149 if cancel { 150 jt.sel = depth > 0 151 152 if i < relsNum { 153 jt.sel = true 154 } 155 } 156 157 jtl = jt 158 159 } else { 160 panic(fmt.Errorf("unknown model/table name `%s`", ex)) 161 } 162 } 163 } 164 } 165 166 // generate join string. 167 func (t *dbTables) getJoinSQL() (join string) { 168 Q := t.base.TableQuote() 169 170 for _, jt := range t.tables { 171 if jt.inner { 172 join += "INNER JOIN " 173 } else { 174 join += "LEFT OUTER JOIN " 175 } 176 var ( 177 table string 178 t1, t2 string 179 c1, c2 string 180 ) 181 t1 = "T0" 182 if jt.jtl != nil { 183 t1 = jt.jtl.index 184 } 185 t2 = jt.index 186 table = jt.mi.table 187 188 switch { 189 case jt.fi.fieldType == RelManyToMany || jt.fi.fieldType == RelReverseMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany: 190 c1 = jt.fi.mi.fields.pk.column 191 for _, ffi := range jt.mi.fields.fieldsRel { 192 if jt.fi.mi == ffi.relModelInfo { 193 c2 = ffi.column 194 break 195 } 196 } 197 default: 198 c1 = jt.fi.column 199 c2 = jt.fi.relModelInfo.fields.pk.column 200 201 if jt.fi.reverse { 202 c1 = jt.mi.fields.pk.column 203 c2 = jt.fi.reverseFieldInfo.column 204 } 205 } 206 207 join += fmt.Sprintf("%s%s%s %s ON %s.%s%s%s = %s.%s%s%s ", Q, table, Q, t2, 208 t2, Q, c2, Q, t1, Q, c1, Q) 209 } 210 return 211 } 212 213 // parse orm model struct field tag expression. 214 func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) { 215 var ( 216 jtl *dbTable 217 fi *fieldInfo 218 fiN *fieldInfo 219 mmi = mi 220 ) 221 222 num := len(exprs) - 1 223 var names []string 224 225 inner := true 226 227 loopFor: 228 for i, ex := range exprs { 229 230 var ok, okN bool 231 232 if fiN != nil { 233 fi = fiN 234 ok = true 235 fiN = nil 236 } 237 238 if i == 0 { 239 fi, ok = mmi.fields.GetByAny(ex) 240 } 241 242 _ = okN 243 244 if ok { 245 246 isRel := fi.rel || fi.reverse 247 248 names = append(names, fi.name) 249 250 switch { 251 case fi.rel: 252 mmi = fi.relModelInfo 253 if fi.fieldType == RelManyToMany { 254 mmi = fi.relThroughModelInfo 255 } 256 case fi.reverse: 257 mmi = fi.reverseFieldInfo.mi 258 } 259 260 if i < num { 261 fiN, okN = mmi.fields.GetByAny(exprs[i+1]) 262 } 263 264 if isRel && (!fi.mi.isThrough || num != i) { 265 if fi.null || t.skipEnd { 266 inner = false 267 } 268 269 if t.skipEnd && okN || !t.skipEnd { 270 if t.skipEnd && okN && fiN.pk { 271 goto loopEnd 272 } 273 274 jt, _ := t.add(names, mmi, fi, inner) 275 jt.jtl = jtl 276 jtl = jt 277 } 278 279 } 280 281 if num != i { 282 continue 283 } 284 285 loopEnd: 286 287 if i == 0 || jtl == nil { 288 index = "T0" 289 } else { 290 index = jtl.index 291 } 292 293 info = fi 294 295 if jtl == nil { 296 name = fi.name 297 } else { 298 name = jtl.name + ExprSep + fi.name 299 } 300 301 switch { 302 case fi.rel: 303 304 case fi.reverse: 305 switch fi.reverseFieldInfo.fieldType { 306 case RelOneToOne, RelForeignKey: 307 index = jtl.index 308 info = fi.reverseFieldInfo.mi.fields.pk 309 name = info.name 310 } 311 } 312 313 break loopFor 314 315 } else { 316 index = "" 317 name = "" 318 info = nil 319 success = false 320 return 321 } 322 } 323 324 success = index != "" && info != nil 325 return 326 } 327 328 // generate condition sql. 329 func (t *dbTables) getCondSQL(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) { 330 if cond == nil || cond.IsEmpty() { 331 return 332 } 333 334 Q := t.base.TableQuote() 335 336 mi := t.mi 337 338 for i, p := range cond.params { 339 if i > 0 { 340 if p.isOr { 341 where += "OR " 342 } else { 343 where += "AND " 344 } 345 } 346 if p.isNot { 347 where += "NOT " 348 } 349 if p.isCond { 350 w, ps := t.getCondSQL(p.cond, true, tz) 351 if w != "" { 352 w = fmt.Sprintf("( %s) ", w) 353 } 354 where += w 355 params = append(params, ps...) 356 } else { 357 exprs := p.exprs 358 359 num := len(exprs) - 1 360 operator := "" 361 if operators[exprs[num]] { 362 operator = exprs[num] 363 exprs = exprs[:num] 364 } 365 366 index, _, fi, suc := t.parseExprs(mi, exprs) 367 if !suc { 368 panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep))) 369 } 370 371 if operator == "" { 372 operator = "exact" 373 } 374 375 var operSQL string 376 var args []interface{} 377 if p.isRaw { 378 operSQL = p.sql 379 } else { 380 operSQL, args = t.base.GenerateOperatorSQL(mi, fi, operator, p.args, tz) 381 } 382 383 leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q) 384 t.base.GenerateOperatorLeftCol(fi, operator, &leftCol) 385 386 where += fmt.Sprintf("%s %s ", leftCol, operSQL) 387 params = append(params, args...) 388 389 } 390 } 391 392 if !sub && where != "" { 393 where = "WHERE " + where 394 } 395 396 return 397 } 398 399 // generate group sql. 400 func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) { 401 if len(groups) == 0 { 402 return 403 } 404 405 Q := t.base.TableQuote() 406 407 groupSqls := make([]string, 0, len(groups)) 408 for _, group := range groups { 409 exprs := strings.Split(group, ExprSep) 410 411 index, _, fi, suc := t.parseExprs(t.mi, exprs) 412 if !suc { 413 panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep))) 414 } 415 416 groupSqls = append(groupSqls, fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q)) 417 } 418 419 groupSQL = fmt.Sprintf("GROUP BY %s ", strings.Join(groupSqls, ", ")) 420 return 421 } 422 423 // generate order sql. 424 func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) { 425 if len(orders) == 0 { 426 return 427 } 428 429 Q := t.base.TableQuote() 430 431 orderSqls := make([]string, 0, len(orders)) 432 for _, order := range orders { 433 asc := "ASC" 434 if order[0] == '-' { 435 asc = "DESC" 436 order = order[1:] 437 } 438 exprs := strings.Split(order, ExprSep) 439 440 index, _, fi, suc := t.parseExprs(t.mi, exprs) 441 if !suc { 442 panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep))) 443 } 444 445 orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, asc)) 446 } 447 448 orderSQL = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", ")) 449 return 450 } 451 452 // generate limit sql. 453 func (t *dbTables) getLimitSQL(mi *modelInfo, offset int64, limit int64) (limits string) { 454 if limit == 0 { 455 limit = int64(DefaultRowsLimit) 456 } 457 if limit < 0 { 458 // no limit 459 if offset > 0 { 460 maxLimit := t.base.MaxLimit() 461 if maxLimit == 0 { 462 limits = fmt.Sprintf("OFFSET %d", offset) 463 } else { 464 limits = fmt.Sprintf("LIMIT %d OFFSET %d", maxLimit, offset) 465 } 466 } 467 } else if offset <= 0 { 468 limits = fmt.Sprintf("LIMIT %d", limit) 469 } else { 470 limits = fmt.Sprintf("LIMIT %d OFFSET %d", limit, offset) 471 } 472 return 473 } 474 475 // crete new tables collection. 476 func newDbTables(mi *modelInfo, base dbBaser) *dbTables { 477 tables := &dbTables{} 478 tables.tablesM = make(map[string]*dbTable) 479 tables.mi = mi 480 tables.base = base 481 return tables 482 }