github.com/matrixorigin/matrixone@v1.2.0/pkg/proxy/event_test.go (about)

     1  // Copyright 2021 - 2023 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package proxy
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"net"
    21  	"os"
    22  	"testing"
    23  	"time"
    24  
    25  	"github.com/lni/goutils/leaktest"
    26  	"github.com/matrixorigin/matrixone/pkg/common/stopper"
    27  	"github.com/stretchr/testify/require"
    28  )
    29  
    30  func TestMakeEvent(t *testing.T) {
    31  	e, r := makeEvent(nil, nil)
    32  	require.Nil(t, e)
    33  	require.False(t, r)
    34  
    35  	t.Run("kill query", func(t *testing.T) {
    36  		e, r = makeEvent(makeSimplePacket("kill quer8y 12"), nil)
    37  		require.Nil(t, e)
    38  		require.False(t, r)
    39  
    40  		e, r = makeEvent(makeSimplePacket("kill query 123"), nil)
    41  		require.NotNil(t, e)
    42  		require.True(t, r)
    43  
    44  		e, r = makeEvent(makeSimplePacket("kiLL Query 12"), nil)
    45  		require.NotNil(t, e)
    46  		require.True(t, r)
    47  
    48  		e, r = makeEvent(makeSimplePacket("set "), nil)
    49  		require.Nil(t, e)
    50  		require.False(t, r)
    51  	})
    52  
    53  	t.Run("set var", func(t *testing.T) {
    54  		stmtsValid := []string{
    55  			"set session a=1",
    56  			"set session a='1'",
    57  			"set local a=1",
    58  			"set local a='1'",
    59  			"set @@session.a=1",
    60  			"set @@session.a='1'",
    61  			"set @@local.a=1",
    62  			"set @@local.a='1'",
    63  			"set @@a=1",
    64  			"set @@a='1'",
    65  			"set a=1",
    66  			"set a='1'",
    67  			// user variables.
    68  			"set @a=1",
    69  			"set @a='1'",
    70  			// session variables.
    71  			"set session a:=1",
    72  			"set session a:='1'",
    73  			"set local a:=1",
    74  			"set local a:='1'",
    75  			"set @@session.a:=1",
    76  			"set @@session.a:='1'",
    77  			"set @@local.a:=1",
    78  			"set @@local.a:='1'",
    79  			"set @@a:=1",
    80  			"set @@a:='1'",
    81  			"set a:=1",
    82  			"set a:='1'",
    83  			// user variables.
    84  			"set @a:=1",
    85  			"set @a:='1'",
    86  		}
    87  		stmtsInvalid := []string{
    88  			"set '1'",
    89  			"set _'1'",
    90  			"set _a'1'",
    91  			"set @a@:='1'",
    92  			"set @a:='1",
    93  		}
    94  		for _, stmt := range stmtsValid {
    95  			e, r = makeEvent(makeSimplePacket(stmt), nil)
    96  			require.NotNil(t, e)
    97  			require.False(t, r)
    98  		}
    99  		for _, stmt := range stmtsInvalid {
   100  			e, r = makeEvent(makeSimplePacket(stmt), nil)
   101  			require.Nil(t, e)
   102  			require.False(t, r)
   103  		}
   104  	})
   105  }
   106  
   107  func TestKillQueryEvent(t *testing.T) {
   108  	defer leaktest.AfterTest(t)()
   109  
   110  	tp := newTestProxyHandler(t)
   111  	defer tp.closeFn()
   112  
   113  	temp := os.TempDir()
   114  	addr1 := fmt.Sprintf("%s/%d.sock", temp, time.Now().Nanosecond())
   115  	require.NoError(t, os.RemoveAll(addr1))
   116  	cn1 := testMakeCNServer("uuid1", addr1, 10, "", labelInfo{})
   117  	stopFn1 := startTestCNServer(t, tp.ctx, addr1, nil)
   118  	defer func() {
   119  		require.NoError(t, stopFn1())
   120  	}()
   121  
   122  	addr2 := fmt.Sprintf("%s/%d.sock", temp, time.Now().Nanosecond())
   123  	require.NoError(t, os.RemoveAll(addr2))
   124  	cn2 := testMakeCNServer("uuid2", addr2, 20, "", labelInfo{})
   125  	stopFn2 := startTestCNServer(t, tp.ctx, addr2, nil)
   126  	defer func() {
   127  		require.NoError(t, stopFn2())
   128  	}()
   129  
   130  	tu1 := newTunnel(tp.ctx, tp.logger, nil)
   131  	defer func() { _ = tu1.Close() }()
   132  	tu2 := newTunnel(tp.ctx, tp.logger, nil)
   133  	defer func() { _ = tu2.Close() }()
   134  
   135  	// Client2 will send "kill query 10", which will route to the server which
   136  	// has connection ID 10. In this case, the connection is server1.
   137  	clientProxy1, _ := net.Pipe()
   138  	serverProxy1, _ := net.Pipe()
   139  
   140  	cc1 := newMockClientConn(clientProxy1, "t1", clientInfo{}, tp.ru, tu1)
   141  	require.NotNil(t, cc1)
   142  	sc1 := newMockServerConn(serverProxy1)
   143  	require.NotNil(t, sc1)
   144  
   145  	clientProxy2, client2 := net.Pipe()
   146  	serverProxy2, _ := net.Pipe()
   147  
   148  	cc2 := newMockClientConn(clientProxy2, "t1", clientInfo{}, tp.ru, tu2)
   149  	require.NotNil(t, cc2)
   150  	sc2 := newMockServerConn(serverProxy2)
   151  	require.NotNil(t, sc2)
   152  
   153  	res := make(chan []byte)
   154  	st := stopper.NewStopper("test-event", stopper.WithLogger(tp.logger.RawLogger()))
   155  	defer st.Stop()
   156  	err := st.RunNamedTask("test-event-handler", func(ctx context.Context) {
   157  		for {
   158  			select {
   159  			case e := <-tu2.reqC:
   160  				err := cc2.HandleEvent(ctx, e, tu2.respC)
   161  				require.NoError(t, err)
   162  			case r := <-tu2.respC:
   163  				if len(r) > 0 {
   164  					res <- r
   165  				}
   166  			case <-ctx.Done():
   167  				return
   168  			}
   169  		}
   170  	})
   171  	require.NoError(t, err)
   172  
   173  	// tunnel1 is on cn1, connection ID is 10.
   174  	_, ret, err := tp.ru.Connect(cn1, testPacket, tu1)
   175  	require.NoError(t, err)
   176  	// get connection id from result.
   177  	connID := ret[6]
   178  
   179  	// tunnel2 is on cn2, connection ID is 20.
   180  	_, _, err = tp.ru.Connect(cn2, testPacket, tu2)
   181  	require.NoError(t, err)
   182  
   183  	err = tu1.run(cc1, sc1)
   184  	require.NoError(t, err)
   185  	require.Nil(t, tu1.ctx.Err())
   186  
   187  	func() {
   188  		tu1.mu.Lock()
   189  		defer tu1.mu.Unlock()
   190  		require.True(t, tu1.mu.started)
   191  	}()
   192  
   193  	err = tu2.run(cc2, sc2)
   194  	require.NoError(t, err)
   195  	require.Nil(t, tu2.ctx.Err())
   196  
   197  	func() {
   198  		tu2.mu.Lock()
   199  		defer tu2.mu.Unlock()
   200  		require.True(t, tu2.mu.started)
   201  	}()
   202  
   203  	tu1.mu.Lock()
   204  	csp1 := tu1.mu.csp
   205  	scp1 := tu1.mu.scp
   206  	tu1.mu.Unlock()
   207  
   208  	tu2.mu.Lock()
   209  	csp2 := tu2.mu.csp
   210  	scp2 := tu2.mu.scp
   211  	tu2.mu.Unlock()
   212  
   213  	barrierStart1, barrierEnd1 := make(chan struct{}), make(chan struct{})
   214  	barrierStart2, barrierEnd2 := make(chan struct{}), make(chan struct{})
   215  	csp1.testHelper.beforeSend = func() {
   216  		<-barrierStart1
   217  		<-barrierEnd1
   218  	}
   219  	csp2.testHelper.beforeSend = func() {
   220  		<-barrierStart2
   221  		<-barrierEnd2
   222  	}
   223  
   224  	csp1.mu.Lock()
   225  	require.True(t, csp1.mu.started)
   226  	csp1.mu.Unlock()
   227  
   228  	scp1.mu.Lock()
   229  	require.True(t, scp1.mu.started)
   230  	scp1.mu.Unlock()
   231  
   232  	csp2.mu.Lock()
   233  	require.True(t, csp2.mu.started)
   234  	csp2.mu.Unlock()
   235  
   236  	scp2.mu.Lock()
   237  	require.True(t, scp2.mu.started)
   238  	scp2.mu.Unlock()
   239  
   240  	// Client2 writes some MySQL packets.
   241  	sendEventCh := make(chan struct{}, 1)
   242  	errChan := make(chan error, 1)
   243  	go func() {
   244  		<-sendEventCh
   245  		// client2 send kill query 10, which is on server1.
   246  		if _, err := client2.Write(makeSimplePacket(fmt.Sprintf("kill query %d", connID))); err != nil {
   247  			errChan <- err
   248  			return
   249  		}
   250  	}()
   251  
   252  	sendEventCh <- struct{}{}
   253  	barrierStart2 <- struct{}{}
   254  	barrierEnd2 <- struct{}{}
   255  
   256  	addr := string(<-res)
   257  	// This test case is mainly focus on if the query is route to the
   258  	// right cn server, but not the result of the query. So we just
   259  	// check the address which is handled is equal to cn1, but not cn2.
   260  	require.Equal(t, cn1.addr, addr)
   261  
   262  	select {
   263  	case err = <-errChan:
   264  		t.Fatalf("require no error, but got %v", err)
   265  	default:
   266  	}
   267  }
   268  
   269  func TestSetVarEvent(t *testing.T) {
   270  	defer leaktest.AfterTest(t)()
   271  
   272  	temp := os.TempDir()
   273  	tp := newTestProxyHandler(t)
   274  	defer tp.closeFn()
   275  
   276  	addr1 := fmt.Sprintf("%s/%d.sock", temp, time.Now().Nanosecond())
   277  	require.NoError(t, os.RemoveAll(addr1))
   278  	cn1 := testMakeCNServer("uuid1", addr1, 10, "", labelInfo{})
   279  	stopFn1 := startTestCNServer(t, tp.ctx, addr1, nil)
   280  	defer func() {
   281  		require.NoError(t, stopFn1())
   282  	}()
   283  
   284  	addr2 := fmt.Sprintf("%s/%d.sock", temp, time.Now().Nanosecond())
   285  	require.NoError(t, os.RemoveAll(addr2))
   286  	cn2 := testMakeCNServer("uuid2", addr2, 20, "", labelInfo{})
   287  	stopFn2 := startTestCNServer(t, tp.ctx, addr2, nil)
   288  	defer func() {
   289  		require.NoError(t, stopFn2())
   290  	}()
   291  
   292  	tu1 := newTunnel(tp.ctx, tp.logger, nil)
   293  	defer func() { _ = tu1.Close() }()
   294  	tu2 := newTunnel(tp.ctx, tp.logger, nil)
   295  	defer func() { _ = tu2.Close() }()
   296  
   297  	// Client2 will send "kill query 10", which will route to the server which
   298  	// has connection ID 10. In this case, the connection is server1.
   299  	clientProxy1, _ := net.Pipe()
   300  	serverProxy1, _ := net.Pipe()
   301  
   302  	cc1 := newMockClientConn(clientProxy1, "t1", clientInfo{}, tp.ru, tu1)
   303  	require.NotNil(t, cc1)
   304  	sc1 := newMockServerConn(serverProxy1)
   305  	require.NotNil(t, sc1)
   306  
   307  	clientProxy2, client2 := net.Pipe()
   308  	serverProxy2, _ := net.Pipe()
   309  
   310  	cc2 := newMockClientConn(clientProxy2, "t1", clientInfo{}, tp.ru, tu2)
   311  	require.NotNil(t, cc2)
   312  	sc2 := newMockServerConn(serverProxy2)
   313  	require.NotNil(t, sc2)
   314  
   315  	res := make(chan []byte)
   316  	st := stopper.NewStopper("test-event", stopper.WithLogger(tp.logger.RawLogger()))
   317  	defer st.Stop()
   318  	err := st.RunNamedTask("test-event-handler", func(ctx context.Context) {
   319  		for {
   320  			select {
   321  			case e := <-tu2.reqC:
   322  				_ = cc2.HandleEvent(ctx, e, tu2.respC)
   323  			case r := <-tu2.respC:
   324  				if len(r) > 0 {
   325  					res <- r
   326  				}
   327  			case <-ctx.Done():
   328  				return
   329  			}
   330  		}
   331  	})
   332  	require.NoError(t, err)
   333  
   334  	// tunnel1 is on cn1, connection ID is 10.
   335  	_, _, err = tp.ru.Connect(cn1, testPacket, tu1)
   336  	require.NoError(t, err)
   337  
   338  	// tunnel2 is on cn2, connection ID is 20.
   339  	_, _, err = tp.ru.Connect(cn2, testPacket, tu2)
   340  	require.NoError(t, err)
   341  
   342  	err = tu1.run(cc1, sc1)
   343  	require.NoError(t, err)
   344  	require.Nil(t, tu1.ctx.Err())
   345  
   346  	func() {
   347  		tu1.mu.Lock()
   348  		defer tu1.mu.Unlock()
   349  		require.True(t, tu1.mu.started)
   350  	}()
   351  
   352  	err = tu2.run(cc2, sc2)
   353  	require.NoError(t, err)
   354  	require.Nil(t, tu2.ctx.Err())
   355  
   356  	func() {
   357  		tu2.mu.Lock()
   358  		defer tu2.mu.Unlock()
   359  		require.True(t, tu2.mu.started)
   360  	}()
   361  
   362  	tu1.mu.Lock()
   363  	csp1 := tu1.mu.csp
   364  	scp1 := tu1.mu.scp
   365  	tu1.mu.Unlock()
   366  
   367  	tu2.mu.Lock()
   368  	csp2 := tu2.mu.csp
   369  	scp2 := tu2.mu.scp
   370  	tu2.mu.Unlock()
   371  
   372  	barrierStart1, barrierEnd1 := make(chan struct{}), make(chan struct{})
   373  	barrierStart2, barrierEnd2 := make(chan struct{}), make(chan struct{})
   374  	csp1.testHelper.beforeSend = func() {
   375  		<-barrierStart1
   376  		<-barrierEnd1
   377  	}
   378  	csp2.testHelper.beforeSend = func() {
   379  		<-barrierStart2
   380  		<-barrierEnd2
   381  	}
   382  
   383  	csp1.mu.Lock()
   384  	require.True(t, csp1.mu.started)
   385  	csp1.mu.Unlock()
   386  
   387  	scp1.mu.Lock()
   388  	require.True(t, scp1.mu.started)
   389  	scp1.mu.Unlock()
   390  
   391  	csp2.mu.Lock()
   392  	require.True(t, csp2.mu.started)
   393  	csp2.mu.Unlock()
   394  
   395  	scp2.mu.Lock()
   396  	require.True(t, scp2.mu.started)
   397  	scp2.mu.Unlock()
   398  
   399  	// Client2 writes some MySQL packets.
   400  	sendEventCh := make(chan struct{}, 1)
   401  	errChan := make(chan error, 1)
   402  	go func() {
   403  		<-sendEventCh
   404  		if _, err := client2.Write(makeSimplePacket("set session cn_label='account=acc1'")); err != nil {
   405  			errChan <- err
   406  			return
   407  		}
   408  	}()
   409  
   410  	sendEventCh <- struct{}{}
   411  	barrierStart2 <- struct{}{}
   412  	barrierEnd2 <- struct{}{}
   413  
   414  	// wait for result
   415  	<-res
   416  	require.Equal(t, 1, len(cc2.(*mockClientConn).redoStmts))
   417  
   418  	select {
   419  	case err = <-errChan:
   420  		t.Fatalf("require no error, but got %v", err)
   421  	default:
   422  	}
   423  }
   424  
   425  func TestEventType_String(t *testing.T) {
   426  	e1 := baseEvent{}
   427  	require.Equal(t, "Unknown", e1.eventType().String())
   428  
   429  	e2 := killQueryEvent{}
   430  	require.Equal(t, "KillQuery", e2.eventType().String())
   431  
   432  	e3 := setVarEvent{}
   433  	require.Equal(t, "SetVar", e3.eventType().String())
   434  }