vitess.io/vitess@v0.16.2/go/mysql/auth_server_clientcert_test.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package mysql 18 19 import ( 20 "context" 21 "crypto/tls" 22 "net" 23 "path" 24 "reflect" 25 "testing" 26 27 "github.com/stretchr/testify/assert" 28 "github.com/stretchr/testify/require" 29 30 "vitess.io/vitess/go/vt/tlstest" 31 "vitess.io/vitess/go/vt/vttls" 32 ) 33 34 const clientCertUsername = "Client Cert" 35 36 func init() { 37 // These tests do not invoke the servenv.Parse codepaths, so this default 38 // does not get set by the OnParseFor hook. 39 clientcertAuthMethod = string(MysqlClearPassword) 40 } 41 42 func TestValidCert(t *testing.T) { 43 th := &testHandler{} 44 45 authServer := newAuthServerClientCert() 46 47 // Create the listener, so we can get its host. 48 l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) 49 require.NoError(t, err, "NewListener failed: %v", err) 50 defer l.Close() 51 host := l.Addr().(*net.TCPAddr).IP.String() 52 port := l.Addr().(*net.TCPAddr).Port 53 54 // Create the certs. 55 root := t.TempDir() 56 tlstest.CreateCA(root) 57 tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com") 58 tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", clientCertUsername) 59 tlstest.CreateCRL(root, tlstest.CA) 60 61 // Create the server with TLS config. 62 serverConfig, err := vttls.ServerConfig( 63 path.Join(root, "server-cert.pem"), 64 path.Join(root, "server-key.pem"), 65 path.Join(root, "ca-cert.pem"), 66 path.Join(root, "ca-crl.pem"), 67 "", 68 tls.VersionTLS12) 69 require.NoError(t, err, "TLSServerConfig failed: %v", err) 70 71 l.TLSConfig.Store(serverConfig) 72 go func() { 73 l.Accept() 74 }() 75 76 // Setup the right parameters. 77 params := &ConnParams{ 78 Host: host, 79 Port: port, 80 Uname: clientCertUsername, 81 Pass: "", 82 // SSL flags. 83 SslMode: vttls.VerifyIdentity, 84 SslCa: path.Join(root, "ca-cert.pem"), 85 SslCert: path.Join(root, "client-cert.pem"), 86 SslKey: path.Join(root, "client-key.pem"), 87 ServerName: "server.example.com", 88 } 89 90 ctx := context.Background() 91 conn, err := Connect(ctx, params) 92 require.NoError(t, err, "Connect failed: %v", err) 93 94 defer conn.Close() 95 96 // Make sure this went through SSL. 97 result, err := conn.ExecuteFetch("ssl echo", 10000, true) 98 require.NoError(t, err, "ExecuteFetch failed: %v", err) 99 assert.Equal(t, "ON", result.Rows[0][0].ToString(), "Got wrong result from ExecuteFetch(ssl echo): %v", result) 100 101 userData := th.LastConn().UserData.Get() 102 assert.Equal(t, clientCertUsername, userData.Username, "userdata username is %v, expected %v", userData.Username, clientCertUsername) 103 104 expectedGroups := []string{"localhost", clientCertUsername} 105 assert.True(t, reflect.DeepEqual(userData.Groups, expectedGroups), "userdata groups is %v, expected %v", userData.Groups, expectedGroups) 106 107 // Send a ComQuit to avoid the error message on the server side. 108 conn.writeComQuit() 109 } 110 111 func TestNoCert(t *testing.T) { 112 th := &testHandler{} 113 114 authServer := newAuthServerClientCert() 115 116 // Create the listener, so we can get its host. 117 l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) 118 require.NoError(t, err, "NewListener failed: %v", err) 119 defer l.Close() 120 host := l.Addr().(*net.TCPAddr).IP.String() 121 port := l.Addr().(*net.TCPAddr).Port 122 123 // Create the certs. 124 root := t.TempDir() 125 tlstest.CreateCA(root) 126 tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com") 127 tlstest.CreateCRL(root, tlstest.CA) 128 129 // Create the server with TLS config. 130 serverConfig, err := vttls.ServerConfig( 131 path.Join(root, "server-cert.pem"), 132 path.Join(root, "server-key.pem"), 133 path.Join(root, "ca-cert.pem"), 134 path.Join(root, "ca-crl.pem"), 135 "", 136 tls.VersionTLS12) 137 require.NoError(t, err, "TLSServerConfig failed: %v", err) 138 139 l.TLSConfig.Store(serverConfig) 140 go func() { 141 l.Accept() 142 }() 143 144 // Setup the right parameters. 145 params := &ConnParams{ 146 Host: host, 147 Port: port, 148 Uname: "user1", 149 Pass: "", 150 SslMode: vttls.VerifyIdentity, 151 SslCa: path.Join(root, "ca-cert.pem"), 152 ServerName: "server.example.com", 153 } 154 155 ctx := context.Background() 156 conn, err := Connect(ctx, params) 157 assert.Error(t, err, "Connect() should have errored due to no client cert") 158 159 if conn != nil { 160 conn.Close() 161 } 162 }