github.com/matrixorigin/matrixone@v1.2.0/pkg/proxy/handler_test.go (about)

     1  // Copyright 2021 - 2023 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package proxy
    16  
    17  import (
    18  	"context"
    19  	"crypto/rand"
    20  	"crypto/rsa"
    21  	"crypto/tls"
    22  	"crypto/x509"
    23  	"crypto/x509/pkix"
    24  	"database/sql"
    25  	"encoding/pem"
    26  	"fmt"
    27  	"math/big"
    28  	"os"
    29  	"testing"
    30  	"time"
    31  
    32  	"github.com/go-sql-driver/mysql"
    33  	"github.com/lni/goutils/leaktest"
    34  	"github.com/matrixorigin/matrixone/pkg/clusterservice"
    35  	"github.com/matrixorigin/matrixone/pkg/common/log"
    36  	"github.com/matrixorigin/matrixone/pkg/common/runtime"
    37  	"github.com/matrixorigin/matrixone/pkg/common/stopper"
    38  	"github.com/matrixorigin/matrixone/pkg/pb/metadata"
    39  	"github.com/stretchr/testify/require"
    40  )
    41  
    42  type testProxyHandler struct {
    43  	ctx        context.Context
    44  	st         *stopper.Stopper
    45  	logger     *log.MOLogger
    46  	hc         *mockHAKeeperClient
    47  	mc         clusterservice.MOCluster
    48  	re         *rebalancer
    49  	ru         Router
    50  	closeFn    func()
    51  	counterSet *counterSet
    52  }
    53  
    54  func newTestProxyHandler(t *testing.T) *testProxyHandler {
    55  	rt := runtime.DefaultRuntime()
    56  	runtime.SetupProcessLevelRuntime(rt)
    57  	ctx, cancel := context.WithCancel(context.TODO())
    58  	hc := &mockHAKeeperClient{}
    59  	mc := clusterservice.NewMOCluster(hc, 3*time.Second)
    60  	rt.SetGlobalVariables(runtime.ClusterService, mc)
    61  	logger := rt.Logger()
    62  	st := stopper.NewStopper("test-proxy", stopper.WithLogger(rt.Logger().RawLogger()))
    63  	re := testRebalancer(t, st, logger, mc)
    64  	return &testProxyHandler{
    65  		ctx:    ctx,
    66  		st:     st,
    67  		logger: logger,
    68  		hc:     hc,
    69  		mc:     mc,
    70  		re:     re,
    71  		ru:     newRouter(mc, re, false),
    72  		closeFn: func() {
    73  			mc.Close()
    74  			st.Stop()
    75  			cancel()
    76  		},
    77  		counterSet: newCounterSet(),
    78  	}
    79  }
    80  
    81  func certGen(basePath string) (*tlsConfig, error) {
    82  	max := new(big.Int).Lsh(big.NewInt(1), 128)
    83  	serialNumber, _ := rand.Int(rand.Reader, max)
    84  	subject := pkix.Name{
    85  		Country:            []string{"CN"},
    86  		Province:           []string{"SH"},
    87  		Organization:       []string{"MO"},
    88  		OrganizationalUnit: []string{"Dev"},
    89  	}
    90  
    91  	// set up CA certificate
    92  	ca := &x509.Certificate{
    93  		SerialNumber: serialNumber,
    94  		Subject:      subject,
    95  		NotBefore:    time.Now(),
    96  		NotAfter:     time.Now().Add(365 * 24 * time.Hour),
    97  		IsCA:         true,
    98  		ExtKeyUsage:  []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
    99  	}
   100  
   101  	// create our private and public key
   102  	caPrivKey, err := rsa.GenerateKey(rand.Reader, 2048)
   103  	if err != nil {
   104  		return nil, err
   105  	}
   106  
   107  	// create the CA
   108  	caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caPrivKey.PublicKey, caPrivKey)
   109  	if err != nil {
   110  		return nil, err
   111  	}
   112  
   113  	// pem encode
   114  	caFile := basePath + "/ca.pem"
   115  	caOut, _ := os.Create(caFile)
   116  	if err := pem.Encode(caOut, &pem.Block{
   117  		Type:  "CERTIFICATE",
   118  		Bytes: caBytes,
   119  	}); err != nil {
   120  		return nil, err
   121  	}
   122  	defer func() {
   123  		_ = caOut.Close()
   124  	}()
   125  
   126  	// set up server certificate
   127  	cert := &x509.Certificate{
   128  		SerialNumber: serialNumber,
   129  		Subject:      subject,
   130  		NotBefore:    time.Now(),
   131  		NotAfter:     time.Now().Add(365 * 24 * time.Hour),
   132  		ExtKeyUsage:  []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
   133  	}
   134  
   135  	certPrivKey, err := rsa.GenerateKey(rand.Reader, 2048)
   136  	if err != nil {
   137  		return nil, err
   138  	}
   139  
   140  	certBytes, err := x509.CreateCertificate(rand.Reader, cert, ca, &certPrivKey.PublicKey, caPrivKey)
   141  	if err != nil {
   142  		return nil, err
   143  	}
   144  
   145  	certFile := basePath + "/server-cert.pem"
   146  	certOut, _ := os.Create(certFile)
   147  	if err := pem.Encode(certOut, &pem.Block{
   148  		Type:  "CERTIFICATE",
   149  		Bytes: certBytes,
   150  	}); err != nil {
   151  		return nil, err
   152  	}
   153  	defer func() {
   154  		_ = certOut.Close()
   155  	}()
   156  
   157  	keyFile := basePath + "/server-key.pem"
   158  	keyOut, _ := os.Create(keyFile)
   159  	if err := pem.Encode(keyOut, &pem.Block{
   160  		Type:  "RSA PRIVATE KEY",
   161  		Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey),
   162  	}); err != nil {
   163  		return nil, err
   164  	}
   165  	defer func() {
   166  		_ = keyOut.Close()
   167  	}()
   168  
   169  	return &tlsConfig{
   170  		caFile:   caFile,
   171  		certFile: certFile,
   172  		keyFile:  keyFile,
   173  	}, nil
   174  }
   175  
   176  func TestHandler_Handle(t *testing.T) {
   177  	defer leaktest.AfterTest(t)()
   178  
   179  	temp := os.TempDir()
   180  	ctx, cancel := context.WithCancel(context.Background())
   181  	defer cancel()
   182  	rt := runtime.DefaultRuntime()
   183  	runtime.SetupProcessLevelRuntime(rt)
   184  	listenAddr := fmt.Sprintf("%s/%d.sock", temp, time.Now().Nanosecond())
   185  	require.NoError(t, os.RemoveAll(listenAddr))
   186  	cfg := Config{
   187  		ListenAddress:     "unix://" + listenAddr,
   188  		RebalanceDisabled: true,
   189  	}
   190  	hc := &mockHAKeeperClient{}
   191  	mc := clusterservice.NewMOCluster(hc, 3*time.Second)
   192  	defer mc.Close()
   193  	rt.SetGlobalVariables(runtime.ClusterService, mc)
   194  	addr := fmt.Sprintf("%s/%d.sock", temp, time.Now().Nanosecond())
   195  	require.NoError(t, os.RemoveAll(addr))
   196  	cn1 := testMakeCNServer("cn11", addr, 0, "", labelInfo{})
   197  	hc.updateCN(cn1.uuid, cn1.addr, map[string]metadata.LabelList{})
   198  	// start backend server.
   199  	stopFn := startTestCNServer(t, ctx, addr, nil)
   200  	defer func() {
   201  		require.NoError(t, stopFn())
   202  	}()
   203  	mc.ForceRefresh(true)
   204  
   205  	// start proxy.
   206  	s, err := NewServer(ctx, cfg, WithRuntime(runtime.DefaultRuntime()),
   207  		WithHAKeeperClient(hc))
   208  	defer func() {
   209  		err := s.Close()
   210  		require.NoError(t, err)
   211  	}()
   212  	require.NoError(t, err)
   213  	require.NotNil(t, s)
   214  	err = s.Start()
   215  	require.NoError(t, err)
   216  
   217  	db, err := sql.Open("mysql", fmt.Sprintf("dump:111@unix(%s)/db1", listenAddr))
   218  	// connect to server.
   219  	require.NoError(t, err)
   220  	require.NotNil(t, db)
   221  	defer func() {
   222  		_ = db.Close()
   223  		timeout := time.NewTimer(time.Second * 15)
   224  		tick := time.NewTicker(time.Millisecond * 100)
   225  		var connTotal int64
   226  		tt := false
   227  		for {
   228  			select {
   229  			case <-tick.C:
   230  				connTotal = s.counterSet.connTotal.Load()
   231  			case <-timeout.C:
   232  				tt = true
   233  			}
   234  			if connTotal == 0 || tt {
   235  				break
   236  			}
   237  		}
   238  		tick.Stop()
   239  		timeout.Stop()
   240  		require.Equal(t, int64(0), connTotal)
   241  	}()
   242  	_, err = db.Exec("anystmt")
   243  	require.NoError(t, err)
   244  
   245  	require.Equal(t, int64(1), s.counterSet.connAccepted.Load())
   246  	require.Equal(t, int64(1), s.counterSet.connTotal.Load())
   247  }
   248  
   249  func TestHandler_HandleErr(t *testing.T) {
   250  	defer leaktest.AfterTest(t)()
   251  
   252  	temp := os.TempDir()
   253  	ctx, cancel := context.WithCancel(context.Background())
   254  	defer cancel()
   255  	rt := runtime.DefaultRuntime()
   256  	runtime.SetupProcessLevelRuntime(rt)
   257  	listenAddr := fmt.Sprintf("%s/%d.sock", temp, time.Now().Nanosecond())
   258  	require.NoError(t, os.RemoveAll(listenAddr))
   259  	cfg := Config{
   260  		ListenAddress:     "unix://" + listenAddr,
   261  		RebalanceDisabled: true,
   262  	}
   263  	hc := &mockHAKeeperClient{}
   264  	mc := clusterservice.NewMOCluster(hc, 3*time.Second)
   265  	defer mc.Close()
   266  	rt.SetGlobalVariables(runtime.ClusterService, mc)
   267  	addr := fmt.Sprintf("%s/%d.sock", temp, time.Now().Nanosecond())
   268  	require.NoError(t, os.RemoveAll(addr))
   269  
   270  	// start proxy.
   271  	s, err := NewServer(ctx, cfg, WithRuntime(runtime.DefaultRuntime()),
   272  		WithHAKeeperClient(hc))
   273  	defer func() {
   274  		err := s.Close()
   275  		require.NoError(t, err)
   276  	}()
   277  	require.NoError(t, err)
   278  	require.NotNil(t, s)
   279  	err = s.Start()
   280  	require.NoError(t, err)
   281  
   282  	db, err := sql.Open("mysql", fmt.Sprintf("dump:111@unix(%s)/db1", listenAddr))
   283  	// connect to server.
   284  	require.NoError(t, err)
   285  	require.NotNil(t, db)
   286  	defer func() {
   287  		_ = db.Close()
   288  		timeout := time.NewTimer(time.Second * 15)
   289  		tick := time.NewTicker(time.Millisecond * 100)
   290  		var connTotal int64
   291  		tt := false
   292  		for {
   293  			select {
   294  			case <-tick.C:
   295  				connTotal = s.counterSet.connTotal.Load()
   296  			case <-timeout.C:
   297  				tt = true
   298  			}
   299  			if connTotal == 0 || tt {
   300  				break
   301  			}
   302  		}
   303  		tick.Stop()
   304  		timeout.Stop()
   305  		require.Equal(t, int64(0), connTotal)
   306  	}()
   307  	_, err = db.Exec("anystmt")
   308  	require.Error(t, err)
   309  
   310  	require.Equal(t, int64(1), s.counterSet.connAccepted.Load())
   311  }
   312  
   313  func TestHandler_HandleWithSSL(t *testing.T) {
   314  	defer leaktest.AfterTest(t)()
   315  
   316  	temp := os.TempDir()
   317  	ctx, cancel := context.WithCancel(context.Background())
   318  	defer cancel()
   319  	rt := runtime.DefaultRuntime()
   320  	runtime.SetupProcessLevelRuntime(rt)
   321  	listenAddr := fmt.Sprintf("%s/%d.sock", temp, time.Now().Nanosecond())
   322  	require.NoError(t, os.RemoveAll(listenAddr))
   323  	cfg := Config{
   324  		ListenAddress:     "unix://" + listenAddr,
   325  		RebalanceDisabled: true,
   326  	}
   327  	hc := &mockHAKeeperClient{}
   328  	mc := clusterservice.NewMOCluster(hc, 3*time.Second)
   329  	defer mc.Close()
   330  	rt.SetGlobalVariables(runtime.ClusterService, mc)
   331  	addr := fmt.Sprintf("%s/%d.sock", temp, time.Now().Nanosecond())
   332  	require.NoError(t, os.RemoveAll(addr))
   333  	cn1 := testMakeCNServer("cn11", addr, 0, "", labelInfo{})
   334  	hc.updateCN(cn1.uuid, cn1.addr, map[string]metadata.LabelList{})
   335  
   336  	tlsC, err := certGen(temp)
   337  	require.NoError(t, err)
   338  	tlsC.enabled = true
   339  	// start backend server.
   340  	stopFn := startTestCNServer(t, ctx, addr, tlsC)
   341  	defer func() {
   342  		require.NoError(t, stopFn())
   343  	}()
   344  	mc.ForceRefresh(true)
   345  
   346  	// start proxy.
   347  	s, err := NewServer(ctx, cfg, WithRuntime(runtime.DefaultRuntime()),
   348  		WithHAKeeperClient(hc),
   349  		WithTLSEnabled(),
   350  		WithTLSCAFile(tlsC.caFile),
   351  		WithTLSCertFile(tlsC.certFile),
   352  		WithTLSKeyFile(tlsC.keyFile))
   353  	defer func() {
   354  		err := s.Close()
   355  		require.NoError(t, err)
   356  	}()
   357  	require.NoError(t, err)
   358  	require.NotNil(t, s)
   359  	err = s.Start()
   360  	require.NoError(t, err)
   361  
   362  	rootCertPool := x509.NewCertPool()
   363  
   364  	pem1, err := os.ReadFile(tlsC.caFile)
   365  	require.NoError(t, err)
   366  
   367  	ok := rootCertPool.AppendCertsFromPEM(pem1)
   368  	require.True(t, ok)
   369  
   370  	err = mysql.RegisterTLSConfig("custom", &tls.Config{
   371  		RootCAs:            rootCertPool,
   372  		InsecureSkipVerify: true,
   373  	})
   374  	require.NoError(t, err)
   375  
   376  	db, err := sql.Open("mysql",
   377  		fmt.Sprintf("dump:111@unix(%s)/db1?tls=custom", listenAddr))
   378  	// connect to server.
   379  	require.NoError(t, err)
   380  	require.NotNil(t, db)
   381  	defer func() {
   382  		_ = db.Close()
   383  	}()
   384  	_, _ = db.Exec("any stmt")
   385  	_, err = db.Exec("any stmt")
   386  	require.NoError(t, err)
   387  	require.Equal(t, int64(1), s.counterSet.connAccepted.Load())
   388  	require.Equal(t, int64(1), s.counterSet.connTotal.Load())
   389  }
   390  
   391  func testWithServer(t *testing.T, fn func(*testing.T, string, *Server)) {
   392  	defer leaktest.AfterTest(t)()
   393  
   394  	temp := os.TempDir()
   395  	ctx, cancel := context.WithCancel(context.Background())
   396  	defer cancel()
   397  	rt := runtime.DefaultRuntime()
   398  	runtime.SetupProcessLevelRuntime(rt)
   399  	listenAddr := fmt.Sprintf("%s/%d.sock", temp, time.Now().Nanosecond())
   400  	require.NoError(t, os.RemoveAll(listenAddr))
   401  	cfg := Config{
   402  		ListenAddress:     "unix://" + listenAddr,
   403  		RebalanceDisabled: true,
   404  	}
   405  	hc := &mockHAKeeperClient{}
   406  	mc := clusterservice.NewMOCluster(hc, 3*time.Second)
   407  	defer mc.Close()
   408  	rt.SetGlobalVariables(runtime.ClusterService, mc)
   409  	addr := fmt.Sprintf("%s/%d.sock", temp, time.Now().Nanosecond())
   410  	require.NoError(t, os.RemoveAll(addr))
   411  	cn1 := testMakeCNServer("cn11", addr, 0, "", labelInfo{})
   412  	hc.updateCN(cn1.uuid, cn1.addr, map[string]metadata.LabelList{})
   413  	// start backend server.
   414  	stopFn := startTestCNServer(t, ctx, addr, nil)
   415  	defer func() {
   416  		require.NoError(t, stopFn())
   417  	}()
   418  	mc.ForceRefresh(true)
   419  
   420  	// start proxy.
   421  	s, err := NewServer(ctx, cfg, WithRuntime(runtime.DefaultRuntime()),
   422  		WithHAKeeperClient(hc))
   423  	defer func() {
   424  		err := s.Close()
   425  		require.NoError(t, err)
   426  	}()
   427  	require.NoError(t, err)
   428  	require.NotNil(t, s)
   429  	err = s.Start()
   430  	require.NoError(t, err)
   431  
   432  	fn(t, listenAddr, s)
   433  }
   434  
   435  func TestHandler_HandleEventKillQuery(t *testing.T) {
   436  	testWithServer(t, func(t *testing.T, addr string, s *Server) {
   437  		db1, err := sql.Open("mysql", fmt.Sprintf("dump:111@unix(%s)/db1", addr))
   438  		// connect to server.
   439  		require.NoError(t, err)
   440  		require.NotNil(t, db1)
   441  		defer func() {
   442  			_ = db1.Close()
   443  		}()
   444  		res, err := db1.Exec("select 1")
   445  		require.NoError(t, err)
   446  		connID, _ := res.LastInsertId() // fake connection id
   447  
   448  		db2, err := sql.Open("mysql", fmt.Sprintf("dump:111@unix(%s)/db1", addr))
   449  		// connect to server.
   450  		require.NoError(t, err)
   451  		require.NotNil(t, db2)
   452  		defer func() {
   453  			_ = db2.Close()
   454  		}()
   455  
   456  		_, err = db2.Exec(fmt.Sprintf("kill query %d", connID))
   457  		require.NoError(t, err)
   458  
   459  		require.Equal(t, int64(2), s.counterSet.connAccepted.Load())
   460  	})
   461  }
   462  
   463  func TestHandler_HandleEventSetVar(t *testing.T) {
   464  	testWithServer(t, func(t *testing.T, addr string, s *Server) {
   465  		db1, err := sql.Open("mysql", fmt.Sprintf("dump:111@unix(%s)/db1", addr))
   466  		// connect to server.
   467  		require.NoError(t, err)
   468  		require.NotNil(t, db1)
   469  		defer func() {
   470  			_ = db1.Close()
   471  		}()
   472  		_, err = db1.Exec("set session cn_label='acc1'")
   473  		require.NoError(t, err)
   474  
   475  		res, err := db1.Query("show session variables")
   476  		require.NoError(t, err)
   477  		defer res.Close()
   478  		var varName, varValue string
   479  		for res.Next() {
   480  			err := res.Scan(&varName, &varValue)
   481  			require.NoError(t, err)
   482  			require.Equal(t, "cn_label", varName)
   483  			require.Equal(t, "acc1", varValue)
   484  		}
   485  		err = res.Err()
   486  		require.NoError(t, err)
   487  
   488  		require.Equal(t, int64(1), s.counterSet.connAccepted.Load())
   489  	})
   490  }
   491  
   492  func TestHandler_HandleTxn(t *testing.T) {
   493  	testWithServer(t, func(t *testing.T, addr string, s *Server) {
   494  		db1, err := sql.Open("mysql", fmt.Sprintf("a1#root:111@unix(%s)/db1", addr))
   495  		// connect to server.
   496  		require.NoError(t, err)
   497  		require.NotNil(t, db1)
   498  		defer func() {
   499  			_ = db1.Close()
   500  		}()
   501  		_, err = db1.Exec("select 1")
   502  		require.NoError(t, err)
   503  	})
   504  }