github.com/dolthub/go-mysql-server@v0.18.0/enginetest/server_engine_test.go (about)

     1  package enginetest_test
     2  
     3  import (
     4  	"context"
     5  	gosql "database/sql"
     6  	"fmt"
     7  	"math"
     8  	"net"
     9  	"testing"
    10  
    11  	"github.com/dolthub/vitess/go/mysql"
    12  	_ "github.com/go-sql-driver/mysql"
    13  	"github.com/gocraft/dbr/v2"
    14  	"github.com/stretchr/testify/require"
    15  
    16  	sqle "github.com/dolthub/go-mysql-server"
    17  	"github.com/dolthub/go-mysql-server/memory"
    18  	"github.com/dolthub/go-mysql-server/server"
    19  	"github.com/dolthub/go-mysql-server/sql"
    20  )
    21  
    22  var (
    23  	address   = "localhost"
    24  	noUserFmt = "no_user:@tcp(%s:%d)/"
    25  )
    26  
    27  func findEmptyPort() (int, error) {
    28  	listener, err := net.Listen("tcp", ":0")
    29  	if err != nil {
    30  		return -1, err
    31  	}
    32  	port := listener.Addr().(*net.TCPAddr).Port
    33  	if err = listener.Close(); err != nil {
    34  		return -1, err
    35  
    36  	}
    37  	return port, nil
    38  }
    39  
    40  // initTestServer initializes an in-memory server with the given port, but does not start it.
    41  func initTestServer(port int) (*server.Server, error) {
    42  	pro := memory.NewDBProvider()
    43  	engine := sqle.NewDefault(pro)
    44  	config := server.Config{
    45  		Protocol: "tcp",
    46  		Address:  fmt.Sprintf("%s:%d", address, port),
    47  	}
    48  	sessBuilder := func(ctx context.Context, conn *mysql.Conn, addr string) (sql.Session, error) {
    49  		return memory.NewSession(sql.NewBaseSession(), pro), nil
    50  	}
    51  	s, err := server.NewServer(config, engine, sessBuilder, nil)
    52  	if err != nil {
    53  		return nil, err
    54  	}
    55  	return s, nil
    56  }
    57  
    58  // TestSmoke checks that an in-memory server can be started and stopped without error.
    59  func TestSmoke(t *testing.T) {
    60  	port, err := findEmptyPort()
    61  	require.NoError(t, err)
    62  
    63  	s, err := initTestServer(port)
    64  	require.NoError(t, err)
    65  	go s.Start()
    66  	defer s.Close()
    67  
    68  	conn, err := dbr.Open("mysql", fmt.Sprintf(noUserFmt, address, port), nil)
    69  	require.NoError(t, err)
    70  	defer conn.Close()
    71  
    72  	require.NoError(t, conn.Ping())
    73  }
    74  
    75  type serverScriptTestAssertion struct {
    76  	query  string
    77  	isExec bool
    78  	args   []any
    79  	skip   bool
    80  
    81  	expectErr            bool
    82  	expectedRowsAffected int64
    83  	expectedRows         []any
    84  
    85  	// can't avoid writing custom comparator because of how gosql.Rows.Scan() works
    86  	checkRows func(rows *gosql.Rows, expectedRows []any) (bool, error)
    87  }
    88  
    89  type serverScriptTest struct {
    90  	name       string
    91  	setup      []string
    92  	assertions []serverScriptTestAssertion
    93  }
    94  
    95  func TestServerPreparedStatements(t *testing.T) {
    96  	tests := []serverScriptTest{
    97  		{
    98  			name: "prepared inserts with big ints",
    99  			setup: []string{
   100  				"create database test_db;",
   101  				"use test_db;",
   102  				"create table signed_tbl (i bigint signed);",
   103  				"create table unsigned_tbl (i bigint unsigned);",
   104  			},
   105  			assertions: []serverScriptTestAssertion{
   106  				{
   107  					query:                "insert into unsigned_tbl values (?)",
   108  					args:                 []any{uint64(math.MaxInt64)},
   109  					isExec:               true,
   110  					expectedRowsAffected: 1,
   111  				},
   112  				{
   113  					query:                "insert into unsigned_tbl values (?)",
   114  					args:                 []any{uint64(math.MaxInt64 + 1)},
   115  					isExec:               true,
   116  					expectedRowsAffected: 1,
   117  				},
   118  				{
   119  					query:                "insert into unsigned_tbl values (?)",
   120  					args:                 []any{uint64(math.MaxUint64)},
   121  					isExec:               true,
   122  					expectedRowsAffected: 1,
   123  				},
   124  				{
   125  					query:     "insert into unsigned_tbl values (?)",
   126  					args:      []any{int64(-1)},
   127  					isExec:    true,
   128  					expectErr: true,
   129  				},
   130  				{
   131  					query:     "insert into unsigned_tbl values (?)",
   132  					args:      []any{int64(math.MinInt64)},
   133  					isExec:    true,
   134  					expectErr: true,
   135  				},
   136  				{
   137  					query: "select * from unsigned_tbl order by i",
   138  					expectedRows: []any{
   139  						[]uint64{uint64(math.MaxInt64)},
   140  						[]uint64{uint64(math.MaxInt64 + 1)},
   141  						[]uint64{uint64(math.MaxUint64)},
   142  					},
   143  					checkRows: func(rows *gosql.Rows, expectedRows []any) (bool, error) {
   144  						var i uint64
   145  						var rowNum int
   146  						for rows.Next() {
   147  							if err := rows.Scan(&i); err != nil {
   148  								return false, err
   149  							}
   150  							if rowNum >= len(expectedRows) {
   151  								return false, nil
   152  							}
   153  							if i != expectedRows[rowNum].([]uint64)[0] {
   154  								return false, nil
   155  							}
   156  							rowNum++
   157  						}
   158  						return true, nil
   159  					},
   160  				},
   161  
   162  				{
   163  					query:                "insert into signed_tbl values (?)",
   164  					args:                 []any{uint64(math.MaxInt64)},
   165  					isExec:               true,
   166  					expectedRowsAffected: 1,
   167  				},
   168  				{
   169  					query:     "insert into signed_tbl values (?)",
   170  					args:      []any{uint64(math.MaxInt64 + 1)},
   171  					isExec:    true,
   172  					expectErr: true,
   173  				},
   174  				{
   175  					query:                "insert into signed_tbl values (?)",
   176  					args:                 []any{int64(-1)},
   177  					isExec:               true,
   178  					expectedRowsAffected: 1,
   179  				},
   180  				{
   181  					query:                "insert into signed_tbl values (?)",
   182  					args:                 []any{int64(math.MinInt64)},
   183  					isExec:               true,
   184  					expectedRowsAffected: 1,
   185  				},
   186  				{
   187  					query: "select * from signed_tbl order by i",
   188  					expectedRows: []any{
   189  						[]int64{int64(math.MinInt64)},
   190  						[]int64{int64(-1)},
   191  						[]int64{int64(math.MaxInt64)},
   192  					},
   193  					checkRows: func(rows *gosql.Rows, expectedRows []any) (bool, error) {
   194  						var i int64
   195  						var rowNum int
   196  						for rows.Next() {
   197  							if err := rows.Scan(&i); err != nil {
   198  								return false, err
   199  							}
   200  							if rowNum >= len(expectedRows) {
   201  								return false, fmt.Errorf("expected %d rows, got more", len(expectedRows))
   202  							}
   203  							if i != expectedRows[rowNum].([]int64)[0] {
   204  								return false, fmt.Errorf("expected %d, got %d", expectedRows[rowNum].([]int64)[0], i)
   205  							}
   206  							rowNum++
   207  						}
   208  						return true, nil
   209  					},
   210  				},
   211  			},
   212  		},
   213  	}
   214  
   215  	port, perr := findEmptyPort()
   216  	require.NoError(t, perr)
   217  
   218  	s, serr := initTestServer(port)
   219  	require.NoError(t, serr)
   220  	go s.Start()
   221  	defer s.Close()
   222  
   223  	for _, test := range tests {
   224  		t.Run(test.name, func(t *testing.T) {
   225  			conn, cerr := dbr.Open("mysql", fmt.Sprintf(noUserFmt, address, port), nil)
   226  			require.NoError(t, cerr)
   227  			defer conn.Close()
   228  
   229  			for _, stmt := range test.setup {
   230  				_, err := conn.Exec(stmt)
   231  				require.NoError(t, err)
   232  			}
   233  			for _, assertion := range test.assertions {
   234  				t.Run(assertion.query, func(t *testing.T) {
   235  					if assertion.skip {
   236  						t.Skip()
   237  					}
   238  					if assertion.isExec {
   239  						res, err := conn.Exec(assertion.query, assertion.args...)
   240  						if assertion.expectErr {
   241  							require.Error(t, err)
   242  							return
   243  						}
   244  						require.NoError(t, err)
   245  						rowsAffected, err := res.RowsAffected()
   246  						require.NoError(t, err)
   247  						require.Equal(t, assertion.expectedRowsAffected, rowsAffected)
   248  						return
   249  					}
   250  					rows, err := conn.Query(assertion.query, assertion.args...)
   251  					if assertion.expectErr {
   252  						require.Error(t, err)
   253  						return
   254  					}
   255  					ok, err := assertion.checkRows(rows, assertion.expectedRows)
   256  					require.NoError(t, err)
   257  					require.True(t, ok)
   258  				})
   259  			}
   260  		})
   261  	}
   262  }