github.com/cellofellow/gopkg@v0.0.0-20140722061823-eec0544a62ad/database/mysql/utils.go (about) 1 // Go MySQL Driver - A MySQL-Driver for Go's database/sql package 2 // 3 // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. 4 // 5 // This Source Code Form is subject to the terms of the Mozilla Public 6 // License, v. 2.0. If a copy of the MPL was not distributed with this file, 7 // You can obtain one at http://mozilla.org/MPL/2.0/. 8 9 package mysql 10 11 import ( 12 "crypto/sha1" 13 "crypto/tls" 14 "database/sql/driver" 15 "encoding/binary" 16 "errors" 17 "fmt" 18 "io" 19 "net/url" 20 "strings" 21 "time" 22 ) 23 24 var ( 25 tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs 26 27 errInvalidDSNUnescaped = errors.New("Invalid DSN: Did you forget to escape a param value?") 28 errInvalidDSNAddr = errors.New("Invalid DSN: Network Address not terminated (missing closing brace)") 29 errInvalidDSNNoSlash = errors.New("Invalid DSN: Missing the slash separating the database name") 30 ) 31 32 func init() { 33 tlsConfigRegister = make(map[string]*tls.Config) 34 } 35 36 // RegisterTLSConfig registers a custom tls.Config to be used with sql.Open. 37 // Use the key as a value in the DSN where tls=value. 38 // 39 // rootCertPool := x509.NewCertPool() 40 // pem, err := ioutil.ReadFile("/path/ca-cert.pem") 41 // if err != nil { 42 // log.Fatal(err) 43 // } 44 // if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { 45 // log.Fatal("Failed to append PEM.") 46 // } 47 // clientCert := make([]tls.Certificate, 0, 1) 48 // certs, err := tls.LoadX509KeyPair("/path/client-cert.pem", "/path/client-key.pem") 49 // if err != nil { 50 // log.Fatal(err) 51 // } 52 // clientCert = append(clientCert, certs) 53 // mysql.RegisterTLSConfig("custom", &tls.Config{ 54 // RootCAs: rootCertPool, 55 // Certificates: clientCert, 56 // }) 57 // db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom") 58 // 59 func RegisterTLSConfig(key string, config *tls.Config) error { 60 if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" { 61 return fmt.Errorf("Key '%s' is reserved", key) 62 } 63 64 tlsConfigRegister[key] = config 65 return nil 66 } 67 68 // DeregisterTLSConfig removes the tls.Config associated with key. 69 func DeregisterTLSConfig(key string) { 70 delete(tlsConfigRegister, key) 71 } 72 73 // parseDSN parses the DSN string to a config 74 func parseDSN(dsn string) (cfg *config, err error) { 75 // New config with some default values 76 cfg = &config{ 77 loc: time.UTC, 78 collation: defaultCollation, 79 } 80 81 // TODO: use strings.IndexByte when we can depend on Go 1.2 82 83 // [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN] 84 // Find the last '/' (since the password or the net addr might contain a '/') 85 foundSlash := false 86 for i := len(dsn) - 1; i >= 0; i-- { 87 if dsn[i] == '/' { 88 foundSlash = true 89 var j, k int 90 91 // left part is empty if i <= 0 92 if i > 0 { 93 // [username[:password]@][protocol[(address)]] 94 // Find the last '@' in dsn[:i] 95 for j = i; j >= 0; j-- { 96 if dsn[j] == '@' { 97 // username[:password] 98 // Find the first ':' in dsn[:j] 99 for k = 0; k < j; k++ { 100 if dsn[k] == ':' { 101 cfg.passwd = dsn[k+1 : j] 102 break 103 } 104 } 105 cfg.user = dsn[:k] 106 107 break 108 } 109 } 110 111 // [protocol[(address)]] 112 // Find the first '(' in dsn[j+1:i] 113 for k = j + 1; k < i; k++ { 114 if dsn[k] == '(' { 115 // dsn[i-1] must be == ')' if an address is specified 116 if dsn[i-1] != ')' { 117 if strings.ContainsRune(dsn[k+1:i], ')') { 118 return nil, errInvalidDSNUnescaped 119 } 120 return nil, errInvalidDSNAddr 121 } 122 cfg.addr = dsn[k+1 : i-1] 123 break 124 } 125 } 126 cfg.net = dsn[j+1 : k] 127 } 128 129 // dbname[?param1=value1&...¶mN=valueN] 130 // Find the first '?' in dsn[i+1:] 131 for j = i + 1; j < len(dsn); j++ { 132 if dsn[j] == '?' { 133 if err = parseDSNParams(cfg, dsn[j+1:]); err != nil { 134 return 135 } 136 break 137 } 138 } 139 cfg.dbname = dsn[i+1 : j] 140 141 break 142 } 143 } 144 145 if !foundSlash && len(dsn) > 0 { 146 return nil, errInvalidDSNNoSlash 147 } 148 149 // Set default network if empty 150 if cfg.net == "" { 151 cfg.net = "tcp" 152 } 153 154 // Set default address if empty 155 if cfg.addr == "" { 156 switch cfg.net { 157 case "tcp": 158 cfg.addr = "127.0.0.1:3306" 159 case "unix": 160 cfg.addr = "/tmp/mysql.sock" 161 default: 162 return nil, errors.New("Default addr for network '" + cfg.net + "' unknown") 163 } 164 165 } 166 167 return 168 } 169 170 // parseDSNParams parses the DSN "query string" 171 // Values must be url.QueryEscape'ed 172 func parseDSNParams(cfg *config, params string) (err error) { 173 for _, v := range strings.Split(params, "&") { 174 param := strings.SplitN(v, "=", 2) 175 if len(param) != 2 { 176 continue 177 } 178 179 // cfg params 180 switch value := param[1]; param[0] { 181 182 // Disable INFILE whitelist / enable all files 183 case "allowAllFiles": 184 var isBool bool 185 cfg.allowAllFiles, isBool = readBool(value) 186 if !isBool { 187 return fmt.Errorf("Invalid Bool value: %s", value) 188 } 189 190 // Use old authentication mode (pre MySQL 4.1) 191 case "allowOldPasswords": 192 var isBool bool 193 cfg.allowOldPasswords, isBool = readBool(value) 194 if !isBool { 195 return fmt.Errorf("Invalid Bool value: %s", value) 196 } 197 198 // Switch "rowsAffected" mode 199 case "clientFoundRows": 200 var isBool bool 201 cfg.clientFoundRows, isBool = readBool(value) 202 if !isBool { 203 return fmt.Errorf("Invalid Bool value: %s", value) 204 } 205 206 // Collation 207 case "collation": 208 collation, ok := collations[value] 209 if !ok { 210 // Note possibility for false negatives: 211 // could be triggered although the collation is valid if the 212 // collations map does not contain entries the server supports. 213 err = errors.New("unknown collation") 214 return 215 } 216 cfg.collation = collation 217 break 218 219 // Time Location 220 case "loc": 221 if value, err = url.QueryUnescape(value); err != nil { 222 return 223 } 224 cfg.loc, err = time.LoadLocation(value) 225 if err != nil { 226 return 227 } 228 229 // Dial Timeout 230 case "timeout": 231 cfg.timeout, err = time.ParseDuration(value) 232 if err != nil { 233 return 234 } 235 236 // TLS-Encryption 237 case "tls": 238 boolValue, isBool := readBool(value) 239 if isBool { 240 if boolValue { 241 cfg.tls = &tls.Config{} 242 } 243 } else { 244 if strings.ToLower(value) == "skip-verify" { 245 cfg.tls = &tls.Config{InsecureSkipVerify: true} 246 } else if tlsConfig, ok := tlsConfigRegister[value]; ok { 247 cfg.tls = tlsConfig 248 } else { 249 return fmt.Errorf("Invalid value / unknown config name: %s", value) 250 } 251 } 252 253 default: 254 // lazy init 255 if cfg.params == nil { 256 cfg.params = make(map[string]string) 257 } 258 259 if cfg.params[param[0]], err = url.QueryUnescape(value); err != nil { 260 return 261 } 262 } 263 } 264 265 return 266 } 267 268 // Returns the bool value of the input. 269 // The 2nd return value indicates if the input was a valid bool value 270 func readBool(input string) (value bool, valid bool) { 271 switch input { 272 case "1", "true", "TRUE", "True": 273 return true, true 274 case "0", "false", "FALSE", "False": 275 return false, true 276 } 277 278 // Not a valid bool value 279 return 280 } 281 282 /****************************************************************************** 283 * Authentication * 284 ******************************************************************************/ 285 286 // Encrypt password using 4.1+ method 287 func scramblePassword(scramble, password []byte) []byte { 288 if len(password) == 0 { 289 return nil 290 } 291 292 // stage1Hash = SHA1(password) 293 crypt := sha1.New() 294 crypt.Write(password) 295 stage1 := crypt.Sum(nil) 296 297 // scrambleHash = SHA1(scramble + SHA1(stage1Hash)) 298 // inner Hash 299 crypt.Reset() 300 crypt.Write(stage1) 301 hash := crypt.Sum(nil) 302 303 // outer Hash 304 crypt.Reset() 305 crypt.Write(scramble) 306 crypt.Write(hash) 307 scramble = crypt.Sum(nil) 308 309 // token = scrambleHash XOR stage1Hash 310 for i := range scramble { 311 scramble[i] ^= stage1[i] 312 } 313 return scramble 314 } 315 316 // Encrypt password using pre 4.1 (old password) method 317 // https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c 318 type myRnd struct { 319 seed1, seed2 uint32 320 } 321 322 const myRndMaxVal = 0x3FFFFFFF 323 324 // Pseudo random number generator 325 func newMyRnd(seed1, seed2 uint32) *myRnd { 326 return &myRnd{ 327 seed1: seed1 % myRndMaxVal, 328 seed2: seed2 % myRndMaxVal, 329 } 330 } 331 332 // Tested to be equivalent to MariaDB's floating point variant 333 // http://play.golang.org/p/QHvhd4qved 334 // http://play.golang.org/p/RG0q4ElWDx 335 func (r *myRnd) NextByte() byte { 336 r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal 337 r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal 338 339 return byte(uint64(r.seed1) * 31 / myRndMaxVal) 340 } 341 342 // Generate binary hash from byte string using insecure pre 4.1 method 343 func pwHash(password []byte) (result [2]uint32) { 344 var add uint32 = 7 345 var tmp uint32 346 347 result[0] = 1345345333 348 result[1] = 0x12345671 349 350 for _, c := range password { 351 // skip spaces and tabs in password 352 if c == ' ' || c == '\t' { 353 continue 354 } 355 356 tmp = uint32(c) 357 result[0] ^= (((result[0] & 63) + add) * tmp) + (result[0] << 8) 358 result[1] += (result[1] << 8) ^ result[0] 359 add += tmp 360 } 361 362 // Remove sign bit (1<<31)-1) 363 result[0] &= 0x7FFFFFFF 364 result[1] &= 0x7FFFFFFF 365 366 return 367 } 368 369 // Encrypt password using insecure pre 4.1 method 370 func scrambleOldPassword(scramble, password []byte) []byte { 371 if len(password) == 0 { 372 return nil 373 } 374 375 scramble = scramble[:8] 376 377 hashPw := pwHash(password) 378 hashSc := pwHash(scramble) 379 380 r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1]) 381 382 var out [8]byte 383 for i := range out { 384 out[i] = r.NextByte() + 64 385 } 386 387 mask := r.NextByte() 388 for i := range out { 389 out[i] ^= mask 390 } 391 392 return out[:] 393 } 394 395 /****************************************************************************** 396 * Time related utils * 397 ******************************************************************************/ 398 399 // NullTime represents a time.Time that may be NULL. 400 // NullTime implements the Scanner interface so 401 // it can be used as a scan destination: 402 // 403 // var nt NullTime 404 // err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt) 405 // ... 406 // if nt.Valid { 407 // // use nt.Time 408 // } else { 409 // // NULL value 410 // } 411 // 412 // This NullTime implementation is not driver-specific 413 type NullTime struct { 414 Time time.Time 415 Valid bool // Valid is true if Time is not NULL 416 } 417 418 // Scan implements the Scanner interface. 419 // The value type must be time.Time or string / []byte (formatted time-string), 420 // otherwise Scan fails. 421 func (nt *NullTime) Scan(value interface{}) (err error) { 422 if value == nil { 423 nt.Time, nt.Valid = time.Time{}, false 424 return 425 } 426 427 switch v := value.(type) { 428 case time.Time: 429 nt.Time, nt.Valid = v, true 430 return 431 case []byte: 432 nt.Time, err = parseDateTime(string(v), time.UTC) 433 nt.Valid = (err == nil) 434 return 435 case string: 436 nt.Time, err = parseDateTime(v, time.UTC) 437 nt.Valid = (err == nil) 438 return 439 } 440 441 nt.Valid = false 442 return fmt.Errorf("Can't convert %T to time.Time", value) 443 } 444 445 // Value implements the driver Valuer interface. 446 func (nt NullTime) Value() (driver.Value, error) { 447 if !nt.Valid { 448 return nil, nil 449 } 450 return nt.Time, nil 451 } 452 453 func parseDateTime(str string, loc *time.Location) (t time.Time, err error) { 454 switch len(str) { 455 case 10: // YYYY-MM-DD 456 if str == "0000-00-00" { 457 return 458 } 459 t, err = time.Parse(timeFormat[:10], str) 460 case 19: // YYYY-MM-DD HH:MM:SS 461 if str == "0000-00-00 00:00:00" { 462 return 463 } 464 t, err = time.Parse(timeFormat, str) 465 default: 466 err = fmt.Errorf("Invalid Time-String: %s", str) 467 return 468 } 469 470 // Adjust location 471 if err == nil && loc != time.UTC { 472 y, mo, d := t.Date() 473 h, mi, s := t.Clock() 474 t, err = time.Date(y, mo, d, h, mi, s, t.Nanosecond(), loc), nil 475 } 476 477 return 478 } 479 480 func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Value, error) { 481 switch num { 482 case 0: 483 return time.Time{}, nil 484 case 4: 485 return time.Date( 486 int(binary.LittleEndian.Uint16(data[:2])), // year 487 time.Month(data[2]), // month 488 int(data[3]), // day 489 0, 0, 0, 0, 490 loc, 491 ), nil 492 case 7: 493 return time.Date( 494 int(binary.LittleEndian.Uint16(data[:2])), // year 495 time.Month(data[2]), // month 496 int(data[3]), // day 497 int(data[4]), // hour 498 int(data[5]), // minutes 499 int(data[6]), // seconds 500 0, 501 loc, 502 ), nil 503 case 11: 504 return time.Date( 505 int(binary.LittleEndian.Uint16(data[:2])), // year 506 time.Month(data[2]), // month 507 int(data[3]), // day 508 int(data[4]), // hour 509 int(data[5]), // minutes 510 int(data[6]), // seconds 511 int(binary.LittleEndian.Uint32(data[7:11]))*1000, // nanoseconds 512 loc, 513 ), nil 514 } 515 return nil, fmt.Errorf("Invalid DATETIME-packet length %d", num) 516 } 517 518 // zeroDateTime is used in formatBinaryDateTime to avoid an allocation 519 // if the DATE or DATETIME has the zero value. 520 // It must never be changed. 521 // The current behavior depends on database/sql copying the result. 522 var zeroDateTime = []byte("0000-00-00 00:00:00") 523 524 func formatBinaryDateTime(src []byte, withTime bool) (driver.Value, error) { 525 if len(src) == 0 { 526 if withTime { 527 return zeroDateTime, nil 528 } 529 return zeroDateTime[:10], nil 530 } 531 var dst []byte 532 if withTime { 533 if len(src) == 11 { 534 dst = []byte("0000-00-00 00:00:00.000000") 535 } else { 536 dst = []byte("0000-00-00 00:00:00") 537 } 538 } else { 539 dst = []byte("0000-00-00") 540 } 541 switch len(src) { 542 case 11: 543 microsecs := binary.LittleEndian.Uint32(src[7:11]) 544 tmp32 := microsecs / 10 545 dst[25] += byte(microsecs - 10*tmp32) 546 tmp32, microsecs = tmp32/10, tmp32 547 dst[24] += byte(microsecs - 10*tmp32) 548 tmp32, microsecs = tmp32/10, tmp32 549 dst[23] += byte(microsecs - 10*tmp32) 550 tmp32, microsecs = tmp32/10, tmp32 551 dst[22] += byte(microsecs - 10*tmp32) 552 tmp32, microsecs = tmp32/10, tmp32 553 dst[21] += byte(microsecs - 10*tmp32) 554 dst[20] += byte(microsecs / 10) 555 fallthrough 556 case 7: 557 second := src[6] 558 tmp := second / 10 559 dst[18] += second - 10*tmp 560 dst[17] += tmp 561 minute := src[5] 562 tmp = minute / 10 563 dst[15] += minute - 10*tmp 564 dst[14] += tmp 565 hour := src[4] 566 tmp = hour / 10 567 dst[12] += hour - 10*tmp 568 dst[11] += tmp 569 fallthrough 570 case 4: 571 day := src[3] 572 tmp := day / 10 573 dst[9] += day - 10*tmp 574 dst[8] += tmp 575 month := src[2] 576 tmp = month / 10 577 dst[6] += month - 10*tmp 578 dst[5] += tmp 579 year := binary.LittleEndian.Uint16(src[:2]) 580 tmp16 := year / 10 581 dst[3] += byte(year - 10*tmp16) 582 tmp16, year = tmp16/10, tmp16 583 dst[2] += byte(year - 10*tmp16) 584 tmp16, year = tmp16/10, tmp16 585 dst[1] += byte(year - 10*tmp16) 586 dst[0] += byte(tmp16) 587 return dst, nil 588 } 589 var t string 590 if withTime { 591 t = "DATETIME" 592 } else { 593 t = "DATE" 594 } 595 return nil, fmt.Errorf("invalid %s-packet length %d", t, len(src)) 596 } 597 598 /****************************************************************************** 599 * Convert from and to bytes * 600 ******************************************************************************/ 601 602 func uint64ToBytes(n uint64) []byte { 603 return []byte{ 604 byte(n), 605 byte(n >> 8), 606 byte(n >> 16), 607 byte(n >> 24), 608 byte(n >> 32), 609 byte(n >> 40), 610 byte(n >> 48), 611 byte(n >> 56), 612 } 613 } 614 615 func uint64ToString(n uint64) []byte { 616 var a [20]byte 617 i := 20 618 619 // U+0030 = 0 620 // ... 621 // U+0039 = 9 622 623 var q uint64 624 for n >= 10 { 625 i-- 626 q = n / 10 627 a[i] = uint8(n-q*10) + 0x30 628 n = q 629 } 630 631 i-- 632 a[i] = uint8(n) + 0x30 633 634 return a[i:] 635 } 636 637 // treats string value as unsigned integer representation 638 func stringToInt(b []byte) int { 639 val := 0 640 for i := range b { 641 val *= 10 642 val += int(b[i] - 0x30) 643 } 644 return val 645 } 646 647 // returns the string read as a bytes slice, wheter the value is NULL, 648 // the number of bytes read and an error, in case the string is longer than 649 // the input slice 650 func readLengthEncodedString(b []byte) ([]byte, bool, int, error) { 651 // Get length 652 num, isNull, n := readLengthEncodedInteger(b) 653 if num < 1 { 654 return b[n:n], isNull, n, nil 655 } 656 657 n += int(num) 658 659 // Check data length 660 if len(b) >= n { 661 return b[n-int(num) : n], false, n, nil 662 } 663 return nil, false, n, io.EOF 664 } 665 666 // returns the number of bytes skipped and an error, in case the string is 667 // longer than the input slice 668 func skipLengthEncodedString(b []byte) (int, error) { 669 // Get length 670 num, _, n := readLengthEncodedInteger(b) 671 if num < 1 { 672 return n, nil 673 } 674 675 n += int(num) 676 677 // Check data length 678 if len(b) >= n { 679 return n, nil 680 } 681 return n, io.EOF 682 } 683 684 // returns the number read, whether the value is NULL and the number of bytes read 685 func readLengthEncodedInteger(b []byte) (uint64, bool, int) { 686 switch b[0] { 687 688 // 251: NULL 689 case 0xfb: 690 return 0, true, 1 691 692 // 252: value of following 2 693 case 0xfc: 694 return uint64(b[1]) | uint64(b[2])<<8, false, 3 695 696 // 253: value of following 3 697 case 0xfd: 698 return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16, false, 4 699 700 // 254: value of following 8 701 case 0xfe: 702 return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 | 703 uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 | 704 uint64(b[7])<<48 | uint64(b[8])<<56, 705 false, 9 706 } 707 708 // 0-250: value of first byte 709 return uint64(b[0]), false, 1 710 } 711 712 // encodes a uint64 value and appends it to the given bytes slice 713 func appendLengthEncodedInteger(b []byte, n uint64) []byte { 714 switch { 715 case n <= 250: 716 return append(b, byte(n)) 717 718 case n <= 0xffff: 719 return append(b, 0xfc, byte(n), byte(n>>8)) 720 721 case n <= 0xffffff: 722 return append(b, 0xfd, byte(n), byte(n>>8), byte(n>>16)) 723 } 724 return append(b, 0xfe, byte(n), byte(n>>8), byte(n>>16), byte(n>>24), 725 byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56)) 726 }