github.com/dolthub/go-mysql-server@v0.18.0/enginetest/server_engine_test.go (about) 1 package enginetest_test 2 3 import ( 4 "context" 5 gosql "database/sql" 6 "fmt" 7 "math" 8 "net" 9 "testing" 10 11 "github.com/dolthub/vitess/go/mysql" 12 _ "github.com/go-sql-driver/mysql" 13 "github.com/gocraft/dbr/v2" 14 "github.com/stretchr/testify/require" 15 16 sqle "github.com/dolthub/go-mysql-server" 17 "github.com/dolthub/go-mysql-server/memory" 18 "github.com/dolthub/go-mysql-server/server" 19 "github.com/dolthub/go-mysql-server/sql" 20 ) 21 22 var ( 23 address = "localhost" 24 noUserFmt = "no_user:@tcp(%s:%d)/" 25 ) 26 27 func findEmptyPort() (int, error) { 28 listener, err := net.Listen("tcp", ":0") 29 if err != nil { 30 return -1, err 31 } 32 port := listener.Addr().(*net.TCPAddr).Port 33 if err = listener.Close(); err != nil { 34 return -1, err 35 36 } 37 return port, nil 38 } 39 40 // initTestServer initializes an in-memory server with the given port, but does not start it. 41 func initTestServer(port int) (*server.Server, error) { 42 pro := memory.NewDBProvider() 43 engine := sqle.NewDefault(pro) 44 config := server.Config{ 45 Protocol: "tcp", 46 Address: fmt.Sprintf("%s:%d", address, port), 47 } 48 sessBuilder := func(ctx context.Context, conn *mysql.Conn, addr string) (sql.Session, error) { 49 return memory.NewSession(sql.NewBaseSession(), pro), nil 50 } 51 s, err := server.NewServer(config, engine, sessBuilder, nil) 52 if err != nil { 53 return nil, err 54 } 55 return s, nil 56 } 57 58 // TestSmoke checks that an in-memory server can be started and stopped without error. 59 func TestSmoke(t *testing.T) { 60 port, err := findEmptyPort() 61 require.NoError(t, err) 62 63 s, err := initTestServer(port) 64 require.NoError(t, err) 65 go s.Start() 66 defer s.Close() 67 68 conn, err := dbr.Open("mysql", fmt.Sprintf(noUserFmt, address, port), nil) 69 require.NoError(t, err) 70 defer conn.Close() 71 72 require.NoError(t, conn.Ping()) 73 } 74 75 type serverScriptTestAssertion struct { 76 query string 77 isExec bool 78 args []any 79 skip bool 80 81 expectErr bool 82 expectedRowsAffected int64 83 expectedRows []any 84 85 // can't avoid writing custom comparator because of how gosql.Rows.Scan() works 86 checkRows func(rows *gosql.Rows, expectedRows []any) (bool, error) 87 } 88 89 type serverScriptTest struct { 90 name string 91 setup []string 92 assertions []serverScriptTestAssertion 93 } 94 95 func TestServerPreparedStatements(t *testing.T) { 96 tests := []serverScriptTest{ 97 { 98 name: "prepared inserts with big ints", 99 setup: []string{ 100 "create database test_db;", 101 "use test_db;", 102 "create table signed_tbl (i bigint signed);", 103 "create table unsigned_tbl (i bigint unsigned);", 104 }, 105 assertions: []serverScriptTestAssertion{ 106 { 107 query: "insert into unsigned_tbl values (?)", 108 args: []any{uint64(math.MaxInt64)}, 109 isExec: true, 110 expectedRowsAffected: 1, 111 }, 112 { 113 query: "insert into unsigned_tbl values (?)", 114 args: []any{uint64(math.MaxInt64 + 1)}, 115 isExec: true, 116 expectedRowsAffected: 1, 117 }, 118 { 119 query: "insert into unsigned_tbl values (?)", 120 args: []any{uint64(math.MaxUint64)}, 121 isExec: true, 122 expectedRowsAffected: 1, 123 }, 124 { 125 query: "insert into unsigned_tbl values (?)", 126 args: []any{int64(-1)}, 127 isExec: true, 128 expectErr: true, 129 }, 130 { 131 query: "insert into unsigned_tbl values (?)", 132 args: []any{int64(math.MinInt64)}, 133 isExec: true, 134 expectErr: true, 135 }, 136 { 137 query: "select * from unsigned_tbl order by i", 138 expectedRows: []any{ 139 []uint64{uint64(math.MaxInt64)}, 140 []uint64{uint64(math.MaxInt64 + 1)}, 141 []uint64{uint64(math.MaxUint64)}, 142 }, 143 checkRows: func(rows *gosql.Rows, expectedRows []any) (bool, error) { 144 var i uint64 145 var rowNum int 146 for rows.Next() { 147 if err := rows.Scan(&i); err != nil { 148 return false, err 149 } 150 if rowNum >= len(expectedRows) { 151 return false, nil 152 } 153 if i != expectedRows[rowNum].([]uint64)[0] { 154 return false, nil 155 } 156 rowNum++ 157 } 158 return true, nil 159 }, 160 }, 161 162 { 163 query: "insert into signed_tbl values (?)", 164 args: []any{uint64(math.MaxInt64)}, 165 isExec: true, 166 expectedRowsAffected: 1, 167 }, 168 { 169 query: "insert into signed_tbl values (?)", 170 args: []any{uint64(math.MaxInt64 + 1)}, 171 isExec: true, 172 expectErr: true, 173 }, 174 { 175 query: "insert into signed_tbl values (?)", 176 args: []any{int64(-1)}, 177 isExec: true, 178 expectedRowsAffected: 1, 179 }, 180 { 181 query: "insert into signed_tbl values (?)", 182 args: []any{int64(math.MinInt64)}, 183 isExec: true, 184 expectedRowsAffected: 1, 185 }, 186 { 187 query: "select * from signed_tbl order by i", 188 expectedRows: []any{ 189 []int64{int64(math.MinInt64)}, 190 []int64{int64(-1)}, 191 []int64{int64(math.MaxInt64)}, 192 }, 193 checkRows: func(rows *gosql.Rows, expectedRows []any) (bool, error) { 194 var i int64 195 var rowNum int 196 for rows.Next() { 197 if err := rows.Scan(&i); err != nil { 198 return false, err 199 } 200 if rowNum >= len(expectedRows) { 201 return false, fmt.Errorf("expected %d rows, got more", len(expectedRows)) 202 } 203 if i != expectedRows[rowNum].([]int64)[0] { 204 return false, fmt.Errorf("expected %d, got %d", expectedRows[rowNum].([]int64)[0], i) 205 } 206 rowNum++ 207 } 208 return true, nil 209 }, 210 }, 211 }, 212 }, 213 } 214 215 port, perr := findEmptyPort() 216 require.NoError(t, perr) 217 218 s, serr := initTestServer(port) 219 require.NoError(t, serr) 220 go s.Start() 221 defer s.Close() 222 223 for _, test := range tests { 224 t.Run(test.name, func(t *testing.T) { 225 conn, cerr := dbr.Open("mysql", fmt.Sprintf(noUserFmt, address, port), nil) 226 require.NoError(t, cerr) 227 defer conn.Close() 228 229 for _, stmt := range test.setup { 230 _, err := conn.Exec(stmt) 231 require.NoError(t, err) 232 } 233 for _, assertion := range test.assertions { 234 t.Run(assertion.query, func(t *testing.T) { 235 if assertion.skip { 236 t.Skip() 237 } 238 if assertion.isExec { 239 res, err := conn.Exec(assertion.query, assertion.args...) 240 if assertion.expectErr { 241 require.Error(t, err) 242 return 243 } 244 require.NoError(t, err) 245 rowsAffected, err := res.RowsAffected() 246 require.NoError(t, err) 247 require.Equal(t, assertion.expectedRowsAffected, rowsAffected) 248 return 249 } 250 rows, err := conn.Query(assertion.query, assertion.args...) 251 if assertion.expectErr { 252 require.Error(t, err) 253 return 254 } 255 ok, err := assertion.checkRows(rows, assertion.expectedRows) 256 require.NoError(t, err) 257 require.True(t, ok) 258 }) 259 } 260 }) 261 } 262 }