github.com/dkishere/pop/v6@v6.103.1/connection_instrumented.go (about)

     1  package pop
     2  
     3  import (
     4  	"database/sql"
     5  	"database/sql/driver"
     6  	"fmt"
     7  
     8  	mysqld "github.com/go-sql-driver/mysql"
     9  	"github.com/dkishere/pop/v6/logging"
    10  	pgx "github.com/jackc/pgx/v4/stdlib"
    11  	"github.com/jmoiron/sqlx"
    12  	"github.com/luna-duclos/instrumentedsql"
    13  )
    14  
    15  const instrumentedDriverName = "instrumented-sql-driver"
    16  
    17  func instrumentDriver(deets *ConnectionDetails, defaultDriverName string) (driverName, dialect string, err error) {
    18  	driverName = defaultDriverName
    19  	if deets.Driver != "" {
    20  		driverName = deets.Driver
    21  	}
    22  	dialect = driverName
    23  
    24  	if !deets.UseInstrumentedDriver {
    25  		if len(deets.InstrumentedDriverOptions) > 0 {
    26  			log(logging.Warn, "SQL driver instrumentation is disabled but `ConnectionDetails.InstrumentedDriverOptions` is not empty. Please double-check if this is a error.")
    27  		}
    28  
    29  		// If instrumentation is disabled, we just return the driver name we got (e.g. "pgx").
    30  		return driverName, dialect, nil
    31  	}
    32  
    33  	if len(deets.InstrumentedDriverOptions) == 0 {
    34  		log(logging.Warn, "SQL driver instrumentation was enabled but no options have been passed to `ConnectionDetails.InstrumentedDriverOptions`. Instrumentation will therefore not result in any output.")
    35  	}
    36  
    37  	var dr driver.Driver
    38  	var newDriverName string
    39  	switch normalizeSynonyms(driverName) {
    40  	case nameCockroach:
    41  		fallthrough
    42  	case namePostgreSQL:
    43  		dr = new(pgx.Driver)
    44  		newDriverName = instrumentedDriverName + "-" + namePostgreSQL
    45  	case nameMariaDB:
    46  		fallthrough
    47  	case nameMySQL:
    48  		dr = mysqld.MySQLDriver{}
    49  		newDriverName = instrumentedDriverName + "-" + nameMySQL
    50  	case nameSQLite3:
    51  		var err error
    52  		dr, err = newSQLiteDriver()
    53  		if err != nil {
    54  			return "", "", err
    55  		}
    56  		newDriverName = instrumentedDriverName + "-" + nameSQLite3
    57  	}
    58  
    59  	var found bool
    60  	for _, n := range sql.Drivers() {
    61  		if n == newDriverName {
    62  			found = true
    63  			break
    64  		}
    65  	}
    66  
    67  	if !found {
    68  		sql.Register(newDriverName, instrumentedsql.WrapDriver(dr, deets.InstrumentedDriverOptions...))
    69  	}
    70  
    71  	return newDriverName, dialect, nil
    72  }
    73  
    74  // openPotentiallyInstrumentedConnection first opens a raw SQL connection and then wraps it with `sqlx`.
    75  //
    76  // We do this because `sqlx` needs the database type in order to properly
    77  // translate arguments (e.g. `?` to `$1`) in SQL queries. Because we use
    78  // a custom driver name when using instrumentation, this detection would fail
    79  // otherwise.
    80  func openPotentiallyInstrumentedConnection(c dialect, dsn string) (*sqlx.DB, error) {
    81  	driverName, dialect, err := instrumentDriver(c.Details(), c.DefaultDriver())
    82  	if err != nil {
    83  		return nil, err
    84  	}
    85  
    86  	con, err := sql.Open(driverName, dsn)
    87  	if err != nil {
    88  		return nil, fmt.Errorf("could not open database connection: %w", err)
    89  	}
    90  
    91  	return sqlx.NewDb(con, dialect), nil
    92  }