github.com/matrixorigin/matrixone@v1.2.0/pkg/proxy/server_conn_test.go (about)

     1  // Copyright 2021 - 2023 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package proxy
    16  
    17  import (
    18  	"bufio"
    19  	"context"
    20  	"crypto/tls"
    21  	"fmt"
    22  	"net"
    23  	"os"
    24  	"strings"
    25  	"sync"
    26  	"sync/atomic"
    27  	"testing"
    28  	"time"
    29  
    30  	"github.com/fagongzi/goetty/v2"
    31  	"github.com/lni/goutils/leaktest"
    32  	"github.com/stretchr/testify/require"
    33  
    34  	"github.com/matrixorigin/matrixone/pkg/config"
    35  	"github.com/matrixorigin/matrixone/pkg/container/types"
    36  	"github.com/matrixorigin/matrixone/pkg/frontend"
    37  	"github.com/matrixorigin/matrixone/pkg/pb/proxy"
    38  	"github.com/matrixorigin/matrixone/pkg/sql/plan"
    39  )
    40  
    41  var testSlat = []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0}
    42  var testPacket = &frontend.Packet{
    43  	Length:     1,
    44  	SequenceID: 0,
    45  	Payload:    []byte{1},
    46  }
    47  
    48  func testMakeCNServer(
    49  	uuid string, addr string, connID uint32, hash LabelHash, reqLabel labelInfo,
    50  ) *CNServer {
    51  	if strings.Contains(addr, "sock") {
    52  		addr = "unix://" + addr
    53  	}
    54  	return &CNServer{
    55  		connID:   connID,
    56  		addr:     addr,
    57  		uuid:     uuid,
    58  		salt:     testSlat,
    59  		hash:     hash,
    60  		reqLabel: reqLabel,
    61  	}
    62  }
    63  
    64  type mockServerConn struct {
    65  	conn net.Conn
    66  }
    67  
    68  var _ ServerConn = (*mockServerConn)(nil)
    69  
    70  func newMockServerConn(conn net.Conn) *mockServerConn {
    71  	m := &mockServerConn{
    72  		conn: conn,
    73  	}
    74  	return m
    75  }
    76  
    77  func (s *mockServerConn) ConnID() uint32    { return 0 }
    78  func (s *mockServerConn) RawConn() net.Conn { return s.conn }
    79  func (s *mockServerConn) HandleHandshake(_ *frontend.Packet, _ time.Duration) (*frontend.Packet, error) {
    80  	return nil, nil
    81  }
    82  func (s *mockServerConn) ExecStmt(stmt internalStmt, resp chan<- []byte) (bool, error) {
    83  	sendResp(makeOKPacket(8), resp)
    84  	return true, nil
    85  }
    86  func (s *mockServerConn) Close() error {
    87  	if s.conn != nil {
    88  		_ = s.conn.Close()
    89  	}
    90  	return nil
    91  }
    92  
    93  var baseConnID atomic.Uint32
    94  
    95  type tlsConfig struct {
    96  	enabled  bool
    97  	caFile   string
    98  	certFile string
    99  	keyFile  string
   100  }
   101  
   102  type testCNServer struct {
   103  	sync.Mutex
   104  	ctx      context.Context
   105  	scheme   string
   106  	addr     string
   107  	listener net.Listener
   108  	started  bool
   109  	quit     chan interface{}
   110  
   111  	globalVars map[string]string
   112  	tlsCfg     tlsConfig
   113  	tlsConfig  *tls.Config
   114  
   115  	beforeHandle func()
   116  }
   117  
   118  type testHandler struct {
   119  	mysqlProto  *frontend.MysqlProtocolImpl
   120  	connID      uint32
   121  	conn        goetty.IOSession
   122  	sessionVars map[string]string
   123  	labels      map[string]string
   124  	server      *testCNServer
   125  	status      uint16
   126  }
   127  
   128  type option func(s *testCNServer)
   129  
   130  func withBeforeHandle(f func()) option {
   131  	return func(s *testCNServer) {
   132  		s.beforeHandle = f
   133  	}
   134  }
   135  
   136  func startTestCNServer(t *testing.T, ctx context.Context, addr string, cfg *tlsConfig, opts ...option) func() error {
   137  	b := &testCNServer{
   138  		ctx:        ctx,
   139  		scheme:     "tcp",
   140  		addr:       addr,
   141  		quit:       make(chan interface{}),
   142  		globalVars: make(map[string]string),
   143  	}
   144  	for _, opt := range opts {
   145  		opt(b)
   146  	}
   147  	if cfg != nil {
   148  		b.tlsCfg = *cfg
   149  	}
   150  	if strings.Contains(addr, "sock") {
   151  		b.scheme = "unix"
   152  	}
   153  	go func() {
   154  		err := b.Start()
   155  		require.NoError(t, err)
   156  	}()
   157  	require.True(t, b.waitCNServerReady())
   158  	return func() error {
   159  		return b.Stop()
   160  	}
   161  }
   162  
   163  func (s *testCNServer) waitCNServerReady() bool {
   164  	ctx, cancel := context.WithTimeout(s.ctx, time.Second*3)
   165  	defer cancel()
   166  	tick := time.NewTicker(time.Millisecond * 100)
   167  	for {
   168  		select {
   169  		case <-ctx.Done():
   170  			return false
   171  		case <-tick.C:
   172  			s.Lock()
   173  			started := s.started
   174  			s.Unlock()
   175  			conn, err := net.Dial(s.scheme, s.addr)
   176  			if err == nil && started {
   177  				_ = conn.Close()
   178  				return true
   179  			}
   180  			if conn != nil {
   181  				_ = conn.Close()
   182  			}
   183  		}
   184  	}
   185  }
   186  
   187  func (s *testCNServer) Start() error {
   188  	var err error
   189  	if s.tlsCfg.enabled {
   190  		s.tlsConfig, err = frontend.ConstructTLSConfig(
   191  			context.TODO(),
   192  			s.tlsCfg.caFile,
   193  			s.tlsCfg.certFile,
   194  			s.tlsCfg.keyFile,
   195  		)
   196  		if err != nil {
   197  			return err
   198  		}
   199  	}
   200  	s.listener, err = net.Listen(s.scheme, s.addr)
   201  	if err != nil {
   202  		return err
   203  	}
   204  	s.Lock()
   205  	s.started = true
   206  	s.Unlock()
   207  
   208  	for {
   209  		select {
   210  		case <-s.ctx.Done():
   211  			return nil
   212  		default:
   213  			conn, err := s.listener.Accept()
   214  			if conn == nil {
   215  				continue
   216  			}
   217  			if err != nil {
   218  				select {
   219  				case <-s.quit:
   220  					return nil
   221  				default:
   222  					return err
   223  				}
   224  			} else {
   225  				fp := config.FrontendParameters{
   226  					EnableTls: s.tlsCfg.enabled,
   227  				}
   228  				fp.SetDefaultValues()
   229  				cid := baseConnID.Add(1)
   230  				c := goetty.NewIOSession(goetty.WithSessionCodec(frontend.NewSqlCodec()),
   231  					goetty.WithSessionConn(uint64(cid), conn))
   232  				h := &testHandler{
   233  					connID: cid,
   234  					conn:   c,
   235  					mysqlProto: frontend.NewMysqlClientProtocol(
   236  						cid, c, 0, &fp),
   237  					sessionVars: make(map[string]string),
   238  					labels:      make(map[string]string),
   239  					server:      s,
   240  				}
   241  				if s.beforeHandle != nil {
   242  					s.beforeHandle()
   243  				}
   244  				go func(h *testHandler) {
   245  					testHandle(h)
   246  				}(h)
   247  			}
   248  		}
   249  	}
   250  }
   251  
   252  func testHandle(h *testHandler) {
   253  	// read extra info from proxy.
   254  	extraInfo := proxy.ExtraInfo{}
   255  	reader := bufio.NewReader(h.conn.RawConn())
   256  	_ = extraInfo.Decode(reader)
   257  	// server writes init handshake.
   258  	_ = h.mysqlProto.WritePacket(h.mysqlProto.MakeHandshakePayload())
   259  	// server reads auth information from client.
   260  	_, _ = h.conn.Read(goetty.ReadOptions{})
   261  	// server writes ok packet.
   262  	_ = h.mysqlProto.WritePacket(h.mysqlProto.MakeOKPayload(0, uint64(h.connID), 0, 0, ""))
   263  	for {
   264  		msg, err := h.conn.Read(goetty.ReadOptions{})
   265  		if err != nil {
   266  			break
   267  		}
   268  		packet, ok := msg.(*frontend.Packet)
   269  		if !ok {
   270  			return
   271  		}
   272  		if packet.Length > 1 && packet.Payload[0] == 3 {
   273  			if strings.HasPrefix(string(packet.Payload[1:]), "set session") {
   274  				h.handleSetVar(packet)
   275  			} else if string(packet.Payload[1:]) == "show session variables" {
   276  				h.handleShowVar()
   277  			} else if string(packet.Payload[1:]) == "show global variables" {
   278  				h.handleShowGlobalVar()
   279  			} else if string(packet.Payload[1:]) == "begin" {
   280  				h.handleStartTxn()
   281  			} else if string(packet.Payload[1:]) == "commit" || string(packet.Payload[1:]) == "rollback" {
   282  				h.handleStopTxn()
   283  			} else if strings.HasPrefix(string(packet.Payload[1:]), "kill connection") {
   284  				h.handleKillConn()
   285  			} else {
   286  				h.handleCommon()
   287  			}
   288  		} else {
   289  			h.handleCommon()
   290  		}
   291  	}
   292  }
   293  
   294  func (h *testHandler) handleCommon() {
   295  	h.mysqlProto.SetSequenceID(1)
   296  	// set last insert id as connection id to do test more easily.
   297  	_ = h.mysqlProto.WritePacket(h.mysqlProto.MakeOKPayload(0, uint64(h.connID), h.status, 0, ""))
   298  }
   299  
   300  func (h *testHandler) handleSetVar(packet *frontend.Packet) {
   301  	words := strings.Split(string(packet.Payload[1:]), " ")
   302  	v := strings.Split(words[2], "=")
   303  	h.sessionVars[v[0]] = strings.Trim(v[1], "'")
   304  	h.mysqlProto.SetSequenceID(1)
   305  	_ = h.mysqlProto.WritePacket(h.mysqlProto.MakeOKPayload(0, uint64(h.connID), h.status, 0, ""))
   306  }
   307  
   308  func (h *testHandler) handleKillConn() {
   309  	h.server.globalVars["killed"] = "yes"
   310  	h.mysqlProto.SetSequenceID(1)
   311  	_ = h.mysqlProto.WritePacket(h.mysqlProto.MakeOKPayload(0, uint64(h.connID), h.status, 0, ""))
   312  }
   313  
   314  func (h *testHandler) handleShowVar() {
   315  	h.mysqlProto.SetSequenceID(1)
   316  	err := h.mysqlProto.SendColumnCountPacket(2)
   317  	if err != nil {
   318  		_ = h.mysqlProto.WritePacket(h.mysqlProto.MakeErrPayload(0, "", err.Error()))
   319  		return
   320  	}
   321  	cols := []*plan.ColDef{
   322  		{Typ: plan.Type{Id: int32(types.T_char)}, Name: "Variable_name"},
   323  		{Typ: plan.Type{Id: int32(types.T_char)}, Name: "Value"},
   324  	}
   325  	columns := make([]interface{}, len(cols))
   326  	res := &frontend.MysqlResultSet{}
   327  	for i, col := range cols {
   328  		c := new(frontend.MysqlColumn)
   329  		c.SetName(col.Name)
   330  		c.SetOrgName(col.Name)
   331  		c.SetTable(col.Typ.Table)
   332  		c.SetOrgTable(col.Typ.Table)
   333  		c.SetAutoIncr(col.Typ.AutoIncr)
   334  		c.SetSchema("")
   335  		c.SetDecimal(col.Typ.Scale)
   336  		columns[i] = c
   337  		res.AddColumn(c)
   338  	}
   339  	for _, c := range columns {
   340  		if err := h.mysqlProto.SendColumnDefinitionPacket(context.TODO(), c.(frontend.Column), 3); err != nil {
   341  			_ = h.mysqlProto.WritePacket(h.mysqlProto.MakeErrPayload(0, "", err.Error()))
   342  			return
   343  		}
   344  	}
   345  	_ = h.mysqlProto.WritePacket(h.mysqlProto.MakeEOFPayload(0, h.status))
   346  	for k, v := range h.sessionVars {
   347  		row := make([]interface{}, 2)
   348  		row[0] = k
   349  		row[1] = v
   350  		res.AddRow(row)
   351  	}
   352  	ses := &frontend.Session{}
   353  	h.mysqlProto.SetSession(ses)
   354  	if err := h.mysqlProto.SendResultSetTextBatchRow(res, res.GetRowCount()); err != nil {
   355  		_ = h.mysqlProto.WritePacket(h.mysqlProto.MakeErrPayload(0, "", err.Error()))
   356  		return
   357  	}
   358  	_ = h.mysqlProto.WritePacket(h.mysqlProto.MakeEOFPayload(0, h.status))
   359  }
   360  
   361  func (h *testHandler) handleShowGlobalVar() {
   362  	h.mysqlProto.SetSequenceID(1)
   363  	err := h.mysqlProto.SendColumnCountPacket(2)
   364  	if err != nil {
   365  		_ = h.mysqlProto.WritePacket(h.mysqlProto.MakeErrPayload(0, "", err.Error()))
   366  		return
   367  	}
   368  	cols := []*plan.ColDef{
   369  		{Typ: plan.Type{Id: int32(types.T_char)}, Name: "Variable_name"},
   370  		{Typ: plan.Type{Id: int32(types.T_char)}, Name: "Value"},
   371  	}
   372  	columns := make([]interface{}, len(cols))
   373  	res := &frontend.MysqlResultSet{}
   374  	for i, col := range cols {
   375  		c := new(frontend.MysqlColumn)
   376  		c.SetName(col.Name)
   377  		c.SetOrgName(col.Name)
   378  		c.SetTable(col.Typ.Table)
   379  		c.SetOrgTable(col.Typ.Table)
   380  		c.SetAutoIncr(col.Typ.AutoIncr)
   381  		c.SetSchema("")
   382  		c.SetDecimal(col.Typ.Scale)
   383  		columns[i] = c
   384  		res.AddColumn(c)
   385  	}
   386  	for _, c := range columns {
   387  		if err := h.mysqlProto.SendColumnDefinitionPacket(context.TODO(), c.(frontend.Column), 3); err != nil {
   388  			_ = h.mysqlProto.WritePacket(h.mysqlProto.MakeErrPayload(0, "", err.Error()))
   389  			return
   390  		}
   391  	}
   392  	_ = h.mysqlProto.WritePacket(h.mysqlProto.MakeEOFPayload(0, h.status))
   393  	for k, v := range h.server.globalVars {
   394  		row := make([]interface{}, 2)
   395  		row[0] = k
   396  		row[1] = v
   397  		res.AddRow(row)
   398  	}
   399  	ses := &frontend.Session{}
   400  	h.mysqlProto.SetSession(ses)
   401  	if err := h.mysqlProto.SendResultSetTextBatchRow(res, res.GetRowCount()); err != nil {
   402  		_ = h.mysqlProto.WritePacket(h.mysqlProto.MakeErrPayload(0, "", err.Error()))
   403  		return
   404  	}
   405  	_ = h.mysqlProto.WritePacket(h.mysqlProto.MakeEOFPayload(0, h.status))
   406  }
   407  
   408  func (h *testHandler) handleStartTxn() {
   409  	h.status |= frontend.SERVER_STATUS_IN_TRANS
   410  	h.handleCommon()
   411  }
   412  
   413  func (h *testHandler) handleStopTxn() {
   414  	h.status &= ^frontend.SERVER_STATUS_IN_TRANS
   415  	h.handleCommon()
   416  }
   417  
   418  func (s *testCNServer) Stop() error {
   419  	close(s.quit)
   420  	_ = s.listener.Close()
   421  	return nil
   422  }
   423  
   424  func TestServerConn_Create(t *testing.T) {
   425  	defer leaktest.AfterTest(t)
   426  
   427  	temp := os.TempDir()
   428  	addr := fmt.Sprintf("%s/%d.sock", temp, time.Now().Nanosecond())
   429  	require.NoError(t, os.RemoveAll(addr))
   430  	cn1 := testMakeCNServer("cn11", addr, 0, "", labelInfo{})
   431  	cn1.reqLabel = newLabelInfo("t1", map[string]string{
   432  		"k1": "v1",
   433  		"k2": "v2",
   434  	})
   435  	// server not started.
   436  	sc, err := newServerConn(cn1, nil, nil, 0)
   437  	require.Error(t, err)
   438  	require.Nil(t, sc)
   439  
   440  	// start server.
   441  	tp := newTestProxyHandler(t)
   442  	defer tp.closeFn()
   443  	stopFn := startTestCNServer(t, tp.ctx, addr, nil)
   444  	defer func() {
   445  		require.NoError(t, stopFn())
   446  	}()
   447  
   448  	sc, err = newServerConn(cn1, nil, nil, 0)
   449  	require.NoError(t, err)
   450  	require.NotNil(t, sc)
   451  }
   452  
   453  func TestServerConn_Connect(t *testing.T) {
   454  	defer leaktest.AfterTest(t)
   455  	temp := os.TempDir()
   456  	addr := fmt.Sprintf("%s/%d.sock", temp, time.Now().Nanosecond())
   457  	require.NoError(t, os.RemoveAll(addr))
   458  	cn1 := testMakeCNServer("cn11", addr, 0, "", labelInfo{})
   459  	cn1.reqLabel = newLabelInfo("t1", map[string]string{
   460  		"k1": "v1",
   461  		"k2": "v2",
   462  	})
   463  	tp := newTestProxyHandler(t)
   464  	defer tp.closeFn()
   465  	stopFn := startTestCNServer(t, tp.ctx, addr, nil)
   466  	defer func() {
   467  		require.NoError(t, stopFn())
   468  	}()
   469  
   470  	sc, err := newServerConn(cn1, nil, tp.re, 0)
   471  	require.NoError(t, err)
   472  	require.NotNil(t, sc)
   473  	_, err = sc.HandleHandshake(&frontend.Packet{Payload: []byte{1}}, time.Second*3)
   474  	require.NoError(t, err)
   475  	require.NotEqual(t, 0, int(sc.ConnID()))
   476  	err = sc.Close()
   477  	require.NoError(t, err)
   478  }
   479  
   480  func TestFakeCNServer(t *testing.T) {
   481  	defer leaktest.AfterTest(t)
   482  
   483  	tp := newTestProxyHandler(t)
   484  	defer tp.closeFn()
   485  
   486  	temp := os.TempDir()
   487  	addr := fmt.Sprintf("%s/%d.sock", temp, time.Now().Nanosecond())
   488  	require.NoError(t, os.RemoveAll(addr))
   489  	stopFn := startTestCNServer(t, tp.ctx, addr, nil)
   490  	defer func() {
   491  		require.NoError(t, stopFn())
   492  	}()
   493  
   494  	li := labelInfo{}
   495  	cn1 := testMakeCNServer("cn11", addr, 0, "", labelInfo{})
   496  	cn1.reqLabel = newLabelInfo("t1", map[string]string{
   497  		"k1": "v1",
   498  		"k2": "v2",
   499  	})
   500  
   501  	cleanup := testStartClient(t, tp, clientInfo{labelInfo: li}, cn1)
   502  	defer cleanup()
   503  }
   504  
   505  func TestServerConn_ExecStmt(t *testing.T) {
   506  	defer leaktest.AfterTest(t)
   507  
   508  	temp := os.TempDir()
   509  	addr := fmt.Sprintf("%s/%d.sock", temp, time.Now().Nanosecond())
   510  	require.NoError(t, os.RemoveAll(addr))
   511  	cn1 := testMakeCNServer("cn11", addr, 0, "", labelInfo{})
   512  	cn1.reqLabel = newLabelInfo("t1", map[string]string{
   513  		"k1": "v1",
   514  		"k2": "v2",
   515  	})
   516  	tp := newTestProxyHandler(t)
   517  	defer tp.closeFn()
   518  	stopFn := startTestCNServer(t, tp.ctx, addr, nil)
   519  	defer func() {
   520  		require.NoError(t, stopFn())
   521  	}()
   522  
   523  	sc, err := newServerConn(cn1, nil, tp.re, 0)
   524  	require.NoError(t, err)
   525  	require.NotNil(t, sc)
   526  	_, err = sc.HandleHandshake(&frontend.Packet{Payload: []byte{1}}, time.Second*3)
   527  	require.NoError(t, err)
   528  	require.NotEqual(t, 0, int(sc.ConnID()))
   529  	resp := make(chan []byte, 10)
   530  	_, err = sc.ExecStmt(internalStmt{cmdType: cmdQuery, s: "kill query"}, resp)
   531  	require.NoError(t, err)
   532  	res := <-resp
   533  	ok := isOKPacket(res)
   534  	require.True(t, ok)
   535  }
   536  
   537  func TestServerConnParseConnID(t *testing.T) {
   538  	t.Run("too short error", func(t *testing.T) {
   539  		s := &serverConn{}
   540  		p := &frontend.Packet{
   541  			Payload: []byte{10},
   542  		}
   543  		err := s.parseConnID(p)
   544  		require.Error(t, err)
   545  	})
   546  
   547  	t.Run("no string", func(t *testing.T) {
   548  		s := &serverConn{}
   549  		p := &frontend.Packet{
   550  			Length:  8,
   551  			Payload: []byte{10},
   552  		}
   553  		p.Payload = append(p.Payload, []byte("v1")...)
   554  		err := s.parseConnID(p)
   555  		require.Error(t, err)
   556  	})
   557  
   558  	t.Run("no conn id", func(t *testing.T) {
   559  		s := &serverConn{}
   560  		p := &frontend.Packet{
   561  			Length:  5,
   562  			Payload: []byte{10},
   563  		}
   564  		p.Payload = append(p.Payload, []byte("v1")...)
   565  		p.Payload = append(p.Payload, []byte{0}...)
   566  		p.Payload = append(p.Payload, []byte{2, 0, 0, 0}...)
   567  		err := s.parseConnID(p)
   568  		require.Error(t, err)
   569  	})
   570  
   571  	t.Run("no error", func(t *testing.T) {
   572  		s := &serverConn{}
   573  		p := &frontend.Packet{
   574  			Length:  8,
   575  			Payload: []byte{10},
   576  		}
   577  		p.Payload = append(p.Payload, []byte("v1")...)
   578  		p.Payload = append(p.Payload, []byte{0}...)
   579  		p.Payload = append(p.Payload, []byte{2, 0, 0, 0}...)
   580  		err := s.parseConnID(p)
   581  		require.NoError(t, err)
   582  	})
   583  }