vitess.io/vitess@v0.16.2/go/mysql/handshake_test.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package mysql
    18  
    19  import (
    20  	"context"
    21  	"crypto/tls"
    22  	"net"
    23  	"path"
    24  	"strings"
    25  	"testing"
    26  
    27  	"github.com/stretchr/testify/assert"
    28  	"github.com/stretchr/testify/require"
    29  
    30  	"vitess.io/vitess/go/test/utils"
    31  
    32  	"vitess.io/vitess/go/vt/tlstest"
    33  	"vitess.io/vitess/go/vt/vttls"
    34  )
    35  
    36  // This file tests the handshake scenarios between our client and our server.
    37  
    38  func TestClearTextClientAuth(t *testing.T) {
    39  	th := &testHandler{}
    40  
    41  	authServer := NewAuthServerStaticWithAuthMethodDescription("", "", 0, MysqlClearPassword)
    42  	authServer.entries["user1"] = []*AuthServerStaticEntry{
    43  		{Password: "password1"},
    44  	}
    45  	defer authServer.close()
    46  
    47  	// Create the listener.
    48  	l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false)
    49  	require.NoError(t, err, "NewListener failed: %v", err)
    50  	defer l.Close()
    51  	host := l.Addr().(*net.TCPAddr).IP.String()
    52  	port := l.Addr().(*net.TCPAddr).Port
    53  	go func() {
    54  		l.Accept()
    55  	}()
    56  
    57  	// Setup the right parameters.
    58  	params := &ConnParams{
    59  		Host:    host,
    60  		Port:    port,
    61  		Uname:   "user1",
    62  		Pass:    "password1",
    63  		SslMode: vttls.Disabled,
    64  	}
    65  
    66  	// Connection should fail, as server requires SSL for clear text auth.
    67  	ctx := context.Background()
    68  	_, err = Connect(ctx, params)
    69  	if err == nil || !strings.Contains(err.Error(), "Cannot use clear text authentication over non-SSL connections") {
    70  		t.Fatalf("unexpected connection error: %v", err)
    71  	}
    72  
    73  	// Change server side to allow clear text without auth.
    74  	l.AllowClearTextWithoutTLS.Set(true)
    75  	conn, err := Connect(ctx, params)
    76  	require.NoError(t, err, "unexpected connection error: %v", err)
    77  
    78  	defer conn.Close()
    79  
    80  	// Run a 'select rows' command with results.
    81  	result, err := conn.ExecuteFetch("select rows", 10000, true)
    82  	require.NoError(t, err, "ExecuteFetch failed: %v", err)
    83  
    84  	utils.MustMatch(t, result, selectRowsResult)
    85  
    86  	// Send a ComQuit to avoid the error message on the server side.
    87  	conn.writeComQuit()
    88  }
    89  
    90  // TestSSLConnection creates a server with TLS support, a client that
    91  // also has SSL support, and connects them.
    92  func TestSSLConnection(t *testing.T) {
    93  	th := &testHandler{}
    94  
    95  	authServer := NewAuthServerStaticWithAuthMethodDescription("", "", 0, MysqlClearPassword)
    96  	authServer.entries["user1"] = []*AuthServerStaticEntry{
    97  		{Password: "password1"},
    98  	}
    99  	defer authServer.close()
   100  
   101  	// Create the listener, so we can get its host.
   102  	l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false)
   103  	require.NoError(t, err, "NewListener failed: %v", err)
   104  	defer l.Close()
   105  	host := l.Addr().(*net.TCPAddr).IP.String()
   106  	port := l.Addr().(*net.TCPAddr).Port
   107  
   108  	// Create the certs.
   109  	root := t.TempDir()
   110  	tlstest.CreateCA(root)
   111  	tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com")
   112  	tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert")
   113  
   114  	// Create the server with TLS config.
   115  	serverConfig, err := vttls.ServerConfig(
   116  		path.Join(root, "server-cert.pem"),
   117  		path.Join(root, "server-key.pem"),
   118  		path.Join(root, "ca-cert.pem"),
   119  		"",
   120  		"",
   121  		tls.VersionTLS12)
   122  	require.NoError(t, err, "TLSServerConfig failed: %v", err)
   123  
   124  	l.TLSConfig.Store(serverConfig)
   125  	go func() {
   126  		l.Accept()
   127  	}()
   128  
   129  	// Setup the right parameters.
   130  	params := &ConnParams{
   131  		Host:  host,
   132  		Port:  port,
   133  		Uname: "user1",
   134  		Pass:  "password1",
   135  		// SSL flags.
   136  		SslMode:    vttls.VerifyIdentity,
   137  		SslCa:      path.Join(root, "ca-cert.pem"),
   138  		SslCert:    path.Join(root, "client-cert.pem"),
   139  		SslKey:     path.Join(root, "client-key.pem"),
   140  		ServerName: "server.example.com",
   141  	}
   142  
   143  	t.Run("Basics", func(t *testing.T) {
   144  		testSSLConnectionBasics(t, params)
   145  	})
   146  
   147  	// Make sure clear text auth works over SSL.
   148  	t.Run("ClearText", func(t *testing.T) {
   149  		testSSLConnectionClearText(t, params)
   150  	})
   151  }
   152  
   153  func testSSLConnectionClearText(t *testing.T, params *ConnParams) {
   154  	// Create a client connection, connect.
   155  	ctx := context.Background()
   156  	conn, err := Connect(ctx, params)
   157  	require.NoError(t, err, "Connect failed: %v", err)
   158  
   159  	defer conn.Close()
   160  	assert.Equal(t, "user1", conn.User, "Invalid conn.User, got %v was expecting user1", conn.User)
   161  
   162  	// Make sure this went through SSL.
   163  	result, err := conn.ExecuteFetch("ssl echo", 10000, true)
   164  	require.NoError(t, err, "ExecuteFetch failed: %v", err)
   165  	assert.Equal(t, "ON", result.Rows[0][0].ToString(), "Got wrong result from ExecuteFetch(ssl echo): %v", result)
   166  
   167  	// Send a ComQuit to avoid the error message on the server side.
   168  	conn.writeComQuit()
   169  }
   170  
   171  func testSSLConnectionBasics(t *testing.T, params *ConnParams) {
   172  	// Create a client connection, connect.
   173  	ctx := context.Background()
   174  	conn, err := Connect(ctx, params)
   175  	require.NoError(t, err, "Connect failed: %v", err)
   176  
   177  	defer conn.Close()
   178  	assert.Equal(t, "user1", conn.User, "Invalid conn.User, got %v was expecting user1", conn.User)
   179  
   180  	// Run a 'select rows' command with results.
   181  	result, err := conn.ExecuteFetch("select rows", 10000, true)
   182  	require.NoError(t, err, "ExecuteFetch failed: %v", err)
   183  
   184  	utils.MustMatch(t, result, selectRowsResult)
   185  
   186  	// Make sure this went through SSL.
   187  	result, err = conn.ExecuteFetch("ssl echo", 10000, true)
   188  	require.NoError(t, err, "ExecuteFetch failed: %v", err)
   189  	assert.Equal(t, "ON", result.Rows[0][0].ToString(), "Got wrong result from ExecuteFetch(ssl echo): %v", result)
   190  
   191  	// Send a ComQuit to avoid the error message on the server side.
   192  	conn.writeComQuit()
   193  }