github.com/insionng/yougam@v0.0.0-20170714101924-2bc18d833463/libraries/lib/pq/encode.go (about) 1 package pq 2 3 import ( 4 "bytes" 5 "database/sql/driver" 6 "encoding/binary" 7 "encoding/hex" 8 "errors" 9 "fmt" 10 "math" 11 "strconv" 12 "strings" 13 "sync" 14 "time" 15 16 "github.com/insionng/yougam/libraries/lib/pq/oid" 17 ) 18 19 func binaryEncode(parameterStatus *parameterStatus, x interface{}) []byte { 20 switch v := x.(type) { 21 case []byte: 22 return v 23 default: 24 return encode(parameterStatus, x, oid.T_unknown) 25 } 26 panic("not reached") 27 } 28 29 func encode(parameterStatus *parameterStatus, x interface{}, pgtypOid oid.Oid) []byte { 30 switch v := x.(type) { 31 case int64: 32 return strconv.AppendInt(nil, v, 10) 33 case float64: 34 return strconv.AppendFloat(nil, v, 'f', -1, 64) 35 case []byte: 36 if pgtypOid == oid.T_bytea { 37 return encodeBytea(parameterStatus.serverVersion, v) 38 } 39 40 return v 41 case string: 42 if pgtypOid == oid.T_bytea { 43 return encodeBytea(parameterStatus.serverVersion, []byte(v)) 44 } 45 46 return []byte(v) 47 case bool: 48 return strconv.AppendBool(nil, v) 49 case time.Time: 50 return formatTs(v) 51 52 default: 53 errorf("encode: unknown type for %T", v) 54 } 55 56 panic("not reached") 57 } 58 59 func decode(parameterStatus *parameterStatus, s []byte, typ oid.Oid, f format) interface{} { 60 if f == formatBinary { 61 return binaryDecode(parameterStatus, s, typ) 62 } else { 63 return textDecode(parameterStatus, s, typ) 64 } 65 } 66 67 func binaryDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} { 68 switch typ { 69 case oid.T_bytea: 70 return s 71 case oid.T_int8: 72 return int64(binary.BigEndian.Uint64(s)) 73 case oid.T_int4: 74 return int64(int32(binary.BigEndian.Uint32(s))) 75 case oid.T_int2: 76 return int64(int16(binary.BigEndian.Uint16(s))) 77 78 default: 79 errorf("don't know how to decode binary parameter of type %u", uint32(typ)) 80 } 81 82 panic("not reached") 83 } 84 85 func textDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} { 86 switch typ { 87 case oid.T_bytea: 88 return parseBytea(s) 89 case oid.T_timestamptz: 90 return parseTs(parameterStatus.currentLocation, string(s)) 91 case oid.T_timestamp, oid.T_date: 92 return parseTs(nil, string(s)) 93 case oid.T_time: 94 return mustParse("15:04:05", typ, s) 95 case oid.T_timetz: 96 return mustParse("15:04:05-07", typ, s) 97 case oid.T_bool: 98 return s[0] == 't' 99 case oid.T_int8, oid.T_int4, oid.T_int2: 100 i, err := strconv.ParseInt(string(s), 10, 64) 101 if err != nil { 102 errorf("%s", err) 103 } 104 return i 105 case oid.T_float4, oid.T_float8: 106 bits := 64 107 if typ == oid.T_float4 { 108 bits = 32 109 } 110 f, err := strconv.ParseFloat(string(s), bits) 111 if err != nil { 112 errorf("%s", err) 113 } 114 return f 115 } 116 117 return s 118 } 119 120 // appendEncodedText encodes item in text format as required by COPY 121 // and appends to buf 122 func appendEncodedText(parameterStatus *parameterStatus, buf []byte, x interface{}) []byte { 123 switch v := x.(type) { 124 case int64: 125 return strconv.AppendInt(buf, v, 10) 126 case float64: 127 return strconv.AppendFloat(buf, v, 'f', -1, 64) 128 case []byte: 129 encodedBytea := encodeBytea(parameterStatus.serverVersion, v) 130 return appendEscapedText(buf, string(encodedBytea)) 131 case string: 132 return appendEscapedText(buf, v) 133 case bool: 134 return strconv.AppendBool(buf, v) 135 case time.Time: 136 return append(buf, formatTs(v)...) 137 case nil: 138 return append(buf, "\\N"...) 139 default: 140 errorf("encode: unknown type for %T", v) 141 } 142 143 panic("not reached") 144 } 145 146 func appendEscapedText(buf []byte, text string) []byte { 147 escapeNeeded := false 148 startPos := 0 149 var c byte 150 151 // check if we need to escape 152 for i := 0; i < len(text); i++ { 153 c = text[i] 154 if c == '\\' || c == '\n' || c == '\r' || c == '\t' { 155 escapeNeeded = true 156 startPos = i 157 break 158 } 159 } 160 if !escapeNeeded { 161 return append(buf, text...) 162 } 163 164 // copy till first char to escape, iterate the rest 165 result := append(buf, text[:startPos]...) 166 for i := startPos; i < len(text); i++ { 167 c = text[i] 168 switch c { 169 case '\\': 170 result = append(result, '\\', '\\') 171 case '\n': 172 result = append(result, '\\', 'n') 173 case '\r': 174 result = append(result, '\\', 'r') 175 case '\t': 176 result = append(result, '\\', 't') 177 default: 178 result = append(result, c) 179 } 180 } 181 return result 182 } 183 184 func mustParse(f string, typ oid.Oid, s []byte) time.Time { 185 str := string(s) 186 187 // check for a 30-minute-offset timezone 188 if (typ == oid.T_timestamptz || typ == oid.T_timetz) && 189 str[len(str)-3] == ':' { 190 f += ":00" 191 } 192 t, err := time.Parse(f, str) 193 if err != nil { 194 errorf("decode: %s", err) 195 } 196 return t 197 } 198 199 var invalidTimestampErr = errors.New("invalid timestamp") 200 201 type timestampParser struct { 202 err error 203 } 204 205 func (p *timestampParser) expect(str, char string, pos int) { 206 if p.err != nil { 207 return 208 } 209 if pos+1 > len(str) { 210 p.err = invalidTimestampErr 211 return 212 } 213 if c := str[pos : pos+1]; c != char && p.err == nil { 214 p.err = fmt.Errorf("expected '%v' at position %v; got '%v'", char, pos, c) 215 } 216 } 217 218 func (p *timestampParser) mustAtoi(str string, begin int, end int) int { 219 if p.err != nil { 220 return 0 221 } 222 if begin < 0 || end < 0 || begin > end || end > len(str) { 223 p.err = invalidTimestampErr 224 return 0 225 } 226 result, err := strconv.Atoi(str[begin:end]) 227 if err != nil { 228 if p.err == nil { 229 p.err = fmt.Errorf("expected number; got '%v'", str) 230 } 231 return 0 232 } 233 return result 234 } 235 236 // The location cache caches the time zones typically used by the client. 237 type locationCache struct { 238 cache map[int]*time.Location 239 lock sync.Mutex 240 } 241 242 // All connections share the same list of timezones. Benchmarking shows that 243 // about 5% speed could be gained by putting the cache in the connection and 244 // losing the mutex, at the cost of a small amount of memory and a somewhat 245 // significant increase in code complexity. 246 var globalLocationCache *locationCache = newLocationCache() 247 248 func newLocationCache() *locationCache { 249 return &locationCache{cache: make(map[int]*time.Location)} 250 } 251 252 // Returns the cached timezone for the specified offset, creating and caching 253 // it if necessary. 254 func (c *locationCache) getLocation(offset int) *time.Location { 255 c.lock.Lock() 256 defer c.lock.Unlock() 257 258 location, ok := c.cache[offset] 259 if !ok { 260 location = time.FixedZone("", offset) 261 c.cache[offset] = location 262 } 263 264 return location 265 } 266 267 var infinityTsEnabled = false 268 var infinityTsNegative time.Time 269 var infinityTsPositive time.Time 270 271 const ( 272 infinityTsEnabledAlready = "pq: infinity timestamp enabled already" 273 infinityTsNegativeMustBeSmaller = "pq: infinity timestamp: negative value must be smaller (before) than positive" 274 ) 275 276 /* 277 * If EnableInfinityTs is not called, "-infinity" and "infinity" will return 278 * []byte("-infinity") and []byte("infinity") respectively, and potentially 279 * cause error "sql: Scan error on column index 0: unsupported driver -> Scan pair: []uint8 -> *time.Time", 280 * when scanning into a time.Time value. 281 * 282 * Once EnableInfinityTs has been called, all connections created using this 283 * driver will decode Postgres' "-infinity" and "infinity" for "timestamp", 284 * "timestamp with time zone" and "date" types to the predefined minimum and 285 * maximum times, respectively. When encoding time.Time values, any time which 286 * equals or precedes the predefined minimum time will be encoded to 287 * "-infinity". Any values at or past the maximum time will similarly be 288 * encoded to "infinity". 289 * 290 * 291 * If EnableInfinityTs is called with negative >= positive, it will panic. 292 * Calling EnableInfinityTs after a connection has been established results in 293 * undefined behavior. If EnableInfinityTs is called more than once, it will 294 * panic. 295 */ 296 func EnableInfinityTs(negative time.Time, positive time.Time) { 297 if infinityTsEnabled { 298 panic(infinityTsEnabledAlready) 299 } 300 if !negative.Before(positive) { 301 panic(infinityTsNegativeMustBeSmaller) 302 } 303 infinityTsEnabled = true 304 infinityTsNegative = negative 305 infinityTsPositive = positive 306 } 307 308 /* 309 * Testing might want to toggle infinityTsEnabled 310 */ 311 func disableInfinityTs() { 312 infinityTsEnabled = false 313 } 314 315 // This is a time function specific to the Postgres default DateStyle 316 // setting ("ISO, MDY"), the only one we currently support. This 317 // accounts for the discrepancies between the parsing available with 318 // time.Parse and the Postgres date formatting quirks. 319 func parseTs(currentLocation *time.Location, str string) interface{} { 320 switch str { 321 case "-infinity": 322 if infinityTsEnabled { 323 return infinityTsNegative 324 } 325 return []byte(str) 326 case "infinity": 327 if infinityTsEnabled { 328 return infinityTsPositive 329 } 330 return []byte(str) 331 } 332 t, err := ParseTimestamp(currentLocation, str) 333 if err != nil { 334 panic(err) 335 } 336 return t 337 } 338 339 func ParseTimestamp(currentLocation *time.Location, str string) (time.Time, error) { 340 p := timestampParser{} 341 342 monSep := strings.IndexRune(str, '-') 343 // this is Gregorian year, not ISO Year 344 // In Gregorian system, the year 1 BC is followed by AD 1 345 year := p.mustAtoi(str, 0, monSep) 346 daySep := monSep + 3 347 month := p.mustAtoi(str, monSep+1, daySep) 348 p.expect(str, "-", daySep) 349 timeSep := daySep + 3 350 day := p.mustAtoi(str, daySep+1, timeSep) 351 352 var hour, minute, second int 353 if len(str) > monSep+len("01-01")+1 { 354 p.expect(str, " ", timeSep) 355 minSep := timeSep + 3 356 p.expect(str, ":", minSep) 357 hour = p.mustAtoi(str, timeSep+1, minSep) 358 secSep := minSep + 3 359 p.expect(str, ":", secSep) 360 minute = p.mustAtoi(str, minSep+1, secSep) 361 secEnd := secSep + 3 362 second = p.mustAtoi(str, secSep+1, secEnd) 363 } 364 remainderIdx := monSep + len("01-01 00:00:00") + 1 365 // Three optional (but ordered) sections follow: the 366 // fractional seconds, the time zone offset, and the BC 367 // designation. We set them up here and adjust the other 368 // offsets if the preceding sections exist. 369 370 nanoSec := 0 371 tzOff := 0 372 373 if remainderIdx+1 <= len(str) && str[remainderIdx:remainderIdx+1] == "." { 374 fracStart := remainderIdx + 1 375 fracOff := strings.IndexAny(str[fracStart:], "-+ ") 376 if fracOff < 0 { 377 fracOff = len(str) - fracStart 378 } 379 fracSec := p.mustAtoi(str, fracStart, fracStart+fracOff) 380 nanoSec = fracSec * (1000000000 / int(math.Pow(10, float64(fracOff)))) 381 382 remainderIdx += fracOff + 1 383 } 384 if tzStart := remainderIdx; tzStart+1 <= len(str) && (str[tzStart:tzStart+1] == "-" || str[tzStart:tzStart+1] == "+") { 385 // time zone separator is always '-' or '+' (UTC is +00) 386 var tzSign int 387 if c := str[tzStart : tzStart+1]; c == "-" { 388 tzSign = -1 389 } else if c == "+" { 390 tzSign = +1 391 } else { 392 return time.Time{}, fmt.Errorf("expected '-' or '+' at position %v; got %v", tzStart, c) 393 } 394 tzHours := p.mustAtoi(str, tzStart+1, tzStart+3) 395 remainderIdx += 3 396 var tzMin, tzSec int 397 if tzStart+4 <= len(str) && str[tzStart+3:tzStart+4] == ":" { 398 tzMin = p.mustAtoi(str, tzStart+4, tzStart+6) 399 remainderIdx += 3 400 } 401 if tzStart+7 <= len(str) && str[tzStart+6:tzStart+7] == ":" { 402 tzSec = p.mustAtoi(str, tzStart+7, tzStart+9) 403 remainderIdx += 3 404 } 405 tzOff = tzSign * ((tzHours * 60 * 60) + (tzMin * 60) + tzSec) 406 } 407 var isoYear int 408 if remainderIdx+3 <= len(str) && str[remainderIdx:remainderIdx+3] == " BC" { 409 isoYear = 1 - year 410 remainderIdx += 3 411 } else { 412 isoYear = year 413 } 414 if remainderIdx < len(str) { 415 return time.Time{}, fmt.Errorf("expected end of input, got %v", str[remainderIdx:]) 416 } 417 t := time.Date(isoYear, time.Month(month), day, 418 hour, minute, second, nanoSec, 419 globalLocationCache.getLocation(tzOff)) 420 421 if currentLocation != nil { 422 // Set the location of the returned Time based on the self.Session.s 423 // TimeZone value, but only if the local time zone database agrees with 424 // the remote database on the offset. 425 lt := t.In(currentLocation) 426 _, newOff := lt.Zone() 427 if newOff == tzOff { 428 t = lt 429 } 430 } 431 432 return t, p.err 433 } 434 435 // formatTs formats t into a format postgres understands. 436 func formatTs(t time.Time) (b []byte) { 437 if infinityTsEnabled { 438 // t <= -infinity : ! (t > -infinity) 439 if !t.After(infinityTsNegative) { 440 return []byte("-infinity") 441 } 442 // t >= infinity : ! (!t < infinity) 443 if !t.Before(infinityTsPositive) { 444 return []byte("infinity") 445 } 446 } 447 // Need to send dates before 0001 A.D. with " BC" suffix, instead of the 448 // minus sign preferred by Go. 449 // Beware, "0000" in ISO is "1 BC", "-0001" is "2 BC" and so on 450 bc := false 451 if t.Year() <= 0 { 452 // flip year sign, and add 1, e.g: "0" will be "1", and "-10" will be "11" 453 t = t.AddDate((-t.Year())*2+1, 0, 0) 454 bc = true 455 } 456 b = []byte(t.Format(time.RFC3339Nano)) 457 458 _, offset := t.Zone() 459 offset = offset % 60 460 if offset != 0 { 461 // RFC3339Nano already printed the minus sign 462 if offset < 0 { 463 offset = -offset 464 } 465 466 b = append(b, ':') 467 if offset < 10 { 468 b = append(b, '0') 469 } 470 b = strconv.AppendInt(b, int64(offset), 10) 471 } 472 473 if bc { 474 b = append(b, " BC"...) 475 } 476 return b 477 } 478 479 // Parse a bytea value received from the server. Both "hex" and the legacy 480 // "escape" format are supported. 481 func parseBytea(s []byte) (result []byte) { 482 if len(s) >= 2 && bytes.Equal(s[:2], []byte("\\x")) { 483 // bytea_output = hex 484 s = s[2:] // trim off leading "\\x" 485 result = make([]byte, hex.DecodedLen(len(s))) 486 _, err := hex.Decode(result, s) 487 if err != nil { 488 errorf("%s", err) 489 } 490 } else { 491 // bytea_output = escape 492 for len(s) > 0 { 493 if s[0] == '\\' { 494 // escaped '\\' 495 if len(s) >= 2 && s[1] == '\\' { 496 result = append(result, '\\') 497 s = s[2:] 498 continue 499 } 500 501 // '\\' followed by an octal number 502 if len(s) < 4 { 503 errorf("invalid bytea sequence %v", s) 504 } 505 r, err := strconv.ParseInt(string(s[1:4]), 8, 9) 506 if err != nil { 507 errorf("could not parse bytea value: %s", err.Error()) 508 } 509 result = append(result, byte(r)) 510 s = s[4:] 511 } else { 512 // We hit an unescaped, raw byte. Try to read in as many as 513 // possible in one go. 514 i := bytes.IndexByte(s, '\\') 515 if i == -1 { 516 result = append(result, s...) 517 break 518 } 519 result = append(result, s[:i]...) 520 s = s[i:] 521 } 522 } 523 } 524 525 return result 526 } 527 528 func encodeBytea(serverVersion int, v []byte) (result []byte) { 529 if serverVersion >= 90000 { 530 // Use the hex format if we know that the server supports it 531 result = make([]byte, 2+hex.EncodedLen(len(v))) 532 result[0] = '\\' 533 result[1] = 'x' 534 hex.Encode(result[2:], v) 535 } else { 536 // .. or resort to "escape" 537 for _, b := range v { 538 if b == '\\' { 539 result = append(result, '\\', '\\') 540 } else if b < 0x20 || b > 0x7e { 541 result = append(result, []byte(fmt.Sprintf("\\%03o", b))...) 542 } else { 543 result = append(result, b) 544 } 545 } 546 } 547 548 return result 549 } 550 551 // NullTime represents a time.Time that may be null. NullTime implements the 552 // sql.Scanner interface so it can be used as a scan destination, similar to 553 // sql.NullString. 554 type NullTime struct { 555 Time time.Time 556 Valid bool // Valid is true if Time is not NULL 557 } 558 559 // Scan implements the Scanner interface. 560 func (nt *NullTime) Scan(value interface{}) error { 561 nt.Time, nt.Valid = value.(time.Time) 562 return nil 563 } 564 565 // Value implements the driver Valuer interface. 566 func (nt NullTime) Value() (driver.Value, error) { 567 if !nt.Valid { 568 return nil, nil 569 } 570 return nt.Time, nil 571 }