github.com/mier85/go-sensor@v1.30.1-0.20220920111756-9bf41b3bc7e0/instrumentation_sql.go (about)

     1  // (c) Copyright IBM Corp. 2021
     2  // (c) Copyright Instana Inc. 2020
     3  
     4  package instana
     5  
     6  import (
     7  	"context"
     8  	"database/sql"
     9  	"database/sql/driver"
    10  	"errors"
    11  	"net/url"
    12  	"regexp"
    13  	"strings"
    14  	"sync"
    15  
    16  	ot "github.com/opentracing/opentracing-go"
    17  	"github.com/opentracing/opentracing-go/ext"
    18  	otlog "github.com/opentracing/opentracing-go/log"
    19  
    20  	_ "unsafe"
    21  )
    22  
    23  var (
    24  	sqlDriverRegistrationMu sync.Mutex
    25  )
    26  
    27  // InstrumentSQLDriver instruments provided database driver for  use with `sql.Open()`.
    28  // This method will ignore any attempt to register the driver with the same name again.
    29  //
    30  // The instrumented version is registered with `_with_instana` suffix, e.g.
    31  // if `postgres` provided as a name, the instrumented version is registered as
    32  // `postgres_with_instana`.
    33  func InstrumentSQLDriver(sensor *Sensor, name string, driver driver.Driver) {
    34  	sqlDriverRegistrationMu.Lock()
    35  	defer sqlDriverRegistrationMu.Unlock()
    36  
    37  	instrumentedName := name + "_with_instana"
    38  
    39  	// Check if the instrumented version of a driver has already been registered
    40  	// with database/sql and ignore the second attempt to avoid panicking
    41  	for _, drv := range sql.Drivers() {
    42  		if drv == instrumentedName {
    43  			return
    44  		}
    45  	}
    46  
    47  	sql.Register(instrumentedName, &wrappedSQLDriver{
    48  		Driver: driver,
    49  		sensor: sensor,
    50  	})
    51  }
    52  
    53  // SQLOpen is a convenience wrapper for `sql.Open()` to use the instrumented version
    54  // of a driver previosly registered using `instana.InstrumentSQLDriver()`
    55  func SQLOpen(driverName, dataSourceName string) (*sql.DB, error) {
    56  
    57  	if !strings.HasSuffix(driverName, "_with_instana") {
    58  		driverName += "_with_instana"
    59  	}
    60  
    61  	return sql.Open(driverName, dataSourceName)
    62  }
    63  
    64  //go:linkname drivers database/sql.drivers
    65  var drivers map[string]driver.Driver
    66  
    67  // SQLInstrumentAndOpen returns instrumented `*sql.DB`.
    68  // It takes already registered `driver.Driver` by name, instruments it and additionally registers
    69  // it with different name. After that it returns instrumented `*sql.DB` or error if any.
    70  //
    71  // This function can be used as a convenient shortcut for InstrumentSQLDriver and SQLOpen functions.
    72  // The main difference is that this approach will use the already registered driver and using InstrumentSQLDriver
    73  // requires to explicitly provide an instance of the driver to instrument.
    74  func SQLInstrumentAndOpen(sensor *Sensor, driverName, dataSourceName string) (*sql.DB, error) {
    75  	if d, ok := drivers[driverName]; ok {
    76  		InstrumentSQLDriver(sensor, driverName, d)
    77  	}
    78  
    79  	return SQLOpen(driverName, dataSourceName)
    80  }
    81  
    82  type wrappedSQLDriver struct {
    83  	driver.Driver
    84  
    85  	sensor *Sensor
    86  }
    87  
    88  func (drv *wrappedSQLDriver) Open(name string) (driver.Conn, error) {
    89  	conn, err := drv.Driver.Open(name)
    90  	if err != nil {
    91  		return conn, err
    92  	}
    93  
    94  	if conn, ok := conn.(*wrappedSQLConn); ok {
    95  		return conn, nil
    96  	}
    97  
    98  	return &wrappedSQLConn{
    99  		Conn:    conn,
   100  		details: parseDBConnDetails(name),
   101  		sensor:  drv.sensor,
   102  	}, nil
   103  }
   104  
   105  type wrappedSQLConn struct {
   106  	driver.Conn
   107  
   108  	details dbConnDetails
   109  	sensor  *Sensor
   110  }
   111  
   112  func (conn *wrappedSQLConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
   113  	sp := startSQLSpan(ctx, conn.details, query, conn.sensor)
   114  	defer sp.Finish()
   115  
   116  	if c, ok := conn.Conn.(driver.QueryerContext); ok {
   117  		res, err := c.QueryContext(ctx, query, args)
   118  		if err != nil && err != driver.ErrSkip {
   119  			sp.LogFields(otlog.Error(err))
   120  		}
   121  
   122  		return res, err
   123  	}
   124  
   125  	if c, ok := conn.Conn.(driver.Queryer); ok { //nolint:staticcheck
   126  		values, err := sqlNamedValuesToValues(args)
   127  		if err != nil {
   128  			return nil, err
   129  		}
   130  
   131  		select {
   132  		default:
   133  		case <-ctx.Done():
   134  			return nil, ctx.Err()
   135  		}
   136  
   137  		res, err := c.Query(query, values)
   138  		if err != nil && err != driver.ErrSkip {
   139  			sp.LogFields(otlog.Error(err))
   140  		}
   141  
   142  		return res, err
   143  	}
   144  
   145  	return nil, driver.ErrSkip
   146  }
   147  
   148  func (conn *wrappedSQLConn) Prepare(query string) (driver.Stmt, error) {
   149  	stmt, err := conn.Conn.Prepare(query)
   150  	if err != nil {
   151  		return stmt, err
   152  	}
   153  
   154  	if stmt, ok := stmt.(*wrappedSQLStmt); ok {
   155  		return stmt, nil
   156  	}
   157  
   158  	return &wrappedSQLStmt{
   159  		Stmt:        stmt,
   160  		connDetails: conn.details,
   161  		query:       query,
   162  		sensor:      conn.sensor,
   163  	}, nil
   164  }
   165  
   166  func (conn *wrappedSQLConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
   167  	var (
   168  		stmt driver.Stmt
   169  		err  error
   170  	)
   171  	if c, ok := conn.Conn.(driver.ConnPrepareContext); ok {
   172  		stmt, err = c.PrepareContext(ctx, query)
   173  	} else {
   174  		stmt, err = conn.Prepare(query)
   175  	}
   176  
   177  	if err != nil {
   178  		return stmt, err
   179  	}
   180  
   181  	if stmt, ok := stmt.(*wrappedSQLStmt); ok {
   182  		return stmt, nil
   183  	}
   184  
   185  	return &wrappedSQLStmt{
   186  		Stmt:        stmt,
   187  		connDetails: conn.details,
   188  		query:       query,
   189  		sensor:      conn.sensor,
   190  	}, nil
   191  }
   192  
   193  func (conn *wrappedSQLConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
   194  	sp := startSQLSpan(ctx, conn.details, query, conn.sensor)
   195  	defer sp.Finish()
   196  
   197  	if c, ok := conn.Conn.(driver.ExecerContext); ok {
   198  		res, err := c.ExecContext(ctx, query, args)
   199  		if err != nil && err != driver.ErrSkip {
   200  			sp.LogFields(otlog.Error(err))
   201  		}
   202  
   203  		return res, err
   204  	}
   205  
   206  	if c, ok := conn.Conn.(driver.Execer); ok { //nolint:staticcheck
   207  		values, err := sqlNamedValuesToValues(args)
   208  		if err != nil {
   209  			return nil, err
   210  		}
   211  
   212  		select {
   213  		default:
   214  		case <-ctx.Done():
   215  			return nil, ctx.Err()
   216  		}
   217  
   218  		res, err := c.Exec(query, values)
   219  		if err != nil && err != driver.ErrSkip {
   220  			sp.LogFields(otlog.Error(err))
   221  		}
   222  
   223  		return res, err
   224  	}
   225  
   226  	return nil, driver.ErrSkip
   227  }
   228  
   229  type wrappedSQLStmt struct {
   230  	driver.Stmt
   231  
   232  	connDetails dbConnDetails
   233  	query       string
   234  	sensor      *Sensor
   235  }
   236  
   237  func (stmt *wrappedSQLStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
   238  	sp := startSQLSpan(ctx, stmt.connDetails, stmt.query, stmt.sensor)
   239  	defer sp.Finish()
   240  
   241  	if s, ok := stmt.Stmt.(driver.StmtExecContext); ok {
   242  		res, err := s.ExecContext(ctx, args)
   243  		if err != nil && err != driver.ErrSkip {
   244  			sp.LogFields(otlog.Error(err))
   245  		}
   246  
   247  		return res, err
   248  	}
   249  
   250  	values, err := sqlNamedValuesToValues(args)
   251  	if err != nil {
   252  		return nil, err
   253  	}
   254  
   255  	select {
   256  	default:
   257  	case <-ctx.Done():
   258  		return nil, ctx.Err()
   259  	}
   260  
   261  	res, err := stmt.Exec(values) //nolint:staticcheck
   262  	if err != nil && err != driver.ErrSkip {
   263  		sp.LogFields(otlog.Error(err))
   264  	}
   265  
   266  	return res, err
   267  }
   268  
   269  func (stmt *wrappedSQLStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
   270  	sp := startSQLSpan(ctx, stmt.connDetails, stmt.query, stmt.sensor)
   271  	defer sp.Finish()
   272  
   273  	if s, ok := stmt.Stmt.(driver.StmtQueryContext); ok {
   274  		res, err := s.QueryContext(ctx, args)
   275  		if err != nil && err != driver.ErrSkip {
   276  			sp.LogFields(otlog.Error(err))
   277  		}
   278  
   279  		return res, err
   280  	}
   281  
   282  	values, err := sqlNamedValuesToValues(args)
   283  	if err != nil {
   284  		return nil, err
   285  	}
   286  
   287  	select {
   288  	default:
   289  	case <-ctx.Done():
   290  		return nil, ctx.Err()
   291  	}
   292  
   293  	res, err := stmt.Stmt.Query(values) //nolint:staticcheck
   294  	if err != nil && err != driver.ErrSkip {
   295  		sp.LogFields(otlog.Error(err))
   296  	}
   297  
   298  	return res, err
   299  }
   300  
   301  func startSQLSpan(ctx context.Context, conn dbConnDetails, query string, sensor *Sensor) ot.Span {
   302  	tags := ot.Tags{
   303  		string(ext.DBType):      "sql",
   304  		string(ext.DBStatement): query,
   305  		string(ext.PeerAddress): conn.RawString,
   306  	}
   307  
   308  	if conn.Schema != "" {
   309  		tags[string(ext.DBInstance)] = conn.Schema
   310  	} else {
   311  		tags[string(ext.DBInstance)] = conn.RawString
   312  	}
   313  
   314  	if conn.Host != "" {
   315  		tags[string(ext.PeerHostname)] = conn.Host
   316  	}
   317  
   318  	if conn.Port != "" {
   319  		tags[string(ext.PeerPort)] = conn.Port
   320  	}
   321  
   322  	opts := []ot.StartSpanOption{ext.SpanKindRPCClient, tags}
   323  	if parentSpan, ok := SpanFromContext(ctx); ok {
   324  		opts = append(opts, ot.ChildOf(parentSpan.Context()))
   325  	}
   326  
   327  	return sensor.Tracer().StartSpan("sdk.database", opts...)
   328  }
   329  
   330  type dbConnDetails struct {
   331  	RawString  string
   332  	Host, Port string
   333  	Schema     string
   334  	User       string
   335  }
   336  
   337  func parseDBConnDetails(connStr string) dbConnDetails {
   338  	strategies := [...]func(string) (dbConnDetails, bool){
   339  		parseDBConnDetailsURI,
   340  		parsePostgresConnDetailsKV,
   341  		parseMySQLConnDetailsKV,
   342  	}
   343  	for _, parseFn := range strategies {
   344  		if details, ok := parseFn(connStr); ok {
   345  			return details
   346  		}
   347  	}
   348  
   349  	return dbConnDetails{RawString: connStr}
   350  }
   351  
   352  // parseDBConnDetailsURI attempts to parse a connection string as an URI, assuming that it has
   353  // following format: [scheme://][user[:[password]]@]host[:port][/schema][?attribute1=value1&attribute2=value2...]
   354  func parseDBConnDetailsURI(connStr string) (dbConnDetails, bool) {
   355  	u, err := url.Parse(connStr)
   356  	if err != nil {
   357  		return dbConnDetails{}, false
   358  	}
   359  
   360  	if u.Scheme == "" {
   361  		return dbConnDetails{}, false
   362  	}
   363  
   364  	path := ""
   365  	if len(u.Path) > 1 {
   366  		path = u.Path[1:]
   367  	}
   368  
   369  	details := dbConnDetails{
   370  		RawString: connStr,
   371  		Host:      u.Hostname(),
   372  		Port:      u.Port(),
   373  		Schema:    path,
   374  	}
   375  
   376  	if u.User != nil {
   377  		details.User = u.User.Username()
   378  
   379  		// create a copy without user password
   380  		u := cloneURL(u)
   381  		u.User = url.User(details.User)
   382  		details.RawString = u.String()
   383  	}
   384  
   385  	return details, true
   386  }
   387  
   388  var postgresKVPasswordRegex = regexp.MustCompile(`(^|\s)password=[^\s]+(\s|$)`)
   389  
   390  // parsePostgresConnDetailsKV parses a space-separated PostgreSQL-style connection string
   391  func parsePostgresConnDetailsKV(connStr string) (dbConnDetails, bool) {
   392  	var details dbConnDetails
   393  
   394  	for _, field := range strings.Split(connStr, " ") {
   395  		fieldNorm := strings.ToLower(field)
   396  
   397  		var (
   398  			prefix   string
   399  			fieldPtr *string
   400  		)
   401  		switch {
   402  		case strings.HasPrefix(fieldNorm, "host="):
   403  			if details.Host != "" {
   404  				// hostaddr= takes precedence
   405  				continue
   406  			}
   407  
   408  			prefix, fieldPtr = "host=", &details.Host
   409  		case strings.HasPrefix(fieldNorm, "hostaddr="):
   410  			prefix, fieldPtr = "hostaddr=", &details.Host
   411  		case strings.HasPrefix(fieldNorm, "port="):
   412  			prefix, fieldPtr = "port=", &details.Port
   413  		case strings.HasPrefix(fieldNorm, "user="):
   414  			prefix, fieldPtr = "user=", &details.User
   415  		case strings.HasPrefix(fieldNorm, "dbname="):
   416  			prefix, fieldPtr = "dbname=", &details.Schema
   417  		default:
   418  			continue
   419  		}
   420  
   421  		*fieldPtr = field[len(prefix):]
   422  	}
   423  
   424  	if details.Schema == "" {
   425  		return dbConnDetails{}, false
   426  	}
   427  
   428  	details.RawString = postgresKVPasswordRegex.ReplaceAllString(connStr, " ")
   429  
   430  	return details, true
   431  }
   432  
   433  var mysqlKVPasswordRegex = regexp.MustCompile(`(?i)(^|;)Pwd=[^;]+(;|$)`)
   434  
   435  // parseMySQLConnDetailsKV parses a semicolon-separated MySQL-style connection string
   436  func parseMySQLConnDetailsKV(connStr string) (dbConnDetails, bool) {
   437  	details := dbConnDetails{RawString: connStr}
   438  
   439  	for _, field := range strings.Split(connStr, ";") {
   440  		fieldNorm := strings.ToLower(field)
   441  
   442  		var (
   443  			prefix   string
   444  			fieldPtr *string
   445  		)
   446  		switch {
   447  		case strings.HasPrefix(fieldNorm, "server="):
   448  			prefix, fieldPtr = "server=", &details.Host
   449  		case strings.HasPrefix(fieldNorm, "port="):
   450  			prefix, fieldPtr = "port=", &details.Port
   451  		case strings.HasPrefix(fieldNorm, "uid="):
   452  			prefix, fieldPtr = "uid=", &details.User
   453  		case strings.HasPrefix(fieldNorm, "database="):
   454  			prefix, fieldPtr = "database=", &details.Schema
   455  		default:
   456  			continue
   457  		}
   458  
   459  		*fieldPtr = field[len(prefix):]
   460  	}
   461  
   462  	if details.Schema == "" {
   463  		return dbConnDetails{}, false
   464  	}
   465  
   466  	details.RawString = mysqlKVPasswordRegex.ReplaceAllString(connStr, ";")
   467  
   468  	return details, true
   469  }
   470  
   471  // The following code is ported from $GOROOT/src/database/sql/ctxutil.go
   472  //
   473  // Copyright 2019 The Go Authors. All rights reserved.
   474  // Use of this source code is governed by a BSD-style
   475  // license that can be found in the LICENSE file.
   476  func sqlNamedValuesToValues(named []driver.NamedValue) ([]driver.Value, error) {
   477  	dargs := make([]driver.Value, len(named))
   478  	for n, param := range named {
   479  		if len(param.Name) > 0 {
   480  			return nil, errors.New("sql: driver does not support the use of Named Parameters")
   481  		}
   482  		dargs[n] = param.Value
   483  	}
   484  	return dargs, nil
   485  }
   486  
   487  type dsnConnector struct {
   488  	dsn    string
   489  	driver driver.Driver
   490  }
   491  
   492  func (t dsnConnector) Connect(_ context.Context) (driver.Conn, error) {
   493  	return t.driver.Open(t.dsn)
   494  }
   495  
   496  func (t dsnConnector) Driver() driver.Driver {
   497  	return t.driver
   498  }