github.com/bingoohuang/gg@v0.0.0-20240325092523-45da7dee9335/pkg/sqx/driver.go (about)

     1  package sqx
     2  
     3  import (
     4  	"database/sql"
     5  	"database/sql/driver"
     6  	"log"
     7  	"reflect"
     8  	"strings"
     9  	"sync"
    10  )
    11  
    12  var (
    13  	sqlDriverNamesByType     = map[reflect.Type]string{}
    14  	sqlDriverNamesByTypeLock sync.Mutex
    15  	sqlDriverNamesByTypeOnce sync.Once
    16  )
    17  
    18  // The database/sql API doesn't provide a way to get the registry name for
    19  // a driver from the driver type.
    20  func sqlDriverToDriverName(driver driver.Driver) string {
    21  	driverType := reflect.TypeOf(driver)
    22  	if name, ok := sqlDriverNamesByType[driverType]; ok {
    23  		return name
    24  	}
    25  
    26  	sqlDriverNamesByTypeOnce.Do(func() {
    27  		for _, driverName := range sql.Drivers() {
    28  			// Tested empty string DSN with MySQL, PostgreSQL, and SQLite3 drivers.
    29  			if db, err := sql.Open(driverName, ""); err != nil {
    30  				log.Printf("E! test empty dsn: %v", err)
    31  			} else {
    32  				sqlDriverNamesByType[reflect.TypeOf(db.Driver())] = driverName
    33  			}
    34  		}
    35  	})
    36  
    37  	return sqlDriverNamesByType[driverType]
    38  }
    39  
    40  // RegisterDriverName register the driver name for the current db.
    41  func RegisterDriverName(d driver.Driver, driverName string) {
    42  	sqlDriverNamesByTypeLock.Lock()
    43  	defer sqlDriverNamesByTypeLock.Unlock()
    44  
    45  	sqlDriverNamesByType[reflect.TypeOf(d)] = driverName
    46  }
    47  
    48  // DriverName returns the driver name for the current db.
    49  func DriverName(d driver.Driver) string {
    50  	sqlDriverNamesByTypeLock.Lock()
    51  	defer sqlDriverNamesByTypeLock.Unlock()
    52  
    53  	return sqlDriverToDriverName(d)
    54  }
    55  
    56  // DetectDriverName detects the driver name for database source name.
    57  func DetectDriverName(driverName, dataSourceName string) string {
    58  	// DB | driverName | DSN
    59  	// ---|---|---
    60  	// MySQL |mysql |user:pass@tcp(127.0.0.1:3306)/mydb?charset=utf8
    61  	// 达梦 |dm| dm://user:pass@127.0.0.1:5236
    62  	// 人大金仓|pgx|postgres://user:pass@127.0.0.1:54321/mydb?sslmode=disable
    63  	// 华为GaussDB|opengauss://user:pass@127.0.0.1:54321/mydb?sslmode=disable
    64  
    65  	if driverName == "" {
    66  		if strings.Contains(dataSourceName, "://") {
    67  			driverName = dataSourceName[:strings.Index(dataSourceName, "://")]
    68  		} else {
    69  			driverName = "mysql"
    70  		}
    71  	}
    72  
    73  	switch driverName { // use pgx when dsn starts with postgres
    74  	case "postgres":
    75  		driverName = "pgx"
    76  	}
    77  
    78  	return driverName
    79  }