vitess.io/vitess@v0.16.2/go/mysql/client_test.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package mysql
    18  
    19  import (
    20  	"context"
    21  	"crypto/tls"
    22  	"fmt"
    23  	"net"
    24  	"os"
    25  	"path"
    26  	"regexp"
    27  	"strings"
    28  	"sync"
    29  	"testing"
    30  	"time"
    31  
    32  	"github.com/stretchr/testify/assert"
    33  	"github.com/stretchr/testify/require"
    34  
    35  	"vitess.io/vitess/go/vt/tlstest"
    36  	"vitess.io/vitess/go/vt/vttls"
    37  )
    38  
    39  // assertSQLError makes sure we get the right error.
    40  func assertSQLError(t *testing.T, err error, code int, sqlState, subtext, query, pattern string) {
    41  	t.Helper()
    42  
    43  	require.Error(t, err, "was expecting SQLError %v / %v / %v but got no error.", code, sqlState, subtext)
    44  	serr, ok := err.(*SQLError)
    45  	require.True(t, ok, "was expecting SQLError %v / %v / %v but got: %v", code, sqlState, subtext, err)
    46  	require.Equal(t, code, serr.Num, "was expecting SQLError %v / %v / %v but got code %v", code, sqlState, subtext, serr.Num)
    47  	require.Equal(t, sqlState, serr.State, "was expecting SQLError %v / %v / %v but got state %v", code, sqlState, subtext, serr.State)
    48  	if pattern != "" {
    49  		require.Regexp(t, regexp.MustCompile(pattern), serr.Message)
    50  	} else {
    51  		require.True(t, subtext == "" || strings.Contains(serr.Message, subtext), "was expecting SQLError %v / %v / %v but got message %v", code, sqlState, subtext, serr.Message)
    52  	}
    53  	require.Equal(t, query, serr.Query, "was expecting SQLError %v / %v / %v with Query '%v' but got query '%v'", code, sqlState, subtext, query, serr.Query)
    54  }
    55  
    56  // TestConnectTimeout runs connection failure scenarios against a
    57  // server that's not listening or has trouble.  This test is not meant
    58  // to use a valid server. So we do not test bad handshakes here.
    59  func TestConnectTimeout(t *testing.T) {
    60  	// Create a socket, but it's not accepting. So all Dial
    61  	// attempts will timeout.
    62  	listener, err := net.Listen("tcp", "127.0.0.1:")
    63  	require.NoError(t, err, "cannot listen: %v", err)
    64  	host, port := getHostPort(t, listener.Addr())
    65  	params := &ConnParams{
    66  		Host: host,
    67  		Port: port,
    68  	}
    69  	defer listener.Close()
    70  
    71  	// Test that canceling the context really interrupts the Connect.
    72  	ctx, cancel := context.WithCancel(context.Background())
    73  	done := make(chan struct{})
    74  	go func() {
    75  		_, err := Connect(ctx, params)
    76  		assert.Equal(t, context.Canceled, err, "Was expecting context.Canceled but got: %v", err)
    77  		close(done)
    78  	}()
    79  	time.Sleep(100 * time.Millisecond)
    80  	cancel()
    81  	<-done
    82  
    83  	// Tests a connection timeout works.
    84  	ctx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond)
    85  	_, err = Connect(ctx, params)
    86  	cancel()
    87  	assert.Equal(t, context.DeadlineExceeded, err, "Was expecting context.DeadlineExceeded but got: %v", err)
    88  
    89  	// Tests a connection timeout through params
    90  	ctx = context.Background()
    91  	paramsWithTimeout := *params
    92  	paramsWithTimeout.ConnectTimeoutMs = 1
    93  	_, err = Connect(ctx, &paramsWithTimeout)
    94  	cancel()
    95  	assert.Equal(t, context.DeadlineExceeded, err, "Was expecting context.DeadlineExceeded but got: %v", err)
    96  
    97  	// Now the server will listen, but close all connections on accept.
    98  	wg := sync.WaitGroup{}
    99  	wg.Add(1)
   100  	go func() {
   101  		defer wg.Done()
   102  		for {
   103  			conn, err := listener.Accept()
   104  			if err != nil {
   105  				// Listener was closed.
   106  				return
   107  			}
   108  			conn.Close()
   109  		}
   110  	}()
   111  	ctx = context.Background()
   112  	_, err = Connect(ctx, params)
   113  	assertSQLError(t, err, CRServerLost, SSUnknownSQLState, "initial packet read failed", "", "")
   114  
   115  	// Now close the listener. Connect should fail right away,
   116  	// check the error.
   117  	listener.Close()
   118  	wg.Wait()
   119  	_, err = Connect(ctx, params)
   120  	assertSQLError(t, err, CRConnHostError, SSUnknownSQLState, "connection refused", "", "")
   121  
   122  	// Tests a connection where Dial to a unix socket fails
   123  	// properly returns the right error. To simulate exactly the
   124  	// right failure, try to dial a Unix socket that's just a temp file.
   125  	fd, err := os.CreateTemp("", "mysql")
   126  	require.NoError(t, err, "cannot create TempFile: %v", err)
   127  	name := fd.Name()
   128  	fd.Close()
   129  	params.UnixSocket = name
   130  	ctx = context.Background()
   131  	_, err = Connect(ctx, params)
   132  	os.Remove(name)
   133  	t.Log(err)
   134  	assertSQLError(t, err, CRConnectionError, SSUnknownSQLState, "connection refused", "", "net\\.Dial\\(([a-z0-9A-Z_\\/]*)\\) to local server failed:")
   135  }
   136  
   137  // TestTLSClientDisabled creates a Server with TLS support, then connects
   138  // with a client with TLS disabled.
   139  func TestTLSClientDisabled(t *testing.T) {
   140  	th := &testHandler{}
   141  
   142  	authServer := NewAuthServerStatic("", "", 0)
   143  	authServer.entries["user1"] = []*AuthServerStaticEntry{{
   144  		Password: "password1",
   145  	}}
   146  	defer authServer.close()
   147  
   148  	// Create the listener, so we can get its host.
   149  	// Below, we are enabling --ssl-verify-server-cert, which adds
   150  	// a check that the common name of the certificate matches the
   151  	// server host name we connect to.
   152  	l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false)
   153  	require.NoError(t, err)
   154  	defer l.Close()
   155  
   156  	host := l.Addr().(*net.TCPAddr).IP.String()
   157  	port := l.Addr().(*net.TCPAddr).Port
   158  
   159  	// Create the certs.
   160  	root := t.TempDir()
   161  	tlstest.CreateCA(root)
   162  	tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", host)
   163  	tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert")
   164  
   165  	// Create the server with TLS config.
   166  	serverConfig, err := vttls.ServerConfig(
   167  		path.Join(root, "server-cert.pem"),
   168  		path.Join(root, "server-key.pem"),
   169  		"",
   170  		"",
   171  		"",
   172  		tls.VersionTLS12)
   173  	require.NoError(t, err)
   174  	l.TLSConfig.Store(serverConfig)
   175  
   176  	var wg sync.WaitGroup
   177  	wg.Add(1)
   178  	go func(l *Listener) {
   179  		wg.Done()
   180  		l.Accept()
   181  	}(l)
   182  	// This is ensure the listener is called
   183  	wg.Wait()
   184  	// Sleep so that the Accept function is called as well.'
   185  	time.Sleep(3 * time.Second)
   186  
   187  	// Setup the right parameters.
   188  	params := &ConnParams{
   189  		Host:    host,
   190  		Port:    port,
   191  		Uname:   "user1",
   192  		Pass:    "password1",
   193  		SslMode: vttls.Disabled,
   194  	}
   195  
   196  	conn, err := Connect(context.Background(), params)
   197  	require.NoError(t, err)
   198  
   199  	// make sure this went through SSL
   200  	results, err := conn.ExecuteFetch("ssl echo", 1000, true)
   201  	require.NoError(t, err)
   202  	assert.Equal(t, "OFF", results.Rows[0][0].ToString())
   203  
   204  	if conn != nil {
   205  		conn.Close()
   206  	}
   207  }
   208  
   209  // TestTLSClientDisabled creates a Server with TLS support, then connects
   210  // with a client with TLS preferred.
   211  func TestTLSClientPreferredDefault(t *testing.T) {
   212  	th := &testHandler{}
   213  
   214  	authServer := NewAuthServerStatic("", "", 0)
   215  	authServer.entries["user1"] = []*AuthServerStaticEntry{{
   216  		Password: "password1",
   217  	}}
   218  	defer authServer.close()
   219  
   220  	// Create the listener, so we can get its host.
   221  	// Below, we are enabling --ssl-verify-server-cert, which adds
   222  	// a check that the common name of the certificate matches the
   223  	// server host name we connect to.
   224  	l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false)
   225  	require.NoError(t, err)
   226  	defer l.Close()
   227  
   228  	host := l.Addr().(*net.TCPAddr).IP.String()
   229  	port := l.Addr().(*net.TCPAddr).Port
   230  
   231  	// Create the certs.
   232  	root := t.TempDir()
   233  	tlstest.CreateCA(root)
   234  	tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com")
   235  	tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert")
   236  
   237  	// Create the server with TLS config.
   238  	serverConfig, err := vttls.ServerConfig(
   239  		path.Join(root, "server-cert.pem"),
   240  		path.Join(root, "server-key.pem"),
   241  		"",
   242  		"",
   243  		"",
   244  		tls.VersionTLS12)
   245  	require.NoError(t, err)
   246  	l.TLSConfig.Store(serverConfig)
   247  
   248  	var wg sync.WaitGroup
   249  	wg.Add(1)
   250  	go func(l *Listener) {
   251  		wg.Done()
   252  		l.Accept()
   253  	}(l)
   254  	// This is ensure the listener is called
   255  	wg.Wait()
   256  	// Sleep so that the Accept function is called as well.'
   257  	time.Sleep(3 * time.Second)
   258  
   259  	// Setup the right parameters.
   260  	params := &ConnParams{
   261  		Host:       host,
   262  		Port:       port,
   263  		Uname:      "user1",
   264  		Pass:       "password1",
   265  		SslMode:    vttls.Preferred,
   266  		ServerName: "server.example.com",
   267  	}
   268  
   269  	conn, err := Connect(context.Background(), params)
   270  	require.NoError(t, err)
   271  
   272  	// make sure this went through SSL
   273  	results, err := conn.ExecuteFetch("ssl echo", 1000, true)
   274  	require.NoError(t, err)
   275  	assert.Equal(t, "ON", results.Rows[0][0].ToString())
   276  
   277  	if conn != nil {
   278  		conn.Close()
   279  	}
   280  }
   281  
   282  // TestTLSClientRequired creates a Server with no TLS support, then connects
   283  // with a client with TLS required.
   284  func TestTLSClientRequired(t *testing.T) {
   285  	th := &testHandler{}
   286  
   287  	authServer := NewAuthServerStatic("", "", 0)
   288  	authServer.entries["user1"] = []*AuthServerStaticEntry{{
   289  		Password: "password1",
   290  	}}
   291  	defer authServer.close()
   292  
   293  	// Create the listener, so we can get its host.
   294  	// Below, we are enabling --ssl-verify-server-cert, which adds
   295  	// a check that the common name of the certificate matches the
   296  	// server host name we connect to.
   297  	l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false)
   298  	require.NoError(t, err)
   299  	defer l.Close()
   300  
   301  	host := l.Addr().(*net.TCPAddr).IP.String()
   302  	port := l.Addr().(*net.TCPAddr).Port
   303  
   304  	var wg sync.WaitGroup
   305  	wg.Add(1)
   306  	go func(l *Listener) {
   307  		wg.Done()
   308  		l.Accept()
   309  	}(l)
   310  	// This is ensure the listener is called
   311  	wg.Wait()
   312  	// Sleep so that the Accept function is called as well.'
   313  	time.Sleep(3 * time.Second)
   314  
   315  	// Setup the right parameters.
   316  	params := &ConnParams{
   317  		Host:    host,
   318  		Port:    port,
   319  		Uname:   "user1",
   320  		Pass:    "password1",
   321  		SslMode: vttls.Required,
   322  	}
   323  
   324  	_, err = Connect(context.Background(), params)
   325  	require.Error(t, err)
   326  	assert.Contains(t, err.Error(), "server doesn't support SSL but client asked for it")
   327  }
   328  
   329  // TestTLSClientVerifyCA creates a Server with TLS support, then connects
   330  // with a client with TLS enabled on a wrong hostname but with verify CA on.
   331  func TestTLSClientVerifyCA(t *testing.T) {
   332  	th := &testHandler{}
   333  
   334  	authServer := NewAuthServerStatic("", "", 0)
   335  	authServer.entries["user1"] = []*AuthServerStaticEntry{{
   336  		Password: "password1",
   337  	}}
   338  	defer authServer.close()
   339  
   340  	// Create the listener, so we can get its host.
   341  	// Below, we are enabling --ssl-verify-server-cert, which adds
   342  	// a check that the common name of the certificate matches the
   343  	// server host name we connect to.
   344  	l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false)
   345  	require.NoError(t, err)
   346  	defer l.Close()
   347  
   348  	host := l.Addr().(*net.TCPAddr).IP.String()
   349  	port := l.Addr().(*net.TCPAddr).Port
   350  
   351  	// Create the certs.
   352  	root := t.TempDir()
   353  	tlstest.CreateCA(root)
   354  	tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com")
   355  	tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert")
   356  
   357  	// Create the server with TLS config.
   358  	serverConfig, err := vttls.ServerConfig(
   359  		path.Join(root, "server-cert.pem"),
   360  		path.Join(root, "server-key.pem"),
   361  		"",
   362  		"",
   363  		"",
   364  		tls.VersionTLS12)
   365  	require.NoError(t, err)
   366  	l.TLSConfig.Store(serverConfig)
   367  
   368  	var wg sync.WaitGroup
   369  	wg.Add(1)
   370  	go func(l *Listener) {
   371  		wg.Done()
   372  		l.Accept()
   373  	}(l)
   374  	// This is ensure the listener is called
   375  	wg.Wait()
   376  	// Sleep so that the Accept function is called as well.'
   377  	time.Sleep(3 * time.Second)
   378  
   379  	// Setup the right parameters.
   380  	params := &ConnParams{
   381  		Host:  host,
   382  		Port:  port,
   383  		Uname: "user1",
   384  		Pass:  "password1",
   385  		// SSL flags.
   386  		SslMode:    vttls.VerifyCA,
   387  		ServerName: "server.example.com",
   388  	}
   389  
   390  	_, err = Connect(context.Background(), params)
   391  	require.Error(t, err)
   392  
   393  	fmt.Printf("Error: %s", err)
   394  
   395  	assert.Contains(t, err.Error(), "cannot send HandshakeResponse41: x509:")
   396  
   397  	// Now setup proper CA that is valid to verify
   398  	params.SslCa = path.Join(root, "ca-cert.pem")
   399  	conn, err := Connect(context.Background(), params)
   400  	require.NoError(t, err)
   401  
   402  	// make sure this went through SSL
   403  	results, err := conn.ExecuteFetch("ssl echo", 1000, true)
   404  	require.NoError(t, err)
   405  	assert.Equal(t, "ON", results.Rows[0][0].ToString())
   406  
   407  	if conn != nil {
   408  		conn.Close()
   409  	}
   410  }
   411  
   412  // TestTLSClientVerifyIdentity creates a Server with TLS support, then connects
   413  // with a client with TLS enabled on a wrong hostname but with verify CA on.
   414  func TestTLSClientVerifyIdentity(t *testing.T) {
   415  	th := &testHandler{}
   416  
   417  	authServer := NewAuthServerStatic("", "", 0)
   418  	authServer.entries["user1"] = []*AuthServerStaticEntry{{
   419  		Password: "password1",
   420  	}}
   421  	defer authServer.close()
   422  
   423  	// Create the listener, so we can get its host.
   424  	// Below, we are enabling --ssl-verify-server-cert, which adds
   425  	// a check that the common name of the certificate matches the
   426  	// server host name we connect to.
   427  	l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false)
   428  	require.NoError(t, err)
   429  	defer l.Close()
   430  
   431  	host := l.Addr().(*net.TCPAddr).IP.String()
   432  	port := l.Addr().(*net.TCPAddr).Port
   433  
   434  	// Create the certs.
   435  	root := t.TempDir()
   436  	tlstest.CreateCA(root)
   437  	tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com")
   438  	tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert")
   439  
   440  	// Create the server with TLS config.
   441  	serverConfig, err := vttls.ServerConfig(
   442  		path.Join(root, "server-cert.pem"),
   443  		path.Join(root, "server-key.pem"),
   444  		"",
   445  		"",
   446  		"",
   447  		tls.VersionTLS12)
   448  	require.NoError(t, err)
   449  	l.TLSConfig.Store(serverConfig)
   450  
   451  	var wg sync.WaitGroup
   452  	wg.Add(1)
   453  	go func(l *Listener) {
   454  		wg.Done()
   455  		l.Accept()
   456  	}(l)
   457  	// This is ensure the listener is called
   458  	wg.Wait()
   459  	// Sleep so that the Accept function is called as well.'
   460  	time.Sleep(3 * time.Second)
   461  
   462  	// Setup the right parameters.
   463  	params := &ConnParams{
   464  		Host:  host,
   465  		Port:  port,
   466  		Uname: "user1",
   467  		Pass:  "password1",
   468  		// SSL flags.
   469  		SslMode:    vttls.VerifyIdentity,
   470  		ServerName: "server.example.com",
   471  	}
   472  
   473  	_, err = Connect(context.Background(), params)
   474  	require.Error(t, err)
   475  
   476  	fmt.Printf("Error: %s", err)
   477  
   478  	assert.Contains(t, err.Error(), "cannot send HandshakeResponse41: tls:")
   479  
   480  	// Now setup proper CA that is valid to verify
   481  	params.SslCa = path.Join(root, "ca-cert.pem")
   482  	conn, err := Connect(context.Background(), params)
   483  	require.NoError(t, err)
   484  
   485  	// make sure this went through SSL
   486  	results, err := conn.ExecuteFetch("ssl echo", 1000, true)
   487  	require.NoError(t, err)
   488  	assert.Equal(t, "ON", results.Rows[0][0].ToString())
   489  
   490  	if conn != nil {
   491  		conn.Close()
   492  	}
   493  
   494  	// Now revoke the server certificate and make sure we can't connect
   495  	tlstest.RevokeCertAndRegenerateCRL(root, tlstest.CA, "server")
   496  
   497  	params.SslCrl = path.Join(root, "ca-crl.pem")
   498  	_, err = Connect(context.Background(), params)
   499  	require.Error(t, err)
   500  	require.Contains(t, err.Error(), "Certificate revoked: CommonName=server.example.com")
   501  }