github.com/matrixorigin/matrixone@v1.2.0/pkg/proxy/client_conn_test.go (about)

     1  // Copyright 2021 - 2023 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package proxy
    16  
    17  import (
    18  	"context"
    19  	"encoding/binary"
    20  	"net"
    21  	"strings"
    22  	"sync"
    23  	"testing"
    24  	"time"
    25  
    26  	"github.com/fagongzi/goetty/v2"
    27  	"github.com/fagongzi/goetty/v2/buf"
    28  	"github.com/lni/goutils/leaktest"
    29  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    30  	"github.com/matrixorigin/matrixone/pkg/common/runtime"
    31  	"github.com/matrixorigin/matrixone/pkg/frontend"
    32  	"github.com/stretchr/testify/require"
    33  )
    34  
    35  type mockNetConn struct {
    36  	localIP    string
    37  	localPort  int
    38  	remoteIP   string
    39  	remotePort int
    40  	c          net.Conn
    41  }
    42  
    43  func newMockNetConn(
    44  	localIP string, localPort int, remoteIP string, remotePort int, c net.Conn,
    45  ) *mockNetConn {
    46  	return &mockNetConn{
    47  		localIP:    localIP,
    48  		localPort:  localPort,
    49  		remoteIP:   remoteIP,
    50  		remotePort: remotePort,
    51  		c:          c,
    52  	}
    53  }
    54  
    55  func (c *mockNetConn) SetRemote(addr string) {
    56  	c.remoteIP = addr
    57  }
    58  
    59  func (c *mockNetConn) Read(b []byte) (n int, err error) {
    60  	return c.c.Read(b)
    61  }
    62  
    63  func (c *mockNetConn) Write(b []byte) (n int, err error) {
    64  	return c.c.Write(b)
    65  }
    66  
    67  func (c *mockNetConn) Close() error {
    68  	return nil
    69  }
    70  
    71  func (c *mockNetConn) LocalAddr() net.Addr {
    72  	return &net.TCPAddr{
    73  		IP:   []byte(c.localIP),
    74  		Port: c.localPort,
    75  	}
    76  }
    77  
    78  func (c *mockNetConn) RemoteAddr() net.Addr {
    79  	return &net.TCPAddr{
    80  		IP:   []byte(c.remoteIP),
    81  		Port: c.remotePort,
    82  	}
    83  }
    84  
    85  func (c *mockNetConn) SetDeadline(t time.Time) error {
    86  	return nil
    87  }
    88  
    89  func (c *mockNetConn) SetReadDeadline(t time.Time) error {
    90  	return nil
    91  }
    92  
    93  func (c *mockNetConn) SetWriteDeadline(t time.Time) error {
    94  	return nil
    95  }
    96  
    97  type mockClientConn struct {
    98  	conn       net.Conn
    99  	tenant     Tenant
   100  	clientInfo clientInfo // need to set it explicitly
   101  	router     Router
   102  	tun        *tunnel
   103  	redoStmts  []internalStmt
   104  }
   105  
   106  var _ ClientConn = (*mockClientConn)(nil)
   107  
   108  func newMockClientConn(
   109  	conn net.Conn, tenant Tenant, ci clientInfo, router Router, tun *tunnel,
   110  ) ClientConn {
   111  	c := &mockClientConn{
   112  		conn:       conn,
   113  		tenant:     tenant,
   114  		clientInfo: ci,
   115  		router:     router,
   116  		tun:        tun,
   117  	}
   118  	return c
   119  }
   120  
   121  func (c *mockClientConn) ConnID() uint32                     { return 0 }
   122  func (c *mockClientConn) GetSalt() []byte                    { return nil }
   123  func (c *mockClientConn) GetHandshakePack() *frontend.Packet { return nil }
   124  func (c *mockClientConn) RawConn() net.Conn                  { return c.conn }
   125  func (c *mockClientConn) GetTenant() Tenant                  { return c.tenant }
   126  func (c *mockClientConn) SendErrToClient(err error)          {}
   127  func (c *mockClientConn) BuildConnWithServer(_ string) (ServerConn, error) {
   128  	cn, err := c.router.Route(context.TODO(), c.clientInfo, nil)
   129  	if err != nil {
   130  		return nil, err
   131  	}
   132  	cn.salt = testSlat
   133  	sc, _, err := c.router.Connect(cn, testPacket, c.tun)
   134  	if err != nil {
   135  		return nil, err
   136  	}
   137  	// Set the use defined variables, including session variables and user variables.
   138  	for _, stmt := range c.redoStmts {
   139  		if _, err := sc.ExecStmt(stmt, nil); err != nil {
   140  			return nil, err
   141  		}
   142  	}
   143  	return sc, nil
   144  }
   145  
   146  func (c *mockClientConn) HandleEvent(ctx context.Context, e IEvent, resp chan<- []byte) error {
   147  	switch ev := e.(type) {
   148  	case *killQueryEvent:
   149  		cn, err := c.router.SelectByConnID(ev.connID)
   150  		if err != nil {
   151  			sendResp([]byte(err.Error()), resp)
   152  			return err
   153  		}
   154  		sendResp([]byte(cn.addr), resp)
   155  		return nil
   156  	case *setVarEvent:
   157  		c.redoStmts = append(c.redoStmts, internalStmt{cmdType: cmdQuery, s: ev.stmt})
   158  		sendResp([]byte("ok"), resp)
   159  		return nil
   160  	default:
   161  		sendResp([]byte("type not supported"), resp)
   162  		return moerr.NewInternalErrorNoCtx("type not supported")
   163  	}
   164  }
   165  func (c *mockClientConn) Close() error { return nil }
   166  
   167  func testStartClient(t *testing.T, tp *testProxyHandler, ci clientInfo, cn *CNServer) func() {
   168  	if cn.salt == nil || len(cn.salt) != 20 {
   169  		cn.salt = testSlat
   170  	}
   171  	clientProxy, client := net.Pipe()
   172  	go func(ctx context.Context) {
   173  		b := make([]byte, 10)
   174  		for {
   175  			select {
   176  			case <-ctx.Done():
   177  				return
   178  			default:
   179  			}
   180  			_, _ = client.Read(b)
   181  		}
   182  	}(tp.ctx)
   183  	tu := newTunnel(tp.ctx, tp.logger, tp.counterSet)
   184  	sc, _, err := tp.ru.Connect(cn, testPacket, tu)
   185  	require.NoError(t, err)
   186  	cc := newMockClientConn(clientProxy, "t1", ci, tp.ru, tu)
   187  	err = tu.run(cc, sc)
   188  	require.NoError(t, err)
   189  	select {
   190  	case err := <-tu.errC:
   191  		t.Fatalf("tunnel error: %v", err)
   192  	default:
   193  	}
   194  	return func() {
   195  		_ = tu.Close()
   196  	}
   197  }
   198  
   199  func testStartNClients(t *testing.T, tp *testProxyHandler, ci clientInfo, cn *CNServer, n int) func() {
   200  	var cleanFns []func()
   201  	for i := 0; i < n; i++ {
   202  		c := testStartClient(t, tp, ci, cn)
   203  		cleanFns = append(cleanFns, c)
   204  	}
   205  	return func() {
   206  		for _, f := range cleanFns {
   207  			f()
   208  		}
   209  	}
   210  }
   211  
   212  func TestAccountParser(t *testing.T) {
   213  	cases := []struct {
   214  		str      string
   215  		tenant   string
   216  		username string
   217  		hasErr   bool
   218  	}{
   219  		{
   220  			str:      "t1:u1",
   221  			tenant:   "t1",
   222  			username: "u1",
   223  			hasErr:   false,
   224  		},
   225  		{
   226  			str:      "t1#u1",
   227  			tenant:   "t1",
   228  			username: "u1",
   229  			hasErr:   false,
   230  		},
   231  		{
   232  			str:      ":u1",
   233  			tenant:   "",
   234  			username: "",
   235  			hasErr:   true,
   236  		},
   237  		{
   238  			str:      "a:",
   239  			tenant:   "",
   240  			username: "",
   241  			hasErr:   true,
   242  		},
   243  		{
   244  			str:      "u1",
   245  			tenant:   frontend.GetDefaultTenant(),
   246  			username: "u1",
   247  			hasErr:   false,
   248  		},
   249  		{
   250  			str:      "t1:u1?a=1",
   251  			tenant:   "t1",
   252  			username: "u1",
   253  			hasErr:   false,
   254  		},
   255  	}
   256  	for _, item := range cases {
   257  		a := clientInfo{}
   258  		err := a.parse(item.str)
   259  		if item.hasErr {
   260  			require.Error(t, err)
   261  		} else {
   262  			require.NoError(t, err)
   263  		}
   264  		require.Equal(t, string(a.labelInfo.Tenant), item.tenant)
   265  		require.Equal(t, a.username, item.username)
   266  	}
   267  }
   268  
   269  func createNewClientConn(t *testing.T) (ClientConn, func()) {
   270  	s := goetty.NewIOSession(goetty.WithSessionConn(1,
   271  		newMockNetConn("127.0.0.1", 30001,
   272  			"127.0.0.1", 30010, nil)),
   273  		goetty.WithSessionCodec(WithProxyProtocolCodec(frontend.NewSqlCodec())))
   274  	ctx, cancel := context.WithCancel(context.Background())
   275  	clientBaseConnID = 90
   276  	rt := runtime.DefaultRuntime()
   277  	logger := rt.Logger()
   278  	cs := newCounterSet()
   279  	cc, err := newClientConn(ctx, &Config{}, logger, cs, s, nil, nil, nil, nil, nil)
   280  	require.NoError(t, err)
   281  	require.NotNil(t, cc)
   282  	return cc, func() {
   283  		cancel()
   284  		_ = cc.Close()
   285  	}
   286  }
   287  
   288  func TestNewClientConn(t *testing.T) {
   289  	cc, cleanup := createNewClientConn(t)
   290  	defer cleanup()
   291  	require.Equal(t, 91, int(cc.ConnID()))
   292  	require.Equal(t, 20, len(cc.GetSalt()))
   293  	require.NotNil(t, cc.RawConn())
   294  }
   295  
   296  func makeClientHandshakeResp() []byte {
   297  	payload := make([]byte, 200)
   298  	pos := 0
   299  	copy(payload[pos:], []byte{141, 162, 10, 0}) // Capabilities Flags
   300  	pos += 4
   301  	copy(payload[pos:], []byte{0, 0, 0, 0}) // maximum packet size
   302  	pos += 4
   303  	payload[pos] = 45 // client charset
   304  	pos += 1
   305  	pos += 23 // filler
   306  	username := "tenant1:user1"
   307  	copy(payload[pos:], username) // login username
   308  	pos += len(username)
   309  	payload[pos] = 0 // the end of username
   310  	pos += 1
   311  	payload[pos] = 20 // length of auth response
   312  	pos += 1
   313  	pos += 20 // auth response
   314  	dbname := "db1"
   315  	copy(payload[pos:], dbname) // db name
   316  	pos += len(dbname)
   317  	payload[pos] = 0 // end of db name
   318  	pos += 1
   319  	plugin := "mysql_native_password"
   320  	copy(payload[pos:], plugin)
   321  	pos += 1 + len(plugin)
   322  	data := make([]byte, pos+4)
   323  	data[0] = uint8(pos)
   324  	data[1] = uint8(pos >> 8)
   325  	data[2] = uint8(pos >> 16)
   326  	data[3] = 1
   327  	copy(data[4:], payload)
   328  	return data
   329  }
   330  
   331  func TestClientConn_ConnectToBackend(t *testing.T) {
   332  	defer leaktest.AfterTest(t)()
   333  
   334  	runtime.SetupProcessLevelRuntime(runtime.DefaultRuntime())
   335  	rt := runtime.DefaultRuntime()
   336  	logger := rt.Logger()
   337  
   338  	t.Run("cannot connect", func(t *testing.T) {
   339  		nilC := (*clientConn)(nil)
   340  		require.Equal(t, "", string(nilC.GetTenant()))
   341  		require.Nil(t, nilC.RawConn())
   342  
   343  		cc := &clientConn{
   344  			log: logger,
   345  		}
   346  		cc.testHelper.connectToBackend = func() (ServerConn, error) {
   347  			return nil, moerr.NewInternalErrorNoCtx("123 456")
   348  		}
   349  
   350  		sc, err := cc.BuildConnWithServer("aaa")
   351  		require.ErrorContains(t, err, "123 456")
   352  		require.Nil(t, sc)
   353  	})
   354  
   355  	t.Run("ok connect", func(t *testing.T) {
   356  		local, remote := net.Pipe()
   357  		require.NotNil(t, local)
   358  		require.NotNil(t, remote)
   359  
   360  		cc, cleanup := createNewClientConn(t)
   361  		defer cleanup()
   362  		c, ok := cc.(*clientConn)
   363  		require.True(t, ok)
   364  		require.NotNil(t, c)
   365  		c.conn.UseConn(local)
   366  		require.Equal(t, "", string(cc.GetTenant()))
   367  
   368  		var wg sync.WaitGroup
   369  		wg.Add(1)
   370  		go func() {
   371  			defer wg.Done()
   372  			b := make([]byte, 100)
   373  			// client reads init handshake.
   374  			n, err := remote.Read(b)
   375  			require.NoError(t, err)
   376  			require.NotEqual(t, 0, n)
   377  
   378  			// client sends handshake resp.
   379  			resp := makeClientHandshakeResp()
   380  			n, err = remote.Write(resp)
   381  			require.NoError(t, err)
   382  			require.Equal(t, len(resp), n)
   383  		}()
   384  
   385  		_, err := cc.BuildConnWithServer("")
   386  		require.Error(t, err) // just test client, no router set
   387  		require.Equal(t, "tenant1", string(cc.GetTenant()))
   388  		require.NotNil(t, cc.GetHandshakePack())
   389  		wg.Wait()
   390  	})
   391  }
   392  
   393  func TestClientConn_ReadPacket(t *testing.T) {
   394  	defer leaktest.AfterTest(t)()
   395  
   396  	cc, cleanup := createNewClientConn(t)
   397  	defer cleanup()
   398  	c, ok := cc.(*clientConn)
   399  	require.True(t, ok)
   400  	require.NotNil(t, c)
   401  
   402  	local, remote := net.Pipe()
   403  	require.NotNil(t, local)
   404  	require.NotNil(t, remote)
   405  
   406  	var wg sync.WaitGroup
   407  	wg.Add(1)
   408  	go func() {
   409  		defer wg.Done()
   410  		addr := &ProxyAddr{
   411  			SourceAddress: []byte{10, 10, 10, 10},
   412  			SourcePort:    1000,
   413  			TargetAddress: []byte{20, 20, 20, 20},
   414  			TargetPort:    2000,
   415  		}
   416  
   417  		b := buf.NewByteBuf(1000)
   418  
   419  		b.WriteString(ProxyProtocolV2Signature)
   420  		err := b.WriteByte(0)
   421  		require.NoError(t, err)
   422  		err = b.WriteByte(0)
   423  		require.NoError(t, err)
   424  		b.WriteUint16(12)
   425  		n, err := b.Write(addr.SourceAddress)
   426  		require.Equal(t, 4, n)
   427  		require.NoError(t, err)
   428  		n, err = b.Write(addr.TargetAddress)
   429  		require.Equal(t, 4, n)
   430  		require.NoError(t, err)
   431  		b.WriteUint16(addr.SourcePort)
   432  		b.WriteUint16(addr.TargetPort)
   433  
   434  		n, d := b.ReadAll()
   435  		require.Equal(t, 28, n)
   436  		err = binary.Write(remote, binary.BigEndian, d)
   437  		require.NoError(t, err)
   438  
   439  		// little endian
   440  		err = b.WriteByte(9)
   441  		require.NoError(t, err)
   442  		err = b.WriteByte(0)
   443  		require.NoError(t, err)
   444  		err = b.WriteByte(0)
   445  		require.NoError(t, err)
   446  		err = b.WriteByte(0)
   447  		require.NoError(t, err)
   448  		err = b.WriteByte(3)
   449  		require.NoError(t, err)
   450  		b.WriteString("select 1")
   451  
   452  		n, d = b.ReadAll()
   453  		require.Equal(t, 13, n)
   454  		err = binary.Write(remote, binary.LittleEndian, d)
   455  		require.NoError(t, err)
   456  	}()
   457  
   458  	c.conn.UseConn(local)
   459  	ret, err := c.readPacket()
   460  	require.NoError(t, err)
   461  	require.NotNil(t, ret)
   462  	require.Equal(t, 9, int(ret.Length))
   463  	require.Equal(t, 0, int(ret.SequenceID))
   464  	require.Equal(t, 3, int(ret.Payload[0]))
   465  	require.Equal(t, "select 1", string(ret.Payload[1:]))
   466  
   467  	wg.Wait()
   468  }
   469  
   470  func TestClientConn_ConnID(t *testing.T) {
   471  	parallel := 100
   472  	clientBaseConnID = 1
   473  	var wg sync.WaitGroup
   474  	for i := 0; i < parallel; i++ {
   475  		wg.Add(1)
   476  		go func() {
   477  			nextClientConnID()
   478  			defer wg.Done()
   479  		}()
   480  	}
   481  	wg.Wait()
   482  	require.Equal(t, 101, int(clientBaseConnID))
   483  }
   484  
   485  func TestClientConn_SendErrToClient(t *testing.T) {
   486  	local, remote := net.Pipe()
   487  	require.NotNil(t, local)
   488  	require.NotNil(t, remote)
   489  
   490  	cc, cleanup := createNewClientConn(t)
   491  	defer cleanup()
   492  	c, ok := cc.(*clientConn)
   493  	require.True(t, ok)
   494  	require.NotNil(t, c)
   495  	c.conn.UseConn(local)
   496  	require.Equal(t, "", string(cc.GetTenant()))
   497  
   498  	var wg sync.WaitGroup
   499  	wg.Add(1)
   500  	go func() {
   501  		defer wg.Done()
   502  		b := make([]byte, 100)
   503  		// client reads init handshake.
   504  		n, err := remote.Read(b)
   505  		require.NoError(t, err)
   506  		require.NotEqual(t, 0, n)
   507  
   508  		// client sends handshake resp.
   509  		resp := makeClientHandshakeResp()
   510  		n, err = remote.Write(resp)
   511  		require.NoError(t, err)
   512  		require.Equal(t, len(resp), n)
   513  
   514  		n, err = remote.Read(b)
   515  		require.NoError(t, err)
   516  		require.Equal(t, 33, n)
   517  		require.True(t, strings.Contains(string(b[4+1+2+1+5:n]), "internal error: msg1"))
   518  	}()
   519  
   520  	_, err := cc.BuildConnWithServer("")
   521  	require.Error(t, err) // just test client, no router set
   522  	require.Equal(t, "tenant1", string(cc.GetTenant()))
   523  	require.NotNil(t, cc.GetHandshakePack())
   524  	cc.SendErrToClient(moerr.NewInternalErrorNoCtx("msg1"))
   525  	wg.Wait()
   526  }