github.com/selefra/selefra-utils@v0.0.4/pkg/dsn_util/postgresql.go (about) 1 package dsn_util 2 3 import ( 4 "context" 5 "crypto/tls" 6 "crypto/x509" 7 "errors" 8 "fmt" 9 "io/ioutil" 10 "net" 11 "net/url" 12 "os" 13 "strconv" 14 "strings" 15 "time" 16 ) 17 18 type Config struct { 19 // Network type, either tcp or unix. 20 // Default is tcp. 21 Network string 22 // TCP host:port or Unix socket depending on Network. 23 Addr string 24 // Dial timeout for establishing new connections. 25 // Default is 5 seconds. 26 DialTimeout time.Duration 27 // Dialer creates new network connection and has priority over 28 // Network and Addr options. 29 Dialer func(ctx context.Context, network, addr string) (net.Conn, error) 30 31 // TLS config for secure connections. 32 TLSConfig *tls.Config 33 34 User string 35 Password string 36 Database string 37 AppName string 38 // PostgreSQL session parameters updated with `SET` command when a connection is created. 39 ConnParams map[string]interface{} 40 41 // Timeout for socket reads. If reached, commands fail with a timeout instead of blocking. 42 ReadTimeout time.Duration 43 // Timeout for socket writes. If reached, commands fail with a timeout instead of blocking. 44 WriteTimeout time.Duration 45 46 //// ResetSessionFunc is called prior to executing a query on a connection that has been used before. 47 //ResetSessionFunc func(context.Context, *Conn) error 48 } 49 50 func newDefaultConfig() *Config { 51 host := env("PGHOST", "localhost") 52 port := env("PGPORT", "5432") 53 54 cfg := &Config{ 55 Network: "tcp", 56 Addr: net.JoinHostPort(host, port), 57 DialTimeout: 5 * time.Second, 58 TLSConfig: &tls.Config{InsecureSkipVerify: true}, 59 60 User: env("PGUSER", "postgres"), 61 Database: env("PGDATABASE", "postgres"), 62 63 ReadTimeout: 10 * time.Second, 64 WriteTimeout: 5 * time.Second, 65 } 66 67 cfg.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) { 68 netDialer := &net.Dialer{ 69 Timeout: cfg.DialTimeout, 70 KeepAlive: 5 * time.Minute, 71 } 72 return netDialer.DialContext(ctx, network, addr) 73 } 74 75 return cfg 76 } 77 78 func NewConfigByDSN(dsn string) (c *Config, err error) { 79 80 defer func() { 81 if r := recover(); r != nil { 82 e, ok := r.(error) 83 if ok { 84 err = e 85 } else { 86 err = fmt.Errorf("%v", r) 87 } 88 } 89 }() 90 91 c = newDefaultConfig() 92 WithDSN(dsn)(c) 93 return 94 } 95 96 func (x *Config) ToDSN(isPasswordMosaic ...bool) string { 97 buff := strings.Builder{} 98 buff.WriteString("postgres://") 99 if x.User != "" || x.Password != "" { 100 buff.WriteString(x.User) 101 buff.WriteString(":") 102 103 if len(isPasswordMosaic) != 0 && isPasswordMosaic[0] { 104 buff.WriteString("*******") 105 } else { 106 buff.WriteString(x.Password) 107 } 108 109 buff.WriteString("@") 110 } 111 buff.WriteString(x.Addr) 112 if x.Database != "" { 113 buff.WriteString("/") 114 buff.WriteString(x.Database) 115 } 116 return buff.String() 117 } 118 119 type Option func(cfg *Config) 120 121 // Deprecated. Use Option instead. 122 type DriverOption = Option 123 124 func WithNetwork(network string) Option { 125 if network == "" { 126 panic("network is empty") 127 } 128 return func(cfg *Config) { 129 cfg.Network = network 130 } 131 } 132 133 func WithAddr(addr string) Option { 134 if addr == "" { 135 panic("addr is empty") 136 } 137 return func(cfg *Config) { 138 cfg.Addr = addr 139 } 140 } 141 142 func WithTLSConfig(tlsConfig *tls.Config) Option { 143 return func(cfg *Config) { 144 cfg.TLSConfig = tlsConfig 145 } 146 } 147 148 func WithInsecure(on bool) Option { 149 return func(cfg *Config) { 150 if on { 151 cfg.TLSConfig = nil 152 } else { 153 cfg.TLSConfig = &tls.Config{InsecureSkipVerify: true} 154 } 155 } 156 } 157 158 func WithUser(user string) Option { 159 if user == "" { 160 panic("user is empty") 161 } 162 return func(cfg *Config) { 163 cfg.User = user 164 } 165 } 166 167 func WithPassword(password string) Option { 168 return func(cfg *Config) { 169 cfg.Password = password 170 } 171 } 172 173 func WithDatabase(database string) Option { 174 if database == "" { 175 panic("database is empty") 176 } 177 return func(cfg *Config) { 178 cfg.Database = database 179 } 180 } 181 182 func WithApplicationName(appName string) Option { 183 return func(cfg *Config) { 184 cfg.AppName = appName 185 } 186 } 187 188 func WithConnParams(params map[string]interface{}) Option { 189 return func(cfg *Config) { 190 cfg.ConnParams = params 191 } 192 } 193 194 func WithTimeout(timeout time.Duration) Option { 195 return func(cfg *Config) { 196 cfg.DialTimeout = timeout 197 cfg.ReadTimeout = timeout 198 cfg.WriteTimeout = timeout 199 } 200 } 201 202 func WithDialTimeout(dialTimeout time.Duration) Option { 203 return func(cfg *Config) { 204 cfg.DialTimeout = dialTimeout 205 } 206 } 207 208 func WithReadTimeout(readTimeout time.Duration) Option { 209 return func(cfg *Config) { 210 cfg.ReadTimeout = readTimeout 211 } 212 } 213 214 func WithWriteTimeout(writeTimeout time.Duration) Option { 215 return func(cfg *Config) { 216 cfg.WriteTimeout = writeTimeout 217 } 218 } 219 220 //// WithResetSessionFunc configures a function that is called prior to executing 221 //// a query on a connection that has been used before. 222 //// If the func returns driver.ErrBadConn, the connection is discarded. 223 //func WithResetSessionFunc(fn func(context.Context, *Conn) error) Option { 224 // return func(cfg *Config) { 225 // cfg.ResetSessionFunc = fn 226 // } 227 //} 228 229 func WithDSN(dsn string) Option { 230 return func(cfg *Config) { 231 opts, err := parseDSN(dsn) 232 if err != nil { 233 opts, err := parseKVDSN(dsn) 234 if err == nil { 235 for _, opt := range opts { 236 opt(cfg) 237 } 238 return 239 } 240 panic(err) 241 } 242 for _, opt := range opts { 243 opt(cfg) 244 } 245 } 246 } 247 248 func env(key, defValue string) string { 249 if s := os.Getenv(key); s != "" { 250 return s 251 } 252 return defValue 253 } 254 255 // It's just a temporary solution 256 func parseKVDSN(dsn string) ([]Option, error) { 257 opts := make([]Option, 0) 258 host := "" 259 port := "" 260 for _, pair := range strings.Split(dsn, " ") { 261 262 pair = strings.TrimSpace(pair) 263 if pair == "" { 264 continue 265 } 266 267 kv := strings.SplitN(pair, "=", 2) 268 if len(kv) != 2 { 269 return nil, fmt.Errorf("dsn %s not key value pairs", dsn) 270 } 271 key := strings.ToLower(kv[0]) 272 value := kv[1] 273 switch key { 274 case "host": 275 host = value 276 case "user": 277 opts = append(opts, func(cfg *Config) { 278 cfg.User = value 279 }) 280 case "password": 281 opts = append(opts, func(cfg *Config) { 282 cfg.Password = value 283 }) 284 case "port": 285 port = value 286 case "dbname": 287 opts = append(opts, func(cfg *Config) { 288 cfg.Database = value 289 }) 290 } 291 } 292 293 if port != "" { 294 host = host + ":" + port 295 } 296 opts = append(opts, func(cfg *Config) { 297 cfg.Addr = host 298 }) 299 300 return opts, nil 301 } 302 303 func parseDSN(dsn string) ([]Option, error) { 304 u, err := url.Parse(dsn) 305 if err != nil { 306 return nil, err 307 } 308 309 q := queryOptions{q: u.Query()} 310 var opts []Option 311 312 switch u.Scheme { 313 case "postgres", "postgresql": 314 if u.Host != "" { 315 addr := u.Host 316 if !strings.Contains(addr, ":") { 317 addr += ":5432" 318 } 319 opts = append(opts, WithAddr(addr)) 320 } 321 322 if len(u.Path) > 1 { 323 opts = append(opts, WithDatabase(u.Path[1:])) 324 } 325 326 if host := q.string("host"); host != "" { 327 opts = append(opts, WithAddr(host)) 328 if host[0] == '/' { 329 opts = append(opts, WithNetwork("unix")) 330 } 331 } 332 case "unix": 333 if len(u.Path) == 0 { 334 return nil, fmt.Errorf("unix socket DSN requires a path: %s", dsn) 335 } 336 337 opts = append(opts, WithNetwork("unix")) 338 if u.Host != "" { 339 opts = append(opts, WithDatabase(u.Host)) 340 } 341 opts = append(opts, WithAddr(u.Path)) 342 default: 343 return nil, errors.New("pgdriver: invalid scheme: " + u.Scheme) 344 } 345 346 if u.User != nil { 347 opts = append(opts, WithUser(u.User.Username())) 348 if password, ok := u.User.Password(); ok { 349 opts = append(opts, WithPassword(password)) 350 } 351 } 352 353 if appName := q.string("application_name"); appName != "" { 354 opts = append(opts, WithApplicationName(appName)) 355 } 356 357 if sslMode, sslRootCert := q.string("sslmode"), q.string("sslrootcert"); sslMode != "" || sslRootCert != "" { 358 tlsConfig := &tls.Config{} 359 switch sslMode { 360 case "disable": 361 tlsConfig = nil 362 case "allow", "prefer", "": 363 tlsConfig.InsecureSkipVerify = true 364 case "require": 365 if sslRootCert == "" { 366 tlsConfig.InsecureSkipVerify = true 367 break 368 } 369 // For backwards compatibility reasons, in the presence of `sslrootcert`, 370 // `sslmode` = `require` must act as if `sslmode` = `verify-ca`. See the note at 371 // https://www.postgresql.org/docs/current/libpq-ssl.html#LIBQ-SSL-CERTIFICATES . 372 fallthrough 373 case "verify-ca": 374 // The default certificate verification will also verify the host name 375 // which is not the behavior of `verify-ca`. As such, we need to manually 376 // check the certificate chain. 377 // At the time of writing, tls.Config has no option for this behavior 378 // (verify chain, but skip server name). 379 // See https://github.com/golang/go/issues/21971 . 380 tlsConfig.InsecureSkipVerify = true 381 tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { 382 certs := make([]*x509.Certificate, 0, len(rawCerts)) 383 for _, rawCert := range rawCerts { 384 cert, err := x509.ParseCertificate(rawCert) 385 if err != nil { 386 return fmt.Errorf("pgdriver: failed to parse certificate: %w", err) 387 } 388 certs = append(certs, cert) 389 } 390 intermediates := x509.NewCertPool() 391 for _, cert := range certs[1:] { 392 intermediates.AddCert(cert) 393 } 394 _, err := certs[0].Verify(x509.VerifyOptions{ 395 Roots: tlsConfig.RootCAs, 396 Intermediates: intermediates, 397 }) 398 return err 399 } 400 case "verify-full": 401 tlsConfig.ServerName = u.Host 402 if host, _, err := net.SplitHostPort(u.Host); err == nil { 403 tlsConfig.ServerName = host 404 } 405 default: 406 return nil, fmt.Errorf("pgdriver: sslmode '%s' is not supported", sslMode) 407 } 408 if tlsConfig != nil && sslRootCert != "" { 409 rawCA, err := ioutil.ReadFile(sslRootCert) 410 if err != nil { 411 return nil, fmt.Errorf("pgdriver: failed to read root CA: %w", err) 412 } 413 certPool := x509.NewCertPool() 414 if !certPool.AppendCertsFromPEM(rawCA) { 415 return nil, fmt.Errorf("pgdriver: failed to append root CA") 416 } 417 tlsConfig.RootCAs = certPool 418 } 419 opts = append(opts, WithTLSConfig(tlsConfig)) 420 } 421 422 if d := q.duration("timeout"); d != 0 { 423 opts = append(opts, WithTimeout(d)) 424 } 425 if d := q.duration("dial_timeout"); d != 0 { 426 opts = append(opts, WithDialTimeout(d)) 427 } 428 if d := q.duration("connect_timeout"); d != 0 { 429 opts = append(opts, WithDialTimeout(d)) 430 } 431 if d := q.duration("read_timeout"); d != 0 { 432 opts = append(opts, WithReadTimeout(d)) 433 } 434 if d := q.duration("write_timeout"); d != 0 { 435 opts = append(opts, WithWriteTimeout(d)) 436 } 437 438 rem, err := q.remaining() 439 if err != nil { 440 return nil, q.err 441 } 442 443 if len(rem) > 0 { 444 params := make(map[string]interface{}, len(rem)) 445 for k, v := range rem { 446 params[k] = v 447 } 448 opts = append(opts, WithConnParams(params)) 449 } 450 451 return opts, nil 452 } 453 454 // ------------------------------------------------- -------------------------------------------------------------------- 455 456 type queryOptions struct { 457 q url.Values 458 err error 459 } 460 461 func (o *queryOptions) string(name string) string { 462 vs := o.q[name] 463 if len(vs) == 0 { 464 return "" 465 } 466 delete(o.q, name) // enable detection of unknown parameters 467 return vs[len(vs)-1] 468 } 469 470 func (o *queryOptions) duration(name string) time.Duration { 471 s := o.string(name) 472 if s == "" { 473 return 0 474 } 475 // try plain number first 476 if i, err := strconv.Atoi(s); err == nil { 477 if i <= 0 { 478 // disable timeouts 479 return -1 480 } 481 return time.Duration(i) * time.Second 482 } 483 dur, err := time.ParseDuration(s) 484 if err == nil { 485 return dur 486 } 487 if o.err == nil { 488 o.err = fmt.Errorf("pgdriver: invalid %s duration: %w", name, err) 489 } 490 return 0 491 } 492 493 func (o *queryOptions) remaining() (map[string]string, error) { 494 if o.err != nil { 495 return nil, o.err 496 } 497 if len(o.q) == 0 { 498 return nil, nil 499 } 500 m := make(map[string]string, len(o.q)) 501 for k, ss := range o.q { 502 m[k] = ss[len(ss)-1] 503 } 504 return m, nil 505 } 506 507 // ------------------------------------------------- --------------------------------------------------------------------