github.com/matrixorigin/matrixone@v1.2.0/pkg/proxy/mysql_conn_buf_test.go (about)

     1  // Copyright 2021 - 2023 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package proxy
    16  
    17  import (
    18  	"math"
    19  	"net"
    20  	"testing"
    21  
    22  	"github.com/lni/goutils/leaktest"
    23  	"github.com/stretchr/testify/require"
    24  )
    25  
    26  func TestMySQLConnPreRecv(t *testing.T) {
    27  	defer leaktest.AfterTest(t)()
    28  
    29  	t.Run("protocol_error/length-0", func(t *testing.T) {
    30  		var data [10]byte
    31  		src, dst := net.Pipe()
    32  		go func() {
    33  			n, err := dst.Write(data[:])
    34  			require.NoError(t, err)
    35  			require.Equal(t, 10, n)
    36  		}()
    37  		sc := newMySQLConn("source", src, 10, nil, nil, 0)
    38  		size, err := sc.preRecv()
    39  		require.NoError(t, err, "mysql protocol error")
    40  		require.Equal(t, 4, size)
    41  	})
    42  
    43  	t.Run("protocol_error/length-max", func(t *testing.T) {
    44  		var data [10]byte
    45  		l := math.MaxInt32
    46  		data[0] = byte(l)
    47  		data[1] = byte(l >> 8)
    48  		data[2] = byte(l >> 16)
    49  		src, dst := net.Pipe()
    50  		go func() {
    51  			n, err := dst.Write(data[:])
    52  			require.NoError(t, err)
    53  			require.Equal(t, 10, n)
    54  		}()
    55  		// sc := newMySQLConn("source", src, 10, nil, nil, nil)
    56  		sc := newMySQLConn("source", src, 10, nil, nil, 0)
    57  		size, err := sc.preRecv()
    58  		require.NoError(t, err)
    59  		require.Equal(t, mysqlHeadLen+(1<<24-1), size)
    60  	})
    61  
    62  	t.Run("protocol_error/length-normal", func(t *testing.T) {
    63  		var data [10]byte
    64  		l := 6
    65  		data[0] = byte(l)
    66  		data[1] = byte(l >> 8)
    67  		data[2] = byte(l >> 16)
    68  		src, dst := net.Pipe()
    69  		go func() {
    70  			n, err := dst.Write(data[:])
    71  			require.NoError(t, err)
    72  			require.Equal(t, 10, n)
    73  		}()
    74  		sc := newMySQLConn("source", src, 10, nil, nil, 0)
    75  		size, err := sc.preRecv()
    76  		require.NoError(t, err)
    77  		require.Equal(t, 10, size)
    78  	})
    79  
    80  	t.Run("protocol_error/begin", func(t *testing.T) {
    81  		data := makeSimplePacket("begin")
    82  		src, dst := net.Pipe()
    83  		go func() {
    84  			n, err := dst.Write(data[:])
    85  			require.NoError(t, err)
    86  			require.Equal(t, 10, n)
    87  		}()
    88  		sc := newMySQLConn("source", src, 10, nil, nil, 0)
    89  		size, err := sc.preRecv()
    90  		require.NoError(t, err)
    91  		require.Equal(t, 10, size)
    92  	})
    93  
    94  	t.Run("protocol_error/commit", func(t *testing.T) {
    95  		data := makeSimplePacket("commit")
    96  		src, dst := net.Pipe()
    97  		go func() {
    98  			n, err := dst.Write(data[:])
    99  			require.NoError(t, err)
   100  			require.Equal(t, 11, n)
   101  		}()
   102  		sc := newMySQLConn("source", src, 20, nil, nil, 0)
   103  		size, err := sc.preRecv()
   104  		require.NoError(t, err)
   105  		require.Equal(t, 11, size)
   106  	})
   107  
   108  	t.Run("protocol_error/rollback", func(t *testing.T) {
   109  		data := makeSimplePacket("rollback")
   110  		src, dst := net.Pipe()
   111  		go func() {
   112  			n, err := dst.Write(data[:])
   113  			require.NoError(t, err)
   114  			require.Equal(t, 13, n)
   115  		}()
   116  		sc := newMySQLConn("source", src, 20, nil, nil, 0)
   117  		size, err := sc.preRecv()
   118  		require.NoError(t, err)
   119  		require.Equal(t, 13, size)
   120  	})
   121  }
   122  
   123  func TestMySQLConnReceive(t *testing.T) {
   124  	defer leaktest.AfterTest(t)()
   125  
   126  	t.Run("small buffer", func(t *testing.T) {
   127  		q := "select 1"
   128  		data := makeSimplePacket(q)
   129  		src, dst := net.Pipe()
   130  		go func() {
   131  			n, err := dst.Write(data[:])
   132  			require.NoError(t, err)
   133  			require.Equal(t, 13, n)
   134  		}()
   135  		sc := newMySQLConn("source", src, 6, nil, nil, 0)
   136  		res, err := sc.receive()
   137  		require.NoError(t, err)
   138  		require.Equal(t, q, string(res[5:]))
   139  	})
   140  
   141  	t.Run("enough buffer", func(t *testing.T) {
   142  		q := "select 1"
   143  		data := makeSimplePacket(q)
   144  		src, dst := net.Pipe()
   145  		go func() {
   146  			n, err := dst.Write(data[:])
   147  			require.NoError(t, err)
   148  			require.Equal(t, 13, n)
   149  		}()
   150  		sc := newMySQLConn("source", src, 100, nil, nil, 0)
   151  		res, err := sc.receive()
   152  		require.NoError(t, err)
   153  		require.Equal(t, q, string(res[5:]))
   154  	})
   155  }
   156  
   157  func TestMySQLConnSend(t *testing.T) {
   158  	defer leaktest.AfterTest(t)()
   159  
   160  	t.Run("small buffer", func(t *testing.T) {
   161  		q := "select 1"
   162  		data := makeSimplePacket(q)
   163  		// data write to src1, src1 pipe to dst1,
   164  		// dst1 sendTo src2, src2 pipe to dst2, dst2 read data.
   165  		src1, dst1 := net.Pipe()
   166  		src2, dst2 := net.Pipe()
   167  
   168  		go func() {
   169  			n, err := src1.Write(data[:])
   170  			require.NoError(t, err)
   171  			require.Equal(t, 13, n)
   172  		}()
   173  		go func() {
   174  			var res [30]byte
   175  			n, err := dst2.Read(res[:])
   176  			require.NoError(t, err)
   177  			require.Equal(t, 8, n)
   178  			require.Equal(t, "sel", string(res[5:n]))
   179  
   180  			n, err = dst2.Read(res[:])
   181  			require.NoError(t, err)
   182  			require.Equal(t, 5, n)
   183  			require.Equal(t, "ect 1", string(res[:n]))
   184  		}()
   185  		d1 := newMySQLConn("source", dst1, 8, nil, nil, 0)
   186  		err := d1.sendTo(src2)
   187  		require.NoError(t, err)
   188  	})
   189  
   190  	t.Run("enough buffer", func(t *testing.T) {
   191  		q := "select 1"
   192  		data := makeSimplePacket(q)
   193  		// data write to src1, src1 pipe to dst1,
   194  		// dst1 sendTo src2, src2 pipe to dst2, dst2 read data.
   195  		src1, dst1 := net.Pipe()
   196  		src2, dst2 := net.Pipe()
   197  
   198  		go func() {
   199  			n, err := src1.Write(data[:])
   200  			require.NoError(t, err)
   201  			require.Equal(t, 13, n)
   202  		}()
   203  		go func() {
   204  			var res [30]byte
   205  			n, err := dst2.Read(res[:])
   206  			require.NoError(t, err)
   207  			require.Equal(t, 13, n)
   208  			require.Equal(t, q, string(res[5:n]))
   209  		}()
   210  		d1 := newMySQLConn("source", dst1, 30, nil, nil, 0)
   211  		err := d1.sendTo(src2)
   212  		require.NoError(t, err)
   213  	})
   214  }
   215  
   216  func TestMySQLConnSize(t *testing.T) {
   217  	defer leaktest.AfterTest(t)()
   218  
   219  	t.Run("available read", func(t *testing.T) {
   220  		src, dst := net.Pipe()
   221  		go func() {
   222  			n, err := src.Write([]byte("123456789012"))
   223  			require.NoError(t, err)
   224  			require.Equal(t, 12, n)
   225  		}()
   226  		s := newMySQLConn("source", dst, 30, nil, nil, 0)
   227  		require.Equal(t, 0, s.readAvail())
   228  
   229  		require.NoError(t, s.receiveAtLeast(3))
   230  		require.Equal(t, 12, s.readAvail())
   231  	})
   232  
   233  	t.Run("available write", func(t *testing.T) {
   234  		src, dst := net.Pipe()
   235  		go func() {
   236  			n, err := src.Write([]byte("123456789012"))
   237  			require.NoError(t, err)
   238  			require.Equal(t, 12, n)
   239  		}()
   240  		s := newMySQLConn("source", dst, 30, nil, nil, 0)
   241  		require.Equal(t, 30, s.writeAvail())
   242  
   243  		require.NoError(t, s.receiveAtLeast(3))
   244  		require.Equal(t, 18, s.writeAvail())
   245  	})
   246  }