github.com/mdaxf/iac@v0.0.0-20240519030858-58a061660378/databases/orm/db_alias.go (about)

     1  // The original package is migrated from beego and modified, you can find orignal from following link:
     2  //    "github.com/beego/beego/"
     3  //
     4  // Copyright 2023 IAC. All Rights Reserved.
     5  //
     6  // Licensed under the Apache License, Version 2.0 (the "License");
     7  // you may not use this file except in compliance with the License.
     8  // You may obtain a copy of the License at
     9  //
    10  //      http://www.apache.org/licenses/LICENSE-2.0
    11  //
    12  // Unless required by applicable law or agreed to in writing, software
    13  // distributed under the License is distributed on an "AS IS" BASIS,
    14  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    15  // See the License for the specific language governing permissions and
    16  // limitations under the License.
    17  
    18  package orm
    19  
    20  import (
    21  	"context"
    22  	"database/sql"
    23  	"fmt"
    24  	"sync"
    25  	"time"
    26  
    27  	lru "github.com/hashicorp/golang-lru"
    28  )
    29  
    30  // DriverType database driver constant int.
    31  type DriverType int
    32  
    33  // Enum the Database driver
    34  const (
    35  	_          DriverType = iota // int enum type
    36  	DRMySQL                      // mysql
    37  	DRSqlite                     // sqlite
    38  	DROracle                     // oracle
    39  	DRPostgres                   // pgsql
    40  	DRTiDB                       // TiDB
    41  	DRMSSQL                      //MS SQL Server
    42  )
    43  
    44  // database driver string.
    45  type driver string
    46  
    47  // get type constant int of current driver..
    48  func (d driver) Type() DriverType {
    49  	a, _ := dataBaseCache.get(string(d))
    50  	return a.Driver
    51  }
    52  
    53  // get name of current driver
    54  func (d driver) Name() string {
    55  	return string(d)
    56  }
    57  
    58  // check driver iis implemented Driver interface or not.
    59  var _ Driver = new(driver)
    60  
    61  var (
    62  	dataBaseCache = &_dbCache{cache: make(map[string]*alias)}
    63  	drivers       = map[string]DriverType{
    64  		"mysql":     DRMySQL,
    65  		"postgres":  DRPostgres,
    66  		"sqlite3":   DRSqlite,
    67  		"tidb":      DRTiDB,
    68  		"oracle":    DROracle,
    69  		"oci8":      DROracle, // github.com/mattn/go-oci8
    70  		"ora":       DROracle, // https://github.com/rana/ora
    71  		"sqlserver": DRMSSQL,  //"github.com/denisenkom/go-mssqldb"
    72  	}
    73  	dbBasers = map[DriverType]dbBaser{
    74  		DRMySQL:    newdbBaseMysql(),
    75  		DRSqlite:   newdbBaseSqlite(),
    76  		DROracle:   newdbBaseOracle(),
    77  		DRPostgres: newdbBasePostgres(),
    78  		DRTiDB:     newdbBaseTidb(),
    79  	}
    80  )
    81  
    82  // database alias cacher.
    83  type _dbCache struct {
    84  	mux   sync.RWMutex
    85  	cache map[string]*alias
    86  }
    87  
    88  // add database alias with original name.
    89  func (ac *_dbCache) add(name string, al *alias) (added bool) {
    90  	ac.mux.Lock()
    91  	defer ac.mux.Unlock()
    92  	if _, ok := ac.cache[name]; !ok {
    93  		ac.cache[name] = al
    94  		added = true
    95  	}
    96  	return
    97  }
    98  
    99  // get database alias if cached.
   100  func (ac *_dbCache) get(name string) (al *alias, ok bool) {
   101  	ac.mux.RLock()
   102  	defer ac.mux.RUnlock()
   103  	al, ok = ac.cache[name]
   104  	return
   105  }
   106  
   107  // get default alias.
   108  func (ac *_dbCache) getDefault() (al *alias) {
   109  	al, _ = ac.get("default")
   110  	return
   111  }
   112  
   113  type DB struct {
   114  	*sync.RWMutex
   115  	DB                  *sql.DB
   116  	stmtDecorators      *lru.Cache
   117  	stmtDecoratorsLimit int
   118  }
   119  
   120  var (
   121  	_ dbQuerier = new(DB)
   122  	_ txer      = new(DB)
   123  )
   124  
   125  func (d *DB) Begin() (*sql.Tx, error) {
   126  	return d.DB.Begin()
   127  }
   128  
   129  func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
   130  	return d.DB.BeginTx(ctx, opts)
   131  }
   132  
   133  // su must call release to release *sql.Stmt after using
   134  func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) {
   135  	d.RLock()
   136  	c, ok := d.stmtDecorators.Get(query)
   137  	if ok {
   138  		c.(*stmtDecorator).acquire()
   139  		d.RUnlock()
   140  		return c.(*stmtDecorator), nil
   141  	}
   142  	d.RUnlock()
   143  
   144  	d.Lock()
   145  	c, ok = d.stmtDecorators.Get(query)
   146  	if ok {
   147  		c.(*stmtDecorator).acquire()
   148  		d.Unlock()
   149  		return c.(*stmtDecorator), nil
   150  	}
   151  
   152  	stmt, err := d.Prepare(query)
   153  	if err != nil {
   154  		d.Unlock()
   155  		return nil, err
   156  	}
   157  	sd := newStmtDecorator(stmt)
   158  	sd.acquire()
   159  	d.stmtDecorators.Add(query, sd)
   160  	d.Unlock()
   161  
   162  	return sd, nil
   163  }
   164  
   165  func (d *DB) Prepare(query string) (*sql.Stmt, error) {
   166  	return d.DB.Prepare(query)
   167  }
   168  
   169  func (d *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
   170  	return d.DB.PrepareContext(ctx, query)
   171  }
   172  
   173  func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) {
   174  	return d.ExecContext(context.Background(), query, args...)
   175  }
   176  
   177  func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
   178  	if d.stmtDecorators == nil {
   179  		return d.DB.ExecContext(ctx, query, args...)
   180  	}
   181  
   182  	sd, err := d.getStmtDecorator(query)
   183  	if err != nil {
   184  		return nil, err
   185  	}
   186  	stmt := sd.getStmt()
   187  	defer sd.release()
   188  	return stmt.ExecContext(ctx, args...)
   189  }
   190  
   191  func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) {
   192  	return d.QueryContext(context.Background(), query, args...)
   193  }
   194  
   195  func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
   196  	if d.stmtDecorators == nil {
   197  		return d.DB.QueryContext(ctx, query, args...)
   198  	}
   199  
   200  	sd, err := d.getStmtDecorator(query)
   201  	if err != nil {
   202  		return nil, err
   203  	}
   204  	stmt := sd.getStmt()
   205  	defer sd.release()
   206  	return stmt.QueryContext(ctx, args...)
   207  }
   208  
   209  func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row {
   210  	return d.QueryRowContext(context.Background(), query, args...)
   211  }
   212  
   213  func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
   214  	if d.stmtDecorators == nil {
   215  		return d.DB.QueryRowContext(ctx, query, args...)
   216  	}
   217  
   218  	sd, err := d.getStmtDecorator(query)
   219  	if err != nil {
   220  		panic(err)
   221  	}
   222  	stmt := sd.getStmt()
   223  	defer sd.release()
   224  	return stmt.QueryRowContext(ctx, args...)
   225  }
   226  
   227  type TxDB struct {
   228  	tx *sql.Tx
   229  }
   230  
   231  var (
   232  	_ dbQuerier = new(TxDB)
   233  	_ txEnder   = new(TxDB)
   234  )
   235  
   236  func (t *TxDB) Commit() error {
   237  	return t.tx.Commit()
   238  }
   239  
   240  func (t *TxDB) Rollback() error {
   241  	return t.tx.Rollback()
   242  }
   243  
   244  func (t *TxDB) RollbackUnlessCommit() error {
   245  	err := t.tx.Rollback()
   246  	if err != sql.ErrTxDone {
   247  		return err
   248  	}
   249  	return nil
   250  }
   251  
   252  var (
   253  	_ dbQuerier = new(TxDB)
   254  	_ txEnder   = new(TxDB)
   255  )
   256  
   257  func (t *TxDB) Prepare(query string) (*sql.Stmt, error) {
   258  	return t.PrepareContext(context.Background(), query)
   259  }
   260  
   261  func (t *TxDB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
   262  	return t.tx.PrepareContext(ctx, query)
   263  }
   264  
   265  func (t *TxDB) Exec(query string, args ...interface{}) (sql.Result, error) {
   266  	return t.ExecContext(context.Background(), query, args...)
   267  }
   268  
   269  func (t *TxDB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
   270  	return t.tx.ExecContext(ctx, query, args...)
   271  }
   272  
   273  func (t *TxDB) Query(query string, args ...interface{}) (*sql.Rows, error) {
   274  	return t.QueryContext(context.Background(), query, args...)
   275  }
   276  
   277  func (t *TxDB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
   278  	return t.tx.QueryContext(ctx, query, args...)
   279  }
   280  
   281  func (t *TxDB) QueryRow(query string, args ...interface{}) *sql.Row {
   282  	return t.QueryRowContext(context.Background(), query, args...)
   283  }
   284  
   285  func (t *TxDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
   286  	return t.tx.QueryRowContext(ctx, query, args...)
   287  }
   288  
   289  type alias struct {
   290  	Name            string
   291  	Driver          DriverType
   292  	DriverName      string
   293  	DataSource      string
   294  	MaxIdleConns    int
   295  	MaxOpenConns    int
   296  	ConnMaxLifetime time.Duration
   297  	StmtCacheSize   int
   298  	DB              *DB
   299  	DbBaser         dbBaser
   300  	TZ              *time.Location
   301  	Engine          string
   302  }
   303  
   304  func detectTZ(al *alias) {
   305  	// orm timezone system match database
   306  	// default use Local
   307  	al.TZ = DefaultTimeLoc
   308  
   309  	if al.DriverName == "sphinx" {
   310  		return
   311  	}
   312  
   313  	switch al.Driver {
   314  	case DRMySQL:
   315  		row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)")
   316  		var tz string
   317  		row.Scan(&tz)
   318  		if len(tz) >= 8 {
   319  			if tz[0] != '-' {
   320  				tz = "+" + tz
   321  			}
   322  			t, err := time.Parse("-07:00:00", tz)
   323  			if err == nil {
   324  				if t.Location().String() != "" {
   325  					al.TZ = t.Location()
   326  				}
   327  			} else {
   328  				DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
   329  			}
   330  		}
   331  
   332  		// get default engine from current database
   333  		row = al.DB.QueryRow("SELECT ENGINE, TRANSACTIONS FROM information_schema.engines WHERE SUPPORT = 'DEFAULT'")
   334  		var engine string
   335  		var tx bool
   336  		row.Scan(&engine, &tx)
   337  
   338  		if engine != "" {
   339  			al.Engine = engine
   340  		} else {
   341  			al.Engine = "INNODB"
   342  		}
   343  
   344  	case DRSqlite, DROracle:
   345  		al.TZ = time.UTC
   346  
   347  	case DRPostgres:
   348  		row := al.DB.QueryRow("SELECT current_setting('TIMEZONE')")
   349  		var tz string
   350  		row.Scan(&tz)
   351  		loc, err := time.LoadLocation(tz)
   352  		if err == nil {
   353  			al.TZ = loc
   354  		} else {
   355  			DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
   356  		}
   357  	}
   358  }
   359  
   360  func addAliasWthDB(aliasName, driverName string, db *sql.DB, params ...DBOption) (*alias, error) {
   361  	existErr := fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName)
   362  	if _, ok := dataBaseCache.get(aliasName); ok {
   363  		return nil, existErr
   364  	}
   365  
   366  	al, err := newAliasWithDb(aliasName, driverName, db, params...)
   367  	if err != nil {
   368  		return nil, err
   369  	}
   370  
   371  	if !dataBaseCache.add(aliasName, al) {
   372  		return nil, existErr
   373  	}
   374  
   375  	return al, nil
   376  }
   377  
   378  func newAliasWithDb(aliasName, driverName string, db *sql.DB, params ...DBOption) (*alias, error) {
   379  	al := &alias{}
   380  	al.DB = &DB{
   381  		RWMutex: new(sync.RWMutex),
   382  		DB:      db,
   383  	}
   384  
   385  	for _, p := range params {
   386  		p(al)
   387  	}
   388  
   389  	var stmtCache *lru.Cache
   390  	var stmtCacheSize int
   391  
   392  	if al.StmtCacheSize > 0 {
   393  		_stmtCache, errC := newStmtDecoratorLruWithEvict(al.StmtCacheSize)
   394  		if errC != nil {
   395  			return nil, errC
   396  		} else {
   397  			stmtCache = _stmtCache
   398  			stmtCacheSize = al.StmtCacheSize
   399  		}
   400  	}
   401  
   402  	al.Name = aliasName
   403  	al.DriverName = driverName
   404  	al.DB.stmtDecorators = stmtCache
   405  	al.DB.stmtDecoratorsLimit = stmtCacheSize
   406  
   407  	if dr, ok := drivers[driverName]; ok {
   408  		al.DbBaser = dbBasers[dr]
   409  		al.Driver = dr
   410  	} else {
   411  		return nil, fmt.Errorf("driver name `%s` have not registered", driverName)
   412  	}
   413  
   414  	err := db.Ping()
   415  	if err != nil {
   416  		return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error())
   417  	}
   418  
   419  	detectTZ(al)
   420  
   421  	return al, nil
   422  }
   423  
   424  // SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name
   425  // Deprecated you should not use this, we will remove it in the future
   426  func SetMaxIdleConns(aliasName string, maxIdleConns int) {
   427  	al := getDbAlias(aliasName)
   428  	al.SetMaxIdleConns(maxIdleConns)
   429  }
   430  
   431  // SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name
   432  // Deprecated you should not use this, we will remove it in the future
   433  func SetMaxOpenConns(aliasName string, maxOpenConns int) {
   434  	al := getDbAlias(aliasName)
   435  	al.SetMaxOpenConns(maxOpenConns)
   436  }
   437  
   438  // SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name
   439  func (al *alias) SetMaxIdleConns(maxIdleConns int) {
   440  	al.MaxIdleConns = maxIdleConns
   441  	al.DB.DB.SetMaxIdleConns(maxIdleConns)
   442  }
   443  
   444  // SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name
   445  func (al *alias) SetMaxOpenConns(maxOpenConns int) {
   446  	al.MaxOpenConns = maxOpenConns
   447  	al.DB.DB.SetMaxOpenConns(maxOpenConns)
   448  }
   449  
   450  func (al *alias) SetConnMaxLifetime(lifeTime time.Duration) {
   451  	al.ConnMaxLifetime = lifeTime
   452  	al.DB.DB.SetConnMaxLifetime(lifeTime)
   453  }
   454  
   455  // AddAliasWthDB add a aliasName for the drivename
   456  func AddAliasWthDB(aliasName, driverName string, db *sql.DB, params ...DBOption) error {
   457  	_, err := addAliasWthDB(aliasName, driverName, db, params...)
   458  	return err
   459  }
   460  
   461  // RegisterDataBase Setting the database connect params. Use the database driver self dataSource args.
   462  func RegisterDataBase(aliasName, driverName, dataSource string, params ...DBOption) error {
   463  	var (
   464  		err error
   465  		db  *sql.DB
   466  		al  *alias
   467  	)
   468  
   469  	db, err = sql.Open(driverName, dataSource)
   470  	if err != nil {
   471  		err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error())
   472  		goto end
   473  	}
   474  
   475  	al, err = addAliasWthDB(aliasName, driverName, db, params...)
   476  	if err != nil {
   477  		goto end
   478  	}
   479  
   480  	al.DataSource = dataSource
   481  
   482  end:
   483  	if err != nil {
   484  		if db != nil {
   485  			db.Close()
   486  		}
   487  		DebugLog.Println(err.Error())
   488  	}
   489  
   490  	return err
   491  }
   492  
   493  // RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type.
   494  func RegisterDriver(driverName string, typ DriverType) error {
   495  	if t, ok := drivers[driverName]; !ok {
   496  		drivers[driverName] = typ
   497  	} else {
   498  		if t != typ {
   499  			return fmt.Errorf("driverName `%s` db driver already registered and is other type", driverName)
   500  		}
   501  	}
   502  	return nil
   503  }
   504  
   505  // SetDataBaseTZ Change the database default used timezone
   506  func SetDataBaseTZ(aliasName string, tz *time.Location) error {
   507  	if al, ok := dataBaseCache.get(aliasName); ok {
   508  		al.TZ = tz
   509  	} else {
   510  		return fmt.Errorf("DataBase alias name `%s` not registered", aliasName)
   511  	}
   512  	return nil
   513  }
   514  
   515  // GetDB Get *sql.DB from registered database by db alias name.
   516  // Use "default" as alias name if you not set.
   517  func GetDB(aliasNames ...string) (*sql.DB, error) {
   518  	var name string
   519  	if len(aliasNames) > 0 {
   520  		name = aliasNames[0]
   521  	} else {
   522  		name = "default"
   523  	}
   524  	al, ok := dataBaseCache.get(name)
   525  	if ok {
   526  		return al.DB.DB, nil
   527  	}
   528  	return nil, fmt.Errorf("DataBase of alias name `%s` not found", name)
   529  }
   530  
   531  type stmtDecorator struct {
   532  	wg   sync.WaitGroup
   533  	stmt *sql.Stmt
   534  }
   535  
   536  func (s *stmtDecorator) getStmt() *sql.Stmt {
   537  	return s.stmt
   538  }
   539  
   540  // acquire will add one
   541  // since this method will be used inside read lock scope,
   542  // so we can not do more things here
   543  // we should think about refactor this
   544  func (s *stmtDecorator) acquire() {
   545  	s.wg.Add(1)
   546  }
   547  
   548  func (s *stmtDecorator) release() {
   549  	s.wg.Done()
   550  }
   551  
   552  // garbage recycle for stmt
   553  func (s *stmtDecorator) destroy() {
   554  	go func() {
   555  		s.wg.Wait()
   556  		_ = s.stmt.Close()
   557  	}()
   558  }
   559  
   560  func newStmtDecorator(sqlStmt *sql.Stmt) *stmtDecorator {
   561  	return &stmtDecorator{
   562  		stmt: sqlStmt,
   563  	}
   564  }
   565  
   566  func newStmtDecoratorLruWithEvict(cacheSize int) (*lru.Cache, error) {
   567  	cache, err := lru.NewWithEvict(cacheSize, func(key interface{}, value interface{}) {
   568  		value.(*stmtDecorator).destroy()
   569  	})
   570  	if err != nil {
   571  		return nil, err
   572  	}
   573  	return cache, nil
   574  }
   575  
   576  type DBOption func(al *alias)
   577  
   578  // MaxIdleConnections return a hint about MaxIdleConnections
   579  func MaxIdleConnections(maxIdleConn int) DBOption {
   580  	return func(al *alias) {
   581  		al.SetMaxIdleConns(maxIdleConn)
   582  	}
   583  }
   584  
   585  // MaxOpenConnections return a hint about MaxOpenConnections
   586  func MaxOpenConnections(maxOpenConn int) DBOption {
   587  	return func(al *alias) {
   588  		al.SetMaxOpenConns(maxOpenConn)
   589  	}
   590  }
   591  
   592  // ConnMaxLifetime return a hint about ConnMaxLifetime
   593  func ConnMaxLifetime(v time.Duration) DBOption {
   594  	return func(al *alias) {
   595  		al.SetConnMaxLifetime(v)
   596  	}
   597  }
   598  
   599  // MaxStmtCacheSize return a hint about MaxStmtCacheSize
   600  func MaxStmtCacheSize(v int) DBOption {
   601  	return func(al *alias) {
   602  		al.StmtCacheSize = v
   603  	}
   604  }