github.com/dolthub/go-mysql-server@v0.18.0/server/handler.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package server
    16  
    17  import (
    18  	"context"
    19  	"encoding/base64"
    20  	"fmt"
    21  	"io"
    22  	"net"
    23  	"regexp"
    24  	"sync"
    25  	"time"
    26  
    27  	"github.com/dolthub/vitess/go/mysql"
    28  	"github.com/dolthub/vitess/go/netutil"
    29  	"github.com/dolthub/vitess/go/sqltypes"
    30  	querypb "github.com/dolthub/vitess/go/vt/proto/query"
    31  	"github.com/dolthub/vitess/go/vt/sqlparser"
    32  	"github.com/go-kit/kit/metrics/discard"
    33  	"github.com/sirupsen/logrus"
    34  	"go.opentelemetry.io/otel/attribute"
    35  	"go.opentelemetry.io/otel/trace"
    36  	"gopkg.in/src-d/go-errors.v1"
    37  
    38  	sqle "github.com/dolthub/go-mysql-server"
    39  	"github.com/dolthub/go-mysql-server/internal/sockstate"
    40  	"github.com/dolthub/go-mysql-server/sql"
    41  	"github.com/dolthub/go-mysql-server/sql/analyzer"
    42  	"github.com/dolthub/go-mysql-server/sql/plan"
    43  	"github.com/dolthub/go-mysql-server/sql/planbuilder"
    44  	"github.com/dolthub/go-mysql-server/sql/types"
    45  )
    46  
    47  var errConnectionNotFound = errors.NewKind("connection not found: %c")
    48  
    49  // ErrRowTimeout will be returned if the wait for the row is longer than the connection timeout
    50  var ErrRowTimeout = errors.NewKind("row read wait bigger than connection timeout")
    51  
    52  // ErrConnectionWasClosed will be returned if we try to use a previously closed connection
    53  var ErrConnectionWasClosed = errors.NewKind("connection was closed")
    54  
    55  const rowsBatch = 128
    56  
    57  var tcpCheckerSleepDuration time.Duration = 1 * time.Second
    58  
    59  type MultiStmtMode int
    60  
    61  const (
    62  	MultiStmtModeOff MultiStmtMode = 0
    63  	MultiStmtModeOn  MultiStmtMode = 1
    64  )
    65  
    66  // Handler is a connection handler for a SQLe engine, implementing the Vitess mysql.Handler interface.
    67  type Handler struct {
    68  	e                 *sqle.Engine
    69  	sm                *SessionManager
    70  	readTimeout       time.Duration
    71  	disableMultiStmts bool
    72  	maxLoggedQueryLen int
    73  	encodeLoggedQuery bool
    74  	sel               ServerEventListener
    75  }
    76  
    77  var _ mysql.Handler = (*Handler)(nil)
    78  var _ mysql.ExtendedHandler = (*Handler)(nil)
    79  
    80  // NewConnection reports that a new connection has been established.
    81  func (h *Handler) NewConnection(c *mysql.Conn) {
    82  	if h.sel != nil {
    83  		h.sel.ClientConnected()
    84  	}
    85  
    86  	h.sm.AddConn(c)
    87  
    88  	c.DisableClientMultiStatements = h.disableMultiStmts
    89  	logrus.WithField(sql.ConnectionIdLogField, c.ConnectionID).WithField("DisableClientMultiStatements", c.DisableClientMultiStatements).Infof("NewConnection")
    90  }
    91  
    92  func (h *Handler) ComInitDB(c *mysql.Conn, schemaName string) error {
    93  	return h.sm.SetDB(c, schemaName)
    94  }
    95  
    96  // ComPrepare parses, partially analyzes, and caches a prepared statement's plan
    97  // with the given [c.ConnectionID].
    98  func (h *Handler) ComPrepare(c *mysql.Conn, query string, prepare *mysql.PrepareData) ([]*querypb.Field, error) {
    99  	logrus.WithField("query", query).
   100  		WithField("paramsCount", prepare.ParamsCount).
   101  		WithField("statementId", prepare.StatementID).Debugf("preparing query")
   102  
   103  	ctx, err := h.sm.NewContextWithQuery(c, query)
   104  	if err != nil {
   105  		return nil, err
   106  	}
   107  	var analyzed sql.Node
   108  	if analyzer.PreparedStmtDisabled {
   109  		analyzed, err = h.e.AnalyzeQuery(ctx, query)
   110  	} else {
   111  		analyzed, err = h.e.PrepareQuery(ctx, query)
   112  	}
   113  	if err != nil {
   114  		logrus.WithField("query", query).Errorf("unable to prepare query: %s", err.Error())
   115  		err := sql.CastSQLError(err)
   116  		return nil, err
   117  	}
   118  
   119  	// A nil result signals to the handler that the query is not a SELECT statement.
   120  	if nodeReturnsOkResultSchema(analyzed) || types.IsOkResultSchema(analyzed.Schema()) {
   121  		return nil, nil
   122  	}
   123  
   124  	return schemaToFields(ctx, analyzed.Schema()), nil
   125  }
   126  
   127  // These nodes will eventually return an OK result, but their intermediate forms here return a different schema
   128  // than they will at execution time.
   129  func nodeReturnsOkResultSchema(node sql.Node) bool {
   130  	switch node.(type) {
   131  	case *plan.InsertInto, *plan.Update, *plan.UpdateJoin, *plan.DeleteFrom:
   132  		return true
   133  	}
   134  	return false
   135  }
   136  
   137  func (h *Handler) ComPrepareParsed(c *mysql.Conn, query string, parsed sqlparser.Statement, prepare *mysql.PrepareData) (mysql.ParsedQuery, []*querypb.Field, error) {
   138  	logrus.WithField("query", query).
   139  		WithField("paramsCount", prepare.ParamsCount).
   140  		WithField("statementId", prepare.StatementID).Debugf("preparing query")
   141  
   142  	ctx, err := h.sm.NewContextWithQuery(c, query)
   143  	if err != nil {
   144  		return nil, nil, err
   145  	}
   146  
   147  	analyzed, err := h.e.PrepareParsedQuery(ctx, query, query, parsed)
   148  	if err != nil {
   149  		logrus.WithField("query", query).Errorf("unable to prepare query: %s", err.Error())
   150  		err := sql.CastSQLError(err)
   151  		return nil, nil, err
   152  	}
   153  
   154  	var fields []*querypb.Field
   155  	// The return result fields should only be directly translated if it doesn't correspond to an OK result.
   156  	// See comment in ComPrepare
   157  	if !(nodeReturnsOkResultSchema(analyzed) || types.IsOkResultSchema(analyzed.Schema())) {
   158  		fields = nil
   159  	} else {
   160  		fields = schemaToFields(ctx, analyzed.Schema())
   161  	}
   162  
   163  	return analyzed, fields, nil
   164  }
   165  
   166  func (h *Handler) ComBind(c *mysql.Conn, query string, parsedQuery mysql.ParsedQuery, prepare *mysql.PrepareData) (mysql.BoundQuery, []*querypb.Field, error) {
   167  	ctx, err := h.sm.NewContextWithQuery(c, query)
   168  	if err != nil {
   169  		return nil, nil, err
   170  	}
   171  
   172  	stmt, ok := parsedQuery.(sqlparser.Statement)
   173  	if !ok {
   174  		return nil, nil, fmt.Errorf("parsedQuery must be a sqlparser.Statement, but got %T", parsedQuery)
   175  	}
   176  
   177  	queryPlan, err := h.e.BoundQueryPlan(ctx, query, stmt, prepare.BindVars)
   178  	if err != nil {
   179  		return nil, nil, err
   180  	}
   181  
   182  	return queryPlan, schemaToFields(ctx, queryPlan.Schema()), nil
   183  }
   184  
   185  func (h *Handler) ComExecuteBound(c *mysql.Conn, query string, boundQuery mysql.BoundQuery, callback mysql.ResultSpoolFn) error {
   186  	plan, ok := boundQuery.(sql.Node)
   187  	if !ok {
   188  		return fmt.Errorf("boundQuery must be a sql.Node, but got %T", boundQuery)
   189  	}
   190  
   191  	return h.errorWrappedComExec(c, query, plan, callback)
   192  }
   193  
   194  func (h *Handler) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error {
   195  	_, err := h.errorWrappedDoQuery(c, prepare.PrepareStmt, nil, MultiStmtModeOff, prepare.BindVars, func(res *sqltypes.Result, more bool) error {
   196  		return callback(res)
   197  	})
   198  	return err
   199  }
   200  
   201  // ComResetConnection implements the mysql.Handler interface.
   202  //
   203  // This command resets the connection's session, clearing out any cached prepared statements, locks, user and
   204  // session variables. The currently selected database is preserved.
   205  //
   206  // The COM_RESET command can be sent manually through the mysql client by issuing the "resetconnection" (or "\x")
   207  // client command.
   208  func (h *Handler) ComResetConnection(c *mysql.Conn) error {
   209  	logrus.WithField("connectionId", c.ConnectionID).Debug("COM_RESET_CONNECTION command received")
   210  
   211  	// Grab the currently selected database name
   212  	s := h.sm.session(c)
   213  	db := s.GetCurrentDatabase()
   214  
   215  	// Dispose of the connection's current session
   216  	h.maybeReleaseAllLocks(c)
   217  	h.e.CloseSession(c.ConnectionID)
   218  
   219  	// Create a new session and set the current database
   220  	err := h.sm.NewSession(context.Background(), c)
   221  	if err != nil {
   222  		return err
   223  	}
   224  	s = h.sm.session(c)
   225  	s.SetCurrentDatabase(db)
   226  	return nil
   227  }
   228  
   229  func (h *Handler) ParserOptionsForConnection(c *mysql.Conn) (sqlparser.ParserOptions, error) {
   230  	ctx, err := h.sm.NewContext(c)
   231  	if err != nil {
   232  		return sqlparser.ParserOptions{}, err
   233  	}
   234  	return sql.LoadSqlMode(ctx).ParserOptions(), nil
   235  }
   236  
   237  // ConnectionClosed reports that a connection has been closed.
   238  func (h *Handler) ConnectionClosed(c *mysql.Conn) {
   239  	defer func() {
   240  		if h.sel != nil {
   241  			h.sel.ClientDisconnected()
   242  		}
   243  	}()
   244  
   245  	defer h.sm.RemoveConn(c)
   246  	defer h.e.CloseSession(c.ConnectionID)
   247  
   248  	h.maybeReleaseAllLocks(c)
   249  
   250  	logrus.WithField(sql.ConnectionIdLogField, c.ConnectionID).Infof("ConnectionClosed")
   251  }
   252  
   253  // maybeReleaseAllLocks makes a best effort attempt to release all locks on the given connection. If the attempt fails,
   254  // an error is logged but not returned.
   255  func (h *Handler) maybeReleaseAllLocks(c *mysql.Conn) {
   256  	if ctx, err := h.sm.NewContextWithQuery(c, ""); err != nil {
   257  		logrus.Errorf("unable to release all locks on session close: %s", err)
   258  		logrus.Errorf("unable to unlock tables on session close: %s", err)
   259  	} else {
   260  		_, err = h.e.LS.ReleaseAll(ctx)
   261  		if err != nil {
   262  			logrus.Errorf("unable to release all locks on session close: %s", err)
   263  		}
   264  		if err = h.e.Analyzer.Catalog.UnlockTables(ctx, c.ConnectionID); err != nil {
   265  			logrus.Errorf("unable to unlock tables on session close: %s", err)
   266  		}
   267  	}
   268  }
   269  
   270  func (h *Handler) ComMultiQuery(
   271  	c *mysql.Conn,
   272  	query string,
   273  	callback mysql.ResultSpoolFn,
   274  ) (string, error) {
   275  	return h.errorWrappedDoQuery(c, query, nil, MultiStmtModeOn, nil, callback)
   276  }
   277  
   278  // ComQuery executes a SQL query on the SQLe engine.
   279  func (h *Handler) ComQuery(
   280  	c *mysql.Conn,
   281  	query string,
   282  	callback mysql.ResultSpoolFn,
   283  ) error {
   284  	_, err := h.errorWrappedDoQuery(c, query, nil, MultiStmtModeOff, nil, callback)
   285  	return err
   286  }
   287  
   288  // ComParsedQuery executes a pre-parsed SQL query on the SQLe engine.
   289  func (h *Handler) ComParsedQuery(
   290  	c *mysql.Conn,
   291  	query string,
   292  	parsed sqlparser.Statement,
   293  	callback mysql.ResultSpoolFn,
   294  ) error {
   295  	_, err := h.errorWrappedDoQuery(c, query, parsed, MultiStmtModeOff, nil, callback)
   296  	return err
   297  }
   298  
   299  var queryLoggingRegex = regexp.MustCompile(`[\r\n\t ]+`)
   300  
   301  func (h *Handler) doQuery(
   302  	c *mysql.Conn,
   303  	query string,
   304  	parsed sqlparser.Statement,
   305  	analyzedPlan sql.Node,
   306  	mode MultiStmtMode,
   307  	queryExec QueryExecutor,
   308  	bindings map[string]*querypb.BindVariable,
   309  	callback func(*sqltypes.Result, bool) error,
   310  ) (string, error) {
   311  	ctx, err := h.sm.NewContext(c)
   312  	if err != nil {
   313  		return "", err
   314  	}
   315  
   316  	var remainder string
   317  	var prequery string
   318  	if parsed == nil {
   319  		_, inPreparedCache := h.e.PreparedDataCache.GetCachedStmt(ctx.Session.ID(), query)
   320  		if mode == MultiStmtModeOn && !inPreparedCache {
   321  			parsed, prequery, remainder, err = planbuilder.ParseOnly(ctx, query, true)
   322  			if prequery != "" {
   323  				query = prequery
   324  			}
   325  		}
   326  	}
   327  
   328  	ctx = ctx.WithQuery(query)
   329  	more := remainder != ""
   330  
   331  	var queryStr string
   332  	if h.encodeLoggedQuery {
   333  		queryStr = base64.StdEncoding.EncodeToString([]byte(query))
   334  	} else if logrus.IsLevelEnabled(logrus.DebugLevel) {
   335  		// this is expensive, so skip this unless we're logging at DEBUG level
   336  		queryStr = string(queryLoggingRegex.ReplaceAll([]byte(query), []byte(" ")))
   337  		if h.maxLoggedQueryLen > 0 && len(queryStr) > h.maxLoggedQueryLen {
   338  			queryStr = queryStr[:h.maxLoggedQueryLen] + "..."
   339  		}
   340  	}
   341  
   342  	if queryStr != "" {
   343  		ctx.SetLogger(ctx.GetLogger().WithField("query", queryStr))
   344  	}
   345  	ctx.GetLogger().Debugf("Starting query")
   346  
   347  	finish := observeQuery(ctx, query)
   348  	defer finish(err)
   349  
   350  	start := time.Now()
   351  
   352  	ctx.GetLogger().Tracef("beginning execution")
   353  
   354  	oCtx := ctx
   355  	eg, ctx := ctx.NewErrgroup()
   356  
   357  	// TODO: it would be nice to put this logic in the engine, not the handler, but we don't want the process to be
   358  	//  marked done until we're done spooling rows over the wire
   359  	ctx, err = ctx.ProcessList.BeginQuery(ctx, query)
   360  	defer func() {
   361  		if err != nil && ctx != nil {
   362  			ctx.ProcessList.EndQuery(ctx)
   363  		}
   364  	}()
   365  
   366  	// TODO (next): this method needs a function param that produces the following elements, rather than hard-coding
   367  	schema, rowIter, err := queryExec(ctx, query, parsed, analyzedPlan, bindings)
   368  	if err != nil {
   369  		ctx.GetLogger().WithError(err).Warn("error running query")
   370  		return remainder, err
   371  	}
   372  
   373  	var rowChan chan sql.Row
   374  
   375  	rowChan = make(chan sql.Row, 512)
   376  
   377  	wg := sync.WaitGroup{}
   378  	wg.Add(2)
   379  	// Read rows off the row iterator and send them to the row channel.
   380  	eg.Go(func() error {
   381  		defer wg.Done()
   382  		defer close(rowChan)
   383  		for {
   384  			select {
   385  			case <-ctx.Done():
   386  				return nil
   387  			default:
   388  				row, err := rowIter.Next(ctx)
   389  				if err == io.EOF {
   390  					return nil
   391  				}
   392  				if err != nil {
   393  					return err
   394  				}
   395  				select {
   396  				case rowChan <- row:
   397  				case <-ctx.Done():
   398  					return nil
   399  				}
   400  			}
   401  		}
   402  
   403  	})
   404  
   405  	pollCtx, cancelF := ctx.NewSubContext()
   406  	eg.Go(func() error {
   407  		return h.pollForClosedConnection(pollCtx, c)
   408  	})
   409  
   410  	// Default waitTime is one minute if there is no timeout configured, in which case
   411  	// it will loop to iterate again unless the socket died by the OS timeout or other problems.
   412  	// If there is a timeout, it will be enforced to ensure that Vitess has a chance to
   413  	// call Handler.CloseConnection()
   414  	waitTime := 1 * time.Minute
   415  	if h.readTimeout > 0 {
   416  		waitTime = h.readTimeout
   417  	}
   418  	timer := time.NewTimer(waitTime)
   419  	defer timer.Stop()
   420  
   421  	var r *sqltypes.Result
   422  	var processedAtLeastOneBatch bool
   423  
   424  	// reads rows from the channel, converts them to wire format,
   425  	// and calls |callback| to give them to vitess.
   426  	eg.Go(func() error {
   427  		defer cancelF()
   428  		defer wg.Done()
   429  		for {
   430  			if r == nil {
   431  				r = &sqltypes.Result{Fields: schemaToFields(ctx, schema)}
   432  			}
   433  
   434  			if r.RowsAffected == rowsBatch {
   435  				if err := callback(r, more); err != nil {
   436  					return err
   437  				}
   438  				r = nil
   439  				processedAtLeastOneBatch = true
   440  				continue
   441  			}
   442  
   443  			select {
   444  			case <-ctx.Done():
   445  				return nil
   446  			case row, ok := <-rowChan:
   447  				if !ok {
   448  					return nil
   449  				}
   450  				if types.IsOkResult(row) {
   451  					if len(r.Rows) > 0 {
   452  						panic("Got OkResult mixed with RowResult")
   453  					}
   454  					r = resultFromOkResult(row[0].(types.OkResult))
   455  					continue
   456  				}
   457  
   458  				outputRow, err := rowToSQL(ctx, schema, row)
   459  				if err != nil {
   460  					return err
   461  				}
   462  
   463  				ctx.GetLogger().Tracef("spooling result row %s", outputRow)
   464  				r.Rows = append(r.Rows, outputRow)
   465  				r.RowsAffected++
   466  			case <-timer.C:
   467  				if h.readTimeout != 0 {
   468  					// Cancel and return so Vitess can call the CloseConnection callback
   469  					ctx.GetLogger().Tracef("connection timeout")
   470  					return ErrRowTimeout.New()
   471  				}
   472  			}
   473  			if !timer.Stop() {
   474  				<-timer.C
   475  			}
   476  			timer.Reset(waitTime)
   477  		}
   478  	})
   479  
   480  	// Close() kills this PID in the process list,
   481  	// wait until all rows have be sent over the wire
   482  	eg.Go(func() error {
   483  		wg.Wait()
   484  		return rowIter.Close(ctx)
   485  	})
   486  
   487  	err = eg.Wait()
   488  	if err != nil {
   489  		ctx.GetLogger().WithError(err).Warn("error running query")
   490  		return remainder, err
   491  	}
   492  
   493  	// errGroup context is now canceled
   494  	ctx = oCtx
   495  
   496  	if err = setConnStatusFlags(ctx, c); err != nil {
   497  		return remainder, err
   498  	}
   499  
   500  	switch len(r.Rows) {
   501  	case 0:
   502  		if len(r.Info) > 0 {
   503  			ctx.GetLogger().Tracef("returning result %s", r.Info)
   504  		} else {
   505  			ctx.GetLogger().Tracef("returning empty result")
   506  		}
   507  	case 1:
   508  		ctx.GetLogger().Tracef("returning result %v", r)
   509  	}
   510  
   511  	ctx.GetLogger().Debugf("Query finished in %d ms", time.Since(start).Milliseconds())
   512  
   513  	// processedAtLeastOneBatch means we already called callback() at least
   514  	// once, so no need to call it if RowsAffected == 0.
   515  	if r != nil && (r.RowsAffected == 0 && processedAtLeastOneBatch) {
   516  		return remainder, nil
   517  	}
   518  
   519  	return remainder, callback(r, more)
   520  }
   521  
   522  // See https://dev.mysql.com/doc/internals/en/status-flags.html
   523  func setConnStatusFlags(ctx *sql.Context, c *mysql.Conn) error {
   524  	ok, err := isSessionAutocommit(ctx)
   525  	if err != nil {
   526  		return err
   527  	}
   528  	if ok {
   529  		c.StatusFlags |= uint16(mysql.ServerStatusAutocommit)
   530  	} else {
   531  		c.StatusFlags &= ^uint16(mysql.ServerStatusAutocommit)
   532  	}
   533  
   534  	if t := ctx.GetTransaction(); t != nil {
   535  		c.StatusFlags |= uint16(mysql.ServerInTransaction)
   536  	} else {
   537  		c.StatusFlags &= ^uint16(mysql.ServerInTransaction)
   538  	}
   539  
   540  	return nil
   541  }
   542  
   543  func isSessionAutocommit(ctx *sql.Context) (bool, error) {
   544  	autoCommitSessionVar, err := ctx.GetSessionVariable(ctx, sql.AutoCommitSessionVar)
   545  	if err != nil {
   546  		return false, err
   547  	}
   548  	return sql.ConvertToBool(ctx, autoCommitSessionVar)
   549  }
   550  
   551  // Call doQuery and cast known errors to SQLError
   552  func (h *Handler) errorWrappedDoQuery(
   553  	c *mysql.Conn,
   554  	query string,
   555  	parsed sqlparser.Statement,
   556  	mode MultiStmtMode,
   557  	bindings map[string]*querypb.BindVariable,
   558  	callback func(*sqltypes.Result, bool) error,
   559  ) (string, error) {
   560  	start := time.Now()
   561  	if h.sel != nil {
   562  		h.sel.QueryStarted()
   563  	}
   564  
   565  	remainder, err := h.doQuery(c, query, parsed, nil, mode, h.executeQuery, bindings, callback)
   566  	if err != nil {
   567  		err = sql.CastSQLError(err)
   568  	}
   569  
   570  	if h.sel != nil {
   571  		h.sel.QueryCompleted(err == nil, time.Since(start))
   572  	}
   573  
   574  	return remainder, err
   575  }
   576  
   577  // Call doQuery and cast known errors to SQLError
   578  func (h *Handler) errorWrappedComExec(
   579  	c *mysql.Conn,
   580  	query string,
   581  	analyzedPlan sql.Node,
   582  	callback func(*sqltypes.Result, bool) error,
   583  ) error {
   584  	start := time.Now()
   585  	if h.sel != nil {
   586  		h.sel.QueryStarted()
   587  	}
   588  
   589  	_, err := h.doQuery(c, query, nil, analyzedPlan, MultiStmtModeOff, h.executeBoundPlan, nil, callback)
   590  
   591  	if err != nil {
   592  		err = sql.CastSQLError(err)
   593  	}
   594  
   595  	if h.sel != nil {
   596  		h.sel.QueryCompleted(err == nil, time.Since(start))
   597  	}
   598  
   599  	return err
   600  }
   601  
   602  // Periodically polls the connection socket to determine if it is has been closed by the client, returning an error
   603  // if it has been. Meant to be run in an errgroup from the query handler routine. Returns immediately with no error
   604  // on platforms that can't support TCP socket checks.
   605  func (h *Handler) pollForClosedConnection(ctx *sql.Context, c *mysql.Conn) error {
   606  	tcpConn, ok := maybeGetTCPConn(c.Conn)
   607  	if !ok {
   608  		ctx.GetLogger().Trace("Connection checker exiting, connection isn't TCP")
   609  		return nil
   610  	}
   611  
   612  	inode, err := sockstate.GetConnInode(tcpConn)
   613  	if err != nil || inode == 0 {
   614  		if !sockstate.ErrSocketCheckNotImplemented.Is(err) {
   615  			ctx.GetLogger().Trace("Connection checker exiting, connection isn't TCP")
   616  		}
   617  		return nil
   618  	}
   619  
   620  	t, ok := tcpConn.LocalAddr().(*net.TCPAddr)
   621  	if !ok {
   622  		ctx.GetLogger().Trace("Connection checker exiting, could not get local port")
   623  		return nil
   624  	}
   625  
   626  	timer := time.NewTimer(tcpCheckerSleepDuration)
   627  	defer timer.Stop()
   628  
   629  	for {
   630  		select {
   631  		case <-ctx.Done():
   632  			return nil
   633  		case <-timer.C:
   634  		}
   635  
   636  		st, err := sockstate.GetInodeSockState(t.Port, inode)
   637  		switch st {
   638  		case sockstate.Broken:
   639  			ctx.GetLogger().Warn("socket state is broken, returning error")
   640  			return ErrConnectionWasClosed.New()
   641  		case sockstate.Error:
   642  			ctx.GetLogger().WithError(err).Warn("Connection checker exiting, got err checking sockstate")
   643  			return nil
   644  		default: // Established
   645  			// (juanjux) this check is not free, each iteration takes about 9 milliseconds to run on my machine
   646  			// thus the small wait between checks
   647  			timer.Reset(tcpCheckerSleepDuration)
   648  		}
   649  	}
   650  }
   651  
   652  func maybeGetTCPConn(conn net.Conn) (*net.TCPConn, bool) {
   653  	wrap, ok := conn.(netutil.ConnWithTimeouts)
   654  	if ok {
   655  		conn = wrap.Conn
   656  	}
   657  
   658  	tcp, ok := conn.(*net.TCPConn)
   659  	if ok {
   660  		return tcp, true
   661  	}
   662  
   663  	return nil, false
   664  }
   665  
   666  func resultFromOkResult(result types.OkResult) *sqltypes.Result {
   667  	infoStr := ""
   668  	if result.Info != nil {
   669  		infoStr = result.Info.String()
   670  	}
   671  	return &sqltypes.Result{
   672  		RowsAffected: result.RowsAffected,
   673  		InsertID:     result.InsertID,
   674  		Info:         infoStr,
   675  	}
   676  }
   677  
   678  // WarningCount is called at the end of each query to obtain
   679  // the value to be returned to the client in the EOF packet.
   680  // Note that this will be called either in the context of the
   681  // ComQuery callback if the result does not contain any fields,
   682  // or after the last ComQuery call completes.
   683  func (h *Handler) WarningCount(c *mysql.Conn) uint16 {
   684  	if sess := h.sm.session(c); sess != nil {
   685  		return sess.WarningCount()
   686  	}
   687  
   688  	return 0
   689  }
   690  
   691  func rowToSQL(ctx *sql.Context, s sql.Schema, row sql.Row) ([]sqltypes.Value, error) {
   692  	o := make([]sqltypes.Value, len(row))
   693  	var err error
   694  	for i, v := range row {
   695  		if v == nil {
   696  			o[i] = sqltypes.NULL
   697  			continue
   698  		}
   699  		// need to make sure the schema is not null as some plan schema is defined as null (e.g. IfElseBlock)
   700  		if s != nil {
   701  			o[i], err = s[i].Type.SQL(ctx, nil, v)
   702  			if err != nil {
   703  				return nil, err
   704  			}
   705  		}
   706  	}
   707  
   708  	return o, nil
   709  }
   710  
   711  func row2ToSQL(s sql.Schema, row sql.Row2) ([]sqltypes.Value, error) {
   712  	o := make([]sqltypes.Value, len(row))
   713  	var err error
   714  	for i := 0; i < row.Len(); i++ {
   715  		v := row.GetField(i)
   716  		if v.IsNull() {
   717  			o[i] = sqltypes.NULL
   718  			continue
   719  		}
   720  
   721  		// need to make sure the schema is not null as some plan schema is defined as null (e.g. IfElseBlock)
   722  		if s != nil {
   723  			o[i], err = s[i].Type.(sql.Type2).SQL2(v)
   724  			if err != nil {
   725  				return nil, err
   726  			}
   727  		}
   728  	}
   729  
   730  	return o, nil
   731  }
   732  
   733  func schemaToFields(ctx *sql.Context, s sql.Schema) []*querypb.Field {
   734  	charSetResults := ctx.GetCharacterSetResults()
   735  	fields := make([]*querypb.Field, len(s))
   736  	for i, c := range s {
   737  		charset := uint32(sql.Collation_Default.CharacterSet())
   738  		if collatedType, ok := c.Type.(sql.TypeWithCollation); ok {
   739  			charset = uint32(collatedType.Collation().CharacterSet())
   740  		}
   741  
   742  		// Binary types always use a binary collation, but non-binary types must
   743  		// respect character_set_results if it is set.
   744  		if types.IsBinaryType(c.Type) {
   745  			charset = uint32(sql.Collation_binary)
   746  		} else if charSetResults != sql.CharacterSet_Unspecified {
   747  			charset = uint32(charSetResults)
   748  		}
   749  
   750  		var flags querypb.MySqlFlag
   751  		if !c.Nullable {
   752  			flags = flags | querypb.MySqlFlag_NOT_NULL_FLAG
   753  		}
   754  		if c.AutoIncrement {
   755  			flags = flags | querypb.MySqlFlag_AUTO_INCREMENT_FLAG
   756  		}
   757  		if c.PrimaryKey {
   758  			flags = flags | querypb.MySqlFlag_PRI_KEY_FLAG
   759  		}
   760  		if types.IsUnsigned(c.Type) {
   761  			flags = flags | querypb.MySqlFlag_UNSIGNED_FLAG
   762  		}
   763  
   764  		fields[i] = &querypb.Field{
   765  			Name:         c.Name,
   766  			OrgName:      c.Name,
   767  			Table:        c.Source,
   768  			OrgTable:     c.Source,
   769  			Database:     c.DatabaseSource,
   770  			Type:         c.Type.Type(),
   771  			Charset:      charset,
   772  			ColumnLength: c.Type.MaxTextResponseByteLength(ctx),
   773  			Flags:        uint32(flags),
   774  		}
   775  
   776  		if types.IsDecimal(c.Type) {
   777  			decimalType := c.Type.(sql.DecimalType)
   778  			fields[i].Decimals = uint32(decimalType.Scale())
   779  		} else if types.IsDatetimeType(c.Type) {
   780  			dtType := c.Type.(sql.DatetimeType)
   781  			fields[i].Decimals = uint32(dtType.Precision())
   782  		}
   783  	}
   784  
   785  	return fields
   786  }
   787  
   788  var (
   789  	// QueryCounter describes a metric that accumulates number of queries monotonically.
   790  	QueryCounter = discard.NewCounter()
   791  
   792  	// QueryErrorCounter describes a metric that accumulates number of failed queries monotonically.
   793  	QueryErrorCounter = discard.NewCounter()
   794  
   795  	// QueryHistogram describes a queries latency.
   796  	QueryHistogram = discard.NewHistogram()
   797  )
   798  
   799  func observeQuery(ctx *sql.Context, query string) func(err error) {
   800  	span, ctx := ctx.Span("query", trace.WithAttributes(attribute.String("query", query)))
   801  
   802  	t := time.Now()
   803  	return func(err error) {
   804  		if err != nil {
   805  			QueryErrorCounter.With("query", query, "error", err.Error()).Add(1)
   806  		} else {
   807  			QueryCounter.With("query", query).Add(1)
   808  			QueryHistogram.With("query", query, "duration", "seconds").Observe(time.Since(t).Seconds())
   809  		}
   810  
   811  		span.End()
   812  	}
   813  }
   814  
   815  // QueryExecutor is a function that executes a query and returns the result as a schema and iterator. Either of
   816  // |parsed| or |analyzed| can be nil depending on the use case
   817  type QueryExecutor func(
   818  	ctx *sql.Context,
   819  	query string,
   820  	parsed sqlparser.Statement,
   821  	analyzed sql.Node,
   822  	bindings map[string]*querypb.BindVariable,
   823  ) (sql.Schema, sql.RowIter, error)
   824  
   825  // executeQuery is a QueryExecutor that calls QueryWithBindings on the given engine using the given query and parsed
   826  // statement, which may be nil.
   827  func (h *Handler) executeQuery(
   828  	ctx *sql.Context,
   829  	query string,
   830  	parsed sqlparser.Statement,
   831  	_ sql.Node,
   832  	bindings map[string]*querypb.BindVariable,
   833  ) (sql.Schema, sql.RowIter, error) {
   834  	return h.e.QueryWithBindings(ctx, query, parsed, bindings)
   835  }
   836  
   837  // executeQuery is a QueryExecutor that calls QueryWithBindings on the given engine using the given query and parsed
   838  // statement, which may be nil.
   839  func (h *Handler) executeBoundPlan(
   840  	ctx *sql.Context,
   841  	query string,
   842  	_ sqlparser.Statement,
   843  	plan sql.Node,
   844  	_ map[string]*querypb.BindVariable,
   845  ) (sql.Schema, sql.RowIter, error) {
   846  	return h.e.PrepQueryPlanForExecution(ctx, query, plan)
   847  }