github.com/hellobchain/third_party@v0.0.0-20230331131523-deb0478a2e52/go-sql-driver/mysql/utils.go (about) 1 // Go MySQL Driver - A MySQL-Driver for Go's database/sql package 2 // 3 // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. 4 // 5 // This Source Code Form is subject to the terms of the Mozilla Public 6 // License, v. 2.0. If a copy of the MPL was not distributed with this file, 7 // You can obtain one at http://mozilla.org/MPL/2.0/. 8 9 package mysql 10 11 import ( 12 "database/sql" 13 "database/sql/driver" 14 "encoding/binary" 15 "errors" 16 "fmt" 17 "github.com/hellobchain/newcryptosm/tls" 18 "io" 19 "strconv" 20 "strings" 21 "sync" 22 "sync/atomic" 23 "time" 24 ) 25 26 // Registry for custom tls.Configs 27 var ( 28 tlsConfigLock sync.RWMutex 29 tlsConfigRegistry map[string]*tls.Config 30 ) 31 32 // RegisterTLSConfig registers a custom tls.Config to be used with sql.Open. 33 // Use the key as a value in the DSN where tls=value. 34 // 35 // Note: The provided tls.Config is exclusively owned by the driver after 36 // registering it. 37 // 38 // rootCertPool := x509.NewCertPool() 39 // pem, err := ioutil.ReadFile("/path/ca-cert.pem") 40 // if err != nil { 41 // log.Fatal(err) 42 // } 43 // if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { 44 // log.Fatal("Failed to append PEM.") 45 // } 46 // clientCert := make([]tls.Certificate, 0, 1) 47 // certs, err := tls.LoadX509KeyPair("/path/client-cert.pem", "/path/client-key.pem") 48 // if err != nil { 49 // log.Fatal(err) 50 // } 51 // clientCert = append(clientCert, certs) 52 // mysql.RegisterTLSConfig("custom", &tls.Config{ 53 // RootCAs: rootCertPool, 54 // Certificates: clientCert, 55 // }) 56 // db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom") 57 // 58 func RegisterTLSConfig(key string, config *tls.Config) error { 59 if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" || strings.ToLower(key) == "preferred" { 60 return fmt.Errorf("key '%s' is reserved", key) 61 } 62 63 tlsConfigLock.Lock() 64 if tlsConfigRegistry == nil { 65 tlsConfigRegistry = make(map[string]*tls.Config) 66 } 67 68 tlsConfigRegistry[key] = config 69 tlsConfigLock.Unlock() 70 return nil 71 } 72 73 // DeregisterTLSConfig removes the tls.Config associated with key. 74 func DeregisterTLSConfig(key string) { 75 tlsConfigLock.Lock() 76 if tlsConfigRegistry != nil { 77 delete(tlsConfigRegistry, key) 78 } 79 tlsConfigLock.Unlock() 80 } 81 82 func getTLSConfigClone(key string) (config *tls.Config) { 83 tlsConfigLock.RLock() 84 if v, ok := tlsConfigRegistry[key]; ok { 85 config = v.Clone() 86 } 87 tlsConfigLock.RUnlock() 88 return 89 } 90 91 // Returns the bool value of the input. 92 // The 2nd return value indicates if the input was a valid bool value 93 func readBool(input string) (value bool, valid bool) { 94 switch input { 95 case "1", "true", "TRUE", "True": 96 return true, true 97 case "0", "false", "FALSE", "False": 98 return false, true 99 } 100 101 // Not a valid bool value 102 return 103 } 104 105 /****************************************************************************** 106 * Time related utils * 107 ******************************************************************************/ 108 109 // NullTime represents a time.Time that may be NULL. 110 // NullTime implements the Scanner interface so 111 // it can be used as a scan destination: 112 // 113 // var nt NullTime 114 // err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt) 115 // ... 116 // if nt.Valid { 117 // // use nt.Time 118 // } else { 119 // // NULL value 120 // } 121 // 122 // This NullTime implementation is not driver-specific 123 type NullTime struct { 124 Time time.Time 125 Valid bool // Valid is true if Time is not NULL 126 } 127 128 // Scan implements the Scanner interface. 129 // The value type must be time.Time or string / []byte (formatted time-string), 130 // otherwise Scan fails. 131 func (nt *NullTime) Scan(value interface{}) (err error) { 132 if value == nil { 133 nt.Time, nt.Valid = time.Time{}, false 134 return 135 } 136 137 switch v := value.(type) { 138 case time.Time: 139 nt.Time, nt.Valid = v, true 140 return 141 case []byte: 142 nt.Time, err = parseDateTime(string(v), time.UTC) 143 nt.Valid = (err == nil) 144 return 145 case string: 146 nt.Time, err = parseDateTime(v, time.UTC) 147 nt.Valid = (err == nil) 148 return 149 } 150 151 nt.Valid = false 152 return fmt.Errorf("Can't convert %T to time.Time", value) 153 } 154 155 // Value implements the driver Valuer interface. 156 func (nt NullTime) Value() (driver.Value, error) { 157 if !nt.Valid { 158 return nil, nil 159 } 160 return nt.Time, nil 161 } 162 163 func parseDateTime(str string, loc *time.Location) (t time.Time, err error) { 164 base := "0000-00-00 00:00:00.0000000" 165 switch len(str) { 166 case 10, 19, 21, 22, 23, 24, 25, 26: // up to "YYYY-MM-DD HH:MM:SS.MMMMMM" 167 if str == base[:len(str)] { 168 return 169 } 170 t, err = time.Parse(timeFormat[:len(str)], str) 171 default: 172 err = fmt.Errorf("invalid time string: %s", str) 173 return 174 } 175 176 // Adjust location 177 if err == nil && loc != time.UTC { 178 y, mo, d := t.Date() 179 h, mi, s := t.Clock() 180 t, err = time.Date(y, mo, d, h, mi, s, t.Nanosecond(), loc), nil 181 } 182 183 return 184 } 185 186 func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Value, error) { 187 switch num { 188 case 0: 189 return time.Time{}, nil 190 case 4: 191 return time.Date( 192 int(binary.LittleEndian.Uint16(data[:2])), // year 193 time.Month(data[2]), // month 194 int(data[3]), // day 195 0, 0, 0, 0, 196 loc, 197 ), nil 198 case 7: 199 return time.Date( 200 int(binary.LittleEndian.Uint16(data[:2])), // year 201 time.Month(data[2]), // month 202 int(data[3]), // day 203 int(data[4]), // hour 204 int(data[5]), // minutes 205 int(data[6]), // seconds 206 0, 207 loc, 208 ), nil 209 case 11: 210 return time.Date( 211 int(binary.LittleEndian.Uint16(data[:2])), // year 212 time.Month(data[2]), // month 213 int(data[3]), // day 214 int(data[4]), // hour 215 int(data[5]), // minutes 216 int(data[6]), // seconds 217 int(binary.LittleEndian.Uint32(data[7:11]))*1000, // nanoseconds 218 loc, 219 ), nil 220 } 221 return nil, fmt.Errorf("invalid DATETIME packet length %d", num) 222 } 223 224 // zeroDateTime is used in formatBinaryDateTime to avoid an allocation 225 // if the DATE or DATETIME has the zero value. 226 // It must never be changed. 227 // The current behavior depends on database/sql copying the result. 228 var zeroDateTime = []byte("0000-00-00 00:00:00.000000") 229 230 const digits01 = "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789" 231 const digits10 = "0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999" 232 233 func appendMicrosecs(dst, src []byte, decimals int) []byte { 234 if decimals <= 0 { 235 return dst 236 } 237 if len(src) == 0 { 238 return append(dst, ".000000"[:decimals+1]...) 239 } 240 241 microsecs := binary.LittleEndian.Uint32(src[:4]) 242 p1 := byte(microsecs / 10000) 243 microsecs -= 10000 * uint32(p1) 244 p2 := byte(microsecs / 100) 245 microsecs -= 100 * uint32(p2) 246 p3 := byte(microsecs) 247 248 switch decimals { 249 default: 250 return append(dst, '.', 251 digits10[p1], digits01[p1], 252 digits10[p2], digits01[p2], 253 digits10[p3], digits01[p3], 254 ) 255 case 1: 256 return append(dst, '.', 257 digits10[p1], 258 ) 259 case 2: 260 return append(dst, '.', 261 digits10[p1], digits01[p1], 262 ) 263 case 3: 264 return append(dst, '.', 265 digits10[p1], digits01[p1], 266 digits10[p2], 267 ) 268 case 4: 269 return append(dst, '.', 270 digits10[p1], digits01[p1], 271 digits10[p2], digits01[p2], 272 ) 273 case 5: 274 return append(dst, '.', 275 digits10[p1], digits01[p1], 276 digits10[p2], digits01[p2], 277 digits10[p3], 278 ) 279 } 280 } 281 282 func formatBinaryDateTime(src []byte, length uint8) (driver.Value, error) { 283 // length expects the deterministic length of the zero value, 284 // negative time and 100+ hours are automatically added if needed 285 if len(src) == 0 { 286 return zeroDateTime[:length], nil 287 } 288 var dst []byte // return value 289 var p1, p2, p3 byte // current digit pair 290 291 switch length { 292 case 10, 19, 21, 22, 23, 24, 25, 26: 293 default: 294 t := "DATE" 295 if length > 10 { 296 t += "TIME" 297 } 298 return nil, fmt.Errorf("illegal %s length %d", t, length) 299 } 300 switch len(src) { 301 case 4, 7, 11: 302 default: 303 t := "DATE" 304 if length > 10 { 305 t += "TIME" 306 } 307 return nil, fmt.Errorf("illegal %s packet length %d", t, len(src)) 308 } 309 dst = make([]byte, 0, length) 310 // start with the date 311 year := binary.LittleEndian.Uint16(src[:2]) 312 pt := year / 100 313 p1 = byte(year - 100*uint16(pt)) 314 p2, p3 = src[2], src[3] 315 dst = append(dst, 316 digits10[pt], digits01[pt], 317 digits10[p1], digits01[p1], '-', 318 digits10[p2], digits01[p2], '-', 319 digits10[p3], digits01[p3], 320 ) 321 if length == 10 { 322 return dst, nil 323 } 324 if len(src) == 4 { 325 return append(dst, zeroDateTime[10:length]...), nil 326 } 327 dst = append(dst, ' ') 328 p1 = src[4] // hour 329 src = src[5:] 330 331 // p1 is 2-digit hour, src is after hour 332 p2, p3 = src[0], src[1] 333 dst = append(dst, 334 digits10[p1], digits01[p1], ':', 335 digits10[p2], digits01[p2], ':', 336 digits10[p3], digits01[p3], 337 ) 338 return appendMicrosecs(dst, src[2:], int(length)-20), nil 339 } 340 341 func formatBinaryTime(src []byte, length uint8) (driver.Value, error) { 342 // length expects the deterministic length of the zero value, 343 // negative time and 100+ hours are automatically added if needed 344 if len(src) == 0 { 345 return zeroDateTime[11 : 11+length], nil 346 } 347 var dst []byte // return value 348 349 switch length { 350 case 351 8, // time (can be up to 10 when negative and 100+ hours) 352 10, 11, 12, 13, 14, 15: // time with fractional seconds 353 default: 354 return nil, fmt.Errorf("illegal TIME length %d", length) 355 } 356 switch len(src) { 357 case 8, 12: 358 default: 359 return nil, fmt.Errorf("invalid TIME packet length %d", len(src)) 360 } 361 // +2 to enable negative time and 100+ hours 362 dst = make([]byte, 0, length+2) 363 if src[0] == 1 { 364 dst = append(dst, '-') 365 } 366 days := binary.LittleEndian.Uint32(src[1:5]) 367 hours := int64(days)*24 + int64(src[5]) 368 369 if hours >= 100 { 370 dst = strconv.AppendInt(dst, hours, 10) 371 } else { 372 dst = append(dst, digits10[hours], digits01[hours]) 373 } 374 375 min, sec := src[6], src[7] 376 dst = append(dst, ':', 377 digits10[min], digits01[min], ':', 378 digits10[sec], digits01[sec], 379 ) 380 return appendMicrosecs(dst, src[8:], int(length)-9), nil 381 } 382 383 /****************************************************************************** 384 * Convert from and to bytes * 385 ******************************************************************************/ 386 387 func uint64ToBytes(n uint64) []byte { 388 return []byte{ 389 byte(n), 390 byte(n >> 8), 391 byte(n >> 16), 392 byte(n >> 24), 393 byte(n >> 32), 394 byte(n >> 40), 395 byte(n >> 48), 396 byte(n >> 56), 397 } 398 } 399 400 func uint64ToString(n uint64) []byte { 401 var a [20]byte 402 i := 20 403 404 // U+0030 = 0 405 // ... 406 // U+0039 = 9 407 408 var q uint64 409 for n >= 10 { 410 i-- 411 q = n / 10 412 a[i] = uint8(n-q*10) + 0x30 413 n = q 414 } 415 416 i-- 417 a[i] = uint8(n) + 0x30 418 419 return a[i:] 420 } 421 422 // treats string value as unsigned integer representation 423 func stringToInt(b []byte) int { 424 val := 0 425 for i := range b { 426 val *= 10 427 val += int(b[i] - 0x30) 428 } 429 return val 430 } 431 432 // returns the string read as a bytes slice, wheter the value is NULL, 433 // the number of bytes read and an error, in case the string is longer than 434 // the input slice 435 func readLengthEncodedString(b []byte) ([]byte, bool, int, error) { 436 // Get length 437 num, isNull, n := readLengthEncodedInteger(b) 438 if num < 1 { 439 return b[n:n], isNull, n, nil 440 } 441 442 n += int(num) 443 444 // Check data length 445 if len(b) >= n { 446 return b[n-int(num) : n : n], false, n, nil 447 } 448 return nil, false, n, io.EOF 449 } 450 451 // returns the number of bytes skipped and an error, in case the string is 452 // longer than the input slice 453 func skipLengthEncodedString(b []byte) (int, error) { 454 // Get length 455 num, _, n := readLengthEncodedInteger(b) 456 if num < 1 { 457 return n, nil 458 } 459 460 n += int(num) 461 462 // Check data length 463 if len(b) >= n { 464 return n, nil 465 } 466 return n, io.EOF 467 } 468 469 // returns the number read, whether the value is NULL and the number of bytes read 470 func readLengthEncodedInteger(b []byte) (uint64, bool, int) { 471 // See issue #349 472 if len(b) == 0 { 473 return 0, true, 1 474 } 475 476 switch b[0] { 477 // 251: NULL 478 case 0xfb: 479 return 0, true, 1 480 481 // 252: value of following 2 482 case 0xfc: 483 return uint64(b[1]) | uint64(b[2])<<8, false, 3 484 485 // 253: value of following 3 486 case 0xfd: 487 return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16, false, 4 488 489 // 254: value of following 8 490 case 0xfe: 491 return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 | 492 uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 | 493 uint64(b[7])<<48 | uint64(b[8])<<56, 494 false, 9 495 } 496 497 // 0-250: value of first byte 498 return uint64(b[0]), false, 1 499 } 500 501 // encodes a uint64 value and appends it to the given bytes slice 502 func appendLengthEncodedInteger(b []byte, n uint64) []byte { 503 switch { 504 case n <= 250: 505 return append(b, byte(n)) 506 507 case n <= 0xffff: 508 return append(b, 0xfc, byte(n), byte(n>>8)) 509 510 case n <= 0xffffff: 511 return append(b, 0xfd, byte(n), byte(n>>8), byte(n>>16)) 512 } 513 return append(b, 0xfe, byte(n), byte(n>>8), byte(n>>16), byte(n>>24), 514 byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56)) 515 } 516 517 // reserveBuffer checks cap(buf) and expand buffer to len(buf) + appendSize. 518 // If cap(buf) is not enough, reallocate new buffer. 519 func reserveBuffer(buf []byte, appendSize int) []byte { 520 newSize := len(buf) + appendSize 521 if cap(buf) < newSize { 522 // Grow buffer exponentially 523 newBuf := make([]byte, len(buf)*2+appendSize) 524 copy(newBuf, buf) 525 buf = newBuf 526 } 527 return buf[:newSize] 528 } 529 530 // escapeBytesBackslash escapes []byte with backslashes (\) 531 // This escapes the contents of a string (provided as []byte) by adding backslashes before special 532 // characters, and turning others into specific escape sequences, such as 533 // turning newlines into \n and null bytes into \0. 534 // https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L823-L932 535 func escapeBytesBackslash(buf, v []byte) []byte { 536 pos := len(buf) 537 buf = reserveBuffer(buf, len(v)*2) 538 539 for _, c := range v { 540 switch c { 541 case '\x00': 542 buf[pos] = '\\' 543 buf[pos+1] = '0' 544 pos += 2 545 case '\n': 546 buf[pos] = '\\' 547 buf[pos+1] = 'n' 548 pos += 2 549 case '\r': 550 buf[pos] = '\\' 551 buf[pos+1] = 'r' 552 pos += 2 553 case '\x1a': 554 buf[pos] = '\\' 555 buf[pos+1] = 'Z' 556 pos += 2 557 case '\'': 558 buf[pos] = '\\' 559 buf[pos+1] = '\'' 560 pos += 2 561 case '"': 562 buf[pos] = '\\' 563 buf[pos+1] = '"' 564 pos += 2 565 case '\\': 566 buf[pos] = '\\' 567 buf[pos+1] = '\\' 568 pos += 2 569 default: 570 buf[pos] = c 571 pos++ 572 } 573 } 574 575 return buf[:pos] 576 } 577 578 // escapeStringBackslash is similar to escapeBytesBackslash but for string. 579 func escapeStringBackslash(buf []byte, v string) []byte { 580 pos := len(buf) 581 buf = reserveBuffer(buf, len(v)*2) 582 583 for i := 0; i < len(v); i++ { 584 c := v[i] 585 switch c { 586 case '\x00': 587 buf[pos] = '\\' 588 buf[pos+1] = '0' 589 pos += 2 590 case '\n': 591 buf[pos] = '\\' 592 buf[pos+1] = 'n' 593 pos += 2 594 case '\r': 595 buf[pos] = '\\' 596 buf[pos+1] = 'r' 597 pos += 2 598 case '\x1a': 599 buf[pos] = '\\' 600 buf[pos+1] = 'Z' 601 pos += 2 602 case '\'': 603 buf[pos] = '\\' 604 buf[pos+1] = '\'' 605 pos += 2 606 case '"': 607 buf[pos] = '\\' 608 buf[pos+1] = '"' 609 pos += 2 610 case '\\': 611 buf[pos] = '\\' 612 buf[pos+1] = '\\' 613 pos += 2 614 default: 615 buf[pos] = c 616 pos++ 617 } 618 } 619 620 return buf[:pos] 621 } 622 623 // escapeBytesQuotes escapes apostrophes in []byte by doubling them up. 624 // This escapes the contents of a string by doubling up any apostrophes that 625 // it contains. This is used when the NO_BACKSLASH_ESCAPES SQL_MODE is in 626 // effect on the server. 627 // https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L963-L1038 628 func escapeBytesQuotes(buf, v []byte) []byte { 629 pos := len(buf) 630 buf = reserveBuffer(buf, len(v)*2) 631 632 for _, c := range v { 633 if c == '\'' { 634 buf[pos] = '\'' 635 buf[pos+1] = '\'' 636 pos += 2 637 } else { 638 buf[pos] = c 639 pos++ 640 } 641 } 642 643 return buf[:pos] 644 } 645 646 // escapeStringQuotes is similar to escapeBytesQuotes but for string. 647 func escapeStringQuotes(buf []byte, v string) []byte { 648 pos := len(buf) 649 buf = reserveBuffer(buf, len(v)*2) 650 651 for i := 0; i < len(v); i++ { 652 c := v[i] 653 if c == '\'' { 654 buf[pos] = '\'' 655 buf[pos+1] = '\'' 656 pos += 2 657 } else { 658 buf[pos] = c 659 pos++ 660 } 661 } 662 663 return buf[:pos] 664 } 665 666 /****************************************************************************** 667 * Sync utils * 668 ******************************************************************************/ 669 670 // noCopy may be embedded into structs which must not be copied 671 // after the first use. 672 // 673 // See https://github.com/golang/go/issues/8005#issuecomment-190753527 674 // for details. 675 type noCopy struct{} 676 677 // Lock is a no-op used by -copylocks checker from `go vet`. 678 func (*noCopy) Lock() {} 679 680 // atomicBool is a wrapper around uint32 for usage as a boolean value with 681 // atomic access. 682 type atomicBool struct { 683 _noCopy noCopy 684 value uint32 685 } 686 687 // IsSet returns whether the current boolean value is true 688 func (ab *atomicBool) IsSet() bool { 689 return atomic.LoadUint32(&ab.value) > 0 690 } 691 692 // Set sets the value of the bool regardless of the previous value 693 func (ab *atomicBool) Set(value bool) { 694 if value { 695 atomic.StoreUint32(&ab.value, 1) 696 } else { 697 atomic.StoreUint32(&ab.value, 0) 698 } 699 } 700 701 // TrySet sets the value of the bool and returns whether the value changed 702 func (ab *atomicBool) TrySet(value bool) bool { 703 if value { 704 return atomic.SwapUint32(&ab.value, 1) == 0 705 } 706 return atomic.SwapUint32(&ab.value, 0) > 0 707 } 708 709 // atomicError is a wrapper for atomically accessed error values 710 type atomicError struct { 711 _noCopy noCopy 712 value atomic.Value 713 } 714 715 // Set sets the error value regardless of the previous value. 716 // The value must not be nil 717 func (ae *atomicError) Set(value error) { 718 ae.value.Store(value) 719 } 720 721 // Value returns the current error value 722 func (ae *atomicError) Value() error { 723 if v := ae.value.Load(); v != nil { 724 // this will panic if the value doesn't implement the error interface 725 return v.(error) 726 } 727 return nil 728 } 729 730 func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { 731 dargs := make([]driver.Value, len(named)) 732 for n, param := range named { 733 if len(param.Name) > 0 { 734 // TODO: support the use of Named Parameters #561 735 return nil, errors.New("mysql: driver does not support the use of Named Parameters") 736 } 737 dargs[n] = param.Value 738 } 739 return dargs, nil 740 } 741 742 func mapIsolationLevel(level driver.IsolationLevel) (string, error) { 743 switch sql.IsolationLevel(level) { 744 case sql.LevelRepeatableRead: 745 return "REPEATABLE READ", nil 746 case sql.LevelReadCommitted: 747 return "READ COMMITTED", nil 748 case sql.LevelReadUncommitted: 749 return "READ UNCOMMITTED", nil 750 case sql.LevelSerializable: 751 return "SERIALIZABLE", nil 752 default: 753 return "", fmt.Errorf("mysql: unsupported isolation level: %v", level) 754 } 755 }