github.com/insionng/yougam@v0.0.0-20170714101924-2bc18d833463/libraries/lib/pq/encode.go (about)

     1  package pq
     2  
     3  import (
     4  	"bytes"
     5  	"database/sql/driver"
     6  	"encoding/binary"
     7  	"encoding/hex"
     8  	"errors"
     9  	"fmt"
    10  	"math"
    11  	"strconv"
    12  	"strings"
    13  	"sync"
    14  	"time"
    15  
    16  	"github.com/insionng/yougam/libraries/lib/pq/oid"
    17  )
    18  
    19  func binaryEncode(parameterStatus *parameterStatus, x interface{}) []byte {
    20  	switch v := x.(type) {
    21  	case []byte:
    22  		return v
    23  	default:
    24  		return encode(parameterStatus, x, oid.T_unknown)
    25  	}
    26  	panic("not reached")
    27  }
    28  
    29  func encode(parameterStatus *parameterStatus, x interface{}, pgtypOid oid.Oid) []byte {
    30  	switch v := x.(type) {
    31  	case int64:
    32  		return strconv.AppendInt(nil, v, 10)
    33  	case float64:
    34  		return strconv.AppendFloat(nil, v, 'f', -1, 64)
    35  	case []byte:
    36  		if pgtypOid == oid.T_bytea {
    37  			return encodeBytea(parameterStatus.serverVersion, v)
    38  		}
    39  
    40  		return v
    41  	case string:
    42  		if pgtypOid == oid.T_bytea {
    43  			return encodeBytea(parameterStatus.serverVersion, []byte(v))
    44  		}
    45  
    46  		return []byte(v)
    47  	case bool:
    48  		return strconv.AppendBool(nil, v)
    49  	case time.Time:
    50  		return formatTs(v)
    51  
    52  	default:
    53  		errorf("encode: unknown type for %T", v)
    54  	}
    55  
    56  	panic("not reached")
    57  }
    58  
    59  func decode(parameterStatus *parameterStatus, s []byte, typ oid.Oid, f format) interface{} {
    60  	if f == formatBinary {
    61  		return binaryDecode(parameterStatus, s, typ)
    62  	} else {
    63  		return textDecode(parameterStatus, s, typ)
    64  	}
    65  }
    66  
    67  func binaryDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} {
    68  	switch typ {
    69  	case oid.T_bytea:
    70  		return s
    71  	case oid.T_int8:
    72  		return int64(binary.BigEndian.Uint64(s))
    73  	case oid.T_int4:
    74  		return int64(int32(binary.BigEndian.Uint32(s)))
    75  	case oid.T_int2:
    76  		return int64(int16(binary.BigEndian.Uint16(s)))
    77  
    78  	default:
    79  		errorf("don't know how to decode binary parameter of type %u", uint32(typ))
    80  	}
    81  
    82  	panic("not reached")
    83  }
    84  
    85  func textDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} {
    86  	switch typ {
    87  	case oid.T_bytea:
    88  		return parseBytea(s)
    89  	case oid.T_timestamptz:
    90  		return parseTs(parameterStatus.currentLocation, string(s))
    91  	case oid.T_timestamp, oid.T_date:
    92  		return parseTs(nil, string(s))
    93  	case oid.T_time:
    94  		return mustParse("15:04:05", typ, s)
    95  	case oid.T_timetz:
    96  		return mustParse("15:04:05-07", typ, s)
    97  	case oid.T_bool:
    98  		return s[0] == 't'
    99  	case oid.T_int8, oid.T_int4, oid.T_int2:
   100  		i, err := strconv.ParseInt(string(s), 10, 64)
   101  		if err != nil {
   102  			errorf("%s", err)
   103  		}
   104  		return i
   105  	case oid.T_float4, oid.T_float8:
   106  		bits := 64
   107  		if typ == oid.T_float4 {
   108  			bits = 32
   109  		}
   110  		f, err := strconv.ParseFloat(string(s), bits)
   111  		if err != nil {
   112  			errorf("%s", err)
   113  		}
   114  		return f
   115  	}
   116  
   117  	return s
   118  }
   119  
   120  // appendEncodedText encodes item in text format as required by COPY
   121  // and appends to buf
   122  func appendEncodedText(parameterStatus *parameterStatus, buf []byte, x interface{}) []byte {
   123  	switch v := x.(type) {
   124  	case int64:
   125  		return strconv.AppendInt(buf, v, 10)
   126  	case float64:
   127  		return strconv.AppendFloat(buf, v, 'f', -1, 64)
   128  	case []byte:
   129  		encodedBytea := encodeBytea(parameterStatus.serverVersion, v)
   130  		return appendEscapedText(buf, string(encodedBytea))
   131  	case string:
   132  		return appendEscapedText(buf, v)
   133  	case bool:
   134  		return strconv.AppendBool(buf, v)
   135  	case time.Time:
   136  		return append(buf, formatTs(v)...)
   137  	case nil:
   138  		return append(buf, "\\N"...)
   139  	default:
   140  		errorf("encode: unknown type for %T", v)
   141  	}
   142  
   143  	panic("not reached")
   144  }
   145  
   146  func appendEscapedText(buf []byte, text string) []byte {
   147  	escapeNeeded := false
   148  	startPos := 0
   149  	var c byte
   150  
   151  	// check if we need to escape
   152  	for i := 0; i < len(text); i++ {
   153  		c = text[i]
   154  		if c == '\\' || c == '\n' || c == '\r' || c == '\t' {
   155  			escapeNeeded = true
   156  			startPos = i
   157  			break
   158  		}
   159  	}
   160  	if !escapeNeeded {
   161  		return append(buf, text...)
   162  	}
   163  
   164  	// copy till first char to escape, iterate the rest
   165  	result := append(buf, text[:startPos]...)
   166  	for i := startPos; i < len(text); i++ {
   167  		c = text[i]
   168  		switch c {
   169  		case '\\':
   170  			result = append(result, '\\', '\\')
   171  		case '\n':
   172  			result = append(result, '\\', 'n')
   173  		case '\r':
   174  			result = append(result, '\\', 'r')
   175  		case '\t':
   176  			result = append(result, '\\', 't')
   177  		default:
   178  			result = append(result, c)
   179  		}
   180  	}
   181  	return result
   182  }
   183  
   184  func mustParse(f string, typ oid.Oid, s []byte) time.Time {
   185  	str := string(s)
   186  
   187  	// check for a 30-minute-offset timezone
   188  	if (typ == oid.T_timestamptz || typ == oid.T_timetz) &&
   189  		str[len(str)-3] == ':' {
   190  		f += ":00"
   191  	}
   192  	t, err := time.Parse(f, str)
   193  	if err != nil {
   194  		errorf("decode: %s", err)
   195  	}
   196  	return t
   197  }
   198  
   199  var invalidTimestampErr = errors.New("invalid timestamp")
   200  
   201  type timestampParser struct {
   202  	err error
   203  }
   204  
   205  func (p *timestampParser) expect(str, char string, pos int) {
   206  	if p.err != nil {
   207  		return
   208  	}
   209  	if pos+1 > len(str) {
   210  		p.err = invalidTimestampErr
   211  		return
   212  	}
   213  	if c := str[pos : pos+1]; c != char && p.err == nil {
   214  		p.err = fmt.Errorf("expected '%v' at position %v; got '%v'", char, pos, c)
   215  	}
   216  }
   217  
   218  func (p *timestampParser) mustAtoi(str string, begin int, end int) int {
   219  	if p.err != nil {
   220  		return 0
   221  	}
   222  	if begin < 0 || end < 0 || begin > end || end > len(str) {
   223  		p.err = invalidTimestampErr
   224  		return 0
   225  	}
   226  	result, err := strconv.Atoi(str[begin:end])
   227  	if err != nil {
   228  		if p.err == nil {
   229  			p.err = fmt.Errorf("expected number; got '%v'", str)
   230  		}
   231  		return 0
   232  	}
   233  	return result
   234  }
   235  
   236  // The location cache caches the time zones typically used by the client.
   237  type locationCache struct {
   238  	cache map[int]*time.Location
   239  	lock  sync.Mutex
   240  }
   241  
   242  // All connections share the same list of timezones. Benchmarking shows that
   243  // about 5% speed could be gained by putting the cache in the connection and
   244  // losing the mutex, at the cost of a small amount of memory and a somewhat
   245  // significant increase in code complexity.
   246  var globalLocationCache *locationCache = newLocationCache()
   247  
   248  func newLocationCache() *locationCache {
   249  	return &locationCache{cache: make(map[int]*time.Location)}
   250  }
   251  
   252  // Returns the cached timezone for the specified offset, creating and caching
   253  // it if necessary.
   254  func (c *locationCache) getLocation(offset int) *time.Location {
   255  	c.lock.Lock()
   256  	defer c.lock.Unlock()
   257  
   258  	location, ok := c.cache[offset]
   259  	if !ok {
   260  		location = time.FixedZone("", offset)
   261  		c.cache[offset] = location
   262  	}
   263  
   264  	return location
   265  }
   266  
   267  var infinityTsEnabled = false
   268  var infinityTsNegative time.Time
   269  var infinityTsPositive time.Time
   270  
   271  const (
   272  	infinityTsEnabledAlready        = "pq: infinity timestamp enabled already"
   273  	infinityTsNegativeMustBeSmaller = "pq: infinity timestamp: negative value must be smaller (before) than positive"
   274  )
   275  
   276  /*
   277   * If EnableInfinityTs is not called, "-infinity" and "infinity" will return
   278   * []byte("-infinity") and []byte("infinity") respectively, and potentially
   279   * cause error "sql: Scan error on column index 0: unsupported driver -> Scan pair: []uint8 -> *time.Time",
   280   * when scanning into a time.Time value.
   281   *
   282   * Once EnableInfinityTs has been called, all connections created using this
   283   * driver will decode Postgres' "-infinity" and "infinity" for "timestamp",
   284   * "timestamp with time zone" and "date" types to the predefined minimum and
   285   * maximum times, respectively.  When encoding time.Time values, any time which
   286   * equals or precedes the predefined minimum time will be encoded to
   287   * "-infinity".  Any values at or past the maximum time will similarly be
   288   * encoded to "infinity".
   289   *
   290   *
   291   * If EnableInfinityTs is called with negative >= positive, it will panic.
   292   * Calling EnableInfinityTs after a connection has been established results in
   293   * undefined behavior.  If EnableInfinityTs is called more than once, it will
   294   * panic.
   295   */
   296  func EnableInfinityTs(negative time.Time, positive time.Time) {
   297  	if infinityTsEnabled {
   298  		panic(infinityTsEnabledAlready)
   299  	}
   300  	if !negative.Before(positive) {
   301  		panic(infinityTsNegativeMustBeSmaller)
   302  	}
   303  	infinityTsEnabled = true
   304  	infinityTsNegative = negative
   305  	infinityTsPositive = positive
   306  }
   307  
   308  /*
   309   * Testing might want to toggle infinityTsEnabled
   310   */
   311  func disableInfinityTs() {
   312  	infinityTsEnabled = false
   313  }
   314  
   315  // This is a time function specific to the Postgres default DateStyle
   316  // setting ("ISO, MDY"), the only one we currently support. This
   317  // accounts for the discrepancies between the parsing available with
   318  // time.Parse and the Postgres date formatting quirks.
   319  func parseTs(currentLocation *time.Location, str string) interface{} {
   320  	switch str {
   321  	case "-infinity":
   322  		if infinityTsEnabled {
   323  			return infinityTsNegative
   324  		}
   325  		return []byte(str)
   326  	case "infinity":
   327  		if infinityTsEnabled {
   328  			return infinityTsPositive
   329  		}
   330  		return []byte(str)
   331  	}
   332  	t, err := ParseTimestamp(currentLocation, str)
   333  	if err != nil {
   334  		panic(err)
   335  	}
   336  	return t
   337  }
   338  
   339  func ParseTimestamp(currentLocation *time.Location, str string) (time.Time, error) {
   340  	p := timestampParser{}
   341  
   342  	monSep := strings.IndexRune(str, '-')
   343  	// this is Gregorian year, not ISO Year
   344  	// In Gregorian system, the year 1 BC is followed by AD 1
   345  	year := p.mustAtoi(str, 0, monSep)
   346  	daySep := monSep + 3
   347  	month := p.mustAtoi(str, monSep+1, daySep)
   348  	p.expect(str, "-", daySep)
   349  	timeSep := daySep + 3
   350  	day := p.mustAtoi(str, daySep+1, timeSep)
   351  
   352  	var hour, minute, second int
   353  	if len(str) > monSep+len("01-01")+1 {
   354  		p.expect(str, " ", timeSep)
   355  		minSep := timeSep + 3
   356  		p.expect(str, ":", minSep)
   357  		hour = p.mustAtoi(str, timeSep+1, minSep)
   358  		secSep := minSep + 3
   359  		p.expect(str, ":", secSep)
   360  		minute = p.mustAtoi(str, minSep+1, secSep)
   361  		secEnd := secSep + 3
   362  		second = p.mustAtoi(str, secSep+1, secEnd)
   363  	}
   364  	remainderIdx := monSep + len("01-01 00:00:00") + 1
   365  	// Three optional (but ordered) sections follow: the
   366  	// fractional seconds, the time zone offset, and the BC
   367  	// designation. We set them up here and adjust the other
   368  	// offsets if the preceding sections exist.
   369  
   370  	nanoSec := 0
   371  	tzOff := 0
   372  
   373  	if remainderIdx+1 <= len(str) && str[remainderIdx:remainderIdx+1] == "." {
   374  		fracStart := remainderIdx + 1
   375  		fracOff := strings.IndexAny(str[fracStart:], "-+ ")
   376  		if fracOff < 0 {
   377  			fracOff = len(str) - fracStart
   378  		}
   379  		fracSec := p.mustAtoi(str, fracStart, fracStart+fracOff)
   380  		nanoSec = fracSec * (1000000000 / int(math.Pow(10, float64(fracOff))))
   381  
   382  		remainderIdx += fracOff + 1
   383  	}
   384  	if tzStart := remainderIdx; tzStart+1 <= len(str) && (str[tzStart:tzStart+1] == "-" || str[tzStart:tzStart+1] == "+") {
   385  		// time zone separator is always '-' or '+' (UTC is +00)
   386  		var tzSign int
   387  		if c := str[tzStart : tzStart+1]; c == "-" {
   388  			tzSign = -1
   389  		} else if c == "+" {
   390  			tzSign = +1
   391  		} else {
   392  			return time.Time{}, fmt.Errorf("expected '-' or '+' at position %v; got %v", tzStart, c)
   393  		}
   394  		tzHours := p.mustAtoi(str, tzStart+1, tzStart+3)
   395  		remainderIdx += 3
   396  		var tzMin, tzSec int
   397  		if tzStart+4 <= len(str) && str[tzStart+3:tzStart+4] == ":" {
   398  			tzMin = p.mustAtoi(str, tzStart+4, tzStart+6)
   399  			remainderIdx += 3
   400  		}
   401  		if tzStart+7 <= len(str) && str[tzStart+6:tzStart+7] == ":" {
   402  			tzSec = p.mustAtoi(str, tzStart+7, tzStart+9)
   403  			remainderIdx += 3
   404  		}
   405  		tzOff = tzSign * ((tzHours * 60 * 60) + (tzMin * 60) + tzSec)
   406  	}
   407  	var isoYear int
   408  	if remainderIdx+3 <= len(str) && str[remainderIdx:remainderIdx+3] == " BC" {
   409  		isoYear = 1 - year
   410  		remainderIdx += 3
   411  	} else {
   412  		isoYear = year
   413  	}
   414  	if remainderIdx < len(str) {
   415  		return time.Time{}, fmt.Errorf("expected end of input, got %v", str[remainderIdx:])
   416  	}
   417  	t := time.Date(isoYear, time.Month(month), day,
   418  		hour, minute, second, nanoSec,
   419  		globalLocationCache.getLocation(tzOff))
   420  
   421  	if currentLocation != nil {
   422  		// Set the location of the returned Time based on the self.Session.s
   423  		// TimeZone value, but only if the local time zone database agrees with
   424  		// the remote database on the offset.
   425  		lt := t.In(currentLocation)
   426  		_, newOff := lt.Zone()
   427  		if newOff == tzOff {
   428  			t = lt
   429  		}
   430  	}
   431  
   432  	return t, p.err
   433  }
   434  
   435  // formatTs formats t into a format postgres understands.
   436  func formatTs(t time.Time) (b []byte) {
   437  	if infinityTsEnabled {
   438  		// t <= -infinity : ! (t > -infinity)
   439  		if !t.After(infinityTsNegative) {
   440  			return []byte("-infinity")
   441  		}
   442  		// t >= infinity : ! (!t < infinity)
   443  		if !t.Before(infinityTsPositive) {
   444  			return []byte("infinity")
   445  		}
   446  	}
   447  	// Need to send dates before 0001 A.D. with " BC" suffix, instead of the
   448  	// minus sign preferred by Go.
   449  	// Beware, "0000" in ISO is "1 BC", "-0001" is "2 BC" and so on
   450  	bc := false
   451  	if t.Year() <= 0 {
   452  		// flip year sign, and add 1, e.g: "0" will be "1", and "-10" will be "11"
   453  		t = t.AddDate((-t.Year())*2+1, 0, 0)
   454  		bc = true
   455  	}
   456  	b = []byte(t.Format(time.RFC3339Nano))
   457  
   458  	_, offset := t.Zone()
   459  	offset = offset % 60
   460  	if offset != 0 {
   461  		// RFC3339Nano already printed the minus sign
   462  		if offset < 0 {
   463  			offset = -offset
   464  		}
   465  
   466  		b = append(b, ':')
   467  		if offset < 10 {
   468  			b = append(b, '0')
   469  		}
   470  		b = strconv.AppendInt(b, int64(offset), 10)
   471  	}
   472  
   473  	if bc {
   474  		b = append(b, " BC"...)
   475  	}
   476  	return b
   477  }
   478  
   479  // Parse a bytea value received from the server.  Both "hex" and the legacy
   480  // "escape" format are supported.
   481  func parseBytea(s []byte) (result []byte) {
   482  	if len(s) >= 2 && bytes.Equal(s[:2], []byte("\\x")) {
   483  		// bytea_output = hex
   484  		s = s[2:] // trim off leading "\\x"
   485  		result = make([]byte, hex.DecodedLen(len(s)))
   486  		_, err := hex.Decode(result, s)
   487  		if err != nil {
   488  			errorf("%s", err)
   489  		}
   490  	} else {
   491  		// bytea_output = escape
   492  		for len(s) > 0 {
   493  			if s[0] == '\\' {
   494  				// escaped '\\'
   495  				if len(s) >= 2 && s[1] == '\\' {
   496  					result = append(result, '\\')
   497  					s = s[2:]
   498  					continue
   499  				}
   500  
   501  				// '\\' followed by an octal number
   502  				if len(s) < 4 {
   503  					errorf("invalid bytea sequence %v", s)
   504  				}
   505  				r, err := strconv.ParseInt(string(s[1:4]), 8, 9)
   506  				if err != nil {
   507  					errorf("could not parse bytea value: %s", err.Error())
   508  				}
   509  				result = append(result, byte(r))
   510  				s = s[4:]
   511  			} else {
   512  				// We hit an unescaped, raw byte.  Try to read in as many as
   513  				// possible in one go.
   514  				i := bytes.IndexByte(s, '\\')
   515  				if i == -1 {
   516  					result = append(result, s...)
   517  					break
   518  				}
   519  				result = append(result, s[:i]...)
   520  				s = s[i:]
   521  			}
   522  		}
   523  	}
   524  
   525  	return result
   526  }
   527  
   528  func encodeBytea(serverVersion int, v []byte) (result []byte) {
   529  	if serverVersion >= 90000 {
   530  		// Use the hex format if we know that the server supports it
   531  		result = make([]byte, 2+hex.EncodedLen(len(v)))
   532  		result[0] = '\\'
   533  		result[1] = 'x'
   534  		hex.Encode(result[2:], v)
   535  	} else {
   536  		// .. or resort to "escape"
   537  		for _, b := range v {
   538  			if b == '\\' {
   539  				result = append(result, '\\', '\\')
   540  			} else if b < 0x20 || b > 0x7e {
   541  				result = append(result, []byte(fmt.Sprintf("\\%03o", b))...)
   542  			} else {
   543  				result = append(result, b)
   544  			}
   545  		}
   546  	}
   547  
   548  	return result
   549  }
   550  
   551  // NullTime represents a time.Time that may be null. NullTime implements the
   552  // sql.Scanner interface so it can be used as a scan destination, similar to
   553  // sql.NullString.
   554  type NullTime struct {
   555  	Time  time.Time
   556  	Valid bool // Valid is true if Time is not NULL
   557  }
   558  
   559  // Scan implements the Scanner interface.
   560  func (nt *NullTime) Scan(value interface{}) error {
   561  	nt.Time, nt.Valid = value.(time.Time)
   562  	return nil
   563  }
   564  
   565  // Value implements the driver Valuer interface.
   566  func (nt NullTime) Value() (driver.Value, error) {
   567  	if !nt.Valid {
   568  		return nil, nil
   569  	}
   570  	return nt.Time, nil
   571  }