github.com/insionng/yougam@v0.0.0-20170714101924-2bc18d833463/libraries/lib/pq/conn.go (about)

     1  package pq
     2  
     3  import (
     4  	"bufio"
     5  	"crypto/md5"
     6  	"crypto/tls"
     7  	"crypto/x509"
     8  	"database/sql"
     9  	"database/sql/driver"
    10  	"encoding/binary"
    11  	"errors"
    12  	"fmt"
    13  	"io"
    14  	"io/ioutil"
    15  	"net"
    16  	"os"
    17  	"os/user"
    18  	"path"
    19  	"path/filepath"
    20  	"strconv"
    21  	"strings"
    22  	"time"
    23  	"unicode"
    24  
    25  	"github.com/insionng/yougam/libraries/lib/pq/oid"
    26  )
    27  
    28  // Common error types
    29  var (
    30  	ErrNotSupported              = errors.New("pq: Unsupported command")
    31  	ErrInFailedTransaction       = errors.New("pq: Could not complete operation in a failed transaction")
    32  	ErrSSLNotSupported           = errors.New("pq: SSL is not enabled on the server")
    33  	ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less.")
    34  	ErrCouldNotDetectUsername    = errors.New("pq: Could not detect default username. Please provide one explicitly.")
    35  )
    36  
    37  type drv struct{}
    38  
    39  func (d *drv) Open(name string) (driver.Conn, error) {
    40  	return Open(name)
    41  }
    42  
    43  func init() {
    44  	sql.Register("postgres", &drv{})
    45  }
    46  
    47  type parameterStatus struct {
    48  	// server version in the same format as server_version_num, or 0 if
    49  	// unavailable
    50  	serverVersion int
    51  
    52  	// the current location based on the TimeZone value of the self.Session. if
    53  	// available
    54  	currentLocation *time.Location
    55  }
    56  
    57  type transactionStatus byte
    58  
    59  const (
    60  	txnStatusIdle                transactionStatus = 'I'
    61  	txnStatusIdleInTransaction   transactionStatus = 'T'
    62  	txnStatusInFailedTransaction transactionStatus = 'E'
    63  )
    64  
    65  func (s transactionStatus) String() string {
    66  	switch s {
    67  	case txnStatusIdle:
    68  		return "idle"
    69  	case txnStatusIdleInTransaction:
    70  		return "idle in transaction"
    71  	case txnStatusInFailedTransaction:
    72  		return "in a failed transaction"
    73  	default:
    74  		errorf("unknown transactionStatus %d", s)
    75  	}
    76  
    77  	panic("not reached")
    78  }
    79  
    80  type Dialer interface {
    81  	Dial(network, address string) (net.Conn, error)
    82  	DialTimeout(network, address string, timeout time.Duration) (net.Conn, error)
    83  }
    84  
    85  type defaultDialer struct{}
    86  
    87  func (d defaultDialer) Dial(ntw, addr string) (net.Conn, error) {
    88  	return net.Dial(ntw, addr)
    89  }
    90  func (d defaultDialer) DialTimeout(ntw, addr string, timeout time.Duration) (net.Conn, error) {
    91  	return net.DialTimeout(ntw, addr, timeout)
    92  }
    93  
    94  type conn struct {
    95  	c         net.Conn
    96  	buf       *bufio.Reader
    97  	namei     int
    98  	scratch   [512]byte
    99  	txnStatus transactionStatus
   100  
   101  	parameterStatus parameterStatus
   102  
   103  	saveMessageType   byte
   104  	saveMessageBuffer []byte
   105  
   106  	// If true, this connection is bad and all public-facing functions should
   107  	// return ErrBadConn.
   108  	bad bool
   109  
   110  	// If set, this connection should never use the binary format when
   111  	// receiving query results from prepared statements.  Only provided for
   112  	// debugging.
   113  	disablePreparedBinaryResult bool
   114  
   115  	// Whether to always send []byte parameters over as binary.  Enables single
   116  	// round-trip mode for non-prepared Query calls.
   117  	binaryParameters bool
   118  }
   119  
   120  // Handle driver-side settings in parsed connection string.
   121  func (c *conn) handleDriverSettings(o values) (err error) {
   122  	boolSetting := func(key string, val *bool) error {
   123  		if value := o.Get(key); value != "" {
   124  			if value == "yes" {
   125  				*val = true
   126  			} else if value == "no" {
   127  				*val = false
   128  			} else {
   129  				return fmt.Errorf("unrecognized value %q for %s", value, key)
   130  			}
   131  		}
   132  		return nil
   133  	}
   134  
   135  	err = boolSetting("disable_prepared_binary_result", &c.disablePreparedBinaryResult)
   136  	if err != nil {
   137  		return err
   138  	}
   139  	err = boolSetting("binary_parameters", &c.binaryParameters)
   140  	if err != nil {
   141  		return err
   142  	}
   143  	return nil
   144  }
   145  
   146  func (c *conn) handlePgpass(o values) {
   147  	// if a password was supplied, do not process .pgpass
   148  	_, ok := o["password"]
   149  	if ok {
   150  		return
   151  	}
   152  	filename := os.Getenv("PGPASSFILE")
   153  	if filename == "" {
   154  		// XXX this code doesn't work on Windows where the default filename is
   155  		// XXX %APPDATA%\postgresql\pgpass.conf
   156  		user, err := user.Current()
   157  		if err != nil {
   158  			return
   159  		}
   160  		filename = filepath.Join(user.HomeDir, ".pgpass")
   161  	}
   162  	fileinfo, err := os.Stat(filename)
   163  	if err != nil {
   164  		return
   165  	}
   166  	mode := fileinfo.Mode()
   167  	if mode&(0x77) != 0 {
   168  		// XXX should warn about incorrect .pgpass permissions as psql does
   169  		return
   170  	}
   171  	file, err := os.Open(filename)
   172  	if err != nil {
   173  		return
   174  	}
   175  	defer file.Close()
   176  	scanner := bufio.NewScanner(io.Reader(file))
   177  	hostname := o.Get("host")
   178  	ntw, _ := network(o)
   179  	port := o.Get("port")
   180  	db := o.Get("dbname")
   181  	username := o.Get("user")
   182  	// From: https://yougam/libraries/tg/pgpass/blob/master/reader.go
   183  	getFields := func(s string) []string {
   184  		fs := make([]string, 0, 5)
   185  		f := make([]rune, 0, len(s))
   186  
   187  		var esc bool
   188  		for _, c := range s {
   189  			switch {
   190  			case esc:
   191  				f = append(f, c)
   192  				esc = false
   193  			case c == '\\':
   194  				esc = true
   195  			case c == ':':
   196  				fs = append(fs, string(f))
   197  				f = f[:0]
   198  			default:
   199  				f = append(f, c)
   200  			}
   201  		}
   202  		return append(fs, string(f))
   203  	}
   204  	for scanner.Scan() {
   205  		line := scanner.Text()
   206  		if len(line) == 0 || line[0] == '#' {
   207  			continue
   208  		}
   209  		split := getFields(line)
   210  		if len(split) != 5 {
   211  			continue
   212  		}
   213  		if (split[0] == "*" || split[0] == hostname || (split[0] == "localhost" && (hostname == "" || ntw == "unix"))) && (split[1] == "*" || split[1] == port) && (split[2] == "*" || split[2] == db) && (split[3] == "*" || split[3] == username) {
   214  			o["password"] = split[4]
   215  			return
   216  		}
   217  	}
   218  }
   219  
   220  func (c *conn) writeBuf(b byte) *writeBuf {
   221  	c.scratch[0] = b
   222  	return &writeBuf{
   223  		buf: c.scratch[:5],
   224  		pos: 1,
   225  	}
   226  }
   227  
   228  func Open(name string) (_ driver.Conn, err error) {
   229  	return DialOpen(defaultDialer{}, name)
   230  }
   231  
   232  func DialOpen(d Dialer, name string) (_ driver.Conn, err error) {
   233  	// Handle any panics during connection initialization.  Note that we
   234  	// specifically do *not* want to use errRecover(), as that would turn any
   235  	// connection errors into ErrBadConns, hiding the real error message from
   236  	// the user.
   237  	defer errRecoverNoErrBadConn(&err)
   238  
   239  	o := make(values)
   240  
   241  	// A number of defaults are applied here, in this order:
   242  	//
   243  	// * Very low precedence defaults applied in every situation
   244  	// * Environment variables
   245  	// * Explicitly passed connection information
   246  	o.Set("host", "localhost")
   247  	o.Set("port", "5432")
   248  	// N.B.: Extra float digits should be set to 3, but that breaks
   249  	// Postgres 8.4 and older, where the max is 2.
   250  	o.Set("extra_float_digits", "2")
   251  	for k, v := range parseEnviron(os.Environ()) {
   252  		o.Set(k, v)
   253  	}
   254  
   255  	if strings.HasPrefix(name, "postgres://") || strings.HasPrefix(name, "postgresql://") {
   256  		name, err = ParseURL(name)
   257  		if err != nil {
   258  			return nil, err
   259  		}
   260  	}
   261  
   262  	if err := parseOpts(name, o); err != nil {
   263  		return nil, err
   264  	}
   265  
   266  	// Use the "fallback" application name if necessary
   267  	if fallback := o.Get("fallback_application_name"); fallback != "" {
   268  		if !o.Isset("application_name") {
   269  			o.Set("application_name", fallback)
   270  		}
   271  	}
   272  
   273  	// We can't work with any client_encoding other than UTF-8 currently.
   274  	// However, we have historically allowed the user to set it to UTF-8
   275  	// explicitly, and there's no reason to break such programs, so allow that.
   276  	// Note that the "options" setting could also set client_encoding, but
   277  	// parsing its value is not worth it.  Instead, we always explicitly send
   278  	// client_encoding as a separate run-time parameter, which should override
   279  	// anything set in options.
   280  	if enc := o.Get("client_encoding"); enc != "" && !isUTF8(enc) {
   281  		return nil, errors.New("client_encoding must be absent or 'UTF8'")
   282  	}
   283  	o.Set("client_encoding", "UTF8")
   284  	// DateStyle needs a similar treatment.
   285  	if datestyle := o.Get("datestyle"); datestyle != "" {
   286  		if datestyle != "ISO, MDY" {
   287  			panic(fmt.Sprintf("setting datestyle must be absent or %v; got %v",
   288  				"ISO, MDY", datestyle))
   289  		}
   290  	} else {
   291  		o.Set("datestyle", "ISO, MDY")
   292  	}
   293  
   294  	// If a user is not provided by any other means, the last
   295  	// resort is to use the current operating system provided user
   296  	// name.
   297  	if o.Get("user") == "" {
   298  		u, err := userCurrent()
   299  		if err != nil {
   300  			return nil, err
   301  		} else {
   302  			o.Set("user", u)
   303  		}
   304  	}
   305  
   306  	cn := &conn{}
   307  	err = cn.handleDriverSettings(o)
   308  	if err != nil {
   309  		return nil, err
   310  	}
   311  	cn.handlePgpass(o)
   312  
   313  	cn.c, err = dial(d, o)
   314  	if err != nil {
   315  		return nil, err
   316  	}
   317  	cn.ssl(o)
   318  	cn.buf = bufio.NewReader(cn.c)
   319  	cn.startup(o)
   320  
   321  	// reset the deadline, in case one was set (see dial)
   322  	if timeout := o.Get("connect_timeout"); timeout != "" && timeout != "0" {
   323  		err = cn.c.SetDeadline(time.Time{})
   324  	}
   325  	return cn, err
   326  }
   327  
   328  func dial(d Dialer, o values) (net.Conn, error) {
   329  	ntw, addr := network(o)
   330  	// SSL is not necessary or supported over UNIX domain sockets
   331  	if ntw == "unix" {
   332  		o["sslmode"] = "disable"
   333  	}
   334  
   335  	// Zero or not specified means wait indefinitely.
   336  	if timeout := o.Get("connect_timeout"); timeout != "" && timeout != "0" {
   337  		seconds, err := strconv.ParseInt(timeout, 10, 0)
   338  		if err != nil {
   339  			return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err)
   340  		}
   341  		duration := time.Duration(seconds) * time.Second
   342  		// connect_timeout should apply to the entire connection establishment
   343  		// procedure, so we both use a timeout for the TCP connection
   344  		// establishment and set a deadline for doing the initial handshake.
   345  		// The deadline is then reset after startup() is done.
   346  		deadline := time.Now().Add(duration)
   347  		conn, err := d.DialTimeout(ntw, addr, duration)
   348  		if err != nil {
   349  			return nil, err
   350  		}
   351  		err = conn.SetDeadline(deadline)
   352  		return conn, err
   353  	}
   354  	return d.Dial(ntw, addr)
   355  }
   356  
   357  func network(o values) (string, string) {
   358  	host := o.Get("host")
   359  
   360  	if strings.HasPrefix(host, "/") {
   361  		sockPath := path.Join(host, ".s.PGSQL."+o.Get("port"))
   362  		return "unix", sockPath
   363  	}
   364  
   365  	return "tcp", net.JoinHostPort(host, o.Get("port"))
   366  }
   367  
   368  type values map[string]string
   369  
   370  func (vs values) Set(k, v string) {
   371  	vs[k] = v
   372  }
   373  
   374  func (vs values) Get(k string) (v string) {
   375  	return vs[k]
   376  }
   377  
   378  func (vs values) Isset(k string) bool {
   379  	_, ok := vs[k]
   380  	return ok
   381  }
   382  
   383  // scanner implements a tokenizer for libpq-style option strings.
   384  type scanner struct {
   385  	s []rune
   386  	i int
   387  }
   388  
   389  // newScanner returns a new scanner initialized with the option string s.
   390  func newScanner(s string) *scanner {
   391  	return &scanner{[]rune(s), 0}
   392  }
   393  
   394  // Next returns the next rune.
   395  // It returns 0, false if the end of the text has been reached.
   396  func (s *scanner) Next() (rune, bool) {
   397  	if s.i >= len(s.s) {
   398  		return 0, false
   399  	}
   400  	r := s.s[s.i]
   401  	s.i++
   402  	return r, true
   403  }
   404  
   405  // SkipSpaces returns the next non-whitespace rune.
   406  // It returns 0, false if the end of the text has been reached.
   407  func (s *scanner) SkipSpaces() (rune, bool) {
   408  	r, ok := s.Next()
   409  	for unicode.IsSpace(r) && ok {
   410  		r, ok = s.Next()
   411  	}
   412  	return r, ok
   413  }
   414  
   415  // parseOpts parses the options from name and adds them to the values.
   416  //
   417  // The parsing code is based on conninfo_parse from libpq's fe-connect.c
   418  func parseOpts(name string, o values) error {
   419  	s := newScanner(name)
   420  
   421  	for {
   422  		var (
   423  			keyRunes, valRunes []rune
   424  			r                  rune
   425  			ok                 bool
   426  		)
   427  
   428  		if r, ok = s.SkipSpaces(); !ok {
   429  			break
   430  		}
   431  
   432  		// Scan the key
   433  		for !unicode.IsSpace(r) && r != '=' {
   434  			keyRunes = append(keyRunes, r)
   435  			if r, ok = s.Next(); !ok {
   436  				break
   437  			}
   438  		}
   439  
   440  		// Skip any whitespace if we're not at the = yet
   441  		if r != '=' {
   442  			r, ok = s.SkipSpaces()
   443  		}
   444  
   445  		// The current character should be =
   446  		if r != '=' || !ok {
   447  			return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes))
   448  		}
   449  
   450  		// Skip any whitespace after the =
   451  		if r, ok = s.SkipSpaces(); !ok {
   452  			// If we reach the end here, the last value is just an empty string as per libpq.
   453  			o.Set(string(keyRunes), "")
   454  			break
   455  		}
   456  
   457  		if r != '\'' {
   458  			for !unicode.IsSpace(r) {
   459  				if r == '\\' {
   460  					if r, ok = s.Next(); !ok {
   461  						return fmt.Errorf(`missing character after backslash`)
   462  					}
   463  				}
   464  				valRunes = append(valRunes, r)
   465  
   466  				if r, ok = s.Next(); !ok {
   467  					break
   468  				}
   469  			}
   470  		} else {
   471  		quote:
   472  			for {
   473  				if r, ok = s.Next(); !ok {
   474  					return fmt.Errorf(`unterminated quoted string literal in connection string`)
   475  				}
   476  				switch r {
   477  				case '\'':
   478  					break quote
   479  				case '\\':
   480  					r, _ = s.Next()
   481  					fallthrough
   482  				default:
   483  					valRunes = append(valRunes, r)
   484  				}
   485  			}
   486  		}
   487  
   488  		o.Set(string(keyRunes), string(valRunes))
   489  	}
   490  
   491  	return nil
   492  }
   493  
   494  func (cn *conn) isInTransaction() bool {
   495  	return cn.txnStatus == txnStatusIdleInTransaction ||
   496  		cn.txnStatus == txnStatusInFailedTransaction
   497  }
   498  
   499  func (cn *conn) checkIsInTransaction(intxn bool) {
   500  	if cn.isInTransaction() != intxn {
   501  		cn.bad = true
   502  		errorf("unexpected transaction status %v", cn.txnStatus)
   503  	}
   504  }
   505  
   506  func (cn *conn) Begin() (_ driver.Tx, err error) {
   507  	if cn.bad {
   508  		return nil, driver.ErrBadConn
   509  	}
   510  	defer cn.errRecover(&err)
   511  
   512  	cn.checkIsInTransaction(false)
   513  	_, commandTag, err := cn.simpleExec("BEGIN")
   514  	if err != nil {
   515  		return nil, err
   516  	}
   517  	if commandTag != "BEGIN" {
   518  		cn.bad = true
   519  		return nil, fmt.Errorf("unexpected command tag %s", commandTag)
   520  	}
   521  	if cn.txnStatus != txnStatusIdleInTransaction {
   522  		cn.bad = true
   523  		return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus)
   524  	}
   525  	return cn, nil
   526  }
   527  
   528  func (cn *conn) Commit() (err error) {
   529  	if cn.bad {
   530  		return driver.ErrBadConn
   531  	}
   532  	defer cn.errRecover(&err)
   533  
   534  	cn.checkIsInTransaction(true)
   535  	// We don't want the client to think that everything is okay if it tries
   536  	// to commit a failed transaction.  However, no matter what we return,
   537  	// database/sql will release this connection back into the free connection
   538  	// pool so we have to abort the current transaction here.  Note that you
   539  	// would get the same behaviour if you issued a COMMIT in a failed
   540  	// transaction, so it's also the least surprising thing to do here.
   541  	if cn.txnStatus == txnStatusInFailedTransaction {
   542  		if err := cn.Rollback(); err != nil {
   543  			return err
   544  		}
   545  		return ErrInFailedTransaction
   546  	}
   547  
   548  	_, commandTag, err := cn.simpleExec("COMMIT")
   549  	if err != nil {
   550  		if cn.isInTransaction() {
   551  			cn.bad = true
   552  		}
   553  		return err
   554  	}
   555  	if commandTag != "COMMIT" {
   556  		cn.bad = true
   557  		return fmt.Errorf("unexpected command tag %s", commandTag)
   558  	}
   559  	cn.checkIsInTransaction(false)
   560  	return nil
   561  }
   562  
   563  func (cn *conn) Rollback() (err error) {
   564  	if cn.bad {
   565  		return driver.ErrBadConn
   566  	}
   567  	defer cn.errRecover(&err)
   568  
   569  	cn.checkIsInTransaction(true)
   570  	_, commandTag, err := cn.simpleExec("ROLLBACK")
   571  	if err != nil {
   572  		if cn.isInTransaction() {
   573  			cn.bad = true
   574  		}
   575  		return err
   576  	}
   577  	if commandTag != "ROLLBACK" {
   578  		return fmt.Errorf("unexpected command tag %s", commandTag)
   579  	}
   580  	cn.checkIsInTransaction(false)
   581  	return nil
   582  }
   583  
   584  func (cn *conn) gname() string {
   585  	cn.namei++
   586  	return strconv.FormatInt(int64(cn.namei), 10)
   587  }
   588  
   589  func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err error) {
   590  	b := cn.writeBuf('Q')
   591  	b.string(q)
   592  	cn.send(b)
   593  
   594  	for {
   595  		t, r := cn.recv1()
   596  		switch t {
   597  		case 'C':
   598  			res, commandTag = cn.parseComplete(r.string())
   599  		case 'Z':
   600  			cn.processReadyForQuery(r)
   601  			// done
   602  			return
   603  		case 'E':
   604  			err = parseError(r)
   605  		case 'T', 'D', 'I':
   606  			// ignore any results
   607  		default:
   608  			cn.bad = true
   609  			errorf("unknown response for simple query: %q", t)
   610  		}
   611  	}
   612  }
   613  
   614  func (cn *conn) simpleQuery(q string) (res *rows, err error) {
   615  	defer cn.errRecover(&err)
   616  
   617  	b := cn.writeBuf('Q')
   618  	b.string(q)
   619  	cn.send(b)
   620  
   621  	for {
   622  		t, r := cn.recv1()
   623  		switch t {
   624  		case 'C', 'I':
   625  			// We allow queries which don't return any results through Query as
   626  			// well as Exec.  We still have to give database/sql a rows object
   627  			// the user can close, though, to avoid connections from being
   628  			// leaked.  A "rows" with done=true works fine for that purpose.
   629  			if err != nil {
   630  				cn.bad = true
   631  				errorf("unexpected message %q in simple query execution", t)
   632  			}
   633  			if res == nil {
   634  				res = &rows{
   635  					cn: cn,
   636  				}
   637  			}
   638  			res.done = true
   639  		case 'Z':
   640  			cn.processReadyForQuery(r)
   641  			// done
   642  			return
   643  		case 'E':
   644  			res = nil
   645  			err = parseError(r)
   646  		case 'D':
   647  			if res == nil {
   648  				cn.bad = true
   649  				errorf("unexpected DataRow in simple query execution")
   650  			}
   651  			// the query didn't fail; kick off to Next
   652  			cn.saveMessage(t, r)
   653  			return
   654  		case 'T':
   655  			// res might be non-nil here if we received a previous
   656  			// CommandComplete, but that's fine; just overwrite it
   657  			res = &rows{cn: cn}
   658  			res.colNames, res.colFmts, res.colTyps = parsePortalRowDescribe(r)
   659  
   660  			// To work around a bug in QueryRow in Go 1.2 and earlier, wait
   661  			// until the first DataRow has been received.
   662  		default:
   663  			cn.bad = true
   664  			errorf("unknown response for simple query: %q", t)
   665  		}
   666  	}
   667  }
   668  
   669  // Decides which column formats to use for a prepared statement.  The input is
   670  // an array of type oids, one element per result column.
   671  func decideColumnFormats(colTyps []oid.Oid, forceText bool) (colFmts []format, colFmtData []byte) {
   672  	if len(colTyps) == 0 {
   673  		return nil, colFmtDataAllText
   674  	}
   675  
   676  	colFmts = make([]format, len(colTyps))
   677  	if forceText {
   678  		return colFmts, colFmtDataAllText
   679  	}
   680  
   681  	allBinary := true
   682  	allText := true
   683  	for i, o := range colTyps {
   684  		switch o {
   685  		// This is the list of types to use binary mode for when receiving them
   686  		// through a prepared statement.  If a type appears in this list, it
   687  		// must also be implemented in binaryDecode in encode.go.
   688  		case oid.T_bytea:
   689  			fallthrough
   690  		case oid.T_int8:
   691  			fallthrough
   692  		case oid.T_int4:
   693  			fallthrough
   694  		case oid.T_int2:
   695  			colFmts[i] = formatBinary
   696  			allText = false
   697  
   698  		default:
   699  			allBinary = false
   700  		}
   701  	}
   702  
   703  	if allBinary {
   704  		return colFmts, colFmtDataAllBinary
   705  	} else if allText {
   706  		return colFmts, colFmtDataAllText
   707  	} else {
   708  		colFmtData = make([]byte, 2+len(colFmts)*2)
   709  		binary.BigEndian.PutUint16(colFmtData, uint16(len(colFmts)))
   710  		for i, v := range colFmts {
   711  			binary.BigEndian.PutUint16(colFmtData[2+i*2:], uint16(v))
   712  		}
   713  		return colFmts, colFmtData
   714  	}
   715  }
   716  
   717  func (cn *conn) prepareTo(q, stmtName string) *stmt {
   718  	st := &stmt{cn: cn, name: stmtName}
   719  
   720  	b := cn.writeBuf('P')
   721  	b.string(st.name)
   722  	b.string(q)
   723  	b.int16(0)
   724  
   725  	b.next('D')
   726  	b.byte('S')
   727  	b.string(st.name)
   728  
   729  	b.next('S')
   730  	cn.send(b)
   731  
   732  	cn.readParseResponse()
   733  	st.paramTyps, st.colNames, st.colTyps = cn.readStatementDescribeResponse()
   734  	st.colFmts, st.colFmtData = decideColumnFormats(st.colTyps, cn.disablePreparedBinaryResult)
   735  	cn.readReadyForQuery()
   736  	return st
   737  }
   738  
   739  func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) {
   740  	if cn.bad {
   741  		return nil, driver.ErrBadConn
   742  	}
   743  	defer cn.errRecover(&err)
   744  
   745  	if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") {
   746  		return cn.prepareCopyIn(q)
   747  	}
   748  	return cn.prepareTo(q, cn.gname()), nil
   749  }
   750  
   751  func (cn *conn) Close() (err error) {
   752  	if cn.bad {
   753  		return driver.ErrBadConn
   754  	}
   755  	defer cn.errRecover(&err)
   756  
   757  	// Don't go through send(); ListenerConn relies on us not scribbling on the
   758  	// scratch buffer of this connection.
   759  	err = cn.sendSimpleMessage('X')
   760  	if err != nil {
   761  		return err
   762  	}
   763  
   764  	return cn.c.Close()
   765  }
   766  
   767  // Implement the "Queryer" interface
   768  func (cn *conn) Query(query string, args []driver.Value) (_ driver.Rows, err error) {
   769  	if cn.bad {
   770  		return nil, driver.ErrBadConn
   771  	}
   772  	defer cn.errRecover(&err)
   773  
   774  	// Check to see if we can use the "simpleQuery" interface, which is
   775  	// *much* faster than going through prepare/exec
   776  	if len(args) == 0 {
   777  		return cn.simpleQuery(query)
   778  	}
   779  
   780  	if cn.binaryParameters {
   781  		cn.sendBinaryModeQuery(query, args)
   782  
   783  		cn.readParseResponse()
   784  		cn.readBindResponse()
   785  		rows := &rows{cn: cn}
   786  		rows.colNames, rows.colFmts, rows.colTyps = cn.readPortalDescribeResponse()
   787  		cn.postExecuteWorkaround()
   788  		return rows, nil
   789  	} else {
   790  		st := cn.prepareTo(query, "")
   791  		st.exec(args)
   792  		return &rows{
   793  			cn:       cn,
   794  			colNames: st.colNames,
   795  			colTyps:  st.colTyps,
   796  			colFmts:  st.colFmts,
   797  		}, nil
   798  	}
   799  }
   800  
   801  // Implement the optional "Execer" interface for one-shot queries
   802  func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) {
   803  	if cn.bad {
   804  		return nil, driver.ErrBadConn
   805  	}
   806  	defer cn.errRecover(&err)
   807  
   808  	// Check to see if we can use the "simpleExec" interface, which is
   809  	// *much* faster than going through prepare/exec
   810  	if len(args) == 0 {
   811  		// ignore commandTag, our caller doesn't care
   812  		r, _, err := cn.simpleExec(query)
   813  		return r, err
   814  	}
   815  
   816  	if cn.binaryParameters {
   817  		cn.sendBinaryModeQuery(query, args)
   818  
   819  		cn.readParseResponse()
   820  		cn.readBindResponse()
   821  		cn.readPortalDescribeResponse()
   822  		cn.postExecuteWorkaround()
   823  		res, _, err = cn.readExecuteResponse("Execute")
   824  		return res, err
   825  	} else {
   826  		// Use the unnamed statement to defer planning until bind
   827  		// time, or else value-based selectivity estimates cannot be
   828  		// used.
   829  		st := cn.prepareTo(query, "")
   830  		r, err := st.Exec(args)
   831  		if err != nil {
   832  			panic(err)
   833  		}
   834  		return r, err
   835  	}
   836  }
   837  
   838  func (cn *conn) send(m *writeBuf) {
   839  	_, err := cn.c.Write(m.wrap())
   840  	if err != nil {
   841  		panic(err)
   842  	}
   843  }
   844  
   845  func (cn *conn) sendStartupPacket(m *writeBuf) {
   846  	// sanity check
   847  	if m.buf[0] != 0 {
   848  		panic("oops")
   849  	}
   850  
   851  	_, err := cn.c.Write((m.wrap())[1:])
   852  	if err != nil {
   853  		panic(err)
   854  	}
   855  }
   856  
   857  // Send a message of type typ to the server on the other end of cn.  The
   858  // message should have no payload.  This method does not use the scratch
   859  // buffer.
   860  func (cn *conn) sendSimpleMessage(typ byte) (err error) {
   861  	_, err = cn.c.Write([]byte{typ, '\x00', '\x00', '\x00', '\x04'})
   862  	return err
   863  }
   864  
   865  // saveMessage memorizes a message and its buffer in the conn struct.
   866  // recvMessage will then return these values on the next call to it.  This
   867  // method is useful in cases where you have to see what the next message is
   868  // going to be (e.g. to see whether it's an error or not) but you can't handle
   869  // the message yourself.
   870  func (cn *conn) saveMessage(typ byte, buf *readBuf) {
   871  	if cn.saveMessageType != 0 {
   872  		cn.bad = true
   873  		errorf("unexpected saveMessageType %d", cn.saveMessageType)
   874  	}
   875  	cn.saveMessageType = typ
   876  	cn.saveMessageBuffer = *buf
   877  }
   878  
   879  // recvMessage receives any message from the backend, or returns an error if
   880  // a problem occurred while reading the message.
   881  func (cn *conn) recvMessage(r *readBuf) (byte, error) {
   882  	// workaround for a QueryRow bug, see exec
   883  	if cn.saveMessageType != 0 {
   884  		t := cn.saveMessageType
   885  		*r = cn.saveMessageBuffer
   886  		cn.saveMessageType = 0
   887  		cn.saveMessageBuffer = nil
   888  		return t, nil
   889  	}
   890  
   891  	x := cn.scratch[:5]
   892  	_, err := io.ReadFull(cn.buf, x)
   893  	if err != nil {
   894  		return 0, err
   895  	}
   896  
   897  	// read the type and length of the message that follows
   898  	t := x[0]
   899  	n := int(binary.BigEndian.Uint32(x[1:])) - 4
   900  	var y []byte
   901  	if n <= len(cn.scratch) {
   902  		y = cn.scratch[:n]
   903  	} else {
   904  		y = make([]byte, n)
   905  	}
   906  	_, err = io.ReadFull(cn.buf, y)
   907  	if err != nil {
   908  		return 0, err
   909  	}
   910  	*r = y
   911  	return t, nil
   912  }
   913  
   914  // recv receives a message from the backend, but if an error happened while
   915  // reading the message or the received message was an ErrorResponse, it panics.
   916  // NoticeResponses are ignored.  This function should generally be used only
   917  // during the startup sequence.
   918  func (cn *conn) recv() (t byte, r *readBuf) {
   919  	for {
   920  		var err error
   921  		r = &readBuf{}
   922  		t, err = cn.recvMessage(r)
   923  		if err != nil {
   924  			panic(err)
   925  		}
   926  
   927  		switch t {
   928  		case 'E':
   929  			panic(parseError(r))
   930  		case 'N':
   931  			// ignore
   932  		default:
   933  			return
   934  		}
   935  	}
   936  }
   937  
   938  // recv1Buf is exactly equivalent to recv1, except it uses a buffer supplied by
   939  // the caller to avoid an allocation.
   940  func (cn *conn) recv1Buf(r *readBuf) byte {
   941  	for {
   942  		t, err := cn.recvMessage(r)
   943  		if err != nil {
   944  			panic(err)
   945  		}
   946  
   947  		switch t {
   948  		case 'A', 'N':
   949  			// ignore
   950  		case 'S':
   951  			cn.processParameterStatus(r)
   952  		default:
   953  			return t
   954  		}
   955  	}
   956  }
   957  
   958  // recv1 receives a message from the backend, panicking if an error occurs
   959  // while attempting to read it.  All asynchronous messages are ignored, with
   960  // the exception of ErrorResponse.
   961  func (cn *conn) recv1() (t byte, r *readBuf) {
   962  	r = &readBuf{}
   963  	t = cn.recv1Buf(r)
   964  	return t, r
   965  }
   966  
   967  func (cn *conn) ssl(o values) {
   968  	verifyCaOnly := false
   969  	tlsConf := tls.Config{}
   970  	switch mode := o.Get("sslmode"); mode {
   971  	case "require", "":
   972  		tlsConf.InsecureSkipVerify = true
   973  	case "verify-ca":
   974  		// We must skip TLS's own verification since it requires full
   975  		// verification since Go 1.3.
   976  		tlsConf.InsecureSkipVerify = true
   977  		verifyCaOnly = true
   978  	case "verify-full":
   979  		tlsConf.ServerName = o.Get("host")
   980  	case "disable":
   981  		return
   982  	default:
   983  		errorf(`unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported`, mode)
   984  	}
   985  
   986  	cn.setupSSLClientCertificates(&tlsConf, o)
   987  	cn.setupSSLCA(&tlsConf, o)
   988  
   989  	w := cn.writeBuf(0)
   990  	w.int32(80877103)
   991  	cn.sendStartupPacket(w)
   992  
   993  	b := cn.scratch[:1]
   994  	_, err := io.ReadFull(cn.c, b)
   995  	if err != nil {
   996  		panic(err)
   997  	}
   998  
   999  	if b[0] != 'S' {
  1000  		panic(ErrSSLNotSupported)
  1001  	}
  1002  
  1003  	client := tls.Client(cn.c, &tlsConf)
  1004  	if verifyCaOnly {
  1005  		cn.verifyCA(client, &tlsConf)
  1006  	}
  1007  	cn.c = client
  1008  }
  1009  
  1010  // verifyCA carries out a TLS handshake to the server and verifies the
  1011  // presented certificate against the effective CA, i.e. the one specified in
  1012  // sslrootcert or the system CA if sslrootcert was not specified.
  1013  func (cn *conn) verifyCA(client *tls.Conn, tlsConf *tls.Config) {
  1014  	err := client.Handshake()
  1015  	if err != nil {
  1016  		panic(err)
  1017  	}
  1018  	certs := client.ConnectionState().PeerCertificates
  1019  	opts := x509.VerifyOptions{
  1020  		DNSName:       client.ConnectionState().ServerName,
  1021  		Intermediates: x509.NewCertPool(),
  1022  		Roots:         tlsConf.RootCAs,
  1023  	}
  1024  	for i, cert := range certs {
  1025  		if i == 0 {
  1026  			continue
  1027  		}
  1028  		opts.Intermediates.AddCert(cert)
  1029  	}
  1030  	_, err = certs[0].Verify(opts)
  1031  	if err != nil {
  1032  		panic(err)
  1033  	}
  1034  }
  1035  
  1036  // This function sets up SSL client certificates based on either the "sslkey"
  1037  // and "sslcert" settings (possibly set via the environment variables PGSSLKEY
  1038  // and PGSSLCERT, respectively), or if they aren't set, from the .postgresql
  1039  // directory in the user's home directory.  If the file paths are set
  1040  // explicitly, the files must exist.  The key file must also not be
  1041  // world-readable, or this function will panic with
  1042  // ErrSSLKeyHasWorldPermissions.
  1043  func (cn *conn) setupSSLClientCertificates(tlsConf *tls.Config, o values) {
  1044  	var missingOk bool
  1045  
  1046  	sslkey := o.Get("sslkey")
  1047  	sslcert := o.Get("sslcert")
  1048  	if sslkey != "" && sslcert != "" {
  1049  		// If the user has set an sslkey and sslcert, they *must* exist.
  1050  		missingOk = false
  1051  	} else {
  1052  		// Automatically load certificates from ~/.postgresql.
  1053  		user, err := user.Current()
  1054  		if err != nil {
  1055  			// user.Current() might fail when cross-compiling.  We have to
  1056  			// ignore the error and continue without client certificates, since
  1057  			// we wouldn't know where to load them from.
  1058  			return
  1059  		}
  1060  
  1061  		sslkey = filepath.Join(user.HomeDir, ".postgresql", "postgresql.key")
  1062  		sslcert = filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt")
  1063  		missingOk = true
  1064  	}
  1065  
  1066  	// Check that both files exist, and report the error or stop, depending on
  1067  	// which behaviour we want.  Note that we don't do any more extensive
  1068  	// checks than this (such as checking that the paths aren't directories);
  1069  	// LoadX509KeyPair() will take care of the rest.
  1070  	keyfinfo, err := os.Stat(sslkey)
  1071  	if err != nil && missingOk {
  1072  		return
  1073  	} else if err != nil {
  1074  		panic(err)
  1075  	}
  1076  	_, err = os.Stat(sslcert)
  1077  	if err != nil && missingOk {
  1078  		return
  1079  	} else if err != nil {
  1080  		panic(err)
  1081  	}
  1082  
  1083  	// If we got this far, the key file must also have the correct permissions
  1084  	kmode := keyfinfo.Mode()
  1085  	if kmode != kmode&0600 {
  1086  		panic(ErrSSLKeyHasWorldPermissions)
  1087  	}
  1088  
  1089  	cert, err := tls.LoadX509KeyPair(sslcert, sslkey)
  1090  	if err != nil {
  1091  		panic(err)
  1092  	}
  1093  	tlsConf.Certificates = []tls.Certificate{cert}
  1094  }
  1095  
  1096  // Sets up RootCAs in the TLS configuration if sslrootcert is set.
  1097  func (cn *conn) setupSSLCA(tlsConf *tls.Config, o values) {
  1098  	if sslrootcert := o.Get("sslrootcert"); sslrootcert != "" {
  1099  		tlsConf.RootCAs = x509.NewCertPool()
  1100  
  1101  		cert, err := ioutil.ReadFile(sslrootcert)
  1102  		if err != nil {
  1103  			panic(err)
  1104  		}
  1105  
  1106  		ok := tlsConf.RootCAs.AppendCertsFromPEM(cert)
  1107  		if !ok {
  1108  			errorf("couldn't parse pem in sslrootcert")
  1109  		}
  1110  	}
  1111  }
  1112  
  1113  // isDriverSetting returns true iff a setting is purely for configuring the
  1114  // driver's options and should not be sent to the server in the connection
  1115  // startup packet.
  1116  func isDriverSetting(key string) bool {
  1117  	switch key {
  1118  	case "host", "port":
  1119  		return true
  1120  	case "password":
  1121  		return true
  1122  	case "sslmode", "sslcert", "sslkey", "sslrootcert":
  1123  		return true
  1124  	case "fallback_application_name":
  1125  		return true
  1126  	case "connect_timeout":
  1127  		return true
  1128  	case "disable_prepared_binary_result":
  1129  		return true
  1130  	case "binary_parameters":
  1131  		return true
  1132  
  1133  	default:
  1134  		return false
  1135  	}
  1136  }
  1137  
  1138  func (cn *conn) startup(o values) {
  1139  	w := cn.writeBuf(0)
  1140  	w.int32(196608)
  1141  	// Send the backend the name of the database we want to connect to, and the
  1142  	// user we want to connect as.  Additionally, we send over any run-time
  1143  	// parameters potentially included in the connection string.  If the server
  1144  	// doesn't recognize any of them, it will reply with an error.
  1145  	for k, v := range o {
  1146  		if isDriverSetting(k) {
  1147  			// skip options which can't be run-time parameters
  1148  			continue
  1149  		}
  1150  		// The protocol requires us to supply the database name as "database"
  1151  		// instead of "dbname".
  1152  		if k == "dbname" {
  1153  			k = "database"
  1154  		}
  1155  		w.string(k)
  1156  		w.string(v)
  1157  	}
  1158  	w.string("")
  1159  	cn.sendStartupPacket(w)
  1160  
  1161  	for {
  1162  		t, r := cn.recv()
  1163  		switch t {
  1164  		case 'K':
  1165  		case 'S':
  1166  			cn.processParameterStatus(r)
  1167  		case 'R':
  1168  			cn.auth(r, o)
  1169  		case 'Z':
  1170  			cn.processReadyForQuery(r)
  1171  			return
  1172  		default:
  1173  			errorf("unknown response for startup: %q", t)
  1174  		}
  1175  	}
  1176  }
  1177  
  1178  func (cn *conn) auth(r *readBuf, o values) {
  1179  	switch code := r.int32(); code {
  1180  	case 0:
  1181  		// OK
  1182  	case 3:
  1183  		w := cn.writeBuf('p')
  1184  		w.string(o.Get("password"))
  1185  		cn.send(w)
  1186  
  1187  		t, r := cn.recv()
  1188  		if t != 'R' {
  1189  			errorf("unexpected password response: %q", t)
  1190  		}
  1191  
  1192  		if r.int32() != 0 {
  1193  			errorf("unexpected authentication response: %q", t)
  1194  		}
  1195  	case 5:
  1196  		s := string(r.next(4))
  1197  		w := cn.writeBuf('p')
  1198  		w.string("md5" + md5s(md5s(o.Get("password")+o.Get("user"))+s))
  1199  		cn.send(w)
  1200  
  1201  		t, r := cn.recv()
  1202  		if t != 'R' {
  1203  			errorf("unexpected password response: %q", t)
  1204  		}
  1205  
  1206  		if r.int32() != 0 {
  1207  			errorf("unexpected authentication response: %q", t)
  1208  		}
  1209  	default:
  1210  		errorf("unknown authentication response: %d", code)
  1211  	}
  1212  }
  1213  
  1214  type format int
  1215  
  1216  const formatText format = 0
  1217  const formatBinary format = 1
  1218  
  1219  // One result-column format code with the value 1 (i.e. all binary).
  1220  var colFmtDataAllBinary []byte = []byte{0, 1, 0, 1}
  1221  
  1222  // No result-column format codes (i.e. all text).
  1223  var colFmtDataAllText []byte = []byte{0, 0}
  1224  
  1225  type stmt struct {
  1226  	cn         *conn
  1227  	name       string
  1228  	colNames   []string
  1229  	colFmts    []format
  1230  	colFmtData []byte
  1231  	colTyps    []oid.Oid
  1232  	paramTyps  []oid.Oid
  1233  	closed     bool
  1234  }
  1235  
  1236  func (st *stmt) Close() (err error) {
  1237  	if st.closed {
  1238  		return nil
  1239  	}
  1240  	if st.cn.bad {
  1241  		return driver.ErrBadConn
  1242  	}
  1243  	defer st.cn.errRecover(&err)
  1244  
  1245  	w := st.cn.writeBuf('C')
  1246  	w.byte('S')
  1247  	w.string(st.name)
  1248  	st.cn.send(w)
  1249  
  1250  	st.cn.send(st.cn.writeBuf('S'))
  1251  
  1252  	t, _ := st.cn.recv1()
  1253  	if t != '3' {
  1254  		st.cn.bad = true
  1255  		errorf("unexpected close response: %q", t)
  1256  	}
  1257  	st.closed = true
  1258  
  1259  	t, r := st.cn.recv1()
  1260  	if t != 'Z' {
  1261  		st.cn.bad = true
  1262  		errorf("expected ready for query, but got: %q", t)
  1263  	}
  1264  	st.cn.processReadyForQuery(r)
  1265  
  1266  	return nil
  1267  }
  1268  
  1269  func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) {
  1270  	if st.cn.bad {
  1271  		return nil, driver.ErrBadConn
  1272  	}
  1273  	defer st.cn.errRecover(&err)
  1274  
  1275  	st.exec(v)
  1276  	return &rows{
  1277  		cn:       st.cn,
  1278  		colNames: st.colNames,
  1279  		colTyps:  st.colTyps,
  1280  		colFmts:  st.colFmts,
  1281  	}, nil
  1282  }
  1283  
  1284  func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) {
  1285  	if st.cn.bad {
  1286  		return nil, driver.ErrBadConn
  1287  	}
  1288  	defer st.cn.errRecover(&err)
  1289  
  1290  	st.exec(v)
  1291  	res, _, err = st.cn.readExecuteResponse("simple query")
  1292  	return res, err
  1293  }
  1294  
  1295  func (st *stmt) exec(v []driver.Value) {
  1296  	if len(v) >= 65536 {
  1297  		errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(v))
  1298  	}
  1299  	if len(v) != len(st.paramTyps) {
  1300  		errorf("got %d parameters but the statement requires %d", len(v), len(st.paramTyps))
  1301  	}
  1302  
  1303  	cn := st.cn
  1304  	w := cn.writeBuf('B')
  1305  	w.byte(0) // unnamed portal
  1306  	w.string(st.name)
  1307  
  1308  	if cn.binaryParameters {
  1309  		cn.sendBinaryParameters(w, v)
  1310  	} else {
  1311  		w.int16(0)
  1312  		w.int16(len(v))
  1313  		for i, x := range v {
  1314  			if x == nil {
  1315  				w.int32(-1)
  1316  			} else {
  1317  				b := encode(&cn.parameterStatus, x, st.paramTyps[i])
  1318  				w.int32(len(b))
  1319  				w.bytes(b)
  1320  			}
  1321  		}
  1322  	}
  1323  	w.bytes(st.colFmtData)
  1324  
  1325  	w.next('E')
  1326  	w.byte(0)
  1327  	w.int32(0)
  1328  
  1329  	w.next('S')
  1330  	cn.send(w)
  1331  
  1332  	cn.readBindResponse()
  1333  	cn.postExecuteWorkaround()
  1334  
  1335  }
  1336  
  1337  func (st *stmt) NumInput() int {
  1338  	return len(st.paramTyps)
  1339  }
  1340  
  1341  // parseComplete parses the "command tag" from a CommandComplete message, and
  1342  // returns the number of rows affected (if applicable) and a string
  1343  // identifying only the command that was executed, e.g. "ALTER TABLE".  If the
  1344  // command tag could not be parsed, parseComplete panics.
  1345  func (cn *conn) parseComplete(commandTag string) (driver.Result, string) {
  1346  	commandsWithAffectedRows := []string{
  1347  		"SELECT ",
  1348  		// INSERT is handled below
  1349  		"UPDATE ",
  1350  		"DELETE ",
  1351  		"FETCH ",
  1352  		"MOVE ",
  1353  		"COPY ",
  1354  	}
  1355  
  1356  	var affectedRows *string
  1357  	for _, tag := range commandsWithAffectedRows {
  1358  		if strings.HasPrefix(commandTag, tag) {
  1359  			t := commandTag[len(tag):]
  1360  			affectedRows = &t
  1361  			commandTag = tag[:len(tag)-1]
  1362  			break
  1363  		}
  1364  	}
  1365  	// INSERT also includes the oid of the inserted row in its command tag.
  1366  	// Oids in user tables are deprecated, and the oid is only returned when
  1367  	// exactly one row is inserted, so it's unlikely to be of value to any
  1368  	// real-world application and we can ignore it.
  1369  	if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") {
  1370  		parts := strings.Split(commandTag, " ")
  1371  		if len(parts) != 3 {
  1372  			cn.bad = true
  1373  			errorf("unexpected INSERT command tag %s", commandTag)
  1374  		}
  1375  		affectedRows = &parts[len(parts)-1]
  1376  		commandTag = "INSERT"
  1377  	}
  1378  	// There should be no affected rows attached to the tag, just return it
  1379  	if affectedRows == nil {
  1380  		return driver.RowsAffected(0), commandTag
  1381  	}
  1382  	n, err := strconv.ParseInt(*affectedRows, 10, 64)
  1383  	if err != nil {
  1384  		cn.bad = true
  1385  		errorf("could not parse commandTag: %s", err)
  1386  	}
  1387  	return driver.RowsAffected(n), commandTag
  1388  }
  1389  
  1390  type rows struct {
  1391  	cn       *conn
  1392  	colNames []string
  1393  	colTyps  []oid.Oid
  1394  	colFmts  []format
  1395  	done     bool
  1396  	rb       readBuf
  1397  }
  1398  
  1399  func (rs *rows) Close() error {
  1400  	// no need to look at cn.bad as Next() will
  1401  	for {
  1402  		err := rs.Next(nil)
  1403  		switch err {
  1404  		case nil:
  1405  		case io.EOF:
  1406  			return nil
  1407  		default:
  1408  			return err
  1409  		}
  1410  	}
  1411  }
  1412  
  1413  func (rs *rows) Columns() []string {
  1414  	return rs.colNames
  1415  }
  1416  
  1417  func (rs *rows) Next(dest []driver.Value) (err error) {
  1418  	if rs.done {
  1419  		return io.EOF
  1420  	}
  1421  
  1422  	conn := rs.cn
  1423  	if conn.bad {
  1424  		return driver.ErrBadConn
  1425  	}
  1426  	defer conn.errRecover(&err)
  1427  
  1428  	for {
  1429  		t := conn.recv1Buf(&rs.rb)
  1430  		switch t {
  1431  		case 'E':
  1432  			err = parseError(&rs.rb)
  1433  		case 'C', 'I':
  1434  			continue
  1435  		case 'Z':
  1436  			conn.processReadyForQuery(&rs.rb)
  1437  			rs.done = true
  1438  			if err != nil {
  1439  				return err
  1440  			}
  1441  			return io.EOF
  1442  		case 'D':
  1443  			n := rs.rb.int16()
  1444  			if err != nil {
  1445  				conn.bad = true
  1446  				errorf("unexpected DataRow after error %s", err)
  1447  			}
  1448  			if n < len(dest) {
  1449  				dest = dest[:n]
  1450  			}
  1451  			for i := range dest {
  1452  				l := rs.rb.int32()
  1453  				if l == -1 {
  1454  					dest[i] = nil
  1455  					continue
  1456  				}
  1457  				dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i], rs.colFmts[i])
  1458  			}
  1459  			return
  1460  		default:
  1461  			errorf("unexpected message after execute: %q", t)
  1462  		}
  1463  	}
  1464  }
  1465  
  1466  // QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be
  1467  // used as part of an SQL statement.  For example:
  1468  //
  1469  //    tblname := "my_table"
  1470  //    data := "my_data"
  1471  //    err = db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", pq.QuoteIdentifier(tblname)), data)
  1472  //
  1473  // Any double quotes in name will be escaped.  The quoted identifier will be
  1474  // case sensitive when used in a query.  If the input string contains a zero
  1475  // byte, the result will be truncated immediately before it.
  1476  func QuoteIdentifier(name string) string {
  1477  	end := strings.IndexRune(name, 0)
  1478  	if end > -1 {
  1479  		name = name[:end]
  1480  	}
  1481  	return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
  1482  }
  1483  
  1484  func md5s(s string) string {
  1485  	h := md5.New()
  1486  	h.Write([]byte(s))
  1487  	return fmt.Sprintf("%x", h.Sum(nil))
  1488  }
  1489  
  1490  func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.Value) {
  1491  	// Do one pass over the parameters to see if we're going to send any of
  1492  	// them over in binary.  If we are, create a paramFormats array at the
  1493  	// same time.
  1494  	var paramFormats []int
  1495  	for i, x := range args {
  1496  		_, ok := x.([]byte)
  1497  		if ok {
  1498  			if paramFormats == nil {
  1499  				paramFormats = make([]int, len(args))
  1500  			}
  1501  			paramFormats[i] = 1
  1502  		}
  1503  	}
  1504  	if paramFormats == nil {
  1505  		b.int16(0)
  1506  	} else {
  1507  		b.int16(len(paramFormats))
  1508  		for _, x := range paramFormats {
  1509  			b.int16(x)
  1510  		}
  1511  	}
  1512  
  1513  	b.int16(len(args))
  1514  	for _, x := range args {
  1515  		if x == nil {
  1516  			b.int32(-1)
  1517  		} else {
  1518  			datum := binaryEncode(&cn.parameterStatus, x)
  1519  			b.int32(len(datum))
  1520  			b.bytes(datum)
  1521  		}
  1522  	}
  1523  }
  1524  
  1525  func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) {
  1526  	if len(args) >= 65536 {
  1527  		errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(args))
  1528  	}
  1529  
  1530  	b := cn.writeBuf('P')
  1531  	b.byte(0) // unnamed statement
  1532  	b.string(query)
  1533  	b.int16(0)
  1534  
  1535  	b.next('B')
  1536  	b.int16(0) // unnamed portal and statement
  1537  	cn.sendBinaryParameters(b, args)
  1538  	b.bytes(colFmtDataAllText)
  1539  
  1540  	b.next('D')
  1541  	b.byte('P')
  1542  	b.byte(0) // unnamed portal
  1543  
  1544  	b.next('E')
  1545  	b.byte(0)
  1546  	b.int32(0)
  1547  
  1548  	b.next('S')
  1549  	cn.send(b)
  1550  }
  1551  
  1552  func (c *conn) processParameterStatus(r *readBuf) {
  1553  	var err error
  1554  
  1555  	param := r.string()
  1556  	switch param {
  1557  	case "server_version":
  1558  		var major1 int
  1559  		var major2 int
  1560  		var minor int
  1561  		_, err = fmt.Sscanf(r.string(), "%d.%d.%d", &major1, &major2, &minor)
  1562  		if err == nil {
  1563  			c.parameterStatus.serverVersion = major1*10000 + major2*100 + minor
  1564  		}
  1565  
  1566  	case "TimeZone":
  1567  		c.parameterStatus.currentLocation, err = time.LoadLocation(r.string())
  1568  		if err != nil {
  1569  			c.parameterStatus.currentLocation = nil
  1570  		}
  1571  
  1572  	default:
  1573  		// ignore
  1574  	}
  1575  }
  1576  
  1577  func (c *conn) processReadyForQuery(r *readBuf) {
  1578  	c.txnStatus = transactionStatus(r.byte())
  1579  }
  1580  
  1581  func (cn *conn) readReadyForQuery() {
  1582  	t, r := cn.recv1()
  1583  	switch t {
  1584  	case 'Z':
  1585  		cn.processReadyForQuery(r)
  1586  		return
  1587  	default:
  1588  		cn.bad = true
  1589  		errorf("unexpected message %q; expected ReadyForQuery", t)
  1590  	}
  1591  }
  1592  
  1593  func (cn *conn) readParseResponse() {
  1594  	t, r := cn.recv1()
  1595  	switch t {
  1596  	case '1':
  1597  		return
  1598  	case 'E':
  1599  		err := parseError(r)
  1600  		cn.readReadyForQuery()
  1601  		panic(err)
  1602  	default:
  1603  		cn.bad = true
  1604  		errorf("unexpected Parse response %q", t)
  1605  	}
  1606  }
  1607  
  1608  func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []oid.Oid) {
  1609  	for {
  1610  		t, r := cn.recv1()
  1611  		switch t {
  1612  		case 't':
  1613  			nparams := r.int16()
  1614  			paramTyps = make([]oid.Oid, nparams)
  1615  			for i := range paramTyps {
  1616  				paramTyps[i] = r.oid()
  1617  			}
  1618  		case 'n':
  1619  			return paramTyps, nil, nil
  1620  		case 'T':
  1621  			colNames, colTyps = parseStatementRowDescribe(r)
  1622  			return paramTyps, colNames, colTyps
  1623  		case 'E':
  1624  			err := parseError(r)
  1625  			cn.readReadyForQuery()
  1626  			panic(err)
  1627  		default:
  1628  			cn.bad = true
  1629  			errorf("unexpected Describe statement response %q", t)
  1630  		}
  1631  	}
  1632  }
  1633  
  1634  func (cn *conn) readPortalDescribeResponse() (colNames []string, colFmts []format, colTyps []oid.Oid) {
  1635  	t, r := cn.recv1()
  1636  	switch t {
  1637  	case 'T':
  1638  		return parsePortalRowDescribe(r)
  1639  	case 'n':
  1640  		return nil, nil, nil
  1641  	case 'E':
  1642  		err := parseError(r)
  1643  		cn.readReadyForQuery()
  1644  		panic(err)
  1645  	default:
  1646  		cn.bad = true
  1647  		errorf("unexpected Describe response %q", t)
  1648  	}
  1649  	panic("not reached")
  1650  }
  1651  
  1652  func (cn *conn) readBindResponse() {
  1653  	t, r := cn.recv1()
  1654  	switch t {
  1655  	case '2':
  1656  		return
  1657  	case 'E':
  1658  		err := parseError(r)
  1659  		cn.readReadyForQuery()
  1660  		panic(err)
  1661  	default:
  1662  		cn.bad = true
  1663  		errorf("unexpected Bind response %q", t)
  1664  	}
  1665  }
  1666  
  1667  func (cn *conn) postExecuteWorkaround() {
  1668  	// Work around a bug in sql.DB.QueryRow: in Go 1.2 and earlier it ignores
  1669  	// any errors from rows.Next, which masks errors that happened during the
  1670  	// execution of the query.  To avoid the problem in common cases, we wait
  1671  	// here for one more message from the database.  If it's not an error the
  1672  	// query will likely succeed (or perhaps has already, if it's a
  1673  	// CommandComplete), so we push the message into the conn struct; recv1
  1674  	// will return it as the next message for rows.Next or rows.Close.
  1675  	// However, if it's an error, we wait until ReadyForQuery and then return
  1676  	// the error to our caller.
  1677  	for {
  1678  		t, r := cn.recv1()
  1679  		switch t {
  1680  		case 'E':
  1681  			err := parseError(r)
  1682  			cn.readReadyForQuery()
  1683  			panic(err)
  1684  		case 'C', 'D', 'I':
  1685  			// the query didn't fail, but we can't process this message
  1686  			cn.saveMessage(t, r)
  1687  			return
  1688  		default:
  1689  			cn.bad = true
  1690  			errorf("unexpected message during extended query execution: %q", t)
  1691  		}
  1692  	}
  1693  }
  1694  
  1695  // Only for Exec(), since we ignore the returned data
  1696  func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, commandTag string, err error) {
  1697  	for {
  1698  		t, r := cn.recv1()
  1699  		switch t {
  1700  		case 'C':
  1701  			if err != nil {
  1702  				cn.bad = true
  1703  				errorf("unexpected CommandComplete after error %s", err)
  1704  			}
  1705  			res, commandTag = cn.parseComplete(r.string())
  1706  		case 'Z':
  1707  			cn.processReadyForQuery(r)
  1708  			return res, commandTag, err
  1709  		case 'E':
  1710  			err = parseError(r)
  1711  		case 'T', 'D', 'I':
  1712  			if err != nil {
  1713  				cn.bad = true
  1714  				errorf("unexpected %q after error %s", t, err)
  1715  			}
  1716  			// ignore any results
  1717  		default:
  1718  			cn.bad = true
  1719  			errorf("unknown %s response: %q", protocolState, t)
  1720  		}
  1721  	}
  1722  }
  1723  
  1724  func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []oid.Oid) {
  1725  	n := r.int16()
  1726  	colNames = make([]string, n)
  1727  	colTyps = make([]oid.Oid, n)
  1728  	for i := range colNames {
  1729  		colNames[i] = r.string()
  1730  		r.next(6)
  1731  		colTyps[i] = r.oid()
  1732  		r.next(6)
  1733  		// format code not known when describing a statement; always 0
  1734  		r.next(2)
  1735  	}
  1736  	return
  1737  }
  1738  
  1739  func parsePortalRowDescribe(r *readBuf) (colNames []string, colFmts []format, colTyps []oid.Oid) {
  1740  	n := r.int16()
  1741  	colNames = make([]string, n)
  1742  	colFmts = make([]format, n)
  1743  	colTyps = make([]oid.Oid, n)
  1744  	for i := range colNames {
  1745  		colNames[i] = r.string()
  1746  		r.next(6)
  1747  		colTyps[i] = r.oid()
  1748  		r.next(6)
  1749  		colFmts[i] = format(r.int16())
  1750  	}
  1751  	return
  1752  }
  1753  
  1754  // parseEnviron tries to mimic some of libpq's environment handling
  1755  //
  1756  // To ease testing, it does not directly reference os.Environ, but is
  1757  // designed to accept its output.
  1758  //
  1759  // Environment-set connection information is intended to have a higher
  1760  // precedence than a library default but lower than any explicitly
  1761  // passed information (such as in the URL or connection string).
  1762  func parseEnviron(env []string) (out map[string]string) {
  1763  	out = make(map[string]string)
  1764  
  1765  	for _, v := range env {
  1766  		parts := strings.SplitN(v, "=", 2)
  1767  
  1768  		accrue := func(keyname string) {
  1769  			out[keyname] = parts[1]
  1770  		}
  1771  		unsupported := func() {
  1772  			panic(fmt.Sprintf("setting %v not supported", parts[0]))
  1773  		}
  1774  
  1775  		// The order of these is the same as is seen in the
  1776  		// PostgreSQL 9.1 manual. Unsupported but well-defined
  1777  		// keys cause a panic; these should be unset prior to
  1778  		// execution. Options which pq expects to be set to a
  1779  		// certain value are allowed, but must be set to that
  1780  		// value if present (they can, of course, be absent).
  1781  		switch parts[0] {
  1782  		case "PGHOST":
  1783  			accrue("host")
  1784  		case "PGHOSTADDR":
  1785  			unsupported()
  1786  		case "PGPORT":
  1787  			accrue("port")
  1788  		case "PGDATABASE":
  1789  			accrue("dbname")
  1790  		case "PGUSER":
  1791  			accrue("user")
  1792  		case "PGPASSWORD":
  1793  			accrue("password")
  1794  		case "PGSERVICE", "PGSERVICEFILE", "PGREALM":
  1795  			unsupported()
  1796  		case "PGOPTIONS":
  1797  			accrue("options")
  1798  		case "PGAPPNAME":
  1799  			accrue("application_name")
  1800  		case "PGSSLMODE":
  1801  			accrue("sslmode")
  1802  		case "PGSSLCERT":
  1803  			accrue("sslcert")
  1804  		case "PGSSLKEY":
  1805  			accrue("sslkey")
  1806  		case "PGSSLROOTCERT":
  1807  			accrue("sslrootcert")
  1808  		case "PGREQUIRESSL", "PGSSLCRL":
  1809  			unsupported()
  1810  		case "PGREQUIREPEER":
  1811  			unsupported()
  1812  		case "PGKRBSRVNAME", "PGGSSLIB":
  1813  			unsupported()
  1814  		case "PGCONNECT_TIMEOUT":
  1815  			accrue("connect_timeout")
  1816  		case "PGCLIENTENCODING":
  1817  			accrue("client_encoding")
  1818  		case "PGDATESTYLE":
  1819  			accrue("datestyle")
  1820  		case "PGTZ":
  1821  			accrue("timezone")
  1822  		case "PGGEQO":
  1823  			accrue("geqo")
  1824  		case "PGSYSCONFDIR", "PGLOCALEDIR":
  1825  			unsupported()
  1826  		}
  1827  	}
  1828  
  1829  	return out
  1830  }
  1831  
  1832  // isUTF8 returns whether name is a fuzzy variation of the string "UTF-8".
  1833  func isUTF8(name string) bool {
  1834  	// Recognize all sorts of silly things as "UTF-8", like Postgres does
  1835  	s := strings.Map(alnumLowerASCII, name)
  1836  	return s == "utf8" || s == "unicode"
  1837  }
  1838  
  1839  func alnumLowerASCII(ch rune) rune {
  1840  	if 'A' <= ch && ch <= 'Z' {
  1841  		return ch + ('a' - 'A')
  1842  	}
  1843  	if 'a' <= ch && ch <= 'z' || '0' <= ch && ch <= '9' {
  1844  		return ch
  1845  	}
  1846  	return -1 // discard
  1847  }