github.com/insionng/yougam@v0.0.0-20170714101924-2bc18d833463/libraries/lib/pq/conn.go (about) 1 package pq 2 3 import ( 4 "bufio" 5 "crypto/md5" 6 "crypto/tls" 7 "crypto/x509" 8 "database/sql" 9 "database/sql/driver" 10 "encoding/binary" 11 "errors" 12 "fmt" 13 "io" 14 "io/ioutil" 15 "net" 16 "os" 17 "os/user" 18 "path" 19 "path/filepath" 20 "strconv" 21 "strings" 22 "time" 23 "unicode" 24 25 "github.com/insionng/yougam/libraries/lib/pq/oid" 26 ) 27 28 // Common error types 29 var ( 30 ErrNotSupported = errors.New("pq: Unsupported command") 31 ErrInFailedTransaction = errors.New("pq: Could not complete operation in a failed transaction") 32 ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server") 33 ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less.") 34 ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly.") 35 ) 36 37 type drv struct{} 38 39 func (d *drv) Open(name string) (driver.Conn, error) { 40 return Open(name) 41 } 42 43 func init() { 44 sql.Register("postgres", &drv{}) 45 } 46 47 type parameterStatus struct { 48 // server version in the same format as server_version_num, or 0 if 49 // unavailable 50 serverVersion int 51 52 // the current location based on the TimeZone value of the self.Session. if 53 // available 54 currentLocation *time.Location 55 } 56 57 type transactionStatus byte 58 59 const ( 60 txnStatusIdle transactionStatus = 'I' 61 txnStatusIdleInTransaction transactionStatus = 'T' 62 txnStatusInFailedTransaction transactionStatus = 'E' 63 ) 64 65 func (s transactionStatus) String() string { 66 switch s { 67 case txnStatusIdle: 68 return "idle" 69 case txnStatusIdleInTransaction: 70 return "idle in transaction" 71 case txnStatusInFailedTransaction: 72 return "in a failed transaction" 73 default: 74 errorf("unknown transactionStatus %d", s) 75 } 76 77 panic("not reached") 78 } 79 80 type Dialer interface { 81 Dial(network, address string) (net.Conn, error) 82 DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) 83 } 84 85 type defaultDialer struct{} 86 87 func (d defaultDialer) Dial(ntw, addr string) (net.Conn, error) { 88 return net.Dial(ntw, addr) 89 } 90 func (d defaultDialer) DialTimeout(ntw, addr string, timeout time.Duration) (net.Conn, error) { 91 return net.DialTimeout(ntw, addr, timeout) 92 } 93 94 type conn struct { 95 c net.Conn 96 buf *bufio.Reader 97 namei int 98 scratch [512]byte 99 txnStatus transactionStatus 100 101 parameterStatus parameterStatus 102 103 saveMessageType byte 104 saveMessageBuffer []byte 105 106 // If true, this connection is bad and all public-facing functions should 107 // return ErrBadConn. 108 bad bool 109 110 // If set, this connection should never use the binary format when 111 // receiving query results from prepared statements. Only provided for 112 // debugging. 113 disablePreparedBinaryResult bool 114 115 // Whether to always send []byte parameters over as binary. Enables single 116 // round-trip mode for non-prepared Query calls. 117 binaryParameters bool 118 } 119 120 // Handle driver-side settings in parsed connection string. 121 func (c *conn) handleDriverSettings(o values) (err error) { 122 boolSetting := func(key string, val *bool) error { 123 if value := o.Get(key); value != "" { 124 if value == "yes" { 125 *val = true 126 } else if value == "no" { 127 *val = false 128 } else { 129 return fmt.Errorf("unrecognized value %q for %s", value, key) 130 } 131 } 132 return nil 133 } 134 135 err = boolSetting("disable_prepared_binary_result", &c.disablePreparedBinaryResult) 136 if err != nil { 137 return err 138 } 139 err = boolSetting("binary_parameters", &c.binaryParameters) 140 if err != nil { 141 return err 142 } 143 return nil 144 } 145 146 func (c *conn) handlePgpass(o values) { 147 // if a password was supplied, do not process .pgpass 148 _, ok := o["password"] 149 if ok { 150 return 151 } 152 filename := os.Getenv("PGPASSFILE") 153 if filename == "" { 154 // XXX this code doesn't work on Windows where the default filename is 155 // XXX %APPDATA%\postgresql\pgpass.conf 156 user, err := user.Current() 157 if err != nil { 158 return 159 } 160 filename = filepath.Join(user.HomeDir, ".pgpass") 161 } 162 fileinfo, err := os.Stat(filename) 163 if err != nil { 164 return 165 } 166 mode := fileinfo.Mode() 167 if mode&(0x77) != 0 { 168 // XXX should warn about incorrect .pgpass permissions as psql does 169 return 170 } 171 file, err := os.Open(filename) 172 if err != nil { 173 return 174 } 175 defer file.Close() 176 scanner := bufio.NewScanner(io.Reader(file)) 177 hostname := o.Get("host") 178 ntw, _ := network(o) 179 port := o.Get("port") 180 db := o.Get("dbname") 181 username := o.Get("user") 182 // From: https://yougam/libraries/tg/pgpass/blob/master/reader.go 183 getFields := func(s string) []string { 184 fs := make([]string, 0, 5) 185 f := make([]rune, 0, len(s)) 186 187 var esc bool 188 for _, c := range s { 189 switch { 190 case esc: 191 f = append(f, c) 192 esc = false 193 case c == '\\': 194 esc = true 195 case c == ':': 196 fs = append(fs, string(f)) 197 f = f[:0] 198 default: 199 f = append(f, c) 200 } 201 } 202 return append(fs, string(f)) 203 } 204 for scanner.Scan() { 205 line := scanner.Text() 206 if len(line) == 0 || line[0] == '#' { 207 continue 208 } 209 split := getFields(line) 210 if len(split) != 5 { 211 continue 212 } 213 if (split[0] == "*" || split[0] == hostname || (split[0] == "localhost" && (hostname == "" || ntw == "unix"))) && (split[1] == "*" || split[1] == port) && (split[2] == "*" || split[2] == db) && (split[3] == "*" || split[3] == username) { 214 o["password"] = split[4] 215 return 216 } 217 } 218 } 219 220 func (c *conn) writeBuf(b byte) *writeBuf { 221 c.scratch[0] = b 222 return &writeBuf{ 223 buf: c.scratch[:5], 224 pos: 1, 225 } 226 } 227 228 func Open(name string) (_ driver.Conn, err error) { 229 return DialOpen(defaultDialer{}, name) 230 } 231 232 func DialOpen(d Dialer, name string) (_ driver.Conn, err error) { 233 // Handle any panics during connection initialization. Note that we 234 // specifically do *not* want to use errRecover(), as that would turn any 235 // connection errors into ErrBadConns, hiding the real error message from 236 // the user. 237 defer errRecoverNoErrBadConn(&err) 238 239 o := make(values) 240 241 // A number of defaults are applied here, in this order: 242 // 243 // * Very low precedence defaults applied in every situation 244 // * Environment variables 245 // * Explicitly passed connection information 246 o.Set("host", "localhost") 247 o.Set("port", "5432") 248 // N.B.: Extra float digits should be set to 3, but that breaks 249 // Postgres 8.4 and older, where the max is 2. 250 o.Set("extra_float_digits", "2") 251 for k, v := range parseEnviron(os.Environ()) { 252 o.Set(k, v) 253 } 254 255 if strings.HasPrefix(name, "postgres://") || strings.HasPrefix(name, "postgresql://") { 256 name, err = ParseURL(name) 257 if err != nil { 258 return nil, err 259 } 260 } 261 262 if err := parseOpts(name, o); err != nil { 263 return nil, err 264 } 265 266 // Use the "fallback" application name if necessary 267 if fallback := o.Get("fallback_application_name"); fallback != "" { 268 if !o.Isset("application_name") { 269 o.Set("application_name", fallback) 270 } 271 } 272 273 // We can't work with any client_encoding other than UTF-8 currently. 274 // However, we have historically allowed the user to set it to UTF-8 275 // explicitly, and there's no reason to break such programs, so allow that. 276 // Note that the "options" setting could also set client_encoding, but 277 // parsing its value is not worth it. Instead, we always explicitly send 278 // client_encoding as a separate run-time parameter, which should override 279 // anything set in options. 280 if enc := o.Get("client_encoding"); enc != "" && !isUTF8(enc) { 281 return nil, errors.New("client_encoding must be absent or 'UTF8'") 282 } 283 o.Set("client_encoding", "UTF8") 284 // DateStyle needs a similar treatment. 285 if datestyle := o.Get("datestyle"); datestyle != "" { 286 if datestyle != "ISO, MDY" { 287 panic(fmt.Sprintf("setting datestyle must be absent or %v; got %v", 288 "ISO, MDY", datestyle)) 289 } 290 } else { 291 o.Set("datestyle", "ISO, MDY") 292 } 293 294 // If a user is not provided by any other means, the last 295 // resort is to use the current operating system provided user 296 // name. 297 if o.Get("user") == "" { 298 u, err := userCurrent() 299 if err != nil { 300 return nil, err 301 } else { 302 o.Set("user", u) 303 } 304 } 305 306 cn := &conn{} 307 err = cn.handleDriverSettings(o) 308 if err != nil { 309 return nil, err 310 } 311 cn.handlePgpass(o) 312 313 cn.c, err = dial(d, o) 314 if err != nil { 315 return nil, err 316 } 317 cn.ssl(o) 318 cn.buf = bufio.NewReader(cn.c) 319 cn.startup(o) 320 321 // reset the deadline, in case one was set (see dial) 322 if timeout := o.Get("connect_timeout"); timeout != "" && timeout != "0" { 323 err = cn.c.SetDeadline(time.Time{}) 324 } 325 return cn, err 326 } 327 328 func dial(d Dialer, o values) (net.Conn, error) { 329 ntw, addr := network(o) 330 // SSL is not necessary or supported over UNIX domain sockets 331 if ntw == "unix" { 332 o["sslmode"] = "disable" 333 } 334 335 // Zero or not specified means wait indefinitely. 336 if timeout := o.Get("connect_timeout"); timeout != "" && timeout != "0" { 337 seconds, err := strconv.ParseInt(timeout, 10, 0) 338 if err != nil { 339 return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err) 340 } 341 duration := time.Duration(seconds) * time.Second 342 // connect_timeout should apply to the entire connection establishment 343 // procedure, so we both use a timeout for the TCP connection 344 // establishment and set a deadline for doing the initial handshake. 345 // The deadline is then reset after startup() is done. 346 deadline := time.Now().Add(duration) 347 conn, err := d.DialTimeout(ntw, addr, duration) 348 if err != nil { 349 return nil, err 350 } 351 err = conn.SetDeadline(deadline) 352 return conn, err 353 } 354 return d.Dial(ntw, addr) 355 } 356 357 func network(o values) (string, string) { 358 host := o.Get("host") 359 360 if strings.HasPrefix(host, "/") { 361 sockPath := path.Join(host, ".s.PGSQL."+o.Get("port")) 362 return "unix", sockPath 363 } 364 365 return "tcp", net.JoinHostPort(host, o.Get("port")) 366 } 367 368 type values map[string]string 369 370 func (vs values) Set(k, v string) { 371 vs[k] = v 372 } 373 374 func (vs values) Get(k string) (v string) { 375 return vs[k] 376 } 377 378 func (vs values) Isset(k string) bool { 379 _, ok := vs[k] 380 return ok 381 } 382 383 // scanner implements a tokenizer for libpq-style option strings. 384 type scanner struct { 385 s []rune 386 i int 387 } 388 389 // newScanner returns a new scanner initialized with the option string s. 390 func newScanner(s string) *scanner { 391 return &scanner{[]rune(s), 0} 392 } 393 394 // Next returns the next rune. 395 // It returns 0, false if the end of the text has been reached. 396 func (s *scanner) Next() (rune, bool) { 397 if s.i >= len(s.s) { 398 return 0, false 399 } 400 r := s.s[s.i] 401 s.i++ 402 return r, true 403 } 404 405 // SkipSpaces returns the next non-whitespace rune. 406 // It returns 0, false if the end of the text has been reached. 407 func (s *scanner) SkipSpaces() (rune, bool) { 408 r, ok := s.Next() 409 for unicode.IsSpace(r) && ok { 410 r, ok = s.Next() 411 } 412 return r, ok 413 } 414 415 // parseOpts parses the options from name and adds them to the values. 416 // 417 // The parsing code is based on conninfo_parse from libpq's fe-connect.c 418 func parseOpts(name string, o values) error { 419 s := newScanner(name) 420 421 for { 422 var ( 423 keyRunes, valRunes []rune 424 r rune 425 ok bool 426 ) 427 428 if r, ok = s.SkipSpaces(); !ok { 429 break 430 } 431 432 // Scan the key 433 for !unicode.IsSpace(r) && r != '=' { 434 keyRunes = append(keyRunes, r) 435 if r, ok = s.Next(); !ok { 436 break 437 } 438 } 439 440 // Skip any whitespace if we're not at the = yet 441 if r != '=' { 442 r, ok = s.SkipSpaces() 443 } 444 445 // The current character should be = 446 if r != '=' || !ok { 447 return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes)) 448 } 449 450 // Skip any whitespace after the = 451 if r, ok = s.SkipSpaces(); !ok { 452 // If we reach the end here, the last value is just an empty string as per libpq. 453 o.Set(string(keyRunes), "") 454 break 455 } 456 457 if r != '\'' { 458 for !unicode.IsSpace(r) { 459 if r == '\\' { 460 if r, ok = s.Next(); !ok { 461 return fmt.Errorf(`missing character after backslash`) 462 } 463 } 464 valRunes = append(valRunes, r) 465 466 if r, ok = s.Next(); !ok { 467 break 468 } 469 } 470 } else { 471 quote: 472 for { 473 if r, ok = s.Next(); !ok { 474 return fmt.Errorf(`unterminated quoted string literal in connection string`) 475 } 476 switch r { 477 case '\'': 478 break quote 479 case '\\': 480 r, _ = s.Next() 481 fallthrough 482 default: 483 valRunes = append(valRunes, r) 484 } 485 } 486 } 487 488 o.Set(string(keyRunes), string(valRunes)) 489 } 490 491 return nil 492 } 493 494 func (cn *conn) isInTransaction() bool { 495 return cn.txnStatus == txnStatusIdleInTransaction || 496 cn.txnStatus == txnStatusInFailedTransaction 497 } 498 499 func (cn *conn) checkIsInTransaction(intxn bool) { 500 if cn.isInTransaction() != intxn { 501 cn.bad = true 502 errorf("unexpected transaction status %v", cn.txnStatus) 503 } 504 } 505 506 func (cn *conn) Begin() (_ driver.Tx, err error) { 507 if cn.bad { 508 return nil, driver.ErrBadConn 509 } 510 defer cn.errRecover(&err) 511 512 cn.checkIsInTransaction(false) 513 _, commandTag, err := cn.simpleExec("BEGIN") 514 if err != nil { 515 return nil, err 516 } 517 if commandTag != "BEGIN" { 518 cn.bad = true 519 return nil, fmt.Errorf("unexpected command tag %s", commandTag) 520 } 521 if cn.txnStatus != txnStatusIdleInTransaction { 522 cn.bad = true 523 return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus) 524 } 525 return cn, nil 526 } 527 528 func (cn *conn) Commit() (err error) { 529 if cn.bad { 530 return driver.ErrBadConn 531 } 532 defer cn.errRecover(&err) 533 534 cn.checkIsInTransaction(true) 535 // We don't want the client to think that everything is okay if it tries 536 // to commit a failed transaction. However, no matter what we return, 537 // database/sql will release this connection back into the free connection 538 // pool so we have to abort the current transaction here. Note that you 539 // would get the same behaviour if you issued a COMMIT in a failed 540 // transaction, so it's also the least surprising thing to do here. 541 if cn.txnStatus == txnStatusInFailedTransaction { 542 if err := cn.Rollback(); err != nil { 543 return err 544 } 545 return ErrInFailedTransaction 546 } 547 548 _, commandTag, err := cn.simpleExec("COMMIT") 549 if err != nil { 550 if cn.isInTransaction() { 551 cn.bad = true 552 } 553 return err 554 } 555 if commandTag != "COMMIT" { 556 cn.bad = true 557 return fmt.Errorf("unexpected command tag %s", commandTag) 558 } 559 cn.checkIsInTransaction(false) 560 return nil 561 } 562 563 func (cn *conn) Rollback() (err error) { 564 if cn.bad { 565 return driver.ErrBadConn 566 } 567 defer cn.errRecover(&err) 568 569 cn.checkIsInTransaction(true) 570 _, commandTag, err := cn.simpleExec("ROLLBACK") 571 if err != nil { 572 if cn.isInTransaction() { 573 cn.bad = true 574 } 575 return err 576 } 577 if commandTag != "ROLLBACK" { 578 return fmt.Errorf("unexpected command tag %s", commandTag) 579 } 580 cn.checkIsInTransaction(false) 581 return nil 582 } 583 584 func (cn *conn) gname() string { 585 cn.namei++ 586 return strconv.FormatInt(int64(cn.namei), 10) 587 } 588 589 func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err error) { 590 b := cn.writeBuf('Q') 591 b.string(q) 592 cn.send(b) 593 594 for { 595 t, r := cn.recv1() 596 switch t { 597 case 'C': 598 res, commandTag = cn.parseComplete(r.string()) 599 case 'Z': 600 cn.processReadyForQuery(r) 601 // done 602 return 603 case 'E': 604 err = parseError(r) 605 case 'T', 'D', 'I': 606 // ignore any results 607 default: 608 cn.bad = true 609 errorf("unknown response for simple query: %q", t) 610 } 611 } 612 } 613 614 func (cn *conn) simpleQuery(q string) (res *rows, err error) { 615 defer cn.errRecover(&err) 616 617 b := cn.writeBuf('Q') 618 b.string(q) 619 cn.send(b) 620 621 for { 622 t, r := cn.recv1() 623 switch t { 624 case 'C', 'I': 625 // We allow queries which don't return any results through Query as 626 // well as Exec. We still have to give database/sql a rows object 627 // the user can close, though, to avoid connections from being 628 // leaked. A "rows" with done=true works fine for that purpose. 629 if err != nil { 630 cn.bad = true 631 errorf("unexpected message %q in simple query execution", t) 632 } 633 if res == nil { 634 res = &rows{ 635 cn: cn, 636 } 637 } 638 res.done = true 639 case 'Z': 640 cn.processReadyForQuery(r) 641 // done 642 return 643 case 'E': 644 res = nil 645 err = parseError(r) 646 case 'D': 647 if res == nil { 648 cn.bad = true 649 errorf("unexpected DataRow in simple query execution") 650 } 651 // the query didn't fail; kick off to Next 652 cn.saveMessage(t, r) 653 return 654 case 'T': 655 // res might be non-nil here if we received a previous 656 // CommandComplete, but that's fine; just overwrite it 657 res = &rows{cn: cn} 658 res.colNames, res.colFmts, res.colTyps = parsePortalRowDescribe(r) 659 660 // To work around a bug in QueryRow in Go 1.2 and earlier, wait 661 // until the first DataRow has been received. 662 default: 663 cn.bad = true 664 errorf("unknown response for simple query: %q", t) 665 } 666 } 667 } 668 669 // Decides which column formats to use for a prepared statement. The input is 670 // an array of type oids, one element per result column. 671 func decideColumnFormats(colTyps []oid.Oid, forceText bool) (colFmts []format, colFmtData []byte) { 672 if len(colTyps) == 0 { 673 return nil, colFmtDataAllText 674 } 675 676 colFmts = make([]format, len(colTyps)) 677 if forceText { 678 return colFmts, colFmtDataAllText 679 } 680 681 allBinary := true 682 allText := true 683 for i, o := range colTyps { 684 switch o { 685 // This is the list of types to use binary mode for when receiving them 686 // through a prepared statement. If a type appears in this list, it 687 // must also be implemented in binaryDecode in encode.go. 688 case oid.T_bytea: 689 fallthrough 690 case oid.T_int8: 691 fallthrough 692 case oid.T_int4: 693 fallthrough 694 case oid.T_int2: 695 colFmts[i] = formatBinary 696 allText = false 697 698 default: 699 allBinary = false 700 } 701 } 702 703 if allBinary { 704 return colFmts, colFmtDataAllBinary 705 } else if allText { 706 return colFmts, colFmtDataAllText 707 } else { 708 colFmtData = make([]byte, 2+len(colFmts)*2) 709 binary.BigEndian.PutUint16(colFmtData, uint16(len(colFmts))) 710 for i, v := range colFmts { 711 binary.BigEndian.PutUint16(colFmtData[2+i*2:], uint16(v)) 712 } 713 return colFmts, colFmtData 714 } 715 } 716 717 func (cn *conn) prepareTo(q, stmtName string) *stmt { 718 st := &stmt{cn: cn, name: stmtName} 719 720 b := cn.writeBuf('P') 721 b.string(st.name) 722 b.string(q) 723 b.int16(0) 724 725 b.next('D') 726 b.byte('S') 727 b.string(st.name) 728 729 b.next('S') 730 cn.send(b) 731 732 cn.readParseResponse() 733 st.paramTyps, st.colNames, st.colTyps = cn.readStatementDescribeResponse() 734 st.colFmts, st.colFmtData = decideColumnFormats(st.colTyps, cn.disablePreparedBinaryResult) 735 cn.readReadyForQuery() 736 return st 737 } 738 739 func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) { 740 if cn.bad { 741 return nil, driver.ErrBadConn 742 } 743 defer cn.errRecover(&err) 744 745 if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") { 746 return cn.prepareCopyIn(q) 747 } 748 return cn.prepareTo(q, cn.gname()), nil 749 } 750 751 func (cn *conn) Close() (err error) { 752 if cn.bad { 753 return driver.ErrBadConn 754 } 755 defer cn.errRecover(&err) 756 757 // Don't go through send(); ListenerConn relies on us not scribbling on the 758 // scratch buffer of this connection. 759 err = cn.sendSimpleMessage('X') 760 if err != nil { 761 return err 762 } 763 764 return cn.c.Close() 765 } 766 767 // Implement the "Queryer" interface 768 func (cn *conn) Query(query string, args []driver.Value) (_ driver.Rows, err error) { 769 if cn.bad { 770 return nil, driver.ErrBadConn 771 } 772 defer cn.errRecover(&err) 773 774 // Check to see if we can use the "simpleQuery" interface, which is 775 // *much* faster than going through prepare/exec 776 if len(args) == 0 { 777 return cn.simpleQuery(query) 778 } 779 780 if cn.binaryParameters { 781 cn.sendBinaryModeQuery(query, args) 782 783 cn.readParseResponse() 784 cn.readBindResponse() 785 rows := &rows{cn: cn} 786 rows.colNames, rows.colFmts, rows.colTyps = cn.readPortalDescribeResponse() 787 cn.postExecuteWorkaround() 788 return rows, nil 789 } else { 790 st := cn.prepareTo(query, "") 791 st.exec(args) 792 return &rows{ 793 cn: cn, 794 colNames: st.colNames, 795 colTyps: st.colTyps, 796 colFmts: st.colFmts, 797 }, nil 798 } 799 } 800 801 // Implement the optional "Execer" interface for one-shot queries 802 func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) { 803 if cn.bad { 804 return nil, driver.ErrBadConn 805 } 806 defer cn.errRecover(&err) 807 808 // Check to see if we can use the "simpleExec" interface, which is 809 // *much* faster than going through prepare/exec 810 if len(args) == 0 { 811 // ignore commandTag, our caller doesn't care 812 r, _, err := cn.simpleExec(query) 813 return r, err 814 } 815 816 if cn.binaryParameters { 817 cn.sendBinaryModeQuery(query, args) 818 819 cn.readParseResponse() 820 cn.readBindResponse() 821 cn.readPortalDescribeResponse() 822 cn.postExecuteWorkaround() 823 res, _, err = cn.readExecuteResponse("Execute") 824 return res, err 825 } else { 826 // Use the unnamed statement to defer planning until bind 827 // time, or else value-based selectivity estimates cannot be 828 // used. 829 st := cn.prepareTo(query, "") 830 r, err := st.Exec(args) 831 if err != nil { 832 panic(err) 833 } 834 return r, err 835 } 836 } 837 838 func (cn *conn) send(m *writeBuf) { 839 _, err := cn.c.Write(m.wrap()) 840 if err != nil { 841 panic(err) 842 } 843 } 844 845 func (cn *conn) sendStartupPacket(m *writeBuf) { 846 // sanity check 847 if m.buf[0] != 0 { 848 panic("oops") 849 } 850 851 _, err := cn.c.Write((m.wrap())[1:]) 852 if err != nil { 853 panic(err) 854 } 855 } 856 857 // Send a message of type typ to the server on the other end of cn. The 858 // message should have no payload. This method does not use the scratch 859 // buffer. 860 func (cn *conn) sendSimpleMessage(typ byte) (err error) { 861 _, err = cn.c.Write([]byte{typ, '\x00', '\x00', '\x00', '\x04'}) 862 return err 863 } 864 865 // saveMessage memorizes a message and its buffer in the conn struct. 866 // recvMessage will then return these values on the next call to it. This 867 // method is useful in cases where you have to see what the next message is 868 // going to be (e.g. to see whether it's an error or not) but you can't handle 869 // the message yourself. 870 func (cn *conn) saveMessage(typ byte, buf *readBuf) { 871 if cn.saveMessageType != 0 { 872 cn.bad = true 873 errorf("unexpected saveMessageType %d", cn.saveMessageType) 874 } 875 cn.saveMessageType = typ 876 cn.saveMessageBuffer = *buf 877 } 878 879 // recvMessage receives any message from the backend, or returns an error if 880 // a problem occurred while reading the message. 881 func (cn *conn) recvMessage(r *readBuf) (byte, error) { 882 // workaround for a QueryRow bug, see exec 883 if cn.saveMessageType != 0 { 884 t := cn.saveMessageType 885 *r = cn.saveMessageBuffer 886 cn.saveMessageType = 0 887 cn.saveMessageBuffer = nil 888 return t, nil 889 } 890 891 x := cn.scratch[:5] 892 _, err := io.ReadFull(cn.buf, x) 893 if err != nil { 894 return 0, err 895 } 896 897 // read the type and length of the message that follows 898 t := x[0] 899 n := int(binary.BigEndian.Uint32(x[1:])) - 4 900 var y []byte 901 if n <= len(cn.scratch) { 902 y = cn.scratch[:n] 903 } else { 904 y = make([]byte, n) 905 } 906 _, err = io.ReadFull(cn.buf, y) 907 if err != nil { 908 return 0, err 909 } 910 *r = y 911 return t, nil 912 } 913 914 // recv receives a message from the backend, but if an error happened while 915 // reading the message or the received message was an ErrorResponse, it panics. 916 // NoticeResponses are ignored. This function should generally be used only 917 // during the startup sequence. 918 func (cn *conn) recv() (t byte, r *readBuf) { 919 for { 920 var err error 921 r = &readBuf{} 922 t, err = cn.recvMessage(r) 923 if err != nil { 924 panic(err) 925 } 926 927 switch t { 928 case 'E': 929 panic(parseError(r)) 930 case 'N': 931 // ignore 932 default: 933 return 934 } 935 } 936 } 937 938 // recv1Buf is exactly equivalent to recv1, except it uses a buffer supplied by 939 // the caller to avoid an allocation. 940 func (cn *conn) recv1Buf(r *readBuf) byte { 941 for { 942 t, err := cn.recvMessage(r) 943 if err != nil { 944 panic(err) 945 } 946 947 switch t { 948 case 'A', 'N': 949 // ignore 950 case 'S': 951 cn.processParameterStatus(r) 952 default: 953 return t 954 } 955 } 956 } 957 958 // recv1 receives a message from the backend, panicking if an error occurs 959 // while attempting to read it. All asynchronous messages are ignored, with 960 // the exception of ErrorResponse. 961 func (cn *conn) recv1() (t byte, r *readBuf) { 962 r = &readBuf{} 963 t = cn.recv1Buf(r) 964 return t, r 965 } 966 967 func (cn *conn) ssl(o values) { 968 verifyCaOnly := false 969 tlsConf := tls.Config{} 970 switch mode := o.Get("sslmode"); mode { 971 case "require", "": 972 tlsConf.InsecureSkipVerify = true 973 case "verify-ca": 974 // We must skip TLS's own verification since it requires full 975 // verification since Go 1.3. 976 tlsConf.InsecureSkipVerify = true 977 verifyCaOnly = true 978 case "verify-full": 979 tlsConf.ServerName = o.Get("host") 980 case "disable": 981 return 982 default: 983 errorf(`unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported`, mode) 984 } 985 986 cn.setupSSLClientCertificates(&tlsConf, o) 987 cn.setupSSLCA(&tlsConf, o) 988 989 w := cn.writeBuf(0) 990 w.int32(80877103) 991 cn.sendStartupPacket(w) 992 993 b := cn.scratch[:1] 994 _, err := io.ReadFull(cn.c, b) 995 if err != nil { 996 panic(err) 997 } 998 999 if b[0] != 'S' { 1000 panic(ErrSSLNotSupported) 1001 } 1002 1003 client := tls.Client(cn.c, &tlsConf) 1004 if verifyCaOnly { 1005 cn.verifyCA(client, &tlsConf) 1006 } 1007 cn.c = client 1008 } 1009 1010 // verifyCA carries out a TLS handshake to the server and verifies the 1011 // presented certificate against the effective CA, i.e. the one specified in 1012 // sslrootcert or the system CA if sslrootcert was not specified. 1013 func (cn *conn) verifyCA(client *tls.Conn, tlsConf *tls.Config) { 1014 err := client.Handshake() 1015 if err != nil { 1016 panic(err) 1017 } 1018 certs := client.ConnectionState().PeerCertificates 1019 opts := x509.VerifyOptions{ 1020 DNSName: client.ConnectionState().ServerName, 1021 Intermediates: x509.NewCertPool(), 1022 Roots: tlsConf.RootCAs, 1023 } 1024 for i, cert := range certs { 1025 if i == 0 { 1026 continue 1027 } 1028 opts.Intermediates.AddCert(cert) 1029 } 1030 _, err = certs[0].Verify(opts) 1031 if err != nil { 1032 panic(err) 1033 } 1034 } 1035 1036 // This function sets up SSL client certificates based on either the "sslkey" 1037 // and "sslcert" settings (possibly set via the environment variables PGSSLKEY 1038 // and PGSSLCERT, respectively), or if they aren't set, from the .postgresql 1039 // directory in the user's home directory. If the file paths are set 1040 // explicitly, the files must exist. The key file must also not be 1041 // world-readable, or this function will panic with 1042 // ErrSSLKeyHasWorldPermissions. 1043 func (cn *conn) setupSSLClientCertificates(tlsConf *tls.Config, o values) { 1044 var missingOk bool 1045 1046 sslkey := o.Get("sslkey") 1047 sslcert := o.Get("sslcert") 1048 if sslkey != "" && sslcert != "" { 1049 // If the user has set an sslkey and sslcert, they *must* exist. 1050 missingOk = false 1051 } else { 1052 // Automatically load certificates from ~/.postgresql. 1053 user, err := user.Current() 1054 if err != nil { 1055 // user.Current() might fail when cross-compiling. We have to 1056 // ignore the error and continue without client certificates, since 1057 // we wouldn't know where to load them from. 1058 return 1059 } 1060 1061 sslkey = filepath.Join(user.HomeDir, ".postgresql", "postgresql.key") 1062 sslcert = filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt") 1063 missingOk = true 1064 } 1065 1066 // Check that both files exist, and report the error or stop, depending on 1067 // which behaviour we want. Note that we don't do any more extensive 1068 // checks than this (such as checking that the paths aren't directories); 1069 // LoadX509KeyPair() will take care of the rest. 1070 keyfinfo, err := os.Stat(sslkey) 1071 if err != nil && missingOk { 1072 return 1073 } else if err != nil { 1074 panic(err) 1075 } 1076 _, err = os.Stat(sslcert) 1077 if err != nil && missingOk { 1078 return 1079 } else if err != nil { 1080 panic(err) 1081 } 1082 1083 // If we got this far, the key file must also have the correct permissions 1084 kmode := keyfinfo.Mode() 1085 if kmode != kmode&0600 { 1086 panic(ErrSSLKeyHasWorldPermissions) 1087 } 1088 1089 cert, err := tls.LoadX509KeyPair(sslcert, sslkey) 1090 if err != nil { 1091 panic(err) 1092 } 1093 tlsConf.Certificates = []tls.Certificate{cert} 1094 } 1095 1096 // Sets up RootCAs in the TLS configuration if sslrootcert is set. 1097 func (cn *conn) setupSSLCA(tlsConf *tls.Config, o values) { 1098 if sslrootcert := o.Get("sslrootcert"); sslrootcert != "" { 1099 tlsConf.RootCAs = x509.NewCertPool() 1100 1101 cert, err := ioutil.ReadFile(sslrootcert) 1102 if err != nil { 1103 panic(err) 1104 } 1105 1106 ok := tlsConf.RootCAs.AppendCertsFromPEM(cert) 1107 if !ok { 1108 errorf("couldn't parse pem in sslrootcert") 1109 } 1110 } 1111 } 1112 1113 // isDriverSetting returns true iff a setting is purely for configuring the 1114 // driver's options and should not be sent to the server in the connection 1115 // startup packet. 1116 func isDriverSetting(key string) bool { 1117 switch key { 1118 case "host", "port": 1119 return true 1120 case "password": 1121 return true 1122 case "sslmode", "sslcert", "sslkey", "sslrootcert": 1123 return true 1124 case "fallback_application_name": 1125 return true 1126 case "connect_timeout": 1127 return true 1128 case "disable_prepared_binary_result": 1129 return true 1130 case "binary_parameters": 1131 return true 1132 1133 default: 1134 return false 1135 } 1136 } 1137 1138 func (cn *conn) startup(o values) { 1139 w := cn.writeBuf(0) 1140 w.int32(196608) 1141 // Send the backend the name of the database we want to connect to, and the 1142 // user we want to connect as. Additionally, we send over any run-time 1143 // parameters potentially included in the connection string. If the server 1144 // doesn't recognize any of them, it will reply with an error. 1145 for k, v := range o { 1146 if isDriverSetting(k) { 1147 // skip options which can't be run-time parameters 1148 continue 1149 } 1150 // The protocol requires us to supply the database name as "database" 1151 // instead of "dbname". 1152 if k == "dbname" { 1153 k = "database" 1154 } 1155 w.string(k) 1156 w.string(v) 1157 } 1158 w.string("") 1159 cn.sendStartupPacket(w) 1160 1161 for { 1162 t, r := cn.recv() 1163 switch t { 1164 case 'K': 1165 case 'S': 1166 cn.processParameterStatus(r) 1167 case 'R': 1168 cn.auth(r, o) 1169 case 'Z': 1170 cn.processReadyForQuery(r) 1171 return 1172 default: 1173 errorf("unknown response for startup: %q", t) 1174 } 1175 } 1176 } 1177 1178 func (cn *conn) auth(r *readBuf, o values) { 1179 switch code := r.int32(); code { 1180 case 0: 1181 // OK 1182 case 3: 1183 w := cn.writeBuf('p') 1184 w.string(o.Get("password")) 1185 cn.send(w) 1186 1187 t, r := cn.recv() 1188 if t != 'R' { 1189 errorf("unexpected password response: %q", t) 1190 } 1191 1192 if r.int32() != 0 { 1193 errorf("unexpected authentication response: %q", t) 1194 } 1195 case 5: 1196 s := string(r.next(4)) 1197 w := cn.writeBuf('p') 1198 w.string("md5" + md5s(md5s(o.Get("password")+o.Get("user"))+s)) 1199 cn.send(w) 1200 1201 t, r := cn.recv() 1202 if t != 'R' { 1203 errorf("unexpected password response: %q", t) 1204 } 1205 1206 if r.int32() != 0 { 1207 errorf("unexpected authentication response: %q", t) 1208 } 1209 default: 1210 errorf("unknown authentication response: %d", code) 1211 } 1212 } 1213 1214 type format int 1215 1216 const formatText format = 0 1217 const formatBinary format = 1 1218 1219 // One result-column format code with the value 1 (i.e. all binary). 1220 var colFmtDataAllBinary []byte = []byte{0, 1, 0, 1} 1221 1222 // No result-column format codes (i.e. all text). 1223 var colFmtDataAllText []byte = []byte{0, 0} 1224 1225 type stmt struct { 1226 cn *conn 1227 name string 1228 colNames []string 1229 colFmts []format 1230 colFmtData []byte 1231 colTyps []oid.Oid 1232 paramTyps []oid.Oid 1233 closed bool 1234 } 1235 1236 func (st *stmt) Close() (err error) { 1237 if st.closed { 1238 return nil 1239 } 1240 if st.cn.bad { 1241 return driver.ErrBadConn 1242 } 1243 defer st.cn.errRecover(&err) 1244 1245 w := st.cn.writeBuf('C') 1246 w.byte('S') 1247 w.string(st.name) 1248 st.cn.send(w) 1249 1250 st.cn.send(st.cn.writeBuf('S')) 1251 1252 t, _ := st.cn.recv1() 1253 if t != '3' { 1254 st.cn.bad = true 1255 errorf("unexpected close response: %q", t) 1256 } 1257 st.closed = true 1258 1259 t, r := st.cn.recv1() 1260 if t != 'Z' { 1261 st.cn.bad = true 1262 errorf("expected ready for query, but got: %q", t) 1263 } 1264 st.cn.processReadyForQuery(r) 1265 1266 return nil 1267 } 1268 1269 func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) { 1270 if st.cn.bad { 1271 return nil, driver.ErrBadConn 1272 } 1273 defer st.cn.errRecover(&err) 1274 1275 st.exec(v) 1276 return &rows{ 1277 cn: st.cn, 1278 colNames: st.colNames, 1279 colTyps: st.colTyps, 1280 colFmts: st.colFmts, 1281 }, nil 1282 } 1283 1284 func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) { 1285 if st.cn.bad { 1286 return nil, driver.ErrBadConn 1287 } 1288 defer st.cn.errRecover(&err) 1289 1290 st.exec(v) 1291 res, _, err = st.cn.readExecuteResponse("simple query") 1292 return res, err 1293 } 1294 1295 func (st *stmt) exec(v []driver.Value) { 1296 if len(v) >= 65536 { 1297 errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(v)) 1298 } 1299 if len(v) != len(st.paramTyps) { 1300 errorf("got %d parameters but the statement requires %d", len(v), len(st.paramTyps)) 1301 } 1302 1303 cn := st.cn 1304 w := cn.writeBuf('B') 1305 w.byte(0) // unnamed portal 1306 w.string(st.name) 1307 1308 if cn.binaryParameters { 1309 cn.sendBinaryParameters(w, v) 1310 } else { 1311 w.int16(0) 1312 w.int16(len(v)) 1313 for i, x := range v { 1314 if x == nil { 1315 w.int32(-1) 1316 } else { 1317 b := encode(&cn.parameterStatus, x, st.paramTyps[i]) 1318 w.int32(len(b)) 1319 w.bytes(b) 1320 } 1321 } 1322 } 1323 w.bytes(st.colFmtData) 1324 1325 w.next('E') 1326 w.byte(0) 1327 w.int32(0) 1328 1329 w.next('S') 1330 cn.send(w) 1331 1332 cn.readBindResponse() 1333 cn.postExecuteWorkaround() 1334 1335 } 1336 1337 func (st *stmt) NumInput() int { 1338 return len(st.paramTyps) 1339 } 1340 1341 // parseComplete parses the "command tag" from a CommandComplete message, and 1342 // returns the number of rows affected (if applicable) and a string 1343 // identifying only the command that was executed, e.g. "ALTER TABLE". If the 1344 // command tag could not be parsed, parseComplete panics. 1345 func (cn *conn) parseComplete(commandTag string) (driver.Result, string) { 1346 commandsWithAffectedRows := []string{ 1347 "SELECT ", 1348 // INSERT is handled below 1349 "UPDATE ", 1350 "DELETE ", 1351 "FETCH ", 1352 "MOVE ", 1353 "COPY ", 1354 } 1355 1356 var affectedRows *string 1357 for _, tag := range commandsWithAffectedRows { 1358 if strings.HasPrefix(commandTag, tag) { 1359 t := commandTag[len(tag):] 1360 affectedRows = &t 1361 commandTag = tag[:len(tag)-1] 1362 break 1363 } 1364 } 1365 // INSERT also includes the oid of the inserted row in its command tag. 1366 // Oids in user tables are deprecated, and the oid is only returned when 1367 // exactly one row is inserted, so it's unlikely to be of value to any 1368 // real-world application and we can ignore it. 1369 if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") { 1370 parts := strings.Split(commandTag, " ") 1371 if len(parts) != 3 { 1372 cn.bad = true 1373 errorf("unexpected INSERT command tag %s", commandTag) 1374 } 1375 affectedRows = &parts[len(parts)-1] 1376 commandTag = "INSERT" 1377 } 1378 // There should be no affected rows attached to the tag, just return it 1379 if affectedRows == nil { 1380 return driver.RowsAffected(0), commandTag 1381 } 1382 n, err := strconv.ParseInt(*affectedRows, 10, 64) 1383 if err != nil { 1384 cn.bad = true 1385 errorf("could not parse commandTag: %s", err) 1386 } 1387 return driver.RowsAffected(n), commandTag 1388 } 1389 1390 type rows struct { 1391 cn *conn 1392 colNames []string 1393 colTyps []oid.Oid 1394 colFmts []format 1395 done bool 1396 rb readBuf 1397 } 1398 1399 func (rs *rows) Close() error { 1400 // no need to look at cn.bad as Next() will 1401 for { 1402 err := rs.Next(nil) 1403 switch err { 1404 case nil: 1405 case io.EOF: 1406 return nil 1407 default: 1408 return err 1409 } 1410 } 1411 } 1412 1413 func (rs *rows) Columns() []string { 1414 return rs.colNames 1415 } 1416 1417 func (rs *rows) Next(dest []driver.Value) (err error) { 1418 if rs.done { 1419 return io.EOF 1420 } 1421 1422 conn := rs.cn 1423 if conn.bad { 1424 return driver.ErrBadConn 1425 } 1426 defer conn.errRecover(&err) 1427 1428 for { 1429 t := conn.recv1Buf(&rs.rb) 1430 switch t { 1431 case 'E': 1432 err = parseError(&rs.rb) 1433 case 'C', 'I': 1434 continue 1435 case 'Z': 1436 conn.processReadyForQuery(&rs.rb) 1437 rs.done = true 1438 if err != nil { 1439 return err 1440 } 1441 return io.EOF 1442 case 'D': 1443 n := rs.rb.int16() 1444 if err != nil { 1445 conn.bad = true 1446 errorf("unexpected DataRow after error %s", err) 1447 } 1448 if n < len(dest) { 1449 dest = dest[:n] 1450 } 1451 for i := range dest { 1452 l := rs.rb.int32() 1453 if l == -1 { 1454 dest[i] = nil 1455 continue 1456 } 1457 dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i], rs.colFmts[i]) 1458 } 1459 return 1460 default: 1461 errorf("unexpected message after execute: %q", t) 1462 } 1463 } 1464 } 1465 1466 // QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be 1467 // used as part of an SQL statement. For example: 1468 // 1469 // tblname := "my_table" 1470 // data := "my_data" 1471 // err = db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", pq.QuoteIdentifier(tblname)), data) 1472 // 1473 // Any double quotes in name will be escaped. The quoted identifier will be 1474 // case sensitive when used in a query. If the input string contains a zero 1475 // byte, the result will be truncated immediately before it. 1476 func QuoteIdentifier(name string) string { 1477 end := strings.IndexRune(name, 0) 1478 if end > -1 { 1479 name = name[:end] 1480 } 1481 return `"` + strings.Replace(name, `"`, `""`, -1) + `"` 1482 } 1483 1484 func md5s(s string) string { 1485 h := md5.New() 1486 h.Write([]byte(s)) 1487 return fmt.Sprintf("%x", h.Sum(nil)) 1488 } 1489 1490 func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.Value) { 1491 // Do one pass over the parameters to see if we're going to send any of 1492 // them over in binary. If we are, create a paramFormats array at the 1493 // same time. 1494 var paramFormats []int 1495 for i, x := range args { 1496 _, ok := x.([]byte) 1497 if ok { 1498 if paramFormats == nil { 1499 paramFormats = make([]int, len(args)) 1500 } 1501 paramFormats[i] = 1 1502 } 1503 } 1504 if paramFormats == nil { 1505 b.int16(0) 1506 } else { 1507 b.int16(len(paramFormats)) 1508 for _, x := range paramFormats { 1509 b.int16(x) 1510 } 1511 } 1512 1513 b.int16(len(args)) 1514 for _, x := range args { 1515 if x == nil { 1516 b.int32(-1) 1517 } else { 1518 datum := binaryEncode(&cn.parameterStatus, x) 1519 b.int32(len(datum)) 1520 b.bytes(datum) 1521 } 1522 } 1523 } 1524 1525 func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) { 1526 if len(args) >= 65536 { 1527 errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(args)) 1528 } 1529 1530 b := cn.writeBuf('P') 1531 b.byte(0) // unnamed statement 1532 b.string(query) 1533 b.int16(0) 1534 1535 b.next('B') 1536 b.int16(0) // unnamed portal and statement 1537 cn.sendBinaryParameters(b, args) 1538 b.bytes(colFmtDataAllText) 1539 1540 b.next('D') 1541 b.byte('P') 1542 b.byte(0) // unnamed portal 1543 1544 b.next('E') 1545 b.byte(0) 1546 b.int32(0) 1547 1548 b.next('S') 1549 cn.send(b) 1550 } 1551 1552 func (c *conn) processParameterStatus(r *readBuf) { 1553 var err error 1554 1555 param := r.string() 1556 switch param { 1557 case "server_version": 1558 var major1 int 1559 var major2 int 1560 var minor int 1561 _, err = fmt.Sscanf(r.string(), "%d.%d.%d", &major1, &major2, &minor) 1562 if err == nil { 1563 c.parameterStatus.serverVersion = major1*10000 + major2*100 + minor 1564 } 1565 1566 case "TimeZone": 1567 c.parameterStatus.currentLocation, err = time.LoadLocation(r.string()) 1568 if err != nil { 1569 c.parameterStatus.currentLocation = nil 1570 } 1571 1572 default: 1573 // ignore 1574 } 1575 } 1576 1577 func (c *conn) processReadyForQuery(r *readBuf) { 1578 c.txnStatus = transactionStatus(r.byte()) 1579 } 1580 1581 func (cn *conn) readReadyForQuery() { 1582 t, r := cn.recv1() 1583 switch t { 1584 case 'Z': 1585 cn.processReadyForQuery(r) 1586 return 1587 default: 1588 cn.bad = true 1589 errorf("unexpected message %q; expected ReadyForQuery", t) 1590 } 1591 } 1592 1593 func (cn *conn) readParseResponse() { 1594 t, r := cn.recv1() 1595 switch t { 1596 case '1': 1597 return 1598 case 'E': 1599 err := parseError(r) 1600 cn.readReadyForQuery() 1601 panic(err) 1602 default: 1603 cn.bad = true 1604 errorf("unexpected Parse response %q", t) 1605 } 1606 } 1607 1608 func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []oid.Oid) { 1609 for { 1610 t, r := cn.recv1() 1611 switch t { 1612 case 't': 1613 nparams := r.int16() 1614 paramTyps = make([]oid.Oid, nparams) 1615 for i := range paramTyps { 1616 paramTyps[i] = r.oid() 1617 } 1618 case 'n': 1619 return paramTyps, nil, nil 1620 case 'T': 1621 colNames, colTyps = parseStatementRowDescribe(r) 1622 return paramTyps, colNames, colTyps 1623 case 'E': 1624 err := parseError(r) 1625 cn.readReadyForQuery() 1626 panic(err) 1627 default: 1628 cn.bad = true 1629 errorf("unexpected Describe statement response %q", t) 1630 } 1631 } 1632 } 1633 1634 func (cn *conn) readPortalDescribeResponse() (colNames []string, colFmts []format, colTyps []oid.Oid) { 1635 t, r := cn.recv1() 1636 switch t { 1637 case 'T': 1638 return parsePortalRowDescribe(r) 1639 case 'n': 1640 return nil, nil, nil 1641 case 'E': 1642 err := parseError(r) 1643 cn.readReadyForQuery() 1644 panic(err) 1645 default: 1646 cn.bad = true 1647 errorf("unexpected Describe response %q", t) 1648 } 1649 panic("not reached") 1650 } 1651 1652 func (cn *conn) readBindResponse() { 1653 t, r := cn.recv1() 1654 switch t { 1655 case '2': 1656 return 1657 case 'E': 1658 err := parseError(r) 1659 cn.readReadyForQuery() 1660 panic(err) 1661 default: 1662 cn.bad = true 1663 errorf("unexpected Bind response %q", t) 1664 } 1665 } 1666 1667 func (cn *conn) postExecuteWorkaround() { 1668 // Work around a bug in sql.DB.QueryRow: in Go 1.2 and earlier it ignores 1669 // any errors from rows.Next, which masks errors that happened during the 1670 // execution of the query. To avoid the problem in common cases, we wait 1671 // here for one more message from the database. If it's not an error the 1672 // query will likely succeed (or perhaps has already, if it's a 1673 // CommandComplete), so we push the message into the conn struct; recv1 1674 // will return it as the next message for rows.Next or rows.Close. 1675 // However, if it's an error, we wait until ReadyForQuery and then return 1676 // the error to our caller. 1677 for { 1678 t, r := cn.recv1() 1679 switch t { 1680 case 'E': 1681 err := parseError(r) 1682 cn.readReadyForQuery() 1683 panic(err) 1684 case 'C', 'D', 'I': 1685 // the query didn't fail, but we can't process this message 1686 cn.saveMessage(t, r) 1687 return 1688 default: 1689 cn.bad = true 1690 errorf("unexpected message during extended query execution: %q", t) 1691 } 1692 } 1693 } 1694 1695 // Only for Exec(), since we ignore the returned data 1696 func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, commandTag string, err error) { 1697 for { 1698 t, r := cn.recv1() 1699 switch t { 1700 case 'C': 1701 if err != nil { 1702 cn.bad = true 1703 errorf("unexpected CommandComplete after error %s", err) 1704 } 1705 res, commandTag = cn.parseComplete(r.string()) 1706 case 'Z': 1707 cn.processReadyForQuery(r) 1708 return res, commandTag, err 1709 case 'E': 1710 err = parseError(r) 1711 case 'T', 'D', 'I': 1712 if err != nil { 1713 cn.bad = true 1714 errorf("unexpected %q after error %s", t, err) 1715 } 1716 // ignore any results 1717 default: 1718 cn.bad = true 1719 errorf("unknown %s response: %q", protocolState, t) 1720 } 1721 } 1722 } 1723 1724 func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []oid.Oid) { 1725 n := r.int16() 1726 colNames = make([]string, n) 1727 colTyps = make([]oid.Oid, n) 1728 for i := range colNames { 1729 colNames[i] = r.string() 1730 r.next(6) 1731 colTyps[i] = r.oid() 1732 r.next(6) 1733 // format code not known when describing a statement; always 0 1734 r.next(2) 1735 } 1736 return 1737 } 1738 1739 func parsePortalRowDescribe(r *readBuf) (colNames []string, colFmts []format, colTyps []oid.Oid) { 1740 n := r.int16() 1741 colNames = make([]string, n) 1742 colFmts = make([]format, n) 1743 colTyps = make([]oid.Oid, n) 1744 for i := range colNames { 1745 colNames[i] = r.string() 1746 r.next(6) 1747 colTyps[i] = r.oid() 1748 r.next(6) 1749 colFmts[i] = format(r.int16()) 1750 } 1751 return 1752 } 1753 1754 // parseEnviron tries to mimic some of libpq's environment handling 1755 // 1756 // To ease testing, it does not directly reference os.Environ, but is 1757 // designed to accept its output. 1758 // 1759 // Environment-set connection information is intended to have a higher 1760 // precedence than a library default but lower than any explicitly 1761 // passed information (such as in the URL or connection string). 1762 func parseEnviron(env []string) (out map[string]string) { 1763 out = make(map[string]string) 1764 1765 for _, v := range env { 1766 parts := strings.SplitN(v, "=", 2) 1767 1768 accrue := func(keyname string) { 1769 out[keyname] = parts[1] 1770 } 1771 unsupported := func() { 1772 panic(fmt.Sprintf("setting %v not supported", parts[0])) 1773 } 1774 1775 // The order of these is the same as is seen in the 1776 // PostgreSQL 9.1 manual. Unsupported but well-defined 1777 // keys cause a panic; these should be unset prior to 1778 // execution. Options which pq expects to be set to a 1779 // certain value are allowed, but must be set to that 1780 // value if present (they can, of course, be absent). 1781 switch parts[0] { 1782 case "PGHOST": 1783 accrue("host") 1784 case "PGHOSTADDR": 1785 unsupported() 1786 case "PGPORT": 1787 accrue("port") 1788 case "PGDATABASE": 1789 accrue("dbname") 1790 case "PGUSER": 1791 accrue("user") 1792 case "PGPASSWORD": 1793 accrue("password") 1794 case "PGSERVICE", "PGSERVICEFILE", "PGREALM": 1795 unsupported() 1796 case "PGOPTIONS": 1797 accrue("options") 1798 case "PGAPPNAME": 1799 accrue("application_name") 1800 case "PGSSLMODE": 1801 accrue("sslmode") 1802 case "PGSSLCERT": 1803 accrue("sslcert") 1804 case "PGSSLKEY": 1805 accrue("sslkey") 1806 case "PGSSLROOTCERT": 1807 accrue("sslrootcert") 1808 case "PGREQUIRESSL", "PGSSLCRL": 1809 unsupported() 1810 case "PGREQUIREPEER": 1811 unsupported() 1812 case "PGKRBSRVNAME", "PGGSSLIB": 1813 unsupported() 1814 case "PGCONNECT_TIMEOUT": 1815 accrue("connect_timeout") 1816 case "PGCLIENTENCODING": 1817 accrue("client_encoding") 1818 case "PGDATESTYLE": 1819 accrue("datestyle") 1820 case "PGTZ": 1821 accrue("timezone") 1822 case "PGGEQO": 1823 accrue("geqo") 1824 case "PGSYSCONFDIR", "PGLOCALEDIR": 1825 unsupported() 1826 } 1827 } 1828 1829 return out 1830 } 1831 1832 // isUTF8 returns whether name is a fuzzy variation of the string "UTF-8". 1833 func isUTF8(name string) bool { 1834 // Recognize all sorts of silly things as "UTF-8", like Postgres does 1835 s := strings.Map(alnumLowerASCII, name) 1836 return s == "utf8" || s == "unicode" 1837 } 1838 1839 func alnumLowerASCII(ch rune) rune { 1840 if 'A' <= ch && ch <= 'Z' { 1841 return ch + ('a' - 'A') 1842 } 1843 if 'a' <= ch && ch <= 'z' || '0' <= ch && ch <= '9' { 1844 return ch 1845 } 1846 return -1 // discard 1847 }