github.com/XiaoMi/Gaea@v1.2.5/backend/direct_connection.go (about) 1 // Copyright 2019 The Gaea Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package backend 16 17 import ( 18 "bytes" 19 "encoding/binary" 20 "errors" 21 "fmt" 22 "net" 23 "strings" 24 25 sqlerr "github.com/XiaoMi/Gaea/core/errors" 26 "github.com/XiaoMi/Gaea/log" 27 "github.com/XiaoMi/Gaea/mysql" 28 "github.com/XiaoMi/Gaea/util/sync2" 29 ) 30 31 // DirectConnection means connection to backend mysql 32 type DirectConnection struct { 33 conn *mysql.Conn 34 35 addr string 36 user string 37 password string 38 db string 39 40 capability uint32 41 42 sessionVariables *mysql.SessionVariables 43 44 status uint16 45 46 collation mysql.CollationID 47 charset string 48 salt []byte 49 50 defaultCollation mysql.CollationID 51 defaultCharset string 52 53 pkgErr error 54 closed sync2.AtomicBool 55 } 56 57 // NewDirectConnection return direct and authorised connection to mysql with real net connection 58 func NewDirectConnection(addr string, user string, password string, db string, charset string, collationID mysql.CollationID) (*DirectConnection, error) { 59 dc := &DirectConnection{ 60 addr: addr, 61 user: user, 62 password: password, 63 db: db, 64 charset: charset, 65 collation: collationID, 66 defaultCharset: charset, 67 defaultCollation: collationID, 68 closed: sync2.NewAtomicBool(false), 69 sessionVariables: mysql.NewSessionVariables(), 70 } 71 err := dc.connect() 72 return dc, err 73 } 74 75 // connect means real connection to backend mysql after authorization 76 func (dc *DirectConnection) connect() error { 77 if dc.conn != nil { 78 dc.conn.Close() 79 } 80 81 typ := "tcp" 82 if strings.Contains(dc.addr, "/") { 83 typ = "unix" 84 } 85 86 netConn, err := net.Dial(typ, dc.addr) 87 if err != nil { 88 return err 89 } 90 91 tcpConn := netConn.(*net.TCPConn) 92 // SetNoDelay controls whether the operating system should delay packet transmission 93 // in hopes of sending fewer packets (Nagle's algorithm). 94 // The default is true (no delay), 95 // meaning that data is sent as soon as possible after a Write. 96 tcpConn.SetNoDelay(true) 97 tcpConn.SetKeepAlive(true) 98 dc.conn = mysql.NewConn(tcpConn) 99 100 // step1: read handshake requirements 101 if err := dc.readInitialHandshake(); err != nil { 102 dc.conn.Close() 103 return err 104 } 105 106 // step2: write handshake response 107 if err := dc.writeHandshakeResponse41(); err != nil { 108 dc.conn.Close() 109 110 return err 111 } 112 113 response, err := dc.readPacket() 114 if err != nil { 115 dc.conn.Close() 116 return err 117 } 118 119 switch response[0] { 120 case mysql.OKHeader: 121 default: 122 return errors.New("dc connection handshake failed with mysql") 123 } 124 125 // we must always use autocommit 126 if !dc.IsAutoCommit() { 127 if _, err := dc.exec("set autocommit = 1", 0); err != nil { 128 dc.conn.Close() 129 130 return err 131 } 132 } 133 134 return nil 135 } 136 137 // Close close connection to backend mysql and reset conn structure 138 func (dc *DirectConnection) Close() { 139 if dc.conn != nil { 140 dc.conn.Close() 141 } 142 143 dc.conn = nil 144 dc.salt = nil 145 dc.pkgErr = nil 146 dc.closed.Set(true) 147 148 return 149 } 150 151 // IsClosed check if connection closed 152 func (dc *DirectConnection) IsClosed() bool { 153 return dc.closed.Get() 154 } 155 156 // readPacket doesn't use EphemeralBuffer 157 func (dc *DirectConnection) readPacket() ([]byte, error) { 158 data, err := dc.conn.ReadPacket() 159 dc.pkgErr = err 160 return data, err 161 } 162 163 // writePacket doesn't use EphemeralBuffer 164 func (dc *DirectConnection) writePacket(data []byte) error { 165 err := dc.conn.WritePacket(data) 166 if err != nil && strings.Contains(err.Error(), "broken pipe") { 167 // retry 3 times, close dc's conn、reset dc's stats and reconnect 168 for i := 0; i < 3; i++ { 169 dc.Close() 170 e := dc.connect() 171 if e == nil { // no need to write data again 172 break 173 } 174 } 175 176 } 177 return err 178 } 179 180 // writeEphemeralPacket 181 func (dc *DirectConnection) writeEphemeralPacket() error { 182 err := dc.conn.WriteEphemeralPacket() 183 if err != nil && strings.Contains(err.Error(), "broken pipe") { 184 // retry 3 times, close dc's conn、reset dc's stats and reconnect 185 for i := 0; i < 3; i++ { 186 dc.Close() 187 e := dc.connect() 188 if e == nil { // no need to write data again and ephemeral buffer is recycled 189 break 190 } 191 } 192 } 193 return err 194 } 195 196 func (dc *DirectConnection) readInitialHandshake() error { 197 data, err := dc.readPacket() 198 if err != nil { 199 return err 200 } 201 202 if data[0] == mysql.ErrHeader { 203 return errors.New("read initial handshake error") 204 } 205 206 if data[0] < mysql.MinProtocolVersion { 207 return fmt.Errorf("invalid protocol version %d, must >= 10", data[0]) 208 } 209 210 //skip mysql version 211 //mysql version end with 0x00 212 pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 213 214 // get connection id 215 dc.conn.ConnectionID = binary.LittleEndian.Uint32(data[pos : pos+4]) 216 217 pos += 4 218 219 dc.salt = append(dc.salt, data[pos:pos+8]...) 220 221 //skip filter 222 pos += 8 + 1 223 224 //capability lower 2 bytes 225 dc.capability = uint32(binary.LittleEndian.Uint16(data[pos : pos+2])) 226 227 pos += 2 228 229 if len(data) > pos { 230 //skip server charset 231 //c.charset = data[pos] 232 pos++ 233 234 dc.status = binary.LittleEndian.Uint16(data[pos : pos+2]) 235 pos += 2 236 237 dc.capability = uint32(binary.LittleEndian.Uint16(data[pos:pos+2]))<<16 | dc.capability 238 239 pos += 2 240 241 //skip auth data len or [00] 242 //skip reserved (all [00]) 243 pos += 10 + 1 244 245 // The documentation is ambiguous about the length. 246 // The official Python library uses the fixed length 12 247 // mysql-proxy also use 12 248 // which is not documented but seems to work. 249 dc.salt = append(dc.salt, data[pos:pos+12]...) 250 } 251 252 return nil 253 } 254 255 // writeHandshakeResponse41 writes the handshake response. 256 func (dc *DirectConnection) writeHandshakeResponse41() error { 257 // Adjust client capability flags based on server support 258 capability := mysql.ClientProtocol41 | mysql.ClientSecureConnection | 259 mysql.ClientLongPassword | mysql.ClientTransactions | mysql.ClientLongFlag 260 capability &= dc.capability 261 262 //we only support secure connection 263 auth := mysql.CalcPassword(dc.salt, []byte(dc.password)) 264 265 length := 4 + // Client capability flags 266 4 + // Max-packet size. 267 1 + // Character set. 268 23 + // Reserved. 269 mysql.LenNullString(dc.user) + // user 270 1 + 271 len(auth) 272 273 if len(dc.db) > 0 { 274 capability |= mysql.ClientConnectWithDB 275 length += mysql.LenNullString(dc.db) 276 } 277 278 dc.capability = capability 279 280 data := make([]byte, length, length) 281 pos := 0 282 283 // Client capability flags. 284 pos = mysql.WriteUint32(data, pos, capability) 285 286 // Max-packet size, always 0. See doc.go. 287 pos = mysql.WriteZeroes(data, pos, 4) 288 289 // Character set. 290 pos = mysql.WriteByte(data, pos, byte(dc.collation)) 291 292 // 23 reserved bytes, all 0. 293 pos = mysql.WriteZeroes(data, pos, 23) 294 295 // user type: null terminated string 296 pos = mysql.WriteNullString(data, pos, dc.user) 297 298 // auth [length encoded integer] 299 data[pos] = byte(len(auth)) 300 pos++ 301 pos += copy(data[pos:], auth) 302 303 // db type: null terminated string 304 if len(dc.db) > 0 { 305 pos = mysql.WriteNullString(data, pos, dc.db) 306 } 307 308 if err := dc.writePacket(data); err != nil { 309 return err 310 } 311 312 return nil 313 } 314 315 // writeComInitDB changes the default database to use. 316 // Client -> Server.DirectConnection 317 // Returns SQLError(CRServerGone) if it can't. 318 func (dc *DirectConnection) writeComInitDB(db string) error { 319 dc.conn.SetSequence(0) 320 data := make([]byte, len(db)+1, len(db)+1) 321 data[0] = mysql.ComInitDB 322 copy(data[1:], db) 323 if err := dc.writePacket(data); err != nil { 324 return err 325 } 326 return nil 327 } 328 329 // writeComQuery send ComQuery request use EphemeralBuffer 330 func (dc *DirectConnection) writeComQuery(sql string) error { 331 dc.conn.SetSequence(0) 332 data := dc.conn.StartEphemeralPacket(len(sql) + 1) 333 data[0] = mysql.ComQuery 334 copy(data[1:], sql) 335 if err := dc.writeEphemeralPacket(); err != nil { 336 return err 337 } 338 return nil 339 } 340 341 func (dc *DirectConnection) writeComFieldList(table string, wildcard string) error { 342 dc.conn.SetSequence(0) 343 length := 1 + 344 mysql.LenNullString(table) + 345 mysql.LenNullString(wildcard) 346 347 data := make([]byte, length, length) 348 pos := 0 349 350 pos = mysql.WriteByte(data, 0, mysql.ComFieldList) 351 pos = mysql.WriteNullString(data, pos, table) 352 pos = mysql.WriteNullString(data, pos, wildcard) 353 354 if err := dc.writePacket(data); err != nil { 355 return err 356 } 357 358 return nil 359 } 360 361 // Ping implements mysql ping command. 362 func (dc *DirectConnection) Ping() error { 363 dc.conn.SetSequence(0) 364 if err := dc.writePacket([]byte{mysql.ComPing}); err != nil { 365 return err 366 } 367 data, err := dc.readPacket() 368 if err != nil { 369 return err 370 } 371 switch data[0] { 372 case mysql.OKHeader: 373 return nil 374 case mysql.ErrHeader: 375 return errors.New("dc connection ping failed") 376 } 377 return fmt.Errorf("unexpected packet type: %d", data[0]) 378 } 379 380 // UseDB send ComInitDB to backend mysql 381 func (dc *DirectConnection) UseDB(dbName string) error { 382 dc.conn.SetSequence(0) 383 if dc.db == dbName || len(dbName) == 0 { 384 return nil 385 } 386 387 if err := dc.writeComInitDB(dbName); err != nil { 388 return err 389 } 390 391 if r, err := dc.readPacket(); err != nil { 392 return err 393 } else if !mysql.IsOKPacket(r) { 394 return errors.New("dc connection use db failed") 395 } 396 397 dc.db = dbName 398 return nil 399 } 400 401 // GetDB return database name 402 func (dc *DirectConnection) GetDB() string { 403 return dc.db 404 } 405 406 // GetAddr return addr of backend mysql 407 func (dc *DirectConnection) GetAddr() string { 408 return dc.addr 409 } 410 411 // Execute send ComQuery or ComStmtPrepare/ComStmtExecute/ComStmtClose to backend mysql 412 func (dc *DirectConnection) Execute(sql string, maxRows int) (*mysql.Result, error) { 413 return dc.exec(sql, maxRows) 414 } 415 416 // Begin send ComQuery with 'begin' to backend mysql to start transaction 417 func (dc *DirectConnection) Begin() error { 418 _, err := dc.exec("begin", 0) 419 return err 420 } 421 422 // Commit send ComQuery with 'commit' to backend mysql to commit transaction 423 func (dc *DirectConnection) Commit() error { 424 _, err := dc.exec("commit", 0) 425 return err 426 } 427 428 // Rollback send ComQuery with 'rollback' to backend mysql to rollback transaction 429 func (dc *DirectConnection) Rollback() error { 430 _, err := dc.exec("rollback", 0) 431 return err 432 } 433 434 // SetAutoCommit trun on/off autocommit 435 func (dc *DirectConnection) SetAutoCommit(v uint8) error { 436 if v == 0 { 437 if _, err := dc.exec("set autocommit = 0", 0); err != nil { 438 dc.conn.Close() 439 440 return err 441 } 442 } else { 443 if _, err := dc.exec("set autocommit = 1", 0); err != nil { 444 dc.conn.Close() 445 446 return err 447 } 448 } 449 return nil 450 } 451 452 // SetCharset set charset of connection to backend mysql 453 func (dc *DirectConnection) SetCharset(charset string, collation mysql.CollationID) ( /*changed*/ bool, error) { 454 charset = strings.Trim(charset, "\"'`") 455 456 if collation == 0 || collation > 247 { 457 collation = mysql.CollationNames[mysql.Charsets[charset]] 458 } 459 460 if dc.charset == charset && dc.collation == collation { 461 return false, nil 462 } 463 464 _, ok := mysql.CharsetIds[charset] 465 if !ok { 466 return false, fmt.Errorf("invalid charset %s", charset) 467 } 468 469 _, ok = mysql.Collations[collation] 470 if !ok { 471 return false, fmt.Errorf("invalid collation %d", collation) 472 } 473 474 dc.collation = collation 475 dc.charset = charset 476 return true, nil 477 } 478 479 // ResetConnection reset connection stattus, include transaction、autocommit、charset、sql_mode .etc 480 func (dc *DirectConnection) ResetConnection() error { 481 if dc.IsInTransaction() { 482 log.Debug("get transaction connection from pool, addr: %s, user: %s, db: %s, status: %d", dc.addr, dc.user, dc.db, dc.status) 483 if err := dc.Rollback(); err != nil { 484 log.Warn("rollback in reset connection error, addr: %s, user: %s, db: %s, status: %d, err: %v", dc.addr, dc.user, dc.db, dc.status, err) 485 return err 486 } 487 } 488 489 if !dc.IsAutoCommit() { 490 log.Debug("get autocommit = 0 connection from pool, addr: %s, user: %s, db: %s, status: %d", dc.addr, dc.user, dc.db, dc.status) 491 if err := dc.SetAutoCommit(1); err != nil { 492 log.Warn("set autocommit = 1 in reset connection error, addr: %s, user: %s, db: %s, status: %d, err: %v", dc.addr, dc.user, dc.db, dc.status, err) 493 return err 494 } 495 } 496 497 return nil 498 } 499 500 // SetSessionVariables set direction variables according to Session 501 func (dc *DirectConnection) SetSessionVariables(frontend *mysql.SessionVariables) (bool, error) { 502 return dc.sessionVariables.SetEqualsWith(frontend) 503 } 504 505 // WriteSetStatement execute sql 506 func (dc *DirectConnection) WriteSetStatement() error { 507 var setVariableSQL bytes.Buffer 508 collation, ok := mysql.Collations[dc.collation] 509 if !ok { 510 return fmt.Errorf("invalid collationId: %v", dc.collation) 511 } 512 appendSetCharset(&setVariableSQL, dc.charset, collation) 513 514 for _, v := range dc.sessionVariables.GetAll() { 515 appendSetVariable(&setVariableSQL, v.Name(), v.Get()) 516 } 517 518 for _, v := range dc.sessionVariables.GetUnusedAndClear() { 519 appendSetVariableToDefault(&setVariableSQL, v.Name()) 520 } 521 522 setSQL := setVariableSQL.String() 523 if setSQL == "" { 524 return nil 525 } 526 if _, err := dc.exec(setSQL, 0); err != nil { 527 return err 528 } 529 return nil 530 } 531 532 // FieldList send ComFieldList to backend mysql 533 func (dc *DirectConnection) FieldList(table string, wildcard string) ([]*mysql.Field, error) { 534 if err := dc.writeComFieldList(table, wildcard); err != nil { 535 return nil, err 536 } 537 fs := make([]*mysql.Field, 0, 4) 538 var f *mysql.Field 539 for { 540 data, err := dc.readPacket() 541 if err != nil { 542 return nil, err 543 } 544 545 // EOF Packet 546 if dc.isEOFPacket(data) { 547 return fs, nil 548 } 549 550 if data[0] == mysql.ErrHeader { 551 return nil, dc.handleErrorPacket(data) 552 } 553 554 if f, err = mysql.FieldData(data).Parse(); err != nil { 555 return nil, err 556 } 557 fs = append(fs, f) 558 } 559 } 560 561 // execute ComQuery command 562 func (dc *DirectConnection) exec(query string, maxRows int) (*mysql.Result, error) { 563 if err := dc.writeComQuery(query); err != nil { 564 return nil, err 565 } 566 567 return dc.readResult(false, maxRows) 568 } 569 570 // read resultset from mysql 571 func (dc *DirectConnection) readResultSet(data []byte, binary bool, maxRows int) (*mysql.Result, error) { 572 result := &mysql.Result{ 573 Status: 0, 574 InsertID: 0, 575 AffectedRows: 0, 576 577 Resultset: &mysql.Resultset{}, 578 } 579 580 // column count 581 pos := 0 582 count, pos, _, _ := mysql.ReadLenEncInt(data, pos) 583 584 if pos-len(data) != 0 { 585 return nil, mysql.ErrMalformPacket 586 } 587 588 result.Fields = make([]*mysql.Field, count) 589 result.FieldNames = make(map[string]int, count) 590 591 if err := dc.readResultColumns(result); err != nil { 592 return nil, err 593 } 594 595 if err := dc.readResultRows(result, binary, maxRows); err != nil { 596 return nil, err 597 } 598 599 return result, nil 600 } 601 602 // readResultColumns read column information 603 func (dc *DirectConnection) readResultColumns(result *mysql.Result) (err error) { 604 var i = 0 605 var data []byte 606 607 for { 608 data, err = dc.readPacket() 609 if err != nil { 610 return 611 } 612 613 // EOF Packet 614 if dc.isEOFPacket(data) { 615 if dc.capability&mysql.ClientProtocol41 > 0 { 616 //result.Warnings = binary.LittleEndian.Uint16(data[1:]) 617 //todo add strict_mode, warning will be treat as error 618 result.Status = binary.LittleEndian.Uint16(data[3:]) 619 dc.status = result.Status 620 } 621 622 if i != len(result.Fields) { 623 err = mysql.ErrMalformPacket 624 } 625 626 return 627 } 628 629 if data[0] == mysql.ErrHeader { 630 return dc.handleErrorPacket(data) 631 } 632 633 result.Fields[i], err = mysql.FieldData(data).Parse() 634 if err != nil { 635 return 636 } 637 638 result.FieldNames[string(result.Fields[i].Name)] = i 639 640 i++ 641 } 642 } 643 644 // readResultRows read result rows 645 func (dc *DirectConnection) readResultRows(result *mysql.Result, isBinary bool, maxRows int) (err error) { 646 var data []byte 647 648 for { 649 data, err = dc.readPacket() 650 if err != nil { 651 return 652 } 653 654 // EOF Packet 655 if dc.isEOFPacket(data) { 656 if dc.capability&mysql.ClientProtocol41 > 0 { 657 //result.Warnings = binary.LittleEndian.Uint16(data[1:]) 658 //todo add strict_mode, warning will be treat as error 659 result.Status = binary.LittleEndian.Uint16(data[3:]) 660 dc.status = result.Status 661 } 662 663 break 664 } 665 666 if data[0] == mysql.ErrHeader { 667 return dc.handleErrorPacket(data) 668 } 669 670 result.RowDatas = append(result.RowDatas, data) 671 if maxRows > 0 && len(result.RowDatas) >= maxRows { 672 if err := dc.drainResults(); err != nil { 673 return fmt.Errorf("%v %d, drain error: %v", sqlerr.ErrRowsLimitExceeded, maxRows, err) 674 } 675 return fmt.Errorf("%v %d", sqlerr.ErrRowsLimitExceeded, maxRows) 676 } 677 } 678 679 result.Values = make([][]interface{}, len(result.RowDatas)) 680 for i := range result.Values { 681 result.Values[i], err = result.RowDatas[i].Parse(result.Fields, isBinary) 682 if err != nil { 683 return err 684 } 685 } 686 687 return nil 688 } 689 690 // drainResults will read all packets for a result set and ignore them. 691 func (dc *DirectConnection) drainResults() error { 692 for { 693 data, err := dc.conn.ReadEphemeralPacket() 694 if err != nil { 695 dc.conn.RecycleReadPacket() 696 return err 697 } 698 699 if dc.isEOFPacket(data) { 700 dc.conn.RecycleReadPacket() 701 return nil 702 } else if data[0] == mysql.ErrHeader { 703 err := dc.handleErrorPacket(data) 704 dc.conn.RecycleReadPacket() 705 return err 706 } 707 dc.conn.RecycleReadPacket() 708 } 709 } 710 711 func (dc *DirectConnection) isEOFPacket(data []byte) bool { 712 return data[0] == mysql.EOFHeader && len(data) <= 5 713 } 714 715 func (dc *DirectConnection) handleOKPacket(data []byte) (*mysql.Result, error) { 716 var pos = 1 717 718 r := new(mysql.Result) 719 720 r.AffectedRows, pos, _, _ = mysql.ReadLenEncInt(data, pos) 721 r.InsertID, pos, _, _ = mysql.ReadLenEncInt(data, pos) 722 723 if dc.capability&mysql.ClientProtocol41 > 0 { 724 r.Status = binary.LittleEndian.Uint16(data[pos:]) 725 dc.status = r.Status 726 pos += 2 727 728 // TODO strict_mode, check warnings as error 729 // Warnings := binary.LittleEndian.Uint16(data[pos:]) 730 // pos += 2 731 } else if dc.capability&mysql.ClientTransactions > 0 { 732 r.Status = binary.LittleEndian.Uint16(data[pos:]) 733 dc.status = r.Status 734 pos += 2 735 } 736 737 //info 738 return r, nil 739 } 740 741 func (dc *DirectConnection) handleErrorPacket(data []byte) error { 742 e := new(mysql.SQLError) 743 744 var pos = 1 745 746 e.Code = binary.LittleEndian.Uint16(data[pos:]) 747 pos += 2 748 749 if dc.capability&mysql.ClientProtocol41 > 0 { 750 // skip '#' 751 pos++ 752 e.State = string(data[pos : pos+5]) 753 pos += 5 754 } 755 756 e.Message = string(data[pos:]) 757 758 return e 759 } 760 761 func (dc *DirectConnection) readResult(binary bool, maxRows int) (*mysql.Result, error) { 762 data, err := dc.readPacket() 763 if err != nil { 764 return nil, err 765 } 766 if data[0] == mysql.OKHeader { 767 return dc.handleOKPacket(data) 768 } else if data[0] == mysql.ErrHeader { 769 return nil, dc.handleErrorPacket(data) 770 } else if data[0] == mysql.LocalInFileHeader { 771 return nil, mysql.ErrMalformPacket 772 } 773 774 return dc.readResultSet(data, binary, maxRows) 775 } 776 777 // IsAutoCommit check if autocommit 778 func (dc *DirectConnection) IsAutoCommit() bool { 779 return dc.status&mysql.ServerStatusAutocommit > 0 780 } 781 782 // IsInTransaction check if in transaction 783 func (dc *DirectConnection) IsInTransaction() bool { 784 return dc.status&mysql.ServerStatusInTrans > 0 785 } 786 787 // GetCharset return charset of specific connection 788 func (dc *DirectConnection) GetCharset() string { 789 return dc.charset 790 } 791 792 func appendSetCharset(buf *bytes.Buffer, charset string, collation string) { 793 if buf.Len() != 0 { 794 buf.WriteString(",") 795 } else { 796 buf.WriteString("SET NAMES '") 797 } 798 buf.WriteString(charset) 799 buf.WriteString("' COLLATE '") 800 buf.WriteString(collation) 801 buf.WriteString("'") 802 } 803 804 func appendSetVariable(buf *bytes.Buffer, key string, value interface{}) { 805 if buf.Len() != 0 { 806 buf.WriteString(",") 807 } else { 808 buf.WriteString("SET ") 809 } 810 buf.WriteString(key) 811 buf.WriteString(" = ") 812 switch v := value.(type) { 813 case string: 814 if strings.ToLower(v) == mysql.KeywordDefault { 815 buf.WriteString(v) 816 } else { 817 buf.WriteString("'") 818 buf.WriteString(v) 819 buf.WriteString("'") 820 } 821 case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: 822 buf.WriteString(fmt.Sprintf("%d", v)) 823 default: 824 buf.WriteString("'") 825 buf.WriteString(fmt.Sprintf("%v", v)) 826 buf.WriteString("'") 827 } 828 } 829 830 func appendSetVariableToDefault(buf *bytes.Buffer, key string) { 831 if buf.Len() != 0 { 832 buf.WriteString(",") 833 } else { 834 buf.WriteString("SET ") 835 } 836 buf.WriteString(key) 837 buf.WriteString(" = ") 838 buf.WriteString("DEFAULT") 839 }