github.com/snowflakedb/gosnowflake@v1.9.0/dsn.go (about) 1 // Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved. 2 3 package gosnowflake 4 5 import ( 6 "crypto/rsa" 7 "crypto/x509" 8 "encoding/base64" 9 "encoding/pem" 10 "errors" 11 "fmt" 12 "net" 13 "net/http" 14 "net/url" 15 "os" 16 "strconv" 17 "strings" 18 "time" 19 ) 20 21 const ( 22 defaultClientTimeout = 900 * time.Second // Timeout for network round trip + read out http response 23 defaultJWTClientTimeout = 10 * time.Second // Timeout for network round trip + read out http response but used for JWT auth 24 defaultLoginTimeout = 300 * time.Second // Timeout for retry for login EXCLUDING clientTimeout 25 defaultRequestTimeout = 0 * time.Second // Timeout for retry for request EXCLUDING clientTimeout 26 defaultJWTTimeout = 60 * time.Second 27 defaultExternalBrowserTimeout = 120 * time.Second // Timeout for external browser login 28 defaultMaxRetryCount = 7 // specifies maximum number of subsequent retries 29 defaultDomain = ".snowflakecomputing.com" 30 ) 31 32 // ConfigBool is a type to represent true or false in the Config 33 type ConfigBool uint8 34 35 const ( 36 configBoolNotSet ConfigBool = iota // Reserved for unset to let default value fall into this category 37 // ConfigBoolTrue represents true for the config field 38 ConfigBoolTrue 39 // ConfigBoolFalse represents false for the config field 40 ConfigBoolFalse 41 ) 42 43 // Config is a set of configuration parameters 44 type Config struct { 45 Account string // Account name 46 User string // Username 47 Password string // Password (requires User) 48 Database string // Database name 49 Schema string // Schema 50 Warehouse string // Warehouse 51 Role string // Role 52 Region string // Region 53 54 // ValidateDefaultParameters disable the validation checks for Database, Schema, Warehouse and Role 55 // at the time a connection is established 56 ValidateDefaultParameters ConfigBool 57 58 Params map[string]*string // other connection parameters 59 60 ClientIP net.IP // IP address for network check 61 Protocol string // http or https (optional) 62 Host string // hostname (optional) 63 Port int // port (optional) 64 65 Authenticator AuthType // The authenticator type 66 67 Passcode string 68 PasscodeInPassword bool 69 70 OktaURL *url.URL 71 72 LoginTimeout time.Duration // Login retry timeout EXCLUDING network roundtrip and read out http response 73 RequestTimeout time.Duration // request retry timeout EXCLUDING network roundtrip and read out http response 74 JWTExpireTimeout time.Duration // JWT expire after timeout 75 ClientTimeout time.Duration // Timeout for network round trip + read out http response 76 JWTClientTimeout time.Duration // Timeout for network round trip + read out http response used when JWT token auth is taking place 77 ExternalBrowserTimeout time.Duration // Timeout for external browser login 78 MaxRetryCount int // Specifies how many times non-periodic HTTP request can be retried 79 80 Application string // application name. 81 InsecureMode bool // driver doesn't check certificate revocation status 82 OCSPFailOpen OCSPFailOpenMode // OCSP Fail Open 83 84 Token string // Token to use for OAuth other forms of token based auth 85 TokenAccessor TokenAccessor // Optional token accessor to use 86 KeepSessionAlive bool // Enables the session to persist even after the connection is closed 87 88 PrivateKey *rsa.PrivateKey // Private key used to sign JWT 89 90 Transporter http.RoundTripper // RoundTripper to intercept HTTP requests and responses 91 92 DisableTelemetry bool // indicates whether to disable telemetry 93 94 Tracing string // sets logging level 95 96 TmpDirPath string // sets temporary directory used by a driver for operations like encrypting, compressing etc 97 98 MfaToken string // Internally used to cache the MFA token 99 IDToken string // Internally used to cache the Id Token for external browser 100 ClientRequestMfaToken ConfigBool // When true the MFA token is cached in the credential manager. True by default in Windows/OSX. False for Linux. 101 ClientStoreTemporaryCredential ConfigBool // When true the ID token is cached in the credential manager. True by default in Windows/OSX. False for Linux. 102 103 DisableQueryContextCache bool // Should HTAP query context cache be disabled 104 105 IncludeRetryReason ConfigBool // Should retried request contain retry reason 106 107 ClientConfigFile string // File path to the client configuration json file 108 109 DisableConsoleLogin ConfigBool // Indicates whether console login should be disabled 110 } 111 112 // Validate enables testing if config is correct. 113 // A driver client may call it manually, but it is also called during opening first connection. 114 func (c *Config) Validate() error { 115 if c.TmpDirPath != "" { 116 if _, err := os.Stat(c.TmpDirPath); err != nil { 117 return err 118 } 119 } 120 return nil 121 } 122 123 // ocspMode returns the OCSP mode in string INSECURE, FAIL_OPEN, FAIL_CLOSED 124 func (c *Config) ocspMode() string { 125 if c.InsecureMode { 126 return ocspModeInsecure 127 } else if c.OCSPFailOpen == ocspFailOpenNotSet || c.OCSPFailOpen == OCSPFailOpenTrue { 128 // by default or set to true 129 return ocspModeFailOpen 130 } 131 return ocspModeFailClosed 132 } 133 134 // DSN constructs a DSN for Snowflake db. 135 func DSN(cfg *Config) (dsn string, err error) { 136 hasHost := true 137 if cfg.Host == "" { 138 hasHost = false 139 if cfg.Region == "us-west-2" { 140 cfg.Region = "" 141 } 142 if cfg.Region == "" { 143 cfg.Host = cfg.Account + defaultDomain 144 } else { 145 cfg.Host = cfg.Account + "." + cfg.Region + defaultDomain 146 } 147 } 148 // in case account includes region 149 posDot := strings.Index(cfg.Account, ".") 150 if posDot > 0 { 151 if cfg.Region != "" { 152 return "", errInvalidRegion() 153 } 154 cfg.Region = cfg.Account[posDot+1:] 155 cfg.Account = cfg.Account[:posDot] 156 } 157 err = fillMissingConfigParameters(cfg) 158 if err != nil { 159 return "", err 160 } 161 params := &url.Values{} 162 if hasHost && cfg.Account != "" { 163 // account may not be included in a Host string 164 params.Add("account", cfg.Account) 165 } 166 if cfg.Database != "" { 167 params.Add("database", cfg.Database) 168 } 169 if cfg.Schema != "" { 170 params.Add("schema", cfg.Schema) 171 } 172 if cfg.Warehouse != "" { 173 params.Add("warehouse", cfg.Warehouse) 174 } 175 if cfg.Role != "" { 176 params.Add("role", cfg.Role) 177 } 178 if cfg.Region != "" { 179 params.Add("region", cfg.Region) 180 } 181 if cfg.Authenticator != AuthTypeSnowflake { 182 if cfg.Authenticator == AuthTypeOkta { 183 params.Add("authenticator", strings.ToLower(cfg.OktaURL.String())) 184 } else { 185 params.Add("authenticator", strings.ToLower(cfg.Authenticator.String())) 186 } 187 } 188 if cfg.Passcode != "" { 189 params.Add("passcode", cfg.Passcode) 190 } 191 if cfg.PasscodeInPassword { 192 params.Add("passcodeInPassword", strconv.FormatBool(cfg.PasscodeInPassword)) 193 } 194 if cfg.ClientTimeout != defaultClientTimeout { 195 params.Add("clientTimeout", strconv.FormatInt(int64(cfg.ClientTimeout/time.Second), 10)) 196 } 197 if cfg.JWTClientTimeout != defaultJWTClientTimeout { 198 params.Add("jwtClientTimeout", strconv.FormatInt(int64(cfg.JWTClientTimeout/time.Second), 10)) 199 } 200 if cfg.LoginTimeout != defaultLoginTimeout { 201 params.Add("loginTimeout", strconv.FormatInt(int64(cfg.LoginTimeout/time.Second), 10)) 202 } 203 if cfg.RequestTimeout != defaultRequestTimeout { 204 params.Add("requestTimeout", strconv.FormatInt(int64(cfg.RequestTimeout/time.Second), 10)) 205 } 206 if cfg.JWTExpireTimeout != defaultJWTTimeout { 207 params.Add("jwtTimeout", strconv.FormatInt(int64(cfg.JWTExpireTimeout/time.Second), 10)) 208 } 209 if cfg.ExternalBrowserTimeout != defaultExternalBrowserTimeout { 210 params.Add("externalBrowserTimeout", strconv.FormatInt(int64(cfg.ExternalBrowserTimeout/time.Second), 10)) 211 } 212 if cfg.MaxRetryCount != defaultMaxRetryCount { 213 params.Add("maxRetryCount", strconv.Itoa(cfg.MaxRetryCount)) 214 } 215 if cfg.Application != clientType { 216 params.Add("application", cfg.Application) 217 } 218 if cfg.Protocol != "" && cfg.Protocol != "https" { 219 params.Add("protocol", cfg.Protocol) 220 } 221 if cfg.Token != "" { 222 params.Add("token", cfg.Token) 223 } 224 if cfg.Params != nil { 225 for k, v := range cfg.Params { 226 params.Add(k, *v) 227 } 228 } 229 if cfg.PrivateKey != nil { 230 privateKeyInBytes, err := marshalPKCS8PrivateKey(cfg.PrivateKey) 231 if err != nil { 232 return "", err 233 } 234 keyBase64 := base64.URLEncoding.EncodeToString(privateKeyInBytes) 235 params.Add("privateKey", keyBase64) 236 } 237 if cfg.InsecureMode { 238 params.Add("insecureMode", strconv.FormatBool(cfg.InsecureMode)) 239 } 240 if cfg.Tracing != "" { 241 params.Add("tracing", cfg.Tracing) 242 } 243 if cfg.TmpDirPath != "" { 244 params.Add("tmpDirPath", cfg.TmpDirPath) 245 } 246 if cfg.DisableQueryContextCache { 247 params.Add("disableQueryContextCache", "true") 248 } 249 if cfg.IncludeRetryReason == ConfigBoolFalse { 250 params.Add("includeRetryReason", "false") 251 } 252 253 params.Add("ocspFailOpen", strconv.FormatBool(cfg.OCSPFailOpen != OCSPFailOpenFalse)) 254 255 params.Add("validateDefaultParameters", strconv.FormatBool(cfg.ValidateDefaultParameters != ConfigBoolFalse)) 256 257 if cfg.ClientRequestMfaToken != configBoolNotSet { 258 params.Add("clientRequestMfaToken", strconv.FormatBool(cfg.ClientRequestMfaToken != ConfigBoolFalse)) 259 } 260 261 if cfg.ClientStoreTemporaryCredential != configBoolNotSet { 262 params.Add("clientStoreTemporaryCredential", strconv.FormatBool(cfg.ClientStoreTemporaryCredential != ConfigBoolFalse)) 263 } 264 if cfg.ClientConfigFile != "" { 265 params.Add("clientConfigFile", cfg.ClientConfigFile) 266 } 267 if cfg.DisableConsoleLogin != configBoolNotSet { 268 params.Add("disableConsoleLogin", strconv.FormatBool(cfg.DisableConsoleLogin != ConfigBoolFalse)) 269 } 270 271 dsn = fmt.Sprintf("%v:%v@%v:%v", url.QueryEscape(cfg.User), url.QueryEscape(cfg.Password), cfg.Host, cfg.Port) 272 if params.Encode() != "" { 273 dsn += "?" + params.Encode() 274 } 275 return 276 } 277 278 // ParseDSN parses the DSN string to a Config. 279 func ParseDSN(dsn string) (cfg *Config, err error) { 280 // New config with some default values 281 cfg = &Config{ 282 Params: make(map[string]*string), 283 Authenticator: AuthTypeSnowflake, // Default to snowflake 284 } 285 286 // user[:password]@account/database/schema[?param1=value1¶mN=valueN] 287 // or 288 // user[:password]@account/database[?param1=value1¶mN=valueN] 289 // or 290 // user[:password]@host:port/database/schema?account=user_account[?param1=value1¶mN=valueN] 291 // or 292 // host:port/database/schema?account=user_account[?param1=value1¶mN=valueN] 293 294 foundSlash := false 295 secondSlash := false 296 done := false 297 var i int 298 posQuestion := len(dsn) 299 for i = len(dsn) - 1; i >= 0; i-- { 300 switch { 301 case dsn[i] == '/': 302 foundSlash = true 303 304 // left part is empty if i <= 0 305 var j int 306 posSecondSlash := i 307 if i > 0 { 308 for j = i - 1; j >= 0; j-- { 309 switch { 310 case dsn[j] == '/': 311 // second slash 312 secondSlash = true 313 posSecondSlash = j 314 case dsn[j] == '@': 315 // username[:password]@... 316 cfg.User, cfg.Password = parseUserPassword(j, dsn) 317 } 318 if dsn[j] == '@' { 319 break 320 } 321 } 322 323 // account or host:port 324 err = parseAccountHostPort(cfg, j, posSecondSlash, dsn) 325 if err != nil { 326 return nil, err 327 } 328 } 329 // [?param1=value1&...¶mN=valueN] 330 // Find the first '?' in dsn[i+1:] 331 err = parseParams(cfg, i, dsn) 332 if err != nil { 333 return 334 } 335 if secondSlash { 336 cfg.Database = dsn[posSecondSlash+1 : i] 337 cfg.Schema = dsn[i+1 : posQuestion] 338 } else { 339 cfg.Database = dsn[posSecondSlash+1 : posQuestion] 340 } 341 done = true 342 case dsn[i] == '?': 343 posQuestion = i 344 } 345 if done { 346 break 347 } 348 } 349 if !foundSlash { 350 // no db or schema is specified 351 var j int 352 for j = len(dsn) - 1; j >= 0; j-- { 353 switch { 354 case dsn[j] == '@': 355 cfg.User, cfg.Password = parseUserPassword(j, dsn) 356 case dsn[j] == '?': 357 posQuestion = j 358 } 359 if dsn[j] == '@' { 360 break 361 } 362 } 363 err = parseAccountHostPort(cfg, j, posQuestion, dsn) 364 if err != nil { 365 return nil, err 366 } 367 err = parseParams(cfg, posQuestion-1, dsn) 368 if err != nil { 369 return 370 } 371 } 372 if cfg.Account == "" && strings.HasSuffix(cfg.Host, defaultDomain) { 373 posDot := strings.Index(cfg.Host, ".") 374 if posDot > 0 { 375 cfg.Account = cfg.Host[:posDot] 376 } 377 } 378 posDot := strings.Index(cfg.Account, ".") 379 if posDot >= 0 { 380 cfg.Account = cfg.Account[:posDot] 381 } 382 383 err = fillMissingConfigParameters(cfg) 384 if err != nil { 385 return nil, err 386 } 387 388 // unescape parameters 389 var s string 390 s, err = url.QueryUnescape(cfg.User) 391 if err != nil { 392 return nil, err 393 } 394 cfg.User = s 395 s, err = url.QueryUnescape(cfg.Password) 396 if err != nil { 397 return nil, err 398 } 399 cfg.Password = s 400 s, err = url.QueryUnescape(cfg.Database) 401 if err != nil { 402 return nil, err 403 } 404 cfg.Database = s 405 s, err = url.QueryUnescape(cfg.Schema) 406 if err != nil { 407 return nil, err 408 } 409 cfg.Schema = s 410 s, err = url.QueryUnescape(cfg.Role) 411 if err != nil { 412 return nil, err 413 } 414 cfg.Role = s 415 s, err = url.QueryUnescape(cfg.Warehouse) 416 if err != nil { 417 return nil, err 418 } 419 cfg.Warehouse = s 420 return cfg, nil 421 } 422 423 func fillMissingConfigParameters(cfg *Config) error { 424 posDash := strings.LastIndex(cfg.Account, "-") 425 if posDash > 0 { 426 if strings.Contains(cfg.Host, ".global.") { 427 cfg.Account = cfg.Account[:posDash] 428 } 429 } 430 if strings.Trim(cfg.Account, " ") == "" { 431 return errEmptyAccount() 432 } 433 434 if authRequiresUser(cfg) && strings.TrimSpace(cfg.User) == "" { 435 return errEmptyUsername() 436 } 437 438 if authRequiresPassword(cfg) && strings.TrimSpace(cfg.Password) == "" { 439 return errEmptyPassword() 440 } 441 if strings.Trim(cfg.Protocol, " ") == "" { 442 cfg.Protocol = "https" 443 } 444 if cfg.Port == 0 { 445 cfg.Port = 443 446 } 447 448 cfg.Region = strings.Trim(cfg.Region, " ") 449 if cfg.Region != "" { 450 // region is specified but not included in Host 451 i := strings.Index(cfg.Host, defaultDomain) 452 if i >= 1 { 453 hostPrefix := cfg.Host[0:i] 454 if !strings.HasSuffix(hostPrefix, cfg.Region) { 455 cfg.Host = hostPrefix + "." + cfg.Region + defaultDomain 456 } 457 } 458 } 459 if cfg.Host == "" { 460 if cfg.Region != "" { 461 cfg.Host = cfg.Account + "." + cfg.Region + defaultDomain 462 } else { 463 cfg.Host = cfg.Account + defaultDomain 464 } 465 } 466 if cfg.LoginTimeout == 0 { 467 cfg.LoginTimeout = defaultLoginTimeout 468 } 469 if cfg.RequestTimeout == 0 { 470 cfg.RequestTimeout = defaultRequestTimeout 471 } 472 if cfg.JWTExpireTimeout == 0 { 473 cfg.JWTExpireTimeout = defaultJWTTimeout 474 } 475 if cfg.ClientTimeout == 0 { 476 cfg.ClientTimeout = defaultClientTimeout 477 } 478 if cfg.JWTClientTimeout == 0 { 479 cfg.JWTClientTimeout = defaultJWTClientTimeout 480 } 481 if cfg.ExternalBrowserTimeout == 0 { 482 cfg.ExternalBrowserTimeout = defaultExternalBrowserTimeout 483 } 484 if cfg.MaxRetryCount == 0 { 485 cfg.MaxRetryCount = defaultMaxRetryCount 486 } 487 if strings.Trim(cfg.Application, " ") == "" { 488 cfg.Application = clientType 489 } 490 491 if cfg.OCSPFailOpen == ocspFailOpenNotSet { 492 cfg.OCSPFailOpen = OCSPFailOpenTrue 493 } 494 495 if cfg.ValidateDefaultParameters == configBoolNotSet { 496 cfg.ValidateDefaultParameters = ConfigBoolTrue 497 } 498 499 if cfg.IncludeRetryReason == configBoolNotSet { 500 cfg.IncludeRetryReason = ConfigBoolTrue 501 } 502 503 if strings.HasSuffix(cfg.Host, defaultDomain) && len(cfg.Host) == len(defaultDomain) { 504 return &SnowflakeError{ 505 Number: ErrCodeFailedToParseHost, 506 Message: errMsgFailedToParseHost, 507 MessageArgs: []interface{}{cfg.Host}, 508 } 509 } 510 return nil 511 } 512 513 func authRequiresUser(cfg *Config) bool { 514 return cfg.Authenticator != AuthTypeOAuth && 515 cfg.Authenticator != AuthTypeTokenAccessor && 516 cfg.Authenticator != AuthTypeExternalBrowser 517 } 518 519 func authRequiresPassword(cfg *Config) bool { 520 return cfg.Authenticator != AuthTypeOAuth && 521 cfg.Authenticator != AuthTypeTokenAccessor && 522 cfg.Authenticator != AuthTypeExternalBrowser && 523 cfg.Authenticator != AuthTypeJwt 524 } 525 526 // transformAccountToHost transforms host to account name 527 func transformAccountToHost(cfg *Config) (err error) { 528 if cfg.Port == 0 && !strings.HasSuffix(cfg.Host, defaultDomain) && cfg.Host != "" { 529 // account name is specified instead of host:port 530 cfg.Account = cfg.Host 531 cfg.Host = cfg.Account + defaultDomain 532 cfg.Port = 443 533 posDot := strings.Index(cfg.Account, ".") 534 if posDot > 0 { 535 cfg.Region = cfg.Account[posDot+1:] 536 cfg.Account = cfg.Account[:posDot] 537 } 538 } 539 return nil 540 } 541 542 // parseAccountHostPort parses the DSN string to attempt to get account or host and port. 543 func parseAccountHostPort(cfg *Config, posAt, posSlash int, dsn string) (err error) { 544 // account or host:port 545 var k int 546 for k = posAt + 1; k < posSlash; k++ { 547 if dsn[k] == ':' { 548 cfg.Port, err = strconv.Atoi(dsn[k+1 : posSlash]) 549 if err != nil { 550 err = &SnowflakeError{ 551 Number: ErrCodeFailedToParsePort, 552 Message: errMsgFailedToParsePort, 553 MessageArgs: []interface{}{dsn[k+1 : posSlash]}, 554 } 555 return 556 } 557 break 558 } 559 } 560 cfg.Host = dsn[posAt+1 : k] 561 return transformAccountToHost(cfg) 562 } 563 564 // parseUserPassword parses the DSN string for username and password 565 func parseUserPassword(posAt int, dsn string) (user, password string) { 566 var k int 567 for k = 0; k < posAt; k++ { 568 if dsn[k] == ':' { 569 password = dsn[k+1 : posAt] 570 break 571 } 572 } 573 user = dsn[:k] 574 return 575 } 576 577 // parseParams parse parameters 578 func parseParams(cfg *Config, posQuestion int, dsn string) (err error) { 579 for j := posQuestion + 1; j < len(dsn); j++ { 580 if dsn[j] == '?' { 581 if err = parseDSNParams(cfg, dsn[j+1:]); err != nil { 582 return 583 } 584 break 585 } 586 } 587 return 588 } 589 590 // parseDSNParams parses the DSN "query string". Values must be url.QueryEscape'ed 591 func parseDSNParams(cfg *Config, params string) (err error) { 592 logger.Infof("Query String: %v\n", params) 593 for _, v := range strings.Split(params, "&") { 594 param := strings.SplitN(v, "=", 2) 595 if len(param) != 2 { 596 continue 597 } 598 var value string 599 value, err = url.QueryUnescape(param[1]) 600 if err != nil { 601 return err 602 } 603 switch param[0] { 604 // Disable INFILE whitelist / enable all files 605 case "account": 606 cfg.Account = value 607 case "warehouse": 608 cfg.Warehouse = value 609 case "database": 610 cfg.Database = value 611 case "schema": 612 cfg.Schema = value 613 case "role": 614 cfg.Role = value 615 case "region": 616 cfg.Region = value 617 case "protocol": 618 cfg.Protocol = value 619 case "passcode": 620 cfg.Passcode = value 621 case "passcodeInPassword": 622 var vv bool 623 vv, err = strconv.ParseBool(value) 624 if err != nil { 625 return 626 } 627 cfg.PasscodeInPassword = vv 628 case "clientTimeout": 629 cfg.ClientTimeout, err = parseTimeout(value) 630 if err != nil { 631 return 632 } 633 case "jwtClientTimeout": 634 cfg.JWTClientTimeout, err = parseTimeout(value) 635 if err != nil { 636 return 637 } 638 case "loginTimeout": 639 cfg.LoginTimeout, err = parseTimeout(value) 640 if err != nil { 641 return 642 } 643 case "requestTimeout": 644 cfg.RequestTimeout, err = parseTimeout(value) 645 if err != nil { 646 return 647 } 648 case "jwtTimeout": 649 cfg.JWTExpireTimeout, err = parseTimeout(value) 650 if err != nil { 651 return err 652 } 653 case "externalBrowserTimeout": 654 cfg.ExternalBrowserTimeout, err = parseTimeout(value) 655 if err != nil { 656 return err 657 } 658 case "maxRetryCount": 659 cfg.MaxRetryCount, err = strconv.Atoi(value) 660 if err != nil { 661 return err 662 } 663 case "application": 664 cfg.Application = value 665 case "authenticator": 666 err := determineAuthenticatorType(cfg, value) 667 if err != nil { 668 return err 669 } 670 case "insecureMode": 671 var vv bool 672 vv, err = strconv.ParseBool(value) 673 if err != nil { 674 return 675 } 676 cfg.InsecureMode = vv 677 case "ocspFailOpen": 678 var vv bool 679 vv, err = strconv.ParseBool(value) 680 if err != nil { 681 return 682 } 683 if vv { 684 cfg.OCSPFailOpen = OCSPFailOpenTrue 685 } else { 686 cfg.OCSPFailOpen = OCSPFailOpenFalse 687 } 688 689 case "token": 690 cfg.Token = value 691 case "privateKey": 692 var decodeErr error 693 block, decodeErr := base64.URLEncoding.DecodeString(value) 694 if decodeErr != nil { 695 err = &SnowflakeError{ 696 Number: ErrCodePrivateKeyParseError, 697 Message: "Base64 decode failed", 698 } 699 return 700 } 701 cfg.PrivateKey, err = parsePKCS8PrivateKey(block) 702 if err != nil { 703 return err 704 } 705 case "validateDefaultParameters": 706 var vv bool 707 vv, err = strconv.ParseBool(value) 708 if err != nil { 709 return 710 } 711 if vv { 712 cfg.ValidateDefaultParameters = ConfigBoolTrue 713 } else { 714 cfg.ValidateDefaultParameters = ConfigBoolFalse 715 } 716 case "clientRequestMfaToken": 717 var vv bool 718 vv, err = strconv.ParseBool(value) 719 if err != nil { 720 return 721 } 722 if vv { 723 cfg.ClientRequestMfaToken = ConfigBoolTrue 724 } else { 725 cfg.ClientRequestMfaToken = ConfigBoolFalse 726 } 727 case "clientStoreTemporaryCredential": 728 var vv bool 729 vv, err = strconv.ParseBool(value) 730 if err != nil { 731 return 732 } 733 if vv { 734 cfg.ClientStoreTemporaryCredential = ConfigBoolTrue 735 } else { 736 cfg.ClientStoreTemporaryCredential = ConfigBoolFalse 737 } 738 case "tracing": 739 cfg.Tracing = value 740 case "tmpDirPath": 741 cfg.TmpDirPath = value 742 case "disableQueryContextCache": 743 var b bool 744 b, err = strconv.ParseBool(value) 745 if err != nil { 746 return 747 } 748 cfg.DisableQueryContextCache = b 749 case "includeRetryReason": 750 var vv bool 751 vv, err = strconv.ParseBool(value) 752 if err != nil { 753 return 754 } 755 if vv { 756 cfg.IncludeRetryReason = ConfigBoolTrue 757 } else { 758 cfg.IncludeRetryReason = ConfigBoolFalse 759 } 760 case "clientConfigFile": 761 cfg.ClientConfigFile = value 762 case "disableConsoleLogin": 763 var vv bool 764 vv, err = strconv.ParseBool(value) 765 if err != nil { 766 return 767 } 768 if vv { 769 cfg.DisableConsoleLogin = ConfigBoolTrue 770 } else { 771 cfg.DisableConsoleLogin = ConfigBoolFalse 772 } 773 default: 774 if cfg.Params == nil { 775 cfg.Params = make(map[string]*string) 776 } 777 cfg.Params[param[0]] = &value 778 } 779 } 780 return 781 } 782 783 func parseTimeout(value string) (time.Duration, error) { 784 var vv int64 785 var err error 786 vv, err = strconv.ParseInt(value, 10, 64) 787 if err != nil { 788 return time.Duration(0), err 789 } 790 return time.Duration(vv * int64(time.Second)), nil 791 } 792 793 // ConfigParam is used to bind the name of the Config field with the environment variable and set the requirement for it 794 type ConfigParam struct { 795 Name string 796 EnvName string 797 FailOnMissing bool 798 } 799 800 // GetConfigFromEnv is used to parse the environment variable values to specific fields of the Config 801 func GetConfigFromEnv(properties []*ConfigParam) (*Config, error) { 802 var account, user, password, role, host, portStr, protocol, warehouse, database, schema, region, passcode, application string 803 var privateKey *rsa.PrivateKey 804 var err error 805 if len(properties) == 0 || properties == nil { 806 return nil, errors.New("missing configuration parameters for the connection") 807 } 808 for _, prop := range properties { 809 value, err := GetFromEnv(prop.EnvName, prop.FailOnMissing) 810 if err != nil { 811 return nil, err 812 } 813 switch prop.Name { 814 case "Account": 815 account = value 816 case "User": 817 user = value 818 case "Password": 819 password = value 820 case "Role": 821 role = value 822 case "Host": 823 host = value 824 case "Port": 825 portStr = value 826 case "Protocol": 827 protocol = value 828 case "Warehouse": 829 warehouse = value 830 case "Database": 831 database = value 832 case "Region": 833 region = value 834 case "Passcode": 835 passcode = value 836 case "Schema": 837 schema = value 838 case "Application": 839 application = value 840 case "PrivateKey": 841 privateKey, err = parsePrivateKeyFromFile(value) 842 if err != nil { 843 return nil, err 844 } 845 } 846 } 847 848 port := 443 // snowflake default port 849 if len(portStr) > 0 { 850 port, err = strconv.Atoi(portStr) 851 if err != nil { 852 return nil, err 853 } 854 } 855 856 cfg := &Config{ 857 Account: account, 858 User: user, 859 Password: password, 860 Role: role, 861 Host: host, 862 Port: port, 863 Protocol: protocol, 864 Warehouse: warehouse, 865 Database: database, 866 Schema: schema, 867 PrivateKey: privateKey, 868 Region: region, 869 Passcode: passcode, 870 Application: application, 871 } 872 return cfg, nil 873 } 874 875 func parsePrivateKeyFromFile(path string) (*rsa.PrivateKey, error) { 876 bytes, err := os.ReadFile(path) 877 if err != nil { 878 return nil, err 879 } 880 block, _ := pem.Decode(bytes) 881 if block == nil { 882 return nil, errors.New("failed to parse PEM block containing the private key") 883 } 884 privateKey, err := x509.ParsePKCS8PrivateKey(block.Bytes) 885 if err != nil { 886 return nil, err 887 } 888 pk, ok := privateKey.(*rsa.PrivateKey) 889 if !ok { 890 return nil, fmt.Errorf("interface convertion. expected type *rsa.PrivateKey, but got %T", privateKey) 891 } 892 return pk, nil 893 }