github.com/astaxie/beego@v1.12.3/orm/db_alias.go (about) 1 // Copyright 2014 beego Author. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package orm 16 17 import ( 18 "context" 19 "database/sql" 20 "fmt" 21 "reflect" 22 "sync" 23 "time" 24 25 lru "github.com/hashicorp/golang-lru" 26 ) 27 28 // DriverType database driver constant int. 29 type DriverType int 30 31 // Enum the Database driver 32 const ( 33 _ DriverType = iota // int enum type 34 DRMySQL // mysql 35 DRSqlite // sqlite 36 DROracle // oracle 37 DRPostgres // pgsql 38 DRTiDB // TiDB 39 ) 40 41 // database driver string. 42 type driver string 43 44 // get type constant int of current driver.. 45 func (d driver) Type() DriverType { 46 a, _ := dataBaseCache.get(string(d)) 47 return a.Driver 48 } 49 50 // get name of current driver 51 func (d driver) Name() string { 52 return string(d) 53 } 54 55 // check driver iis implemented Driver interface or not. 56 var _ Driver = new(driver) 57 58 var ( 59 dataBaseCache = &_dbCache{cache: make(map[string]*alias)} 60 drivers = map[string]DriverType{ 61 "mysql": DRMySQL, 62 "postgres": DRPostgres, 63 "sqlite3": DRSqlite, 64 "tidb": DRTiDB, 65 "oracle": DROracle, 66 "oci8": DROracle, // github.com/mattn/go-oci8 67 "ora": DROracle, //https://github.com/rana/ora 68 } 69 dbBasers = map[DriverType]dbBaser{ 70 DRMySQL: newdbBaseMysql(), 71 DRSqlite: newdbBaseSqlite(), 72 DROracle: newdbBaseOracle(), 73 DRPostgres: newdbBasePostgres(), 74 DRTiDB: newdbBaseTidb(), 75 } 76 ) 77 78 // database alias cacher. 79 type _dbCache struct { 80 mux sync.RWMutex 81 cache map[string]*alias 82 } 83 84 // add database alias with original name. 85 func (ac *_dbCache) add(name string, al *alias) (added bool) { 86 ac.mux.Lock() 87 defer ac.mux.Unlock() 88 if _, ok := ac.cache[name]; !ok { 89 ac.cache[name] = al 90 added = true 91 } 92 return 93 } 94 95 // get database alias if cached. 96 func (ac *_dbCache) get(name string) (al *alias, ok bool) { 97 ac.mux.RLock() 98 defer ac.mux.RUnlock() 99 al, ok = ac.cache[name] 100 return 101 } 102 103 // get default alias. 104 func (ac *_dbCache) getDefault() (al *alias) { 105 al, _ = ac.get("default") 106 return 107 } 108 109 type DB struct { 110 *sync.RWMutex 111 DB *sql.DB 112 stmtDecorators *lru.Cache 113 } 114 115 func (d *DB) Begin() (*sql.Tx, error) { 116 return d.DB.Begin() 117 } 118 119 func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { 120 return d.DB.BeginTx(ctx, opts) 121 } 122 123 //su must call release to release *sql.Stmt after using 124 func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) { 125 d.RLock() 126 c, ok := d.stmtDecorators.Get(query) 127 if ok { 128 c.(*stmtDecorator).acquire() 129 d.RUnlock() 130 return c.(*stmtDecorator), nil 131 } 132 d.RUnlock() 133 134 d.Lock() 135 c, ok = d.stmtDecorators.Get(query) 136 if ok { 137 c.(*stmtDecorator).acquire() 138 d.Unlock() 139 return c.(*stmtDecorator), nil 140 } 141 142 stmt, err := d.Prepare(query) 143 if err != nil { 144 d.Unlock() 145 return nil, err 146 } 147 sd := newStmtDecorator(stmt) 148 sd.acquire() 149 d.stmtDecorators.Add(query, sd) 150 d.Unlock() 151 152 return sd, nil 153 } 154 155 func (d *DB) Prepare(query string) (*sql.Stmt, error) { 156 return d.DB.Prepare(query) 157 } 158 159 func (d *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { 160 return d.DB.PrepareContext(ctx, query) 161 } 162 163 func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) { 164 sd, err := d.getStmtDecorator(query) 165 if err != nil { 166 return nil, err 167 } 168 stmt := sd.getStmt() 169 defer sd.release() 170 return stmt.Exec(args...) 171 } 172 173 func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { 174 sd, err := d.getStmtDecorator(query) 175 if err != nil { 176 return nil, err 177 } 178 stmt := sd.getStmt() 179 defer sd.release() 180 return stmt.ExecContext(ctx, args...) 181 } 182 183 func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) { 184 sd, err := d.getStmtDecorator(query) 185 if err != nil { 186 return nil, err 187 } 188 stmt := sd.getStmt() 189 defer sd.release() 190 return stmt.Query(args...) 191 } 192 193 func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { 194 sd, err := d.getStmtDecorator(query) 195 if err != nil { 196 return nil, err 197 } 198 stmt := sd.getStmt() 199 defer sd.release() 200 return stmt.QueryContext(ctx, args...) 201 } 202 203 func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row { 204 sd, err := d.getStmtDecorator(query) 205 if err != nil { 206 panic(err) 207 } 208 stmt := sd.getStmt() 209 defer sd.release() 210 return stmt.QueryRow(args...) 211 212 } 213 214 func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { 215 sd, err := d.getStmtDecorator(query) 216 if err != nil { 217 panic(err) 218 } 219 stmt := sd.getStmt() 220 defer sd.release() 221 return stmt.QueryRowContext(ctx, args) 222 } 223 224 type alias struct { 225 Name string 226 Driver DriverType 227 DriverName string 228 DataSource string 229 MaxIdleConns int 230 MaxOpenConns int 231 DB *DB 232 DbBaser dbBaser 233 TZ *time.Location 234 Engine string 235 } 236 237 func detectTZ(al *alias) { 238 // orm timezone system match database 239 // default use Local 240 al.TZ = DefaultTimeLoc 241 242 if al.DriverName == "sphinx" { 243 return 244 } 245 246 switch al.Driver { 247 case DRMySQL: 248 row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)") 249 var tz string 250 row.Scan(&tz) 251 if len(tz) >= 8 { 252 if tz[0] != '-' { 253 tz = "+" + tz 254 } 255 t, err := time.Parse("-07:00:00", tz) 256 if err == nil { 257 if t.Location().String() != "" { 258 al.TZ = t.Location() 259 } 260 } else { 261 DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error()) 262 } 263 } 264 265 // get default engine from current database 266 row = al.DB.QueryRow("SELECT ENGINE, TRANSACTIONS FROM information_schema.engines WHERE SUPPORT = 'DEFAULT'") 267 var engine string 268 var tx bool 269 row.Scan(&engine, &tx) 270 271 if engine != "" { 272 al.Engine = engine 273 } else { 274 al.Engine = "INNODB" 275 } 276 277 case DRSqlite, DROracle: 278 al.TZ = time.UTC 279 280 case DRPostgres: 281 row := al.DB.QueryRow("SELECT current_setting('TIMEZONE')") 282 var tz string 283 row.Scan(&tz) 284 loc, err := time.LoadLocation(tz) 285 if err == nil { 286 al.TZ = loc 287 } else { 288 DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error()) 289 } 290 } 291 } 292 293 func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) { 294 al := new(alias) 295 al.Name = aliasName 296 al.DriverName = driverName 297 al.DB = &DB{ 298 RWMutex: new(sync.RWMutex), 299 DB: db, 300 stmtDecorators: newStmtDecoratorLruWithEvict(), 301 } 302 303 if dr, ok := drivers[driverName]; ok { 304 al.DbBaser = dbBasers[dr] 305 al.Driver = dr 306 } else { 307 return nil, fmt.Errorf("driver name `%s` have not registered", driverName) 308 } 309 310 err := db.Ping() 311 if err != nil { 312 return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error()) 313 } 314 315 if !dataBaseCache.add(aliasName, al) { 316 return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName) 317 } 318 319 return al, nil 320 } 321 322 // AddAliasWthDB add a aliasName for the drivename 323 func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error { 324 _, err := addAliasWthDB(aliasName, driverName, db) 325 return err 326 } 327 328 // RegisterDataBase Setting the database connect params. Use the database driver self dataSource args. 329 func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error { 330 var ( 331 err error 332 db *sql.DB 333 al *alias 334 ) 335 336 db, err = sql.Open(driverName, dataSource) 337 if err != nil { 338 err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error()) 339 goto end 340 } 341 342 al, err = addAliasWthDB(aliasName, driverName, db) 343 if err != nil { 344 goto end 345 } 346 347 al.DataSource = dataSource 348 349 detectTZ(al) 350 351 for i, v := range params { 352 switch i { 353 case 0: 354 SetMaxIdleConns(al.Name, v) 355 case 1: 356 SetMaxOpenConns(al.Name, v) 357 } 358 } 359 360 end: 361 if err != nil { 362 if db != nil { 363 db.Close() 364 } 365 DebugLog.Println(err.Error()) 366 } 367 368 return err 369 } 370 371 // RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type. 372 func RegisterDriver(driverName string, typ DriverType) error { 373 if t, ok := drivers[driverName]; !ok { 374 drivers[driverName] = typ 375 } else { 376 if t != typ { 377 return fmt.Errorf("driverName `%s` db driver already registered and is other type", driverName) 378 } 379 } 380 return nil 381 } 382 383 // SetDataBaseTZ Change the database default used timezone 384 func SetDataBaseTZ(aliasName string, tz *time.Location) error { 385 if al, ok := dataBaseCache.get(aliasName); ok { 386 al.TZ = tz 387 } else { 388 return fmt.Errorf("DataBase alias name `%s` not registered", aliasName) 389 } 390 return nil 391 } 392 393 // SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name 394 func SetMaxIdleConns(aliasName string, maxIdleConns int) { 395 al := getDbAlias(aliasName) 396 al.MaxIdleConns = maxIdleConns 397 al.DB.DB.SetMaxIdleConns(maxIdleConns) 398 } 399 400 // SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name 401 func SetMaxOpenConns(aliasName string, maxOpenConns int) { 402 al := getDbAlias(aliasName) 403 al.MaxOpenConns = maxOpenConns 404 al.DB.DB.SetMaxOpenConns(maxOpenConns) 405 // for tip go 1.2 406 if fun := reflect.ValueOf(al.DB).MethodByName("SetMaxOpenConns"); fun.IsValid() { 407 fun.Call([]reflect.Value{reflect.ValueOf(maxOpenConns)}) 408 } 409 } 410 411 // GetDB Get *sql.DB from registered database by db alias name. 412 // Use "default" as alias name if you not set. 413 func GetDB(aliasNames ...string) (*sql.DB, error) { 414 var name string 415 if len(aliasNames) > 0 { 416 name = aliasNames[0] 417 } else { 418 name = "default" 419 } 420 al, ok := dataBaseCache.get(name) 421 if ok { 422 return al.DB.DB, nil 423 } 424 return nil, fmt.Errorf("DataBase of alias name `%s` not found", name) 425 } 426 427 type stmtDecorator struct { 428 wg sync.WaitGroup 429 stmt *sql.Stmt 430 } 431 432 func (s *stmtDecorator) getStmt() *sql.Stmt { 433 return s.stmt 434 } 435 436 // acquire will add one 437 // since this method will be used inside read lock scope, 438 // so we can not do more things here 439 // we should think about refactor this 440 func (s *stmtDecorator) acquire() { 441 s.wg.Add(1) 442 } 443 444 func (s *stmtDecorator) release() { 445 s.wg.Done() 446 } 447 448 //garbage recycle for stmt 449 func (s *stmtDecorator) destroy() { 450 go func() { 451 s.wg.Wait() 452 _ = s.stmt.Close() 453 }() 454 } 455 456 func newStmtDecorator(sqlStmt *sql.Stmt) *stmtDecorator { 457 return &stmtDecorator{ 458 stmt: sqlStmt, 459 } 460 } 461 462 func newStmtDecoratorLruWithEvict() *lru.Cache { 463 // temporarily solution 464 // we fixed this problem in v2.x 465 cache, _ := lru.NewWithEvict(50, func(key interface{}, value interface{}) { 466 value.(*stmtDecorator).destroy() 467 }) 468 return cache 469 }