github.com/matrixorigin/matrixone@v1.2.0/pkg/proxy/ppv2_test.go (about)

     1  // Copyright 2021 - 2023 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package proxy
    16  
    17  import (
    18  	"testing"
    19  
    20  	"github.com/fagongzi/goetty/v2/buf"
    21  	"github.com/matrixorigin/matrixone/pkg/frontend"
    22  	"github.com/stretchr/testify/require"
    23  )
    24  
    25  func TestProxyProtocolOptions(t *testing.T) {
    26  	pp := &proxyProtocolCodec{}
    27  	ret := WithProxyProtocolCodec(pp)
    28  	require.NotNil(t, ret)
    29  }
    30  
    31  func TestProxyProtocolCodec_Decode(t *testing.T) {
    32  	t.Run("short header", func(t *testing.T) {
    33  		data := buf.NewByteBuf(100)
    34  		n, err := data.Write([]byte("12345"))
    35  		require.NoError(t, err)
    36  		require.Equal(t, 5, n)
    37  
    38  		pp := WithProxyProtocolCodec(frontend.NewSqlCodec())
    39  		res, ok, err := pp.Decode(data)
    40  		require.NoError(t, err)
    41  		require.False(t, ok)
    42  		require.Nil(t, res)
    43  	})
    44  
    45  	t.Run("long header", func(t *testing.T) {
    46  		data := buf.NewByteBuf(100)
    47  		n, err := data.Write([]byte("12345678901234567890"))
    48  		require.NoError(t, err)
    49  		require.Equal(t, 20, n)
    50  
    51  		pp := WithProxyProtocolCodec(frontend.NewSqlCodec())
    52  		res, ok, err := pp.Decode(data)
    53  		require.NoError(t, err)
    54  		require.False(t, ok)
    55  		require.Nil(t, res)
    56  	})
    57  
    58  	t.Run("local address", func(t *testing.T) {
    59  		data := buf.NewByteBuf(100)
    60  		n, err := data.Write([]byte(ProxyProtocolV2Signature))
    61  		require.NoError(t, err)
    62  		require.Equal(t, len(ProxyProtocolV2Signature), n)
    63  
    64  		// skip 2 bytes
    65  		n, err = data.Write([]byte{0, 0})
    66  		require.NoError(t, err)
    67  		require.Equal(t, 2, n)
    68  		data.WriteUint16(0)
    69  
    70  		pp := WithProxyProtocolCodec(frontend.NewSqlCodec())
    71  		res, ok, err := pp.Decode(data)
    72  		require.NoError(t, err)
    73  		require.False(t, ok)
    74  		require.Nil(t, res)
    75  	})
    76  
    77  	t.Run("ipv4 address", func(t *testing.T) {
    78  		data := buf.NewByteBuf(100)
    79  		n, err := data.Write([]byte(ProxyProtocolV2Signature))
    80  		require.NoError(t, err)
    81  		require.Equal(t, len(ProxyProtocolV2Signature), n)
    82  
    83  		// ipv4
    84  		n, err = data.Write([]byte{0, tcpOverIPv4})
    85  		require.NoError(t, err)
    86  		require.Equal(t, 2, n)
    87  
    88  		// ipv4 length
    89  		data.WriteUint16(12)
    90  		// source address
    91  		err = data.WriteByte(10)
    92  		require.NoError(t, err)
    93  		err = data.WriteByte(11)
    94  		require.NoError(t, err)
    95  		err = data.WriteByte(12)
    96  		require.NoError(t, err)
    97  		err = data.WriteByte(13)
    98  		require.NoError(t, err)
    99  		// target address
   100  		err = data.WriteByte(20)
   101  		require.NoError(t, err)
   102  		err = data.WriteByte(21)
   103  		require.NoError(t, err)
   104  		err = data.WriteByte(22)
   105  		require.NoError(t, err)
   106  		err = data.WriteByte(23)
   107  		require.NoError(t, err)
   108  		// source port
   109  		data.WriteUint16(8000)
   110  		// target port
   111  		data.WriteUint16(9000)
   112  
   113  		pp := &proxyProtocolCodec{}
   114  		res, ok, err := pp.Decode(data)
   115  		require.NoError(t, err)
   116  		require.True(t, ok)
   117  		require.NotNil(t, res)
   118  		addr, ok := res.(*ProxyAddr)
   119  		require.True(t, ok)
   120  		require.Equal(t, "10.11.12.13", addr.SourceAddress.To4().String())
   121  		require.Equal(t, 8000, int(addr.SourcePort))
   122  		require.Equal(t, "20.21.22.23", addr.TargetAddress.To4().String())
   123  		require.Equal(t, 9000, int(addr.TargetPort))
   124  	})
   125  
   126  	t.Run("ipv6 address", func(t *testing.T) {
   127  		data := buf.NewByteBuf(100)
   128  		n, err := data.Write([]byte(ProxyProtocolV2Signature))
   129  		require.NoError(t, err)
   130  		require.Equal(t, len(ProxyProtocolV2Signature), n)
   131  
   132  		// ipv6
   133  		n, err = data.Write([]byte{0, tcpOverIPv6})
   134  		require.NoError(t, err)
   135  		require.Equal(t, 2, n)
   136  
   137  		// ipv4 length
   138  		data.WriteUint16(36)
   139  		// source address
   140  		for i := 0; i < 16; i++ {
   141  			err = data.WriteByte(byte(10 + i))
   142  			require.NoError(t, err)
   143  		}
   144  		// target address
   145  		for i := 0; i < 16; i++ {
   146  			err = data.WriteByte(byte(50 + i))
   147  			require.NoError(t, err)
   148  		}
   149  		// source port
   150  		data.WriteUint16(8000)
   151  		// target port
   152  		data.WriteUint16(9000)
   153  
   154  		pp := &proxyProtocolCodec{}
   155  		res, ok, err := pp.Decode(data)
   156  		require.NoError(t, err)
   157  		require.True(t, ok)
   158  		require.NotNil(t, res)
   159  		addr, ok := res.(*ProxyAddr)
   160  		require.True(t, ok)
   161  		require.Equal(t, "a0b:c0d:e0f:1011:1213:1415:1617:1819", addr.SourceAddress.To16().String())
   162  		require.Equal(t, 8000, int(addr.SourcePort))
   163  		require.Equal(t, "3233:3435:3637:3839:3a3b:3c3d:3e3f:4041", addr.TargetAddress.To16().String())
   164  		require.Equal(t, 9000, int(addr.TargetPort))
   165  	})
   166  
   167  	t.Run("error length", func(t *testing.T) {
   168  		data := buf.NewByteBuf(100)
   169  		n, err := data.Write([]byte(ProxyProtocolV2Signature))
   170  		require.NoError(t, err)
   171  		require.Equal(t, len(ProxyProtocolV2Signature), n)
   172  
   173  		// skip 2 bytes
   174  		n, err = data.Write([]byte{0, 0})
   175  		require.NoError(t, err)
   176  		require.Equal(t, 2, n)
   177  		data.WriteUint16(33)
   178  
   179  		pp := &proxyProtocolCodec{}
   180  		res, ok, err := pp.Decode(data)
   181  		require.Error(t, err)
   182  		require.False(t, ok)
   183  		require.Nil(t, res)
   184  	})
   185  }