github.com/mdaxf/iac@v0.0.0-20240519030858-58a061660378/databases/orm/db_alias.go (about) 1 // The original package is migrated from beego and modified, you can find orignal from following link: 2 // "github.com/beego/beego/" 3 // 4 // Copyright 2023 IAC. All Rights Reserved. 5 // 6 // Licensed under the Apache License, Version 2.0 (the "License"); 7 // you may not use this file except in compliance with the License. 8 // You may obtain a copy of the License at 9 // 10 // http://www.apache.org/licenses/LICENSE-2.0 11 // 12 // Unless required by applicable law or agreed to in writing, software 13 // distributed under the License is distributed on an "AS IS" BASIS, 14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 // See the License for the specific language governing permissions and 16 // limitations under the License. 17 18 package orm 19 20 import ( 21 "context" 22 "database/sql" 23 "fmt" 24 "sync" 25 "time" 26 27 lru "github.com/hashicorp/golang-lru" 28 ) 29 30 // DriverType database driver constant int. 31 type DriverType int 32 33 // Enum the Database driver 34 const ( 35 _ DriverType = iota // int enum type 36 DRMySQL // mysql 37 DRSqlite // sqlite 38 DROracle // oracle 39 DRPostgres // pgsql 40 DRTiDB // TiDB 41 DRMSSQL //MS SQL Server 42 ) 43 44 // database driver string. 45 type driver string 46 47 // get type constant int of current driver.. 48 func (d driver) Type() DriverType { 49 a, _ := dataBaseCache.get(string(d)) 50 return a.Driver 51 } 52 53 // get name of current driver 54 func (d driver) Name() string { 55 return string(d) 56 } 57 58 // check driver iis implemented Driver interface or not. 59 var _ Driver = new(driver) 60 61 var ( 62 dataBaseCache = &_dbCache{cache: make(map[string]*alias)} 63 drivers = map[string]DriverType{ 64 "mysql": DRMySQL, 65 "postgres": DRPostgres, 66 "sqlite3": DRSqlite, 67 "tidb": DRTiDB, 68 "oracle": DROracle, 69 "oci8": DROracle, // github.com/mattn/go-oci8 70 "ora": DROracle, // https://github.com/rana/ora 71 "sqlserver": DRMSSQL, //"github.com/denisenkom/go-mssqldb" 72 } 73 dbBasers = map[DriverType]dbBaser{ 74 DRMySQL: newdbBaseMysql(), 75 DRSqlite: newdbBaseSqlite(), 76 DROracle: newdbBaseOracle(), 77 DRPostgres: newdbBasePostgres(), 78 DRTiDB: newdbBaseTidb(), 79 } 80 ) 81 82 // database alias cacher. 83 type _dbCache struct { 84 mux sync.RWMutex 85 cache map[string]*alias 86 } 87 88 // add database alias with original name. 89 func (ac *_dbCache) add(name string, al *alias) (added bool) { 90 ac.mux.Lock() 91 defer ac.mux.Unlock() 92 if _, ok := ac.cache[name]; !ok { 93 ac.cache[name] = al 94 added = true 95 } 96 return 97 } 98 99 // get database alias if cached. 100 func (ac *_dbCache) get(name string) (al *alias, ok bool) { 101 ac.mux.RLock() 102 defer ac.mux.RUnlock() 103 al, ok = ac.cache[name] 104 return 105 } 106 107 // get default alias. 108 func (ac *_dbCache) getDefault() (al *alias) { 109 al, _ = ac.get("default") 110 return 111 } 112 113 type DB struct { 114 *sync.RWMutex 115 DB *sql.DB 116 stmtDecorators *lru.Cache 117 stmtDecoratorsLimit int 118 } 119 120 var ( 121 _ dbQuerier = new(DB) 122 _ txer = new(DB) 123 ) 124 125 func (d *DB) Begin() (*sql.Tx, error) { 126 return d.DB.Begin() 127 } 128 129 func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { 130 return d.DB.BeginTx(ctx, opts) 131 } 132 133 // su must call release to release *sql.Stmt after using 134 func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) { 135 d.RLock() 136 c, ok := d.stmtDecorators.Get(query) 137 if ok { 138 c.(*stmtDecorator).acquire() 139 d.RUnlock() 140 return c.(*stmtDecorator), nil 141 } 142 d.RUnlock() 143 144 d.Lock() 145 c, ok = d.stmtDecorators.Get(query) 146 if ok { 147 c.(*stmtDecorator).acquire() 148 d.Unlock() 149 return c.(*stmtDecorator), nil 150 } 151 152 stmt, err := d.Prepare(query) 153 if err != nil { 154 d.Unlock() 155 return nil, err 156 } 157 sd := newStmtDecorator(stmt) 158 sd.acquire() 159 d.stmtDecorators.Add(query, sd) 160 d.Unlock() 161 162 return sd, nil 163 } 164 165 func (d *DB) Prepare(query string) (*sql.Stmt, error) { 166 return d.DB.Prepare(query) 167 } 168 169 func (d *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { 170 return d.DB.PrepareContext(ctx, query) 171 } 172 173 func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) { 174 return d.ExecContext(context.Background(), query, args...) 175 } 176 177 func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { 178 if d.stmtDecorators == nil { 179 return d.DB.ExecContext(ctx, query, args...) 180 } 181 182 sd, err := d.getStmtDecorator(query) 183 if err != nil { 184 return nil, err 185 } 186 stmt := sd.getStmt() 187 defer sd.release() 188 return stmt.ExecContext(ctx, args...) 189 } 190 191 func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) { 192 return d.QueryContext(context.Background(), query, args...) 193 } 194 195 func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { 196 if d.stmtDecorators == nil { 197 return d.DB.QueryContext(ctx, query, args...) 198 } 199 200 sd, err := d.getStmtDecorator(query) 201 if err != nil { 202 return nil, err 203 } 204 stmt := sd.getStmt() 205 defer sd.release() 206 return stmt.QueryContext(ctx, args...) 207 } 208 209 func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row { 210 return d.QueryRowContext(context.Background(), query, args...) 211 } 212 213 func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { 214 if d.stmtDecorators == nil { 215 return d.DB.QueryRowContext(ctx, query, args...) 216 } 217 218 sd, err := d.getStmtDecorator(query) 219 if err != nil { 220 panic(err) 221 } 222 stmt := sd.getStmt() 223 defer sd.release() 224 return stmt.QueryRowContext(ctx, args...) 225 } 226 227 type TxDB struct { 228 tx *sql.Tx 229 } 230 231 var ( 232 _ dbQuerier = new(TxDB) 233 _ txEnder = new(TxDB) 234 ) 235 236 func (t *TxDB) Commit() error { 237 return t.tx.Commit() 238 } 239 240 func (t *TxDB) Rollback() error { 241 return t.tx.Rollback() 242 } 243 244 func (t *TxDB) RollbackUnlessCommit() error { 245 err := t.tx.Rollback() 246 if err != sql.ErrTxDone { 247 return err 248 } 249 return nil 250 } 251 252 var ( 253 _ dbQuerier = new(TxDB) 254 _ txEnder = new(TxDB) 255 ) 256 257 func (t *TxDB) Prepare(query string) (*sql.Stmt, error) { 258 return t.PrepareContext(context.Background(), query) 259 } 260 261 func (t *TxDB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { 262 return t.tx.PrepareContext(ctx, query) 263 } 264 265 func (t *TxDB) Exec(query string, args ...interface{}) (sql.Result, error) { 266 return t.ExecContext(context.Background(), query, args...) 267 } 268 269 func (t *TxDB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { 270 return t.tx.ExecContext(ctx, query, args...) 271 } 272 273 func (t *TxDB) Query(query string, args ...interface{}) (*sql.Rows, error) { 274 return t.QueryContext(context.Background(), query, args...) 275 } 276 277 func (t *TxDB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { 278 return t.tx.QueryContext(ctx, query, args...) 279 } 280 281 func (t *TxDB) QueryRow(query string, args ...interface{}) *sql.Row { 282 return t.QueryRowContext(context.Background(), query, args...) 283 } 284 285 func (t *TxDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { 286 return t.tx.QueryRowContext(ctx, query, args...) 287 } 288 289 type alias struct { 290 Name string 291 Driver DriverType 292 DriverName string 293 DataSource string 294 MaxIdleConns int 295 MaxOpenConns int 296 ConnMaxLifetime time.Duration 297 StmtCacheSize int 298 DB *DB 299 DbBaser dbBaser 300 TZ *time.Location 301 Engine string 302 } 303 304 func detectTZ(al *alias) { 305 // orm timezone system match database 306 // default use Local 307 al.TZ = DefaultTimeLoc 308 309 if al.DriverName == "sphinx" { 310 return 311 } 312 313 switch al.Driver { 314 case DRMySQL: 315 row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)") 316 var tz string 317 row.Scan(&tz) 318 if len(tz) >= 8 { 319 if tz[0] != '-' { 320 tz = "+" + tz 321 } 322 t, err := time.Parse("-07:00:00", tz) 323 if err == nil { 324 if t.Location().String() != "" { 325 al.TZ = t.Location() 326 } 327 } else { 328 DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error()) 329 } 330 } 331 332 // get default engine from current database 333 row = al.DB.QueryRow("SELECT ENGINE, TRANSACTIONS FROM information_schema.engines WHERE SUPPORT = 'DEFAULT'") 334 var engine string 335 var tx bool 336 row.Scan(&engine, &tx) 337 338 if engine != "" { 339 al.Engine = engine 340 } else { 341 al.Engine = "INNODB" 342 } 343 344 case DRSqlite, DROracle: 345 al.TZ = time.UTC 346 347 case DRPostgres: 348 row := al.DB.QueryRow("SELECT current_setting('TIMEZONE')") 349 var tz string 350 row.Scan(&tz) 351 loc, err := time.LoadLocation(tz) 352 if err == nil { 353 al.TZ = loc 354 } else { 355 DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error()) 356 } 357 } 358 } 359 360 func addAliasWthDB(aliasName, driverName string, db *sql.DB, params ...DBOption) (*alias, error) { 361 existErr := fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName) 362 if _, ok := dataBaseCache.get(aliasName); ok { 363 return nil, existErr 364 } 365 366 al, err := newAliasWithDb(aliasName, driverName, db, params...) 367 if err != nil { 368 return nil, err 369 } 370 371 if !dataBaseCache.add(aliasName, al) { 372 return nil, existErr 373 } 374 375 return al, nil 376 } 377 378 func newAliasWithDb(aliasName, driverName string, db *sql.DB, params ...DBOption) (*alias, error) { 379 al := &alias{} 380 al.DB = &DB{ 381 RWMutex: new(sync.RWMutex), 382 DB: db, 383 } 384 385 for _, p := range params { 386 p(al) 387 } 388 389 var stmtCache *lru.Cache 390 var stmtCacheSize int 391 392 if al.StmtCacheSize > 0 { 393 _stmtCache, errC := newStmtDecoratorLruWithEvict(al.StmtCacheSize) 394 if errC != nil { 395 return nil, errC 396 } else { 397 stmtCache = _stmtCache 398 stmtCacheSize = al.StmtCacheSize 399 } 400 } 401 402 al.Name = aliasName 403 al.DriverName = driverName 404 al.DB.stmtDecorators = stmtCache 405 al.DB.stmtDecoratorsLimit = stmtCacheSize 406 407 if dr, ok := drivers[driverName]; ok { 408 al.DbBaser = dbBasers[dr] 409 al.Driver = dr 410 } else { 411 return nil, fmt.Errorf("driver name `%s` have not registered", driverName) 412 } 413 414 err := db.Ping() 415 if err != nil { 416 return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error()) 417 } 418 419 detectTZ(al) 420 421 return al, nil 422 } 423 424 // SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name 425 // Deprecated you should not use this, we will remove it in the future 426 func SetMaxIdleConns(aliasName string, maxIdleConns int) { 427 al := getDbAlias(aliasName) 428 al.SetMaxIdleConns(maxIdleConns) 429 } 430 431 // SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name 432 // Deprecated you should not use this, we will remove it in the future 433 func SetMaxOpenConns(aliasName string, maxOpenConns int) { 434 al := getDbAlias(aliasName) 435 al.SetMaxOpenConns(maxOpenConns) 436 } 437 438 // SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name 439 func (al *alias) SetMaxIdleConns(maxIdleConns int) { 440 al.MaxIdleConns = maxIdleConns 441 al.DB.DB.SetMaxIdleConns(maxIdleConns) 442 } 443 444 // SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name 445 func (al *alias) SetMaxOpenConns(maxOpenConns int) { 446 al.MaxOpenConns = maxOpenConns 447 al.DB.DB.SetMaxOpenConns(maxOpenConns) 448 } 449 450 func (al *alias) SetConnMaxLifetime(lifeTime time.Duration) { 451 al.ConnMaxLifetime = lifeTime 452 al.DB.DB.SetConnMaxLifetime(lifeTime) 453 } 454 455 // AddAliasWthDB add a aliasName for the drivename 456 func AddAliasWthDB(aliasName, driverName string, db *sql.DB, params ...DBOption) error { 457 _, err := addAliasWthDB(aliasName, driverName, db, params...) 458 return err 459 } 460 461 // RegisterDataBase Setting the database connect params. Use the database driver self dataSource args. 462 func RegisterDataBase(aliasName, driverName, dataSource string, params ...DBOption) error { 463 var ( 464 err error 465 db *sql.DB 466 al *alias 467 ) 468 469 db, err = sql.Open(driverName, dataSource) 470 if err != nil { 471 err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error()) 472 goto end 473 } 474 475 al, err = addAliasWthDB(aliasName, driverName, db, params...) 476 if err != nil { 477 goto end 478 } 479 480 al.DataSource = dataSource 481 482 end: 483 if err != nil { 484 if db != nil { 485 db.Close() 486 } 487 DebugLog.Println(err.Error()) 488 } 489 490 return err 491 } 492 493 // RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type. 494 func RegisterDriver(driverName string, typ DriverType) error { 495 if t, ok := drivers[driverName]; !ok { 496 drivers[driverName] = typ 497 } else { 498 if t != typ { 499 return fmt.Errorf("driverName `%s` db driver already registered and is other type", driverName) 500 } 501 } 502 return nil 503 } 504 505 // SetDataBaseTZ Change the database default used timezone 506 func SetDataBaseTZ(aliasName string, tz *time.Location) error { 507 if al, ok := dataBaseCache.get(aliasName); ok { 508 al.TZ = tz 509 } else { 510 return fmt.Errorf("DataBase alias name `%s` not registered", aliasName) 511 } 512 return nil 513 } 514 515 // GetDB Get *sql.DB from registered database by db alias name. 516 // Use "default" as alias name if you not set. 517 func GetDB(aliasNames ...string) (*sql.DB, error) { 518 var name string 519 if len(aliasNames) > 0 { 520 name = aliasNames[0] 521 } else { 522 name = "default" 523 } 524 al, ok := dataBaseCache.get(name) 525 if ok { 526 return al.DB.DB, nil 527 } 528 return nil, fmt.Errorf("DataBase of alias name `%s` not found", name) 529 } 530 531 type stmtDecorator struct { 532 wg sync.WaitGroup 533 stmt *sql.Stmt 534 } 535 536 func (s *stmtDecorator) getStmt() *sql.Stmt { 537 return s.stmt 538 } 539 540 // acquire will add one 541 // since this method will be used inside read lock scope, 542 // so we can not do more things here 543 // we should think about refactor this 544 func (s *stmtDecorator) acquire() { 545 s.wg.Add(1) 546 } 547 548 func (s *stmtDecorator) release() { 549 s.wg.Done() 550 } 551 552 // garbage recycle for stmt 553 func (s *stmtDecorator) destroy() { 554 go func() { 555 s.wg.Wait() 556 _ = s.stmt.Close() 557 }() 558 } 559 560 func newStmtDecorator(sqlStmt *sql.Stmt) *stmtDecorator { 561 return &stmtDecorator{ 562 stmt: sqlStmt, 563 } 564 } 565 566 func newStmtDecoratorLruWithEvict(cacheSize int) (*lru.Cache, error) { 567 cache, err := lru.NewWithEvict(cacheSize, func(key interface{}, value interface{}) { 568 value.(*stmtDecorator).destroy() 569 }) 570 if err != nil { 571 return nil, err 572 } 573 return cache, nil 574 } 575 576 type DBOption func(al *alias) 577 578 // MaxIdleConnections return a hint about MaxIdleConnections 579 func MaxIdleConnections(maxIdleConn int) DBOption { 580 return func(al *alias) { 581 al.SetMaxIdleConns(maxIdleConn) 582 } 583 } 584 585 // MaxOpenConnections return a hint about MaxOpenConnections 586 func MaxOpenConnections(maxOpenConn int) DBOption { 587 return func(al *alias) { 588 al.SetMaxOpenConns(maxOpenConn) 589 } 590 } 591 592 // ConnMaxLifetime return a hint about ConnMaxLifetime 593 func ConnMaxLifetime(v time.Duration) DBOption { 594 return func(al *alias) { 595 al.SetConnMaxLifetime(v) 596 } 597 } 598 599 // MaxStmtCacheSize return a hint about MaxStmtCacheSize 600 func MaxStmtCacheSize(v int) DBOption { 601 return func(al *alias) { 602 al.StmtCacheSize = v 603 } 604 }