github.com/dolthub/go-mysql-server@v0.18.0/server/golden/proxy.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package golden
    16  
    17  import (
    18  	dsql "database/sql"
    19  	"fmt"
    20  	"math"
    21  	"reflect"
    22  	"strings"
    23  	"time"
    24  
    25  	"github.com/dolthub/vitess/go/mysql"
    26  	"github.com/dolthub/vitess/go/sqltypes"
    27  	querypb "github.com/dolthub/vitess/go/vt/proto/query"
    28  	"github.com/dolthub/vitess/go/vt/sqlparser"
    29  	mysql2 "github.com/go-sql-driver/mysql"
    30  	"github.com/gocraft/dbr/v2"
    31  	"github.com/sirupsen/logrus"
    32  
    33  	"github.com/dolthub/go-mysql-server/sql"
    34  	"github.com/dolthub/go-mysql-server/sql/planbuilder"
    35  )
    36  
    37  type MySqlProxy struct {
    38  	ctx     *sql.Context
    39  	connStr string
    40  	logger  *logrus.Logger
    41  	conns   map[uint32]proxyConn
    42  }
    43  
    44  func (h MySqlProxy) ParserOptionsForConnection(_ *mysql.Conn) (sqlparser.ParserOptions, error) {
    45  	return sqlparser.ParserOptions{}, nil
    46  }
    47  
    48  type proxyConn struct {
    49  	*dbr.Connection
    50  	*logrus.Entry
    51  }
    52  
    53  // NewMySqlProxyHandler creates a new MySqlProxy.
    54  func NewMySqlProxyHandler(logger *logrus.Logger, connStr string) (MySqlProxy, error) {
    55  	// ensure parseTime=true
    56  	cfg, err := mysql2.ParseDSN(connStr)
    57  	if err != nil {
    58  		return MySqlProxy{}, err
    59  	}
    60  	cfg.ParseTime = true
    61  	connStr = cfg.FormatDSN()
    62  
    63  	conn, err := newConn(connStr, 0, logger)
    64  	if err != nil {
    65  		return MySqlProxy{}, err
    66  	}
    67  	defer func() { _ = conn.Close() }()
    68  
    69  	if err = conn.Ping(); err != nil {
    70  		return MySqlProxy{}, err
    71  	}
    72  
    73  	return MySqlProxy{
    74  		ctx:     sql.NewEmptyContext(),
    75  		connStr: connStr,
    76  		logger:  logger,
    77  		conns:   make(map[uint32]proxyConn),
    78  	}, nil
    79  }
    80  
    81  var _ mysql.Handler = MySqlProxy{}
    82  
    83  func newConn(connStr string, connId uint32, lgr *logrus.Logger) (conn proxyConn, err error) {
    84  	l := logrus.NewEntry(lgr).WithField("dsn", connStr).WithField(sql.ConnectionIdLogField, connId)
    85  	var c *dbr.Connection
    86  	for d := 100.0; d < 10000.0; d *= 1.6 {
    87  		l.Debugf("Attempting connection to MySQL")
    88  		if c, err = dbr.Open("mysql", connStr, nil); err == nil {
    89  			if err = c.Ping(); err == nil {
    90  				break
    91  			}
    92  		}
    93  		time.Sleep(time.Duration(d) * time.Millisecond)
    94  	}
    95  	if err != nil {
    96  		l.Debugf("Failed to establish connection %d", connId)
    97  		return proxyConn{}, err
    98  	}
    99  	l.Debugf("Succesfully established connection")
   100  	return proxyConn{Connection: c, Entry: l}, nil
   101  }
   102  
   103  // NewConnection implements mysql.Handler.
   104  func (h MySqlProxy) NewConnection(c *mysql.Conn) {
   105  	conn, err := newConn(h.connStr, c.ConnectionID, h.logger)
   106  	if err == nil {
   107  		h.conns[c.ConnectionID] = conn
   108  	}
   109  }
   110  
   111  func (h MySqlProxy) getConn(connId uint32) (conn proxyConn, err error) {
   112  	var ok bool
   113  	conn, ok = h.conns[connId]
   114  	if ok {
   115  		return conn, nil
   116  	} else {
   117  		conn, err = newConn(h.connStr, connId, h.logger)
   118  		if err != nil {
   119  			return proxyConn{}, err
   120  		}
   121  	}
   122  	if err = conn.Ping(); err != nil {
   123  		return proxyConn{}, err
   124  	}
   125  	h.conns[connId] = conn
   126  	return conn, nil
   127  }
   128  
   129  // ComInitDB implements mysql.Handler.
   130  func (h MySqlProxy) ComInitDB(c *mysql.Conn, schemaName string) error {
   131  	conn, err := h.getConn(c.ConnectionID)
   132  	if err != nil {
   133  		return err
   134  	}
   135  	if schemaName != "" {
   136  		_, err = conn.Exec("USE " + schemaName + " ;")
   137  	}
   138  	return err
   139  }
   140  
   141  // ComPrepare implements mysql.Handler.
   142  func (h MySqlProxy) ComPrepare(_ *mysql.Conn, _ string, _ *mysql.PrepareData) ([]*querypb.Field, error) {
   143  	return nil, fmt.Errorf("ComPrepare unsupported")
   144  }
   145  
   146  // ComStmtExecute implements mysql.Handler.
   147  func (h MySqlProxy) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error {
   148  	return fmt.Errorf("ComStmtExecute unsupported")
   149  }
   150  
   151  // ComResetConnection implements mysql.Handler.
   152  func (h MySqlProxy) ComResetConnection(_ *mysql.Conn) error {
   153  	return nil
   154  }
   155  
   156  // ConnectionClosed implements mysql.Handler.
   157  func (h MySqlProxy) ConnectionClosed(c *mysql.Conn) {
   158  	conn, ok := h.conns[c.ConnectionID]
   159  	if !ok {
   160  		return
   161  	}
   162  	if err := conn.Close(); err != nil {
   163  		lgr := logrus.WithField(sql.ConnectionIdLogField, c.ConnectionID)
   164  		lgr.Errorf("Error closing connection")
   165  	}
   166  	delete(h.conns, c.ConnectionID)
   167  }
   168  
   169  // ComMultiQuery implements mysql.Handler.
   170  func (h MySqlProxy) ComMultiQuery(
   171  	c *mysql.Conn,
   172  	query string,
   173  	callback mysql.ResultSpoolFn,
   174  ) (string, error) {
   175  	conn, err := h.getConn(c.ConnectionID)
   176  	if err != nil {
   177  		return "", err
   178  	}
   179  	conn.Entry = conn.Entry.WithField("query", query)
   180  
   181  	remainder, err := h.processQuery(c, conn, query, true, callback)
   182  	if err != nil {
   183  		conn.Errorf("Failed to process MySQL results: %s", err)
   184  	}
   185  	return remainder, err
   186  }
   187  
   188  // ComQuery implements mysql.Handler.
   189  func (h MySqlProxy) ComQuery(
   190  	c *mysql.Conn,
   191  	query string,
   192  	callback mysql.ResultSpoolFn,
   193  ) error {
   194  	conn, err := h.getConn(c.ConnectionID)
   195  	if err != nil {
   196  		return err
   197  	}
   198  	conn.Entry = conn.Entry.WithField("query", query)
   199  
   200  	_, err = h.processQuery(c, conn, query, false, callback)
   201  	if err != nil {
   202  		conn.Errorf("Failed to process MySQL results: %s", err)
   203  	}
   204  	return err
   205  }
   206  
   207  // ComParsedQuery implements mysql.Handler.
   208  func (h MySqlProxy) ComParsedQuery(
   209  	c *mysql.Conn,
   210  	query string,
   211  	parsed sqlparser.Statement,
   212  	callback func(*sqltypes.Result, bool) error,
   213  ) error {
   214  	return h.ComQuery(c, query, callback)
   215  }
   216  
   217  func (h MySqlProxy) processQuery(
   218  	c *mysql.Conn,
   219  	proxy proxyConn,
   220  	query string,
   221  	isMultiStatement bool,
   222  	callback func(*sqltypes.Result, bool) error,
   223  ) (string, error) {
   224  	ctx := sql.NewContext(h.ctx)
   225  	var remainder string
   226  	if isMultiStatement {
   227  		_, ri, err := sqlparser.ParseOne(query)
   228  		if err != nil {
   229  			return "", err
   230  		}
   231  		if ri != 0 && ri < len(query) {
   232  			remainder = query[ri:]
   233  			query = query[:ri]
   234  			query = planbuilder.RemoveSpaceAndDelimiter(query, ';')
   235  		}
   236  	}
   237  
   238  	ctx = ctx.WithQuery(query)
   239  	more := remainder != ""
   240  
   241  	proxy.Debugf("Sending query to MySQL")
   242  	rows, err := proxy.Query(query)
   243  	if err != nil {
   244  		return "", err
   245  	}
   246  	defer func() {
   247  		if cerr := rows.Close(); cerr != nil {
   248  			err = cerr
   249  		}
   250  	}()
   251  
   252  	var processedAtLeastOneBatch bool
   253  	res := &sqltypes.Result{}
   254  	ok := true
   255  	for ok {
   256  		if res, ok, err = fetchMySqlRows(ctx, rows, 128); err != nil {
   257  			return "", err
   258  		}
   259  		if err := callback(res, more); err != nil {
   260  			return "", err
   261  		}
   262  		processedAtLeastOneBatch = true
   263  	}
   264  
   265  	if err := setConnStatusFlags(ctx, c); err != nil {
   266  		return remainder, err
   267  	}
   268  
   269  	switch len(res.Rows) {
   270  	case 0:
   271  		if len(res.Info) > 0 {
   272  			ctx.GetLogger().Tracef("returning result %s", res.Info)
   273  		} else {
   274  			ctx.GetLogger().Tracef("returning empty result")
   275  		}
   276  	case 1:
   277  		ctx.GetLogger().Tracef("returning result %v", res)
   278  	}
   279  
   280  	// processedAtLeastOneBatch means we already called resultsCB() at least
   281  	// once, so no need to call it if RowsAffected == 0.
   282  	if res != nil && (res.RowsAffected == 0 && processedAtLeastOneBatch) {
   283  		return remainder, nil
   284  	}
   285  
   286  	return remainder, nil
   287  }
   288  
   289  // WarningCount is called at the end of each query to obtain
   290  // the value to be returned to the client in the EOF packet.
   291  // Note that this will be called either in the context of the
   292  // ComQuery resultsCB if the result does not contain any fields,
   293  // or after the last ComQuery call completes.
   294  func (h MySqlProxy) WarningCount(c *mysql.Conn) uint16 {
   295  	return 0
   296  }
   297  
   298  // See https://dev.mysql.com/doc/internals/en/status-flags.html
   299  func setConnStatusFlags(ctx *sql.Context, c *mysql.Conn) error {
   300  	ok, err := isSessionAutocommit(ctx)
   301  	if err != nil {
   302  		return err
   303  	}
   304  	if ok {
   305  		c.StatusFlags |= uint16(mysql.ServerStatusAutocommit)
   306  	} else {
   307  		c.StatusFlags &= ^uint16(mysql.ServerStatusAutocommit)
   308  	}
   309  	if t := ctx.GetTransaction(); t != nil {
   310  		c.StatusFlags |= uint16(mysql.ServerInTransaction)
   311  	} else {
   312  		c.StatusFlags &= ^uint16(mysql.ServerInTransaction)
   313  	}
   314  	return nil
   315  }
   316  
   317  func isSessionAutocommit(ctx *sql.Context) (bool, error) {
   318  	autoCommitSessionVar, err := ctx.GetSessionVariable(ctx, sql.AutoCommitSessionVar)
   319  	if err != nil {
   320  		return false, err
   321  	}
   322  	return sql.ConvertToBool(ctx, autoCommitSessionVar)
   323  }
   324  
   325  func fetchMySqlRows(ctx *sql.Context, results *dsql.Rows, count int) (res *sqltypes.Result, more bool, err error) {
   326  	cols, err := results.ColumnTypes()
   327  	if err != nil {
   328  		return nil, false, err
   329  	}
   330  
   331  	types, fields, err := schemaToFields(ctx, cols)
   332  	if err != nil {
   333  		return nil, false, err
   334  	}
   335  
   336  	rows := make([][]sqltypes.Value, 0, count)
   337  	for results.Next() {
   338  		if len(rows) == count {
   339  			more = true
   340  			break
   341  		}
   342  
   343  		scanRow, err := scanResultRow(results)
   344  		if err != nil {
   345  			return nil, false, err
   346  		}
   347  
   348  		row := make([]sqltypes.Value, len(fields))
   349  		for i := range row {
   350  			scanRow[i], _, err = types[i].Convert(scanRow[i])
   351  			if err != nil {
   352  				return nil, false, err
   353  			}
   354  			row[i], err = types[i].SQL(ctx, nil, scanRow[i])
   355  			if err != nil {
   356  				return nil, false, err
   357  			}
   358  		}
   359  		rows = append(rows, row)
   360  	}
   361  
   362  	res = &sqltypes.Result{
   363  		Fields:       fields,
   364  		RowsAffected: uint64(len(rows)),
   365  		Rows:         rows,
   366  	}
   367  	return
   368  }
   369  
   370  var typeDefaults = map[string]string{
   371  	"char":      "char(255)",
   372  	"binary":    "binary(255)",
   373  	"varchar":   "varchar(65535)",
   374  	"varbinary": "varbinary(65535)",
   375  }
   376  
   377  func schemaToFields(ctx *sql.Context, cols []*dsql.ColumnType) ([]sql.Type, []*querypb.Field, error) {
   378  	types := make([]sql.Type, len(cols))
   379  	fields := make([]*querypb.Field, len(cols))
   380  
   381  	var err error
   382  	for i, col := range cols {
   383  		typeStr := strings.ToLower(col.DatabaseTypeName())
   384  		if length, ok := col.Length(); ok {
   385  			// append length specifier to type
   386  			typeStr = fmt.Sprintf("%s(%d)", typeStr, length)
   387  		} else if ts, ok := typeDefaults[typeStr]; ok {
   388  			// if no length specifier if given,
   389  			// default to the maximum width
   390  			typeStr = ts
   391  		}
   392  		types[i], err = planbuilder.ParseColumnTypeString(typeStr)
   393  		if err != nil {
   394  			return nil, nil, err
   395  		}
   396  
   397  		var charset uint32
   398  		switch types[i].Type() {
   399  		case sqltypes.Binary, sqltypes.VarBinary, sqltypes.Blob:
   400  			charset = mysql.CharacterSetBinary
   401  		default:
   402  			charset = mysql.CharacterSetUtf8
   403  		}
   404  
   405  		fields[i] = &querypb.Field{
   406  			Name:         col.Name(),
   407  			Type:         types[i].Type(),
   408  			Charset:      charset,
   409  			ColumnLength: math.MaxUint32,
   410  		}
   411  	}
   412  	return types, fields, nil
   413  }
   414  
   415  func scanResultRow(results *dsql.Rows) (sql.Row, error) {
   416  	cols, err := results.ColumnTypes()
   417  	if err != nil {
   418  		return nil, err
   419  	}
   420  
   421  	scanRow := make(sql.Row, len(cols))
   422  	for i := range cols {
   423  		scanRow[i] = reflect.New(cols[i].ScanType()).Interface()
   424  	}
   425  
   426  	for i, columnType := range cols {
   427  		scanRow[i] = reflect.New(columnType.ScanType()).Interface()
   428  	}
   429  
   430  	if err = results.Scan(scanRow...); err != nil {
   431  		return nil, err
   432  	}
   433  	for i, val := range scanRow {
   434  		v := reflect.ValueOf(val).Elem().Interface()
   435  		switch t := v.(type) {
   436  		case dsql.RawBytes:
   437  			if t == nil {
   438  				scanRow[i] = nil
   439  			} else {
   440  				scanRow[i] = string(t)
   441  			}
   442  		case dsql.NullBool:
   443  			if t.Valid {
   444  				scanRow[i] = t.Bool
   445  			} else {
   446  				scanRow[i] = nil
   447  			}
   448  		case dsql.NullByte:
   449  			if t.Valid {
   450  				scanRow[i] = t.Byte
   451  			} else {
   452  				scanRow[i] = nil
   453  			}
   454  		case dsql.NullFloat64:
   455  			if t.Valid {
   456  				scanRow[i] = t.Float64
   457  			} else {
   458  				scanRow[i] = nil
   459  			}
   460  		case dsql.NullInt16:
   461  			if t.Valid {
   462  				scanRow[i] = t.Int16
   463  			} else {
   464  				scanRow[i] = nil
   465  			}
   466  		case dsql.NullInt32:
   467  			if t.Valid {
   468  				scanRow[i] = t.Int32
   469  			} else {
   470  				scanRow[i] = nil
   471  			}
   472  		case dsql.NullInt64:
   473  			if t.Valid {
   474  				scanRow[i] = t.Int64
   475  			} else {
   476  				scanRow[i] = nil
   477  			}
   478  		case dsql.NullString:
   479  			if t.Valid {
   480  				scanRow[i] = t.String
   481  			} else {
   482  				scanRow[i] = nil
   483  			}
   484  		case dsql.NullTime:
   485  			if t.Valid {
   486  				scanRow[i] = t.Time
   487  			} else {
   488  				scanRow[i] = nil
   489  			}
   490  		default:
   491  			scanRow[i] = t
   492  		}
   493  	}
   494  	return scanRow, nil
   495  }