github.com/astaxie/beego@v1.12.3/orm/db_tables.go (about)

     1  // Copyright 2014 beego Author. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package orm
    16  
    17  import (
    18  	"fmt"
    19  	"strings"
    20  	"time"
    21  )
    22  
    23  // table info struct.
    24  type dbTable struct {
    25  	id    int
    26  	index string
    27  	name  string
    28  	names []string
    29  	sel   bool
    30  	inner bool
    31  	mi    *modelInfo
    32  	fi    *fieldInfo
    33  	jtl   *dbTable
    34  }
    35  
    36  // tables collection struct, contains some tables.
    37  type dbTables struct {
    38  	tablesM map[string]*dbTable
    39  	tables  []*dbTable
    40  	mi      *modelInfo
    41  	base    dbBaser
    42  	skipEnd bool
    43  }
    44  
    45  // set table info to collection.
    46  // if not exist, create new.
    47  func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable {
    48  	name := strings.Join(names, ExprSep)
    49  	if j, ok := t.tablesM[name]; ok {
    50  		j.name = name
    51  		j.mi = mi
    52  		j.fi = fi
    53  		j.inner = inner
    54  	} else {
    55  		i := len(t.tables) + 1
    56  		jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil}
    57  		t.tablesM[name] = jt
    58  		t.tables = append(t.tables, jt)
    59  	}
    60  	return t.tablesM[name]
    61  }
    62  
    63  // add table info to collection.
    64  func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) {
    65  	name := strings.Join(names, ExprSep)
    66  	if _, ok := t.tablesM[name]; !ok {
    67  		i := len(t.tables) + 1
    68  		jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil}
    69  		t.tablesM[name] = jt
    70  		t.tables = append(t.tables, jt)
    71  		return jt, true
    72  	}
    73  	return t.tablesM[name], false
    74  }
    75  
    76  // get table info in collection.
    77  func (t *dbTables) get(name string) (*dbTable, bool) {
    78  	j, ok := t.tablesM[name]
    79  	return j, ok
    80  }
    81  
    82  // get related fields info in recursive depth loop.
    83  // loop once, depth decreases one.
    84  func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string {
    85  	if depth < 0 || fi.fieldType == RelManyToMany {
    86  		return related
    87  	}
    88  
    89  	if prefix == "" {
    90  		prefix = fi.name
    91  	} else {
    92  		prefix = prefix + ExprSep + fi.name
    93  	}
    94  	related = append(related, prefix)
    95  
    96  	depth--
    97  	for _, fi := range fi.relModelInfo.fields.fieldsRel {
    98  		related = t.loopDepth(depth, prefix, fi, related)
    99  	}
   100  
   101  	return related
   102  }
   103  
   104  // parse related fields.
   105  func (t *dbTables) parseRelated(rels []string, depth int) {
   106  
   107  	relsNum := len(rels)
   108  	related := make([]string, relsNum)
   109  	copy(related, rels)
   110  
   111  	relDepth := depth
   112  
   113  	if relsNum != 0 {
   114  		relDepth = 0
   115  	}
   116  
   117  	relDepth--
   118  	for _, fi := range t.mi.fields.fieldsRel {
   119  		related = t.loopDepth(relDepth, "", fi, related)
   120  	}
   121  
   122  	for i, s := range related {
   123  		var (
   124  			exs    = strings.Split(s, ExprSep)
   125  			names  = make([]string, 0, len(exs))
   126  			mmi    = t.mi
   127  			cancel = true
   128  			jtl    *dbTable
   129  		)
   130  
   131  		inner := true
   132  
   133  		for _, ex := range exs {
   134  			if fi, ok := mmi.fields.GetByAny(ex); ok && fi.rel && fi.fieldType != RelManyToMany {
   135  				names = append(names, fi.name)
   136  				mmi = fi.relModelInfo
   137  
   138  				if fi.null || t.skipEnd {
   139  					inner = false
   140  				}
   141  
   142  				jt := t.set(names, mmi, fi, inner)
   143  				jt.jtl = jtl
   144  
   145  				if fi.reverse {
   146  					cancel = false
   147  				}
   148  
   149  				if cancel {
   150  					jt.sel = depth > 0
   151  
   152  					if i < relsNum {
   153  						jt.sel = true
   154  					}
   155  				}
   156  
   157  				jtl = jt
   158  
   159  			} else {
   160  				panic(fmt.Errorf("unknown model/table name `%s`", ex))
   161  			}
   162  		}
   163  	}
   164  }
   165  
   166  // generate join string.
   167  func (t *dbTables) getJoinSQL() (join string) {
   168  	Q := t.base.TableQuote()
   169  
   170  	for _, jt := range t.tables {
   171  		if jt.inner {
   172  			join += "INNER JOIN "
   173  		} else {
   174  			join += "LEFT OUTER JOIN "
   175  		}
   176  		var (
   177  			table  string
   178  			t1, t2 string
   179  			c1, c2 string
   180  		)
   181  		t1 = "T0"
   182  		if jt.jtl != nil {
   183  			t1 = jt.jtl.index
   184  		}
   185  		t2 = jt.index
   186  		table = jt.mi.table
   187  
   188  		switch {
   189  		case jt.fi.fieldType == RelManyToMany || jt.fi.fieldType == RelReverseMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany:
   190  			c1 = jt.fi.mi.fields.pk.column
   191  			for _, ffi := range jt.mi.fields.fieldsRel {
   192  				if jt.fi.mi == ffi.relModelInfo {
   193  					c2 = ffi.column
   194  					break
   195  				}
   196  			}
   197  		default:
   198  			c1 = jt.fi.column
   199  			c2 = jt.fi.relModelInfo.fields.pk.column
   200  
   201  			if jt.fi.reverse {
   202  				c1 = jt.mi.fields.pk.column
   203  				c2 = jt.fi.reverseFieldInfo.column
   204  			}
   205  		}
   206  
   207  		join += fmt.Sprintf("%s%s%s %s ON %s.%s%s%s = %s.%s%s%s ", Q, table, Q, t2,
   208  			t2, Q, c2, Q, t1, Q, c1, Q)
   209  	}
   210  	return
   211  }
   212  
   213  // parse orm model struct field tag expression.
   214  func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) {
   215  	var (
   216  		jtl *dbTable
   217  		fi  *fieldInfo
   218  		fiN *fieldInfo
   219  		mmi = mi
   220  	)
   221  
   222  	num := len(exprs) - 1
   223  	var names []string
   224  
   225  	inner := true
   226  
   227  loopFor:
   228  	for i, ex := range exprs {
   229  
   230  		var ok, okN bool
   231  
   232  		if fiN != nil {
   233  			fi = fiN
   234  			ok = true
   235  			fiN = nil
   236  		}
   237  
   238  		if i == 0 {
   239  			fi, ok = mmi.fields.GetByAny(ex)
   240  		}
   241  
   242  		_ = okN
   243  
   244  		if ok {
   245  
   246  			isRel := fi.rel || fi.reverse
   247  
   248  			names = append(names, fi.name)
   249  
   250  			switch {
   251  			case fi.rel:
   252  				mmi = fi.relModelInfo
   253  				if fi.fieldType == RelManyToMany {
   254  					mmi = fi.relThroughModelInfo
   255  				}
   256  			case fi.reverse:
   257  				mmi = fi.reverseFieldInfo.mi
   258  			}
   259  
   260  			if i < num {
   261  				fiN, okN = mmi.fields.GetByAny(exprs[i+1])
   262  			}
   263  
   264  			if isRel && (!fi.mi.isThrough || num != i) {
   265  				if fi.null || t.skipEnd {
   266  					inner = false
   267  				}
   268  
   269  				if t.skipEnd && okN || !t.skipEnd {
   270  					if t.skipEnd && okN && fiN.pk {
   271  						goto loopEnd
   272  					}
   273  
   274  					jt, _ := t.add(names, mmi, fi, inner)
   275  					jt.jtl = jtl
   276  					jtl = jt
   277  				}
   278  
   279  			}
   280  
   281  			if num != i {
   282  				continue
   283  			}
   284  
   285  		loopEnd:
   286  
   287  			if i == 0 || jtl == nil {
   288  				index = "T0"
   289  			} else {
   290  				index = jtl.index
   291  			}
   292  
   293  			info = fi
   294  
   295  			if jtl == nil {
   296  				name = fi.name
   297  			} else {
   298  				name = jtl.name + ExprSep + fi.name
   299  			}
   300  
   301  			switch {
   302  			case fi.rel:
   303  
   304  			case fi.reverse:
   305  				switch fi.reverseFieldInfo.fieldType {
   306  				case RelOneToOne, RelForeignKey:
   307  					index = jtl.index
   308  					info = fi.reverseFieldInfo.mi.fields.pk
   309  					name = info.name
   310  				}
   311  			}
   312  
   313  			break loopFor
   314  
   315  		} else {
   316  			index = ""
   317  			name = ""
   318  			info = nil
   319  			success = false
   320  			return
   321  		}
   322  	}
   323  
   324  	success = index != "" && info != nil
   325  	return
   326  }
   327  
   328  // generate condition sql.
   329  func (t *dbTables) getCondSQL(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) {
   330  	if cond == nil || cond.IsEmpty() {
   331  		return
   332  	}
   333  
   334  	Q := t.base.TableQuote()
   335  
   336  	mi := t.mi
   337  
   338  	for i, p := range cond.params {
   339  		if i > 0 {
   340  			if p.isOr {
   341  				where += "OR "
   342  			} else {
   343  				where += "AND "
   344  			}
   345  		}
   346  		if p.isNot {
   347  			where += "NOT "
   348  		}
   349  		if p.isCond {
   350  			w, ps := t.getCondSQL(p.cond, true, tz)
   351  			if w != "" {
   352  				w = fmt.Sprintf("( %s) ", w)
   353  			}
   354  			where += w
   355  			params = append(params, ps...)
   356  		} else {
   357  			exprs := p.exprs
   358  
   359  			num := len(exprs) - 1
   360  			operator := ""
   361  			if operators[exprs[num]] {
   362  				operator = exprs[num]
   363  				exprs = exprs[:num]
   364  			}
   365  
   366  			index, _, fi, suc := t.parseExprs(mi, exprs)
   367  			if !suc {
   368  				panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep)))
   369  			}
   370  
   371  			if operator == "" {
   372  				operator = "exact"
   373  			}
   374  
   375  			var operSQL string
   376  			var args []interface{}
   377  			if p.isRaw {
   378  				operSQL = p.sql
   379  			} else {
   380  				operSQL, args = t.base.GenerateOperatorSQL(mi, fi, operator, p.args, tz)
   381  			}
   382  
   383  			leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q)
   384  			t.base.GenerateOperatorLeftCol(fi, operator, &leftCol)
   385  
   386  			where += fmt.Sprintf("%s %s ", leftCol, operSQL)
   387  			params = append(params, args...)
   388  
   389  		}
   390  	}
   391  
   392  	if !sub && where != "" {
   393  		where = "WHERE " + where
   394  	}
   395  
   396  	return
   397  }
   398  
   399  // generate group sql.
   400  func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) {
   401  	if len(groups) == 0 {
   402  		return
   403  	}
   404  
   405  	Q := t.base.TableQuote()
   406  
   407  	groupSqls := make([]string, 0, len(groups))
   408  	for _, group := range groups {
   409  		exprs := strings.Split(group, ExprSep)
   410  
   411  		index, _, fi, suc := t.parseExprs(t.mi, exprs)
   412  		if !suc {
   413  			panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
   414  		}
   415  
   416  		groupSqls = append(groupSqls, fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q))
   417  	}
   418  
   419  	groupSQL = fmt.Sprintf("GROUP BY %s ", strings.Join(groupSqls, ", "))
   420  	return
   421  }
   422  
   423  // generate order sql.
   424  func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) {
   425  	if len(orders) == 0 {
   426  		return
   427  	}
   428  
   429  	Q := t.base.TableQuote()
   430  
   431  	orderSqls := make([]string, 0, len(orders))
   432  	for _, order := range orders {
   433  		asc := "ASC"
   434  		if order[0] == '-' {
   435  			asc = "DESC"
   436  			order = order[1:]
   437  		}
   438  		exprs := strings.Split(order, ExprSep)
   439  
   440  		index, _, fi, suc := t.parseExprs(t.mi, exprs)
   441  		if !suc {
   442  			panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
   443  		}
   444  
   445  		orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, asc))
   446  	}
   447  
   448  	orderSQL = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", "))
   449  	return
   450  }
   451  
   452  // generate limit sql.
   453  func (t *dbTables) getLimitSQL(mi *modelInfo, offset int64, limit int64) (limits string) {
   454  	if limit == 0 {
   455  		limit = int64(DefaultRowsLimit)
   456  	}
   457  	if limit < 0 {
   458  		// no limit
   459  		if offset > 0 {
   460  			maxLimit := t.base.MaxLimit()
   461  			if maxLimit == 0 {
   462  				limits = fmt.Sprintf("OFFSET %d", offset)
   463  			} else {
   464  				limits = fmt.Sprintf("LIMIT %d OFFSET %d", maxLimit, offset)
   465  			}
   466  		}
   467  	} else if offset <= 0 {
   468  		limits = fmt.Sprintf("LIMIT %d", limit)
   469  	} else {
   470  		limits = fmt.Sprintf("LIMIT %d OFFSET %d", limit, offset)
   471  	}
   472  	return
   473  }
   474  
   475  // crete new tables collection.
   476  func newDbTables(mi *modelInfo, base dbBaser) *dbTables {
   477  	tables := &dbTables{}
   478  	tables.tablesM = make(map[string]*dbTable)
   479  	tables.mi = mi
   480  	tables.base = base
   481  	return tables
   482  }