github.com/hellobchain/third_party@v0.0.0-20230331131523-deb0478a2e52/go-sql-driver/mysql/connection.go (about) 1 // Go MySQL Driver - A MySQL-Driver for Go's database/sql package 2 // 3 // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. 4 // 5 // This Source Code Form is subject to the terms of the Mozilla Public 6 // License, v. 2.0. If a copy of the MPL was not distributed with this file, 7 // You can obtain one at http://mozilla.org/MPL/2.0/. 8 9 package mysql 10 11 import ( 12 "context" 13 "database/sql" 14 "database/sql/driver" 15 "io" 16 "net" 17 "strconv" 18 "strings" 19 "time" 20 ) 21 22 type mysqlConn struct { 23 buf buffer 24 netConn net.Conn 25 rawConn net.Conn // underlying connection when netConn is TLS connection. 26 affectedRows uint64 27 insertId uint64 28 cfg *Config 29 maxAllowedPacket int 30 maxWriteSize int 31 writeTimeout time.Duration 32 flags clientFlag 33 status statusFlag 34 sequence uint8 35 parseTime bool 36 reset bool // set when the Go SQL package calls ResetSession 37 38 // for context support (Go 1.8+) 39 watching bool 40 watcher chan<- context.Context 41 closech chan struct{} 42 finished chan<- struct{} 43 canceled atomicError // set non-nil if conn is canceled 44 closed atomicBool // set when conn is closed, before closech is closed 45 } 46 47 // Handles parameters set in DSN after the connection is established 48 func (mc *mysqlConn) handleParams() (err error) { 49 for param, val := range mc.cfg.Params { 50 switch param { 51 // Charset 52 case "charset": 53 charsets := strings.Split(val, ",") 54 for i := range charsets { 55 // ignore errors here - a charset may not exist 56 err = mc.exec("SET NAMES " + charsets[i]) 57 if err == nil { 58 break 59 } 60 } 61 if err != nil { 62 return 63 } 64 65 // System Vars 66 default: 67 err = mc.exec("SET " + param + "=" + val + "") 68 if err != nil { 69 return 70 } 71 } 72 } 73 74 return 75 } 76 77 func (mc *mysqlConn) markBadConn(err error) error { 78 if mc == nil { 79 return err 80 } 81 if err != errBadConnNoWrite { 82 return err 83 } 84 return driver.ErrBadConn 85 } 86 87 func (mc *mysqlConn) Begin() (driver.Tx, error) { 88 return mc.begin(false) 89 } 90 91 func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) { 92 if mc.closed.IsSet() { 93 errLog.Print(ErrInvalidConn) 94 return nil, driver.ErrBadConn 95 } 96 var q string 97 if readOnly { 98 q = "START TRANSACTION READ ONLY" 99 } else { 100 q = "START TRANSACTION" 101 } 102 err := mc.exec(q) 103 if err == nil { 104 return &mysqlTx{mc}, err 105 } 106 return nil, mc.markBadConn(err) 107 } 108 109 func (mc *mysqlConn) Close() (err error) { 110 // Makes Close idempotent 111 if !mc.closed.IsSet() { 112 err = mc.writeCommandPacket(comQuit) 113 } 114 115 mc.cleanup() 116 117 return 118 } 119 120 // Closes the network connection and unsets internal variables. Do not call this 121 // function after successfully authentication, call Close instead. This function 122 // is called before auth or on auth failure because MySQL will have already 123 // closed the network connection. 124 func (mc *mysqlConn) cleanup() { 125 if !mc.closed.TrySet(true) { 126 return 127 } 128 129 // Makes cleanup idempotent 130 close(mc.closech) 131 if mc.netConn == nil { 132 return 133 } 134 if err := mc.netConn.Close(); err != nil { 135 errLog.Print(err) 136 } 137 } 138 139 func (mc *mysqlConn) error() error { 140 if mc.closed.IsSet() { 141 if err := mc.canceled.Value(); err != nil { 142 return err 143 } 144 return ErrInvalidConn 145 } 146 return nil 147 } 148 149 func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { 150 if mc.closed.IsSet() { 151 errLog.Print(ErrInvalidConn) 152 return nil, driver.ErrBadConn 153 } 154 // Send command 155 err := mc.writeCommandPacketStr(comStmtPrepare, query) 156 if err != nil { 157 return nil, mc.markBadConn(err) 158 } 159 160 stmt := &mysqlStmt{ 161 mc: mc, 162 } 163 164 // Read Result 165 columnCount, err := stmt.readPrepareResultPacket() 166 if err == nil { 167 if stmt.paramCount > 0 { 168 if err = mc.readUntilEOF(); err != nil { 169 return nil, err 170 } 171 } 172 173 if columnCount > 0 { 174 err = mc.readUntilEOF() 175 } 176 } 177 178 return stmt, err 179 } 180 181 func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) { 182 // Number of ? should be same to len(args) 183 if strings.Count(query, "?") != len(args) { 184 return "", driver.ErrSkip 185 } 186 187 buf, err := mc.buf.takeCompleteBuffer() 188 if err != nil { 189 // can not take the buffer. Something must be wrong with the connection 190 errLog.Print(err) 191 return "", ErrInvalidConn 192 } 193 buf = buf[:0] 194 argPos := 0 195 196 for i := 0; i < len(query); i++ { 197 q := strings.IndexByte(query[i:], '?') 198 if q == -1 { 199 buf = append(buf, query[i:]...) 200 break 201 } 202 buf = append(buf, query[i:i+q]...) 203 i += q 204 205 arg := args[argPos] 206 argPos++ 207 208 if arg == nil { 209 buf = append(buf, "NULL"...) 210 continue 211 } 212 213 switch v := arg.(type) { 214 case int64: 215 buf = strconv.AppendInt(buf, v, 10) 216 case uint64: 217 // Handle uint64 explicitly because our custom ConvertValue emits unsigned values 218 buf = strconv.AppendUint(buf, v, 10) 219 case float64: 220 buf = strconv.AppendFloat(buf, v, 'g', -1, 64) 221 case bool: 222 if v { 223 buf = append(buf, '1') 224 } else { 225 buf = append(buf, '0') 226 } 227 case time.Time: 228 if v.IsZero() { 229 buf = append(buf, "'0000-00-00'"...) 230 } else { 231 v := v.In(mc.cfg.Loc) 232 v = v.Add(time.Nanosecond * 500) // To round under microsecond 233 year := v.Year() 234 year100 := year / 100 235 year1 := year % 100 236 month := v.Month() 237 day := v.Day() 238 hour := v.Hour() 239 minute := v.Minute() 240 second := v.Second() 241 micro := v.Nanosecond() / 1000 242 243 buf = append(buf, []byte{ 244 '\'', 245 digits10[year100], digits01[year100], 246 digits10[year1], digits01[year1], 247 '-', 248 digits10[month], digits01[month], 249 '-', 250 digits10[day], digits01[day], 251 ' ', 252 digits10[hour], digits01[hour], 253 ':', 254 digits10[minute], digits01[minute], 255 ':', 256 digits10[second], digits01[second], 257 }...) 258 259 if micro != 0 { 260 micro10000 := micro / 10000 261 micro100 := micro / 100 % 100 262 micro1 := micro % 100 263 buf = append(buf, []byte{ 264 '.', 265 digits10[micro10000], digits01[micro10000], 266 digits10[micro100], digits01[micro100], 267 digits10[micro1], digits01[micro1], 268 }...) 269 } 270 buf = append(buf, '\'') 271 } 272 case []byte: 273 if v == nil { 274 buf = append(buf, "NULL"...) 275 } else { 276 buf = append(buf, "_binary'"...) 277 if mc.status&statusNoBackslashEscapes == 0 { 278 buf = escapeBytesBackslash(buf, v) 279 } else { 280 buf = escapeBytesQuotes(buf, v) 281 } 282 buf = append(buf, '\'') 283 } 284 case string: 285 buf = append(buf, '\'') 286 if mc.status&statusNoBackslashEscapes == 0 { 287 buf = escapeStringBackslash(buf, v) 288 } else { 289 buf = escapeStringQuotes(buf, v) 290 } 291 buf = append(buf, '\'') 292 default: 293 return "", driver.ErrSkip 294 } 295 296 if len(buf)+4 > mc.maxAllowedPacket { 297 return "", driver.ErrSkip 298 } 299 } 300 if argPos != len(args) { 301 return "", driver.ErrSkip 302 } 303 return string(buf), nil 304 } 305 306 func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { 307 if mc.closed.IsSet() { 308 errLog.Print(ErrInvalidConn) 309 return nil, driver.ErrBadConn 310 } 311 if len(args) != 0 { 312 if !mc.cfg.InterpolateParams { 313 return nil, driver.ErrSkip 314 } 315 // try to interpolate the parameters to save extra roundtrips for preparing and closing a statement 316 prepared, err := mc.interpolateParams(query, args) 317 if err != nil { 318 return nil, err 319 } 320 query = prepared 321 } 322 mc.affectedRows = 0 323 mc.insertId = 0 324 325 err := mc.exec(query) 326 if err == nil { 327 return &mysqlResult{ 328 affectedRows: int64(mc.affectedRows), 329 insertId: int64(mc.insertId), 330 }, err 331 } 332 return nil, mc.markBadConn(err) 333 } 334 335 // Internal function to execute commands 336 func (mc *mysqlConn) exec(query string) error { 337 // Send command 338 if err := mc.writeCommandPacketStr(comQuery, query); err != nil { 339 return mc.markBadConn(err) 340 } 341 342 // Read Result 343 resLen, err := mc.readResultSetHeaderPacket() 344 if err != nil { 345 return err 346 } 347 348 if resLen > 0 { 349 // columns 350 if err := mc.readUntilEOF(); err != nil { 351 return err 352 } 353 354 // rows 355 if err := mc.readUntilEOF(); err != nil { 356 return err 357 } 358 } 359 360 return mc.discardResults() 361 } 362 363 func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) { 364 return mc.query(query, args) 365 } 366 367 func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) { 368 if mc.closed.IsSet() { 369 errLog.Print(ErrInvalidConn) 370 return nil, driver.ErrBadConn 371 } 372 if len(args) != 0 { 373 if !mc.cfg.InterpolateParams { 374 return nil, driver.ErrSkip 375 } 376 // try client-side prepare to reduce roundtrip 377 prepared, err := mc.interpolateParams(query, args) 378 if err != nil { 379 return nil, err 380 } 381 query = prepared 382 } 383 // Send command 384 err := mc.writeCommandPacketStr(comQuery, query) 385 if err == nil { 386 // Read Result 387 var resLen int 388 resLen, err = mc.readResultSetHeaderPacket() 389 if err == nil { 390 rows := new(textRows) 391 rows.mc = mc 392 393 if resLen == 0 { 394 rows.rs.done = true 395 396 switch err := rows.NextResultSet(); err { 397 case nil, io.EOF: 398 return rows, nil 399 default: 400 return nil, err 401 } 402 } 403 404 // Columns 405 rows.rs.columns, err = mc.readColumns(resLen) 406 return rows, err 407 } 408 } 409 return nil, mc.markBadConn(err) 410 } 411 412 // Gets the value of the given MySQL System Variable 413 // The returned byte slice is only valid until the next read 414 func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { 415 // Send command 416 if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil { 417 return nil, err 418 } 419 420 // Read Result 421 resLen, err := mc.readResultSetHeaderPacket() 422 if err == nil { 423 rows := new(textRows) 424 rows.mc = mc 425 rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}} 426 427 if resLen > 0 { 428 // Columns 429 if err := mc.readUntilEOF(); err != nil { 430 return nil, err 431 } 432 } 433 434 dest := make([]driver.Value, resLen) 435 if err = rows.readRow(dest); err == nil { 436 return dest[0].([]byte), mc.readUntilEOF() 437 } 438 } 439 return nil, err 440 } 441 442 // finish is called when the query has canceled. 443 func (mc *mysqlConn) cancel(err error) { 444 mc.canceled.Set(err) 445 mc.cleanup() 446 } 447 448 // finish is called when the query has succeeded. 449 func (mc *mysqlConn) finish() { 450 if !mc.watching || mc.finished == nil { 451 return 452 } 453 select { 454 case mc.finished <- struct{}{}: 455 mc.watching = false 456 case <-mc.closech: 457 } 458 } 459 460 // Ping implements driver.Pinger interface 461 func (mc *mysqlConn) Ping(ctx context.Context) (err error) { 462 if mc.closed.IsSet() { 463 errLog.Print(ErrInvalidConn) 464 return driver.ErrBadConn 465 } 466 467 if err = mc.watchCancel(ctx); err != nil { 468 return 469 } 470 defer mc.finish() 471 472 if err = mc.writeCommandPacket(comPing); err != nil { 473 return mc.markBadConn(err) 474 } 475 476 return mc.readResultOK() 477 } 478 479 // BeginTx implements driver.ConnBeginTx interface 480 func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { 481 if err := mc.watchCancel(ctx); err != nil { 482 return nil, err 483 } 484 defer mc.finish() 485 486 if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault { 487 level, err := mapIsolationLevel(opts.Isolation) 488 if err != nil { 489 return nil, err 490 } 491 err = mc.exec("SET TRANSACTION ISOLATION LEVEL " + level) 492 if err != nil { 493 return nil, err 494 } 495 } 496 497 return mc.begin(opts.ReadOnly) 498 } 499 500 func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { 501 dargs, err := namedValueToValue(args) 502 if err != nil { 503 return nil, err 504 } 505 506 if err := mc.watchCancel(ctx); err != nil { 507 return nil, err 508 } 509 510 rows, err := mc.query(query, dargs) 511 if err != nil { 512 mc.finish() 513 return nil, err 514 } 515 rows.finish = mc.finish 516 return rows, err 517 } 518 519 func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { 520 dargs, err := namedValueToValue(args) 521 if err != nil { 522 return nil, err 523 } 524 525 if err := mc.watchCancel(ctx); err != nil { 526 return nil, err 527 } 528 defer mc.finish() 529 530 return mc.Exec(query, dargs) 531 } 532 533 func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { 534 if err := mc.watchCancel(ctx); err != nil { 535 return nil, err 536 } 537 538 stmt, err := mc.Prepare(query) 539 mc.finish() 540 if err != nil { 541 return nil, err 542 } 543 544 select { 545 default: 546 case <-ctx.Done(): 547 stmt.Close() 548 return nil, ctx.Err() 549 } 550 return stmt, nil 551 } 552 553 func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { 554 dargs, err := namedValueToValue(args) 555 if err != nil { 556 return nil, err 557 } 558 559 if err := stmt.mc.watchCancel(ctx); err != nil { 560 return nil, err 561 } 562 563 rows, err := stmt.query(dargs) 564 if err != nil { 565 stmt.mc.finish() 566 return nil, err 567 } 568 rows.finish = stmt.mc.finish 569 return rows, err 570 } 571 572 func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { 573 dargs, err := namedValueToValue(args) 574 if err != nil { 575 return nil, err 576 } 577 578 if err := stmt.mc.watchCancel(ctx); err != nil { 579 return nil, err 580 } 581 defer stmt.mc.finish() 582 583 return stmt.Exec(dargs) 584 } 585 586 func (mc *mysqlConn) watchCancel(ctx context.Context) error { 587 if mc.watching { 588 // Reach here if canceled, 589 // so the connection is already invalid 590 mc.cleanup() 591 return nil 592 } 593 // When ctx is already cancelled, don't watch it. 594 if err := ctx.Err(); err != nil { 595 return err 596 } 597 // When ctx is not cancellable, don't watch it. 598 if ctx.Done() == nil { 599 return nil 600 } 601 // When watcher is not alive, can't watch it. 602 if mc.watcher == nil { 603 return nil 604 } 605 606 mc.watching = true 607 mc.watcher <- ctx 608 return nil 609 } 610 611 func (mc *mysqlConn) startWatcher() { 612 watcher := make(chan context.Context, 1) 613 mc.watcher = watcher 614 finished := make(chan struct{}) 615 mc.finished = finished 616 go func() { 617 for { 618 var ctx context.Context 619 select { 620 case ctx = <-watcher: 621 case <-mc.closech: 622 return 623 } 624 625 select { 626 case <-ctx.Done(): 627 mc.cancel(ctx.Err()) 628 case <-finished: 629 case <-mc.closech: 630 return 631 } 632 } 633 }() 634 } 635 636 func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) { 637 nv.Value, err = converter{}.ConvertValue(nv.Value) 638 return 639 } 640 641 // ResetSession implements driver.SessionResetter. 642 // (From Go 1.10) 643 func (mc *mysqlConn) ResetSession(ctx context.Context) error { 644 if mc.closed.IsSet() { 645 return driver.ErrBadConn 646 } 647 mc.reset = true 648 return nil 649 }