vitess.io/vitess@v0.16.2/go/mysql/endtoend/main_test.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package endtoend
    18  
    19  import (
    20  	"flag"
    21  	"fmt"
    22  	"os"
    23  	"os/exec"
    24  	"path"
    25  	"strings"
    26  	"testing"
    27  
    28  	"github.com/stretchr/testify/require"
    29  
    30  	"vitess.io/vitess/go/mysql"
    31  	vtenv "vitess.io/vitess/go/vt/env"
    32  	"vitess.io/vitess/go/vt/mysqlctl"
    33  	"vitess.io/vitess/go/vt/tlstest"
    34  	"vitess.io/vitess/go/vt/vttest"
    35  
    36  	vttestpb "vitess.io/vitess/go/vt/proto/vttest"
    37  )
    38  
    39  var (
    40  	connParams mysql.ConnParams
    41  )
    42  
    43  // assertSQLError makes sure we get the right error.
    44  func assertSQLError(t *testing.T, err error, code int, sqlState string, subtext string, query string) {
    45  	t.Helper()
    46  	require.Error(t, err, "was expecting SQLError %v / %v / %v but got no error.", code, sqlState, subtext)
    47  
    48  	serr, ok := err.(*mysql.SQLError)
    49  	require.True(t, ok, "was expecting SQLError %v / %v / %v but got: %v", code, sqlState, subtext, err)
    50  	require.Equal(t, code, serr.Num, "was expecting SQLError %v / %v / %v but got code %v", code, sqlState, subtext, serr.Num)
    51  	require.Equal(t, sqlState, serr.State, "was expecting SQLError %v / %v / %v but got state %v", code, sqlState, subtext, serr.State)
    52  	require.True(t, subtext == "" || strings.Contains(serr.Message, subtext), "was expecting SQLError %v / %v / %v but got message %v", code, sqlState, subtext, serr.Message)
    53  	require.Equal(t, query, serr.Query, "was expecting SQLError %v / %v / %v with Query '%v' but got query '%v'", code, sqlState, subtext, query, serr.Query)
    54  
    55  }
    56  
    57  // runMysql forks a mysql command line process connecting to the provided server.
    58  func runMysql(t *testing.T, params *mysql.ConnParams, command string) (string, bool) {
    59  	dir, err := vtenv.VtMysqlRoot()
    60  	require.NoError(t, err, "vtenv.VtMysqlRoot failed: %v", err)
    61  
    62  	name, err := binaryPath(dir, "mysql")
    63  	require.NoError(t, err, "binaryPath failed: %v", err)
    64  
    65  	// The args contain '-v' 3 times, to switch to very verbose output.
    66  	// In particular, it has the message:
    67  	// Query OK, 1 row affected (0.00 sec)
    68  
    69  	version, getErr := mysqlctl.GetVersionString()
    70  	f, v, err := mysqlctl.ParseVersionString(version)
    71  
    72  	if getErr != nil || err != nil {
    73  		f, v, err = mysqlctl.GetVersionFromEnv()
    74  		if err != nil {
    75  			vtenvMysqlRoot, _ := vtenv.VtMysqlRoot()
    76  			message := fmt.Sprintf(`could not auto-detect MySQL version. You may need to set your PATH so a mysqld binary can be found, or set the environment variable MYSQL_FLAVOR if mysqld is not available locally:
    77  	PATH: %s
    78  	VT_MYSQL_ROOT: %s
    79  	VTROOT: %s
    80  	vtenv.VtMysqlRoot(): %s
    81  	MYSQL_FLAVOR: %s
    82  	`,
    83  				os.Getenv("PATH"),
    84  				os.Getenv("VT_MYSQL_ROOT"),
    85  				os.Getenv("VTROOT"),
    86  				vtenvMysqlRoot,
    87  				os.Getenv("MYSQL_FLAVOR"))
    88  			panic(message)
    89  		}
    90  	}
    91  
    92  	t.Logf("Using flavor: %v, version: %v", f, v)
    93  
    94  	args := []string{
    95  		"-v", "-v", "-v",
    96  	}
    97  	args = append(args, "-e", command)
    98  	if params.UnixSocket != "" {
    99  		args = append(args, "-S", params.UnixSocket)
   100  	} else {
   101  		args = append(args,
   102  			"-h", params.Host,
   103  			"-P", fmt.Sprintf("%v", params.Port))
   104  	}
   105  	if params.Uname != "" {
   106  		args = append(args, "-u", params.Uname)
   107  	}
   108  	if params.Pass != "" {
   109  		args = append(args, "-p"+params.Pass)
   110  	}
   111  	if params.DbName != "" {
   112  		args = append(args, "-D", params.DbName)
   113  	}
   114  	if params.SslEnabled() {
   115  		if f == mysqlctl.FlavorMySQL && v.Major == 5 && v.Minor == 7 || v.Major == 8 {
   116  			args = append(args,
   117  				fmt.Sprintf("--ssl-mode=%s", params.EffectiveSslMode()))
   118  		} else {
   119  			args = append(args,
   120  				"--ssl",
   121  				"--ssl-verify-server-cert")
   122  		}
   123  		args = append(args,
   124  			"--ssl-ca", params.SslCa,
   125  			"--ssl-cert", params.SslCert,
   126  			"--ssl-key", params.SslKey)
   127  	}
   128  	env := []string{
   129  		"LD_LIBRARY_PATH=" + path.Join(dir, "lib/mysql"),
   130  	}
   131  
   132  	cmd := exec.Command(name, args...)
   133  	cmd.Env = env
   134  	cmd.Dir = dir
   135  	out, err := cmd.CombinedOutput()
   136  	output := string(out)
   137  	if err != nil {
   138  		return output, false
   139  	}
   140  	return output, true
   141  }
   142  
   143  // binaryPath does a limited path lookup for a command,
   144  // searching only within sbin and bin in the given root.
   145  //
   146  // FIXME(alainjobart) move this to vt/env, and use it from
   147  // go/vt/mysqlctl too.
   148  func binaryPath(root, binary string) (string, error) {
   149  	subdirs := []string{"sbin", "bin"}
   150  	for _, subdir := range subdirs {
   151  		binPath := path.Join(root, subdir, binary)
   152  		if _, err := os.Stat(binPath); err == nil {
   153  			return binPath, nil
   154  		}
   155  	}
   156  	return "", fmt.Errorf("%s not found in any of %s/{%s}",
   157  		binary, root, strings.Join(subdirs, ","))
   158  }
   159  
   160  func TestMain(m *testing.M) {
   161  	flag.Parse() // Do not remove this comment, import into google3 depends on it
   162  
   163  	exitCode := func() int {
   164  		// Create the certs.
   165  		root, err := os.MkdirTemp("", "TestTLSServer")
   166  		if err != nil {
   167  			fmt.Fprintf(os.Stderr, "TempDir failed: %v", err)
   168  			return 1
   169  		}
   170  		defer os.RemoveAll(root)
   171  		tlstest.CreateCA(root)
   172  		tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "localhost")
   173  		tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert")
   174  
   175  		// Create the extra SSL my.cnf lines.
   176  		cnf := fmt.Sprintf(`
   177  ssl-ca=%v/ca-cert.pem
   178  ssl-cert=%v/server-cert.pem
   179  ssl-key=%v/server-key.pem
   180  `, root, root, root)
   181  		extraMyCnf := path.Join(root, "ssl_my.cnf")
   182  		if err := os.WriteFile(extraMyCnf, []byte(cnf), os.ModePerm); err != nil {
   183  			fmt.Fprintf(os.Stderr, "os.WriteFile(%v) failed: %v", extraMyCnf, err)
   184  			return 1
   185  		}
   186  
   187  		// For LargeQuery tests
   188  		cnf = "max_allowed_packet=100M\n"
   189  		maxPacketMyCnf := path.Join(root, "max_packet.cnf")
   190  		if err := os.WriteFile(maxPacketMyCnf, []byte(cnf), os.ModePerm); err != nil {
   191  			fmt.Fprintf(os.Stderr, "os.WriteFile(%v) failed: %v", maxPacketMyCnf, err)
   192  			return 1
   193  		}
   194  
   195  		// Launch MySQL.
   196  		// We need a Keyspace in the topology, so the DbName is set.
   197  		// We need a Shard too, so the database 'vttest' is created.
   198  		cfg := vttest.Config{
   199  			Topology: &vttestpb.VTTestTopology{
   200  				Keyspaces: []*vttestpb.Keyspace{
   201  					{
   202  						Name: "vttest",
   203  						Shards: []*vttestpb.Shard{
   204  							{
   205  								Name:           "0",
   206  								DbNameOverride: "vttest",
   207  							},
   208  						},
   209  					},
   210  				},
   211  			},
   212  			OnlyMySQL:  true,
   213  			ExtraMyCnf: []string{extraMyCnf, maxPacketMyCnf},
   214  		}
   215  		cluster := vttest.LocalCluster{
   216  			Config: cfg,
   217  		}
   218  		if err := cluster.Setup(); err != nil {
   219  			fmt.Fprintf(os.Stderr, "could not launch mysql: %v\n", err)
   220  			return 1
   221  		}
   222  		defer cluster.TearDown()
   223  		connParams = cluster.MySQLConnParams()
   224  
   225  		// Add the SSL parts, but they're not enabled until
   226  		// the flag is set.
   227  		connParams.SslCa = path.Join(root, "ca-cert.pem")
   228  		connParams.SslCert = path.Join(root, "client-cert.pem")
   229  		connParams.SslKey = path.Join(root, "client-key.pem")
   230  
   231  		// Uncomment to sleep and be able to connect to MySQL
   232  		// fmt.Printf("Connect to MySQL using parameters:\n")
   233  		// json.NewEncoder(os.Stdout).Encode(connParams)
   234  		// time.Sleep(10 * time.Minute)
   235  
   236  		return m.Run()
   237  	}()
   238  	os.Exit(exitCode)
   239  }