github.com/mdaxf/iac@v0.0.0-20240519030858-58a061660378/databases/orm/db_tables.go (about)

     1  // The original package is migrated from beego and modified, you can find orignal from following link:
     2  //    "github.com/beego/beego/"
     3  //
     4  // Copyright 2023 IAC. All Rights Reserved.
     5  //
     6  // Licensed under the Apache License, Version 2.0 (the "License");
     7  // you may not use this file except in compliance with the License.
     8  // You may obtain a copy of the License at
     9  //
    10  //      http://www.apache.org/licenses/LICENSE-2.0
    11  //
    12  // Unless required by applicable law or agreed to in writing, software
    13  // distributed under the License is distributed on an "AS IS" BASIS,
    14  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    15  // See the License for the specific language governing permissions and
    16  // limitations under the License.
    17  
    18  package orm
    19  
    20  import (
    21  	"fmt"
    22  	"strings"
    23  	"time"
    24  
    25  	"github.com/mdaxf/iac/databases/orm/clauses"
    26  )
    27  
    28  // table info struct.
    29  type dbTable struct {
    30  	id    int
    31  	index string
    32  	name  string
    33  	names []string
    34  	sel   bool
    35  	inner bool
    36  	mi    *modelInfo
    37  	fi    *fieldInfo
    38  	jtl   *dbTable
    39  }
    40  
    41  // tables collection struct, contains some tables.
    42  type dbTables struct {
    43  	tablesM map[string]*dbTable
    44  	tables  []*dbTable
    45  	mi      *modelInfo
    46  	base    dbBaser
    47  	skipEnd bool
    48  }
    49  
    50  // set table info to collection.
    51  // if not exist, create new.
    52  func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable {
    53  	name := strings.Join(names, ExprSep)
    54  	if j, ok := t.tablesM[name]; ok {
    55  		j.name = name
    56  		j.mi = mi
    57  		j.fi = fi
    58  		j.inner = inner
    59  	} else {
    60  		i := len(t.tables) + 1
    61  		jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil}
    62  		t.tablesM[name] = jt
    63  		t.tables = append(t.tables, jt)
    64  	}
    65  	return t.tablesM[name]
    66  }
    67  
    68  // add table info to collection.
    69  func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) {
    70  	name := strings.Join(names, ExprSep)
    71  	if _, ok := t.tablesM[name]; !ok {
    72  		i := len(t.tables) + 1
    73  		jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil}
    74  		t.tablesM[name] = jt
    75  		t.tables = append(t.tables, jt)
    76  		return jt, true
    77  	}
    78  	return t.tablesM[name], false
    79  }
    80  
    81  // get table info in collection.
    82  func (t *dbTables) get(name string) (*dbTable, bool) {
    83  	j, ok := t.tablesM[name]
    84  	return j, ok
    85  }
    86  
    87  // get related fields info in recursive depth loop.
    88  // loop once, depth decreases one.
    89  func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string {
    90  	if depth < 0 || fi.fieldType == RelManyToMany {
    91  		return related
    92  	}
    93  
    94  	if prefix == "" {
    95  		prefix = fi.name
    96  	} else {
    97  		prefix = prefix + ExprSep + fi.name
    98  	}
    99  	related = append(related, prefix)
   100  
   101  	depth--
   102  	for _, fi := range fi.relModelInfo.fields.fieldsRel {
   103  		related = t.loopDepth(depth, prefix, fi, related)
   104  	}
   105  
   106  	return related
   107  }
   108  
   109  // parse related fields.
   110  func (t *dbTables) parseRelated(rels []string, depth int) {
   111  	relsNum := len(rels)
   112  	related := make([]string, relsNum)
   113  	copy(related, rels)
   114  
   115  	relDepth := depth
   116  
   117  	if relsNum != 0 {
   118  		relDepth = 0
   119  	}
   120  
   121  	relDepth--
   122  	for _, fi := range t.mi.fields.fieldsRel {
   123  		related = t.loopDepth(relDepth, "", fi, related)
   124  	}
   125  
   126  	for i, s := range related {
   127  		var (
   128  			exs    = strings.Split(s, ExprSep)
   129  			names  = make([]string, 0, len(exs))
   130  			mmi    = t.mi
   131  			cancel = true
   132  			jtl    *dbTable
   133  		)
   134  
   135  		inner := true
   136  
   137  		for _, ex := range exs {
   138  			if fi, ok := mmi.fields.GetByAny(ex); ok && fi.rel && fi.fieldType != RelManyToMany {
   139  				names = append(names, fi.name)
   140  				mmi = fi.relModelInfo
   141  
   142  				if fi.null || t.skipEnd {
   143  					inner = false
   144  				}
   145  
   146  				jt := t.set(names, mmi, fi, inner)
   147  				jt.jtl = jtl
   148  
   149  				if fi.reverse {
   150  					cancel = false
   151  				}
   152  
   153  				if cancel {
   154  					jt.sel = depth > 0
   155  
   156  					if i < relsNum {
   157  						jt.sel = true
   158  					}
   159  				}
   160  
   161  				jtl = jt
   162  
   163  			} else {
   164  				panic(fmt.Errorf("unknown model/table name `%s`", ex))
   165  			}
   166  		}
   167  	}
   168  }
   169  
   170  // generate join string.
   171  func (t *dbTables) getJoinSQL() (join string) {
   172  	Q := t.base.TableQuote()
   173  
   174  	for _, jt := range t.tables {
   175  		if jt.inner {
   176  			join += "INNER JOIN "
   177  		} else {
   178  			join += "LEFT OUTER JOIN "
   179  		}
   180  		var (
   181  			table  string
   182  			t1, t2 string
   183  			c1, c2 string
   184  		)
   185  		t1 = "T0"
   186  		if jt.jtl != nil {
   187  			t1 = jt.jtl.index
   188  		}
   189  		t2 = jt.index
   190  		table = jt.mi.table
   191  
   192  		switch {
   193  		case jt.fi.fieldType == RelManyToMany || jt.fi.fieldType == RelReverseMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany:
   194  			c1 = jt.fi.mi.fields.pk.column
   195  			for _, ffi := range jt.mi.fields.fieldsRel {
   196  				if jt.fi.mi == ffi.relModelInfo {
   197  					c2 = ffi.column
   198  					break
   199  				}
   200  			}
   201  		default:
   202  			c1 = jt.fi.column
   203  			c2 = jt.fi.relModelInfo.fields.pk.column
   204  
   205  			if jt.fi.reverse {
   206  				c1 = jt.mi.fields.pk.column
   207  				c2 = jt.fi.reverseFieldInfo.column
   208  			}
   209  		}
   210  
   211  		join += fmt.Sprintf("%s%s%s %s ON %s.%s%s%s = %s.%s%s%s ", Q, table, Q, t2,
   212  			t2, Q, c2, Q, t1, Q, c1, Q)
   213  	}
   214  	return
   215  }
   216  
   217  // parse orm model struct field tag expression.
   218  func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) {
   219  	var (
   220  		jtl *dbTable
   221  		fi  *fieldInfo
   222  		fiN *fieldInfo
   223  		mmi = mi
   224  	)
   225  
   226  	num := len(exprs) - 1
   227  	var names []string
   228  
   229  	inner := true
   230  
   231  loopFor:
   232  	for i, ex := range exprs {
   233  
   234  		var ok, okN bool
   235  
   236  		if fiN != nil {
   237  			fi = fiN
   238  			ok = true
   239  			fiN = nil
   240  		}
   241  
   242  		if i == 0 {
   243  			fi, ok = mmi.fields.GetByAny(ex)
   244  		}
   245  
   246  		_ = okN
   247  
   248  		if ok {
   249  
   250  			isRel := fi.rel || fi.reverse
   251  
   252  			names = append(names, fi.name)
   253  
   254  			switch {
   255  			case fi.rel:
   256  				mmi = fi.relModelInfo
   257  				if fi.fieldType == RelManyToMany {
   258  					mmi = fi.relThroughModelInfo
   259  				}
   260  			case fi.reverse:
   261  				mmi = fi.reverseFieldInfo.mi
   262  			}
   263  
   264  			if i < num {
   265  				fiN, okN = mmi.fields.GetByAny(exprs[i+1])
   266  			}
   267  
   268  			if isRel && (!fi.mi.isThrough || num != i) {
   269  				if fi.null || t.skipEnd {
   270  					inner = false
   271  				}
   272  
   273  				if t.skipEnd && okN || !t.skipEnd {
   274  					if t.skipEnd && okN && fiN.pk {
   275  						goto loopEnd
   276  					}
   277  
   278  					jt, _ := t.add(names, mmi, fi, inner)
   279  					jt.jtl = jtl
   280  					jtl = jt
   281  				}
   282  
   283  			}
   284  
   285  			if num != i {
   286  				continue
   287  			}
   288  
   289  		loopEnd:
   290  
   291  			if i == 0 || jtl == nil {
   292  				index = "T0"
   293  			} else {
   294  				index = jtl.index
   295  			}
   296  
   297  			info = fi
   298  
   299  			if jtl == nil {
   300  				name = fi.name
   301  			} else {
   302  				name = jtl.name + ExprSep + fi.name
   303  			}
   304  
   305  			switch {
   306  			case fi.rel:
   307  
   308  			case fi.reverse:
   309  				switch fi.reverseFieldInfo.fieldType {
   310  				case RelOneToOne, RelForeignKey:
   311  					index = jtl.index
   312  					info = fi.reverseFieldInfo.mi.fields.pk
   313  					name = info.name
   314  				}
   315  			}
   316  
   317  			break loopFor
   318  
   319  		} else {
   320  			index = ""
   321  			name = ""
   322  			info = nil
   323  			success = false
   324  			return
   325  		}
   326  	}
   327  
   328  	success = index != "" && info != nil
   329  	return
   330  }
   331  
   332  // generate condition sql.
   333  func (t *dbTables) getCondSQL(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) {
   334  	if cond == nil || cond.IsEmpty() {
   335  		return
   336  	}
   337  
   338  	Q := t.base.TableQuote()
   339  
   340  	mi := t.mi
   341  
   342  	for i, p := range cond.params {
   343  		if i > 0 {
   344  			if p.isOr {
   345  				where += "OR "
   346  			} else {
   347  				where += "AND "
   348  			}
   349  		}
   350  		if p.isNot {
   351  			where += "NOT "
   352  		}
   353  		if p.isCond {
   354  			w, ps := t.getCondSQL(p.cond, true, tz)
   355  			if w != "" {
   356  				w = fmt.Sprintf("( %s) ", w)
   357  			}
   358  			where += w
   359  			params = append(params, ps...)
   360  		} else {
   361  			exprs := p.exprs
   362  
   363  			num := len(exprs) - 1
   364  			operator := ""
   365  			if operators[exprs[num]] {
   366  				operator = exprs[num]
   367  				exprs = exprs[:num]
   368  			}
   369  
   370  			index, _, fi, suc := t.parseExprs(mi, exprs)
   371  			if !suc {
   372  				panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep)))
   373  			}
   374  
   375  			if operator == "" {
   376  				operator = "exact"
   377  			}
   378  
   379  			var operSQL string
   380  			var args []interface{}
   381  			if p.isRaw {
   382  				operSQL = p.sql
   383  			} else {
   384  				operSQL, args = t.base.GenerateOperatorSQL(mi, fi, operator, p.args, tz)
   385  			}
   386  
   387  			leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q)
   388  			t.base.GenerateOperatorLeftCol(fi, operator, &leftCol)
   389  
   390  			where += fmt.Sprintf("%s %s ", leftCol, operSQL)
   391  			params = append(params, args...)
   392  
   393  		}
   394  	}
   395  
   396  	if !sub && where != "" {
   397  		where = "WHERE " + where
   398  	}
   399  
   400  	return
   401  }
   402  
   403  // generate group sql.
   404  func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) {
   405  	if len(groups) == 0 {
   406  		return
   407  	}
   408  
   409  	Q := t.base.TableQuote()
   410  
   411  	groupSqls := make([]string, 0, len(groups))
   412  	for _, group := range groups {
   413  		exprs := strings.Split(group, ExprSep)
   414  
   415  		index, _, fi, suc := t.parseExprs(t.mi, exprs)
   416  		if !suc {
   417  			panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
   418  		}
   419  
   420  		groupSqls = append(groupSqls, fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q))
   421  	}
   422  
   423  	groupSQL = fmt.Sprintf("GROUP BY %s ", strings.Join(groupSqls, ", "))
   424  	return
   425  }
   426  
   427  // generate order sql.
   428  func (t *dbTables) getOrderSQL(orders []*clauses.Order) (orderSQL string) {
   429  	if len(orders) == 0 {
   430  		return
   431  	}
   432  
   433  	Q := t.base.TableQuote()
   434  
   435  	orderSqls := make([]string, 0, len(orders))
   436  	for _, order := range orders {
   437  		column := order.GetColumn()
   438  		clause := strings.Split(column, clauses.ExprDot)
   439  
   440  		if order.IsRaw() {
   441  			if len(clause) == 2 {
   442  				orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", clause[0], Q, clause[1], Q, order.SortString()))
   443  			} else if len(clause) == 1 {
   444  				orderSqls = append(orderSqls, fmt.Sprintf("%s%s%s %s", Q, clause[0], Q, order.SortString()))
   445  			} else {
   446  				panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(clause, ExprSep)))
   447  			}
   448  		} else {
   449  			index, _, fi, suc := t.parseExprs(t.mi, clause)
   450  			if !suc {
   451  				panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(clause, ExprSep)))
   452  			}
   453  
   454  			orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, order.SortString()))
   455  		}
   456  	}
   457  
   458  	orderSQL = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", "))
   459  	return
   460  }
   461  
   462  // generate limit sql.
   463  func (t *dbTables) getLimitSQL(mi *modelInfo, offset int64, limit int64) (limits string) {
   464  	if limit == 0 {
   465  		limit = int64(DefaultRowsLimit)
   466  	}
   467  	if limit < 0 {
   468  		// no limit
   469  		if offset > 0 {
   470  			maxLimit := t.base.MaxLimit()
   471  			if maxLimit == 0 {
   472  				limits = fmt.Sprintf("OFFSET %d", offset)
   473  			} else {
   474  				limits = fmt.Sprintf("LIMIT %d OFFSET %d", maxLimit, offset)
   475  			}
   476  		}
   477  	} else if offset <= 0 {
   478  		limits = fmt.Sprintf("LIMIT %d", limit)
   479  	} else {
   480  		limits = fmt.Sprintf("LIMIT %d OFFSET %d", limit, offset)
   481  	}
   482  	return
   483  }
   484  
   485  // getIndexSql generate index sql.
   486  func (t *dbTables) getIndexSql(tableName string, useIndex int, indexes []string) (clause string) {
   487  	if len(indexes) == 0 {
   488  		return
   489  	}
   490  
   491  	return t.base.GenerateSpecifyIndex(tableName, useIndex, indexes)
   492  }
   493  
   494  // crete new tables collection.
   495  func newDbTables(mi *modelInfo, base dbBaser) *dbTables {
   496  	tables := &dbTables{}
   497  	tables.tablesM = make(map[string]*dbTable)
   498  	tables.mi = mi
   499  	tables.base = base
   500  	return tables
   501  }