vitess.io/vitess@v0.16.2/go/vt/vtgate/plugin_mysql_server_test.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package vtgate
    18  
    19  import (
    20  	"context"
    21  	"crypto/tls"
    22  	"fmt"
    23  	"os"
    24  	"path"
    25  	"strings"
    26  	"syscall"
    27  	"testing"
    28  	"time"
    29  
    30  	"github.com/stretchr/testify/assert"
    31  
    32  	"vitess.io/vitess/go/trace"
    33  
    34  	"vitess.io/vitess/go/mysql"
    35  	"vitess.io/vitess/go/sqltypes"
    36  	querypb "vitess.io/vitess/go/vt/proto/query"
    37  	"vitess.io/vitess/go/vt/tlstest"
    38  )
    39  
    40  type testHandler struct {
    41  	mysql.UnimplementedHandler
    42  	lastConn *mysql.Conn
    43  }
    44  
    45  func (th *testHandler) NewConnection(c *mysql.Conn) {
    46  	th.lastConn = c
    47  }
    48  
    49  func (th *testHandler) ComQuery(c *mysql.Conn, q string, callback func(*sqltypes.Result) error) error {
    50  	// when creating a connection, we send a query to MySQL to set the connection's collation,
    51  	// this query usually returns us something. however, we use testHandler which is a fake
    52  	// implementation of MySQL that returns no results and no error for set queries, Vitess
    53  	// interprets this as an error, we do not want to fail if we see such error.
    54  	// for this reason, we send back an empty result to the caller.
    55  	return callback(&sqltypes.Result{Fields: []*querypb.Field{}, Rows: [][]sqltypes.Value{}})
    56  }
    57  
    58  func (th *testHandler) ComPrepare(c *mysql.Conn, q string, b map[string]*querypb.BindVariable) ([]*querypb.Field, error) {
    59  	return nil, nil
    60  }
    61  
    62  func (th *testHandler) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error {
    63  	return nil
    64  }
    65  
    66  func (th *testHandler) ComRegisterReplica(c *mysql.Conn, replicaHost string, replicaPort uint16, replicaUser string, replicaPassword string) error {
    67  	return nil
    68  }
    69  
    70  func (th *testHandler) ComBinlogDump(c *mysql.Conn, logFile string, binlogPos uint32) error {
    71  	return nil
    72  }
    73  
    74  func (th *testHandler) ComBinlogDumpGTID(c *mysql.Conn, logFile string, logPos uint64, gtidSet mysql.GTIDSet) error {
    75  	return nil
    76  }
    77  
    78  func (th *testHandler) WarningCount(c *mysql.Conn) uint16 {
    79  	return 0
    80  }
    81  
    82  func TestConnectionUnixSocket(t *testing.T) {
    83  	th := &testHandler{}
    84  
    85  	authServer := newTestAuthServerStatic()
    86  
    87  	// Use tmp file to reserve a path, remove it immediately, we only care about
    88  	// name in this context
    89  	unixSocket, err := os.CreateTemp("", "mysql_vitess_test.sock")
    90  	if err != nil {
    91  		t.Fatalf("Failed to create temp file")
    92  	}
    93  	os.Remove(unixSocket.Name())
    94  
    95  	l, err := newMysqlUnixSocket(unixSocket.Name(), authServer, th)
    96  	if err != nil {
    97  		t.Fatalf("NewUnixSocket failed: %v", err)
    98  	}
    99  	defer l.Close()
   100  	go l.Accept()
   101  
   102  	params := &mysql.ConnParams{
   103  		UnixSocket: unixSocket.Name(),
   104  		Uname:      "user1",
   105  		Pass:       "password1",
   106  	}
   107  
   108  	c, err := mysql.Connect(context.Background(), params)
   109  	if err != nil {
   110  		t.Errorf("Should be able to connect to server but found error: %v", err)
   111  	}
   112  	c.Close()
   113  }
   114  
   115  func TestConnectionStaleUnixSocket(t *testing.T) {
   116  	th := &testHandler{}
   117  
   118  	authServer := newTestAuthServerStatic()
   119  
   120  	// First let's create a file. In this way, we simulate
   121  	// having a stale socket on disk that needs to be cleaned up.
   122  	unixSocket, err := os.CreateTemp("", "mysql_vitess_test.sock")
   123  	if err != nil {
   124  		t.Fatalf("Failed to create temp file")
   125  	}
   126  
   127  	l, err := newMysqlUnixSocket(unixSocket.Name(), authServer, th)
   128  	if err != nil {
   129  		t.Fatalf("NewListener failed: %v", err)
   130  	}
   131  	defer l.Close()
   132  	go l.Accept()
   133  
   134  	params := &mysql.ConnParams{
   135  		UnixSocket: unixSocket.Name(),
   136  		Uname:      "user1",
   137  		Pass:       "password1",
   138  	}
   139  
   140  	c, err := mysql.Connect(context.Background(), params)
   141  	if err != nil {
   142  		t.Errorf("Should be able to connect to server but found error: %v", err)
   143  	}
   144  	c.Close()
   145  }
   146  
   147  func TestConnectionRespectsExistingUnixSocket(t *testing.T) {
   148  	th := &testHandler{}
   149  
   150  	authServer := newTestAuthServerStatic()
   151  
   152  	unixSocket, err := os.CreateTemp("", "mysql_vitess_test.sock")
   153  	if err != nil {
   154  		t.Fatalf("Failed to create temp file")
   155  	}
   156  	os.Remove(unixSocket.Name())
   157  
   158  	l, err := newMysqlUnixSocket(unixSocket.Name(), authServer, th)
   159  	if err != nil {
   160  		t.Errorf("NewListener failed: %v", err)
   161  	}
   162  	defer l.Close()
   163  	go l.Accept()
   164  	_, err = newMysqlUnixSocket(unixSocket.Name(), authServer, th)
   165  	want := "listen unix"
   166  	if err == nil || !strings.HasPrefix(err.Error(), want) {
   167  		t.Errorf("Error: %v, want prefix %s", err, want)
   168  	}
   169  }
   170  
   171  var newSpanOK = func(ctx context.Context, label string) (trace.Span, context.Context) {
   172  	return trace.NoopSpan{}, context.Background()
   173  }
   174  
   175  var newFromStringOK = func(ctx context.Context, spanContext, label string) (trace.Span, context.Context, error) {
   176  	return trace.NoopSpan{}, context.Background(), nil
   177  }
   178  
   179  func newFromStringFail(t *testing.T) func(ctx context.Context, parentSpan string, label string) (trace.Span, context.Context, error) {
   180  	return func(ctx context.Context, parentSpan string, label string) (trace.Span, context.Context, error) {
   181  		t.Fatalf("we didn't provide a parent span in the sql query. this should not have been called. got: %v", parentSpan)
   182  		return trace.NoopSpan{}, context.Background(), nil
   183  	}
   184  }
   185  
   186  func newFromStringError(t *testing.T) func(ctx context.Context, parentSpan string, label string) (trace.Span, context.Context, error) {
   187  	return func(ctx context.Context, parentSpan string, label string) (trace.Span, context.Context, error) {
   188  		return trace.NoopSpan{}, context.Background(), fmt.Errorf("")
   189  	}
   190  }
   191  
   192  func newFromStringExpect(t *testing.T, expected string) func(ctx context.Context, parentSpan string, label string) (trace.Span, context.Context, error) {
   193  	return func(ctx context.Context, parentSpan string, label string) (trace.Span, context.Context, error) {
   194  		assert.Equal(t, expected, parentSpan)
   195  		return trace.NoopSpan{}, context.Background(), nil
   196  	}
   197  }
   198  
   199  func newSpanFail(t *testing.T) func(ctx context.Context, label string) (trace.Span, context.Context) {
   200  	return func(ctx context.Context, label string) (trace.Span, context.Context) {
   201  		t.Fatalf("we provided a span context but newFromString was not used as expected")
   202  		return trace.NoopSpan{}, context.Background()
   203  	}
   204  }
   205  
   206  func TestNoSpanContextPassed(t *testing.T) {
   207  	_, _, err := startSpanTestable(context.Background(), "sql without comments", "someLabel", newSpanOK, newFromStringFail(t))
   208  	assert.NoError(t, err)
   209  }
   210  
   211  func TestSpanContextNoPassedInButExistsInString(t *testing.T) {
   212  	_, _, err := startSpanTestable(context.Background(), "SELECT * FROM SOMETABLE WHERE COL = \"/*VT_SPAN_CONTEXT=123*/", "someLabel", newSpanOK, newFromStringFail(t))
   213  	assert.NoError(t, err)
   214  }
   215  
   216  func TestSpanContextPassedIn(t *testing.T) {
   217  	_, _, err := startSpanTestable(context.Background(), "/*VT_SPAN_CONTEXT=123*/SQL QUERY", "someLabel", newSpanFail(t), newFromStringOK)
   218  	assert.NoError(t, err)
   219  }
   220  
   221  func TestSpanContextPassedInEvenAroundOtherComments(t *testing.T) {
   222  	_, _, err := startSpanTestable(context.Background(), "/*VT_SPAN_CONTEXT=123*/SELECT /*vt+ SCATTER_ERRORS_AS_WARNINGS */ col1, col2 FROM TABLE ", "someLabel",
   223  		newSpanFail(t),
   224  		newFromStringExpect(t, "123"))
   225  	assert.NoError(t, err)
   226  }
   227  
   228  func TestSpanContextNotParsable(t *testing.T) {
   229  	hasRun := false
   230  	_, _, err := startSpanTestable(context.Background(), "/*VT_SPAN_CONTEXT=123*/SQL QUERY", "someLabel",
   231  		func(c context.Context, s string) (trace.Span, context.Context) {
   232  			hasRun = true
   233  			return trace.NoopSpan{}, context.Background()
   234  		},
   235  		newFromStringError(t))
   236  	assert.NoError(t, err)
   237  	assert.True(t, hasRun, "Should have continued execution despite failure to parse VT_SPAN_CONTEXT")
   238  }
   239  
   240  func newTestAuthServerStatic() *mysql.AuthServerStatic {
   241  	jsonConfig := "{\"user1\":{\"Password\":\"password1\", \"UserData\":\"userData1\", \"SourceHost\":\"localhost\"}}"
   242  	return mysql.NewAuthServerStatic("", jsonConfig, 0)
   243  }
   244  
   245  func TestDefaultWorkloadEmpty(t *testing.T) {
   246  	vh := &vtgateHandler{}
   247  	sess := vh.session(&mysql.Conn{})
   248  	if sess.Options.Workload != querypb.ExecuteOptions_OLTP {
   249  		t.Fatalf("Expected default workload OLTP")
   250  	}
   251  }
   252  
   253  func TestDefaultWorkloadOLAP(t *testing.T) {
   254  	vh := &vtgateHandler{}
   255  	mysqlDefaultWorkload = int32(querypb.ExecuteOptions_OLAP)
   256  	sess := vh.session(&mysql.Conn{})
   257  	if sess.Options.Workload != querypb.ExecuteOptions_OLAP {
   258  		t.Fatalf("Expected default workload OLAP")
   259  	}
   260  }
   261  
   262  func TestInitTLSConfigWithoutServerCA(t *testing.T) {
   263  	testInitTLSConfig(t, false)
   264  }
   265  
   266  func TestInitTLSConfigWithServerCA(t *testing.T) {
   267  	testInitTLSConfig(t, true)
   268  }
   269  
   270  func testInitTLSConfig(t *testing.T, serverCA bool) {
   271  	// Create the certs.
   272  	root := t.TempDir()
   273  	tlstest.CreateCA(root)
   274  	tlstest.CreateCRL(root, tlstest.CA)
   275  	tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com")
   276  
   277  	serverCACert := ""
   278  	if serverCA {
   279  		serverCACert = path.Join(root, "ca-cert.pem")
   280  	}
   281  
   282  	listener := &mysql.Listener{}
   283  	if err := initTLSConfig(listener, path.Join(root, "server-cert.pem"), path.Join(root, "server-key.pem"), path.Join(root, "ca-cert.pem"), path.Join(root, "ca-crl.pem"), serverCACert, true, tls.VersionTLS12); err != nil {
   284  		t.Fatalf("init tls config failure due to: +%v", err)
   285  	}
   286  
   287  	serverConfig := listener.TLSConfig.Load()
   288  	if serverConfig == nil {
   289  		t.Fatalf("init tls config shouldn't create nil server config")
   290  	}
   291  
   292  	sigChan <- syscall.SIGHUP
   293  	time.Sleep(100 * time.Millisecond) // wait for signal handler
   294  
   295  	if listener.TLSConfig.Load() == serverConfig {
   296  		t.Fatalf("init tls config should have been recreated after SIGHUP")
   297  	}
   298  }