github.com/astaxie/beego@v1.12.3/orm/models_boot.go (about)

     1  // Copyright 2014 beego Author. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package orm
    16  
    17  import (
    18  	"fmt"
    19  	"os"
    20  	"reflect"
    21  	"runtime/debug"
    22  	"strings"
    23  )
    24  
    25  // register models.
    26  // PrefixOrSuffix means table name prefix or suffix.
    27  // isPrefix whether the prefix is prefix or suffix
    28  func registerModel(PrefixOrSuffix string, model interface{}, isPrefix bool) {
    29  	val := reflect.ValueOf(model)
    30  	typ := reflect.Indirect(val).Type()
    31  
    32  	if val.Kind() != reflect.Ptr {
    33  		panic(fmt.Errorf("<orm.RegisterModel> cannot use non-ptr model struct `%s`", getFullName(typ)))
    34  	}
    35  	// For this case:
    36  	// u := &User{}
    37  	// registerModel(&u)
    38  	if typ.Kind() == reflect.Ptr {
    39  		panic(fmt.Errorf("<orm.RegisterModel> only allow ptr model struct, it looks you use two reference to the struct `%s`", typ))
    40  	}
    41  
    42  	table := getTableName(val)
    43  
    44  	if PrefixOrSuffix != "" {
    45  		if isPrefix {
    46  			table = PrefixOrSuffix + table
    47  		} else {
    48  			table = table + PrefixOrSuffix
    49  		}
    50  	}
    51  	// models's fullname is pkgpath + struct name
    52  	name := getFullName(typ)
    53  	if _, ok := modelCache.getByFullName(name); ok {
    54  		fmt.Printf("<orm.RegisterModel> model `%s` repeat register, must be unique\n", name)
    55  		os.Exit(2)
    56  	}
    57  
    58  	if _, ok := modelCache.get(table); ok {
    59  		fmt.Printf("<orm.RegisterModel> table name `%s` repeat register, must be unique\n", table)
    60  		os.Exit(2)
    61  	}
    62  
    63  	mi := newModelInfo(val)
    64  	if mi.fields.pk == nil {
    65  	outFor:
    66  		for _, fi := range mi.fields.fieldsDB {
    67  			if strings.ToLower(fi.name) == "id" {
    68  				switch fi.addrValue.Elem().Kind() {
    69  				case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64:
    70  					fi.auto = true
    71  					fi.pk = true
    72  					mi.fields.pk = fi
    73  					break outFor
    74  				}
    75  			}
    76  		}
    77  
    78  		if mi.fields.pk == nil {
    79  			fmt.Printf("<orm.RegisterModel> `%s` needs a primary key field, default is to use 'id' if not set\n", name)
    80  			os.Exit(2)
    81  		}
    82  
    83  	}
    84  
    85  	mi.table = table
    86  	mi.pkg = typ.PkgPath()
    87  	mi.model = model
    88  	mi.manual = true
    89  
    90  	modelCache.set(table, mi)
    91  }
    92  
    93  // bootstrap models
    94  func bootStrap() {
    95  	if modelCache.done {
    96  		return
    97  	}
    98  	var (
    99  		err    error
   100  		models map[string]*modelInfo
   101  	)
   102  	if dataBaseCache.getDefault() == nil {
   103  		err = fmt.Errorf("must have one register DataBase alias named `default`")
   104  		goto end
   105  	}
   106  
   107  	// set rel and reverse model
   108  	// RelManyToMany set the relTable
   109  	models = modelCache.all()
   110  	for _, mi := range models {
   111  		for _, fi := range mi.fields.columns {
   112  			if fi.rel || fi.reverse {
   113  				elm := fi.addrValue.Type().Elem()
   114  				if fi.fieldType == RelReverseMany || fi.fieldType == RelManyToMany {
   115  					elm = elm.Elem()
   116  				}
   117  				// check the rel or reverse model already register
   118  				name := getFullName(elm)
   119  				mii, ok := modelCache.getByFullName(name)
   120  				if !ok || mii.pkg != elm.PkgPath() {
   121  					err = fmt.Errorf("can not find rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String())
   122  					goto end
   123  				}
   124  				fi.relModelInfo = mii
   125  
   126  				switch fi.fieldType {
   127  				case RelManyToMany:
   128  					if fi.relThrough != "" {
   129  						if i := strings.LastIndex(fi.relThrough, "."); i != -1 && len(fi.relThrough) > (i+1) {
   130  							pn := fi.relThrough[:i]
   131  							rmi, ok := modelCache.getByFullName(fi.relThrough)
   132  							if !ok || pn != rmi.pkg {
   133  								err = fmt.Errorf("field `%s` wrong rel_through value `%s` cannot find table", fi.fullName, fi.relThrough)
   134  								goto end
   135  							}
   136  							fi.relThroughModelInfo = rmi
   137  							fi.relTable = rmi.table
   138  						} else {
   139  							err = fmt.Errorf("field `%s` wrong rel_through value `%s`", fi.fullName, fi.relThrough)
   140  							goto end
   141  						}
   142  					} else {
   143  						i := newM2MModelInfo(mi, mii)
   144  						if fi.relTable != "" {
   145  							i.table = fi.relTable
   146  						}
   147  						if v := modelCache.set(i.table, i); v != nil {
   148  							err = fmt.Errorf("the rel table name `%s` already registered, cannot be use, please change one", fi.relTable)
   149  							goto end
   150  						}
   151  						fi.relTable = i.table
   152  						fi.relThroughModelInfo = i
   153  					}
   154  
   155  					fi.relThroughModelInfo.isThrough = true
   156  				}
   157  			}
   158  		}
   159  	}
   160  
   161  	// check the rel filed while the relModelInfo also has filed point to current model
   162  	// if not exist, add a new field to the relModelInfo
   163  	models = modelCache.all()
   164  	for _, mi := range models {
   165  		for _, fi := range mi.fields.fieldsRel {
   166  			switch fi.fieldType {
   167  			case RelForeignKey, RelOneToOne, RelManyToMany:
   168  				inModel := false
   169  				for _, ffi := range fi.relModelInfo.fields.fieldsReverse {
   170  					if ffi.relModelInfo == mi {
   171  						inModel = true
   172  						break
   173  					}
   174  				}
   175  				if !inModel {
   176  					rmi := fi.relModelInfo
   177  					ffi := new(fieldInfo)
   178  					ffi.name = mi.name
   179  					ffi.column = ffi.name
   180  					ffi.fullName = rmi.fullName + "." + ffi.name
   181  					ffi.reverse = true
   182  					ffi.relModelInfo = mi
   183  					ffi.mi = rmi
   184  					if fi.fieldType == RelOneToOne {
   185  						ffi.fieldType = RelReverseOne
   186  					} else {
   187  						ffi.fieldType = RelReverseMany
   188  					}
   189  					if !rmi.fields.Add(ffi) {
   190  						added := false
   191  						for cnt := 0; cnt < 5; cnt++ {
   192  							ffi.name = fmt.Sprintf("%s%d", mi.name, cnt)
   193  							ffi.column = ffi.name
   194  							ffi.fullName = rmi.fullName + "." + ffi.name
   195  							if added = rmi.fields.Add(ffi); added {
   196  								break
   197  							}
   198  						}
   199  						if !added {
   200  							panic(fmt.Errorf("cannot generate auto reverse field info `%s` to `%s`", fi.fullName, ffi.fullName))
   201  						}
   202  					}
   203  				}
   204  			}
   205  		}
   206  	}
   207  
   208  	models = modelCache.all()
   209  	for _, mi := range models {
   210  		for _, fi := range mi.fields.fieldsRel {
   211  			switch fi.fieldType {
   212  			case RelManyToMany:
   213  				for _, ffi := range fi.relThroughModelInfo.fields.fieldsRel {
   214  					switch ffi.fieldType {
   215  					case RelOneToOne, RelForeignKey:
   216  						if ffi.relModelInfo == fi.relModelInfo {
   217  							fi.reverseFieldInfoTwo = ffi
   218  						}
   219  						if ffi.relModelInfo == mi {
   220  							fi.reverseField = ffi.name
   221  							fi.reverseFieldInfo = ffi
   222  						}
   223  					}
   224  				}
   225  				if fi.reverseFieldInfoTwo == nil {
   226  					err = fmt.Errorf("can not find m2m field for m2m model `%s`, ensure your m2m model defined correct",
   227  						fi.relThroughModelInfo.fullName)
   228  					goto end
   229  				}
   230  			}
   231  		}
   232  	}
   233  
   234  	models = modelCache.all()
   235  	for _, mi := range models {
   236  		for _, fi := range mi.fields.fieldsReverse {
   237  			switch fi.fieldType {
   238  			case RelReverseOne:
   239  				found := false
   240  			mForA:
   241  				for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelOneToOne] {
   242  					if ffi.relModelInfo == mi {
   243  						found = true
   244  						fi.reverseField = ffi.name
   245  						fi.reverseFieldInfo = ffi
   246  
   247  						ffi.reverseField = fi.name
   248  						ffi.reverseFieldInfo = fi
   249  						break mForA
   250  					}
   251  				}
   252  				if !found {
   253  					err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName)
   254  					goto end
   255  				}
   256  			case RelReverseMany:
   257  				found := false
   258  			mForB:
   259  				for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelForeignKey] {
   260  					if ffi.relModelInfo == mi {
   261  						found = true
   262  						fi.reverseField = ffi.name
   263  						fi.reverseFieldInfo = ffi
   264  
   265  						ffi.reverseField = fi.name
   266  						ffi.reverseFieldInfo = fi
   267  
   268  						break mForB
   269  					}
   270  				}
   271  				if !found {
   272  				mForC:
   273  					for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelManyToMany] {
   274  						conditions := fi.relThrough != "" && fi.relThrough == ffi.relThrough ||
   275  							fi.relTable != "" && fi.relTable == ffi.relTable ||
   276  							fi.relThrough == "" && fi.relTable == ""
   277  						if ffi.relModelInfo == mi && conditions {
   278  							found = true
   279  
   280  							fi.reverseField = ffi.reverseFieldInfoTwo.name
   281  							fi.reverseFieldInfo = ffi.reverseFieldInfoTwo
   282  							fi.relThroughModelInfo = ffi.relThroughModelInfo
   283  							fi.reverseFieldInfoTwo = ffi.reverseFieldInfo
   284  							fi.reverseFieldInfoM2M = ffi
   285  							ffi.reverseFieldInfoM2M = fi
   286  
   287  							break mForC
   288  						}
   289  					}
   290  				}
   291  				if !found {
   292  					err = fmt.Errorf("reverse field for `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName)
   293  					goto end
   294  				}
   295  			}
   296  		}
   297  	}
   298  
   299  end:
   300  	if err != nil {
   301  		fmt.Println(err)
   302  		debug.PrintStack()
   303  		os.Exit(2)
   304  	}
   305  }
   306  
   307  // RegisterModel register models
   308  func RegisterModel(models ...interface{}) {
   309  	if modelCache.done {
   310  		panic(fmt.Errorf("RegisterModel must be run before BootStrap"))
   311  	}
   312  	RegisterModelWithPrefix("", models...)
   313  }
   314  
   315  // RegisterModelWithPrefix register models with a prefix
   316  func RegisterModelWithPrefix(prefix string, models ...interface{}) {
   317  	if modelCache.done {
   318  		panic(fmt.Errorf("RegisterModelWithPrefix must be run before BootStrap"))
   319  	}
   320  
   321  	for _, model := range models {
   322  		registerModel(prefix, model, true)
   323  	}
   324  }
   325  
   326  // RegisterModelWithSuffix register models with a suffix
   327  func RegisterModelWithSuffix(suffix string, models ...interface{}) {
   328  	if modelCache.done {
   329  		panic(fmt.Errorf("RegisterModelWithSuffix must be run before BootStrap"))
   330  	}
   331  
   332  	for _, model := range models {
   333  		registerModel(suffix, model, false)
   334  	}
   335  }
   336  
   337  // BootStrap bootstrap models.
   338  // make all model parsed and can not add more models
   339  func BootStrap() {
   340  	modelCache.Lock()
   341  	defer modelCache.Unlock()
   342  	if modelCache.done {
   343  		return
   344  	}
   345  	bootStrap()
   346  	modelCache.done = true
   347  }