github.com/newrelic/go-agent@v3.26.0+incompatible/sql_driver.go (about)

     1  // Copyright 2020 New Relic Corporation. All rights reserved.
     2  // SPDX-License-Identifier: Apache-2.0
     3  
     4  // +build go1.10
     5  
     6  package newrelic
     7  
     8  import (
     9  	"context"
    10  	"database/sql/driver"
    11  )
    12  
    13  // SQLDriverSegmentBuilder populates DatastoreSegments for sql.Driver
    14  // instrumentation.  Use this to instrument a database that is not supported by
    15  // an existing integration package (nrmysql, nrpq, and nrsqlite3). See
    16  // https://github.com/newrelic/go-agent/blob/master/_integrations/nrmysql/nrmysql.go
    17  // for example use.
    18  type SQLDriverSegmentBuilder struct {
    19  	BaseSegment DatastoreSegment
    20  	ParseQuery  func(segment *DatastoreSegment, query string)
    21  	ParseDSN    func(segment *DatastoreSegment, dataSourceName string)
    22  }
    23  
    24  // InstrumentSQLDriver wraps a driver.Driver, adding instrumentation for exec
    25  // and query calls made with a transaction-containing context.  Use this to
    26  // instrument a database driver that is not supported by an existing integration
    27  // package (nrmysql, nrpq, and nrsqlite3). See
    28  // https://github.com/newrelic/go-agent/blob/master/_integrations/nrmysql/nrmysql.go
    29  // for example use.
    30  func InstrumentSQLDriver(d driver.Driver, bld SQLDriverSegmentBuilder) driver.Driver {
    31  	return optionalMethodsDriver(&wrapDriver{bld: bld, original: d})
    32  }
    33  
    34  // InstrumentSQLConnector wraps a driver.Connector, adding instrumentation for
    35  // exec and query calls made with a transaction-containing context.  Use this to
    36  // instrument a database connector that is not supported by an existing
    37  // integration package (nrmysql, nrpq, and nrsqlite3). See
    38  // https://github.com/newrelic/go-agent/blob/master/_integrations/nrmysql/nrmysql.go
    39  // for example use.
    40  func InstrumentSQLConnector(connector driver.Connector, bld SQLDriverSegmentBuilder) driver.Connector {
    41  	return &wrapConnector{original: connector, bld: bld}
    42  }
    43  
    44  func (bld SQLDriverSegmentBuilder) useDSN(dsn string) SQLDriverSegmentBuilder {
    45  	if f := bld.ParseDSN; nil != f {
    46  		f(&bld.BaseSegment, dsn)
    47  	}
    48  	return bld
    49  }
    50  
    51  func (bld SQLDriverSegmentBuilder) useQuery(query string) SQLDriverSegmentBuilder {
    52  	if f := bld.ParseQuery; nil != f {
    53  		f(&bld.BaseSegment, query)
    54  	}
    55  	return bld
    56  }
    57  
    58  func (bld SQLDriverSegmentBuilder) startSegment(ctx context.Context) DatastoreSegment {
    59  	segment := bld.BaseSegment
    60  	segment.StartTime = StartSegmentNow(FromContext(ctx))
    61  	return segment
    62  }
    63  
    64  type wrapDriver struct {
    65  	bld      SQLDriverSegmentBuilder
    66  	original driver.Driver
    67  }
    68  
    69  type wrapConnector struct {
    70  	bld      SQLDriverSegmentBuilder
    71  	original driver.Connector
    72  }
    73  
    74  type wrapConn struct {
    75  	bld      SQLDriverSegmentBuilder
    76  	original driver.Conn
    77  }
    78  
    79  type wrapStmt struct {
    80  	bld      SQLDriverSegmentBuilder
    81  	original driver.Stmt
    82  }
    83  
    84  func (w *wrapDriver) Open(name string) (driver.Conn, error) {
    85  	original, err := w.original.Open(name)
    86  	if err != nil {
    87  		return nil, err
    88  	}
    89  	return optionalMethodsConn(&wrapConn{
    90  		original: original,
    91  		bld:      w.bld.useDSN(name),
    92  	}), nil
    93  }
    94  
    95  // OpenConnector implements DriverContext.
    96  func (w *wrapDriver) OpenConnector(name string) (driver.Connector, error) {
    97  	original, err := w.original.(driver.DriverContext).OpenConnector(name)
    98  	if err != nil {
    99  		return nil, err
   100  	}
   101  	return &wrapConnector{
   102  		original: original,
   103  		bld:      w.bld.useDSN(name),
   104  	}, nil
   105  }
   106  
   107  func (w *wrapConnector) Connect(ctx context.Context) (driver.Conn, error) {
   108  	original, err := w.original.Connect(ctx)
   109  	if nil != err {
   110  		return nil, err
   111  	}
   112  	return optionalMethodsConn(&wrapConn{
   113  		bld:      w.bld,
   114  		original: original,
   115  	}), nil
   116  }
   117  
   118  func (w *wrapConnector) Driver() driver.Driver {
   119  	return optionalMethodsDriver(&wrapDriver{
   120  		bld:      w.bld,
   121  		original: w.original.Driver(),
   122  	})
   123  }
   124  
   125  func prepare(original driver.Stmt, err error, bld SQLDriverSegmentBuilder, query string) (driver.Stmt, error) {
   126  	if nil != err {
   127  		return nil, err
   128  	}
   129  	return optionalMethodsStmt(&wrapStmt{
   130  		bld:      bld.useQuery(query),
   131  		original: original,
   132  	}), nil
   133  }
   134  
   135  func (w *wrapConn) Prepare(query string) (driver.Stmt, error) {
   136  	original, err := w.original.Prepare(query)
   137  	return prepare(original, err, w.bld, query)
   138  }
   139  
   140  // PrepareContext implements ConnPrepareContext.
   141  func (w *wrapConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
   142  	original, err := w.original.(driver.ConnPrepareContext).PrepareContext(ctx, query)
   143  	return prepare(original, err, w.bld, query)
   144  }
   145  
   146  func (w *wrapConn) Close() error {
   147  	return w.original.Close()
   148  }
   149  
   150  func (w *wrapConn) Begin() (driver.Tx, error) {
   151  	return w.original.Begin()
   152  }
   153  
   154  // BeginTx implements ConnBeginTx.
   155  func (w *wrapConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
   156  	return w.original.(driver.ConnBeginTx).BeginTx(ctx, opts)
   157  }
   158  
   159  // Exec implements Execer.
   160  func (w *wrapConn) Exec(query string, args []driver.Value) (driver.Result, error) {
   161  	return w.original.(driver.Execer).Exec(query, args)
   162  }
   163  
   164  // ExecContext implements ExecerContext.
   165  func (w *wrapConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
   166  	segment := w.bld.useQuery(query).startSegment(ctx)
   167  	result, err := w.original.(driver.ExecerContext).ExecContext(ctx, query, args)
   168  	if err != driver.ErrSkip {
   169  		segment.End()
   170  	}
   171  	return result, err
   172  }
   173  
   174  // CheckNamedValue implements NamedValueChecker.
   175  func (w *wrapConn) CheckNamedValue(v *driver.NamedValue) error {
   176  	return w.original.(driver.NamedValueChecker).CheckNamedValue(v)
   177  }
   178  
   179  // Ping implements Pinger.
   180  func (w *wrapConn) Ping(ctx context.Context) error {
   181  	return w.original.(driver.Pinger).Ping(ctx)
   182  }
   183  
   184  func (w *wrapConn) Query(query string, args []driver.Value) (driver.Rows, error) {
   185  	return w.original.(driver.Queryer).Query(query, args)
   186  }
   187  
   188  // QueryContext implements QueryerContext.
   189  func (w *wrapConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
   190  	segment := w.bld.useQuery(query).startSegment(ctx)
   191  	rows, err := w.original.(driver.QueryerContext).QueryContext(ctx, query, args)
   192  	if err != driver.ErrSkip {
   193  		segment.End()
   194  	}
   195  	return rows, err
   196  }
   197  
   198  // ResetSession implements SessionResetter.
   199  func (w *wrapConn) ResetSession(ctx context.Context) error {
   200  	return w.original.(driver.SessionResetter).ResetSession(ctx)
   201  }
   202  
   203  func (w *wrapStmt) Close() error {
   204  	return w.original.Close()
   205  }
   206  
   207  func (w *wrapStmt) NumInput() int {
   208  	return w.original.NumInput()
   209  }
   210  
   211  func (w *wrapStmt) Exec(args []driver.Value) (driver.Result, error) {
   212  	return w.original.Exec(args)
   213  }
   214  
   215  func (w *wrapStmt) Query(args []driver.Value) (driver.Rows, error) {
   216  	return w.original.Query(args)
   217  }
   218  
   219  // ColumnConverter implements ColumnConverter.
   220  func (w *wrapStmt) ColumnConverter(idx int) driver.ValueConverter {
   221  	return w.original.(driver.ColumnConverter).ColumnConverter(idx)
   222  }
   223  
   224  // CheckNamedValue implements NamedValueChecker.
   225  func (w *wrapStmt) CheckNamedValue(v *driver.NamedValue) error {
   226  	return w.original.(driver.NamedValueChecker).CheckNamedValue(v)
   227  }
   228  
   229  // ExecContext implements StmtExecContext.
   230  func (w *wrapStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
   231  	segment := w.bld.startSegment(ctx)
   232  	result, err := w.original.(driver.StmtExecContext).ExecContext(ctx, args)
   233  	segment.End()
   234  	return result, err
   235  }
   236  
   237  // QueryContext implements StmtQueryContext.
   238  func (w *wrapStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
   239  	segment := w.bld.startSegment(ctx)
   240  	rows, err := w.original.(driver.StmtQueryContext).QueryContext(ctx, args)
   241  	segment.End()
   242  	return rows, err
   243  }
   244  
   245  var (
   246  	_ interface {
   247  		driver.Driver
   248  		driver.DriverContext
   249  	} = &wrapDriver{}
   250  	_ interface {
   251  		driver.Connector
   252  	} = &wrapConnector{}
   253  	_ interface {
   254  		driver.Conn
   255  		driver.ConnBeginTx
   256  		driver.ConnPrepareContext
   257  		driver.Execer
   258  		driver.ExecerContext
   259  		driver.NamedValueChecker
   260  		driver.Pinger
   261  		driver.Queryer
   262  		driver.QueryerContext
   263  	} = &wrapConn{}
   264  	_ interface {
   265  		driver.Stmt
   266  		driver.ColumnConverter
   267  		driver.NamedValueChecker
   268  		driver.StmtExecContext
   269  		driver.StmtQueryContext
   270  	} = &wrapStmt{}
   271  )