github.com/pinpoint-apm/pinpoint-go-agent@v1.4.1-0.20240110120318-a50c2eb18c8c/sql_driver.go (about) 1 package pinpoint 2 3 import ( 4 "bytes" 5 "context" 6 "database/sql/driver" 7 "errors" 8 "fmt" 9 "time" 10 ) 11 12 type DBInfo struct { 13 DBType int 14 QueryType int 15 DBName string 16 DBHost string 17 18 ParseDSN func(info *DBInfo, dsn string) 19 } 20 21 func parseDSN(info *DBInfo, dsn string) { 22 if f := info.ParseDSN; f != nil { 23 f(info, dsn) 24 } 25 } 26 27 // NewDatabaseTracer returns a Tracer for database operation. 28 func NewDatabaseTracer(ctx context.Context, funcName string, info *DBInfo) Tracer { 29 tracer := FromContext(ctx) 30 tracer.NewSpanEvent(funcName) 31 se := tracer.SpanEvent() 32 se.SetServiceType(int32(info.QueryType)) 33 se.SetEndPoint(info.DBHost) 34 se.SetDestination(info.DBName) 35 36 return tracer 37 } 38 39 func wrapDriver(drv *sqlDriver) driver.Driver { 40 if _, ok := drv.Driver.(driver.DriverContext); ok { 41 return struct { 42 driver.Driver 43 driver.DriverContext 44 }{drv, drv} 45 } else { 46 return struct { 47 driver.Driver 48 }{drv} 49 } 50 } 51 52 // WrapSQLDriver wraps a driver.Driver and instruments SQL query calls. 53 func WrapSQLDriver(drv driver.Driver, info DBInfo) driver.Driver { 54 return wrapDriver(&sqlDriver{Driver: drv, dbInfo: info}) 55 } 56 57 type sqlDriver struct { 58 driver.Driver 59 dbInfo DBInfo 60 } 61 62 func (d *sqlDriver) Open(name string) (driver.Conn, error) { 63 conn, err := d.Driver.Open(name) 64 if err != nil { 65 return nil, err 66 } 67 68 sc := newSqlConn(conn, d.dbInfo) 69 parseDSN(&sc.dbInfo, name) 70 return sc, nil 71 } 72 73 func (d *sqlDriver) OpenConnector(name string) (driver.Connector, error) { 74 conn, err := d.Driver.(driver.DriverContext).OpenConnector(name) 75 if err != nil { 76 return nil, err 77 } 78 79 sc := &sqlConnector{ 80 Connector: conn, 81 dbInfo: d.dbInfo, 82 driver: d, 83 } 84 85 parseDSN(&sc.dbInfo, name) 86 return sc, nil 87 } 88 89 type sqlConnector struct { 90 driver.Connector 91 dbInfo DBInfo 92 driver *sqlDriver 93 } 94 95 func (c *sqlConnector) Connect(ctx context.Context) (driver.Conn, error) { 96 if conn, err := c.Connector.Connect(ctx); err != nil { 97 return nil, err 98 } else { 99 return newSqlConn(conn, c.dbInfo), nil 100 } 101 } 102 103 func (c *sqlConnector) Driver() driver.Driver { 104 return c.driver 105 } 106 107 type sqlConn struct { 108 driver.Conn 109 dbInfo DBInfo 110 config *Config 111 } 112 113 func newSqlConn(conn driver.Conn, dbInfo DBInfo) *sqlConn { 114 return &sqlConn{ 115 Conn: conn, 116 dbInfo: dbInfo, 117 config: GetConfig(), 118 } 119 } 120 121 func prepare(stmt driver.Stmt, err error, conn *sqlConn, sql string) (driver.Stmt, error) { 122 if nil != err { 123 return nil, err 124 } 125 126 return &sqlStmt{ 127 Stmt: stmt, 128 conn: conn, 129 sql: sql, 130 }, nil 131 } 132 133 func (c *sqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { 134 if cpc, ok := c.Conn.(driver.ConnPrepareContext); ok { 135 stmt, err := cpc.PrepareContext(ctx, query) 136 return prepare(stmt, err, c, query) 137 } 138 139 stmt, err := c.Conn.Prepare(query) 140 return prepare(stmt, err, c, query) 141 } 142 143 func (c *sqlConn) newSqlSpanEventWithNamedValue(ctx context.Context, operation string, start time.Time, err error, sql string, args []driver.NamedValue) { 144 tracer := NewDatabaseTracer(ctx, operation, &c.dbInfo) 145 defer tracer.EndSpanEvent() 146 147 if tracer.IsSampled() { 148 setSqlSpanEvent(tracer, start, err, sql, c.namedValueToString(args)) 149 } 150 } 151 152 func (c *sqlConn) newSqlSpanEventWithValue(ctx context.Context, operation string, start time.Time, err error, sql string, args []driver.Value) { 153 tracer := NewDatabaseTracer(ctx, operation, &c.dbInfo) 154 defer tracer.EndSpanEvent() 155 156 if tracer.IsSampled() { 157 setSqlSpanEvent(tracer, start, err, sql, c.valueToString(args)) 158 } 159 } 160 161 func (c *sqlConn) newSqlSpanEventNoSql(ctx context.Context, operation string, start time.Time, err error) { 162 tracer := NewDatabaseTracer(ctx, operation, &c.dbInfo) 163 defer tracer.EndSpanEvent() 164 165 if tracer.IsSampled() { 166 setSqlSpanEvent(tracer, start, err, "", "") 167 } 168 } 169 170 func setSqlSpanEvent(tracer Tracer, start time.Time, err error, sql string, args string) { 171 tracer.SpanEvent().SetSQL(sql, args) 172 tracer.SpanEvent().SetError(err, "SQL error") 173 tracer.SpanEvent().FixDuration(start, time.Now()) 174 } 175 176 func (c *sqlConn) namedValueToString(named []driver.NamedValue) string { 177 if !c.config.sqlTraceBindValue || named == nil { 178 return "" 179 } 180 181 var b bytes.Buffer 182 numComma := len(named) - 1 183 for i, param := range named { 184 if !writeBindValue(&b, i, param.Value, numComma, c.config.sqlMaxBindValueSize) { 185 break 186 } 187 } 188 return b.String() 189 } 190 191 func (c *sqlConn) valueToString(values []driver.Value) string { 192 if !c.config.sqlTraceBindValue || values == nil { 193 return "" 194 } 195 196 var b bytes.Buffer 197 numComma := len(values) - 1 198 for i, v := range values { 199 if !writeBindValue(&b, i, v, numComma, c.config.sqlMaxBindValueSize) { 200 break 201 } 202 } 203 return b.String() 204 } 205 206 func writeBindValue(b *bytes.Buffer, index int, value interface{}, numComma int, maxSize int) bool { 207 b.WriteString(fmt.Sprint(value)) 208 if index < numComma { 209 b.WriteString(", ") 210 } 211 if b.Len() > maxSize { 212 b.WriteString("...(") 213 b.WriteString(fmt.Sprint(maxSize)) 214 b.WriteString(")") 215 return false 216 } 217 return true 218 } 219 220 func (c *sqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { 221 start := time.Now() 222 223 if ec, ok := c.Conn.(driver.ExecerContext); ok { 224 result, err := ec.ExecContext(ctx, query, args) 225 226 if err != driver.ErrSkip { 227 c.newSqlSpanEventWithNamedValue(ctx, "ConnExecContext", start, err, query, args) 228 } 229 230 return result, err 231 } 232 233 // sourced: database/sql/cxtutil.go 234 dargs, err := namedValueToValue(args) 235 if err != nil { 236 return nil, err 237 } 238 select { 239 default: 240 case <-ctx.Done(): 241 return nil, ctx.Err() 242 } 243 244 if e, ok := c.Conn.(driver.Execer); ok { 245 result, err := e.Exec(query, dargs) 246 if err != driver.ErrSkip { 247 c.newSqlSpanEventWithValue(ctx, "ConnExec", start, err, query, dargs) 248 } 249 250 return result, err 251 } 252 253 return nil, driver.ErrSkip 254 } 255 256 func (c *sqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { 257 start := time.Now() 258 259 if qc, ok := c.Conn.(driver.QueryerContext); ok { 260 rows, err := qc.QueryContext(ctx, query, args) 261 if err != driver.ErrSkip { 262 c.newSqlSpanEventWithNamedValue(ctx, "ConnQueryContext", start, err, query, args) 263 } 264 265 return rows, err 266 } 267 268 // sourced: database/sql/cxtutil.go 269 dargs, err := namedValueToValue(args) 270 if err != nil { 271 return nil, err 272 } 273 select { 274 default: 275 case <-ctx.Done(): 276 return nil, ctx.Err() 277 } 278 279 if q, ok := c.Conn.(driver.Queryer); ok { 280 rows, err := q.Query(query, dargs) 281 if err != driver.ErrSkip { 282 c.newSqlSpanEventWithValue(ctx, "ConnQuery", start, err, query, dargs) 283 } 284 285 return rows, err 286 } 287 288 return nil, driver.ErrSkip 289 } 290 291 func (c *sqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { 292 var tx driver.Tx 293 var err error 294 295 start := time.Now() 296 if cbt, ok := c.Conn.(driver.ConnBeginTx); ok { 297 tx, err = cbt.BeginTx(ctx, opts) 298 if c.config.sqlTraceCommit || c.config.sqlTraceRollback { 299 c.newSqlSpanEventNoSql(ctx, "BeginTx", start, err) 300 if err == nil { 301 tx = &sqlTx{tx, c, ctx} 302 } 303 } 304 return tx, err 305 } 306 307 tx, err = c.Conn.Begin() 308 if c.config.sqlTraceCommit || c.config.sqlTraceRollback { 309 c.newSqlSpanEventNoSql(ctx, "Begin", start, err) 310 if err == nil { 311 tx = &sqlTx{tx, c, ctx} 312 } 313 } 314 return tx, err 315 } 316 317 type sqlTx struct { 318 driver.Tx 319 conn *sqlConn 320 ctx context.Context 321 } 322 323 func (t *sqlTx) Commit() (err error) { 324 start := time.Now() 325 err = t.Tx.Commit() 326 if t.conn.config.sqlTraceCommit { 327 t.conn.newSqlSpanEventNoSql(t.ctx, "Commit", start, err) 328 } 329 return err 330 } 331 332 func (t *sqlTx) Rollback() (err error) { 333 start := time.Now() 334 err = t.Tx.Rollback() 335 if t.conn.config.sqlTraceRollback { 336 t.conn.newSqlSpanEventNoSql(t.ctx, "Rollback", start, err) 337 } 338 return err 339 } 340 341 type sqlStmt struct { 342 driver.Stmt 343 conn *sqlConn 344 sql string 345 } 346 347 func (s *sqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { 348 start := time.Now() 349 350 if sec, ok := s.Stmt.(driver.StmtExecContext); ok { 351 result, err := sec.ExecContext(ctx, args) 352 s.conn.newSqlSpanEventWithNamedValue(ctx, "StmtExecContext", start, err, s.sql, args) 353 return result, err 354 } 355 356 // sourced: database/sql/cxtutil.go 357 dargs, err := namedValueToValue(args) 358 if err != nil { 359 return nil, err 360 } 361 select { 362 default: 363 case <-ctx.Done(): 364 return nil, ctx.Err() 365 } 366 367 result, err := s.Stmt.Exec(dargs) 368 s.conn.newSqlSpanEventWithValue(ctx, "StmtExec", start, err, s.sql, dargs) 369 return result, err 370 } 371 372 func (s *sqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { 373 start := time.Now() 374 375 if sqc, ok := s.Stmt.(driver.StmtQueryContext); ok { 376 rows, err := sqc.QueryContext(ctx, args) 377 s.conn.newSqlSpanEventWithNamedValue(ctx, "StmtQueryContext", start, err, s.sql, args) 378 return rows, err 379 } 380 381 // sourced: database/sql/cxtutil.go 382 dargs, err := namedValueToValue(args) 383 if err != nil { 384 return nil, err 385 } 386 select { 387 default: 388 case <-ctx.Done(): 389 return nil, ctx.Err() 390 } 391 392 rows, err := s.Stmt.Query(dargs) 393 s.conn.newSqlSpanEventWithValue(ctx, "StmtQuery", start, err, s.sql, dargs) 394 return rows, err 395 } 396 397 // sourced: database/sql/cxtutil.go 398 func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { 399 dargs := make([]driver.Value, len(named)) 400 for n, param := range named { 401 if len(param.Name) > 0 { 402 return nil, errors.New("sql: driver does not support the use of Named Parameters") 403 } 404 dargs[n] = param.Value 405 } 406 return dargs, nil 407 }