github.com/selefra/selefra-utils@v0.0.4/pkg/dsn_util/postgresql.go (about)

     1  package dsn_util
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"crypto/x509"
     7  	"errors"
     8  	"fmt"
     9  	"io/ioutil"
    10  	"net"
    11  	"net/url"
    12  	"os"
    13  	"strconv"
    14  	"strings"
    15  	"time"
    16  )
    17  
    18  type Config struct {
    19  	// Network type, either tcp or unix.
    20  	// Default is tcp.
    21  	Network string
    22  	// TCP host:port or Unix socket depending on Network.
    23  	Addr string
    24  	// Dial timeout for establishing new connections.
    25  	// Default is 5 seconds.
    26  	DialTimeout time.Duration
    27  	// Dialer creates new network connection and has priority over
    28  	// Network and Addr options.
    29  	Dialer func(ctx context.Context, network, addr string) (net.Conn, error)
    30  
    31  	// TLS config for secure connections.
    32  	TLSConfig *tls.Config
    33  
    34  	User     string
    35  	Password string
    36  	Database string
    37  	AppName  string
    38  	// PostgreSQL session parameters updated with `SET` command when a connection is created.
    39  	ConnParams map[string]interface{}
    40  
    41  	// Timeout for socket reads. If reached, commands fail with a timeout instead of blocking.
    42  	ReadTimeout time.Duration
    43  	// Timeout for socket writes. If reached, commands fail with a timeout instead of blocking.
    44  	WriteTimeout time.Duration
    45  
    46  	//// ResetSessionFunc is called prior to executing a query on a connection that has been used before.
    47  	//ResetSessionFunc func(context.Context, *Conn) error
    48  }
    49  
    50  func newDefaultConfig() *Config {
    51  	host := env("PGHOST", "localhost")
    52  	port := env("PGPORT", "5432")
    53  
    54  	cfg := &Config{
    55  		Network:     "tcp",
    56  		Addr:        net.JoinHostPort(host, port),
    57  		DialTimeout: 5 * time.Second,
    58  		TLSConfig:   &tls.Config{InsecureSkipVerify: true},
    59  
    60  		User:     env("PGUSER", "postgres"),
    61  		Database: env("PGDATABASE", "postgres"),
    62  
    63  		ReadTimeout:  10 * time.Second,
    64  		WriteTimeout: 5 * time.Second,
    65  	}
    66  
    67  	cfg.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) {
    68  		netDialer := &net.Dialer{
    69  			Timeout:   cfg.DialTimeout,
    70  			KeepAlive: 5 * time.Minute,
    71  		}
    72  		return netDialer.DialContext(ctx, network, addr)
    73  	}
    74  
    75  	return cfg
    76  }
    77  
    78  func NewConfigByDSN(dsn string) (c *Config, err error) {
    79  
    80  	defer func() {
    81  		if r := recover(); r != nil {
    82  			e, ok := r.(error)
    83  			if ok {
    84  				err = e
    85  			} else {
    86  				err = fmt.Errorf("%v", r)
    87  			}
    88  		}
    89  	}()
    90  
    91  	c = newDefaultConfig()
    92  	WithDSN(dsn)(c)
    93  	return
    94  }
    95  
    96  func (x *Config) ToDSN(isPasswordMosaic ...bool) string {
    97  	buff := strings.Builder{}
    98  	buff.WriteString("postgres://")
    99  	if x.User != "" || x.Password != "" {
   100  		buff.WriteString(x.User)
   101  		buff.WriteString(":")
   102  
   103  		if len(isPasswordMosaic) != 0 && isPasswordMosaic[0] {
   104  			buff.WriteString("*******")
   105  		} else {
   106  			buff.WriteString(x.Password)
   107  		}
   108  
   109  		buff.WriteString("@")
   110  	}
   111  	buff.WriteString(x.Addr)
   112  	if x.Database != "" {
   113  		buff.WriteString("/")
   114  		buff.WriteString(x.Database)
   115  	}
   116  	return buff.String()
   117  }
   118  
   119  type Option func(cfg *Config)
   120  
   121  // Deprecated. Use Option instead.
   122  type DriverOption = Option
   123  
   124  func WithNetwork(network string) Option {
   125  	if network == "" {
   126  		panic("network is empty")
   127  	}
   128  	return func(cfg *Config) {
   129  		cfg.Network = network
   130  	}
   131  }
   132  
   133  func WithAddr(addr string) Option {
   134  	if addr == "" {
   135  		panic("addr is empty")
   136  	}
   137  	return func(cfg *Config) {
   138  		cfg.Addr = addr
   139  	}
   140  }
   141  
   142  func WithTLSConfig(tlsConfig *tls.Config) Option {
   143  	return func(cfg *Config) {
   144  		cfg.TLSConfig = tlsConfig
   145  	}
   146  }
   147  
   148  func WithInsecure(on bool) Option {
   149  	return func(cfg *Config) {
   150  		if on {
   151  			cfg.TLSConfig = nil
   152  		} else {
   153  			cfg.TLSConfig = &tls.Config{InsecureSkipVerify: true}
   154  		}
   155  	}
   156  }
   157  
   158  func WithUser(user string) Option {
   159  	if user == "" {
   160  		panic("user is empty")
   161  	}
   162  	return func(cfg *Config) {
   163  		cfg.User = user
   164  	}
   165  }
   166  
   167  func WithPassword(password string) Option {
   168  	return func(cfg *Config) {
   169  		cfg.Password = password
   170  	}
   171  }
   172  
   173  func WithDatabase(database string) Option {
   174  	if database == "" {
   175  		panic("database is empty")
   176  	}
   177  	return func(cfg *Config) {
   178  		cfg.Database = database
   179  	}
   180  }
   181  
   182  func WithApplicationName(appName string) Option {
   183  	return func(cfg *Config) {
   184  		cfg.AppName = appName
   185  	}
   186  }
   187  
   188  func WithConnParams(params map[string]interface{}) Option {
   189  	return func(cfg *Config) {
   190  		cfg.ConnParams = params
   191  	}
   192  }
   193  
   194  func WithTimeout(timeout time.Duration) Option {
   195  	return func(cfg *Config) {
   196  		cfg.DialTimeout = timeout
   197  		cfg.ReadTimeout = timeout
   198  		cfg.WriteTimeout = timeout
   199  	}
   200  }
   201  
   202  func WithDialTimeout(dialTimeout time.Duration) Option {
   203  	return func(cfg *Config) {
   204  		cfg.DialTimeout = dialTimeout
   205  	}
   206  }
   207  
   208  func WithReadTimeout(readTimeout time.Duration) Option {
   209  	return func(cfg *Config) {
   210  		cfg.ReadTimeout = readTimeout
   211  	}
   212  }
   213  
   214  func WithWriteTimeout(writeTimeout time.Duration) Option {
   215  	return func(cfg *Config) {
   216  		cfg.WriteTimeout = writeTimeout
   217  	}
   218  }
   219  
   220  //// WithResetSessionFunc configures a function that is called prior to executing
   221  //// a query on a connection that has been used before.
   222  //// If the func returns driver.ErrBadConn, the connection is discarded.
   223  //func WithResetSessionFunc(fn func(context.Context, *Conn) error) Option {
   224  //	return func(cfg *Config) {
   225  //		cfg.ResetSessionFunc = fn
   226  //	}
   227  //}
   228  
   229  func WithDSN(dsn string) Option {
   230  	return func(cfg *Config) {
   231  		opts, err := parseDSN(dsn)
   232  		if err != nil {
   233  			opts, err := parseKVDSN(dsn)
   234  			if err == nil {
   235  				for _, opt := range opts {
   236  					opt(cfg)
   237  				}
   238  				return
   239  			}
   240  			panic(err)
   241  		}
   242  		for _, opt := range opts {
   243  			opt(cfg)
   244  		}
   245  	}
   246  }
   247  
   248  func env(key, defValue string) string {
   249  	if s := os.Getenv(key); s != "" {
   250  		return s
   251  	}
   252  	return defValue
   253  }
   254  
   255  // It's just a temporary solution
   256  func parseKVDSN(dsn string) ([]Option, error) {
   257  	opts := make([]Option, 0)
   258  	host := ""
   259  	port := ""
   260  	for _, pair := range strings.Split(dsn, " ") {
   261  
   262  		pair = strings.TrimSpace(pair)
   263  		if pair == "" {
   264  			continue
   265  		}
   266  
   267  		kv := strings.SplitN(pair, "=", 2)
   268  		if len(kv) != 2 {
   269  			return nil, fmt.Errorf("dsn %s not key value pairs", dsn)
   270  		}
   271  		key := strings.ToLower(kv[0])
   272  		value := kv[1]
   273  		switch key {
   274  		case "host":
   275  			host = value
   276  		case "user":
   277  			opts = append(opts, func(cfg *Config) {
   278  				cfg.User = value
   279  			})
   280  		case "password":
   281  			opts = append(opts, func(cfg *Config) {
   282  				cfg.Password = value
   283  			})
   284  		case "port":
   285  			port = value
   286  		case "dbname":
   287  			opts = append(opts, func(cfg *Config) {
   288  				cfg.Database = value
   289  			})
   290  		}
   291  	}
   292  
   293  	if port != "" {
   294  		host = host + ":" + port
   295  	}
   296  	opts = append(opts, func(cfg *Config) {
   297  		cfg.Addr = host
   298  	})
   299  
   300  	return opts, nil
   301  }
   302  
   303  func parseDSN(dsn string) ([]Option, error) {
   304  	u, err := url.Parse(dsn)
   305  	if err != nil {
   306  		return nil, err
   307  	}
   308  
   309  	q := queryOptions{q: u.Query()}
   310  	var opts []Option
   311  
   312  	switch u.Scheme {
   313  	case "postgres", "postgresql":
   314  		if u.Host != "" {
   315  			addr := u.Host
   316  			if !strings.Contains(addr, ":") {
   317  				addr += ":5432"
   318  			}
   319  			opts = append(opts, WithAddr(addr))
   320  		}
   321  
   322  		if len(u.Path) > 1 {
   323  			opts = append(opts, WithDatabase(u.Path[1:]))
   324  		}
   325  
   326  		if host := q.string("host"); host != "" {
   327  			opts = append(opts, WithAddr(host))
   328  			if host[0] == '/' {
   329  				opts = append(opts, WithNetwork("unix"))
   330  			}
   331  		}
   332  	case "unix":
   333  		if len(u.Path) == 0 {
   334  			return nil, fmt.Errorf("unix socket DSN requires a path: %s", dsn)
   335  		}
   336  
   337  		opts = append(opts, WithNetwork("unix"))
   338  		if u.Host != "" {
   339  			opts = append(opts, WithDatabase(u.Host))
   340  		}
   341  		opts = append(opts, WithAddr(u.Path))
   342  	default:
   343  		return nil, errors.New("pgdriver: invalid scheme: " + u.Scheme)
   344  	}
   345  
   346  	if u.User != nil {
   347  		opts = append(opts, WithUser(u.User.Username()))
   348  		if password, ok := u.User.Password(); ok {
   349  			opts = append(opts, WithPassword(password))
   350  		}
   351  	}
   352  
   353  	if appName := q.string("application_name"); appName != "" {
   354  		opts = append(opts, WithApplicationName(appName))
   355  	}
   356  
   357  	if sslMode, sslRootCert := q.string("sslmode"), q.string("sslrootcert"); sslMode != "" || sslRootCert != "" {
   358  		tlsConfig := &tls.Config{}
   359  		switch sslMode {
   360  		case "disable":
   361  			tlsConfig = nil
   362  		case "allow", "prefer", "":
   363  			tlsConfig.InsecureSkipVerify = true
   364  		case "require":
   365  			if sslRootCert == "" {
   366  				tlsConfig.InsecureSkipVerify = true
   367  				break
   368  			}
   369  			// For backwards compatibility reasons, in the presence of `sslrootcert`,
   370  			// `sslmode` = `require` must act as if `sslmode` = `verify-ca`. See the note at
   371  			// https://www.postgresql.org/docs/current/libpq-ssl.html#LIBQ-SSL-CERTIFICATES .
   372  			fallthrough
   373  		case "verify-ca":
   374  			// The default certificate verification will also verify the host name
   375  			// which is not the behavior of `verify-ca`. As such, we need to manually
   376  			// check the certificate chain.
   377  			// At the time of writing, tls.Config has no option for this behavior
   378  			// (verify chain, but skip server name).
   379  			// See https://github.com/golang/go/issues/21971 .
   380  			tlsConfig.InsecureSkipVerify = true
   381  			tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
   382  				certs := make([]*x509.Certificate, 0, len(rawCerts))
   383  				for _, rawCert := range rawCerts {
   384  					cert, err := x509.ParseCertificate(rawCert)
   385  					if err != nil {
   386  						return fmt.Errorf("pgdriver: failed to parse certificate: %w", err)
   387  					}
   388  					certs = append(certs, cert)
   389  				}
   390  				intermediates := x509.NewCertPool()
   391  				for _, cert := range certs[1:] {
   392  					intermediates.AddCert(cert)
   393  				}
   394  				_, err := certs[0].Verify(x509.VerifyOptions{
   395  					Roots:         tlsConfig.RootCAs,
   396  					Intermediates: intermediates,
   397  				})
   398  				return err
   399  			}
   400  		case "verify-full":
   401  			tlsConfig.ServerName = u.Host
   402  			if host, _, err := net.SplitHostPort(u.Host); err == nil {
   403  				tlsConfig.ServerName = host
   404  			}
   405  		default:
   406  			return nil, fmt.Errorf("pgdriver: sslmode '%s' is not supported", sslMode)
   407  		}
   408  		if tlsConfig != nil && sslRootCert != "" {
   409  			rawCA, err := ioutil.ReadFile(sslRootCert)
   410  			if err != nil {
   411  				return nil, fmt.Errorf("pgdriver: failed to read root CA: %w", err)
   412  			}
   413  			certPool := x509.NewCertPool()
   414  			if !certPool.AppendCertsFromPEM(rawCA) {
   415  				return nil, fmt.Errorf("pgdriver: failed to append root CA")
   416  			}
   417  			tlsConfig.RootCAs = certPool
   418  		}
   419  		opts = append(opts, WithTLSConfig(tlsConfig))
   420  	}
   421  
   422  	if d := q.duration("timeout"); d != 0 {
   423  		opts = append(opts, WithTimeout(d))
   424  	}
   425  	if d := q.duration("dial_timeout"); d != 0 {
   426  		opts = append(opts, WithDialTimeout(d))
   427  	}
   428  	if d := q.duration("connect_timeout"); d != 0 {
   429  		opts = append(opts, WithDialTimeout(d))
   430  	}
   431  	if d := q.duration("read_timeout"); d != 0 {
   432  		opts = append(opts, WithReadTimeout(d))
   433  	}
   434  	if d := q.duration("write_timeout"); d != 0 {
   435  		opts = append(opts, WithWriteTimeout(d))
   436  	}
   437  
   438  	rem, err := q.remaining()
   439  	if err != nil {
   440  		return nil, q.err
   441  	}
   442  
   443  	if len(rem) > 0 {
   444  		params := make(map[string]interface{}, len(rem))
   445  		for k, v := range rem {
   446  			params[k] = v
   447  		}
   448  		opts = append(opts, WithConnParams(params))
   449  	}
   450  
   451  	return opts, nil
   452  }
   453  
   454  // ------------------------------------------------- --------------------------------------------------------------------
   455  
   456  type queryOptions struct {
   457  	q   url.Values
   458  	err error
   459  }
   460  
   461  func (o *queryOptions) string(name string) string {
   462  	vs := o.q[name]
   463  	if len(vs) == 0 {
   464  		return ""
   465  	}
   466  	delete(o.q, name) // enable detection of unknown parameters
   467  	return vs[len(vs)-1]
   468  }
   469  
   470  func (o *queryOptions) duration(name string) time.Duration {
   471  	s := o.string(name)
   472  	if s == "" {
   473  		return 0
   474  	}
   475  	// try plain number first
   476  	if i, err := strconv.Atoi(s); err == nil {
   477  		if i <= 0 {
   478  			// disable timeouts
   479  			return -1
   480  		}
   481  		return time.Duration(i) * time.Second
   482  	}
   483  	dur, err := time.ParseDuration(s)
   484  	if err == nil {
   485  		return dur
   486  	}
   487  	if o.err == nil {
   488  		o.err = fmt.Errorf("pgdriver: invalid %s duration: %w", name, err)
   489  	}
   490  	return 0
   491  }
   492  
   493  func (o *queryOptions) remaining() (map[string]string, error) {
   494  	if o.err != nil {
   495  		return nil, o.err
   496  	}
   497  	if len(o.q) == 0 {
   498  		return nil, nil
   499  	}
   500  	m := make(map[string]string, len(o.q))
   501  	for k, ss := range o.q {
   502  		m[k] = ss[len(ss)-1]
   503  	}
   504  	return m, nil
   505  }
   506  
   507  // ------------------------------------------------- --------------------------------------------------------------------