github.com/dolthub/go-mysql-server@v0.18.0/enginetest/server_engine.go (about)

     1  // Copyright 2023 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package enginetest
    16  
    17  import (
    18  	gosql "database/sql"
    19  	"encoding/json"
    20  	"errors"
    21  	"fmt"
    22  	"net"
    23  	"strconv"
    24  	"strings"
    25  	"testing"
    26  	"time"
    27  
    28  	"github.com/dolthub/vitess/go/sqltypes"
    29  	"github.com/dolthub/vitess/go/vt/proto/query"
    30  	"github.com/dolthub/vitess/go/vt/sqlparser"
    31  	"github.com/go-sql-driver/mysql"
    32  	_ "github.com/go-sql-driver/mysql"
    33  
    34  	sqle "github.com/dolthub/go-mysql-server"
    35  	"github.com/dolthub/go-mysql-server/server"
    36  	"github.com/dolthub/go-mysql-server/sql"
    37  	"github.com/dolthub/go-mysql-server/sql/analyzer"
    38  	"github.com/dolthub/go-mysql-server/sql/mysql_db"
    39  	"github.com/dolthub/go-mysql-server/sql/types"
    40  )
    41  
    42  type ServerQueryEngine struct {
    43  	engine *sqle.Engine
    44  	server *server.Server
    45  	t      *testing.T
    46  	port   int
    47  	conn   *gosql.DB
    48  }
    49  
    50  var _ QueryEngine = (*ServerQueryEngine)(nil)
    51  
    52  var address = "localhost"
    53  
    54  func NewServerQueryEngine(t *testing.T, engine *sqle.Engine, builder server.SessionBuilder) (*ServerQueryEngine, error) {
    55  	ctx := sql.NewEmptyContext()
    56  
    57  	if err := enableUserAccounts(ctx, engine); err != nil {
    58  		panic(err)
    59  	}
    60  
    61  	p, err := findEmptyPort()
    62  	if err != nil {
    63  		return nil, err
    64  	}
    65  
    66  	config := server.Config{
    67  		Protocol: "tcp",
    68  		Address:  fmt.Sprintf("%s:%d", address, p),
    69  	}
    70  	s, err := server.NewServer(config, engine, builder, nil)
    71  	if err != nil {
    72  		return nil, err
    73  	}
    74  
    75  	go func() {
    76  		_ = s.Start()
    77  	}()
    78  
    79  	return &ServerQueryEngine{
    80  		t:      t,
    81  		engine: engine,
    82  		server: s,
    83  		port:   p,
    84  	}, nil
    85  }
    86  
    87  // NewConnection creates a new connection to the server regardless of whether there is an existing connection.
    88  // If there is an existing connection, it closes it and creates a new connection. New connection uses new session
    89  // that the previous session state data will not persist. This function is also called when there is no connection
    90  // when running a query.
    91  func (s *ServerQueryEngine) NewConnection(ctx *sql.Context) error {
    92  	if s.conn != nil {
    93  		err := s.conn.Close()
    94  		if err != nil {
    95  			return err
    96  		}
    97  	}
    98  
    99  	db := ctx.GetCurrentDatabase()
   100  	// https://stackoverflow.com/questions/29341590/how-to-parse-time-from-database/29343013#29343013
   101  	conn, err := gosql.Open("mysql", fmt.Sprintf("root:@tcp(127.0.0.1:%d)/%s?parseTime=true", s.port, db))
   102  	if err != nil {
   103  		return err
   104  	}
   105  	s.conn = conn
   106  	return nil
   107  }
   108  
   109  func (s *ServerQueryEngine) AnalyzeQuery(ctx *sql.Context, query string) (sql.Node, error) {
   110  	return s.engine.AnalyzeQuery(ctx, query)
   111  }
   112  
   113  func (s *ServerQueryEngine) PrepareQuery(ctx *sql.Context, query string) (sql.Node, error) {
   114  	if s.conn == nil {
   115  		err := s.NewConnection(ctx)
   116  		if err != nil {
   117  			return nil, err
   118  		}
   119  	}
   120  	// TODO
   121  	// q, bindVars, err := injectBindVarsAndPrepare(s.t, ctx, s.engine, query)
   122  	return nil, nil
   123  }
   124  
   125  func (s *ServerQueryEngine) Query(ctx *sql.Context, query string) (sql.Schema, sql.RowIter, error) {
   126  	if s.conn == nil {
   127  		err := s.NewConnection(ctx)
   128  		if err != nil {
   129  			return nil, nil, err
   130  		}
   131  	}
   132  
   133  	// we prepare each query as prepared statement if possible to add more coverage to prepared tests
   134  	q, bindVars, err := injectBindVarsAndPrepare(s.t, ctx, s.engine, query)
   135  	if err != nil {
   136  		// TODO: ctx being used does not get updated when running the queries through go sql driver.
   137  		//  we can try preparing and if it errors, then pass the original query
   138  		// For example, `USE db` does not change the db in the ctx.
   139  		return s.QueryWithBindings(ctx, query, nil, nil)
   140  	}
   141  	if _, ok := cannotBePrepared[query]; ok {
   142  		return s.QueryWithBindings(ctx, query, nil, nil)
   143  	}
   144  	return s.QueryWithBindings(ctx, q, nil, bindVars)
   145  }
   146  
   147  func (s *ServerQueryEngine) EngineAnalyzer() *analyzer.Analyzer {
   148  	return s.engine.Analyzer
   149  }
   150  
   151  func (s *ServerQueryEngine) EnginePreparedDataCache() *sqle.PreparedDataCache {
   152  	return s.engine.PreparedDataCache
   153  }
   154  
   155  func (s *ServerQueryEngine) QueryWithBindings(ctx *sql.Context, query string, parsed sqlparser.Statement, bindings map[string]*query.BindVariable) (sql.Schema, sql.RowIter, error) {
   156  	if s.conn == nil {
   157  		err := s.NewConnection(ctx)
   158  		if err != nil {
   159  			return nil, nil, err
   160  		}
   161  	}
   162  
   163  	var err error
   164  	if parsed == nil {
   165  		parsed, err = sqlparser.Parse(query)
   166  		if err != nil {
   167  			// TODO: conn.Query() empty query does not error
   168  			if strings.HasSuffix(err.Error(), "empty statement") {
   169  				return nil, sql.RowsToRowIter(), nil
   170  			}
   171  			// Note: we cannot access sql_mode when using ServerEngine
   172  			//  to use ParseWithOptions() method. Replacing double quotes
   173  			//  because the 'ANSI' mode is not on by default and will not
   174  			//  be set on the context after SET @@sql_mode = 'ANSI' query.
   175  			ansiQuery := strings.Replace(query, "\"", "`", -1)
   176  			parsed, err = sqlparser.Parse(ansiQuery)
   177  			if err != nil {
   178  				return nil, nil, err
   179  			}
   180  		}
   181  	}
   182  
   183  	// NOTE: MySQL does not support LOAD DATA query as PREPARED STATEMENT.
   184  	//  However, Dolt supports, but not go-sql-driver client
   185  	switch parsed.(type) {
   186  	case *sqlparser.Load, *sqlparser.Execute, *sqlparser.Prepare:
   187  		return s.queryOrExec(nil, parsed, query, []any{})
   188  	}
   189  
   190  	stmt, err := s.conn.Prepare(query)
   191  	if err != nil {
   192  		return nil, nil, trimMySQLErrCodePrefix(err)
   193  	}
   194  
   195  	args := prepareBindingArgs(bindings)
   196  
   197  	return s.queryOrExec(stmt, parsed, query, args)
   198  }
   199  
   200  // queryOrExec function use `query()` or `exec()` method of go-sql-driver depending on the sql parser plan.
   201  // If |stmt| is nil, then we use the connection db to query/exec the given query statement because some queries cannot
   202  // be run as prepared.
   203  // TODO: for `EXECUTE` and `CALL` statements, it can be either query or exec depending on the statement that prepared or stored procedure holds.
   204  //
   205  //	for now, we use `query` to get the row results for these statements. For statements that needs `exec`, there will be no result.
   206  func (s *ServerQueryEngine) queryOrExec(stmt *gosql.Stmt, parsed sqlparser.Statement, query string, args []any) (sql.Schema, sql.RowIter, error) {
   207  	var err error
   208  	switch parsed.(type) {
   209  	// TODO: added `FLUSH` stmt here (should be `exec`) because we don't support `FLUSH BINARY LOGS` or `FLUSH ENGINE LOGS`, so nil schema is returned.
   210  	case *sqlparser.Select, *sqlparser.SetOp, *sqlparser.Show, *sqlparser.Set, *sqlparser.Call, *sqlparser.Begin, *sqlparser.Use, *sqlparser.Load, *sqlparser.Execute, *sqlparser.Analyze, *sqlparser.Flush:
   211  		var rows *gosql.Rows
   212  		if stmt != nil {
   213  			rows, err = stmt.Query(args...)
   214  		} else {
   215  			rows, err = s.conn.Query(query, args...)
   216  		}
   217  		if err != nil {
   218  			return nil, nil, trimMySQLErrCodePrefix(err)
   219  		}
   220  		return convertRowsResult(rows)
   221  	default:
   222  		var res gosql.Result
   223  		if stmt != nil {
   224  			res, err = stmt.Exec(args...)
   225  		} else {
   226  			res, err = s.conn.Exec(query, args...)
   227  		}
   228  		if err != nil {
   229  			return nil, nil, trimMySQLErrCodePrefix(err)
   230  		}
   231  		return convertExecResult(res)
   232  	}
   233  }
   234  
   235  // trimMySQLErrCodePrefix temporarily removes the error code part of the error message returned from the server.
   236  // This allows us to assert the error message strings in the enginetest.
   237  func trimMySQLErrCodePrefix(err error) error {
   238  	errMsg := err.Error()
   239  	r := strings.Split(errMsg, "(HY000): ")
   240  	if len(r) == 2 {
   241  		return errors.New(r[1])
   242  	}
   243  	if e, ok := err.(*mysql.MySQLError); ok {
   244  		// Note: the error msg can be fixed to match with MySQLError at https://github.com/dolthub/vitess/blob/main/go/mysql/sql_error.go#L62
   245  		return errors.New(fmt.Sprintf("%s (errno %v) (sqlstate %s)", e.Message, e.Number, e.SQLState))
   246  	}
   247  	if strings.HasPrefix(errMsg, "sql: expected") && strings.Contains(errMsg, "arguments, got") {
   248  		// TODO: needs better error message for non matching number of binding argument
   249  		//  for Dolt, this error is caught on the first binding variable
   250  		err = sql.ErrUnboundPreparedStatementVariable.New("v1")
   251  	}
   252  	return err
   253  }
   254  
   255  func convertExecResult(exec gosql.Result) (sql.Schema, sql.RowIter, error) {
   256  	affected, err := exec.RowsAffected()
   257  	if err != nil {
   258  		return nil, nil, err
   259  	}
   260  	lastInsertId, err := exec.LastInsertId()
   261  	if err != nil {
   262  		return nil, nil, err
   263  	}
   264  
   265  	okResult := types.OkResult{
   266  		RowsAffected: uint64(affected),
   267  		InsertID:     uint64(lastInsertId),
   268  		Info:         nil,
   269  	}
   270  
   271  	return types.OkResultSchema, sql.RowsToRowIter(sql.NewRow(okResult)), nil
   272  }
   273  
   274  func convertRowsResult(rows *gosql.Rows) (sql.Schema, sql.RowIter, error) {
   275  	sch, err := schemaForRows(rows)
   276  	if err != nil {
   277  		return nil, nil, err
   278  	}
   279  
   280  	rowIter, err := rowIterForGoSqlRows(sch, rows)
   281  	if err != nil {
   282  		return nil, nil, err
   283  	}
   284  
   285  	return sch, rowIter, nil
   286  }
   287  
   288  func rowIterForGoSqlRows(sch sql.Schema, rows *gosql.Rows) (sql.RowIter, error) {
   289  	result := make([]sql.Row, 0)
   290  	r, err := emptyRowForSchema(sch)
   291  	if err != nil {
   292  		return nil, err
   293  	}
   294  
   295  	for rows.Next() {
   296  		err = rows.Scan(r...)
   297  		if err != nil {
   298  			return nil, err
   299  		}
   300  
   301  		row, err := derefRow(r)
   302  		if err != nil {
   303  			return nil, err
   304  		}
   305  
   306  		row = convertValue(sch, row)
   307  
   308  		result = append(result, row)
   309  	}
   310  
   311  	return sql.RowsToRowIter(result...), nil
   312  }
   313  
   314  // convertValue converts the row value scanned from go sql driver client to type that we expect.
   315  // This method helps with testing existing enginetests that expects specific type as returned value.
   316  func convertValue(sch sql.Schema, row sql.Row) sql.Row {
   317  	for i, col := range sch {
   318  		switch col.Type.Type() {
   319  		case query.Type_GEOMETRY:
   320  			if row[i] != nil {
   321  				r, _, err := types.GeometryType{}.Convert(row[i].([]byte))
   322  				if err != nil {
   323  					//t.Skip(fmt.Sprintf("received error converting returned geometry result"))
   324  				} else {
   325  					row[i] = r
   326  				}
   327  			}
   328  		case query.Type_JSON:
   329  			if row[i] != nil {
   330  				// TODO: dolt returns the json result without escaped quotes and backslashes, which does not Unmarshall
   331  				r, err := attemptUnmarshalJSON(string(row[i].([]byte)))
   332  				if err != nil {
   333  					//t.Skip(fmt.Sprintf("received error unmarshalling returned json result"))
   334  					row[i] = nil
   335  				} else {
   336  					row[i] = r
   337  				}
   338  			}
   339  		case query.Type_TIME:
   340  			if row[i] != nil {
   341  				r, _, err := types.TimespanType_{}.Convert(string(row[i].([]byte)))
   342  				if err != nil {
   343  					//t.Skip(fmt.Sprintf("received error converting returned timespan result"))
   344  				} else {
   345  					row[i] = r
   346  				}
   347  			}
   348  		case query.Type_UINT8, query.Type_UINT16, query.Type_UINT24, query.Type_UINT32, query.Type_UINT64:
   349  			// TODO: check todo in 'emptyValuePointerForType' method
   350  			//  we try to cast any value we got to uint64
   351  			if row[i] != nil {
   352  				r, err := castToUint64(row[i])
   353  				if err != nil {
   354  					//t.Skip(fmt.Sprintf("received error converting returned unsigned int result"))
   355  				} else {
   356  					row[i] = r
   357  				}
   358  			}
   359  		}
   360  	}
   361  	return row
   362  }
   363  
   364  // attemptUnmarshalJSON is returns error if the result cannot be unmarshalled
   365  // instead of panicking from using `types.MustJSON()` method.
   366  func attemptUnmarshalJSON(s string) (types.JSONDocument, error) {
   367  	var doc interface{}
   368  	if err := json.Unmarshal([]byte(s), &doc); err != nil {
   369  		return types.JSONDocument{}, err
   370  	}
   371  	return types.JSONDocument{Val: doc}, nil
   372  }
   373  
   374  func castToUint64(v any) (uint64, error) {
   375  	switch val := v.(type) {
   376  	case int8:
   377  		return uint64(val), nil
   378  	case int16:
   379  		return uint64(val), nil
   380  	case int32:
   381  		return uint64(val), nil
   382  	case int64:
   383  		return uint64(val), nil
   384  	case uint8:
   385  		return uint64(val), nil
   386  	case uint16:
   387  		return uint64(val), nil
   388  	case uint32:
   389  		return uint64(val), nil
   390  	case uint64:
   391  		return val, nil
   392  	case []byte:
   393  		u, err := strconv.ParseUint(string(val), 10, 64)
   394  		if err != nil {
   395  			return 0, fmt.Errorf("expected uint64 number, but received: %s", string(val))
   396  		}
   397  		return u, nil
   398  	default:
   399  		return 0, fmt.Errorf("expected uint64 number, but received unexpected type: %T", v)
   400  	}
   401  }
   402  
   403  func derefRow(r []any) (sql.Row, error) {
   404  	row := make(sql.Row, len(r))
   405  	for i, v := range r {
   406  		var err error
   407  		row[i], err = deref(v)
   408  		if err != nil {
   409  			return nil, err
   410  		}
   411  	}
   412  	return row, nil
   413  }
   414  
   415  func deref(val any) (any, error) {
   416  	switch v := val.(type) {
   417  	case *int8:
   418  		return *v, nil
   419  	case *int16:
   420  		return *v, nil
   421  	case *int32:
   422  		return *v, nil
   423  	case *int64:
   424  		return *v, nil
   425  	case *uint8:
   426  		return *v, nil
   427  	case *uint16:
   428  		return *v, nil
   429  	case *uint32:
   430  		return *v, nil
   431  	case *uint64:
   432  		return *v, nil
   433  	case *gosql.NullInt32:
   434  		if v.Valid {
   435  			return v.Int32, nil
   436  		}
   437  		return nil, nil
   438  	case *gosql.NullInt64:
   439  		if v.Valid {
   440  			return v.Int64, nil
   441  		}
   442  		return nil, nil
   443  	case *float32:
   444  		return *v, nil
   445  	case *float64:
   446  		return *v, nil
   447  	case *gosql.NullFloat64:
   448  		if v.Valid {
   449  			return v.Float64, nil
   450  		}
   451  		return nil, nil
   452  	case *string:
   453  		return *v, nil
   454  	case *gosql.NullString:
   455  		if v.Valid {
   456  			return v.String, nil
   457  		}
   458  		return nil, nil
   459  	case *[]byte:
   460  		if *v == nil {
   461  			return nil, nil
   462  		}
   463  		return *v, nil
   464  	case *bool:
   465  		return *v, nil
   466  	case *time.Time:
   467  		return *v, nil
   468  	case *gosql.NullTime:
   469  		if v.Valid {
   470  			return v.Time, nil
   471  		}
   472  		return nil, nil
   473  	case *gosql.NullByte:
   474  		if v.Valid {
   475  			return v.Byte, nil
   476  		}
   477  		return nil, nil
   478  	case *any:
   479  		if *v == nil {
   480  			return nil, nil
   481  		}
   482  		return *v, nil
   483  	default:
   484  		return nil, fmt.Errorf("unhandled type %T", v)
   485  	}
   486  }
   487  
   488  func emptyRowForSchema(sch sql.Schema) ([]any, error) {
   489  	result := make([]any, len(sch))
   490  	for i, col := range sch {
   491  		var err error
   492  		result[i], err = emptyValuePointerForType(col.Type)
   493  		if err != nil {
   494  			return nil, err
   495  		}
   496  	}
   497  	return result, nil
   498  }
   499  
   500  func emptyValuePointerForType(t sql.Type) (any, error) {
   501  	switch t.Type() {
   502  	case query.Type_INT8, query.Type_INT16, query.Type_INT24, query.Type_INT64,
   503  		query.Type_BIT, query.Type_YEAR:
   504  		var i gosql.NullInt64
   505  		return &i, nil
   506  	case query.Type_INT32:
   507  		var i gosql.NullInt32
   508  		return &i, nil
   509  	case query.Type_UINT8, query.Type_UINT16, query.Type_UINT24, query.Type_UINT32, query.Type_UINT64:
   510  		//var i uint64
   511  		// TODO: currently there is no gosql.NullUint64 type, so null value for unsigned integer value cannot be scanned.
   512  		//  this might be resolved in Go 1.22, that is not out yet, https://github.com/go-sql-driver/mysql/issues/1433
   513  		var i any
   514  		return &i, nil
   515  	case query.Type_DATE, query.Type_DATETIME, query.Type_TIMESTAMP:
   516  		var t gosql.NullTime
   517  		return &t, nil
   518  	case query.Type_TEXT, query.Type_VARCHAR, query.Type_CHAR, query.Type_BINARY, query.Type_VARBINARY,
   519  		query.Type_ENUM, query.Type_SET, query.Type_DECIMAL:
   520  		// We have DECIMAL type results in enginetests be checked in STRING format.
   521  		var s gosql.NullString
   522  		return &s, nil
   523  	case query.Type_FLOAT32, query.Type_FLOAT64:
   524  		var f gosql.NullFloat64
   525  		return &f, nil
   526  	case query.Type_JSON, query.Type_BLOB, query.Type_TIME, query.Type_GEOMETRY:
   527  		var f []byte
   528  		return &f, nil
   529  	case query.Type_NULL_TYPE:
   530  		var f gosql.NullByte
   531  		return &f, nil
   532  	default:
   533  		return nil, fmt.Errorf("unsupported type %v", t.Type())
   534  	}
   535  }
   536  
   537  func schemaForRows(rows *gosql.Rows) (sql.Schema, error) {
   538  	types, err := rows.ColumnTypes()
   539  	if err != nil {
   540  		return nil, err
   541  	}
   542  
   543  	names, err := rows.Columns()
   544  	if err != nil {
   545  		return nil, err
   546  	}
   547  
   548  	schema := make(sql.Schema, len(types))
   549  	for i, columnType := range types {
   550  		typ, err := convertGoSqlType(columnType)
   551  		if err != nil {
   552  			return nil, err
   553  		}
   554  		schema[i] = &sql.Column{
   555  			Name: names[i],
   556  			Type: typ,
   557  		}
   558  	}
   559  
   560  	return schema, nil
   561  }
   562  
   563  func convertGoSqlType(columnType *gosql.ColumnType) (sql.Type, error) {
   564  	switch strings.ToLower(columnType.DatabaseTypeName()) {
   565  	case "tinyint", "smallint", "mediumint", "int", "bigint", "bit":
   566  		return types.Int64, nil
   567  	case "unsigned tinyint", "unsigned smallint", "unsigned mediumint", "unsigned int", "unsigned bigint":
   568  		return types.Uint64, nil
   569  	case "float", "double":
   570  		return types.Float64, nil
   571  	case "decimal":
   572  		precision, scale, ok := columnType.DecimalSize()
   573  		if !ok {
   574  			return nil, fmt.Errorf("could not get decimal size for column %s", columnType.Name())
   575  		}
   576  		decimalType, err := types.CreateDecimalType(uint8(precision), uint8(scale))
   577  		if err != nil {
   578  			return nil, err
   579  		}
   580  		return decimalType, nil
   581  	case "date":
   582  		return types.Date, nil
   583  	case "datetime":
   584  		precision, _, ok := columnType.DecimalSize()
   585  		if !ok {
   586  			return nil, fmt.Errorf("could not get precision size for column %s", columnType.Name())
   587  		}
   588  		dtType, err := types.CreateDatetimeType(sqltypes.Datetime, int(precision))
   589  		if err != nil {
   590  			return nil, err
   591  		}
   592  		return dtType, nil
   593  	case "timestamp":
   594  		return types.Timestamp, nil
   595  	case "time":
   596  		return types.Time, nil
   597  	case "year":
   598  		return types.Year, nil
   599  	case "char", "varchar":
   600  		length, _ := columnType.Length()
   601  		if length == 0 {
   602  			length = 255
   603  		}
   604  		return types.CreateString(query.Type_VARCHAR, length, sql.Collation_Default)
   605  	case "tinytext", "text", "mediumtext", "longtext":
   606  		return types.Text, nil
   607  	case "binary", "varbinary", "tinyblob", "blob", "mediumblob", "longblob":
   608  		return types.Blob, nil
   609  	case "json":
   610  		return types.JSON, nil
   611  	case "enum":
   612  		return types.EnumType{}, nil
   613  	case "set":
   614  		return types.SetType{}, nil
   615  	case "null":
   616  		return types.Null, nil
   617  	case "geometry":
   618  		return types.GeometryType{}, nil
   619  	default:
   620  		return nil, fmt.Errorf("unhandled type %s", columnType.DatabaseTypeName())
   621  	}
   622  }
   623  
   624  // prepareBindingArgs returns an array of the binding variable converted from given map.
   625  // The binding variables need to be sorted in order of position in the query. The variable in binding map
   626  // is in random order. The function expects binding variables starting with `:v1` and do not skip number.
   627  // It cannot sort user-defined binding variables (e.g. :var, :foo)
   628  func prepareBindingArgs(bindings map[string]*query.BindVariable) []any {
   629  	numBindVars := len(bindings)
   630  	args := make([]any, numBindVars)
   631  	for i := 0; i < numBindVars; i++ {
   632  		k := fmt.Sprintf("v%d", i+1)
   633  		args[i] = convertVtQueryTypeToGoTypeValue(bindings[k])
   634  	}
   635  	return args
   636  }
   637  
   638  // convertValue converts the row value scanned from go sql driver client to type that we expect.
   639  // This method helps with testing existing enginetests that expects specific type as returned value.
   640  func convertVtQueryTypeToGoTypeValue(b *query.BindVariable) any {
   641  	val := string(b.Value)
   642  	switch b.Type {
   643  	case query.Type_INT8, query.Type_INT16, query.Type_INT24, query.Type_INT32, query.Type_INT64,
   644  		query.Type_BIT, query.Type_YEAR:
   645  		i, err := strconv.ParseInt(val, 10, 64)
   646  		if err != nil {
   647  			return val
   648  		}
   649  		return i
   650  	case query.Type_UINT8, query.Type_UINT16, query.Type_UINT24, query.Type_UINT32, query.Type_UINT64:
   651  		i, err := strconv.ParseUint(val, 10, 64)
   652  		if err != nil {
   653  			return val
   654  		}
   655  		return i
   656  	case query.Type_DATE, query.Type_DATETIME, query.Type_TIMESTAMP:
   657  		return val
   658  	case query.Type_TEXT, query.Type_VARCHAR, query.Type_CHAR, query.Type_BINARY, query.Type_VARBINARY,
   659  		query.Type_ENUM, query.Type_SET, query.Type_DECIMAL:
   660  		return val
   661  	case query.Type_FLOAT32, query.Type_FLOAT64:
   662  		// TODO: maybe not?
   663  		return val
   664  	case query.Type_JSON, query.Type_BLOB, query.Type_TIME, query.Type_GEOMETRY:
   665  		return val
   666  	case query.Type_NULL_TYPE:
   667  		return nil
   668  	default:
   669  		return val
   670  	}
   671  
   672  }
   673  
   674  func findEmptyPort() (int, error) {
   675  	listener, err := net.Listen("tcp", ":0")
   676  	if err != nil {
   677  		return -1, err
   678  	}
   679  	port := listener.Addr().(*net.TCPAddr).Port
   680  	if err = listener.Close(); err != nil {
   681  		return -1, err
   682  
   683  	}
   684  	return port, nil
   685  }
   686  
   687  func (s *ServerQueryEngine) CloseSession(connID uint32) {
   688  	// TODO
   689  }
   690  
   691  func (s *ServerQueryEngine) Close() error {
   692  	return s.server.Close()
   693  }
   694  
   695  // MySQLPersister is an example struct which handles the persistence of the data in the "mysql" database.
   696  type MySQLPersister struct {
   697  	Data []byte
   698  }
   699  
   700  var _ mysql_db.MySQLDbPersistence = (*MySQLPersister)(nil)
   701  
   702  // Persist implements the interface mysql_db.MySQLDbPersistence. This function is simple, in that it simply stores
   703  // the given data inside itself. A real application would persist to the file system.
   704  func (m *MySQLPersister) Persist(ctx *sql.Context, data []byte) error {
   705  	m.Data = data
   706  	return nil
   707  }
   708  
   709  func enableUserAccounts(ctx *sql.Context, engine *sqle.Engine) error {
   710  	mysqlDb := engine.Analyzer.Catalog.MySQLDb
   711  
   712  	// The functions "AddRootAccount" and "LoadData" both automatically enable the "mysql" database, but this is just
   713  	// to explicitly show how one can manually enable (or disable) the database.
   714  	mysqlDb.SetEnabled(true)
   715  	// The persister here simply stands-in for your provided persistence function. The database calls this whenever it
   716  	// needs to save any changes to any of the "mysql" database's tables.
   717  	persister := &MySQLPersister{}
   718  	mysqlDb.SetPersister(persister)
   719  
   720  	// AddRootAccount creates a password-less account named "root" that has all privileges. This is intended for use
   721  	// with testing, and also to set up the initial user accounts. A real application may want to check that a
   722  	// persisted file exists, and call "LoadData" if one does. If a file does not exist, it would call
   723  	// "AddRootAccount".
   724  	mysqlDb.AddRootAccount()
   725  
   726  	return nil
   727  }
   728  
   729  // We skip preparing these queries using injectBindVarsAndPrepare() method. They fail because
   730  // injectBindVarsAndPrepare() method causes the non-string sql values to become string values.
   731  // Other queries simply cause incorrect type result, which is not checked for ServerEngine test for now.
   732  // TODO: remove this map when we fix this issue.
   733  var cannotBePrepared = map[string]bool{
   734  	`select """""foo""""";`: true,
   735  }