go.undefinedlabs.com/scopeagent@v0.4.2/instrumentation/sql/vendor_mysql.go (about)

     1  package sql
     2  
     3  import (
     4  	"net"
     5  	"strings"
     6  )
     7  
     8  type mysqlExtension struct{}
     9  
    10  func init() {
    11  	vendorExtensions = append(vendorExtensions, &mysqlExtension{})
    12  }
    13  
    14  // Gets if the extension is compatible with the component name
    15  func (ext *mysqlExtension) IsCompatible(componentName string) bool {
    16  	return componentName == "mysql.MySQLDriver"
    17  }
    18  
    19  // Complete the missing driver data from the connection string
    20  func (ext *mysqlExtension) ProcessConnectionString(connectionString string, configuration *driverConfiguration) {
    21  	configuration.peerService = "mysql"
    22  
    23  	dsn := *ext.parseDSN(connectionString)
    24  	configuration.user = dsn["User"]
    25  	configuration.port = dsn["Port"]
    26  	configuration.instance = dsn["DBName"]
    27  	configuration.host = dsn["Host"]
    28  	configuration.connString = strings.Replace(connectionString, dsn["Passwd"], "******", -1)
    29  }
    30  
    31  // ParseDSN parses the DSN string to a Config
    32  func (ext *mysqlExtension) parseDSN(dsn string) *values {
    33  	// New config with some default values
    34  	tmpCfg := values{}
    35  
    36  	// [user[:password]@][net[(addr)]]/dbname[?param1=value1&paramN=valueN]
    37  	// Find the last '/' (since the password or the net addr might contain a '/')
    38  	foundSlash := false
    39  	for i := len(dsn) - 1; i >= 0; i-- {
    40  		if dsn[i] == '/' {
    41  			foundSlash = true
    42  			var j, k int
    43  
    44  			// left part is empty if i <= 0
    45  			if i > 0 {
    46  				// [username[:password]@][protocol[(address)]]
    47  				// Find the last '@' in dsn[:i]
    48  				for j = i; j >= 0; j-- {
    49  					if dsn[j] == '@' {
    50  						// username[:password]
    51  						// Find the first ':' in dsn[:j]
    52  						for k = 0; k < j; k++ {
    53  							if dsn[k] == ':' {
    54  								tmpCfg["Passwd"] = dsn[k+1 : j]
    55  								break
    56  							}
    57  						}
    58  						tmpCfg["User"] = dsn[:k]
    59  						break
    60  					}
    61  				}
    62  
    63  				// [protocol[(address)]]
    64  				// Find the first '(' in dsn[j+1:i]
    65  				for k = j + 1; k < i; k++ {
    66  					if dsn[k] == '(' {
    67  						// dsn[i-1] must be == ')' if an address is specified
    68  						if dsn[i-1] != ')' {
    69  							if strings.ContainsRune(dsn[k+1:i], ')') {
    70  								return nil
    71  							}
    72  							return nil
    73  						}
    74  						tmpCfg["Addr"] = dsn[k+1 : i-1]
    75  						break
    76  					}
    77  				}
    78  				tmpCfg["Net"] = dsn[j+1 : k]
    79  			}
    80  
    81  			// dbname[?param1=value1&...&paramN=valueN]
    82  			// Find the first '?' in dsn[i+1:]
    83  			for j = i + 1; j < len(dsn); j++ {
    84  				if dsn[j] == '?' {
    85  					ext.parseDSNParams(&tmpCfg, dsn[j+1:])
    86  					break
    87  				}
    88  			}
    89  			tmpCfg["DBName"] = dsn[i+1 : j]
    90  			break
    91  		}
    92  	}
    93  
    94  	if !foundSlash && len(dsn) > 0 {
    95  		return nil
    96  	}
    97  	ext.normalize(&tmpCfg)
    98  	return &tmpCfg
    99  }
   100  
   101  // parseDSNParams parses the DSN "query string"
   102  // Values must be url.QueryEscape'ed
   103  func (ext *mysqlExtension) parseDSNParams(cfg *values, params string) {
   104  	for _, v := range strings.Split(params, "&") {
   105  		param := strings.SplitN(v, "=", 2)
   106  		if len(param) != 2 {
   107  			continue
   108  		}
   109  		(*cfg)[param[0]] = param[1]
   110  	}
   111  }
   112  
   113  func (ext *mysqlExtension) normalize(cfg *values) {
   114  	// Set default network if empty
   115  	if (*cfg)["Net"] == "" {
   116  		(*cfg)["Net"] = "tcp"
   117  	}
   118  
   119  	// Set default address if empty
   120  	if (*cfg)["Addr"] == "" {
   121  		switch (*cfg)["Net"] {
   122  		case "tcp":
   123  			(*cfg)["Addr"] = "127.0.0.1:3306"
   124  		case "unix":
   125  			(*cfg)["Addr"] = "/tmp/mysql.sock"
   126  		}
   127  	} else if (*cfg)["Net"] == "tcp" {
   128  		(*cfg)["Addr"] = ext.ensureHavePort((*cfg)["Addr"])
   129  	}
   130  
   131  	if host, port, err := net.SplitHostPort((*cfg)["Addr"]); err == nil {
   132  		(*cfg)["Host"] = host
   133  		(*cfg)["Port"] = port
   134  	}
   135  }
   136  
   137  func (ext *mysqlExtension) ensureHavePort(addr string) string {
   138  	if _, _, err := net.SplitHostPort(addr); err != nil {
   139  		return net.JoinHostPort(addr, "3306")
   140  	}
   141  	return addr
   142  }