github.com/astaxie/beego@v1.12.3/orm/db_alias.go (about)

     1  // Copyright 2014 beego Author. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package orm
    16  
    17  import (
    18  	"context"
    19  	"database/sql"
    20  	"fmt"
    21  	"reflect"
    22  	"sync"
    23  	"time"
    24  
    25  	lru "github.com/hashicorp/golang-lru"
    26  )
    27  
    28  // DriverType database driver constant int.
    29  type DriverType int
    30  
    31  // Enum the Database driver
    32  const (
    33  	_          DriverType = iota // int enum type
    34  	DRMySQL                      // mysql
    35  	DRSqlite                     // sqlite
    36  	DROracle                     // oracle
    37  	DRPostgres                   // pgsql
    38  	DRTiDB                       // TiDB
    39  )
    40  
    41  // database driver string.
    42  type driver string
    43  
    44  // get type constant int of current driver..
    45  func (d driver) Type() DriverType {
    46  	a, _ := dataBaseCache.get(string(d))
    47  	return a.Driver
    48  }
    49  
    50  // get name of current driver
    51  func (d driver) Name() string {
    52  	return string(d)
    53  }
    54  
    55  // check driver iis implemented Driver interface or not.
    56  var _ Driver = new(driver)
    57  
    58  var (
    59  	dataBaseCache = &_dbCache{cache: make(map[string]*alias)}
    60  	drivers       = map[string]DriverType{
    61  		"mysql":    DRMySQL,
    62  		"postgres": DRPostgres,
    63  		"sqlite3":  DRSqlite,
    64  		"tidb":     DRTiDB,
    65  		"oracle":   DROracle,
    66  		"oci8":     DROracle, // github.com/mattn/go-oci8
    67  		"ora":      DROracle, //https://github.com/rana/ora
    68  	}
    69  	dbBasers = map[DriverType]dbBaser{
    70  		DRMySQL:    newdbBaseMysql(),
    71  		DRSqlite:   newdbBaseSqlite(),
    72  		DROracle:   newdbBaseOracle(),
    73  		DRPostgres: newdbBasePostgres(),
    74  		DRTiDB:     newdbBaseTidb(),
    75  	}
    76  )
    77  
    78  // database alias cacher.
    79  type _dbCache struct {
    80  	mux   sync.RWMutex
    81  	cache map[string]*alias
    82  }
    83  
    84  // add database alias with original name.
    85  func (ac *_dbCache) add(name string, al *alias) (added bool) {
    86  	ac.mux.Lock()
    87  	defer ac.mux.Unlock()
    88  	if _, ok := ac.cache[name]; !ok {
    89  		ac.cache[name] = al
    90  		added = true
    91  	}
    92  	return
    93  }
    94  
    95  // get database alias if cached.
    96  func (ac *_dbCache) get(name string) (al *alias, ok bool) {
    97  	ac.mux.RLock()
    98  	defer ac.mux.RUnlock()
    99  	al, ok = ac.cache[name]
   100  	return
   101  }
   102  
   103  // get default alias.
   104  func (ac *_dbCache) getDefault() (al *alias) {
   105  	al, _ = ac.get("default")
   106  	return
   107  }
   108  
   109  type DB struct {
   110  	*sync.RWMutex
   111  	DB             *sql.DB
   112  	stmtDecorators *lru.Cache
   113  }
   114  
   115  func (d *DB) Begin() (*sql.Tx, error) {
   116  	return d.DB.Begin()
   117  }
   118  
   119  func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
   120  	return d.DB.BeginTx(ctx, opts)
   121  }
   122  
   123  //su must call release to release *sql.Stmt after using
   124  func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) {
   125  	d.RLock()
   126  	c, ok := d.stmtDecorators.Get(query)
   127  	if ok {
   128  		c.(*stmtDecorator).acquire()
   129  		d.RUnlock()
   130  		return c.(*stmtDecorator), nil
   131  	}
   132  	d.RUnlock()
   133  
   134  	d.Lock()
   135  	c, ok = d.stmtDecorators.Get(query)
   136  	if ok {
   137  		c.(*stmtDecorator).acquire()
   138  		d.Unlock()
   139  		return c.(*stmtDecorator), nil
   140  	}
   141  
   142  	stmt, err := d.Prepare(query)
   143  	if err != nil {
   144  		d.Unlock()
   145  		return nil, err
   146  	}
   147  	sd := newStmtDecorator(stmt)
   148  	sd.acquire()
   149  	d.stmtDecorators.Add(query, sd)
   150  	d.Unlock()
   151  
   152  	return sd, nil
   153  }
   154  
   155  func (d *DB) Prepare(query string) (*sql.Stmt, error) {
   156  	return d.DB.Prepare(query)
   157  }
   158  
   159  func (d *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
   160  	return d.DB.PrepareContext(ctx, query)
   161  }
   162  
   163  func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) {
   164  	sd, err := d.getStmtDecorator(query)
   165  	if err != nil {
   166  		return nil, err
   167  	}
   168  	stmt := sd.getStmt()
   169  	defer sd.release()
   170  	return stmt.Exec(args...)
   171  }
   172  
   173  func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
   174  	sd, err := d.getStmtDecorator(query)
   175  	if err != nil {
   176  		return nil, err
   177  	}
   178  	stmt := sd.getStmt()
   179  	defer sd.release()
   180  	return stmt.ExecContext(ctx, args...)
   181  }
   182  
   183  func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) {
   184  	sd, err := d.getStmtDecorator(query)
   185  	if err != nil {
   186  		return nil, err
   187  	}
   188  	stmt := sd.getStmt()
   189  	defer sd.release()
   190  	return stmt.Query(args...)
   191  }
   192  
   193  func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
   194  	sd, err := d.getStmtDecorator(query)
   195  	if err != nil {
   196  		return nil, err
   197  	}
   198  	stmt := sd.getStmt()
   199  	defer sd.release()
   200  	return stmt.QueryContext(ctx, args...)
   201  }
   202  
   203  func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row {
   204  	sd, err := d.getStmtDecorator(query)
   205  	if err != nil {
   206  		panic(err)
   207  	}
   208  	stmt := sd.getStmt()
   209  	defer sd.release()
   210  	return stmt.QueryRow(args...)
   211  
   212  }
   213  
   214  func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
   215  	sd, err := d.getStmtDecorator(query)
   216  	if err != nil {
   217  		panic(err)
   218  	}
   219  	stmt := sd.getStmt()
   220  	defer sd.release()
   221  	return stmt.QueryRowContext(ctx, args)
   222  }
   223  
   224  type alias struct {
   225  	Name         string
   226  	Driver       DriverType
   227  	DriverName   string
   228  	DataSource   string
   229  	MaxIdleConns int
   230  	MaxOpenConns int
   231  	DB           *DB
   232  	DbBaser      dbBaser
   233  	TZ           *time.Location
   234  	Engine       string
   235  }
   236  
   237  func detectTZ(al *alias) {
   238  	// orm timezone system match database
   239  	// default use Local
   240  	al.TZ = DefaultTimeLoc
   241  
   242  	if al.DriverName == "sphinx" {
   243  		return
   244  	}
   245  
   246  	switch al.Driver {
   247  	case DRMySQL:
   248  		row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)")
   249  		var tz string
   250  		row.Scan(&tz)
   251  		if len(tz) >= 8 {
   252  			if tz[0] != '-' {
   253  				tz = "+" + tz
   254  			}
   255  			t, err := time.Parse("-07:00:00", tz)
   256  			if err == nil {
   257  				if t.Location().String() != "" {
   258  					al.TZ = t.Location()
   259  				}
   260  			} else {
   261  				DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
   262  			}
   263  		}
   264  
   265  		// get default engine from current database
   266  		row = al.DB.QueryRow("SELECT ENGINE, TRANSACTIONS FROM information_schema.engines WHERE SUPPORT = 'DEFAULT'")
   267  		var engine string
   268  		var tx bool
   269  		row.Scan(&engine, &tx)
   270  
   271  		if engine != "" {
   272  			al.Engine = engine
   273  		} else {
   274  			al.Engine = "INNODB"
   275  		}
   276  
   277  	case DRSqlite, DROracle:
   278  		al.TZ = time.UTC
   279  
   280  	case DRPostgres:
   281  		row := al.DB.QueryRow("SELECT current_setting('TIMEZONE')")
   282  		var tz string
   283  		row.Scan(&tz)
   284  		loc, err := time.LoadLocation(tz)
   285  		if err == nil {
   286  			al.TZ = loc
   287  		} else {
   288  			DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
   289  		}
   290  	}
   291  }
   292  
   293  func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) {
   294  	al := new(alias)
   295  	al.Name = aliasName
   296  	al.DriverName = driverName
   297  	al.DB = &DB{
   298  		RWMutex:        new(sync.RWMutex),
   299  		DB:             db,
   300  		stmtDecorators: newStmtDecoratorLruWithEvict(),
   301  	}
   302  
   303  	if dr, ok := drivers[driverName]; ok {
   304  		al.DbBaser = dbBasers[dr]
   305  		al.Driver = dr
   306  	} else {
   307  		return nil, fmt.Errorf("driver name `%s` have not registered", driverName)
   308  	}
   309  
   310  	err := db.Ping()
   311  	if err != nil {
   312  		return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error())
   313  	}
   314  
   315  	if !dataBaseCache.add(aliasName, al) {
   316  		return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName)
   317  	}
   318  
   319  	return al, nil
   320  }
   321  
   322  // AddAliasWthDB add a aliasName for the drivename
   323  func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error {
   324  	_, err := addAliasWthDB(aliasName, driverName, db)
   325  	return err
   326  }
   327  
   328  // RegisterDataBase Setting the database connect params. Use the database driver self dataSource args.
   329  func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error {
   330  	var (
   331  		err error
   332  		db  *sql.DB
   333  		al  *alias
   334  	)
   335  
   336  	db, err = sql.Open(driverName, dataSource)
   337  	if err != nil {
   338  		err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error())
   339  		goto end
   340  	}
   341  
   342  	al, err = addAliasWthDB(aliasName, driverName, db)
   343  	if err != nil {
   344  		goto end
   345  	}
   346  
   347  	al.DataSource = dataSource
   348  
   349  	detectTZ(al)
   350  
   351  	for i, v := range params {
   352  		switch i {
   353  		case 0:
   354  			SetMaxIdleConns(al.Name, v)
   355  		case 1:
   356  			SetMaxOpenConns(al.Name, v)
   357  		}
   358  	}
   359  
   360  end:
   361  	if err != nil {
   362  		if db != nil {
   363  			db.Close()
   364  		}
   365  		DebugLog.Println(err.Error())
   366  	}
   367  
   368  	return err
   369  }
   370  
   371  // RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type.
   372  func RegisterDriver(driverName string, typ DriverType) error {
   373  	if t, ok := drivers[driverName]; !ok {
   374  		drivers[driverName] = typ
   375  	} else {
   376  		if t != typ {
   377  			return fmt.Errorf("driverName `%s` db driver already registered and is other type", driverName)
   378  		}
   379  	}
   380  	return nil
   381  }
   382  
   383  // SetDataBaseTZ Change the database default used timezone
   384  func SetDataBaseTZ(aliasName string, tz *time.Location) error {
   385  	if al, ok := dataBaseCache.get(aliasName); ok {
   386  		al.TZ = tz
   387  	} else {
   388  		return fmt.Errorf("DataBase alias name `%s` not registered", aliasName)
   389  	}
   390  	return nil
   391  }
   392  
   393  // SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name
   394  func SetMaxIdleConns(aliasName string, maxIdleConns int) {
   395  	al := getDbAlias(aliasName)
   396  	al.MaxIdleConns = maxIdleConns
   397  	al.DB.DB.SetMaxIdleConns(maxIdleConns)
   398  }
   399  
   400  // SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name
   401  func SetMaxOpenConns(aliasName string, maxOpenConns int) {
   402  	al := getDbAlias(aliasName)
   403  	al.MaxOpenConns = maxOpenConns
   404  	al.DB.DB.SetMaxOpenConns(maxOpenConns)
   405  	// for tip go 1.2
   406  	if fun := reflect.ValueOf(al.DB).MethodByName("SetMaxOpenConns"); fun.IsValid() {
   407  		fun.Call([]reflect.Value{reflect.ValueOf(maxOpenConns)})
   408  	}
   409  }
   410  
   411  // GetDB Get *sql.DB from registered database by db alias name.
   412  // Use "default" as alias name if you not set.
   413  func GetDB(aliasNames ...string) (*sql.DB, error) {
   414  	var name string
   415  	if len(aliasNames) > 0 {
   416  		name = aliasNames[0]
   417  	} else {
   418  		name = "default"
   419  	}
   420  	al, ok := dataBaseCache.get(name)
   421  	if ok {
   422  		return al.DB.DB, nil
   423  	}
   424  	return nil, fmt.Errorf("DataBase of alias name `%s` not found", name)
   425  }
   426  
   427  type stmtDecorator struct {
   428  	wg   sync.WaitGroup
   429  	stmt *sql.Stmt
   430  }
   431  
   432  func (s *stmtDecorator) getStmt() *sql.Stmt {
   433  	return s.stmt
   434  }
   435  
   436  // acquire will add one
   437  // since this method will be used inside read lock scope,
   438  // so we can not do more things here
   439  // we should think about refactor this
   440  func (s *stmtDecorator) acquire() {
   441  	s.wg.Add(1)
   442  }
   443  
   444  func (s *stmtDecorator) release() {
   445  	s.wg.Done()
   446  }
   447  
   448  //garbage recycle for stmt
   449  func (s *stmtDecorator) destroy() {
   450  	go func() {
   451  		s.wg.Wait()
   452  		_ = s.stmt.Close()
   453  	}()
   454  }
   455  
   456  func newStmtDecorator(sqlStmt *sql.Stmt) *stmtDecorator {
   457  	return &stmtDecorator{
   458  		stmt: sqlStmt,
   459  	}
   460  }
   461  
   462  func newStmtDecoratorLruWithEvict() *lru.Cache {
   463  	// temporarily solution
   464  	// we fixed this problem in v2.x
   465  	cache, _ := lru.NewWithEvict(50, func(key interface{}, value interface{}) {
   466  		value.(*stmtDecorator).destroy()
   467  	})
   468  	return cache
   469  }