github.com/jackc/pgx/v5@v5.5.5/pgconn/config.go (about) 1 package pgconn 2 3 import ( 4 "context" 5 "crypto/tls" 6 "crypto/x509" 7 "encoding/pem" 8 "errors" 9 "fmt" 10 "io" 11 "math" 12 "net" 13 "net/url" 14 "os" 15 "path/filepath" 16 "strconv" 17 "strings" 18 "time" 19 20 "github.com/jackc/pgpassfile" 21 "github.com/jackc/pgservicefile" 22 "github.com/jackc/pgx/v5/pgproto3" 23 ) 24 25 type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error 26 type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error 27 type GetSSLPasswordFunc func(ctx context.Context) string 28 29 // Config is the settings used to establish a connection to a PostgreSQL server. It must be created by [ParseConfig]. A 30 // manually initialized Config will cause ConnectConfig to panic. 31 type Config struct { 32 Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp) 33 Port uint16 34 Database string 35 User string 36 Password string 37 TLSConfig *tls.Config // nil disables TLS 38 ConnectTimeout time.Duration 39 DialFunc DialFunc // e.g. net.Dialer.DialContext 40 LookupFunc LookupFunc // e.g. net.Resolver.LookupHost 41 BuildFrontend BuildFrontendFunc 42 RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) 43 44 KerberosSrvName string 45 KerberosSpn string 46 Fallbacks []*FallbackConfig 47 48 // ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server. 49 // It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next 50 // fallback config is tried. This allows implementing high availability behavior such as libpq does with target_session_attrs. 51 ValidateConnect ValidateConnectFunc 52 53 // AfterConnect is called after ValidateConnect. It can be used to set up the connection (e.g. Set session variables 54 // or prepare statements). If this returns an error the connection attempt fails. 55 AfterConnect AfterConnectFunc 56 57 // OnNotice is a callback function called when a notice response is received. 58 OnNotice NoticeHandler 59 60 // OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received. 61 OnNotification NotificationHandler 62 63 // OnPgError is a callback function called when a Postgres error is received by the server. The default handler will close 64 // the connection on any FATAL errors. If you override this handler you should call the previously set handler or ensure 65 // that you close on FATAL errors by returning false. 66 OnPgError PgErrorHandler 67 68 createdByParseConfig bool // Used to enforce created by ParseConfig rule. 69 } 70 71 // ParseConfigOptions contains options that control how a config is built such as GetSSLPassword. 72 type ParseConfigOptions struct { 73 // GetSSLPassword gets the password to decrypt a SSL client certificate. This is analogous to the the libpq function 74 // PQsetSSLKeyPassHook_OpenSSL. 75 GetSSLPassword GetSSLPasswordFunc 76 } 77 78 // Copy returns a deep copy of the config that is safe to use and modify. 79 // The only exception is the TLSConfig field: 80 // according to the tls.Config docs it must not be modified after creation. 81 func (c *Config) Copy() *Config { 82 newConf := new(Config) 83 *newConf = *c 84 if newConf.TLSConfig != nil { 85 newConf.TLSConfig = c.TLSConfig.Clone() 86 } 87 if newConf.RuntimeParams != nil { 88 newConf.RuntimeParams = make(map[string]string, len(c.RuntimeParams)) 89 for k, v := range c.RuntimeParams { 90 newConf.RuntimeParams[k] = v 91 } 92 } 93 if newConf.Fallbacks != nil { 94 newConf.Fallbacks = make([]*FallbackConfig, len(c.Fallbacks)) 95 for i, fallback := range c.Fallbacks { 96 newFallback := new(FallbackConfig) 97 *newFallback = *fallback 98 if newFallback.TLSConfig != nil { 99 newFallback.TLSConfig = fallback.TLSConfig.Clone() 100 } 101 newConf.Fallbacks[i] = newFallback 102 } 103 } 104 return newConf 105 } 106 107 // FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a 108 // network connection. It is used for TLS fallback such as sslmode=prefer and high availability (HA) connections. 109 type FallbackConfig struct { 110 Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) 111 Port uint16 112 TLSConfig *tls.Config // nil disables TLS 113 } 114 115 // isAbsolutePath checks if the provided value is an absolute path either 116 // beginning with a forward slash (as on Linux-based systems) or with a capital 117 // letter A-Z followed by a colon and a backslash, e.g., "C:\", (as on Windows). 118 func isAbsolutePath(path string) bool { 119 isWindowsPath := func(p string) bool { 120 if len(p) < 3 { 121 return false 122 } 123 drive := p[0] 124 colon := p[1] 125 backslash := p[2] 126 if drive >= 'A' && drive <= 'Z' && colon == ':' && backslash == '\\' { 127 return true 128 } 129 return false 130 } 131 return strings.HasPrefix(path, "/") || isWindowsPath(path) 132 } 133 134 // NetworkAddress converts a PostgreSQL host and port into network and address suitable for use with 135 // net.Dial. 136 func NetworkAddress(host string, port uint16) (network, address string) { 137 if isAbsolutePath(host) { 138 network = "unix" 139 address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10) 140 } else { 141 network = "tcp" 142 address = net.JoinHostPort(host, strconv.Itoa(int(port))) 143 } 144 return network, address 145 } 146 147 // ParseConfig builds a *Config from connString with similar behavior to the PostgreSQL standard C library libpq. It 148 // uses the same defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely 149 // matches the parsing behavior of libpq. connString may either be in URL format or keyword = value format (DSN style). 150 // See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be 151 // empty to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file. 152 // 153 // # Example DSN 154 // user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca 155 // 156 // # Example URL 157 // postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca 158 // 159 // The returned *Config may be modified. However, it is strongly recommended that any configuration that can be done 160 // through the connection string be done there. In particular the fields Host, Port, TLSConfig, and Fallbacks can be 161 // interdependent (e.g. TLSConfig needs knowledge of the host to validate the server certificate). These fields should 162 // not be modified individually. They should all be modified or all left unchanged. 163 // 164 // ParseConfig supports specifying multiple hosts in similar manner to libpq. Host and port may include comma separated 165 // values that will be tried in order. This can be used as part of a high availability system. See 166 // https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information. 167 // 168 // # Example URL 169 // postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb 170 // 171 // ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed 172 // via database URL or DSN: 173 // 174 // PGHOST 175 // PGPORT 176 // PGDATABASE 177 // PGUSER 178 // PGPASSWORD 179 // PGPASSFILE 180 // PGSERVICE 181 // PGSERVICEFILE 182 // PGSSLMODE 183 // PGSSLCERT 184 // PGSSLKEY 185 // PGSSLROOTCERT 186 // PGSSLPASSWORD 187 // PGAPPNAME 188 // PGCONNECT_TIMEOUT 189 // PGTARGETSESSIONATTRS 190 // 191 // See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables. 192 // 193 // See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key word names. They are 194 // usually but not always the environment variable name downcased and without the "PG" prefix. 195 // 196 // Important Security Notes: 197 // 198 // ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to "prefer" behavior if 199 // not set. 200 // 201 // See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on what level of 202 // security each sslmode provides. 203 // 204 // The sslmode "prefer" (the default), sslmode "allow", and multiple hosts are implemented via the Fallbacks field of 205 // the Config struct. If TLSConfig is manually changed it will not affect the fallbacks. For example, in the case of 206 // sslmode "prefer" this means it will first try the main Config settings which use TLS, then it will try the fallback 207 // which does not use TLS. This can lead to an unexpected unencrypted connection if the main TLS config is manually 208 // changed later but the unencrypted fallback is present. Ensure there are no stale fallbacks when manually setting 209 // TLSConfig. 210 // 211 // Other known differences with libpq: 212 // 213 // When multiple hosts are specified, libpq allows them to have different passwords set via the .pgpass file. pgconn 214 // does not. 215 // 216 // In addition, ParseConfig accepts the following options: 217 // 218 // - servicefile. 219 // libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a 220 // part of the connection string. 221 func ParseConfig(connString string) (*Config, error) { 222 var parseConfigOptions ParseConfigOptions 223 return ParseConfigWithOptions(connString, parseConfigOptions) 224 } 225 226 // ParseConfigWithOptions builds a *Config from connString and options with similar behavior to the PostgreSQL standard 227 // C library libpq. options contains settings that cannot be specified in a connString such as providing a function to 228 // get the SSL password. 229 func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Config, error) { 230 defaultSettings := defaultSettings() 231 envSettings := parseEnvSettings() 232 233 connStringSettings := make(map[string]string) 234 if connString != "" { 235 var err error 236 // connString may be a database URL or a DSN 237 if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") { 238 connStringSettings, err = parseURLSettings(connString) 239 if err != nil { 240 return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as URL", err: err} 241 } 242 } else { 243 connStringSettings, err = parseDSNSettings(connString) 244 if err != nil { 245 return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as DSN", err: err} 246 } 247 } 248 } 249 250 settings := mergeSettings(defaultSettings, envSettings, connStringSettings) 251 if service, present := settings["service"]; present { 252 serviceSettings, err := parseServiceSettings(settings["servicefile"], service) 253 if err != nil { 254 return nil, &ParseConfigError{ConnString: connString, msg: "failed to read service", err: err} 255 } 256 257 settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings) 258 } 259 260 config := &Config{ 261 createdByParseConfig: true, 262 Database: settings["database"], 263 User: settings["user"], 264 Password: settings["password"], 265 RuntimeParams: make(map[string]string), 266 BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend { 267 return pgproto3.NewFrontend(r, w) 268 }, 269 OnPgError: func(_ *PgConn, pgErr *PgError) bool { 270 // we want to automatically close any fatal errors 271 if strings.EqualFold(pgErr.Severity, "FATAL") { 272 return false 273 } 274 return true 275 }, 276 } 277 278 if connectTimeoutSetting, present := settings["connect_timeout"]; present { 279 connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting) 280 if err != nil { 281 return nil, &ParseConfigError{ConnString: connString, msg: "invalid connect_timeout", err: err} 282 } 283 config.ConnectTimeout = connectTimeout 284 config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout) 285 } else { 286 defaultDialer := makeDefaultDialer() 287 config.DialFunc = defaultDialer.DialContext 288 } 289 290 config.LookupFunc = makeDefaultResolver().LookupHost 291 292 notRuntimeParams := map[string]struct{}{ 293 "host": {}, 294 "port": {}, 295 "database": {}, 296 "user": {}, 297 "password": {}, 298 "passfile": {}, 299 "connect_timeout": {}, 300 "sslmode": {}, 301 "sslkey": {}, 302 "sslcert": {}, 303 "sslrootcert": {}, 304 "sslpassword": {}, 305 "sslsni": {}, 306 "krbspn": {}, 307 "krbsrvname": {}, 308 "target_session_attrs": {}, 309 "service": {}, 310 "servicefile": {}, 311 } 312 313 // Adding kerberos configuration 314 if _, present := settings["krbsrvname"]; present { 315 config.KerberosSrvName = settings["krbsrvname"] 316 } 317 if _, present := settings["krbspn"]; present { 318 config.KerberosSpn = settings["krbspn"] 319 } 320 321 for k, v := range settings { 322 if _, present := notRuntimeParams[k]; present { 323 continue 324 } 325 config.RuntimeParams[k] = v 326 } 327 328 fallbacks := []*FallbackConfig{} 329 330 hosts := strings.Split(settings["host"], ",") 331 ports := strings.Split(settings["port"], ",") 332 333 for i, host := range hosts { 334 var portStr string 335 if i < len(ports) { 336 portStr = ports[i] 337 } else { 338 portStr = ports[0] 339 } 340 341 port, err := parsePort(portStr) 342 if err != nil { 343 return nil, &ParseConfigError{ConnString: connString, msg: "invalid port", err: err} 344 } 345 346 var tlsConfigs []*tls.Config 347 348 // Ignore TLS settings if Unix domain socket like libpq 349 if network, _ := NetworkAddress(host, port); network == "unix" { 350 tlsConfigs = append(tlsConfigs, nil) 351 } else { 352 var err error 353 tlsConfigs, err = configTLS(settings, host, options) 354 if err != nil { 355 return nil, &ParseConfigError{ConnString: connString, msg: "failed to configure TLS", err: err} 356 } 357 } 358 359 for _, tlsConfig := range tlsConfigs { 360 fallbacks = append(fallbacks, &FallbackConfig{ 361 Host: host, 362 Port: port, 363 TLSConfig: tlsConfig, 364 }) 365 } 366 } 367 368 config.Host = fallbacks[0].Host 369 config.Port = fallbacks[0].Port 370 config.TLSConfig = fallbacks[0].TLSConfig 371 config.Fallbacks = fallbacks[1:] 372 373 passfile, err := pgpassfile.ReadPassfile(settings["passfile"]) 374 if err == nil { 375 if config.Password == "" { 376 host := config.Host 377 if network, _ := NetworkAddress(config.Host, config.Port); network == "unix" { 378 host = "localhost" 379 } 380 381 config.Password = passfile.FindPassword(host, strconv.Itoa(int(config.Port)), config.Database, config.User) 382 } 383 } 384 385 switch tsa := settings["target_session_attrs"]; tsa { 386 case "read-write": 387 config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite 388 case "read-only": 389 config.ValidateConnect = ValidateConnectTargetSessionAttrsReadOnly 390 case "primary": 391 config.ValidateConnect = ValidateConnectTargetSessionAttrsPrimary 392 case "standby": 393 config.ValidateConnect = ValidateConnectTargetSessionAttrsStandby 394 case "prefer-standby": 395 config.ValidateConnect = ValidateConnectTargetSessionAttrsPreferStandby 396 case "any": 397 // do nothing 398 default: 399 return nil, &ParseConfigError{ConnString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)} 400 } 401 402 return config, nil 403 } 404 405 func mergeSettings(settingSets ...map[string]string) map[string]string { 406 settings := make(map[string]string) 407 408 for _, s2 := range settingSets { 409 for k, v := range s2 { 410 settings[k] = v 411 } 412 } 413 414 return settings 415 } 416 417 func parseEnvSettings() map[string]string { 418 settings := make(map[string]string) 419 420 nameMap := map[string]string{ 421 "PGHOST": "host", 422 "PGPORT": "port", 423 "PGDATABASE": "database", 424 "PGUSER": "user", 425 "PGPASSWORD": "password", 426 "PGPASSFILE": "passfile", 427 "PGAPPNAME": "application_name", 428 "PGCONNECT_TIMEOUT": "connect_timeout", 429 "PGSSLMODE": "sslmode", 430 "PGSSLKEY": "sslkey", 431 "PGSSLCERT": "sslcert", 432 "PGSSLSNI": "sslsni", 433 "PGSSLROOTCERT": "sslrootcert", 434 "PGSSLPASSWORD": "sslpassword", 435 "PGTARGETSESSIONATTRS": "target_session_attrs", 436 "PGSERVICE": "service", 437 "PGSERVICEFILE": "servicefile", 438 } 439 440 for envname, realname := range nameMap { 441 value := os.Getenv(envname) 442 if value != "" { 443 settings[realname] = value 444 } 445 } 446 447 return settings 448 } 449 450 func parseURLSettings(connString string) (map[string]string, error) { 451 settings := make(map[string]string) 452 453 url, err := url.Parse(connString) 454 if err != nil { 455 return nil, err 456 } 457 458 if url.User != nil { 459 settings["user"] = url.User.Username() 460 if password, present := url.User.Password(); present { 461 settings["password"] = password 462 } 463 } 464 465 // Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port. 466 var hosts []string 467 var ports []string 468 for _, host := range strings.Split(url.Host, ",") { 469 if host == "" { 470 continue 471 } 472 if isIPOnly(host) { 473 hosts = append(hosts, strings.Trim(host, "[]")) 474 continue 475 } 476 h, p, err := net.SplitHostPort(host) 477 if err != nil { 478 return nil, fmt.Errorf("failed to split host:port in '%s', err: %w", host, err) 479 } 480 if h != "" { 481 hosts = append(hosts, h) 482 } 483 if p != "" { 484 ports = append(ports, p) 485 } 486 } 487 if len(hosts) > 0 { 488 settings["host"] = strings.Join(hosts, ",") 489 } 490 if len(ports) > 0 { 491 settings["port"] = strings.Join(ports, ",") 492 } 493 494 database := strings.TrimLeft(url.Path, "/") 495 if database != "" { 496 settings["database"] = database 497 } 498 499 nameMap := map[string]string{ 500 "dbname": "database", 501 } 502 503 for k, v := range url.Query() { 504 if k2, present := nameMap[k]; present { 505 k = k2 506 } 507 508 settings[k] = v[0] 509 } 510 511 return settings, nil 512 } 513 514 func isIPOnly(host string) bool { 515 return net.ParseIP(strings.Trim(host, "[]")) != nil || !strings.Contains(host, ":") 516 } 517 518 var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1} 519 520 func parseDSNSettings(s string) (map[string]string, error) { 521 settings := make(map[string]string) 522 523 nameMap := map[string]string{ 524 "dbname": "database", 525 } 526 527 for len(s) > 0 { 528 var key, val string 529 eqIdx := strings.IndexRune(s, '=') 530 if eqIdx < 0 { 531 return nil, errors.New("invalid dsn") 532 } 533 534 key = strings.Trim(s[:eqIdx], " \t\n\r\v\f") 535 s = strings.TrimLeft(s[eqIdx+1:], " \t\n\r\v\f") 536 if len(s) == 0 { 537 } else if s[0] != '\'' { 538 end := 0 539 for ; end < len(s); end++ { 540 if asciiSpace[s[end]] == 1 { 541 break 542 } 543 if s[end] == '\\' { 544 end++ 545 if end == len(s) { 546 return nil, errors.New("invalid backslash") 547 } 548 } 549 } 550 val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1) 551 if end == len(s) { 552 s = "" 553 } else { 554 s = s[end+1:] 555 } 556 } else { // quoted string 557 s = s[1:] 558 end := 0 559 for ; end < len(s); end++ { 560 if s[end] == '\'' { 561 break 562 } 563 if s[end] == '\\' { 564 end++ 565 } 566 } 567 if end == len(s) { 568 return nil, errors.New("unterminated quoted string in connection info string") 569 } 570 val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1) 571 if end == len(s) { 572 s = "" 573 } else { 574 s = s[end+1:] 575 } 576 } 577 578 if k, ok := nameMap[key]; ok { 579 key = k 580 } 581 582 if key == "" { 583 return nil, errors.New("invalid dsn") 584 } 585 586 settings[key] = val 587 } 588 589 return settings, nil 590 } 591 592 func parseServiceSettings(servicefilePath, serviceName string) (map[string]string, error) { 593 servicefile, err := pgservicefile.ReadServicefile(servicefilePath) 594 if err != nil { 595 return nil, fmt.Errorf("failed to read service file: %v", servicefilePath) 596 } 597 598 service, err := servicefile.GetService(serviceName) 599 if err != nil { 600 return nil, fmt.Errorf("unable to find service: %v", serviceName) 601 } 602 603 nameMap := map[string]string{ 604 "dbname": "database", 605 } 606 607 settings := make(map[string]string, len(service.Settings)) 608 for k, v := range service.Settings { 609 if k2, present := nameMap[k]; present { 610 k = k2 611 } 612 settings[k] = v 613 } 614 615 return settings, nil 616 } 617 618 // configTLS uses libpq's TLS parameters to construct []*tls.Config. It is 619 // necessary to allow returning multiple TLS configs as sslmode "allow" and 620 // "prefer" allow fallback. 621 func configTLS(settings map[string]string, thisHost string, parseConfigOptions ParseConfigOptions) ([]*tls.Config, error) { 622 host := thisHost 623 sslmode := settings["sslmode"] 624 sslrootcert := settings["sslrootcert"] 625 sslcert := settings["sslcert"] 626 sslkey := settings["sslkey"] 627 sslpassword := settings["sslpassword"] 628 sslsni := settings["sslsni"] 629 630 // Match libpq default behavior 631 if sslmode == "" { 632 sslmode = "prefer" 633 } 634 if sslsni == "" { 635 sslsni = "1" 636 } 637 638 tlsConfig := &tls.Config{} 639 640 switch sslmode { 641 case "disable": 642 return []*tls.Config{nil}, nil 643 case "allow", "prefer": 644 tlsConfig.InsecureSkipVerify = true 645 case "require": 646 // According to PostgreSQL documentation, if a root CA file exists, 647 // the behavior of sslmode=require should be the same as that of verify-ca 648 // 649 // See https://www.postgresql.org/docs/12/libpq-ssl.html 650 if sslrootcert != "" { 651 goto nextCase 652 } 653 tlsConfig.InsecureSkipVerify = true 654 break 655 nextCase: 656 fallthrough 657 case "verify-ca": 658 // Don't perform the default certificate verification because it 659 // will verify the hostname. Instead, verify the server's 660 // certificate chain ourselves in VerifyPeerCertificate and 661 // ignore the server name. This emulates libpq's verify-ca 662 // behavior. 663 // 664 // See https://github.com/golang/go/issues/21971#issuecomment-332693931 665 // and https://pkg.go.dev/crypto/tls?tab=doc#example-Config-VerifyPeerCertificate 666 // for more info. 667 tlsConfig.InsecureSkipVerify = true 668 tlsConfig.VerifyPeerCertificate = func(certificates [][]byte, _ [][]*x509.Certificate) error { 669 certs := make([]*x509.Certificate, len(certificates)) 670 for i, asn1Data := range certificates { 671 cert, err := x509.ParseCertificate(asn1Data) 672 if err != nil { 673 return errors.New("failed to parse certificate from server: " + err.Error()) 674 } 675 certs[i] = cert 676 } 677 678 // Leave DNSName empty to skip hostname verification. 679 opts := x509.VerifyOptions{ 680 Roots: tlsConfig.RootCAs, 681 Intermediates: x509.NewCertPool(), 682 } 683 // Skip the first cert because it's the leaf. All others 684 // are intermediates. 685 for _, cert := range certs[1:] { 686 opts.Intermediates.AddCert(cert) 687 } 688 _, err := certs[0].Verify(opts) 689 return err 690 } 691 case "verify-full": 692 tlsConfig.ServerName = host 693 default: 694 return nil, errors.New("sslmode is invalid") 695 } 696 697 if sslrootcert != "" { 698 caCertPool := x509.NewCertPool() 699 700 caPath := sslrootcert 701 caCert, err := os.ReadFile(caPath) 702 if err != nil { 703 return nil, fmt.Errorf("unable to read CA file: %w", err) 704 } 705 706 if !caCertPool.AppendCertsFromPEM(caCert) { 707 return nil, errors.New("unable to add CA to cert pool") 708 } 709 710 tlsConfig.RootCAs = caCertPool 711 tlsConfig.ClientCAs = caCertPool 712 } 713 714 if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") { 715 return nil, errors.New(`both "sslcert" and "sslkey" are required`) 716 } 717 718 if sslcert != "" && sslkey != "" { 719 buf, err := os.ReadFile(sslkey) 720 if err != nil { 721 return nil, fmt.Errorf("unable to read sslkey: %w", err) 722 } 723 block, _ := pem.Decode(buf) 724 if block == nil { 725 return nil, errors.New("failed to decode sslkey") 726 } 727 var pemKey []byte 728 var decryptedKey []byte 729 var decryptedError error 730 // If PEM is encrypted, attempt to decrypt using pass phrase 731 if x509.IsEncryptedPEMBlock(block) { 732 // Attempt decryption with pass phrase 733 // NOTE: only supports RSA (PKCS#1) 734 if sslpassword != "" { 735 decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword)) 736 } 737 //if sslpassword not provided or has decryption error when use it 738 //try to find sslpassword with callback function 739 if sslpassword == "" || decryptedError != nil { 740 if parseConfigOptions.GetSSLPassword != nil { 741 sslpassword = parseConfigOptions.GetSSLPassword(context.Background()) 742 } 743 if sslpassword == "" { 744 return nil, fmt.Errorf("unable to find sslpassword") 745 } 746 } 747 decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword)) 748 // Should we also provide warning for PKCS#1 needed? 749 if decryptedError != nil { 750 return nil, fmt.Errorf("unable to decrypt key: %w", err) 751 } 752 753 pemBytes := pem.Block{ 754 Type: "RSA PRIVATE KEY", 755 Bytes: decryptedKey, 756 } 757 pemKey = pem.EncodeToMemory(&pemBytes) 758 } else { 759 pemKey = pem.EncodeToMemory(block) 760 } 761 certfile, err := os.ReadFile(sslcert) 762 if err != nil { 763 return nil, fmt.Errorf("unable to read cert: %w", err) 764 } 765 cert, err := tls.X509KeyPair(certfile, pemKey) 766 if err != nil { 767 return nil, fmt.Errorf("unable to load cert: %w", err) 768 } 769 tlsConfig.Certificates = []tls.Certificate{cert} 770 } 771 772 // Set Server Name Indication (SNI), if enabled by connection parameters. 773 // Per RFC 6066, do not set it if the host is a literal IP address (IPv4 774 // or IPv6). 775 if sslsni == "1" && net.ParseIP(host) == nil { 776 tlsConfig.ServerName = host 777 } 778 779 switch sslmode { 780 case "allow": 781 return []*tls.Config{nil, tlsConfig}, nil 782 case "prefer": 783 return []*tls.Config{tlsConfig, nil}, nil 784 case "require", "verify-ca", "verify-full": 785 return []*tls.Config{tlsConfig}, nil 786 default: 787 panic("BUG: bad sslmode should already have been caught") 788 } 789 } 790 791 func parsePort(s string) (uint16, error) { 792 port, err := strconv.ParseUint(s, 10, 16) 793 if err != nil { 794 return 0, err 795 } 796 if port < 1 || port > math.MaxUint16 { 797 return 0, errors.New("outside range") 798 } 799 return uint16(port), nil 800 } 801 802 func makeDefaultDialer() *net.Dialer { 803 return &net.Dialer{KeepAlive: 5 * time.Minute} 804 } 805 806 func makeDefaultResolver() *net.Resolver { 807 return net.DefaultResolver 808 } 809 810 func parseConnectTimeoutSetting(s string) (time.Duration, error) { 811 timeout, err := strconv.ParseInt(s, 10, 64) 812 if err != nil { 813 return 0, err 814 } 815 if timeout < 0 { 816 return 0, errors.New("negative timeout") 817 } 818 return time.Duration(timeout) * time.Second, nil 819 } 820 821 func makeConnectTimeoutDialFunc(timeout time.Duration) DialFunc { 822 d := makeDefaultDialer() 823 d.Timeout = timeout 824 return d.DialContext 825 } 826 827 // ValidateConnectTargetSessionAttrsReadWrite is a ValidateConnectFunc that implements libpq compatible 828 // target_session_attrs=read-write. 829 func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error { 830 result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read() 831 if result.Err != nil { 832 return result.Err 833 } 834 835 if string(result.Rows[0][0]) == "on" { 836 return errors.New("read only connection") 837 } 838 839 return nil 840 } 841 842 // ValidateConnectTargetSessionAttrsReadOnly is a ValidateConnectFunc that implements libpq compatible 843 // target_session_attrs=read-only. 844 func ValidateConnectTargetSessionAttrsReadOnly(ctx context.Context, pgConn *PgConn) error { 845 result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read() 846 if result.Err != nil { 847 return result.Err 848 } 849 850 if string(result.Rows[0][0]) != "on" { 851 return errors.New("connection is not read only") 852 } 853 854 return nil 855 } 856 857 // ValidateConnectTargetSessionAttrsStandby is a ValidateConnectFunc that implements libpq compatible 858 // target_session_attrs=standby. 859 func ValidateConnectTargetSessionAttrsStandby(ctx context.Context, pgConn *PgConn) error { 860 result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read() 861 if result.Err != nil { 862 return result.Err 863 } 864 865 if string(result.Rows[0][0]) != "t" { 866 return errors.New("server is not in hot standby mode") 867 } 868 869 return nil 870 } 871 872 // ValidateConnectTargetSessionAttrsPrimary is a ValidateConnectFunc that implements libpq compatible 873 // target_session_attrs=primary. 874 func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgConn) error { 875 result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read() 876 if result.Err != nil { 877 return result.Err 878 } 879 880 if string(result.Rows[0][0]) == "t" { 881 return errors.New("server is in standby mode") 882 } 883 884 return nil 885 } 886 887 // ValidateConnectTargetSessionAttrsPreferStandby is a ValidateConnectFunc that implements libpq compatible 888 // target_session_attrs=prefer-standby. 889 func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn *PgConn) error { 890 result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read() 891 if result.Err != nil { 892 return result.Err 893 } 894 895 if string(result.Rows[0][0]) != "t" { 896 return &NotPreferredError{err: errors.New("server is not in hot standby mode")} 897 } 898 899 return nil 900 }