github.com/matrixorigin/matrixone@v1.2.0/pkg/proxy/ppv2_test.go (about) 1 // Copyright 2021 - 2023 Matrix Origin 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package proxy 16 17 import ( 18 "testing" 19 20 "github.com/fagongzi/goetty/v2/buf" 21 "github.com/matrixorigin/matrixone/pkg/frontend" 22 "github.com/stretchr/testify/require" 23 ) 24 25 func TestProxyProtocolOptions(t *testing.T) { 26 pp := &proxyProtocolCodec{} 27 ret := WithProxyProtocolCodec(pp) 28 require.NotNil(t, ret) 29 } 30 31 func TestProxyProtocolCodec_Decode(t *testing.T) { 32 t.Run("short header", func(t *testing.T) { 33 data := buf.NewByteBuf(100) 34 n, err := data.Write([]byte("12345")) 35 require.NoError(t, err) 36 require.Equal(t, 5, n) 37 38 pp := WithProxyProtocolCodec(frontend.NewSqlCodec()) 39 res, ok, err := pp.Decode(data) 40 require.NoError(t, err) 41 require.False(t, ok) 42 require.Nil(t, res) 43 }) 44 45 t.Run("long header", func(t *testing.T) { 46 data := buf.NewByteBuf(100) 47 n, err := data.Write([]byte("12345678901234567890")) 48 require.NoError(t, err) 49 require.Equal(t, 20, n) 50 51 pp := WithProxyProtocolCodec(frontend.NewSqlCodec()) 52 res, ok, err := pp.Decode(data) 53 require.NoError(t, err) 54 require.False(t, ok) 55 require.Nil(t, res) 56 }) 57 58 t.Run("local address", func(t *testing.T) { 59 data := buf.NewByteBuf(100) 60 n, err := data.Write([]byte(ProxyProtocolV2Signature)) 61 require.NoError(t, err) 62 require.Equal(t, len(ProxyProtocolV2Signature), n) 63 64 // skip 2 bytes 65 n, err = data.Write([]byte{0, 0}) 66 require.NoError(t, err) 67 require.Equal(t, 2, n) 68 data.WriteUint16(0) 69 70 pp := WithProxyProtocolCodec(frontend.NewSqlCodec()) 71 res, ok, err := pp.Decode(data) 72 require.NoError(t, err) 73 require.False(t, ok) 74 require.Nil(t, res) 75 }) 76 77 t.Run("ipv4 address", func(t *testing.T) { 78 data := buf.NewByteBuf(100) 79 n, err := data.Write([]byte(ProxyProtocolV2Signature)) 80 require.NoError(t, err) 81 require.Equal(t, len(ProxyProtocolV2Signature), n) 82 83 // ipv4 84 n, err = data.Write([]byte{0, tcpOverIPv4}) 85 require.NoError(t, err) 86 require.Equal(t, 2, n) 87 88 // ipv4 length 89 data.WriteUint16(12) 90 // source address 91 err = data.WriteByte(10) 92 require.NoError(t, err) 93 err = data.WriteByte(11) 94 require.NoError(t, err) 95 err = data.WriteByte(12) 96 require.NoError(t, err) 97 err = data.WriteByte(13) 98 require.NoError(t, err) 99 // target address 100 err = data.WriteByte(20) 101 require.NoError(t, err) 102 err = data.WriteByte(21) 103 require.NoError(t, err) 104 err = data.WriteByte(22) 105 require.NoError(t, err) 106 err = data.WriteByte(23) 107 require.NoError(t, err) 108 // source port 109 data.WriteUint16(8000) 110 // target port 111 data.WriteUint16(9000) 112 113 pp := &proxyProtocolCodec{} 114 res, ok, err := pp.Decode(data) 115 require.NoError(t, err) 116 require.True(t, ok) 117 require.NotNil(t, res) 118 addr, ok := res.(*ProxyAddr) 119 require.True(t, ok) 120 require.Equal(t, "10.11.12.13", addr.SourceAddress.To4().String()) 121 require.Equal(t, 8000, int(addr.SourcePort)) 122 require.Equal(t, "20.21.22.23", addr.TargetAddress.To4().String()) 123 require.Equal(t, 9000, int(addr.TargetPort)) 124 }) 125 126 t.Run("ipv6 address", func(t *testing.T) { 127 data := buf.NewByteBuf(100) 128 n, err := data.Write([]byte(ProxyProtocolV2Signature)) 129 require.NoError(t, err) 130 require.Equal(t, len(ProxyProtocolV2Signature), n) 131 132 // ipv6 133 n, err = data.Write([]byte{0, tcpOverIPv6}) 134 require.NoError(t, err) 135 require.Equal(t, 2, n) 136 137 // ipv4 length 138 data.WriteUint16(36) 139 // source address 140 for i := 0; i < 16; i++ { 141 err = data.WriteByte(byte(10 + i)) 142 require.NoError(t, err) 143 } 144 // target address 145 for i := 0; i < 16; i++ { 146 err = data.WriteByte(byte(50 + i)) 147 require.NoError(t, err) 148 } 149 // source port 150 data.WriteUint16(8000) 151 // target port 152 data.WriteUint16(9000) 153 154 pp := &proxyProtocolCodec{} 155 res, ok, err := pp.Decode(data) 156 require.NoError(t, err) 157 require.True(t, ok) 158 require.NotNil(t, res) 159 addr, ok := res.(*ProxyAddr) 160 require.True(t, ok) 161 require.Equal(t, "a0b:c0d:e0f:1011:1213:1415:1617:1819", addr.SourceAddress.To16().String()) 162 require.Equal(t, 8000, int(addr.SourcePort)) 163 require.Equal(t, "3233:3435:3637:3839:3a3b:3c3d:3e3f:4041", addr.TargetAddress.To16().String()) 164 require.Equal(t, 9000, int(addr.TargetPort)) 165 }) 166 167 t.Run("error length", func(t *testing.T) { 168 data := buf.NewByteBuf(100) 169 n, err := data.Write([]byte(ProxyProtocolV2Signature)) 170 require.NoError(t, err) 171 require.Equal(t, len(ProxyProtocolV2Signature), n) 172 173 // skip 2 bytes 174 n, err = data.Write([]byte{0, 0}) 175 require.NoError(t, err) 176 require.Equal(t, 2, n) 177 data.WriteUint16(33) 178 179 pp := &proxyProtocolCodec{} 180 res, ok, err := pp.Decode(data) 181 require.Error(t, err) 182 require.False(t, ok) 183 require.Nil(t, res) 184 }) 185 }