github.com/astaxie/beego@v1.12.3/orm/orm.go (about)

     1  // Copyright 2014 beego Author. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // +build go1.8
    16  
    17  // Package orm provide ORM for MySQL/PostgreSQL/sqlite
    18  // Simple Usage
    19  //
    20  //	package main
    21  //
    22  //	import (
    23  //		"fmt"
    24  //		"github.com/astaxie/beego/orm"
    25  //		_ "github.com/go-sql-driver/mysql" // import your used driver
    26  //	)
    27  //
    28  //	// Model Struct
    29  //	type User struct {
    30  //		Id   int    `orm:"auto"`
    31  //		Name string `orm:"size(100)"`
    32  //	}
    33  //
    34  //	func init() {
    35  //		orm.RegisterDataBase("default", "mysql", "root:root@/my_db?charset=utf8", 30)
    36  //	}
    37  //
    38  //	func main() {
    39  //		o := orm.NewOrm()
    40  //		user := User{Name: "slene"}
    41  //		// insert
    42  //		id, err := o.Insert(&user)
    43  //		// update
    44  //		user.Name = "astaxie"
    45  //		num, err := o.Update(&user)
    46  //		// read one
    47  //		u := User{Id: user.Id}
    48  //		err = o.Read(&u)
    49  //		// delete
    50  //		num, err = o.Delete(&u)
    51  //	}
    52  //
    53  // more docs: http://beego.me/docs/mvc/model/overview.md
    54  package orm
    55  
    56  import (
    57  	"context"
    58  	"database/sql"
    59  	"errors"
    60  	"fmt"
    61  	"os"
    62  	"reflect"
    63  	"sync"
    64  	"time"
    65  )
    66  
    67  // DebugQueries define the debug
    68  const (
    69  	DebugQueries = iota
    70  )
    71  
    72  // Define common vars
    73  var (
    74  	Debug            = false
    75  	DebugLog         = NewLog(os.Stdout)
    76  	DefaultRowsLimit = -1
    77  	DefaultRelsDepth = 2
    78  	DefaultTimeLoc   = time.Local
    79  	ErrTxHasBegan    = errors.New("<Ormer.Begin> transaction already begin")
    80  	ErrTxDone        = errors.New("<Ormer.Commit/Rollback> transaction not begin")
    81  	ErrMultiRows     = errors.New("<QuerySeter> return multi rows")
    82  	ErrNoRows        = errors.New("<QuerySeter> no row found")
    83  	ErrStmtClosed    = errors.New("<QuerySeter> stmt already closed")
    84  	ErrArgs          = errors.New("<Ormer> args error may be empty")
    85  	ErrNotImplement  = errors.New("have not implement")
    86  )
    87  
    88  // Params stores the Params
    89  type Params map[string]interface{}
    90  
    91  // ParamsList stores paramslist
    92  type ParamsList []interface{}
    93  
    94  type orm struct {
    95  	alias *alias
    96  	db    dbQuerier
    97  	isTx  bool
    98  }
    99  
   100  var _ Ormer = new(orm)
   101  
   102  // get model info and model reflect value
   103  func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect.Value) {
   104  	val := reflect.ValueOf(md)
   105  	ind = reflect.Indirect(val)
   106  	typ := ind.Type()
   107  	if needPtr && val.Kind() != reflect.Ptr {
   108  		panic(fmt.Errorf("<Ormer> cannot use non-ptr model struct `%s`", getFullName(typ)))
   109  	}
   110  	name := getFullName(typ)
   111  	if mi, ok := modelCache.getByFullName(name); ok {
   112  		return mi, ind
   113  	}
   114  	panic(fmt.Errorf("<Ormer> table: `%s` not found, make sure it was registered with `RegisterModel()`", name))
   115  }
   116  
   117  // get field info from model info by given field name
   118  func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo {
   119  	fi, ok := mi.fields.GetByAny(name)
   120  	if !ok {
   121  		panic(fmt.Errorf("<Ormer> cannot find field `%s` for model `%s`", name, mi.fullName))
   122  	}
   123  	return fi
   124  }
   125  
   126  // read data to model
   127  func (o *orm) Read(md interface{}, cols ...string) error {
   128  	mi, ind := o.getMiInd(md, true)
   129  	return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false)
   130  }
   131  
   132  // read data to model, like Read(), but use "SELECT FOR UPDATE" form
   133  func (o *orm) ReadForUpdate(md interface{}, cols ...string) error {
   134  	mi, ind := o.getMiInd(md, true)
   135  	return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, true)
   136  }
   137  
   138  // Try to read a row from the database, or insert one if it doesn't exist
   139  func (o *orm) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) {
   140  	cols = append([]string{col1}, cols...)
   141  	mi, ind := o.getMiInd(md, true)
   142  	err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false)
   143  	if err == ErrNoRows {
   144  		// Create
   145  		id, err := o.Insert(md)
   146  		return (err == nil), id, err
   147  	}
   148  
   149  	id, vid := int64(0), ind.FieldByIndex(mi.fields.pk.fieldIndex)
   150  	if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 {
   151  		id = int64(vid.Uint())
   152  	} else if mi.fields.pk.rel {
   153  		return o.ReadOrCreate(vid.Interface(), mi.fields.pk.relModelInfo.fields.pk.name)
   154  	} else {
   155  		id = vid.Int()
   156  	}
   157  
   158  	return false, id, err
   159  }
   160  
   161  // insert model data to database
   162  func (o *orm) Insert(md interface{}) (int64, error) {
   163  	mi, ind := o.getMiInd(md, true)
   164  	id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
   165  	if err != nil {
   166  		return id, err
   167  	}
   168  
   169  	o.setPk(mi, ind, id)
   170  
   171  	return id, nil
   172  }
   173  
   174  // set auto pk field
   175  func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) {
   176  	if mi.fields.pk.auto {
   177  		if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 {
   178  			ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(uint64(id))
   179  		} else {
   180  			ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(id)
   181  		}
   182  	}
   183  }
   184  
   185  // insert some models to database
   186  func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) {
   187  	var cnt int64
   188  
   189  	sind := reflect.Indirect(reflect.ValueOf(mds))
   190  
   191  	switch sind.Kind() {
   192  	case reflect.Array, reflect.Slice:
   193  		if sind.Len() == 0 {
   194  			return cnt, ErrArgs
   195  		}
   196  	default:
   197  		return cnt, ErrArgs
   198  	}
   199  
   200  	if bulk <= 1 {
   201  		for i := 0; i < sind.Len(); i++ {
   202  			ind := reflect.Indirect(sind.Index(i))
   203  			mi, _ := o.getMiInd(ind.Interface(), false)
   204  			id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
   205  			if err != nil {
   206  				return cnt, err
   207  			}
   208  
   209  			o.setPk(mi, ind, id)
   210  
   211  			cnt++
   212  		}
   213  	} else {
   214  		mi, _ := o.getMiInd(sind.Index(0).Interface(), false)
   215  		return o.alias.DbBaser.InsertMulti(o.db, mi, sind, bulk, o.alias.TZ)
   216  	}
   217  	return cnt, nil
   218  }
   219  
   220  // InsertOrUpdate data to database
   221  func (o *orm) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) {
   222  	mi, ind := o.getMiInd(md, true)
   223  	id, err := o.alias.DbBaser.InsertOrUpdate(o.db, mi, ind, o.alias, colConflitAndArgs...)
   224  	if err != nil {
   225  		return id, err
   226  	}
   227  
   228  	o.setPk(mi, ind, id)
   229  
   230  	return id, nil
   231  }
   232  
   233  // update model to database.
   234  // cols set the columns those want to update.
   235  func (o *orm) Update(md interface{}, cols ...string) (int64, error) {
   236  	mi, ind := o.getMiInd(md, true)
   237  	return o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols)
   238  }
   239  
   240  // delete model in database
   241  // cols shows the delete conditions values read from. default is pk
   242  func (o *orm) Delete(md interface{}, cols ...string) (int64, error) {
   243  	mi, ind := o.getMiInd(md, true)
   244  	num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ, cols)
   245  	if err != nil {
   246  		return num, err
   247  	}
   248  	if num > 0 {
   249  		o.setPk(mi, ind, 0)
   250  	}
   251  	return num, nil
   252  }
   253  
   254  // create a models to models queryer
   255  func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer {
   256  	mi, ind := o.getMiInd(md, true)
   257  	fi := o.getFieldInfo(mi, name)
   258  
   259  	switch {
   260  	case fi.fieldType == RelManyToMany:
   261  	case fi.fieldType == RelReverseMany && fi.reverseFieldInfo.mi.isThrough:
   262  	default:
   263  		panic(fmt.Errorf("<Ormer.QueryM2M> model `%s` . name `%s` is not a m2m field", fi.name, mi.fullName))
   264  	}
   265  
   266  	return newQueryM2M(md, o, mi, fi, ind)
   267  }
   268  
   269  // load related models to md model.
   270  // args are limit, offset int and order string.
   271  //
   272  // example:
   273  // 	orm.LoadRelated(post,"Tags")
   274  // 	for _,tag := range post.Tags{...}
   275  //
   276  // make sure the relation is defined in model struct tags.
   277  func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) {
   278  	_, fi, ind, qseter := o.queryRelated(md, name)
   279  
   280  	qs := qseter.(*querySet)
   281  
   282  	var relDepth int
   283  	var limit, offset int64
   284  	var order string
   285  	for i, arg := range args {
   286  		switch i {
   287  		case 0:
   288  			if v, ok := arg.(bool); ok {
   289  				if v {
   290  					relDepth = DefaultRelsDepth
   291  				}
   292  			} else if v, ok := arg.(int); ok {
   293  				relDepth = v
   294  			}
   295  		case 1:
   296  			limit = ToInt64(arg)
   297  		case 2:
   298  			offset = ToInt64(arg)
   299  		case 3:
   300  			order, _ = arg.(string)
   301  		}
   302  	}
   303  
   304  	switch fi.fieldType {
   305  	case RelOneToOne, RelForeignKey, RelReverseOne:
   306  		limit = 1
   307  		offset = 0
   308  	}
   309  
   310  	qs.limit = limit
   311  	qs.offset = offset
   312  	qs.relDepth = relDepth
   313  
   314  	if len(order) > 0 {
   315  		qs.orders = []string{order}
   316  	}
   317  
   318  	find := ind.FieldByIndex(fi.fieldIndex)
   319  
   320  	var nums int64
   321  	var err error
   322  	switch fi.fieldType {
   323  	case RelOneToOne, RelForeignKey, RelReverseOne:
   324  		val := reflect.New(find.Type().Elem())
   325  		container := val.Interface()
   326  		err = qs.One(container)
   327  		if err == nil {
   328  			find.Set(val)
   329  			nums = 1
   330  		}
   331  	default:
   332  		nums, err = qs.All(find.Addr().Interface())
   333  	}
   334  
   335  	return nums, err
   336  }
   337  
   338  // return a QuerySeter for related models to md model.
   339  // it can do all, update, delete in QuerySeter.
   340  // example:
   341  // 	qs := orm.QueryRelated(post,"Tag")
   342  //  qs.All(&[]*Tag{})
   343  //
   344  func (o *orm) QueryRelated(md interface{}, name string) QuerySeter {
   345  	// is this api needed ?
   346  	_, _, _, qs := o.queryRelated(md, name)
   347  	return qs
   348  }
   349  
   350  // get QuerySeter for related models to md model
   351  func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) {
   352  	mi, ind := o.getMiInd(md, true)
   353  	fi := o.getFieldInfo(mi, name)
   354  
   355  	_, _, exist := getExistPk(mi, ind)
   356  	if !exist {
   357  		panic(ErrMissPK)
   358  	}
   359  
   360  	var qs *querySet
   361  
   362  	switch fi.fieldType {
   363  	case RelOneToOne, RelForeignKey, RelManyToMany:
   364  		if !fi.inModel {
   365  			break
   366  		}
   367  		qs = o.getRelQs(md, mi, fi)
   368  	case RelReverseOne, RelReverseMany:
   369  		if !fi.inModel {
   370  			break
   371  		}
   372  		qs = o.getReverseQs(md, mi, fi)
   373  	}
   374  
   375  	if qs == nil {
   376  		panic(fmt.Errorf("<Ormer> name `%s` for model `%s` is not an available rel/reverse field", md, name))
   377  	}
   378  
   379  	return mi, fi, ind, qs
   380  }
   381  
   382  // get reverse relation QuerySeter
   383  func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
   384  	switch fi.fieldType {
   385  	case RelReverseOne, RelReverseMany:
   386  	default:
   387  		panic(fmt.Errorf("<Ormer> name `%s` for model `%s` is not an available reverse field", fi.name, mi.fullName))
   388  	}
   389  
   390  	var q *querySet
   391  
   392  	if fi.fieldType == RelReverseMany && fi.reverseFieldInfo.mi.isThrough {
   393  		q = newQuerySet(o, fi.relModelInfo).(*querySet)
   394  		q.cond = NewCondition().And(fi.reverseFieldInfoM2M.column+ExprSep+fi.reverseFieldInfo.column, md)
   395  	} else {
   396  		q = newQuerySet(o, fi.reverseFieldInfo.mi).(*querySet)
   397  		q.cond = NewCondition().And(fi.reverseFieldInfo.column, md)
   398  	}
   399  
   400  	return q
   401  }
   402  
   403  // get relation QuerySeter
   404  func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
   405  	switch fi.fieldType {
   406  	case RelOneToOne, RelForeignKey, RelManyToMany:
   407  	default:
   408  		panic(fmt.Errorf("<Ormer> name `%s` for model `%s` is not an available rel field", fi.name, mi.fullName))
   409  	}
   410  
   411  	q := newQuerySet(o, fi.relModelInfo).(*querySet)
   412  	q.cond = NewCondition()
   413  
   414  	if fi.fieldType == RelManyToMany {
   415  		q.cond = q.cond.And(fi.reverseFieldInfoM2M.column+ExprSep+fi.reverseFieldInfo.column, md)
   416  	} else {
   417  		q.cond = q.cond.And(fi.reverseFieldInfo.column, md)
   418  	}
   419  
   420  	return q
   421  }
   422  
   423  // return a QuerySeter for table operations.
   424  // table name can be string or struct.
   425  // e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)),
   426  func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
   427  	var name string
   428  	if table, ok := ptrStructOrTableName.(string); ok {
   429  		name = nameStrategyMap[defaultNameStrategy](table)
   430  		if mi, ok := modelCache.get(name); ok {
   431  			qs = newQuerySet(o, mi)
   432  		}
   433  	} else {
   434  		name = getFullName(indirectType(reflect.TypeOf(ptrStructOrTableName)))
   435  		if mi, ok := modelCache.getByFullName(name); ok {
   436  			qs = newQuerySet(o, mi)
   437  		}
   438  	}
   439  	if qs == nil {
   440  		panic(fmt.Errorf("<Ormer.QueryTable> table name: `%s` not exists", name))
   441  	}
   442  	return
   443  }
   444  
   445  // switch to another registered database driver by given name.
   446  func (o *orm) Using(name string) error {
   447  	if o.isTx {
   448  		panic(fmt.Errorf("<Ormer.Using> transaction has been start, cannot change db"))
   449  	}
   450  	if al, ok := dataBaseCache.get(name); ok {
   451  		o.alias = al
   452  		if Debug {
   453  			o.db = newDbQueryLog(al, al.DB)
   454  		} else {
   455  			o.db = al.DB
   456  		}
   457  	} else {
   458  		return fmt.Errorf("<Ormer.Using> unknown db alias name `%s`", name)
   459  	}
   460  	return nil
   461  }
   462  
   463  // begin transaction
   464  func (o *orm) Begin() error {
   465  	return o.BeginTx(context.Background(), nil)
   466  }
   467  
   468  func (o *orm) BeginTx(ctx context.Context, opts *sql.TxOptions) error {
   469  	if o.isTx {
   470  		return ErrTxHasBegan
   471  	}
   472  	var tx *sql.Tx
   473  	tx, err := o.db.(txer).BeginTx(ctx, opts)
   474  	if err != nil {
   475  		return err
   476  	}
   477  	o.isTx = true
   478  	if Debug {
   479  		o.db.(*dbQueryLog).SetDB(tx)
   480  	} else {
   481  		o.db = tx
   482  	}
   483  	return nil
   484  }
   485  
   486  // commit transaction
   487  func (o *orm) Commit() error {
   488  	if !o.isTx {
   489  		return ErrTxDone
   490  	}
   491  	err := o.db.(txEnder).Commit()
   492  	if err == nil {
   493  		o.isTx = false
   494  		o.Using(o.alias.Name)
   495  	} else if err == sql.ErrTxDone {
   496  		return ErrTxDone
   497  	}
   498  	return err
   499  }
   500  
   501  // rollback transaction
   502  func (o *orm) Rollback() error {
   503  	if !o.isTx {
   504  		return ErrTxDone
   505  	}
   506  	err := o.db.(txEnder).Rollback()
   507  	if err == nil {
   508  		o.isTx = false
   509  		o.Using(o.alias.Name)
   510  	} else if err == sql.ErrTxDone {
   511  		return ErrTxDone
   512  	}
   513  	return err
   514  }
   515  
   516  // return a raw query seter for raw sql string.
   517  func (o *orm) Raw(query string, args ...interface{}) RawSeter {
   518  	return newRawSet(o, query, args)
   519  }
   520  
   521  // return current using database Driver
   522  func (o *orm) Driver() Driver {
   523  	return driver(o.alias.Name)
   524  }
   525  
   526  // return sql.DBStats for current database
   527  func (o *orm) DBStats() *sql.DBStats {
   528  	if o.alias != nil && o.alias.DB != nil {
   529  		stats := o.alias.DB.DB.Stats()
   530  		return &stats
   531  	}
   532  	return nil
   533  }
   534  
   535  // NewOrm create new orm
   536  func NewOrm() Ormer {
   537  	BootStrap() // execute only once
   538  
   539  	o := new(orm)
   540  	err := o.Using("default")
   541  	if err != nil {
   542  		panic(err)
   543  	}
   544  	return o
   545  }
   546  
   547  // NewOrmWithDB create a new ormer object with specify *sql.DB for query
   548  func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) {
   549  	var al *alias
   550  
   551  	if dr, ok := drivers[driverName]; ok {
   552  		al = new(alias)
   553  		al.DbBaser = dbBasers[dr]
   554  		al.Driver = dr
   555  	} else {
   556  		return nil, fmt.Errorf("driver name `%s` have not registered", driverName)
   557  	}
   558  
   559  	al.Name = aliasName
   560  	al.DriverName = driverName
   561  	al.DB = &DB{
   562  		RWMutex:        new(sync.RWMutex),
   563  		DB:             db,
   564  		stmtDecorators: newStmtDecoratorLruWithEvict(),
   565  	}
   566  
   567  	detectTZ(al)
   568  
   569  	o := new(orm)
   570  	o.alias = al
   571  
   572  	if Debug {
   573  		o.db = newDbQueryLog(o.alias, db)
   574  	} else {
   575  		o.db = db
   576  	}
   577  
   578  	return o, nil
   579  }