vitess.io/vitess@v0.16.2/go/mysql/client_test.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package mysql 18 19 import ( 20 "context" 21 "crypto/tls" 22 "fmt" 23 "net" 24 "os" 25 "path" 26 "regexp" 27 "strings" 28 "sync" 29 "testing" 30 "time" 31 32 "github.com/stretchr/testify/assert" 33 "github.com/stretchr/testify/require" 34 35 "vitess.io/vitess/go/vt/tlstest" 36 "vitess.io/vitess/go/vt/vttls" 37 ) 38 39 // assertSQLError makes sure we get the right error. 40 func assertSQLError(t *testing.T, err error, code int, sqlState, subtext, query, pattern string) { 41 t.Helper() 42 43 require.Error(t, err, "was expecting SQLError %v / %v / %v but got no error.", code, sqlState, subtext) 44 serr, ok := err.(*SQLError) 45 require.True(t, ok, "was expecting SQLError %v / %v / %v but got: %v", code, sqlState, subtext, err) 46 require.Equal(t, code, serr.Num, "was expecting SQLError %v / %v / %v but got code %v", code, sqlState, subtext, serr.Num) 47 require.Equal(t, sqlState, serr.State, "was expecting SQLError %v / %v / %v but got state %v", code, sqlState, subtext, serr.State) 48 if pattern != "" { 49 require.Regexp(t, regexp.MustCompile(pattern), serr.Message) 50 } else { 51 require.True(t, subtext == "" || strings.Contains(serr.Message, subtext), "was expecting SQLError %v / %v / %v but got message %v", code, sqlState, subtext, serr.Message) 52 } 53 require.Equal(t, query, serr.Query, "was expecting SQLError %v / %v / %v with Query '%v' but got query '%v'", code, sqlState, subtext, query, serr.Query) 54 } 55 56 // TestConnectTimeout runs connection failure scenarios against a 57 // server that's not listening or has trouble. This test is not meant 58 // to use a valid server. So we do not test bad handshakes here. 59 func TestConnectTimeout(t *testing.T) { 60 // Create a socket, but it's not accepting. So all Dial 61 // attempts will timeout. 62 listener, err := net.Listen("tcp", "127.0.0.1:") 63 require.NoError(t, err, "cannot listen: %v", err) 64 host, port := getHostPort(t, listener.Addr()) 65 params := &ConnParams{ 66 Host: host, 67 Port: port, 68 } 69 defer listener.Close() 70 71 // Test that canceling the context really interrupts the Connect. 72 ctx, cancel := context.WithCancel(context.Background()) 73 done := make(chan struct{}) 74 go func() { 75 _, err := Connect(ctx, params) 76 assert.Equal(t, context.Canceled, err, "Was expecting context.Canceled but got: %v", err) 77 close(done) 78 }() 79 time.Sleep(100 * time.Millisecond) 80 cancel() 81 <-done 82 83 // Tests a connection timeout works. 84 ctx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond) 85 _, err = Connect(ctx, params) 86 cancel() 87 assert.Equal(t, context.DeadlineExceeded, err, "Was expecting context.DeadlineExceeded but got: %v", err) 88 89 // Tests a connection timeout through params 90 ctx = context.Background() 91 paramsWithTimeout := *params 92 paramsWithTimeout.ConnectTimeoutMs = 1 93 _, err = Connect(ctx, ¶msWithTimeout) 94 cancel() 95 assert.Equal(t, context.DeadlineExceeded, err, "Was expecting context.DeadlineExceeded but got: %v", err) 96 97 // Now the server will listen, but close all connections on accept. 98 wg := sync.WaitGroup{} 99 wg.Add(1) 100 go func() { 101 defer wg.Done() 102 for { 103 conn, err := listener.Accept() 104 if err != nil { 105 // Listener was closed. 106 return 107 } 108 conn.Close() 109 } 110 }() 111 ctx = context.Background() 112 _, err = Connect(ctx, params) 113 assertSQLError(t, err, CRServerLost, SSUnknownSQLState, "initial packet read failed", "", "") 114 115 // Now close the listener. Connect should fail right away, 116 // check the error. 117 listener.Close() 118 wg.Wait() 119 _, err = Connect(ctx, params) 120 assertSQLError(t, err, CRConnHostError, SSUnknownSQLState, "connection refused", "", "") 121 122 // Tests a connection where Dial to a unix socket fails 123 // properly returns the right error. To simulate exactly the 124 // right failure, try to dial a Unix socket that's just a temp file. 125 fd, err := os.CreateTemp("", "mysql") 126 require.NoError(t, err, "cannot create TempFile: %v", err) 127 name := fd.Name() 128 fd.Close() 129 params.UnixSocket = name 130 ctx = context.Background() 131 _, err = Connect(ctx, params) 132 os.Remove(name) 133 t.Log(err) 134 assertSQLError(t, err, CRConnectionError, SSUnknownSQLState, "connection refused", "", "net\\.Dial\\(([a-z0-9A-Z_\\/]*)\\) to local server failed:") 135 } 136 137 // TestTLSClientDisabled creates a Server with TLS support, then connects 138 // with a client with TLS disabled. 139 func TestTLSClientDisabled(t *testing.T) { 140 th := &testHandler{} 141 142 authServer := NewAuthServerStatic("", "", 0) 143 authServer.entries["user1"] = []*AuthServerStaticEntry{{ 144 Password: "password1", 145 }} 146 defer authServer.close() 147 148 // Create the listener, so we can get its host. 149 // Below, we are enabling --ssl-verify-server-cert, which adds 150 // a check that the common name of the certificate matches the 151 // server host name we connect to. 152 l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) 153 require.NoError(t, err) 154 defer l.Close() 155 156 host := l.Addr().(*net.TCPAddr).IP.String() 157 port := l.Addr().(*net.TCPAddr).Port 158 159 // Create the certs. 160 root := t.TempDir() 161 tlstest.CreateCA(root) 162 tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", host) 163 tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert") 164 165 // Create the server with TLS config. 166 serverConfig, err := vttls.ServerConfig( 167 path.Join(root, "server-cert.pem"), 168 path.Join(root, "server-key.pem"), 169 "", 170 "", 171 "", 172 tls.VersionTLS12) 173 require.NoError(t, err) 174 l.TLSConfig.Store(serverConfig) 175 176 var wg sync.WaitGroup 177 wg.Add(1) 178 go func(l *Listener) { 179 wg.Done() 180 l.Accept() 181 }(l) 182 // This is ensure the listener is called 183 wg.Wait() 184 // Sleep so that the Accept function is called as well.' 185 time.Sleep(3 * time.Second) 186 187 // Setup the right parameters. 188 params := &ConnParams{ 189 Host: host, 190 Port: port, 191 Uname: "user1", 192 Pass: "password1", 193 SslMode: vttls.Disabled, 194 } 195 196 conn, err := Connect(context.Background(), params) 197 require.NoError(t, err) 198 199 // make sure this went through SSL 200 results, err := conn.ExecuteFetch("ssl echo", 1000, true) 201 require.NoError(t, err) 202 assert.Equal(t, "OFF", results.Rows[0][0].ToString()) 203 204 if conn != nil { 205 conn.Close() 206 } 207 } 208 209 // TestTLSClientDisabled creates a Server with TLS support, then connects 210 // with a client with TLS preferred. 211 func TestTLSClientPreferredDefault(t *testing.T) { 212 th := &testHandler{} 213 214 authServer := NewAuthServerStatic("", "", 0) 215 authServer.entries["user1"] = []*AuthServerStaticEntry{{ 216 Password: "password1", 217 }} 218 defer authServer.close() 219 220 // Create the listener, so we can get its host. 221 // Below, we are enabling --ssl-verify-server-cert, which adds 222 // a check that the common name of the certificate matches the 223 // server host name we connect to. 224 l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) 225 require.NoError(t, err) 226 defer l.Close() 227 228 host := l.Addr().(*net.TCPAddr).IP.String() 229 port := l.Addr().(*net.TCPAddr).Port 230 231 // Create the certs. 232 root := t.TempDir() 233 tlstest.CreateCA(root) 234 tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com") 235 tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert") 236 237 // Create the server with TLS config. 238 serverConfig, err := vttls.ServerConfig( 239 path.Join(root, "server-cert.pem"), 240 path.Join(root, "server-key.pem"), 241 "", 242 "", 243 "", 244 tls.VersionTLS12) 245 require.NoError(t, err) 246 l.TLSConfig.Store(serverConfig) 247 248 var wg sync.WaitGroup 249 wg.Add(1) 250 go func(l *Listener) { 251 wg.Done() 252 l.Accept() 253 }(l) 254 // This is ensure the listener is called 255 wg.Wait() 256 // Sleep so that the Accept function is called as well.' 257 time.Sleep(3 * time.Second) 258 259 // Setup the right parameters. 260 params := &ConnParams{ 261 Host: host, 262 Port: port, 263 Uname: "user1", 264 Pass: "password1", 265 SslMode: vttls.Preferred, 266 ServerName: "server.example.com", 267 } 268 269 conn, err := Connect(context.Background(), params) 270 require.NoError(t, err) 271 272 // make sure this went through SSL 273 results, err := conn.ExecuteFetch("ssl echo", 1000, true) 274 require.NoError(t, err) 275 assert.Equal(t, "ON", results.Rows[0][0].ToString()) 276 277 if conn != nil { 278 conn.Close() 279 } 280 } 281 282 // TestTLSClientRequired creates a Server with no TLS support, then connects 283 // with a client with TLS required. 284 func TestTLSClientRequired(t *testing.T) { 285 th := &testHandler{} 286 287 authServer := NewAuthServerStatic("", "", 0) 288 authServer.entries["user1"] = []*AuthServerStaticEntry{{ 289 Password: "password1", 290 }} 291 defer authServer.close() 292 293 // Create the listener, so we can get its host. 294 // Below, we are enabling --ssl-verify-server-cert, which adds 295 // a check that the common name of the certificate matches the 296 // server host name we connect to. 297 l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) 298 require.NoError(t, err) 299 defer l.Close() 300 301 host := l.Addr().(*net.TCPAddr).IP.String() 302 port := l.Addr().(*net.TCPAddr).Port 303 304 var wg sync.WaitGroup 305 wg.Add(1) 306 go func(l *Listener) { 307 wg.Done() 308 l.Accept() 309 }(l) 310 // This is ensure the listener is called 311 wg.Wait() 312 // Sleep so that the Accept function is called as well.' 313 time.Sleep(3 * time.Second) 314 315 // Setup the right parameters. 316 params := &ConnParams{ 317 Host: host, 318 Port: port, 319 Uname: "user1", 320 Pass: "password1", 321 SslMode: vttls.Required, 322 } 323 324 _, err = Connect(context.Background(), params) 325 require.Error(t, err) 326 assert.Contains(t, err.Error(), "server doesn't support SSL but client asked for it") 327 } 328 329 // TestTLSClientVerifyCA creates a Server with TLS support, then connects 330 // with a client with TLS enabled on a wrong hostname but with verify CA on. 331 func TestTLSClientVerifyCA(t *testing.T) { 332 th := &testHandler{} 333 334 authServer := NewAuthServerStatic("", "", 0) 335 authServer.entries["user1"] = []*AuthServerStaticEntry{{ 336 Password: "password1", 337 }} 338 defer authServer.close() 339 340 // Create the listener, so we can get its host. 341 // Below, we are enabling --ssl-verify-server-cert, which adds 342 // a check that the common name of the certificate matches the 343 // server host name we connect to. 344 l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) 345 require.NoError(t, err) 346 defer l.Close() 347 348 host := l.Addr().(*net.TCPAddr).IP.String() 349 port := l.Addr().(*net.TCPAddr).Port 350 351 // Create the certs. 352 root := t.TempDir() 353 tlstest.CreateCA(root) 354 tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com") 355 tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert") 356 357 // Create the server with TLS config. 358 serverConfig, err := vttls.ServerConfig( 359 path.Join(root, "server-cert.pem"), 360 path.Join(root, "server-key.pem"), 361 "", 362 "", 363 "", 364 tls.VersionTLS12) 365 require.NoError(t, err) 366 l.TLSConfig.Store(serverConfig) 367 368 var wg sync.WaitGroup 369 wg.Add(1) 370 go func(l *Listener) { 371 wg.Done() 372 l.Accept() 373 }(l) 374 // This is ensure the listener is called 375 wg.Wait() 376 // Sleep so that the Accept function is called as well.' 377 time.Sleep(3 * time.Second) 378 379 // Setup the right parameters. 380 params := &ConnParams{ 381 Host: host, 382 Port: port, 383 Uname: "user1", 384 Pass: "password1", 385 // SSL flags. 386 SslMode: vttls.VerifyCA, 387 ServerName: "server.example.com", 388 } 389 390 _, err = Connect(context.Background(), params) 391 require.Error(t, err) 392 393 fmt.Printf("Error: %s", err) 394 395 assert.Contains(t, err.Error(), "cannot send HandshakeResponse41: x509:") 396 397 // Now setup proper CA that is valid to verify 398 params.SslCa = path.Join(root, "ca-cert.pem") 399 conn, err := Connect(context.Background(), params) 400 require.NoError(t, err) 401 402 // make sure this went through SSL 403 results, err := conn.ExecuteFetch("ssl echo", 1000, true) 404 require.NoError(t, err) 405 assert.Equal(t, "ON", results.Rows[0][0].ToString()) 406 407 if conn != nil { 408 conn.Close() 409 } 410 } 411 412 // TestTLSClientVerifyIdentity creates a Server with TLS support, then connects 413 // with a client with TLS enabled on a wrong hostname but with verify CA on. 414 func TestTLSClientVerifyIdentity(t *testing.T) { 415 th := &testHandler{} 416 417 authServer := NewAuthServerStatic("", "", 0) 418 authServer.entries["user1"] = []*AuthServerStaticEntry{{ 419 Password: "password1", 420 }} 421 defer authServer.close() 422 423 // Create the listener, so we can get its host. 424 // Below, we are enabling --ssl-verify-server-cert, which adds 425 // a check that the common name of the certificate matches the 426 // server host name we connect to. 427 l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) 428 require.NoError(t, err) 429 defer l.Close() 430 431 host := l.Addr().(*net.TCPAddr).IP.String() 432 port := l.Addr().(*net.TCPAddr).Port 433 434 // Create the certs. 435 root := t.TempDir() 436 tlstest.CreateCA(root) 437 tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com") 438 tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert") 439 440 // Create the server with TLS config. 441 serverConfig, err := vttls.ServerConfig( 442 path.Join(root, "server-cert.pem"), 443 path.Join(root, "server-key.pem"), 444 "", 445 "", 446 "", 447 tls.VersionTLS12) 448 require.NoError(t, err) 449 l.TLSConfig.Store(serverConfig) 450 451 var wg sync.WaitGroup 452 wg.Add(1) 453 go func(l *Listener) { 454 wg.Done() 455 l.Accept() 456 }(l) 457 // This is ensure the listener is called 458 wg.Wait() 459 // Sleep so that the Accept function is called as well.' 460 time.Sleep(3 * time.Second) 461 462 // Setup the right parameters. 463 params := &ConnParams{ 464 Host: host, 465 Port: port, 466 Uname: "user1", 467 Pass: "password1", 468 // SSL flags. 469 SslMode: vttls.VerifyIdentity, 470 ServerName: "server.example.com", 471 } 472 473 _, err = Connect(context.Background(), params) 474 require.Error(t, err) 475 476 fmt.Printf("Error: %s", err) 477 478 assert.Contains(t, err.Error(), "cannot send HandshakeResponse41: tls:") 479 480 // Now setup proper CA that is valid to verify 481 params.SslCa = path.Join(root, "ca-cert.pem") 482 conn, err := Connect(context.Background(), params) 483 require.NoError(t, err) 484 485 // make sure this went through SSL 486 results, err := conn.ExecuteFetch("ssl echo", 1000, true) 487 require.NoError(t, err) 488 assert.Equal(t, "ON", results.Rows[0][0].ToString()) 489 490 if conn != nil { 491 conn.Close() 492 } 493 494 // Now revoke the server certificate and make sure we can't connect 495 tlstest.RevokeCertAndRegenerateCRL(root, tlstest.CA, "server") 496 497 params.SslCrl = path.Join(root, "ca-crl.pem") 498 _, err = Connect(context.Background(), params) 499 require.Error(t, err) 500 require.Contains(t, err.Error(), "Certificate revoked: CommonName=server.example.com") 501 }