github.com/bingoohuang/gg@v0.0.0-20240325092523-45da7dee9335/pkg/sqx/driver.go (about) 1 package sqx 2 3 import ( 4 "database/sql" 5 "database/sql/driver" 6 "log" 7 "reflect" 8 "strings" 9 "sync" 10 ) 11 12 var ( 13 sqlDriverNamesByType = map[reflect.Type]string{} 14 sqlDriverNamesByTypeLock sync.Mutex 15 sqlDriverNamesByTypeOnce sync.Once 16 ) 17 18 // The database/sql API doesn't provide a way to get the registry name for 19 // a driver from the driver type. 20 func sqlDriverToDriverName(driver driver.Driver) string { 21 driverType := reflect.TypeOf(driver) 22 if name, ok := sqlDriverNamesByType[driverType]; ok { 23 return name 24 } 25 26 sqlDriverNamesByTypeOnce.Do(func() { 27 for _, driverName := range sql.Drivers() { 28 // Tested empty string DSN with MySQL, PostgreSQL, and SQLite3 drivers. 29 if db, err := sql.Open(driverName, ""); err != nil { 30 log.Printf("E! test empty dsn: %v", err) 31 } else { 32 sqlDriverNamesByType[reflect.TypeOf(db.Driver())] = driverName 33 } 34 } 35 }) 36 37 return sqlDriverNamesByType[driverType] 38 } 39 40 // RegisterDriverName register the driver name for the current db. 41 func RegisterDriverName(d driver.Driver, driverName string) { 42 sqlDriverNamesByTypeLock.Lock() 43 defer sqlDriverNamesByTypeLock.Unlock() 44 45 sqlDriverNamesByType[reflect.TypeOf(d)] = driverName 46 } 47 48 // DriverName returns the driver name for the current db. 49 func DriverName(d driver.Driver) string { 50 sqlDriverNamesByTypeLock.Lock() 51 defer sqlDriverNamesByTypeLock.Unlock() 52 53 return sqlDriverToDriverName(d) 54 } 55 56 // DetectDriverName detects the driver name for database source name. 57 func DetectDriverName(driverName, dataSourceName string) string { 58 // DB | driverName | DSN 59 // ---|---|--- 60 // MySQL |mysql |user:pass@tcp(127.0.0.1:3306)/mydb?charset=utf8 61 // 达梦 |dm| dm://user:pass@127.0.0.1:5236 62 // 人大金仓|pgx|postgres://user:pass@127.0.0.1:54321/mydb?sslmode=disable 63 // 华为GaussDB|opengauss://user:pass@127.0.0.1:54321/mydb?sslmode=disable 64 65 if driverName == "" { 66 if strings.Contains(dataSourceName, "://") { 67 driverName = dataSourceName[:strings.Index(dataSourceName, "://")] 68 } else { 69 driverName = "mysql" 70 } 71 } 72 73 switch driverName { // use pgx when dsn starts with postgres 74 case "postgres": 75 driverName = "pgx" 76 } 77 78 return driverName 79 }