github.com/hellobchain/third_party@v0.0.0-20230331131523-deb0478a2e52/go-sql-driver/mysql/dsn.go (about)

     1  // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
     2  //
     3  // Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved.
     4  //
     5  // This Source Code Form is subject to the terms of the Mozilla Public
     6  // License, v. 2.0. If a copy of the MPL was not distributed with this file,
     7  // You can obtain one at http://mozilla.org/MPL/2.0/.
     8  
     9  package mysql
    10  
    11  import (
    12  	"bytes"
    13  	"crypto/rsa"
    14  	"errors"
    15  	"fmt"
    16  	"github.com/hellobchain/newcryptosm/tls"
    17  	"math/big"
    18  	"net"
    19  	"net/url"
    20  	"sort"
    21  	"strconv"
    22  	"strings"
    23  	"time"
    24  )
    25  
    26  var (
    27  	errInvalidDSNUnescaped       = errors.New("invalid DSN: did you forget to escape a param value?")
    28  	errInvalidDSNAddr            = errors.New("invalid DSN: network address not terminated (missing closing brace)")
    29  	errInvalidDSNNoSlash         = errors.New("invalid DSN: missing the slash separating the database name")
    30  	errInvalidDSNUnsafeCollation = errors.New("invalid DSN: interpolateParams can not be used with unsafe collations")
    31  )
    32  
    33  // Config is a configuration parsed from a DSN string.
    34  // If a new Config is created instead of being parsed from a DSN string,
    35  // the NewConfig function should be used, which sets default values.
    36  type Config struct {
    37  	User             string            // Username
    38  	Passwd           string            // Password (requires User)
    39  	Net              string            // Network type
    40  	Addr             string            // Network address (requires Net)
    41  	DBName           string            // Database name
    42  	Params           map[string]string // Connection parameters
    43  	Collation        string            // Connection collation
    44  	Loc              *time.Location    // Location for time.Time values
    45  	MaxAllowedPacket int               // Max packet size allowed
    46  	ServerPubKey     string            // Server public key name
    47  	pubKey           *rsa.PublicKey    // Server public key
    48  	TLSConfig        string            // TLS configuration name
    49  	tls              *tls.Config       // TLS configuration
    50  	Timeout          time.Duration     // Dial timeout
    51  	ReadTimeout      time.Duration     // I/O read timeout
    52  	WriteTimeout     time.Duration     // I/O write timeout
    53  
    54  	AllowAllFiles           bool // Allow all files to be used with LOAD DATA LOCAL INFILE
    55  	AllowCleartextPasswords bool // Allows the cleartext client side plugin
    56  	AllowNativePasswords    bool // Allows the native password authentication method
    57  	AllowOldPasswords       bool // Allows the old insecure password method
    58  	ClientFoundRows         bool // Return number of matching rows instead of rows changed
    59  	ColumnsWithAlias        bool // Prepend table alias to column names
    60  	InterpolateParams       bool // Interpolate placeholders into query string
    61  	MultiStatements         bool // Allow multiple statements in one query
    62  	ParseTime               bool // Parse time values to time.Time
    63  	RejectReadOnly          bool // Reject read-only connections
    64  }
    65  
    66  // NewConfig creates a new Config and sets default values.
    67  func NewConfig() *Config {
    68  	return &Config{
    69  		Collation:            defaultCollation,
    70  		Loc:                  time.UTC,
    71  		MaxAllowedPacket:     defaultMaxAllowedPacket,
    72  		AllowNativePasswords: true,
    73  	}
    74  }
    75  
    76  func (cfg *Config) Clone() *Config {
    77  	cp := *cfg
    78  	if cp.tls != nil {
    79  		cp.tls = cfg.tls.Clone()
    80  	}
    81  	if len(cp.Params) > 0 {
    82  		cp.Params = make(map[string]string, len(cfg.Params))
    83  		for k, v := range cfg.Params {
    84  			cp.Params[k] = v
    85  		}
    86  	}
    87  	if cfg.pubKey != nil {
    88  		cp.pubKey = &rsa.PublicKey{
    89  			N: new(big.Int).Set(cfg.pubKey.N),
    90  			E: cfg.pubKey.E,
    91  		}
    92  	}
    93  	return &cp
    94  }
    95  
    96  func (cfg *Config) normalize() error {
    97  	if cfg.InterpolateParams && unsafeCollations[cfg.Collation] {
    98  		return errInvalidDSNUnsafeCollation
    99  	}
   100  
   101  	// Set default network if empty
   102  	if cfg.Net == "" {
   103  		cfg.Net = "tcp"
   104  	}
   105  
   106  	// Set default address if empty
   107  	if cfg.Addr == "" {
   108  		switch cfg.Net {
   109  		case "tcp":
   110  			cfg.Addr = "127.0.0.1:3306"
   111  		case "unix":
   112  			cfg.Addr = "/tmp/mysql.sock"
   113  		default:
   114  			return errors.New("default addr for network '" + cfg.Net + "' unknown")
   115  		}
   116  	} else if cfg.Net == "tcp" {
   117  		cfg.Addr = ensureHavePort(cfg.Addr)
   118  	}
   119  
   120  	switch cfg.TLSConfig {
   121  	case "false", "":
   122  		// don't set anything
   123  	case "true":
   124  		cfg.tls = &tls.Config{}
   125  	case "skip-verify", "preferred":
   126  		cfg.tls = &tls.Config{InsecureSkipVerify: true}
   127  	default:
   128  		cfg.tls = getTLSConfigClone(cfg.TLSConfig)
   129  		if cfg.tls == nil {
   130  			return errors.New("invalid value / unknown config name: " + cfg.TLSConfig)
   131  		}
   132  	}
   133  
   134  	if cfg.tls != nil && cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify {
   135  		host, _, err := net.SplitHostPort(cfg.Addr)
   136  		if err == nil {
   137  			cfg.tls.ServerName = host
   138  		}
   139  	}
   140  
   141  	if cfg.ServerPubKey != "" {
   142  		cfg.pubKey = getServerPubKey(cfg.ServerPubKey)
   143  		if cfg.pubKey == nil {
   144  			return errors.New("invalid value / unknown server pub key name: " + cfg.ServerPubKey)
   145  		}
   146  	}
   147  
   148  	return nil
   149  }
   150  
   151  // FormatDSN formats the given Config into a DSN string which can be passed to
   152  // the driver.
   153  func (cfg *Config) FormatDSN() string {
   154  	var buf bytes.Buffer
   155  
   156  	// [username[:password]@]
   157  	if len(cfg.User) > 0 {
   158  		buf.WriteString(cfg.User)
   159  		if len(cfg.Passwd) > 0 {
   160  			buf.WriteByte(':')
   161  			buf.WriteString(cfg.Passwd)
   162  		}
   163  		buf.WriteByte('@')
   164  	}
   165  
   166  	// [protocol[(address)]]
   167  	if len(cfg.Net) > 0 {
   168  		buf.WriteString(cfg.Net)
   169  		if len(cfg.Addr) > 0 {
   170  			buf.WriteByte('(')
   171  			buf.WriteString(cfg.Addr)
   172  			buf.WriteByte(')')
   173  		}
   174  	}
   175  
   176  	// /dbname
   177  	buf.WriteByte('/')
   178  	buf.WriteString(cfg.DBName)
   179  
   180  	// [?param1=value1&...&paramN=valueN]
   181  	hasParam := false
   182  
   183  	if cfg.AllowAllFiles {
   184  		hasParam = true
   185  		buf.WriteString("?allowAllFiles=true")
   186  	}
   187  
   188  	if cfg.AllowCleartextPasswords {
   189  		if hasParam {
   190  			buf.WriteString("&allowCleartextPasswords=true")
   191  		} else {
   192  			hasParam = true
   193  			buf.WriteString("?allowCleartextPasswords=true")
   194  		}
   195  	}
   196  
   197  	if !cfg.AllowNativePasswords {
   198  		if hasParam {
   199  			buf.WriteString("&allowNativePasswords=false")
   200  		} else {
   201  			hasParam = true
   202  			buf.WriteString("?allowNativePasswords=false")
   203  		}
   204  	}
   205  
   206  	if cfg.AllowOldPasswords {
   207  		if hasParam {
   208  			buf.WriteString("&allowOldPasswords=true")
   209  		} else {
   210  			hasParam = true
   211  			buf.WriteString("?allowOldPasswords=true")
   212  		}
   213  	}
   214  
   215  	if cfg.ClientFoundRows {
   216  		if hasParam {
   217  			buf.WriteString("&clientFoundRows=true")
   218  		} else {
   219  			hasParam = true
   220  			buf.WriteString("?clientFoundRows=true")
   221  		}
   222  	}
   223  
   224  	if col := cfg.Collation; col != defaultCollation && len(col) > 0 {
   225  		if hasParam {
   226  			buf.WriteString("&collation=")
   227  		} else {
   228  			hasParam = true
   229  			buf.WriteString("?collation=")
   230  		}
   231  		buf.WriteString(col)
   232  	}
   233  
   234  	if cfg.ColumnsWithAlias {
   235  		if hasParam {
   236  			buf.WriteString("&columnsWithAlias=true")
   237  		} else {
   238  			hasParam = true
   239  			buf.WriteString("?columnsWithAlias=true")
   240  		}
   241  	}
   242  
   243  	if cfg.InterpolateParams {
   244  		if hasParam {
   245  			buf.WriteString("&interpolateParams=true")
   246  		} else {
   247  			hasParam = true
   248  			buf.WriteString("?interpolateParams=true")
   249  		}
   250  	}
   251  
   252  	if cfg.Loc != time.UTC && cfg.Loc != nil {
   253  		if hasParam {
   254  			buf.WriteString("&loc=")
   255  		} else {
   256  			hasParam = true
   257  			buf.WriteString("?loc=")
   258  		}
   259  		buf.WriteString(url.QueryEscape(cfg.Loc.String()))
   260  	}
   261  
   262  	if cfg.MultiStatements {
   263  		if hasParam {
   264  			buf.WriteString("&multiStatements=true")
   265  		} else {
   266  			hasParam = true
   267  			buf.WriteString("?multiStatements=true")
   268  		}
   269  	}
   270  
   271  	if cfg.ParseTime {
   272  		if hasParam {
   273  			buf.WriteString("&parseTime=true")
   274  		} else {
   275  			hasParam = true
   276  			buf.WriteString("?parseTime=true")
   277  		}
   278  	}
   279  
   280  	if cfg.ReadTimeout > 0 {
   281  		if hasParam {
   282  			buf.WriteString("&readTimeout=")
   283  		} else {
   284  			hasParam = true
   285  			buf.WriteString("?readTimeout=")
   286  		}
   287  		buf.WriteString(cfg.ReadTimeout.String())
   288  	}
   289  
   290  	if cfg.RejectReadOnly {
   291  		if hasParam {
   292  			buf.WriteString("&rejectReadOnly=true")
   293  		} else {
   294  			hasParam = true
   295  			buf.WriteString("?rejectReadOnly=true")
   296  		}
   297  	}
   298  
   299  	if len(cfg.ServerPubKey) > 0 {
   300  		if hasParam {
   301  			buf.WriteString("&serverPubKey=")
   302  		} else {
   303  			hasParam = true
   304  			buf.WriteString("?serverPubKey=")
   305  		}
   306  		buf.WriteString(url.QueryEscape(cfg.ServerPubKey))
   307  	}
   308  
   309  	if cfg.Timeout > 0 {
   310  		if hasParam {
   311  			buf.WriteString("&timeout=")
   312  		} else {
   313  			hasParam = true
   314  			buf.WriteString("?timeout=")
   315  		}
   316  		buf.WriteString(cfg.Timeout.String())
   317  	}
   318  
   319  	if len(cfg.TLSConfig) > 0 {
   320  		if hasParam {
   321  			buf.WriteString("&tls=")
   322  		} else {
   323  			hasParam = true
   324  			buf.WriteString("?tls=")
   325  		}
   326  		buf.WriteString(url.QueryEscape(cfg.TLSConfig))
   327  	}
   328  
   329  	if cfg.WriteTimeout > 0 {
   330  		if hasParam {
   331  			buf.WriteString("&writeTimeout=")
   332  		} else {
   333  			hasParam = true
   334  			buf.WriteString("?writeTimeout=")
   335  		}
   336  		buf.WriteString(cfg.WriteTimeout.String())
   337  	}
   338  
   339  	if cfg.MaxAllowedPacket != defaultMaxAllowedPacket {
   340  		if hasParam {
   341  			buf.WriteString("&maxAllowedPacket=")
   342  		} else {
   343  			hasParam = true
   344  			buf.WriteString("?maxAllowedPacket=")
   345  		}
   346  		buf.WriteString(strconv.Itoa(cfg.MaxAllowedPacket))
   347  
   348  	}
   349  
   350  	// other params
   351  	if cfg.Params != nil {
   352  		var params []string
   353  		for param := range cfg.Params {
   354  			params = append(params, param)
   355  		}
   356  		sort.Strings(params)
   357  		for _, param := range params {
   358  			if hasParam {
   359  				buf.WriteByte('&')
   360  			} else {
   361  				hasParam = true
   362  				buf.WriteByte('?')
   363  			}
   364  
   365  			buf.WriteString(param)
   366  			buf.WriteByte('=')
   367  			buf.WriteString(url.QueryEscape(cfg.Params[param]))
   368  		}
   369  	}
   370  
   371  	return buf.String()
   372  }
   373  
   374  // ParseDSN parses the DSN string to a Config
   375  func ParseDSN(dsn string) (cfg *Config, err error) {
   376  	// New config with some default values
   377  	cfg = NewConfig()
   378  
   379  	// [user[:password]@][net[(addr)]]/dbname[?param1=value1&paramN=valueN]
   380  	// Find the last '/' (since the password or the net addr might contain a '/')
   381  	foundSlash := false
   382  	for i := len(dsn) - 1; i >= 0; i-- {
   383  		if dsn[i] == '/' {
   384  			foundSlash = true
   385  			var j, k int
   386  
   387  			// left part is empty if i <= 0
   388  			if i > 0 {
   389  				// [username[:password]@][protocol[(address)]]
   390  				// Find the last '@' in dsn[:i]
   391  				for j = i; j >= 0; j-- {
   392  					if dsn[j] == '@' {
   393  						// username[:password]
   394  						// Find the first ':' in dsn[:j]
   395  						for k = 0; k < j; k++ {
   396  							if dsn[k] == ':' {
   397  								cfg.Passwd = dsn[k+1 : j]
   398  								break
   399  							}
   400  						}
   401  						cfg.User = dsn[:k]
   402  
   403  						break
   404  					}
   405  				}
   406  
   407  				// [protocol[(address)]]
   408  				// Find the first '(' in dsn[j+1:i]
   409  				for k = j + 1; k < i; k++ {
   410  					if dsn[k] == '(' {
   411  						// dsn[i-1] must be == ')' if an address is specified
   412  						if dsn[i-1] != ')' {
   413  							if strings.ContainsRune(dsn[k+1:i], ')') {
   414  								return nil, errInvalidDSNUnescaped
   415  							}
   416  							return nil, errInvalidDSNAddr
   417  						}
   418  						cfg.Addr = dsn[k+1 : i-1]
   419  						break
   420  					}
   421  				}
   422  				cfg.Net = dsn[j+1 : k]
   423  			}
   424  
   425  			// dbname[?param1=value1&...&paramN=valueN]
   426  			// Find the first '?' in dsn[i+1:]
   427  			for j = i + 1; j < len(dsn); j++ {
   428  				if dsn[j] == '?' {
   429  					if err = parseDSNParams(cfg, dsn[j+1:]); err != nil {
   430  						return
   431  					}
   432  					break
   433  				}
   434  			}
   435  			cfg.DBName = dsn[i+1 : j]
   436  
   437  			break
   438  		}
   439  	}
   440  
   441  	if !foundSlash && len(dsn) > 0 {
   442  		return nil, errInvalidDSNNoSlash
   443  	}
   444  
   445  	if err = cfg.normalize(); err != nil {
   446  		return nil, err
   447  	}
   448  	return
   449  }
   450  
   451  // parseDSNParams parses the DSN "query string"
   452  // Values must be url.QueryEscape'ed
   453  func parseDSNParams(cfg *Config, params string) (err error) {
   454  	for _, v := range strings.Split(params, "&") {
   455  		param := strings.SplitN(v, "=", 2)
   456  		if len(param) != 2 {
   457  			continue
   458  		}
   459  
   460  		// cfg params
   461  		switch value := param[1]; param[0] {
   462  		// Disable INFILE whitelist / enable all files
   463  		case "allowAllFiles":
   464  			var isBool bool
   465  			cfg.AllowAllFiles, isBool = readBool(value)
   466  			if !isBool {
   467  				return errors.New("invalid bool value: " + value)
   468  			}
   469  
   470  		// Use cleartext authentication mode (MySQL 5.5.10+)
   471  		case "allowCleartextPasswords":
   472  			var isBool bool
   473  			cfg.AllowCleartextPasswords, isBool = readBool(value)
   474  			if !isBool {
   475  				return errors.New("invalid bool value: " + value)
   476  			}
   477  
   478  		// Use native password authentication
   479  		case "allowNativePasswords":
   480  			var isBool bool
   481  			cfg.AllowNativePasswords, isBool = readBool(value)
   482  			if !isBool {
   483  				return errors.New("invalid bool value: " + value)
   484  			}
   485  
   486  		// Use old authentication mode (pre MySQL 4.1)
   487  		case "allowOldPasswords":
   488  			var isBool bool
   489  			cfg.AllowOldPasswords, isBool = readBool(value)
   490  			if !isBool {
   491  				return errors.New("invalid bool value: " + value)
   492  			}
   493  
   494  		// Switch "rowsAffected" mode
   495  		case "clientFoundRows":
   496  			var isBool bool
   497  			cfg.ClientFoundRows, isBool = readBool(value)
   498  			if !isBool {
   499  				return errors.New("invalid bool value: " + value)
   500  			}
   501  
   502  		// Collation
   503  		case "collation":
   504  			cfg.Collation = value
   505  			break
   506  
   507  		case "columnsWithAlias":
   508  			var isBool bool
   509  			cfg.ColumnsWithAlias, isBool = readBool(value)
   510  			if !isBool {
   511  				return errors.New("invalid bool value: " + value)
   512  			}
   513  
   514  		// Compression
   515  		case "compress":
   516  			return errors.New("compression not implemented yet")
   517  
   518  		// Enable client side placeholder substitution
   519  		case "interpolateParams":
   520  			var isBool bool
   521  			cfg.InterpolateParams, isBool = readBool(value)
   522  			if !isBool {
   523  				return errors.New("invalid bool value: " + value)
   524  			}
   525  
   526  		// Time Location
   527  		case "loc":
   528  			if value, err = url.QueryUnescape(value); err != nil {
   529  				return
   530  			}
   531  			cfg.Loc, err = time.LoadLocation(value)
   532  			if err != nil {
   533  				return
   534  			}
   535  
   536  		// multiple statements in one query
   537  		case "multiStatements":
   538  			var isBool bool
   539  			cfg.MultiStatements, isBool = readBool(value)
   540  			if !isBool {
   541  				return errors.New("invalid bool value: " + value)
   542  			}
   543  
   544  		// time.Time parsing
   545  		case "parseTime":
   546  			var isBool bool
   547  			cfg.ParseTime, isBool = readBool(value)
   548  			if !isBool {
   549  				return errors.New("invalid bool value: " + value)
   550  			}
   551  
   552  		// I/O read Timeout
   553  		case "readTimeout":
   554  			cfg.ReadTimeout, err = time.ParseDuration(value)
   555  			if err != nil {
   556  				return
   557  			}
   558  
   559  		// Reject read-only connections
   560  		case "rejectReadOnly":
   561  			var isBool bool
   562  			cfg.RejectReadOnly, isBool = readBool(value)
   563  			if !isBool {
   564  				return errors.New("invalid bool value: " + value)
   565  			}
   566  
   567  		// Server public key
   568  		case "serverPubKey":
   569  			name, err := url.QueryUnescape(value)
   570  			if err != nil {
   571  				return fmt.Errorf("invalid value for server pub key name: %v", err)
   572  			}
   573  			cfg.ServerPubKey = name
   574  
   575  		// Strict mode
   576  		case "strict":
   577  			panic("strict mode has been removed. See https://github.com/go-sql-driver/mysql/wiki/strict-mode")
   578  
   579  		// Dial Timeout
   580  		case "timeout":
   581  			cfg.Timeout, err = time.ParseDuration(value)
   582  			if err != nil {
   583  				return
   584  			}
   585  
   586  		// TLS-Encryption
   587  		case "tls":
   588  			boolValue, isBool := readBool(value)
   589  			if isBool {
   590  				if boolValue {
   591  					cfg.TLSConfig = "true"
   592  				} else {
   593  					cfg.TLSConfig = "false"
   594  				}
   595  			} else if vl := strings.ToLower(value); vl == "skip-verify" || vl == "preferred" {
   596  				cfg.TLSConfig = vl
   597  			} else {
   598  				name, err := url.QueryUnescape(value)
   599  				if err != nil {
   600  					return fmt.Errorf("invalid value for TLS config name: %v", err)
   601  				}
   602  				cfg.TLSConfig = name
   603  			}
   604  
   605  		// I/O write Timeout
   606  		case "writeTimeout":
   607  			cfg.WriteTimeout, err = time.ParseDuration(value)
   608  			if err != nil {
   609  				return
   610  			}
   611  		case "maxAllowedPacket":
   612  			cfg.MaxAllowedPacket, err = strconv.Atoi(value)
   613  			if err != nil {
   614  				return
   615  			}
   616  		default:
   617  			// lazy init
   618  			if cfg.Params == nil {
   619  				cfg.Params = make(map[string]string)
   620  			}
   621  
   622  			if cfg.Params[param[0]], err = url.QueryUnescape(value); err != nil {
   623  				return
   624  			}
   625  		}
   626  	}
   627  
   628  	return
   629  }
   630  
   631  func ensureHavePort(addr string) string {
   632  	if _, _, err := net.SplitHostPort(addr); err != nil {
   633  		return net.JoinHostPort(addr, "3306")
   634  	}
   635  	return addr
   636  }