github.com/mdaxf/iac@v0.0.0-20240519030858-58a061660378/databases/orm/models.go (about)

     1  // The original package is migrated from beego and modified, you can find orignal from following link:
     2  //    "github.com/beego/beego/"
     3  //
     4  // Copyright 2023 IAC. All Rights Reserved.
     5  //
     6  // Licensed under the Apache License, Version 2.0 (the "License");
     7  // you may not use this file except in compliance with the License.
     8  // You may obtain a copy of the License at
     9  //
    10  //      http://www.apache.org/licenses/LICENSE-2.0
    11  //
    12  // Unless required by applicable law or agreed to in writing, software
    13  // distributed under the License is distributed on an "AS IS" BASIS,
    14  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    15  // See the License for the specific language governing permissions and
    16  // limitations under the License.
    17  
    18  package orm
    19  
    20  import (
    21  	"errors"
    22  	"fmt"
    23  	"reflect"
    24  	"runtime/debug"
    25  	"strings"
    26  	"sync"
    27  )
    28  
    29  const (
    30  	odCascade             = "cascade"
    31  	odSetNULL             = "set_null"
    32  	odSetDefault          = "set_default"
    33  	odDoNothing           = "do_nothing"
    34  	defaultStructTagName  = "orm"
    35  	defaultStructTagDelim = ";"
    36  )
    37  
    38  var defaultModelCache = NewModelCacheHandler()
    39  
    40  // model info collection
    41  type modelCache struct {
    42  	sync.RWMutex    // only used outsite for bootStrap
    43  	orders          []string
    44  	cache           map[string]*modelInfo
    45  	cacheByFullName map[string]*modelInfo
    46  	done            bool
    47  }
    48  
    49  // NewModelCacheHandler generator of modelCache
    50  func NewModelCacheHandler() *modelCache {
    51  	return &modelCache{
    52  		cache:           make(map[string]*modelInfo),
    53  		cacheByFullName: make(map[string]*modelInfo),
    54  	}
    55  }
    56  
    57  // get all model info
    58  func (mc *modelCache) all() map[string]*modelInfo {
    59  	m := make(map[string]*modelInfo, len(mc.cache))
    60  	for k, v := range mc.cache {
    61  		m[k] = v
    62  	}
    63  	return m
    64  }
    65  
    66  // get ordered model info
    67  func (mc *modelCache) allOrdered() []*modelInfo {
    68  	m := make([]*modelInfo, 0, len(mc.orders))
    69  	for _, table := range mc.orders {
    70  		m = append(m, mc.cache[table])
    71  	}
    72  	return m
    73  }
    74  
    75  // get model info by table name
    76  func (mc *modelCache) get(table string) (mi *modelInfo, ok bool) {
    77  	mi, ok = mc.cache[table]
    78  	return
    79  }
    80  
    81  // get model info by full name
    82  func (mc *modelCache) getByFullName(name string) (mi *modelInfo, ok bool) {
    83  	mi, ok = mc.cacheByFullName[name]
    84  	return
    85  }
    86  
    87  func (mc *modelCache) getByMd(md interface{}) (*modelInfo, bool) {
    88  	val := reflect.ValueOf(md)
    89  	ind := reflect.Indirect(val)
    90  	typ := ind.Type()
    91  	name := getFullName(typ)
    92  	return mc.getByFullName(name)
    93  }
    94  
    95  // set model info to collection
    96  func (mc *modelCache) set(table string, mi *modelInfo) *modelInfo {
    97  	mii := mc.cache[table]
    98  	mc.cache[table] = mi
    99  	mc.cacheByFullName[mi.fullName] = mi
   100  	if mii == nil {
   101  		mc.orders = append(mc.orders, table)
   102  	}
   103  	return mii
   104  }
   105  
   106  // clean all model info.
   107  func (mc *modelCache) clean() {
   108  	mc.Lock()
   109  	defer mc.Unlock()
   110  
   111  	mc.orders = make([]string, 0)
   112  	mc.cache = make(map[string]*modelInfo)
   113  	mc.cacheByFullName = make(map[string]*modelInfo)
   114  	mc.done = false
   115  }
   116  
   117  // bootstrap bootstrap for models
   118  func (mc *modelCache) bootstrap() {
   119  	mc.Lock()
   120  	defer mc.Unlock()
   121  	if mc.done {
   122  		return
   123  	}
   124  	var (
   125  		err    error
   126  		models map[string]*modelInfo
   127  	)
   128  	if dataBaseCache.getDefault() == nil {
   129  		err = fmt.Errorf("must have one register DataBase alias named `default`")
   130  		goto end
   131  	}
   132  
   133  	// set rel and reverse model
   134  	// RelManyToMany set the relTable
   135  	models = mc.all()
   136  	for _, mi := range models {
   137  		for _, fi := range mi.fields.columns {
   138  			if fi.rel || fi.reverse {
   139  				elm := fi.addrValue.Type().Elem()
   140  				if fi.fieldType == RelReverseMany || fi.fieldType == RelManyToMany {
   141  					elm = elm.Elem()
   142  				}
   143  				// check the rel or reverse model already register
   144  				name := getFullName(elm)
   145  				mii, ok := mc.getByFullName(name)
   146  				if !ok || mii.pkg != elm.PkgPath() {
   147  					err = fmt.Errorf("can not find rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String())
   148  					goto end
   149  				}
   150  				fi.relModelInfo = mii
   151  
   152  				switch fi.fieldType {
   153  				case RelManyToMany:
   154  					if fi.relThrough != "" {
   155  						if i := strings.LastIndex(fi.relThrough, "."); i != -1 && len(fi.relThrough) > (i+1) {
   156  							pn := fi.relThrough[:i]
   157  							rmi, ok := mc.getByFullName(fi.relThrough)
   158  							if !ok || pn != rmi.pkg {
   159  								err = fmt.Errorf("field `%s` wrong rel_through value `%s` cannot find table", fi.fullName, fi.relThrough)
   160  								goto end
   161  							}
   162  							fi.relThroughModelInfo = rmi
   163  							fi.relTable = rmi.table
   164  						} else {
   165  							err = fmt.Errorf("field `%s` wrong rel_through value `%s`", fi.fullName, fi.relThrough)
   166  							goto end
   167  						}
   168  					} else {
   169  						i := newM2MModelInfo(mi, mii)
   170  						if fi.relTable != "" {
   171  							i.table = fi.relTable
   172  						}
   173  						if v := mc.set(i.table, i); v != nil {
   174  							err = fmt.Errorf("the rel table name `%s` already registered, cannot be use, please change one", fi.relTable)
   175  							goto end
   176  						}
   177  						fi.relTable = i.table
   178  						fi.relThroughModelInfo = i
   179  					}
   180  
   181  					fi.relThroughModelInfo.isThrough = true
   182  				}
   183  			}
   184  		}
   185  	}
   186  
   187  	// check the rel filed while the relModelInfo also has filed point to current model
   188  	// if not exist, add a new field to the relModelInfo
   189  	models = mc.all()
   190  	for _, mi := range models {
   191  		for _, fi := range mi.fields.fieldsRel {
   192  			switch fi.fieldType {
   193  			case RelForeignKey, RelOneToOne, RelManyToMany:
   194  				inModel := false
   195  				for _, ffi := range fi.relModelInfo.fields.fieldsReverse {
   196  					if ffi.relModelInfo == mi {
   197  						inModel = true
   198  						break
   199  					}
   200  				}
   201  				if !inModel {
   202  					rmi := fi.relModelInfo
   203  					ffi := new(fieldInfo)
   204  					ffi.name = mi.name
   205  					ffi.column = ffi.name
   206  					ffi.fullName = rmi.fullName + "." + ffi.name
   207  					ffi.reverse = true
   208  					ffi.relModelInfo = mi
   209  					ffi.mi = rmi
   210  					if fi.fieldType == RelOneToOne {
   211  						ffi.fieldType = RelReverseOne
   212  					} else {
   213  						ffi.fieldType = RelReverseMany
   214  					}
   215  					if !rmi.fields.Add(ffi) {
   216  						added := false
   217  						for cnt := 0; cnt < 5; cnt++ {
   218  							ffi.name = fmt.Sprintf("%s%d", mi.name, cnt)
   219  							ffi.column = ffi.name
   220  							ffi.fullName = rmi.fullName + "." + ffi.name
   221  							if added = rmi.fields.Add(ffi); added {
   222  								break
   223  							}
   224  						}
   225  						if !added {
   226  							panic(fmt.Errorf("cannot generate auto reverse field info `%s` to `%s`", fi.fullName, ffi.fullName))
   227  						}
   228  					}
   229  				}
   230  			}
   231  		}
   232  	}
   233  
   234  	models = mc.all()
   235  	for _, mi := range models {
   236  		for _, fi := range mi.fields.fieldsRel {
   237  			switch fi.fieldType {
   238  			case RelManyToMany:
   239  				for _, ffi := range fi.relThroughModelInfo.fields.fieldsRel {
   240  					switch ffi.fieldType {
   241  					case RelOneToOne, RelForeignKey:
   242  						if ffi.relModelInfo == fi.relModelInfo {
   243  							fi.reverseFieldInfoTwo = ffi
   244  						}
   245  						if ffi.relModelInfo == mi {
   246  							fi.reverseField = ffi.name
   247  							fi.reverseFieldInfo = ffi
   248  						}
   249  					}
   250  				}
   251  				if fi.reverseFieldInfoTwo == nil {
   252  					err = fmt.Errorf("can not find m2m field for m2m model `%s`, ensure your m2m model defined correct",
   253  						fi.relThroughModelInfo.fullName)
   254  					goto end
   255  				}
   256  			}
   257  		}
   258  	}
   259  
   260  	models = mc.all()
   261  	for _, mi := range models {
   262  		for _, fi := range mi.fields.fieldsReverse {
   263  			switch fi.fieldType {
   264  			case RelReverseOne:
   265  				found := false
   266  			mForA:
   267  				for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelOneToOne] {
   268  					if ffi.relModelInfo == mi {
   269  						found = true
   270  						fi.reverseField = ffi.name
   271  						fi.reverseFieldInfo = ffi
   272  
   273  						ffi.reverseField = fi.name
   274  						ffi.reverseFieldInfo = fi
   275  						break mForA
   276  					}
   277  				}
   278  				if !found {
   279  					err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName)
   280  					goto end
   281  				}
   282  			case RelReverseMany:
   283  				found := false
   284  			mForB:
   285  				for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelForeignKey] {
   286  					if ffi.relModelInfo == mi {
   287  						found = true
   288  						fi.reverseField = ffi.name
   289  						fi.reverseFieldInfo = ffi
   290  
   291  						ffi.reverseField = fi.name
   292  						ffi.reverseFieldInfo = fi
   293  
   294  						break mForB
   295  					}
   296  				}
   297  				if !found {
   298  				mForC:
   299  					for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelManyToMany] {
   300  						conditions := fi.relThrough != "" && fi.relThrough == ffi.relThrough ||
   301  							fi.relTable != "" && fi.relTable == ffi.relTable ||
   302  							fi.relThrough == "" && fi.relTable == ""
   303  						if ffi.relModelInfo == mi && conditions {
   304  							found = true
   305  
   306  							fi.reverseField = ffi.reverseFieldInfoTwo.name
   307  							fi.reverseFieldInfo = ffi.reverseFieldInfoTwo
   308  							fi.relThroughModelInfo = ffi.relThroughModelInfo
   309  							fi.reverseFieldInfoTwo = ffi.reverseFieldInfo
   310  							fi.reverseFieldInfoM2M = ffi
   311  							ffi.reverseFieldInfoM2M = fi
   312  
   313  							break mForC
   314  						}
   315  					}
   316  				}
   317  				if !found {
   318  					err = fmt.Errorf("reverse field for `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName)
   319  					goto end
   320  				}
   321  			}
   322  		}
   323  	}
   324  
   325  end:
   326  	if err != nil {
   327  		fmt.Println(err)
   328  		debug.PrintStack()
   329  	}
   330  	mc.done = true
   331  }
   332  
   333  // register register models to model cache
   334  func (mc *modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, models ...interface{}) (err error) {
   335  	for _, model := range models {
   336  		val := reflect.ValueOf(model)
   337  		typ := reflect.Indirect(val).Type()
   338  
   339  		if val.Kind() != reflect.Ptr {
   340  			err = fmt.Errorf("<orm.RegisterModel> cannot use non-ptr model struct `%s`", getFullName(typ))
   341  			return
   342  		}
   343  		// For this case:
   344  		// u := &User{}
   345  		// registerModel(&u)
   346  		if typ.Kind() == reflect.Ptr {
   347  			err = fmt.Errorf("<orm.RegisterModel> only allow ptr model struct, it looks you use two reference to the struct `%s`", typ)
   348  			return
   349  		}
   350  		if val.Elem().Kind() == reflect.Slice {
   351  			val = reflect.New(val.Elem().Type().Elem())
   352  		}
   353  		table := getTableName(val)
   354  
   355  		if prefixOrSuffixStr != "" {
   356  			if prefixOrSuffix {
   357  				table = prefixOrSuffixStr + table
   358  			} else {
   359  				table = table + prefixOrSuffixStr
   360  			}
   361  		}
   362  
   363  		// models's fullname is pkgpath + struct name
   364  		name := getFullName(typ)
   365  		if _, ok := mc.getByFullName(name); ok {
   366  			err = fmt.Errorf("<orm.RegisterModel> model `%s` repeat register, must be unique\n", name)
   367  			return
   368  		}
   369  
   370  		if _, ok := mc.get(table); ok {
   371  			return nil
   372  		}
   373  
   374  		mi := newModelInfo(val)
   375  		if mi.fields.pk == nil {
   376  		outFor:
   377  			for _, fi := range mi.fields.fieldsDB {
   378  				if strings.ToLower(fi.name) == "id" {
   379  					switch fi.addrValue.Elem().Kind() {
   380  					case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64:
   381  						fi.auto = true
   382  						fi.pk = true
   383  						mi.fields.pk = fi
   384  						break outFor
   385  					}
   386  				}
   387  			}
   388  		}
   389  
   390  		mi.table = table
   391  		mi.pkg = typ.PkgPath()
   392  		mi.model = model
   393  		mi.manual = true
   394  
   395  		mc.set(table, mi)
   396  	}
   397  	return
   398  }
   399  
   400  // getDbDropSQL get database scheme drop sql queries
   401  func (mc *modelCache) getDbDropSQL(al *alias) (queries []string, err error) {
   402  	if len(mc.cache) == 0 {
   403  		err = errors.New("no Model found, need register your model")
   404  		return
   405  	}
   406  
   407  	Q := al.DbBaser.TableQuote()
   408  
   409  	for _, mi := range mc.allOrdered() {
   410  		queries = append(queries, fmt.Sprintf(`DROP TABLE IF EXISTS %s%s%s`, Q, mi.table, Q))
   411  	}
   412  	return queries, nil
   413  }
   414  
   415  // getDbCreateSQL get database scheme creation sql queries
   416  func (mc *modelCache) getDbCreateSQL(al *alias) (queries []string, tableIndexes map[string][]dbIndex, err error) {
   417  	if len(mc.cache) == 0 {
   418  		err = errors.New("no Model found, need register your model")
   419  		return
   420  	}
   421  
   422  	Q := al.DbBaser.TableQuote()
   423  	T := al.DbBaser.DbTypes()
   424  	sep := fmt.Sprintf("%s, %s", Q, Q)
   425  
   426  	tableIndexes = make(map[string][]dbIndex)
   427  
   428  	for _, mi := range mc.allOrdered() {
   429  		sql := fmt.Sprintf("-- %s\n", strings.Repeat("-", 50))
   430  		sql += fmt.Sprintf("--  Table Structure for `%s`\n", mi.fullName)
   431  		sql += fmt.Sprintf("-- %s\n", strings.Repeat("-", 50))
   432  
   433  		sql += fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s%s%s (\n", Q, mi.table, Q)
   434  
   435  		columns := make([]string, 0, len(mi.fields.fieldsDB))
   436  
   437  		sqlIndexes := [][]string{}
   438  		var commentIndexes []int // store comment indexes for postgres
   439  
   440  		for i, fi := range mi.fields.fieldsDB {
   441  			column := fmt.Sprintf("    %s%s%s ", Q, fi.column, Q)
   442  			col := getColumnTyp(al, fi)
   443  
   444  			if fi.auto {
   445  				switch al.Driver {
   446  				case DRSqlite, DRPostgres:
   447  					column += T["auto"]
   448  				default:
   449  					column += col + " " + T["auto"]
   450  				}
   451  			} else if fi.pk {
   452  				column += col + " " + T["pk"]
   453  			} else {
   454  				column += col
   455  
   456  				if !fi.null {
   457  					column += " " + "NOT NULL"
   458  				}
   459  
   460  				// if fi.initial.String() != "" {
   461  				//	column += " DEFAULT " + fi.initial.String()
   462  				// }
   463  
   464  				// Append attribute DEFAULT
   465  				column += getColumnDefault(fi)
   466  
   467  				if fi.unique {
   468  					column += " " + "UNIQUE"
   469  				}
   470  
   471  				if fi.index {
   472  					sqlIndexes = append(sqlIndexes, []string{fi.column})
   473  				}
   474  			}
   475  
   476  			if strings.Contains(column, "%COL%") {
   477  				column = strings.Replace(column, "%COL%", fi.column, -1)
   478  			}
   479  
   480  			if fi.description != "" && al.Driver != DRSqlite {
   481  				if al.Driver == DRPostgres {
   482  					commentIndexes = append(commentIndexes, i)
   483  				} else {
   484  					column += " " + fmt.Sprintf("COMMENT '%s'", fi.description)
   485  				}
   486  			}
   487  
   488  			columns = append(columns, column)
   489  		}
   490  
   491  		if mi.model != nil {
   492  			allnames := getTableUnique(mi.addrField)
   493  			if !mi.manual && len(mi.uniques) > 0 {
   494  				allnames = append(allnames, mi.uniques)
   495  			}
   496  			for _, names := range allnames {
   497  				cols := make([]string, 0, len(names))
   498  				for _, name := range names {
   499  					if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol {
   500  						cols = append(cols, fi.column)
   501  					} else {
   502  						panic(fmt.Errorf("cannot found column `%s` when parse UNIQUE in `%s.TableUnique`", name, mi.fullName))
   503  					}
   504  				}
   505  				column := fmt.Sprintf("    UNIQUE (%s%s%s)", Q, strings.Join(cols, sep), Q)
   506  				columns = append(columns, column)
   507  			}
   508  		}
   509  
   510  		sql += strings.Join(columns, ",\n")
   511  		sql += "\n)"
   512  
   513  		if al.Driver == DRMySQL {
   514  			var engine string
   515  			if mi.model != nil {
   516  				engine = getTableEngine(mi.addrField)
   517  			}
   518  			if engine == "" {
   519  				engine = al.Engine
   520  			}
   521  			sql += " ENGINE=" + engine
   522  		}
   523  
   524  		sql += ";"
   525  		if al.Driver == DRPostgres && len(commentIndexes) > 0 {
   526  			// append comments for postgres only
   527  			for _, index := range commentIndexes {
   528  				sql += fmt.Sprintf("\nCOMMENT ON COLUMN %s%s%s.%s%s%s is '%s';",
   529  					Q,
   530  					mi.table,
   531  					Q,
   532  					Q,
   533  					mi.fields.fieldsDB[index].column,
   534  					Q,
   535  					mi.fields.fieldsDB[index].description)
   536  			}
   537  		}
   538  		queries = append(queries, sql)
   539  
   540  		if mi.model != nil {
   541  			for _, names := range getTableIndex(mi.addrField) {
   542  				cols := make([]string, 0, len(names))
   543  				for _, name := range names {
   544  					if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol {
   545  						cols = append(cols, fi.column)
   546  					} else {
   547  						panic(fmt.Errorf("cannot found column `%s` when parse INDEX in `%s.TableIndex`", name, mi.fullName))
   548  					}
   549  				}
   550  				sqlIndexes = append(sqlIndexes, cols)
   551  			}
   552  		}
   553  
   554  		for _, names := range sqlIndexes {
   555  			name := mi.table + "_" + strings.Join(names, "_")
   556  			cols := strings.Join(names, sep)
   557  			sql := fmt.Sprintf("CREATE INDEX %s%s%s ON %s%s%s (%s%s%s);", Q, name, Q, Q, mi.table, Q, Q, cols, Q)
   558  
   559  			index := dbIndex{}
   560  			index.Table = mi.table
   561  			index.Name = name
   562  			index.SQL = sql
   563  
   564  			tableIndexes[mi.table] = append(tableIndexes[mi.table], index)
   565  		}
   566  
   567  	}
   568  
   569  	return
   570  }
   571  
   572  // ResetModelCache Clean model cache. Then you can re-RegisterModel.
   573  // Common use this api for test case.
   574  func ResetModelCache() {
   575  	defaultModelCache.clean()
   576  }