go.undefinedlabs.com/scopeagent@v0.4.2/instrumentation/sql/vendor_postgres.go (about)

     1  package sql
     2  
     3  import (
     4  	"fmt"
     5  	"net"
     6  	nurl "net/url"
     7  	"sort"
     8  	"strings"
     9  	"unicode"
    10  )
    11  
    12  type postgresExtension struct{}
    13  
    14  // scanner implements a tokenizer for libpq-style option strings.
    15  type scanner struct {
    16  	s []rune
    17  	i int
    18  }
    19  
    20  func init() {
    21  	vendorExtensions = append(vendorExtensions, &postgresExtension{})
    22  }
    23  
    24  // Gets if the extension is compatible with the component name
    25  func (ext *postgresExtension) IsCompatible(componentName string) bool {
    26  	return componentName == "pq.Driver" || componentName == "stdlib.Driver" || componentName == "pgsqldriver.postgresDriver"
    27  }
    28  
    29  // Complete the missing driver data from the connection string
    30  func (ext *postgresExtension) ProcessConnectionString(connectionString string, configuration *driverConfiguration) {
    31  	configuration.peerService = "postgresql"
    32  
    33  	dsn := connectionString
    34  	if strings.HasPrefix(connectionString, "postgres://") || strings.HasPrefix(connectionString, "postgresql://") {
    35  		if pDsn, err := ext.parseUrl(connectionString); err == nil {
    36  			dsn = pDsn
    37  		}
    38  	}
    39  	o := make(values)
    40  	o["host"] = "localhost"
    41  	o["port"] = "5432"
    42  	_ = ext.parseOpts(dsn, o)
    43  
    44  	if user, ok := o["user"]; ok {
    45  		configuration.user = user
    46  	}
    47  	if port, ok := o["port"]; ok {
    48  		configuration.port = port
    49  	}
    50  	if dbname, ok := o["dbname"]; ok {
    51  		configuration.instance = dbname
    52  	}
    53  	if host, ok := o["host"]; ok {
    54  		configuration.host = host
    55  	}
    56  
    57  	configuration.connString = strings.Replace(connectionString, o["password"], "******", -1)
    58  }
    59  
    60  // postgress ParseURL no longer needs to be used by clients of this library since supplying a URL as a
    61  // connection string to sql.Open() is now supported:
    62  //
    63  //	sql.Open("postgres", "postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full")
    64  //
    65  // It remains exported here for backwards-compatibility.
    66  //
    67  // ParseURL converts a url to a connection string for driver.Open.
    68  // Example:
    69  //
    70  //	"postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full"
    71  //
    72  // converts to:
    73  //
    74  //	"user=bob password=secret host=1.2.3.4 port=5432 dbname=mydb sslmode=verify-full"
    75  //
    76  // A minimal example:
    77  //
    78  //	"postgres://"
    79  //
    80  // This will be blank, causing driver.Open to use all of the defaults
    81  func (ext *postgresExtension) parseUrl(url string) (string, error) {
    82  	u, err := nurl.Parse(url)
    83  	if err != nil {
    84  		return "", err
    85  	}
    86  
    87  	if u.Scheme != "postgres" && u.Scheme != "postgresql" {
    88  		return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme)
    89  	}
    90  
    91  	var kvs []string
    92  	escaper := strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`)
    93  	accrue := func(k, v string) {
    94  		if v != "" {
    95  			kvs = append(kvs, k+"="+escaper.Replace(v))
    96  		}
    97  	}
    98  
    99  	if u.User != nil {
   100  		v := u.User.Username()
   101  		accrue("user", v)
   102  
   103  		v, _ = u.User.Password()
   104  		accrue("password", v)
   105  	}
   106  
   107  	if host, port, err := net.SplitHostPort(u.Host); err != nil {
   108  		accrue("host", u.Host)
   109  	} else {
   110  		accrue("host", host)
   111  		accrue("port", port)
   112  	}
   113  
   114  	if u.Path != "" {
   115  		accrue("dbname", u.Path[1:])
   116  	}
   117  
   118  	q := u.Query()
   119  	for k := range q {
   120  		accrue(k, q.Get(k))
   121  	}
   122  
   123  	sort.Strings(kvs) // Makes testing easier (not a performance concern)
   124  	return strings.Join(kvs, " "), nil
   125  }
   126  
   127  // parseOpts parses the options from name and adds them to the values.
   128  //
   129  // The parsing code is based on conninfo_parse from libpq's fe-connect.c
   130  func (ext *postgresExtension) parseOpts(name string, o values) error {
   131  	s := ext.newScanner(name)
   132  
   133  	for {
   134  		var (
   135  			keyRunes, valRunes []rune
   136  			r                  rune
   137  			ok                 bool
   138  		)
   139  
   140  		if r, ok = s.SkipSpaces(); !ok {
   141  			break
   142  		}
   143  
   144  		// Scan the key
   145  		for !unicode.IsSpace(r) && r != '=' {
   146  			keyRunes = append(keyRunes, r)
   147  			if r, ok = s.Next(); !ok {
   148  				break
   149  			}
   150  		}
   151  
   152  		// Skip any whitespace if we're not at the = yet
   153  		if r != '=' {
   154  			r, ok = s.SkipSpaces()
   155  		}
   156  
   157  		// The current character should be =
   158  		if r != '=' || !ok {
   159  			return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes))
   160  		}
   161  
   162  		// Skip any whitespace after the =
   163  		if r, ok = s.SkipSpaces(); !ok {
   164  			// If we reach the end here, the last value is just an empty string as per libpq.
   165  			o[string(keyRunes)] = ""
   166  			break
   167  		}
   168  
   169  		if r != '\'' {
   170  			for !unicode.IsSpace(r) {
   171  				if r == '\\' {
   172  					if r, ok = s.Next(); !ok {
   173  						return fmt.Errorf(`missing character after backslash`)
   174  					}
   175  				}
   176  				valRunes = append(valRunes, r)
   177  
   178  				if r, ok = s.Next(); !ok {
   179  					break
   180  				}
   181  			}
   182  		} else {
   183  		quote:
   184  			for {
   185  				if r, ok = s.Next(); !ok {
   186  					return fmt.Errorf(`unterminated quoted string literal in connection string`)
   187  				}
   188  				switch r {
   189  				case '\'':
   190  					break quote
   191  				case '\\':
   192  					r, _ = s.Next()
   193  					fallthrough
   194  				default:
   195  					valRunes = append(valRunes, r)
   196  				}
   197  			}
   198  		}
   199  
   200  		o[string(keyRunes)] = string(valRunes)
   201  	}
   202  
   203  	return nil
   204  }
   205  
   206  // newScanner returns a new scanner initialized with the option string s.
   207  func (ext *postgresExtension) newScanner(s string) *scanner {
   208  	return &scanner{[]rune(s), 0}
   209  }
   210  
   211  // Next returns the next rune.
   212  // It returns 0, false if the end of the text has been reached.
   213  func (s *scanner) Next() (rune, bool) {
   214  	if s.i >= len(s.s) {
   215  		return 0, false
   216  	}
   217  	r := s.s[s.i]
   218  	s.i++
   219  	return r, true
   220  }
   221  
   222  // SkipSpaces returns the next non-whitespace rune.
   223  // It returns 0, false if the end of the text has been reached.
   224  func (s *scanner) SkipSpaces() (rune, bool) {
   225  	r, ok := s.Next()
   226  	for unicode.IsSpace(r) && ok {
   227  		r, ok = s.Next()
   228  	}
   229  	return r, ok
   230  }