github.com/pinpoint-apm/pinpoint-go-agent@v1.4.1-0.20240110120318-a50c2eb18c8c/sql_driver.go (about)

     1  package pinpoint
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"database/sql/driver"
     7  	"errors"
     8  	"fmt"
     9  	"time"
    10  )
    11  
    12  type DBInfo struct {
    13  	DBType    int
    14  	QueryType int
    15  	DBName    string
    16  	DBHost    string
    17  
    18  	ParseDSN func(info *DBInfo, dsn string)
    19  }
    20  
    21  func parseDSN(info *DBInfo, dsn string) {
    22  	if f := info.ParseDSN; f != nil {
    23  		f(info, dsn)
    24  	}
    25  }
    26  
    27  // NewDatabaseTracer returns a Tracer for database operation.
    28  func NewDatabaseTracer(ctx context.Context, funcName string, info *DBInfo) Tracer {
    29  	tracer := FromContext(ctx)
    30  	tracer.NewSpanEvent(funcName)
    31  	se := tracer.SpanEvent()
    32  	se.SetServiceType(int32(info.QueryType))
    33  	se.SetEndPoint(info.DBHost)
    34  	se.SetDestination(info.DBName)
    35  
    36  	return tracer
    37  }
    38  
    39  func wrapDriver(drv *sqlDriver) driver.Driver {
    40  	if _, ok := drv.Driver.(driver.DriverContext); ok {
    41  		return struct {
    42  			driver.Driver
    43  			driver.DriverContext
    44  		}{drv, drv}
    45  	} else {
    46  		return struct {
    47  			driver.Driver
    48  		}{drv}
    49  	}
    50  }
    51  
    52  // WrapSQLDriver wraps a driver.Driver and instruments SQL query calls.
    53  func WrapSQLDriver(drv driver.Driver, info DBInfo) driver.Driver {
    54  	return wrapDriver(&sqlDriver{Driver: drv, dbInfo: info})
    55  }
    56  
    57  type sqlDriver struct {
    58  	driver.Driver
    59  	dbInfo DBInfo
    60  }
    61  
    62  func (d *sqlDriver) Open(name string) (driver.Conn, error) {
    63  	conn, err := d.Driver.Open(name)
    64  	if err != nil {
    65  		return nil, err
    66  	}
    67  
    68  	sc := newSqlConn(conn, d.dbInfo)
    69  	parseDSN(&sc.dbInfo, name)
    70  	return sc, nil
    71  }
    72  
    73  func (d *sqlDriver) OpenConnector(name string) (driver.Connector, error) {
    74  	conn, err := d.Driver.(driver.DriverContext).OpenConnector(name)
    75  	if err != nil {
    76  		return nil, err
    77  	}
    78  
    79  	sc := &sqlConnector{
    80  		Connector: conn,
    81  		dbInfo:    d.dbInfo,
    82  		driver:    d,
    83  	}
    84  
    85  	parseDSN(&sc.dbInfo, name)
    86  	return sc, nil
    87  }
    88  
    89  type sqlConnector struct {
    90  	driver.Connector
    91  	dbInfo DBInfo
    92  	driver *sqlDriver
    93  }
    94  
    95  func (c *sqlConnector) Connect(ctx context.Context) (driver.Conn, error) {
    96  	if conn, err := c.Connector.Connect(ctx); err != nil {
    97  		return nil, err
    98  	} else {
    99  		return newSqlConn(conn, c.dbInfo), nil
   100  	}
   101  }
   102  
   103  func (c *sqlConnector) Driver() driver.Driver {
   104  	return c.driver
   105  }
   106  
   107  type sqlConn struct {
   108  	driver.Conn
   109  	dbInfo DBInfo
   110  	config *Config
   111  }
   112  
   113  func newSqlConn(conn driver.Conn, dbInfo DBInfo) *sqlConn {
   114  	return &sqlConn{
   115  		Conn:   conn,
   116  		dbInfo: dbInfo,
   117  		config: GetConfig(),
   118  	}
   119  }
   120  
   121  func prepare(stmt driver.Stmt, err error, conn *sqlConn, sql string) (driver.Stmt, error) {
   122  	if nil != err {
   123  		return nil, err
   124  	}
   125  
   126  	return &sqlStmt{
   127  		Stmt: stmt,
   128  		conn: conn,
   129  		sql:  sql,
   130  	}, nil
   131  }
   132  
   133  func (c *sqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
   134  	if cpc, ok := c.Conn.(driver.ConnPrepareContext); ok {
   135  		stmt, err := cpc.PrepareContext(ctx, query)
   136  		return prepare(stmt, err, c, query)
   137  	}
   138  
   139  	stmt, err := c.Conn.Prepare(query)
   140  	return prepare(stmt, err, c, query)
   141  }
   142  
   143  func (c *sqlConn) newSqlSpanEventWithNamedValue(ctx context.Context, operation string, start time.Time, err error, sql string, args []driver.NamedValue) {
   144  	tracer := NewDatabaseTracer(ctx, operation, &c.dbInfo)
   145  	defer tracer.EndSpanEvent()
   146  
   147  	if tracer.IsSampled() {
   148  		setSqlSpanEvent(tracer, start, err, sql, c.namedValueToString(args))
   149  	}
   150  }
   151  
   152  func (c *sqlConn) newSqlSpanEventWithValue(ctx context.Context, operation string, start time.Time, err error, sql string, args []driver.Value) {
   153  	tracer := NewDatabaseTracer(ctx, operation, &c.dbInfo)
   154  	defer tracer.EndSpanEvent()
   155  
   156  	if tracer.IsSampled() {
   157  		setSqlSpanEvent(tracer, start, err, sql, c.valueToString(args))
   158  	}
   159  }
   160  
   161  func (c *sqlConn) newSqlSpanEventNoSql(ctx context.Context, operation string, start time.Time, err error) {
   162  	tracer := NewDatabaseTracer(ctx, operation, &c.dbInfo)
   163  	defer tracer.EndSpanEvent()
   164  
   165  	if tracer.IsSampled() {
   166  		setSqlSpanEvent(tracer, start, err, "", "")
   167  	}
   168  }
   169  
   170  func setSqlSpanEvent(tracer Tracer, start time.Time, err error, sql string, args string) {
   171  	tracer.SpanEvent().SetSQL(sql, args)
   172  	tracer.SpanEvent().SetError(err, "SQL error")
   173  	tracer.SpanEvent().FixDuration(start, time.Now())
   174  }
   175  
   176  func (c *sqlConn) namedValueToString(named []driver.NamedValue) string {
   177  	if !c.config.sqlTraceBindValue || named == nil {
   178  		return ""
   179  	}
   180  
   181  	var b bytes.Buffer
   182  	numComma := len(named) - 1
   183  	for i, param := range named {
   184  		if !writeBindValue(&b, i, param.Value, numComma, c.config.sqlMaxBindValueSize) {
   185  			break
   186  		}
   187  	}
   188  	return b.String()
   189  }
   190  
   191  func (c *sqlConn) valueToString(values []driver.Value) string {
   192  	if !c.config.sqlTraceBindValue || values == nil {
   193  		return ""
   194  	}
   195  
   196  	var b bytes.Buffer
   197  	numComma := len(values) - 1
   198  	for i, v := range values {
   199  		if !writeBindValue(&b, i, v, numComma, c.config.sqlMaxBindValueSize) {
   200  			break
   201  		}
   202  	}
   203  	return b.String()
   204  }
   205  
   206  func writeBindValue(b *bytes.Buffer, index int, value interface{}, numComma int, maxSize int) bool {
   207  	b.WriteString(fmt.Sprint(value))
   208  	if index < numComma {
   209  		b.WriteString(", ")
   210  	}
   211  	if b.Len() > maxSize {
   212  		b.WriteString("...(")
   213  		b.WriteString(fmt.Sprint(maxSize))
   214  		b.WriteString(")")
   215  		return false
   216  	}
   217  	return true
   218  }
   219  
   220  func (c *sqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
   221  	start := time.Now()
   222  
   223  	if ec, ok := c.Conn.(driver.ExecerContext); ok {
   224  		result, err := ec.ExecContext(ctx, query, args)
   225  
   226  		if err != driver.ErrSkip {
   227  			c.newSqlSpanEventWithNamedValue(ctx, "ConnExecContext", start, err, query, args)
   228  		}
   229  
   230  		return result, err
   231  	}
   232  
   233  	// sourced: database/sql/cxtutil.go
   234  	dargs, err := namedValueToValue(args)
   235  	if err != nil {
   236  		return nil, err
   237  	}
   238  	select {
   239  	default:
   240  	case <-ctx.Done():
   241  		return nil, ctx.Err()
   242  	}
   243  
   244  	if e, ok := c.Conn.(driver.Execer); ok {
   245  		result, err := e.Exec(query, dargs)
   246  		if err != driver.ErrSkip {
   247  			c.newSqlSpanEventWithValue(ctx, "ConnExec", start, err, query, dargs)
   248  		}
   249  
   250  		return result, err
   251  	}
   252  
   253  	return nil, driver.ErrSkip
   254  }
   255  
   256  func (c *sqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
   257  	start := time.Now()
   258  
   259  	if qc, ok := c.Conn.(driver.QueryerContext); ok {
   260  		rows, err := qc.QueryContext(ctx, query, args)
   261  		if err != driver.ErrSkip {
   262  			c.newSqlSpanEventWithNamedValue(ctx, "ConnQueryContext", start, err, query, args)
   263  		}
   264  
   265  		return rows, err
   266  	}
   267  
   268  	// sourced: database/sql/cxtutil.go
   269  	dargs, err := namedValueToValue(args)
   270  	if err != nil {
   271  		return nil, err
   272  	}
   273  	select {
   274  	default:
   275  	case <-ctx.Done():
   276  		return nil, ctx.Err()
   277  	}
   278  
   279  	if q, ok := c.Conn.(driver.Queryer); ok {
   280  		rows, err := q.Query(query, dargs)
   281  		if err != driver.ErrSkip {
   282  			c.newSqlSpanEventWithValue(ctx, "ConnQuery", start, err, query, dargs)
   283  		}
   284  
   285  		return rows, err
   286  	}
   287  
   288  	return nil, driver.ErrSkip
   289  }
   290  
   291  func (c *sqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
   292  	var tx driver.Tx
   293  	var err error
   294  
   295  	start := time.Now()
   296  	if cbt, ok := c.Conn.(driver.ConnBeginTx); ok {
   297  		tx, err = cbt.BeginTx(ctx, opts)
   298  		if c.config.sqlTraceCommit || c.config.sqlTraceRollback {
   299  			c.newSqlSpanEventNoSql(ctx, "BeginTx", start, err)
   300  			if err == nil {
   301  				tx = &sqlTx{tx, c, ctx}
   302  			}
   303  		}
   304  		return tx, err
   305  	}
   306  
   307  	tx, err = c.Conn.Begin()
   308  	if c.config.sqlTraceCommit || c.config.sqlTraceRollback {
   309  		c.newSqlSpanEventNoSql(ctx, "Begin", start, err)
   310  		if err == nil {
   311  			tx = &sqlTx{tx, c, ctx}
   312  		}
   313  	}
   314  	return tx, err
   315  }
   316  
   317  type sqlTx struct {
   318  	driver.Tx
   319  	conn *sqlConn
   320  	ctx  context.Context
   321  }
   322  
   323  func (t *sqlTx) Commit() (err error) {
   324  	start := time.Now()
   325  	err = t.Tx.Commit()
   326  	if t.conn.config.sqlTraceCommit {
   327  		t.conn.newSqlSpanEventNoSql(t.ctx, "Commit", start, err)
   328  	}
   329  	return err
   330  }
   331  
   332  func (t *sqlTx) Rollback() (err error) {
   333  	start := time.Now()
   334  	err = t.Tx.Rollback()
   335  	if t.conn.config.sqlTraceRollback {
   336  		t.conn.newSqlSpanEventNoSql(t.ctx, "Rollback", start, err)
   337  	}
   338  	return err
   339  }
   340  
   341  type sqlStmt struct {
   342  	driver.Stmt
   343  	conn *sqlConn
   344  	sql  string
   345  }
   346  
   347  func (s *sqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
   348  	start := time.Now()
   349  
   350  	if sec, ok := s.Stmt.(driver.StmtExecContext); ok {
   351  		result, err := sec.ExecContext(ctx, args)
   352  		s.conn.newSqlSpanEventWithNamedValue(ctx, "StmtExecContext", start, err, s.sql, args)
   353  		return result, err
   354  	}
   355  
   356  	// sourced: database/sql/cxtutil.go
   357  	dargs, err := namedValueToValue(args)
   358  	if err != nil {
   359  		return nil, err
   360  	}
   361  	select {
   362  	default:
   363  	case <-ctx.Done():
   364  		return nil, ctx.Err()
   365  	}
   366  
   367  	result, err := s.Stmt.Exec(dargs)
   368  	s.conn.newSqlSpanEventWithValue(ctx, "StmtExec", start, err, s.sql, dargs)
   369  	return result, err
   370  }
   371  
   372  func (s *sqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
   373  	start := time.Now()
   374  
   375  	if sqc, ok := s.Stmt.(driver.StmtQueryContext); ok {
   376  		rows, err := sqc.QueryContext(ctx, args)
   377  		s.conn.newSqlSpanEventWithNamedValue(ctx, "StmtQueryContext", start, err, s.sql, args)
   378  		return rows, err
   379  	}
   380  
   381  	// sourced: database/sql/cxtutil.go
   382  	dargs, err := namedValueToValue(args)
   383  	if err != nil {
   384  		return nil, err
   385  	}
   386  	select {
   387  	default:
   388  	case <-ctx.Done():
   389  		return nil, ctx.Err()
   390  	}
   391  
   392  	rows, err := s.Stmt.Query(dargs)
   393  	s.conn.newSqlSpanEventWithValue(ctx, "StmtQuery", start, err, s.sql, dargs)
   394  	return rows, err
   395  }
   396  
   397  // sourced: database/sql/cxtutil.go
   398  func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
   399  	dargs := make([]driver.Value, len(named))
   400  	for n, param := range named {
   401  		if len(param.Name) > 0 {
   402  			return nil, errors.New("sql: driver does not support the use of Named Parameters")
   403  		}
   404  		dargs[n] = param.Value
   405  	}
   406  	return dargs, nil
   407  }