github.com/mier85/go-sensor@v1.30.1-0.20220920111756-9bf41b3bc7e0/instrumentation_sql.go (about) 1 // (c) Copyright IBM Corp. 2021 2 // (c) Copyright Instana Inc. 2020 3 4 package instana 5 6 import ( 7 "context" 8 "database/sql" 9 "database/sql/driver" 10 "errors" 11 "net/url" 12 "regexp" 13 "strings" 14 "sync" 15 16 ot "github.com/opentracing/opentracing-go" 17 "github.com/opentracing/opentracing-go/ext" 18 otlog "github.com/opentracing/opentracing-go/log" 19 20 _ "unsafe" 21 ) 22 23 var ( 24 sqlDriverRegistrationMu sync.Mutex 25 ) 26 27 // InstrumentSQLDriver instruments provided database driver for use with `sql.Open()`. 28 // This method will ignore any attempt to register the driver with the same name again. 29 // 30 // The instrumented version is registered with `_with_instana` suffix, e.g. 31 // if `postgres` provided as a name, the instrumented version is registered as 32 // `postgres_with_instana`. 33 func InstrumentSQLDriver(sensor *Sensor, name string, driver driver.Driver) { 34 sqlDriverRegistrationMu.Lock() 35 defer sqlDriverRegistrationMu.Unlock() 36 37 instrumentedName := name + "_with_instana" 38 39 // Check if the instrumented version of a driver has already been registered 40 // with database/sql and ignore the second attempt to avoid panicking 41 for _, drv := range sql.Drivers() { 42 if drv == instrumentedName { 43 return 44 } 45 } 46 47 sql.Register(instrumentedName, &wrappedSQLDriver{ 48 Driver: driver, 49 sensor: sensor, 50 }) 51 } 52 53 // SQLOpen is a convenience wrapper for `sql.Open()` to use the instrumented version 54 // of a driver previosly registered using `instana.InstrumentSQLDriver()` 55 func SQLOpen(driverName, dataSourceName string) (*sql.DB, error) { 56 57 if !strings.HasSuffix(driverName, "_with_instana") { 58 driverName += "_with_instana" 59 } 60 61 return sql.Open(driverName, dataSourceName) 62 } 63 64 //go:linkname drivers database/sql.drivers 65 var drivers map[string]driver.Driver 66 67 // SQLInstrumentAndOpen returns instrumented `*sql.DB`. 68 // It takes already registered `driver.Driver` by name, instruments it and additionally registers 69 // it with different name. After that it returns instrumented `*sql.DB` or error if any. 70 // 71 // This function can be used as a convenient shortcut for InstrumentSQLDriver and SQLOpen functions. 72 // The main difference is that this approach will use the already registered driver and using InstrumentSQLDriver 73 // requires to explicitly provide an instance of the driver to instrument. 74 func SQLInstrumentAndOpen(sensor *Sensor, driverName, dataSourceName string) (*sql.DB, error) { 75 if d, ok := drivers[driverName]; ok { 76 InstrumentSQLDriver(sensor, driverName, d) 77 } 78 79 return SQLOpen(driverName, dataSourceName) 80 } 81 82 type wrappedSQLDriver struct { 83 driver.Driver 84 85 sensor *Sensor 86 } 87 88 func (drv *wrappedSQLDriver) Open(name string) (driver.Conn, error) { 89 conn, err := drv.Driver.Open(name) 90 if err != nil { 91 return conn, err 92 } 93 94 if conn, ok := conn.(*wrappedSQLConn); ok { 95 return conn, nil 96 } 97 98 return &wrappedSQLConn{ 99 Conn: conn, 100 details: parseDBConnDetails(name), 101 sensor: drv.sensor, 102 }, nil 103 } 104 105 type wrappedSQLConn struct { 106 driver.Conn 107 108 details dbConnDetails 109 sensor *Sensor 110 } 111 112 func (conn *wrappedSQLConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { 113 sp := startSQLSpan(ctx, conn.details, query, conn.sensor) 114 defer sp.Finish() 115 116 if c, ok := conn.Conn.(driver.QueryerContext); ok { 117 res, err := c.QueryContext(ctx, query, args) 118 if err != nil && err != driver.ErrSkip { 119 sp.LogFields(otlog.Error(err)) 120 } 121 122 return res, err 123 } 124 125 if c, ok := conn.Conn.(driver.Queryer); ok { //nolint:staticcheck 126 values, err := sqlNamedValuesToValues(args) 127 if err != nil { 128 return nil, err 129 } 130 131 select { 132 default: 133 case <-ctx.Done(): 134 return nil, ctx.Err() 135 } 136 137 res, err := c.Query(query, values) 138 if err != nil && err != driver.ErrSkip { 139 sp.LogFields(otlog.Error(err)) 140 } 141 142 return res, err 143 } 144 145 return nil, driver.ErrSkip 146 } 147 148 func (conn *wrappedSQLConn) Prepare(query string) (driver.Stmt, error) { 149 stmt, err := conn.Conn.Prepare(query) 150 if err != nil { 151 return stmt, err 152 } 153 154 if stmt, ok := stmt.(*wrappedSQLStmt); ok { 155 return stmt, nil 156 } 157 158 return &wrappedSQLStmt{ 159 Stmt: stmt, 160 connDetails: conn.details, 161 query: query, 162 sensor: conn.sensor, 163 }, nil 164 } 165 166 func (conn *wrappedSQLConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { 167 var ( 168 stmt driver.Stmt 169 err error 170 ) 171 if c, ok := conn.Conn.(driver.ConnPrepareContext); ok { 172 stmt, err = c.PrepareContext(ctx, query) 173 } else { 174 stmt, err = conn.Prepare(query) 175 } 176 177 if err != nil { 178 return stmt, err 179 } 180 181 if stmt, ok := stmt.(*wrappedSQLStmt); ok { 182 return stmt, nil 183 } 184 185 return &wrappedSQLStmt{ 186 Stmt: stmt, 187 connDetails: conn.details, 188 query: query, 189 sensor: conn.sensor, 190 }, nil 191 } 192 193 func (conn *wrappedSQLConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { 194 sp := startSQLSpan(ctx, conn.details, query, conn.sensor) 195 defer sp.Finish() 196 197 if c, ok := conn.Conn.(driver.ExecerContext); ok { 198 res, err := c.ExecContext(ctx, query, args) 199 if err != nil && err != driver.ErrSkip { 200 sp.LogFields(otlog.Error(err)) 201 } 202 203 return res, err 204 } 205 206 if c, ok := conn.Conn.(driver.Execer); ok { //nolint:staticcheck 207 values, err := sqlNamedValuesToValues(args) 208 if err != nil { 209 return nil, err 210 } 211 212 select { 213 default: 214 case <-ctx.Done(): 215 return nil, ctx.Err() 216 } 217 218 res, err := c.Exec(query, values) 219 if err != nil && err != driver.ErrSkip { 220 sp.LogFields(otlog.Error(err)) 221 } 222 223 return res, err 224 } 225 226 return nil, driver.ErrSkip 227 } 228 229 type wrappedSQLStmt struct { 230 driver.Stmt 231 232 connDetails dbConnDetails 233 query string 234 sensor *Sensor 235 } 236 237 func (stmt *wrappedSQLStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { 238 sp := startSQLSpan(ctx, stmt.connDetails, stmt.query, stmt.sensor) 239 defer sp.Finish() 240 241 if s, ok := stmt.Stmt.(driver.StmtExecContext); ok { 242 res, err := s.ExecContext(ctx, args) 243 if err != nil && err != driver.ErrSkip { 244 sp.LogFields(otlog.Error(err)) 245 } 246 247 return res, err 248 } 249 250 values, err := sqlNamedValuesToValues(args) 251 if err != nil { 252 return nil, err 253 } 254 255 select { 256 default: 257 case <-ctx.Done(): 258 return nil, ctx.Err() 259 } 260 261 res, err := stmt.Exec(values) //nolint:staticcheck 262 if err != nil && err != driver.ErrSkip { 263 sp.LogFields(otlog.Error(err)) 264 } 265 266 return res, err 267 } 268 269 func (stmt *wrappedSQLStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { 270 sp := startSQLSpan(ctx, stmt.connDetails, stmt.query, stmt.sensor) 271 defer sp.Finish() 272 273 if s, ok := stmt.Stmt.(driver.StmtQueryContext); ok { 274 res, err := s.QueryContext(ctx, args) 275 if err != nil && err != driver.ErrSkip { 276 sp.LogFields(otlog.Error(err)) 277 } 278 279 return res, err 280 } 281 282 values, err := sqlNamedValuesToValues(args) 283 if err != nil { 284 return nil, err 285 } 286 287 select { 288 default: 289 case <-ctx.Done(): 290 return nil, ctx.Err() 291 } 292 293 res, err := stmt.Stmt.Query(values) //nolint:staticcheck 294 if err != nil && err != driver.ErrSkip { 295 sp.LogFields(otlog.Error(err)) 296 } 297 298 return res, err 299 } 300 301 func startSQLSpan(ctx context.Context, conn dbConnDetails, query string, sensor *Sensor) ot.Span { 302 tags := ot.Tags{ 303 string(ext.DBType): "sql", 304 string(ext.DBStatement): query, 305 string(ext.PeerAddress): conn.RawString, 306 } 307 308 if conn.Schema != "" { 309 tags[string(ext.DBInstance)] = conn.Schema 310 } else { 311 tags[string(ext.DBInstance)] = conn.RawString 312 } 313 314 if conn.Host != "" { 315 tags[string(ext.PeerHostname)] = conn.Host 316 } 317 318 if conn.Port != "" { 319 tags[string(ext.PeerPort)] = conn.Port 320 } 321 322 opts := []ot.StartSpanOption{ext.SpanKindRPCClient, tags} 323 if parentSpan, ok := SpanFromContext(ctx); ok { 324 opts = append(opts, ot.ChildOf(parentSpan.Context())) 325 } 326 327 return sensor.Tracer().StartSpan("sdk.database", opts...) 328 } 329 330 type dbConnDetails struct { 331 RawString string 332 Host, Port string 333 Schema string 334 User string 335 } 336 337 func parseDBConnDetails(connStr string) dbConnDetails { 338 strategies := [...]func(string) (dbConnDetails, bool){ 339 parseDBConnDetailsURI, 340 parsePostgresConnDetailsKV, 341 parseMySQLConnDetailsKV, 342 } 343 for _, parseFn := range strategies { 344 if details, ok := parseFn(connStr); ok { 345 return details 346 } 347 } 348 349 return dbConnDetails{RawString: connStr} 350 } 351 352 // parseDBConnDetailsURI attempts to parse a connection string as an URI, assuming that it has 353 // following format: [scheme://][user[:[password]]@]host[:port][/schema][?attribute1=value1&attribute2=value2...] 354 func parseDBConnDetailsURI(connStr string) (dbConnDetails, bool) { 355 u, err := url.Parse(connStr) 356 if err != nil { 357 return dbConnDetails{}, false 358 } 359 360 if u.Scheme == "" { 361 return dbConnDetails{}, false 362 } 363 364 path := "" 365 if len(u.Path) > 1 { 366 path = u.Path[1:] 367 } 368 369 details := dbConnDetails{ 370 RawString: connStr, 371 Host: u.Hostname(), 372 Port: u.Port(), 373 Schema: path, 374 } 375 376 if u.User != nil { 377 details.User = u.User.Username() 378 379 // create a copy without user password 380 u := cloneURL(u) 381 u.User = url.User(details.User) 382 details.RawString = u.String() 383 } 384 385 return details, true 386 } 387 388 var postgresKVPasswordRegex = regexp.MustCompile(`(^|\s)password=[^\s]+(\s|$)`) 389 390 // parsePostgresConnDetailsKV parses a space-separated PostgreSQL-style connection string 391 func parsePostgresConnDetailsKV(connStr string) (dbConnDetails, bool) { 392 var details dbConnDetails 393 394 for _, field := range strings.Split(connStr, " ") { 395 fieldNorm := strings.ToLower(field) 396 397 var ( 398 prefix string 399 fieldPtr *string 400 ) 401 switch { 402 case strings.HasPrefix(fieldNorm, "host="): 403 if details.Host != "" { 404 // hostaddr= takes precedence 405 continue 406 } 407 408 prefix, fieldPtr = "host=", &details.Host 409 case strings.HasPrefix(fieldNorm, "hostaddr="): 410 prefix, fieldPtr = "hostaddr=", &details.Host 411 case strings.HasPrefix(fieldNorm, "port="): 412 prefix, fieldPtr = "port=", &details.Port 413 case strings.HasPrefix(fieldNorm, "user="): 414 prefix, fieldPtr = "user=", &details.User 415 case strings.HasPrefix(fieldNorm, "dbname="): 416 prefix, fieldPtr = "dbname=", &details.Schema 417 default: 418 continue 419 } 420 421 *fieldPtr = field[len(prefix):] 422 } 423 424 if details.Schema == "" { 425 return dbConnDetails{}, false 426 } 427 428 details.RawString = postgresKVPasswordRegex.ReplaceAllString(connStr, " ") 429 430 return details, true 431 } 432 433 var mysqlKVPasswordRegex = regexp.MustCompile(`(?i)(^|;)Pwd=[^;]+(;|$)`) 434 435 // parseMySQLConnDetailsKV parses a semicolon-separated MySQL-style connection string 436 func parseMySQLConnDetailsKV(connStr string) (dbConnDetails, bool) { 437 details := dbConnDetails{RawString: connStr} 438 439 for _, field := range strings.Split(connStr, ";") { 440 fieldNorm := strings.ToLower(field) 441 442 var ( 443 prefix string 444 fieldPtr *string 445 ) 446 switch { 447 case strings.HasPrefix(fieldNorm, "server="): 448 prefix, fieldPtr = "server=", &details.Host 449 case strings.HasPrefix(fieldNorm, "port="): 450 prefix, fieldPtr = "port=", &details.Port 451 case strings.HasPrefix(fieldNorm, "uid="): 452 prefix, fieldPtr = "uid=", &details.User 453 case strings.HasPrefix(fieldNorm, "database="): 454 prefix, fieldPtr = "database=", &details.Schema 455 default: 456 continue 457 } 458 459 *fieldPtr = field[len(prefix):] 460 } 461 462 if details.Schema == "" { 463 return dbConnDetails{}, false 464 } 465 466 details.RawString = mysqlKVPasswordRegex.ReplaceAllString(connStr, ";") 467 468 return details, true 469 } 470 471 // The following code is ported from $GOROOT/src/database/sql/ctxutil.go 472 // 473 // Copyright 2019 The Go Authors. All rights reserved. 474 // Use of this source code is governed by a BSD-style 475 // license that can be found in the LICENSE file. 476 func sqlNamedValuesToValues(named []driver.NamedValue) ([]driver.Value, error) { 477 dargs := make([]driver.Value, len(named)) 478 for n, param := range named { 479 if len(param.Name) > 0 { 480 return nil, errors.New("sql: driver does not support the use of Named Parameters") 481 } 482 dargs[n] = param.Value 483 } 484 return dargs, nil 485 } 486 487 type dsnConnector struct { 488 dsn string 489 driver driver.Driver 490 } 491 492 func (t dsnConnector) Connect(_ context.Context) (driver.Conn, error) { 493 return t.driver.Open(t.dsn) 494 } 495 496 func (t dsnConnector) Driver() driver.Driver { 497 return t.driver 498 }