github.com/hellobchain/third_party@v0.0.0-20230331131523-deb0478a2e52/go-sql-driver/mysql/connection.go (about)

     1  // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
     2  //
     3  // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
     4  //
     5  // This Source Code Form is subject to the terms of the Mozilla Public
     6  // License, v. 2.0. If a copy of the MPL was not distributed with this file,
     7  // You can obtain one at http://mozilla.org/MPL/2.0/.
     8  
     9  package mysql
    10  
    11  import (
    12  	"context"
    13  	"database/sql"
    14  	"database/sql/driver"
    15  	"io"
    16  	"net"
    17  	"strconv"
    18  	"strings"
    19  	"time"
    20  )
    21  
    22  type mysqlConn struct {
    23  	buf              buffer
    24  	netConn          net.Conn
    25  	rawConn          net.Conn // underlying connection when netConn is TLS connection.
    26  	affectedRows     uint64
    27  	insertId         uint64
    28  	cfg              *Config
    29  	maxAllowedPacket int
    30  	maxWriteSize     int
    31  	writeTimeout     time.Duration
    32  	flags            clientFlag
    33  	status           statusFlag
    34  	sequence         uint8
    35  	parseTime        bool
    36  	reset            bool // set when the Go SQL package calls ResetSession
    37  
    38  	// for context support (Go 1.8+)
    39  	watching bool
    40  	watcher  chan<- context.Context
    41  	closech  chan struct{}
    42  	finished chan<- struct{}
    43  	canceled atomicError // set non-nil if conn is canceled
    44  	closed   atomicBool  // set when conn is closed, before closech is closed
    45  }
    46  
    47  // Handles parameters set in DSN after the connection is established
    48  func (mc *mysqlConn) handleParams() (err error) {
    49  	for param, val := range mc.cfg.Params {
    50  		switch param {
    51  		// Charset
    52  		case "charset":
    53  			charsets := strings.Split(val, ",")
    54  			for i := range charsets {
    55  				// ignore errors here - a charset may not exist
    56  				err = mc.exec("SET NAMES " + charsets[i])
    57  				if err == nil {
    58  					break
    59  				}
    60  			}
    61  			if err != nil {
    62  				return
    63  			}
    64  
    65  		// System Vars
    66  		default:
    67  			err = mc.exec("SET " + param + "=" + val + "")
    68  			if err != nil {
    69  				return
    70  			}
    71  		}
    72  	}
    73  
    74  	return
    75  }
    76  
    77  func (mc *mysqlConn) markBadConn(err error) error {
    78  	if mc == nil {
    79  		return err
    80  	}
    81  	if err != errBadConnNoWrite {
    82  		return err
    83  	}
    84  	return driver.ErrBadConn
    85  }
    86  
    87  func (mc *mysqlConn) Begin() (driver.Tx, error) {
    88  	return mc.begin(false)
    89  }
    90  
    91  func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
    92  	if mc.closed.IsSet() {
    93  		errLog.Print(ErrInvalidConn)
    94  		return nil, driver.ErrBadConn
    95  	}
    96  	var q string
    97  	if readOnly {
    98  		q = "START TRANSACTION READ ONLY"
    99  	} else {
   100  		q = "START TRANSACTION"
   101  	}
   102  	err := mc.exec(q)
   103  	if err == nil {
   104  		return &mysqlTx{mc}, err
   105  	}
   106  	return nil, mc.markBadConn(err)
   107  }
   108  
   109  func (mc *mysqlConn) Close() (err error) {
   110  	// Makes Close idempotent
   111  	if !mc.closed.IsSet() {
   112  		err = mc.writeCommandPacket(comQuit)
   113  	}
   114  
   115  	mc.cleanup()
   116  
   117  	return
   118  }
   119  
   120  // Closes the network connection and unsets internal variables. Do not call this
   121  // function after successfully authentication, call Close instead. This function
   122  // is called before auth or on auth failure because MySQL will have already
   123  // closed the network connection.
   124  func (mc *mysqlConn) cleanup() {
   125  	if !mc.closed.TrySet(true) {
   126  		return
   127  	}
   128  
   129  	// Makes cleanup idempotent
   130  	close(mc.closech)
   131  	if mc.netConn == nil {
   132  		return
   133  	}
   134  	if err := mc.netConn.Close(); err != nil {
   135  		errLog.Print(err)
   136  	}
   137  }
   138  
   139  func (mc *mysqlConn) error() error {
   140  	if mc.closed.IsSet() {
   141  		if err := mc.canceled.Value(); err != nil {
   142  			return err
   143  		}
   144  		return ErrInvalidConn
   145  	}
   146  	return nil
   147  }
   148  
   149  func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
   150  	if mc.closed.IsSet() {
   151  		errLog.Print(ErrInvalidConn)
   152  		return nil, driver.ErrBadConn
   153  	}
   154  	// Send command
   155  	err := mc.writeCommandPacketStr(comStmtPrepare, query)
   156  	if err != nil {
   157  		return nil, mc.markBadConn(err)
   158  	}
   159  
   160  	stmt := &mysqlStmt{
   161  		mc: mc,
   162  	}
   163  
   164  	// Read Result
   165  	columnCount, err := stmt.readPrepareResultPacket()
   166  	if err == nil {
   167  		if stmt.paramCount > 0 {
   168  			if err = mc.readUntilEOF(); err != nil {
   169  				return nil, err
   170  			}
   171  		}
   172  
   173  		if columnCount > 0 {
   174  			err = mc.readUntilEOF()
   175  		}
   176  	}
   177  
   178  	return stmt, err
   179  }
   180  
   181  func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
   182  	// Number of ? should be same to len(args)
   183  	if strings.Count(query, "?") != len(args) {
   184  		return "", driver.ErrSkip
   185  	}
   186  
   187  	buf, err := mc.buf.takeCompleteBuffer()
   188  	if err != nil {
   189  		// can not take the buffer. Something must be wrong with the connection
   190  		errLog.Print(err)
   191  		return "", ErrInvalidConn
   192  	}
   193  	buf = buf[:0]
   194  	argPos := 0
   195  
   196  	for i := 0; i < len(query); i++ {
   197  		q := strings.IndexByte(query[i:], '?')
   198  		if q == -1 {
   199  			buf = append(buf, query[i:]...)
   200  			break
   201  		}
   202  		buf = append(buf, query[i:i+q]...)
   203  		i += q
   204  
   205  		arg := args[argPos]
   206  		argPos++
   207  
   208  		if arg == nil {
   209  			buf = append(buf, "NULL"...)
   210  			continue
   211  		}
   212  
   213  		switch v := arg.(type) {
   214  		case int64:
   215  			buf = strconv.AppendInt(buf, v, 10)
   216  		case uint64:
   217  			// Handle uint64 explicitly because our custom ConvertValue emits unsigned values
   218  			buf = strconv.AppendUint(buf, v, 10)
   219  		case float64:
   220  			buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
   221  		case bool:
   222  			if v {
   223  				buf = append(buf, '1')
   224  			} else {
   225  				buf = append(buf, '0')
   226  			}
   227  		case time.Time:
   228  			if v.IsZero() {
   229  				buf = append(buf, "'0000-00-00'"...)
   230  			} else {
   231  				v := v.In(mc.cfg.Loc)
   232  				v = v.Add(time.Nanosecond * 500) // To round under microsecond
   233  				year := v.Year()
   234  				year100 := year / 100
   235  				year1 := year % 100
   236  				month := v.Month()
   237  				day := v.Day()
   238  				hour := v.Hour()
   239  				minute := v.Minute()
   240  				second := v.Second()
   241  				micro := v.Nanosecond() / 1000
   242  
   243  				buf = append(buf, []byte{
   244  					'\'',
   245  					digits10[year100], digits01[year100],
   246  					digits10[year1], digits01[year1],
   247  					'-',
   248  					digits10[month], digits01[month],
   249  					'-',
   250  					digits10[day], digits01[day],
   251  					' ',
   252  					digits10[hour], digits01[hour],
   253  					':',
   254  					digits10[minute], digits01[minute],
   255  					':',
   256  					digits10[second], digits01[second],
   257  				}...)
   258  
   259  				if micro != 0 {
   260  					micro10000 := micro / 10000
   261  					micro100 := micro / 100 % 100
   262  					micro1 := micro % 100
   263  					buf = append(buf, []byte{
   264  						'.',
   265  						digits10[micro10000], digits01[micro10000],
   266  						digits10[micro100], digits01[micro100],
   267  						digits10[micro1], digits01[micro1],
   268  					}...)
   269  				}
   270  				buf = append(buf, '\'')
   271  			}
   272  		case []byte:
   273  			if v == nil {
   274  				buf = append(buf, "NULL"...)
   275  			} else {
   276  				buf = append(buf, "_binary'"...)
   277  				if mc.status&statusNoBackslashEscapes == 0 {
   278  					buf = escapeBytesBackslash(buf, v)
   279  				} else {
   280  					buf = escapeBytesQuotes(buf, v)
   281  				}
   282  				buf = append(buf, '\'')
   283  			}
   284  		case string:
   285  			buf = append(buf, '\'')
   286  			if mc.status&statusNoBackslashEscapes == 0 {
   287  				buf = escapeStringBackslash(buf, v)
   288  			} else {
   289  				buf = escapeStringQuotes(buf, v)
   290  			}
   291  			buf = append(buf, '\'')
   292  		default:
   293  			return "", driver.ErrSkip
   294  		}
   295  
   296  		if len(buf)+4 > mc.maxAllowedPacket {
   297  			return "", driver.ErrSkip
   298  		}
   299  	}
   300  	if argPos != len(args) {
   301  		return "", driver.ErrSkip
   302  	}
   303  	return string(buf), nil
   304  }
   305  
   306  func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
   307  	if mc.closed.IsSet() {
   308  		errLog.Print(ErrInvalidConn)
   309  		return nil, driver.ErrBadConn
   310  	}
   311  	if len(args) != 0 {
   312  		if !mc.cfg.InterpolateParams {
   313  			return nil, driver.ErrSkip
   314  		}
   315  		// try to interpolate the parameters to save extra roundtrips for preparing and closing a statement
   316  		prepared, err := mc.interpolateParams(query, args)
   317  		if err != nil {
   318  			return nil, err
   319  		}
   320  		query = prepared
   321  	}
   322  	mc.affectedRows = 0
   323  	mc.insertId = 0
   324  
   325  	err := mc.exec(query)
   326  	if err == nil {
   327  		return &mysqlResult{
   328  			affectedRows: int64(mc.affectedRows),
   329  			insertId:     int64(mc.insertId),
   330  		}, err
   331  	}
   332  	return nil, mc.markBadConn(err)
   333  }
   334  
   335  // Internal function to execute commands
   336  func (mc *mysqlConn) exec(query string) error {
   337  	// Send command
   338  	if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
   339  		return mc.markBadConn(err)
   340  	}
   341  
   342  	// Read Result
   343  	resLen, err := mc.readResultSetHeaderPacket()
   344  	if err != nil {
   345  		return err
   346  	}
   347  
   348  	if resLen > 0 {
   349  		// columns
   350  		if err := mc.readUntilEOF(); err != nil {
   351  			return err
   352  		}
   353  
   354  		// rows
   355  		if err := mc.readUntilEOF(); err != nil {
   356  			return err
   357  		}
   358  	}
   359  
   360  	return mc.discardResults()
   361  }
   362  
   363  func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
   364  	return mc.query(query, args)
   365  }
   366  
   367  func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
   368  	if mc.closed.IsSet() {
   369  		errLog.Print(ErrInvalidConn)
   370  		return nil, driver.ErrBadConn
   371  	}
   372  	if len(args) != 0 {
   373  		if !mc.cfg.InterpolateParams {
   374  			return nil, driver.ErrSkip
   375  		}
   376  		// try client-side prepare to reduce roundtrip
   377  		prepared, err := mc.interpolateParams(query, args)
   378  		if err != nil {
   379  			return nil, err
   380  		}
   381  		query = prepared
   382  	}
   383  	// Send command
   384  	err := mc.writeCommandPacketStr(comQuery, query)
   385  	if err == nil {
   386  		// Read Result
   387  		var resLen int
   388  		resLen, err = mc.readResultSetHeaderPacket()
   389  		if err == nil {
   390  			rows := new(textRows)
   391  			rows.mc = mc
   392  
   393  			if resLen == 0 {
   394  				rows.rs.done = true
   395  
   396  				switch err := rows.NextResultSet(); err {
   397  				case nil, io.EOF:
   398  					return rows, nil
   399  				default:
   400  					return nil, err
   401  				}
   402  			}
   403  
   404  			// Columns
   405  			rows.rs.columns, err = mc.readColumns(resLen)
   406  			return rows, err
   407  		}
   408  	}
   409  	return nil, mc.markBadConn(err)
   410  }
   411  
   412  // Gets the value of the given MySQL System Variable
   413  // The returned byte slice is only valid until the next read
   414  func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
   415  	// Send command
   416  	if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
   417  		return nil, err
   418  	}
   419  
   420  	// Read Result
   421  	resLen, err := mc.readResultSetHeaderPacket()
   422  	if err == nil {
   423  		rows := new(textRows)
   424  		rows.mc = mc
   425  		rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
   426  
   427  		if resLen > 0 {
   428  			// Columns
   429  			if err := mc.readUntilEOF(); err != nil {
   430  				return nil, err
   431  			}
   432  		}
   433  
   434  		dest := make([]driver.Value, resLen)
   435  		if err = rows.readRow(dest); err == nil {
   436  			return dest[0].([]byte), mc.readUntilEOF()
   437  		}
   438  	}
   439  	return nil, err
   440  }
   441  
   442  // finish is called when the query has canceled.
   443  func (mc *mysqlConn) cancel(err error) {
   444  	mc.canceled.Set(err)
   445  	mc.cleanup()
   446  }
   447  
   448  // finish is called when the query has succeeded.
   449  func (mc *mysqlConn) finish() {
   450  	if !mc.watching || mc.finished == nil {
   451  		return
   452  	}
   453  	select {
   454  	case mc.finished <- struct{}{}:
   455  		mc.watching = false
   456  	case <-mc.closech:
   457  	}
   458  }
   459  
   460  // Ping implements driver.Pinger interface
   461  func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
   462  	if mc.closed.IsSet() {
   463  		errLog.Print(ErrInvalidConn)
   464  		return driver.ErrBadConn
   465  	}
   466  
   467  	if err = mc.watchCancel(ctx); err != nil {
   468  		return
   469  	}
   470  	defer mc.finish()
   471  
   472  	if err = mc.writeCommandPacket(comPing); err != nil {
   473  		return mc.markBadConn(err)
   474  	}
   475  
   476  	return mc.readResultOK()
   477  }
   478  
   479  // BeginTx implements driver.ConnBeginTx interface
   480  func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
   481  	if err := mc.watchCancel(ctx); err != nil {
   482  		return nil, err
   483  	}
   484  	defer mc.finish()
   485  
   486  	if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault {
   487  		level, err := mapIsolationLevel(opts.Isolation)
   488  		if err != nil {
   489  			return nil, err
   490  		}
   491  		err = mc.exec("SET TRANSACTION ISOLATION LEVEL " + level)
   492  		if err != nil {
   493  			return nil, err
   494  		}
   495  	}
   496  
   497  	return mc.begin(opts.ReadOnly)
   498  }
   499  
   500  func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
   501  	dargs, err := namedValueToValue(args)
   502  	if err != nil {
   503  		return nil, err
   504  	}
   505  
   506  	if err := mc.watchCancel(ctx); err != nil {
   507  		return nil, err
   508  	}
   509  
   510  	rows, err := mc.query(query, dargs)
   511  	if err != nil {
   512  		mc.finish()
   513  		return nil, err
   514  	}
   515  	rows.finish = mc.finish
   516  	return rows, err
   517  }
   518  
   519  func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
   520  	dargs, err := namedValueToValue(args)
   521  	if err != nil {
   522  		return nil, err
   523  	}
   524  
   525  	if err := mc.watchCancel(ctx); err != nil {
   526  		return nil, err
   527  	}
   528  	defer mc.finish()
   529  
   530  	return mc.Exec(query, dargs)
   531  }
   532  
   533  func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
   534  	if err := mc.watchCancel(ctx); err != nil {
   535  		return nil, err
   536  	}
   537  
   538  	stmt, err := mc.Prepare(query)
   539  	mc.finish()
   540  	if err != nil {
   541  		return nil, err
   542  	}
   543  
   544  	select {
   545  	default:
   546  	case <-ctx.Done():
   547  		stmt.Close()
   548  		return nil, ctx.Err()
   549  	}
   550  	return stmt, nil
   551  }
   552  
   553  func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
   554  	dargs, err := namedValueToValue(args)
   555  	if err != nil {
   556  		return nil, err
   557  	}
   558  
   559  	if err := stmt.mc.watchCancel(ctx); err != nil {
   560  		return nil, err
   561  	}
   562  
   563  	rows, err := stmt.query(dargs)
   564  	if err != nil {
   565  		stmt.mc.finish()
   566  		return nil, err
   567  	}
   568  	rows.finish = stmt.mc.finish
   569  	return rows, err
   570  }
   571  
   572  func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
   573  	dargs, err := namedValueToValue(args)
   574  	if err != nil {
   575  		return nil, err
   576  	}
   577  
   578  	if err := stmt.mc.watchCancel(ctx); err != nil {
   579  		return nil, err
   580  	}
   581  	defer stmt.mc.finish()
   582  
   583  	return stmt.Exec(dargs)
   584  }
   585  
   586  func (mc *mysqlConn) watchCancel(ctx context.Context) error {
   587  	if mc.watching {
   588  		// Reach here if canceled,
   589  		// so the connection is already invalid
   590  		mc.cleanup()
   591  		return nil
   592  	}
   593  	// When ctx is already cancelled, don't watch it.
   594  	if err := ctx.Err(); err != nil {
   595  		return err
   596  	}
   597  	// When ctx is not cancellable, don't watch it.
   598  	if ctx.Done() == nil {
   599  		return nil
   600  	}
   601  	// When watcher is not alive, can't watch it.
   602  	if mc.watcher == nil {
   603  		return nil
   604  	}
   605  
   606  	mc.watching = true
   607  	mc.watcher <- ctx
   608  	return nil
   609  }
   610  
   611  func (mc *mysqlConn) startWatcher() {
   612  	watcher := make(chan context.Context, 1)
   613  	mc.watcher = watcher
   614  	finished := make(chan struct{})
   615  	mc.finished = finished
   616  	go func() {
   617  		for {
   618  			var ctx context.Context
   619  			select {
   620  			case ctx = <-watcher:
   621  			case <-mc.closech:
   622  				return
   623  			}
   624  
   625  			select {
   626  			case <-ctx.Done():
   627  				mc.cancel(ctx.Err())
   628  			case <-finished:
   629  			case <-mc.closech:
   630  				return
   631  			}
   632  		}
   633  	}()
   634  }
   635  
   636  func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) {
   637  	nv.Value, err = converter{}.ConvertValue(nv.Value)
   638  	return
   639  }
   640  
   641  // ResetSession implements driver.SessionResetter.
   642  // (From Go 1.10)
   643  func (mc *mysqlConn) ResetSession(ctx context.Context) error {
   644  	if mc.closed.IsSet() {
   645  		return driver.ErrBadConn
   646  	}
   647  	mc.reset = true
   648  	return nil
   649  }