github.com/mdaxf/iac@v0.0.0-20240519030858-58a061660378/vendor_skip/go.mongodb.org/mongo-driver/x/mongo/driver/connstring/connstring.go (about)

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package connstring // import "go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
     8  
     9  import (
    10  	"errors"
    11  	"fmt"
    12  	"net"
    13  	"net/url"
    14  	"strconv"
    15  	"strings"
    16  	"time"
    17  
    18  	"go.mongodb.org/mongo-driver/internal"
    19  	"go.mongodb.org/mongo-driver/internal/randutil"
    20  	"go.mongodb.org/mongo-driver/mongo/writeconcern"
    21  	"go.mongodb.org/mongo-driver/x/mongo/driver/dns"
    22  	"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
    23  )
    24  
    25  // random is a package-global pseudo-random number generator.
    26  var random = randutil.NewLockedRand()
    27  
    28  // ParseAndValidate parses the provided URI into a ConnString object.
    29  // It check that all values are valid.
    30  func ParseAndValidate(s string) (ConnString, error) {
    31  	p := parser{dnsResolver: dns.DefaultResolver}
    32  	err := p.parse(s)
    33  	if err != nil {
    34  		return p.ConnString, internal.WrapErrorf(err, "error parsing uri")
    35  	}
    36  	err = p.ConnString.Validate()
    37  	if err != nil {
    38  		return p.ConnString, internal.WrapErrorf(err, "error validating uri")
    39  	}
    40  	return p.ConnString, nil
    41  }
    42  
    43  // Parse parses the provided URI into a ConnString object
    44  // but does not check that all values are valid. Use `ConnString.Validate()`
    45  // to run the validation checks separately.
    46  func Parse(s string) (ConnString, error) {
    47  	p := parser{dnsResolver: dns.DefaultResolver}
    48  	err := p.parse(s)
    49  	if err != nil {
    50  		err = internal.WrapErrorf(err, "error parsing uri")
    51  	}
    52  	return p.ConnString, err
    53  }
    54  
    55  // ConnString represents a connection string to mongodb.
    56  type ConnString struct {
    57  	Original                           string
    58  	AppName                            string
    59  	AuthMechanism                      string
    60  	AuthMechanismProperties            map[string]string
    61  	AuthMechanismPropertiesSet         bool
    62  	AuthSource                         string
    63  	AuthSourceSet                      bool
    64  	Compressors                        []string
    65  	Connect                            ConnectMode
    66  	ConnectSet                         bool
    67  	DirectConnection                   bool
    68  	DirectConnectionSet                bool
    69  	ConnectTimeout                     time.Duration
    70  	ConnectTimeoutSet                  bool
    71  	Database                           string
    72  	HeartbeatInterval                  time.Duration
    73  	HeartbeatIntervalSet               bool
    74  	Hosts                              []string
    75  	J                                  bool
    76  	JSet                               bool
    77  	LoadBalanced                       bool
    78  	LoadBalancedSet                    bool
    79  	LocalThreshold                     time.Duration
    80  	LocalThresholdSet                  bool
    81  	MaxConnIdleTime                    time.Duration
    82  	MaxConnIdleTimeSet                 bool
    83  	MaxPoolSize                        uint64
    84  	MaxPoolSizeSet                     bool
    85  	MinPoolSize                        uint64
    86  	MinPoolSizeSet                     bool
    87  	MaxConnecting                      uint64
    88  	MaxConnectingSet                   bool
    89  	Password                           string
    90  	PasswordSet                        bool
    91  	ReadConcernLevel                   string
    92  	ReadPreference                     string
    93  	ReadPreferenceTagSets              []map[string]string
    94  	RetryWrites                        bool
    95  	RetryWritesSet                     bool
    96  	RetryReads                         bool
    97  	RetryReadsSet                      bool
    98  	MaxStaleness                       time.Duration
    99  	MaxStalenessSet                    bool
   100  	ReplicaSet                         string
   101  	Scheme                             string
   102  	ServerSelectionTimeout             time.Duration
   103  	ServerSelectionTimeoutSet          bool
   104  	SocketTimeout                      time.Duration
   105  	SocketTimeoutSet                   bool
   106  	SRVMaxHosts                        int
   107  	SRVServiceName                     string
   108  	SSL                                bool
   109  	SSLSet                             bool
   110  	SSLClientCertificateKeyFile        string
   111  	SSLClientCertificateKeyFileSet     bool
   112  	SSLClientCertificateKeyPassword    func() string
   113  	SSLClientCertificateKeyPasswordSet bool
   114  	SSLCertificateFile                 string
   115  	SSLCertificateFileSet              bool
   116  	SSLPrivateKeyFile                  string
   117  	SSLPrivateKeyFileSet               bool
   118  	SSLInsecure                        bool
   119  	SSLInsecureSet                     bool
   120  	SSLCaFile                          string
   121  	SSLCaFileSet                       bool
   122  	SSLDisableOCSPEndpointCheck        bool
   123  	SSLDisableOCSPEndpointCheckSet     bool
   124  	Timeout                            time.Duration
   125  	TimeoutSet                         bool
   126  	WString                            string
   127  	WNumber                            int
   128  	WNumberSet                         bool
   129  	Username                           string
   130  	UsernameSet                        bool
   131  	ZlibLevel                          int
   132  	ZlibLevelSet                       bool
   133  	ZstdLevel                          int
   134  	ZstdLevelSet                       bool
   135  
   136  	WTimeout              time.Duration
   137  	WTimeoutSet           bool
   138  	WTimeoutSetFromOption bool
   139  
   140  	Options        map[string][]string
   141  	UnknownOptions map[string][]string
   142  }
   143  
   144  func (u *ConnString) String() string {
   145  	return u.Original
   146  }
   147  
   148  // HasAuthParameters returns true if this ConnString has any authentication parameters set and therefore represents
   149  // a request for authentication.
   150  func (u *ConnString) HasAuthParameters() bool {
   151  	// Check all auth parameters except for AuthSource because an auth source without other credentials is semantically
   152  	// valid and must not be interpreted as a request for authentication.
   153  	return u.AuthMechanism != "" || u.AuthMechanismProperties != nil || u.UsernameSet || u.PasswordSet
   154  }
   155  
   156  // Validate checks that the Auth and SSL parameters are valid values.
   157  func (u *ConnString) Validate() error {
   158  	p := parser{
   159  		dnsResolver: dns.DefaultResolver,
   160  		ConnString:  *u,
   161  	}
   162  	return p.validate()
   163  }
   164  
   165  // ConnectMode informs the driver on how to connect
   166  // to the server.
   167  type ConnectMode uint8
   168  
   169  var _ fmt.Stringer = ConnectMode(0)
   170  
   171  // ConnectMode constants.
   172  const (
   173  	AutoConnect ConnectMode = iota
   174  	SingleConnect
   175  )
   176  
   177  // String implements the fmt.Stringer interface.
   178  func (c ConnectMode) String() string {
   179  	switch c {
   180  	case AutoConnect:
   181  		return "automatic"
   182  	case SingleConnect:
   183  		return "direct"
   184  	default:
   185  		return "unknown"
   186  	}
   187  }
   188  
   189  // Scheme constants
   190  const (
   191  	SchemeMongoDB    = "mongodb"
   192  	SchemeMongoDBSRV = "mongodb+srv"
   193  )
   194  
   195  type parser struct {
   196  	ConnString
   197  
   198  	dnsResolver *dns.Resolver
   199  	tlsssl      *bool // used to determine if tls and ssl options are both specified and set differently.
   200  }
   201  
   202  func (p *parser) parse(original string) error {
   203  	p.Original = original
   204  	uri := original
   205  
   206  	var err error
   207  	if strings.HasPrefix(uri, SchemeMongoDBSRV+"://") {
   208  		p.Scheme = SchemeMongoDBSRV
   209  		// remove the scheme
   210  		uri = uri[len(SchemeMongoDBSRV)+3:]
   211  	} else if strings.HasPrefix(uri, SchemeMongoDB+"://") {
   212  		p.Scheme = SchemeMongoDB
   213  		// remove the scheme
   214  		uri = uri[len(SchemeMongoDB)+3:]
   215  	} else {
   216  		return fmt.Errorf("scheme must be \"mongodb\" or \"mongodb+srv\"")
   217  	}
   218  
   219  	if idx := strings.Index(uri, "@"); idx != -1 {
   220  		userInfo := uri[:idx]
   221  		uri = uri[idx+1:]
   222  
   223  		username := userInfo
   224  		var password string
   225  
   226  		if idx := strings.Index(userInfo, ":"); idx != -1 {
   227  			username = userInfo[:idx]
   228  			password = userInfo[idx+1:]
   229  			p.PasswordSet = true
   230  		}
   231  
   232  		// Validate and process the username.
   233  		if strings.Contains(username, "/") {
   234  			return fmt.Errorf("unescaped slash in username")
   235  		}
   236  		p.Username, err = url.PathUnescape(username)
   237  		if err != nil {
   238  			return internal.WrapErrorf(err, "invalid username")
   239  		}
   240  		p.UsernameSet = true
   241  
   242  		// Validate and process the password.
   243  		if strings.Contains(password, ":") {
   244  			return fmt.Errorf("unescaped colon in password")
   245  		}
   246  		if strings.Contains(password, "/") {
   247  			return fmt.Errorf("unescaped slash in password")
   248  		}
   249  		p.Password, err = url.PathUnescape(password)
   250  		if err != nil {
   251  			return internal.WrapErrorf(err, "invalid password")
   252  		}
   253  	}
   254  
   255  	// fetch the hosts field
   256  	hosts := uri
   257  	if idx := strings.IndexAny(uri, "/?@"); idx != -1 {
   258  		if uri[idx] == '@' {
   259  			return fmt.Errorf("unescaped @ sign in user info")
   260  		}
   261  		if uri[idx] == '?' {
   262  			return fmt.Errorf("must have a / before the query ?")
   263  		}
   264  		hosts = uri[:idx]
   265  	}
   266  
   267  	parsedHosts := strings.Split(hosts, ",")
   268  	uri = uri[len(hosts):]
   269  	extractedDatabase, err := extractDatabaseFromURI(uri)
   270  	if err != nil {
   271  		return err
   272  	}
   273  
   274  	uri = extractedDatabase.uri
   275  	p.Database = extractedDatabase.db
   276  
   277  	// grab connection arguments from URI
   278  	connectionArgsFromQueryString, err := extractQueryArgsFromURI(uri)
   279  	if err != nil {
   280  		return err
   281  	}
   282  
   283  	// grab connection arguments from TXT record and enable SSL if "mongodb+srv://"
   284  	var connectionArgsFromTXT []string
   285  	if p.Scheme == SchemeMongoDBSRV {
   286  		connectionArgsFromTXT, err = p.dnsResolver.GetConnectionArgsFromTXT(hosts)
   287  		if err != nil {
   288  			return err
   289  		}
   290  
   291  		// SSL is enabled by default for SRV, but can be manually disabled with "ssl=false".
   292  		p.SSL = true
   293  		p.SSLSet = true
   294  	}
   295  
   296  	// add connection arguments from URI and TXT records to connstring
   297  	connectionArgPairs := make([]string, 0, len(connectionArgsFromTXT)+len(connectionArgsFromQueryString))
   298  	connectionArgPairs = append(connectionArgPairs, connectionArgsFromTXT...)
   299  	connectionArgPairs = append(connectionArgPairs, connectionArgsFromQueryString...)
   300  
   301  	for _, pair := range connectionArgPairs {
   302  		err := p.addOption(pair)
   303  		if err != nil {
   304  			return err
   305  		}
   306  	}
   307  
   308  	// do SRV lookup if "mongodb+srv://"
   309  	if p.Scheme == SchemeMongoDBSRV {
   310  		parsedHosts, err = p.dnsResolver.ParseHosts(hosts, p.SRVServiceName, true)
   311  		if err != nil {
   312  			return err
   313  		}
   314  
   315  		// If p.SRVMaxHosts is non-zero and is less than the number of hosts, randomly
   316  		// select SRVMaxHosts hosts from parsedHosts.
   317  		if p.SRVMaxHosts > 0 && p.SRVMaxHosts < len(parsedHosts) {
   318  			random.Shuffle(len(parsedHosts), func(i, j int) {
   319  				parsedHosts[i], parsedHosts[j] = parsedHosts[j], parsedHosts[i]
   320  			})
   321  			parsedHosts = parsedHosts[:p.SRVMaxHosts]
   322  		}
   323  	}
   324  
   325  	for _, host := range parsedHosts {
   326  		err = p.addHost(host)
   327  		if err != nil {
   328  			return internal.WrapErrorf(err, "invalid host %q", host)
   329  		}
   330  	}
   331  	if len(p.Hosts) == 0 {
   332  		return fmt.Errorf("must have at least 1 host")
   333  	}
   334  
   335  	err = p.setDefaultAuthParams(extractedDatabase.db)
   336  	if err != nil {
   337  		return err
   338  	}
   339  
   340  	// If WTimeout was set from manual options passed in, set WTImeoutSet to true.
   341  	if p.WTimeoutSetFromOption {
   342  		p.WTimeoutSet = true
   343  	}
   344  
   345  	return nil
   346  }
   347  
   348  func (p *parser) validate() error {
   349  	var err error
   350  
   351  	err = p.validateAuth()
   352  	if err != nil {
   353  		return err
   354  	}
   355  
   356  	if err = p.validateSSL(); err != nil {
   357  		return err
   358  	}
   359  
   360  	// Check for invalid write concern (i.e. w=0 and j=true)
   361  	if p.WNumberSet && p.WNumber == 0 && p.JSet && p.J {
   362  		return writeconcern.ErrInconsistent
   363  	}
   364  
   365  	// Check for invalid use of direct connections.
   366  	if (p.ConnectSet && p.Connect == SingleConnect) || (p.DirectConnectionSet && p.DirectConnection) {
   367  		if len(p.Hosts) > 1 {
   368  			return errors.New("a direct connection cannot be made if multiple hosts are specified")
   369  		}
   370  		if p.Scheme == SchemeMongoDBSRV {
   371  			return errors.New("a direct connection cannot be made if an SRV URI is used")
   372  		}
   373  		if p.LoadBalancedSet && p.LoadBalanced {
   374  			return internal.ErrLoadBalancedWithDirectConnection
   375  		}
   376  	}
   377  
   378  	// Validation for load-balanced mode.
   379  	if p.LoadBalancedSet && p.LoadBalanced {
   380  		if len(p.Hosts) > 1 {
   381  			return internal.ErrLoadBalancedWithMultipleHosts
   382  		}
   383  		if p.ReplicaSet != "" {
   384  			return internal.ErrLoadBalancedWithReplicaSet
   385  		}
   386  	}
   387  
   388  	// Check for invalid use of SRVMaxHosts.
   389  	if p.SRVMaxHosts > 0 {
   390  		if p.ReplicaSet != "" {
   391  			return internal.ErrSRVMaxHostsWithReplicaSet
   392  		}
   393  		if p.LoadBalanced {
   394  			return internal.ErrSRVMaxHostsWithLoadBalanced
   395  		}
   396  	}
   397  
   398  	return nil
   399  }
   400  
   401  func (p *parser) setDefaultAuthParams(dbName string) error {
   402  	// We do this check here rather than in validateAuth because this function is called as part of parsing and sets
   403  	// the value of AuthSource if authentication is enabled.
   404  	if p.AuthSourceSet && p.AuthSource == "" {
   405  		return errors.New("authSource must be non-empty when supplied in a URI")
   406  	}
   407  
   408  	switch strings.ToLower(p.AuthMechanism) {
   409  	case "plain":
   410  		if p.AuthSource == "" {
   411  			p.AuthSource = dbName
   412  			if p.AuthSource == "" {
   413  				p.AuthSource = "$external"
   414  			}
   415  		}
   416  	case "gssapi":
   417  		if p.AuthMechanismProperties == nil {
   418  			p.AuthMechanismProperties = map[string]string{
   419  				"SERVICE_NAME": "mongodb",
   420  			}
   421  		} else if v, ok := p.AuthMechanismProperties["SERVICE_NAME"]; !ok || v == "" {
   422  			p.AuthMechanismProperties["SERVICE_NAME"] = "mongodb"
   423  		}
   424  		fallthrough
   425  	case "mongodb-aws", "mongodb-x509":
   426  		if p.AuthSource == "" {
   427  			p.AuthSource = "$external"
   428  		} else if p.AuthSource != "$external" {
   429  			return fmt.Errorf("auth source must be $external")
   430  		}
   431  	case "mongodb-cr":
   432  		fallthrough
   433  	case "scram-sha-1":
   434  		fallthrough
   435  	case "scram-sha-256":
   436  		if p.AuthSource == "" {
   437  			p.AuthSource = dbName
   438  			if p.AuthSource == "" {
   439  				p.AuthSource = "admin"
   440  			}
   441  		}
   442  	case "":
   443  		// Only set auth source if there is a request for authentication via non-empty credentials.
   444  		if p.AuthSource == "" && (p.AuthMechanismProperties != nil || p.Username != "" || p.PasswordSet) {
   445  			p.AuthSource = dbName
   446  			if p.AuthSource == "" {
   447  				p.AuthSource = "admin"
   448  			}
   449  		}
   450  	default:
   451  		return fmt.Errorf("invalid auth mechanism")
   452  	}
   453  	return nil
   454  }
   455  
   456  func (p *parser) validateAuth() error {
   457  	switch strings.ToLower(p.AuthMechanism) {
   458  	case "mongodb-cr":
   459  		if p.Username == "" {
   460  			return fmt.Errorf("username required for MONGO-CR")
   461  		}
   462  		if p.Password == "" {
   463  			return fmt.Errorf("password required for MONGO-CR")
   464  		}
   465  		if p.AuthMechanismProperties != nil {
   466  			return fmt.Errorf("MONGO-CR cannot have mechanism properties")
   467  		}
   468  	case "mongodb-x509":
   469  		if p.Password != "" {
   470  			return fmt.Errorf("password cannot be specified for MONGO-X509")
   471  		}
   472  		if p.AuthMechanismProperties != nil {
   473  			return fmt.Errorf("MONGO-X509 cannot have mechanism properties")
   474  		}
   475  	case "mongodb-aws":
   476  		if p.Username != "" && p.Password == "" {
   477  			return fmt.Errorf("username without password is invalid for MONGODB-AWS")
   478  		}
   479  		if p.Username == "" && p.Password != "" {
   480  			return fmt.Errorf("password without username is invalid for MONGODB-AWS")
   481  		}
   482  		var token bool
   483  		for k := range p.AuthMechanismProperties {
   484  			if k != "AWS_SESSION_TOKEN" {
   485  				return fmt.Errorf("invalid auth property for MONGODB-AWS")
   486  			}
   487  			token = true
   488  		}
   489  		if token && p.Username == "" && p.Password == "" {
   490  			return fmt.Errorf("token without username and password is invalid for MONGODB-AWS")
   491  		}
   492  	case "gssapi":
   493  		if p.Username == "" {
   494  			return fmt.Errorf("username required for GSSAPI")
   495  		}
   496  		for k := range p.AuthMechanismProperties {
   497  			if k != "SERVICE_NAME" && k != "CANONICALIZE_HOST_NAME" && k != "SERVICE_REALM" && k != "SERVICE_HOST" {
   498  				return fmt.Errorf("invalid auth property for GSSAPI")
   499  			}
   500  		}
   501  	case "plain":
   502  		if p.Username == "" {
   503  			return fmt.Errorf("username required for PLAIN")
   504  		}
   505  		if p.Password == "" {
   506  			return fmt.Errorf("password required for PLAIN")
   507  		}
   508  		if p.AuthMechanismProperties != nil {
   509  			return fmt.Errorf("PLAIN cannot have mechanism properties")
   510  		}
   511  	case "scram-sha-1":
   512  		if p.Username == "" {
   513  			return fmt.Errorf("username required for SCRAM-SHA-1")
   514  		}
   515  		if p.Password == "" {
   516  			return fmt.Errorf("password required for SCRAM-SHA-1")
   517  		}
   518  		if p.AuthMechanismProperties != nil {
   519  			return fmt.Errorf("SCRAM-SHA-1 cannot have mechanism properties")
   520  		}
   521  	case "scram-sha-256":
   522  		if p.Username == "" {
   523  			return fmt.Errorf("username required for SCRAM-SHA-256")
   524  		}
   525  		if p.Password == "" {
   526  			return fmt.Errorf("password required for SCRAM-SHA-256")
   527  		}
   528  		if p.AuthMechanismProperties != nil {
   529  			return fmt.Errorf("SCRAM-SHA-256 cannot have mechanism properties")
   530  		}
   531  	case "":
   532  		if p.UsernameSet && p.Username == "" {
   533  			return fmt.Errorf("username required if URI contains user info")
   534  		}
   535  	default:
   536  		return fmt.Errorf("invalid auth mechanism")
   537  	}
   538  	return nil
   539  }
   540  
   541  func (p *parser) validateSSL() error {
   542  	if !p.SSL {
   543  		return nil
   544  	}
   545  
   546  	if p.SSLClientCertificateKeyFileSet {
   547  		if p.SSLCertificateFileSet || p.SSLPrivateKeyFileSet {
   548  			return errors.New("the sslClientCertificateKeyFile/tlsCertificateKeyFile URI option cannot be provided " +
   549  				"along with tlsCertificateFile or tlsPrivateKeyFile")
   550  		}
   551  		return nil
   552  	}
   553  	if p.SSLCertificateFileSet && !p.SSLPrivateKeyFileSet {
   554  		return errors.New("the tlsPrivateKeyFile URI option must be provided if the tlsCertificateFile option is specified")
   555  	}
   556  	if p.SSLPrivateKeyFileSet && !p.SSLCertificateFileSet {
   557  		return errors.New("the tlsCertificateFile URI option must be provided if the tlsPrivateKeyFile option is specified")
   558  	}
   559  
   560  	if p.SSLInsecureSet && p.SSLDisableOCSPEndpointCheckSet {
   561  		return errors.New("the sslInsecure/tlsInsecure URI option cannot be provided along with " +
   562  			"tlsDisableOCSPEndpointCheck ")
   563  	}
   564  	return nil
   565  }
   566  
   567  func (p *parser) addHost(host string) error {
   568  	if host == "" {
   569  		return nil
   570  	}
   571  	host, err := url.QueryUnescape(host)
   572  	if err != nil {
   573  		return internal.WrapErrorf(err, "invalid host %q", host)
   574  	}
   575  
   576  	_, port, err := net.SplitHostPort(host)
   577  	// this is unfortunate that SplitHostPort actually requires
   578  	// a port to exist.
   579  	if err != nil {
   580  		if addrError, ok := err.(*net.AddrError); !ok || addrError.Err != "missing port in address" {
   581  			return err
   582  		}
   583  	}
   584  
   585  	if port != "" {
   586  		d, err := strconv.Atoi(port)
   587  		if err != nil {
   588  			return internal.WrapErrorf(err, "port must be an integer")
   589  		}
   590  		if d <= 0 || d >= 65536 {
   591  			return fmt.Errorf("port must be in the range [1, 65535]")
   592  		}
   593  	}
   594  	p.Hosts = append(p.Hosts, host)
   595  	return nil
   596  }
   597  
   598  func (p *parser) addOption(pair string) error {
   599  	kv := strings.SplitN(pair, "=", 2)
   600  	if len(kv) != 2 || kv[0] == "" {
   601  		return fmt.Errorf("invalid option")
   602  	}
   603  
   604  	key, err := url.QueryUnescape(kv[0])
   605  	if err != nil {
   606  		return internal.WrapErrorf(err, "invalid option key %q", kv[0])
   607  	}
   608  
   609  	value, err := url.QueryUnescape(kv[1])
   610  	if err != nil {
   611  		return internal.WrapErrorf(err, "invalid option value %q", kv[1])
   612  	}
   613  
   614  	lowerKey := strings.ToLower(key)
   615  	switch lowerKey {
   616  	case "appname":
   617  		p.AppName = value
   618  	case "authmechanism":
   619  		p.AuthMechanism = value
   620  	case "authmechanismproperties":
   621  		p.AuthMechanismProperties = make(map[string]string)
   622  		pairs := strings.Split(value, ",")
   623  		for _, pair := range pairs {
   624  			kv := strings.SplitN(pair, ":", 2)
   625  			if len(kv) != 2 || kv[0] == "" {
   626  				return fmt.Errorf("invalid authMechanism property")
   627  			}
   628  			p.AuthMechanismProperties[kv[0]] = kv[1]
   629  		}
   630  		p.AuthMechanismPropertiesSet = true
   631  	case "authsource":
   632  		p.AuthSource = value
   633  		p.AuthSourceSet = true
   634  	case "compressors":
   635  		compressors := strings.Split(value, ",")
   636  		if len(compressors) < 1 {
   637  			return fmt.Errorf("must have at least 1 compressor")
   638  		}
   639  		p.Compressors = compressors
   640  	case "connect":
   641  		switch strings.ToLower(value) {
   642  		case "automatic":
   643  		case "direct":
   644  			p.Connect = SingleConnect
   645  		default:
   646  			return fmt.Errorf("invalid 'connect' value: %q", value)
   647  		}
   648  		if p.DirectConnectionSet {
   649  			expectedValue := p.Connect == SingleConnect // directConnection should be true if connect=direct
   650  			if p.DirectConnection != expectedValue {
   651  				return fmt.Errorf("options connect=%q and directConnection=%v conflict", value, p.DirectConnection)
   652  			}
   653  		}
   654  
   655  		p.ConnectSet = true
   656  	case "directconnection":
   657  		switch strings.ToLower(value) {
   658  		case "true":
   659  			p.DirectConnection = true
   660  		case "false":
   661  		default:
   662  			return fmt.Errorf("invalid 'directConnection' value: %q", value)
   663  		}
   664  
   665  		if p.ConnectSet {
   666  			expectedValue := AutoConnect
   667  			if p.DirectConnection {
   668  				expectedValue = SingleConnect
   669  			}
   670  
   671  			if p.Connect != expectedValue {
   672  				return fmt.Errorf("options connect=%q and directConnection=%q conflict", p.Connect, value)
   673  			}
   674  		}
   675  		p.DirectConnectionSet = true
   676  	case "connecttimeoutms":
   677  		n, err := strconv.Atoi(value)
   678  		if err != nil || n < 0 {
   679  			return fmt.Errorf("invalid value for %q: %q", key, value)
   680  		}
   681  		p.ConnectTimeout = time.Duration(n) * time.Millisecond
   682  		p.ConnectTimeoutSet = true
   683  	case "heartbeatintervalms", "heartbeatfrequencyms":
   684  		n, err := strconv.Atoi(value)
   685  		if err != nil || n < 0 {
   686  			return fmt.Errorf("invalid value for %q: %q", key, value)
   687  		}
   688  		p.HeartbeatInterval = time.Duration(n) * time.Millisecond
   689  		p.HeartbeatIntervalSet = true
   690  	case "journal":
   691  		switch value {
   692  		case "true":
   693  			p.J = true
   694  		case "false":
   695  			p.J = false
   696  		default:
   697  			return fmt.Errorf("invalid value for %q: %q", key, value)
   698  		}
   699  
   700  		p.JSet = true
   701  	case "loadbalanced":
   702  		switch value {
   703  		case "true":
   704  			p.LoadBalanced = true
   705  		case "false":
   706  			p.LoadBalanced = false
   707  		default:
   708  			return fmt.Errorf("invalid value for %q: %q", key, value)
   709  		}
   710  
   711  		p.LoadBalancedSet = true
   712  	case "localthresholdms":
   713  		n, err := strconv.Atoi(value)
   714  		if err != nil || n < 0 {
   715  			return fmt.Errorf("invalid value for %q: %q", key, value)
   716  		}
   717  		p.LocalThreshold = time.Duration(n) * time.Millisecond
   718  		p.LocalThresholdSet = true
   719  	case "maxidletimems":
   720  		n, err := strconv.Atoi(value)
   721  		if err != nil || n < 0 {
   722  			return fmt.Errorf("invalid value for %q: %q", key, value)
   723  		}
   724  		p.MaxConnIdleTime = time.Duration(n) * time.Millisecond
   725  		p.MaxConnIdleTimeSet = true
   726  	case "maxpoolsize":
   727  		n, err := strconv.Atoi(value)
   728  		if err != nil || n < 0 {
   729  			return fmt.Errorf("invalid value for %q: %q", key, value)
   730  		}
   731  		p.MaxPoolSize = uint64(n)
   732  		p.MaxPoolSizeSet = true
   733  	case "minpoolsize":
   734  		n, err := strconv.Atoi(value)
   735  		if err != nil || n < 0 {
   736  			return fmt.Errorf("invalid value for %q: %q", key, value)
   737  		}
   738  		p.MinPoolSize = uint64(n)
   739  		p.MinPoolSizeSet = true
   740  	case "maxconnecting":
   741  		n, err := strconv.Atoi(value)
   742  		if err != nil || n < 0 {
   743  			return fmt.Errorf("invalid value for %q: %q", key, value)
   744  		}
   745  		p.MaxConnecting = uint64(n)
   746  		p.MaxConnectingSet = true
   747  	case "readconcernlevel":
   748  		p.ReadConcernLevel = value
   749  	case "readpreference":
   750  		p.ReadPreference = value
   751  	case "readpreferencetags":
   752  		if value == "" {
   753  			// If "readPreferenceTags=" is supplied, append an empty map to tag sets to
   754  			// represent a wild-card.
   755  			p.ReadPreferenceTagSets = append(p.ReadPreferenceTagSets, map[string]string{})
   756  			break
   757  		}
   758  
   759  		tags := make(map[string]string)
   760  		items := strings.Split(value, ",")
   761  		for _, item := range items {
   762  			parts := strings.Split(item, ":")
   763  			if len(parts) != 2 {
   764  				return fmt.Errorf("invalid value for %q: %q", key, value)
   765  			}
   766  			tags[parts[0]] = parts[1]
   767  		}
   768  		p.ReadPreferenceTagSets = append(p.ReadPreferenceTagSets, tags)
   769  	case "maxstaleness", "maxstalenessseconds":
   770  		n, err := strconv.Atoi(value)
   771  		if err != nil || n < 0 {
   772  			return fmt.Errorf("invalid value for %q: %q", key, value)
   773  		}
   774  		p.MaxStaleness = time.Duration(n) * time.Second
   775  		p.MaxStalenessSet = true
   776  	case "replicaset":
   777  		p.ReplicaSet = value
   778  	case "retrywrites":
   779  		switch value {
   780  		case "true":
   781  			p.RetryWrites = true
   782  		case "false":
   783  			p.RetryWrites = false
   784  		default:
   785  			return fmt.Errorf("invalid value for %q: %q", key, value)
   786  		}
   787  
   788  		p.RetryWritesSet = true
   789  	case "retryreads":
   790  		switch value {
   791  		case "true":
   792  			p.RetryReads = true
   793  		case "false":
   794  			p.RetryReads = false
   795  		default:
   796  			return fmt.Errorf("invalid value for %q: %q", key, value)
   797  		}
   798  
   799  		p.RetryReadsSet = true
   800  	case "serverselectiontimeoutms":
   801  		n, err := strconv.Atoi(value)
   802  		if err != nil || n < 0 {
   803  			return fmt.Errorf("invalid value for %q: %q", key, value)
   804  		}
   805  		p.ServerSelectionTimeout = time.Duration(n) * time.Millisecond
   806  		p.ServerSelectionTimeoutSet = true
   807  	case "sockettimeoutms":
   808  		n, err := strconv.Atoi(value)
   809  		if err != nil || n < 0 {
   810  			return fmt.Errorf("invalid value for %q: %q", key, value)
   811  		}
   812  		p.SocketTimeout = time.Duration(n) * time.Millisecond
   813  		p.SocketTimeoutSet = true
   814  	case "srvmaxhosts":
   815  		// srvMaxHosts can only be set on URIs with the "mongodb+srv" scheme
   816  		if p.Scheme != SchemeMongoDBSRV {
   817  			return fmt.Errorf("cannot specify srvMaxHosts on non-SRV URI")
   818  		}
   819  
   820  		n, err := strconv.Atoi(value)
   821  		if err != nil || n < 0 {
   822  			return fmt.Errorf("invalid value for %q: %q", key, value)
   823  		}
   824  		p.SRVMaxHosts = n
   825  	case "srvservicename":
   826  		// srvServiceName can only be set on URIs with the "mongodb+srv" scheme
   827  		if p.Scheme != SchemeMongoDBSRV {
   828  			return fmt.Errorf("cannot specify srvServiceName on non-SRV URI")
   829  		}
   830  
   831  		// srvServiceName must be between 1 and 62 characters according to
   832  		// our specification. Empty service names are not valid, and the service
   833  		// name (including prepended underscore) should not exceed the 63 character
   834  		// limit for DNS query subdomains.
   835  		if len(value) < 1 || len(value) > 62 {
   836  			return fmt.Errorf("srvServiceName value must be between 1 and 62 characters")
   837  		}
   838  		p.SRVServiceName = value
   839  	case "ssl", "tls":
   840  		switch value {
   841  		case "true":
   842  			p.SSL = true
   843  		case "false":
   844  			p.SSL = false
   845  		default:
   846  			return fmt.Errorf("invalid value for %q: %q", key, value)
   847  		}
   848  		if p.tlsssl != nil && *p.tlsssl != p.SSL {
   849  			return errors.New("tls and ssl options, when both specified, must be equivalent")
   850  		}
   851  
   852  		p.tlsssl = new(bool)
   853  		*p.tlsssl = p.SSL
   854  
   855  		p.SSLSet = true
   856  	case "sslclientcertificatekeyfile", "tlscertificatekeyfile":
   857  		p.SSL = true
   858  		p.SSLSet = true
   859  		p.SSLClientCertificateKeyFile = value
   860  		p.SSLClientCertificateKeyFileSet = true
   861  	case "sslclientcertificatekeypassword", "tlscertificatekeyfilepassword":
   862  		p.SSLClientCertificateKeyPassword = func() string { return value }
   863  		p.SSLClientCertificateKeyPasswordSet = true
   864  	case "tlscertificatefile":
   865  		p.SSL = true
   866  		p.SSLSet = true
   867  		p.SSLCertificateFile = value
   868  		p.SSLCertificateFileSet = true
   869  	case "tlsprivatekeyfile":
   870  		p.SSL = true
   871  		p.SSLSet = true
   872  		p.SSLPrivateKeyFile = value
   873  		p.SSLPrivateKeyFileSet = true
   874  	case "sslinsecure", "tlsinsecure":
   875  		switch value {
   876  		case "true":
   877  			p.SSLInsecure = true
   878  		case "false":
   879  			p.SSLInsecure = false
   880  		default:
   881  			return fmt.Errorf("invalid value for %q: %q", key, value)
   882  		}
   883  
   884  		p.SSLInsecureSet = true
   885  	case "sslcertificateauthorityfile", "tlscafile":
   886  		p.SSL = true
   887  		p.SSLSet = true
   888  		p.SSLCaFile = value
   889  		p.SSLCaFileSet = true
   890  	case "timeoutms":
   891  		n, err := strconv.Atoi(value)
   892  		if err != nil || n < 0 {
   893  			return fmt.Errorf("invalid value for %q: %q", key, value)
   894  		}
   895  		p.Timeout = time.Duration(n) * time.Millisecond
   896  		p.TimeoutSet = true
   897  	case "tlsdisableocspendpointcheck":
   898  		p.SSL = true
   899  		p.SSLSet = true
   900  
   901  		switch value {
   902  		case "true":
   903  			p.SSLDisableOCSPEndpointCheck = true
   904  		case "false":
   905  			p.SSLDisableOCSPEndpointCheck = false
   906  		default:
   907  			return fmt.Errorf("invalid value for %q: %q", key, value)
   908  		}
   909  		p.SSLDisableOCSPEndpointCheckSet = true
   910  	case "w":
   911  		if w, err := strconv.Atoi(value); err == nil {
   912  			if w < 0 {
   913  				return fmt.Errorf("invalid value for %q: %q", key, value)
   914  			}
   915  
   916  			p.WNumber = w
   917  			p.WNumberSet = true
   918  			p.WString = ""
   919  			break
   920  		}
   921  
   922  		p.WString = value
   923  		p.WNumberSet = false
   924  
   925  	case "wtimeoutms":
   926  		n, err := strconv.Atoi(value)
   927  		if err != nil || n < 0 {
   928  			return fmt.Errorf("invalid value for %q: %q", key, value)
   929  		}
   930  		p.WTimeout = time.Duration(n) * time.Millisecond
   931  		p.WTimeoutSet = true
   932  	case "wtimeout":
   933  		// Defer to wtimeoutms, but not to a manually-set option.
   934  		if p.WTimeoutSet {
   935  			break
   936  		}
   937  		n, err := strconv.Atoi(value)
   938  		if err != nil || n < 0 {
   939  			return fmt.Errorf("invalid value for %q: %q", key, value)
   940  		}
   941  		p.WTimeout = time.Duration(n) * time.Millisecond
   942  	case "zlibcompressionlevel":
   943  		level, err := strconv.Atoi(value)
   944  		if err != nil || (level < -1 || level > 9) {
   945  			return fmt.Errorf("invalid value for %q: %q", key, value)
   946  		}
   947  
   948  		if level == -1 {
   949  			level = wiremessage.DefaultZlibLevel
   950  		}
   951  		p.ZlibLevel = level
   952  		p.ZlibLevelSet = true
   953  	case "zstdcompressionlevel":
   954  		const maxZstdLevel = 22 // https://github.com/facebook/zstd/blob/a880ca239b447968493dd2fed3850e766d6305cc/contrib/linux-kernel/lib/zstd/compress.c#L3291
   955  		level, err := strconv.Atoi(value)
   956  		if err != nil || (level < -1 || level > maxZstdLevel) {
   957  			return fmt.Errorf("invalid value for %q: %q", key, value)
   958  		}
   959  
   960  		if level == -1 {
   961  			level = wiremessage.DefaultZstdLevel
   962  		}
   963  		p.ZstdLevel = level
   964  		p.ZstdLevelSet = true
   965  	default:
   966  		if p.UnknownOptions == nil {
   967  			p.UnknownOptions = make(map[string][]string)
   968  		}
   969  		p.UnknownOptions[lowerKey] = append(p.UnknownOptions[lowerKey], value)
   970  	}
   971  
   972  	if p.Options == nil {
   973  		p.Options = make(map[string][]string)
   974  	}
   975  	p.Options[lowerKey] = append(p.Options[lowerKey], value)
   976  
   977  	return nil
   978  }
   979  
   980  func extractQueryArgsFromURI(uri string) ([]string, error) {
   981  	if len(uri) == 0 {
   982  		return nil, nil
   983  	}
   984  
   985  	if uri[0] != '?' {
   986  		return nil, errors.New("must have a ? separator between path and query")
   987  	}
   988  
   989  	uri = uri[1:]
   990  	if len(uri) == 0 {
   991  		return nil, nil
   992  	}
   993  	return strings.FieldsFunc(uri, func(r rune) bool { return r == ';' || r == '&' }), nil
   994  
   995  }
   996  
   997  type extractedDatabase struct {
   998  	uri string
   999  	db  string
  1000  }
  1001  
  1002  // extractDatabaseFromURI is a helper function to retrieve information about
  1003  // the database from the passed in URI. It accepts as an argument the currently
  1004  // parsed URI and returns the remainder of the uri, the database it found,
  1005  // and any error it encounters while parsing.
  1006  func extractDatabaseFromURI(uri string) (extractedDatabase, error) {
  1007  	if len(uri) == 0 {
  1008  		return extractedDatabase{}, nil
  1009  	}
  1010  
  1011  	if uri[0] != '/' {
  1012  		return extractedDatabase{}, errors.New("must have a / separator between hosts and path")
  1013  	}
  1014  
  1015  	uri = uri[1:]
  1016  	if len(uri) == 0 {
  1017  		return extractedDatabase{}, nil
  1018  	}
  1019  
  1020  	database := uri
  1021  	if idx := strings.IndexRune(uri, '?'); idx != -1 {
  1022  		database = uri[:idx]
  1023  	}
  1024  
  1025  	escapedDatabase, err := url.QueryUnescape(database)
  1026  	if err != nil {
  1027  		return extractedDatabase{}, internal.WrapErrorf(err, "invalid database %q", database)
  1028  	}
  1029  
  1030  	uri = uri[len(database):]
  1031  
  1032  	return extractedDatabase{
  1033  		uri: uri,
  1034  		db:  escapedDatabase,
  1035  	}, nil
  1036  }