github.com/nshntarora/pop@v0.1.2/connection_instrumented.go (about)

     1  package pop
     2  
     3  import (
     4  	"database/sql"
     5  	"database/sql/driver"
     6  	"fmt"
     7  	"sync"
     8  
     9  	mysqld "github.com/go-sql-driver/mysql"
    10  	pgx "github.com/jackc/pgx/v5/stdlib"
    11  	"github.com/jmoiron/sqlx"
    12  	"github.com/luna-duclos/instrumentedsql"
    13  	"github.com/nshntarora/pop/logging"
    14  )
    15  
    16  const instrumentedDriverName = "instrumented-sql-driver"
    17  
    18  var sqlDriverLock = sync.Mutex{}
    19  
    20  func instrumentDriver(deets *ConnectionDetails, defaultDriverName string) (driverName, dialect string, err error) {
    21  	driverName = defaultDriverName
    22  	if deets.Driver != "" {
    23  		driverName = deets.Driver
    24  	}
    25  	dialect = driverName
    26  
    27  	if !deets.UseInstrumentedDriver {
    28  		if len(deets.InstrumentedDriverOptions) > 0 {
    29  			log(logging.Warn, "SQL driver instrumentation is disabled but `ConnectionDetails.InstrumentedDriverOptions` is not empty. Please double-check if this is a error.")
    30  		}
    31  
    32  		// If instrumentation is disabled, we just return the driver name we got (e.g. "pgx").
    33  		return driverName, dialect, nil
    34  	}
    35  
    36  	if len(deets.InstrumentedDriverOptions) == 0 {
    37  		log(logging.Warn, "SQL driver instrumentation was enabled but no options have been passed to `ConnectionDetails.InstrumentedDriverOptions`. Instrumentation will therefore not result in any output.")
    38  	}
    39  
    40  	var dr driver.Driver
    41  	var newDriverName string
    42  	switch CanonicalDialect(driverName) {
    43  	case nameCockroach:
    44  		fallthrough
    45  	case namePostgreSQL:
    46  		dr = new(pgx.Driver)
    47  		newDriverName = instrumentedDriverName + "-" + namePostgreSQL
    48  	case nameMariaDB:
    49  		fallthrough
    50  	case nameMySQL:
    51  		dr = mysqld.MySQLDriver{}
    52  		newDriverName = instrumentedDriverName + "-" + nameMySQL
    53  	case nameSQLite3:
    54  		var err error
    55  		dr, err = newSQLiteDriver()
    56  		if err != nil {
    57  			return "", "", err
    58  		}
    59  		newDriverName = instrumentedDriverName + "-" + nameSQLite3
    60  	}
    61  
    62  	sqlDriverLock.Lock()
    63  	defer sqlDriverLock.Unlock()
    64  
    65  	var found bool
    66  	for _, n := range sql.Drivers() {
    67  		if n == newDriverName {
    68  			found = true
    69  			break
    70  		}
    71  	}
    72  
    73  	if !found {
    74  		sql.Register(newDriverName, instrumentedsql.WrapDriver(dr, deets.InstrumentedDriverOptions...))
    75  	}
    76  
    77  	return newDriverName, dialect, nil
    78  }
    79  
    80  // openPotentiallyInstrumentedConnection first opens a raw SQL connection and then wraps it with `sqlx`.
    81  //
    82  // We do this because `sqlx` needs the database type in order to properly
    83  // translate arguments (e.g. `?` to `$1`) in SQL queries. Because we use
    84  // a custom driver name when using instrumentation, this detection would fail
    85  // otherwise.
    86  func openPotentiallyInstrumentedConnection(c dialect, dsn string) (*sqlx.DB, error) {
    87  	driverName, dialect, err := instrumentDriver(c.Details(), c.DefaultDriver())
    88  	if err != nil {
    89  		return nil, err
    90  	}
    91  
    92  	con, err := sql.Open(driverName, dsn)
    93  	if err != nil {
    94  		return nil, fmt.Errorf("could not open database connection: %w", err)
    95  	}
    96  
    97  	return sqlx.NewDb(con, dialect), nil
    98  }