vitess.io/vitess@v0.16.2/go/vt/vtgate/plugin_mysql_server_test.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package vtgate 18 19 import ( 20 "context" 21 "crypto/tls" 22 "fmt" 23 "os" 24 "path" 25 "strings" 26 "syscall" 27 "testing" 28 "time" 29 30 "github.com/stretchr/testify/assert" 31 32 "vitess.io/vitess/go/trace" 33 34 "vitess.io/vitess/go/mysql" 35 "vitess.io/vitess/go/sqltypes" 36 querypb "vitess.io/vitess/go/vt/proto/query" 37 "vitess.io/vitess/go/vt/tlstest" 38 ) 39 40 type testHandler struct { 41 mysql.UnimplementedHandler 42 lastConn *mysql.Conn 43 } 44 45 func (th *testHandler) NewConnection(c *mysql.Conn) { 46 th.lastConn = c 47 } 48 49 func (th *testHandler) ComQuery(c *mysql.Conn, q string, callback func(*sqltypes.Result) error) error { 50 // when creating a connection, we send a query to MySQL to set the connection's collation, 51 // this query usually returns us something. however, we use testHandler which is a fake 52 // implementation of MySQL that returns no results and no error for set queries, Vitess 53 // interprets this as an error, we do not want to fail if we see such error. 54 // for this reason, we send back an empty result to the caller. 55 return callback(&sqltypes.Result{Fields: []*querypb.Field{}, Rows: [][]sqltypes.Value{}}) 56 } 57 58 func (th *testHandler) ComPrepare(c *mysql.Conn, q string, b map[string]*querypb.BindVariable) ([]*querypb.Field, error) { 59 return nil, nil 60 } 61 62 func (th *testHandler) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error { 63 return nil 64 } 65 66 func (th *testHandler) ComRegisterReplica(c *mysql.Conn, replicaHost string, replicaPort uint16, replicaUser string, replicaPassword string) error { 67 return nil 68 } 69 70 func (th *testHandler) ComBinlogDump(c *mysql.Conn, logFile string, binlogPos uint32) error { 71 return nil 72 } 73 74 func (th *testHandler) ComBinlogDumpGTID(c *mysql.Conn, logFile string, logPos uint64, gtidSet mysql.GTIDSet) error { 75 return nil 76 } 77 78 func (th *testHandler) WarningCount(c *mysql.Conn) uint16 { 79 return 0 80 } 81 82 func TestConnectionUnixSocket(t *testing.T) { 83 th := &testHandler{} 84 85 authServer := newTestAuthServerStatic() 86 87 // Use tmp file to reserve a path, remove it immediately, we only care about 88 // name in this context 89 unixSocket, err := os.CreateTemp("", "mysql_vitess_test.sock") 90 if err != nil { 91 t.Fatalf("Failed to create temp file") 92 } 93 os.Remove(unixSocket.Name()) 94 95 l, err := newMysqlUnixSocket(unixSocket.Name(), authServer, th) 96 if err != nil { 97 t.Fatalf("NewUnixSocket failed: %v", err) 98 } 99 defer l.Close() 100 go l.Accept() 101 102 params := &mysql.ConnParams{ 103 UnixSocket: unixSocket.Name(), 104 Uname: "user1", 105 Pass: "password1", 106 } 107 108 c, err := mysql.Connect(context.Background(), params) 109 if err != nil { 110 t.Errorf("Should be able to connect to server but found error: %v", err) 111 } 112 c.Close() 113 } 114 115 func TestConnectionStaleUnixSocket(t *testing.T) { 116 th := &testHandler{} 117 118 authServer := newTestAuthServerStatic() 119 120 // First let's create a file. In this way, we simulate 121 // having a stale socket on disk that needs to be cleaned up. 122 unixSocket, err := os.CreateTemp("", "mysql_vitess_test.sock") 123 if err != nil { 124 t.Fatalf("Failed to create temp file") 125 } 126 127 l, err := newMysqlUnixSocket(unixSocket.Name(), authServer, th) 128 if err != nil { 129 t.Fatalf("NewListener failed: %v", err) 130 } 131 defer l.Close() 132 go l.Accept() 133 134 params := &mysql.ConnParams{ 135 UnixSocket: unixSocket.Name(), 136 Uname: "user1", 137 Pass: "password1", 138 } 139 140 c, err := mysql.Connect(context.Background(), params) 141 if err != nil { 142 t.Errorf("Should be able to connect to server but found error: %v", err) 143 } 144 c.Close() 145 } 146 147 func TestConnectionRespectsExistingUnixSocket(t *testing.T) { 148 th := &testHandler{} 149 150 authServer := newTestAuthServerStatic() 151 152 unixSocket, err := os.CreateTemp("", "mysql_vitess_test.sock") 153 if err != nil { 154 t.Fatalf("Failed to create temp file") 155 } 156 os.Remove(unixSocket.Name()) 157 158 l, err := newMysqlUnixSocket(unixSocket.Name(), authServer, th) 159 if err != nil { 160 t.Errorf("NewListener failed: %v", err) 161 } 162 defer l.Close() 163 go l.Accept() 164 _, err = newMysqlUnixSocket(unixSocket.Name(), authServer, th) 165 want := "listen unix" 166 if err == nil || !strings.HasPrefix(err.Error(), want) { 167 t.Errorf("Error: %v, want prefix %s", err, want) 168 } 169 } 170 171 var newSpanOK = func(ctx context.Context, label string) (trace.Span, context.Context) { 172 return trace.NoopSpan{}, context.Background() 173 } 174 175 var newFromStringOK = func(ctx context.Context, spanContext, label string) (trace.Span, context.Context, error) { 176 return trace.NoopSpan{}, context.Background(), nil 177 } 178 179 func newFromStringFail(t *testing.T) func(ctx context.Context, parentSpan string, label string) (trace.Span, context.Context, error) { 180 return func(ctx context.Context, parentSpan string, label string) (trace.Span, context.Context, error) { 181 t.Fatalf("we didn't provide a parent span in the sql query. this should not have been called. got: %v", parentSpan) 182 return trace.NoopSpan{}, context.Background(), nil 183 } 184 } 185 186 func newFromStringError(t *testing.T) func(ctx context.Context, parentSpan string, label string) (trace.Span, context.Context, error) { 187 return func(ctx context.Context, parentSpan string, label string) (trace.Span, context.Context, error) { 188 return trace.NoopSpan{}, context.Background(), fmt.Errorf("") 189 } 190 } 191 192 func newFromStringExpect(t *testing.T, expected string) func(ctx context.Context, parentSpan string, label string) (trace.Span, context.Context, error) { 193 return func(ctx context.Context, parentSpan string, label string) (trace.Span, context.Context, error) { 194 assert.Equal(t, expected, parentSpan) 195 return trace.NoopSpan{}, context.Background(), nil 196 } 197 } 198 199 func newSpanFail(t *testing.T) func(ctx context.Context, label string) (trace.Span, context.Context) { 200 return func(ctx context.Context, label string) (trace.Span, context.Context) { 201 t.Fatalf("we provided a span context but newFromString was not used as expected") 202 return trace.NoopSpan{}, context.Background() 203 } 204 } 205 206 func TestNoSpanContextPassed(t *testing.T) { 207 _, _, err := startSpanTestable(context.Background(), "sql without comments", "someLabel", newSpanOK, newFromStringFail(t)) 208 assert.NoError(t, err) 209 } 210 211 func TestSpanContextNoPassedInButExistsInString(t *testing.T) { 212 _, _, err := startSpanTestable(context.Background(), "SELECT * FROM SOMETABLE WHERE COL = \"/*VT_SPAN_CONTEXT=123*/", "someLabel", newSpanOK, newFromStringFail(t)) 213 assert.NoError(t, err) 214 } 215 216 func TestSpanContextPassedIn(t *testing.T) { 217 _, _, err := startSpanTestable(context.Background(), "/*VT_SPAN_CONTEXT=123*/SQL QUERY", "someLabel", newSpanFail(t), newFromStringOK) 218 assert.NoError(t, err) 219 } 220 221 func TestSpanContextPassedInEvenAroundOtherComments(t *testing.T) { 222 _, _, err := startSpanTestable(context.Background(), "/*VT_SPAN_CONTEXT=123*/SELECT /*vt+ SCATTER_ERRORS_AS_WARNINGS */ col1, col2 FROM TABLE ", "someLabel", 223 newSpanFail(t), 224 newFromStringExpect(t, "123")) 225 assert.NoError(t, err) 226 } 227 228 func TestSpanContextNotParsable(t *testing.T) { 229 hasRun := false 230 _, _, err := startSpanTestable(context.Background(), "/*VT_SPAN_CONTEXT=123*/SQL QUERY", "someLabel", 231 func(c context.Context, s string) (trace.Span, context.Context) { 232 hasRun = true 233 return trace.NoopSpan{}, context.Background() 234 }, 235 newFromStringError(t)) 236 assert.NoError(t, err) 237 assert.True(t, hasRun, "Should have continued execution despite failure to parse VT_SPAN_CONTEXT") 238 } 239 240 func newTestAuthServerStatic() *mysql.AuthServerStatic { 241 jsonConfig := "{\"user1\":{\"Password\":\"password1\", \"UserData\":\"userData1\", \"SourceHost\":\"localhost\"}}" 242 return mysql.NewAuthServerStatic("", jsonConfig, 0) 243 } 244 245 func TestDefaultWorkloadEmpty(t *testing.T) { 246 vh := &vtgateHandler{} 247 sess := vh.session(&mysql.Conn{}) 248 if sess.Options.Workload != querypb.ExecuteOptions_OLTP { 249 t.Fatalf("Expected default workload OLTP") 250 } 251 } 252 253 func TestDefaultWorkloadOLAP(t *testing.T) { 254 vh := &vtgateHandler{} 255 mysqlDefaultWorkload = int32(querypb.ExecuteOptions_OLAP) 256 sess := vh.session(&mysql.Conn{}) 257 if sess.Options.Workload != querypb.ExecuteOptions_OLAP { 258 t.Fatalf("Expected default workload OLAP") 259 } 260 } 261 262 func TestInitTLSConfigWithoutServerCA(t *testing.T) { 263 testInitTLSConfig(t, false) 264 } 265 266 func TestInitTLSConfigWithServerCA(t *testing.T) { 267 testInitTLSConfig(t, true) 268 } 269 270 func testInitTLSConfig(t *testing.T, serverCA bool) { 271 // Create the certs. 272 root := t.TempDir() 273 tlstest.CreateCA(root) 274 tlstest.CreateCRL(root, tlstest.CA) 275 tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com") 276 277 serverCACert := "" 278 if serverCA { 279 serverCACert = path.Join(root, "ca-cert.pem") 280 } 281 282 listener := &mysql.Listener{} 283 if err := initTLSConfig(listener, path.Join(root, "server-cert.pem"), path.Join(root, "server-key.pem"), path.Join(root, "ca-cert.pem"), path.Join(root, "ca-crl.pem"), serverCACert, true, tls.VersionTLS12); err != nil { 284 t.Fatalf("init tls config failure due to: +%v", err) 285 } 286 287 serverConfig := listener.TLSConfig.Load() 288 if serverConfig == nil { 289 t.Fatalf("init tls config shouldn't create nil server config") 290 } 291 292 sigChan <- syscall.SIGHUP 293 time.Sleep(100 * time.Millisecond) // wait for signal handler 294 295 if listener.TLSConfig.Load() == serverConfig { 296 t.Fatalf("init tls config should have been recreated after SIGHUP") 297 } 298 }