vitess.io/vitess@v0.16.2/go/cmd/vttestserver/vttestserver_test.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package main
    18  
    19  import (
    20  	"context"
    21  	"fmt"
    22  	"io"
    23  	"math/rand"
    24  	"os"
    25  	"os/exec"
    26  	"path"
    27  	"strings"
    28  	"testing"
    29  	"time"
    30  
    31  	"github.com/hashicorp/consul/api"
    32  	"github.com/stretchr/testify/assert"
    33  	"github.com/stretchr/testify/require"
    34  	"google.golang.org/protobuf/encoding/protojson"
    35  
    36  	"vitess.io/vitess/go/mysql"
    37  	"vitess.io/vitess/go/sqltypes"
    38  	"vitess.io/vitess/go/vt/log"
    39  	"vitess.io/vitess/go/vt/logutil"
    40  	"vitess.io/vitess/go/vt/tlstest"
    41  	"vitess.io/vitess/go/vt/vtctl/vtctlclient"
    42  	"vitess.io/vitess/go/vt/vttest"
    43  
    44  	logutilpb "vitess.io/vitess/go/vt/proto/logutil"
    45  	vschemapb "vitess.io/vitess/go/vt/proto/vschema"
    46  )
    47  
    48  type columnVindex struct {
    49  	keyspace   string
    50  	table      string
    51  	vindex     string
    52  	vindexType string
    53  	column     string
    54  }
    55  
    56  func TestRunsVschemaMigrations(t *testing.T) {
    57  	args := os.Args
    58  	conf := config
    59  	defer resetFlags(args, conf)
    60  
    61  	cluster, err := startCluster()
    62  	defer cluster.TearDown()
    63  
    64  	assert.NoError(t, err)
    65  	assertColumnVindex(t, cluster, columnVindex{keyspace: "test_keyspace", table: "test_table", vindex: "my_vdx", vindexType: "hash", column: "id"})
    66  	assertColumnVindex(t, cluster, columnVindex{keyspace: "app_customer", table: "customers", vindex: "hash", vindexType: "hash", column: "id"})
    67  
    68  	// Add Hash vindex via vtgate execution on table
    69  	err = addColumnVindex(cluster, "test_keyspace", "alter vschema on test_table1 add vindex my_vdx (id)")
    70  	assert.NoError(t, err)
    71  	assertColumnVindex(t, cluster, columnVindex{keyspace: "test_keyspace", table: "test_table1", vindex: "my_vdx", vindexType: "hash", column: "id"})
    72  }
    73  
    74  func TestPersistentMode(t *testing.T) {
    75  	args := os.Args
    76  	conf := config
    77  	defer resetFlags(args, conf)
    78  
    79  	dir := t.TempDir()
    80  
    81  	cluster, err := startPersistentCluster(dir)
    82  	assert.NoError(t, err)
    83  
    84  	// basic sanity checks similar to TestRunsVschemaMigrations
    85  	assertColumnVindex(t, cluster, columnVindex{keyspace: "test_keyspace", table: "test_table", vindex: "my_vdx", vindexType: "hash", column: "id"})
    86  	assertColumnVindex(t, cluster, columnVindex{keyspace: "app_customer", table: "customers", vindex: "hash", vindexType: "hash", column: "id"})
    87  
    88  	// insert some data to ensure persistence across teardowns
    89  	err = execOnCluster(cluster, "app_customer", func(conn *mysql.Conn) error {
    90  		_, err := conn.ExecuteFetch("insert into customers (id, name) values (1, 'gopherson')", 1, false)
    91  		return err
    92  	})
    93  	assert.NoError(t, err)
    94  
    95  	expectedRows := [][]sqltypes.Value{
    96  		{sqltypes.NewInt64(1), sqltypes.NewVarChar("gopherson"), sqltypes.NULL},
    97  	}
    98  
    99  	// ensure data was actually inserted
   100  	var res *sqltypes.Result
   101  	err = execOnCluster(cluster, "app_customer", func(conn *mysql.Conn) (err error) {
   102  		res, err = conn.ExecuteFetch("SELECT * FROM customers", 1, false)
   103  		return err
   104  	})
   105  	assert.NoError(t, err)
   106  	assert.Equal(t, expectedRows, res.Rows)
   107  
   108  	// reboot the persistent cluster
   109  	cluster.TearDown()
   110  	cluster, err = startPersistentCluster(dir)
   111  	defer cluster.TearDown()
   112  	assert.NoError(t, err)
   113  
   114  	// rerun our sanity checks to make sure vschema migrations are run during every startup
   115  	assertColumnVindex(t, cluster, columnVindex{keyspace: "test_keyspace", table: "test_table", vindex: "my_vdx", vindexType: "hash", column: "id"})
   116  	assertColumnVindex(t, cluster, columnVindex{keyspace: "app_customer", table: "customers", vindex: "hash", vindexType: "hash", column: "id"})
   117  
   118  	// ensure previous data was successfully persisted
   119  	err = execOnCluster(cluster, "app_customer", func(conn *mysql.Conn) (err error) {
   120  		res, err = conn.ExecuteFetch("SELECT * FROM customers", 1, false)
   121  		return err
   122  	})
   123  	assert.NoError(t, err)
   124  	assert.Equal(t, expectedRows, res.Rows)
   125  }
   126  
   127  func TestForeignKeysAndDDLModes(t *testing.T) {
   128  	args := os.Args
   129  	conf := config
   130  	defer resetFlags(args, conf)
   131  
   132  	cluster, err := startCluster("--foreign_key_mode=allow", "--enable_online_ddl=true", "--enable_direct_ddl=true")
   133  	assert.NoError(t, err)
   134  	defer cluster.TearDown()
   135  
   136  	err = execOnCluster(cluster, "test_keyspace", func(conn *mysql.Conn) error {
   137  		_, err := conn.ExecuteFetch(`CREATE TABLE test_table_2 (
   138  			id BIGINT,
   139  			test_table_id BIGINT,
   140  			FOREIGN KEY (test_table_id) REFERENCES test_table(id)
   141  		)`, 1, false)
   142  		assert.NoError(t, err)
   143  		_, err = conn.ExecuteFetch("SET @@ddl_strategy='online'", 1, false)
   144  		assert.NoError(t, err)
   145  		_, err = conn.ExecuteFetch("ALTER TABLE test_table ADD COLUMN something_else VARCHAR(255) NOT NULL DEFAULT ''", 1, false)
   146  		assert.NoError(t, err)
   147  		_, err = conn.ExecuteFetch("SET @@ddl_strategy='direct'", 1, false)
   148  		assert.NoError(t, err)
   149  		_, err = conn.ExecuteFetch("ALTER TABLE test_table ADD COLUMN something_else_2 VARCHAR(255) NOT NULL DEFAULT ''", 1, false)
   150  		assert.NoError(t, err)
   151  		_, err = conn.ExecuteFetch("SELECT something_else_2 FROM test_table", 1, false)
   152  		assert.NoError(t, err)
   153  		return nil
   154  	})
   155  	assert.NoError(t, err)
   156  
   157  	cluster.TearDown()
   158  	cluster, err = startCluster("--foreign_key_mode=disallow", "--enable_online_ddl=false", "--enable_direct_ddl=false")
   159  	assert.NoError(t, err)
   160  	defer cluster.TearDown()
   161  
   162  	err = execOnCluster(cluster, "test_keyspace", func(conn *mysql.Conn) error {
   163  		_, err := conn.ExecuteFetch(`CREATE TABLE test_table_2 (
   164  			id BIGINT,
   165  			test_table_id BIGINT,
   166  			FOREIGN KEY (test_table_id) REFERENCES test_table(id)
   167  		)`, 1, false)
   168  		assert.Error(t, err)
   169  		_, err = conn.ExecuteFetch("SET @@ddl_strategy='online'", 1, false)
   170  		assert.NoError(t, err)
   171  		_, err = conn.ExecuteFetch("ALTER TABLE test_table ADD COLUMN something_else VARCHAR(255) NOT NULL DEFAULT ''", 1, false)
   172  		assert.Error(t, err)
   173  		_, err = conn.ExecuteFetch("SET @@ddl_strategy='direct'", 1, false)
   174  		assert.NoError(t, err)
   175  		_, err = conn.ExecuteFetch("ALTER TABLE test_table ADD COLUMN something_else VARCHAR(255) NOT NULL DEFAULT ''", 1, false)
   176  		assert.Error(t, err)
   177  		return nil
   178  	})
   179  	assert.NoError(t, err)
   180  }
   181  
   182  func TestCanGetKeyspaces(t *testing.T) {
   183  	args := os.Args
   184  	conf := config
   185  	defer resetFlags(args, conf)
   186  
   187  	cluster, err := startCluster()
   188  	assert.NoError(t, err)
   189  	defer cluster.TearDown()
   190  
   191  	assertGetKeyspaces(t, cluster)
   192  }
   193  
   194  func TestExternalTopoServerConsul(t *testing.T) {
   195  	args := os.Args
   196  	conf := config
   197  	defer resetFlags(args, conf)
   198  
   199  	// Start a single consul in the background.
   200  	cmd, serverAddr := startConsul(t)
   201  	defer func() {
   202  		// Alerts command did not run successful
   203  		if err := cmd.Process.Kill(); err != nil {
   204  			log.Errorf("cmd process kill has an error: %v", err)
   205  		}
   206  		// Alerts command did not run successful
   207  		if err := cmd.Wait(); err != nil {
   208  			log.Errorf("cmd process wait has an error: %v", err)
   209  		}
   210  	}()
   211  
   212  	cluster, err := startCluster("--external_topo_implementation=consul",
   213  		fmt.Sprintf("--external_topo_global_server_address=%s", serverAddr), "--external_topo_global_root=consul_test/global")
   214  	assert.NoError(t, err)
   215  	defer cluster.TearDown()
   216  
   217  	assertGetKeyspaces(t, cluster)
   218  }
   219  
   220  func TestMtlsAuth(t *testing.T) {
   221  	args := os.Args
   222  	conf := config
   223  	defer resetFlags(args, conf)
   224  
   225  	// Our test root.
   226  	root := t.TempDir()
   227  
   228  	// Create the certs and configs.
   229  	tlstest.CreateCA(root)
   230  	caCert := path.Join(root, "ca-cert.pem")
   231  
   232  	tlstest.CreateSignedCert(root, tlstest.CA, "01", "vtctld", "vtctld.example.com")
   233  	cert := path.Join(root, "vtctld-cert.pem")
   234  	key := path.Join(root, "vtctld-key.pem")
   235  
   236  	tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "ClientApp")
   237  	clientCert := path.Join(root, "client-cert.pem")
   238  	clientKey := path.Join(root, "client-key.pem")
   239  
   240  	// When cluster starts it will apply SQL and VSchema migrations in the configured schema_dir folder
   241  	// With mtls authorization enabled, the authorized CN must match the certificate's CN
   242  	cluster, err := startCluster(
   243  		"--grpc_auth_mode=mtls",
   244  		fmt.Sprintf("--grpc_key=%s", key),
   245  		fmt.Sprintf("--grpc_cert=%s", cert),
   246  		fmt.Sprintf("--grpc_ca=%s", caCert),
   247  		fmt.Sprintf("--vtctld_grpc_key=%s", clientKey),
   248  		fmt.Sprintf("--vtctld_grpc_cert=%s", clientCert),
   249  		fmt.Sprintf("--vtctld_grpc_ca=%s", caCert),
   250  		fmt.Sprintf("--grpc_auth_mtls_allowed_substrings=%s", "CN=ClientApp"))
   251  	assert.NoError(t, err)
   252  	defer cluster.TearDown()
   253  
   254  	// startCluster will apply vschema migrations using vtctl grpc and the clientCert.
   255  	assertColumnVindex(t, cluster, columnVindex{keyspace: "test_keyspace", table: "test_table", vindex: "my_vdx", vindexType: "hash", column: "id"})
   256  	assertColumnVindex(t, cluster, columnVindex{keyspace: "app_customer", table: "customers", vindex: "hash", vindexType: "hash", column: "id"})
   257  }
   258  
   259  func TestMtlsAuthUnauthorizedFails(t *testing.T) {
   260  	args := os.Args
   261  	conf := config
   262  	defer resetFlags(args, conf)
   263  
   264  	// Our test root.
   265  	root := t.TempDir()
   266  
   267  	// Create the certs and configs.
   268  	tlstest.CreateCA(root)
   269  	caCert := path.Join(root, "ca-cert.pem")
   270  
   271  	tlstest.CreateSignedCert(root, tlstest.CA, "01", "vtctld", "vtctld.example.com")
   272  	cert := path.Join(root, "vtctld-cert.pem")
   273  	key := path.Join(root, "vtctld-key.pem")
   274  
   275  	tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "AnotherApp")
   276  	clientCert := path.Join(root, "client-cert.pem")
   277  	clientKey := path.Join(root, "client-key.pem")
   278  
   279  	// When cluster starts it will apply SQL and VSchema migrations in the configured schema_dir folder
   280  	// For mtls authorization failure by providing a client certificate with different CN thant the
   281  	// authorized in the configuration
   282  	cluster, err := startCluster(
   283  		"--grpc_auth_mode=mtls",
   284  		fmt.Sprintf("--grpc_key=%s", key),
   285  		fmt.Sprintf("--grpc_cert=%s", cert),
   286  		fmt.Sprintf("--grpc_ca=%s", caCert),
   287  		fmt.Sprintf("--vtctld_grpc_key=%s", clientKey),
   288  		fmt.Sprintf("--vtctld_grpc_cert=%s", clientCert),
   289  		fmt.Sprintf("--vtctld_grpc_ca=%s", caCert),
   290  		fmt.Sprintf("--grpc_auth_mtls_allowed_substrings=%s", "CN=ClientApp"))
   291  	defer cluster.TearDown()
   292  
   293  	assert.Error(t, err)
   294  	assert.Contains(t, err.Error(), "code = Unauthenticated desc = client certificate not authorized")
   295  }
   296  
   297  func startPersistentCluster(dir string, flags ...string) (vttest.LocalCluster, error) {
   298  	flags = append(flags, []string{
   299  		"--persistent_mode",
   300  		// FIXME: if port is not provided, data_dir is not respected
   301  		fmt.Sprintf("--port=%d", randomPort()),
   302  		fmt.Sprintf("--data_dir=%s", dir),
   303  	}...)
   304  	return startCluster(flags...)
   305  }
   306  
   307  var clusterKeyspaces = []string{
   308  	"test_keyspace",
   309  	"app_customer",
   310  }
   311  
   312  func startCluster(flags ...string) (vttest.LocalCluster, error) {
   313  	os.Args = []string{"vttestserver"}
   314  	schemaDirArg := "--schema_dir=data/schema"
   315  	tabletHostname := "--tablet_hostname=localhost"
   316  	keyspaceArg := "--keyspaces=" + strings.Join(clusterKeyspaces, ",")
   317  	numShardsArg := "--num_shards=2,2"
   318  	vschemaDDLAuthorizedUsers := "--vschema_ddl_authorized_users=%"
   319  	os.Args = append(os.Args, []string{schemaDirArg, keyspaceArg, numShardsArg, tabletHostname, vschemaDDLAuthorizedUsers}...)
   320  	os.Args = append(os.Args, flags...)
   321  	return runCluster()
   322  }
   323  
   324  func addColumnVindex(cluster vttest.LocalCluster, keyspace string, vschemaMigration string) error {
   325  	return execOnCluster(cluster, keyspace, func(conn *mysql.Conn) error {
   326  		_, err := conn.ExecuteFetch(vschemaMigration, 1, false)
   327  		return err
   328  	})
   329  }
   330  
   331  func execOnCluster(cluster vttest.LocalCluster, keyspace string, f func(*mysql.Conn) error) error {
   332  	ctx := context.Background()
   333  	vtParams := mysql.ConnParams{
   334  		Host:   "localhost",
   335  		DbName: keyspace,
   336  		Port:   cluster.Env.PortForProtocol("vtcombo_mysql_port", ""),
   337  	}
   338  
   339  	conn, err := mysql.Connect(ctx, &vtParams)
   340  	if err != nil {
   341  		return err
   342  	}
   343  	defer conn.Close()
   344  	return f(conn)
   345  }
   346  
   347  func assertColumnVindex(t *testing.T, cluster vttest.LocalCluster, expected columnVindex) {
   348  	server := fmt.Sprintf("localhost:%v", cluster.GrpcPort())
   349  	args := []string{"GetVSchema", expected.keyspace}
   350  	ctx := context.Background()
   351  
   352  	err := vtctlclient.RunCommandAndWait(ctx, server, args, func(e *logutilpb.Event) {
   353  		var keyspace vschemapb.Keyspace
   354  		if err := protojson.Unmarshal([]byte(e.Value), &keyspace); err != nil {
   355  			t.Error(err)
   356  		}
   357  
   358  		columnVindex := keyspace.Tables[expected.table].ColumnVindexes[0]
   359  		actualVindex := keyspace.Vindexes[expected.vindex]
   360  		assertEqual(t, actualVindex.Type, expected.vindexType, "Actual vindex type different from expected")
   361  		assertEqual(t, columnVindex.Name, expected.vindex, "Actual vindex name different from expected")
   362  		assertEqual(t, columnVindex.Columns[0], expected.column, "Actual vindex column different from expected")
   363  	})
   364  	require.NoError(t, err)
   365  }
   366  
   367  func assertEqual(t *testing.T, actual string, expected string, message string) {
   368  	if actual != expected {
   369  		t.Errorf("%s: actual %s, expected %s", message, actual, expected)
   370  	}
   371  }
   372  
   373  func resetFlags(args []string, conf vttest.Config) {
   374  	os.Args = args
   375  	config = conf
   376  }
   377  
   378  func randomPort() int {
   379  	v := rand.Int31n(20000)
   380  	return int(v + 10000)
   381  }
   382  
   383  func assertGetKeyspaces(t *testing.T, cluster vttest.LocalCluster) {
   384  	client, err := vtctlclient.New(fmt.Sprintf("localhost:%v", cluster.GrpcPort()))
   385  	assert.NoError(t, err)
   386  	defer client.Close()
   387  	stream, err := client.ExecuteVtctlCommand(
   388  		context.Background(),
   389  		[]string{
   390  			"GetKeyspaces",
   391  			"--server",
   392  			fmt.Sprintf("localhost:%v", cluster.GrpcPort()),
   393  		},
   394  		30*time.Second,
   395  	)
   396  	assert.NoError(t, err)
   397  
   398  	resp, err := consumeEventStream(stream)
   399  	require.NoError(t, err)
   400  
   401  	keyspaces := strings.Split(resp, "\n")
   402  	if keyspaces[len(keyspaces)-1] == "" { // trailing newlines make Split annoying
   403  		keyspaces = keyspaces[:len(keyspaces)-1]
   404  	}
   405  
   406  	assert.ElementsMatch(t, clusterKeyspaces, keyspaces)
   407  }
   408  
   409  func consumeEventStream(stream logutil.EventStream) (string, error) {
   410  	var buf strings.Builder
   411  	for {
   412  		switch e, err := stream.Recv(); err {
   413  		case nil:
   414  			buf.WriteString(e.Value)
   415  		case io.EOF:
   416  			return buf.String(), nil
   417  		default:
   418  			return "", err
   419  		}
   420  	}
   421  }
   422  
   423  // startConsul starts a consul subprocess, and waits for it to be ready.
   424  // Returns the exec.Cmd forked, and the server address to RPC-connect to.
   425  func startConsul(t *testing.T) (*exec.Cmd, string) {
   426  	// pick a random port to make sure things work with non-default port
   427  	port := randomPort()
   428  
   429  	cmd := exec.Command("consul",
   430  		"agent",
   431  		"-dev",
   432  		"-http-port", fmt.Sprintf("%d", port))
   433  	err := cmd.Start()
   434  	if err != nil {
   435  		t.Fatalf("failed to start consul: %v", err)
   436  	}
   437  
   438  	// Create a client to connect to the created consul.
   439  	serverAddr := fmt.Sprintf("localhost:%v", port)
   440  	cfg := api.DefaultConfig()
   441  	cfg.Address = serverAddr
   442  	c, err := api.NewClient(cfg)
   443  	if err != nil {
   444  		t.Fatalf("api.NewClient(%v) failed: %v", serverAddr, err)
   445  	}
   446  
   447  	// Wait until we can list "/", or timeout.
   448  	start := time.Now()
   449  	kv := c.KV()
   450  	for {
   451  		_, _, err := kv.List("/", nil)
   452  		if err == nil {
   453  			break
   454  		}
   455  		if time.Since(start) > 10*time.Second {
   456  			t.Fatalf("Failed to start consul daemon in time. Consul is returning error: %v", err)
   457  		}
   458  		time.Sleep(10 * time.Millisecond)
   459  	}
   460  
   461  	return cmd, serverAddr
   462  }