github.com/m3db/m3@v1.5.0/src/msg/protocol/proto/roundtrip_test.go (about)

     1  // Copyright (c) 2018 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package proto
    22  
    23  import (
    24  	"bufio"
    25  	"bytes"
    26  	"net"
    27  	"testing"
    28  
    29  	"github.com/m3db/m3/src/msg/generated/proto/msgpb"
    30  	"github.com/m3db/m3/src/x/pool"
    31  
    32  	"github.com/stretchr/testify/require"
    33  )
    34  
    35  func TestBaseEncodeDecodeRoundTripWithoutPool(t *testing.T) {
    36  	enc := NewEncoder(NewOptions()).(*encoder)
    37  	require.Equal(t, 4, len(enc.buffer))
    38  	require.Equal(t, 4, cap(enc.buffer))
    39  	require.Empty(t, enc.Bytes())
    40  
    41  	r := bytes.NewReader(nil)
    42  	buf := bufio.NewReader(r)
    43  	dec := NewDecoder(buf, NewOptions(), 10).(*decoder)
    44  	require.Equal(t, 4, len(dec.buffer))
    45  	require.Equal(t, 4, cap(dec.buffer))
    46  	encodeMsg := msgpb.Message{
    47  		Metadata: msgpb.Metadata{
    48  			Shard: 1,
    49  			Id:    2,
    50  		},
    51  		Value: make([]byte, 80),
    52  	}
    53  	decodeMsg := msgpb.Message{}
    54  
    55  	err := enc.Encode(&encodeMsg)
    56  	require.NoError(t, err)
    57  	require.Equal(t, sizeEncodingLength+encodeMsg.Size(), len(enc.buffer))
    58  	require.Equal(t, sizeEncodingLength+encodeMsg.Size(), cap(enc.buffer))
    59  
    60  	r.Reset(enc.Bytes())
    61  	require.NoError(t, dec.Decode(&decodeMsg))
    62  	require.Equal(t, sizeEncodingLength+decodeMsg.Size(), len(dec.buffer))
    63  	require.Equal(t, sizeEncodingLength+encodeMsg.Size(), cap(dec.buffer))
    64  }
    65  
    66  func TestBaseEncodeDecodeRoundTripWithPool(t *testing.T) {
    67  	p := getBytesPool(2, []int{2, 8, 100})
    68  	p.Init()
    69  
    70  	enc := NewEncoder(NewOptions().SetBytesPool(p)).(*encoder)
    71  	require.Equal(t, 8, len(enc.buffer))
    72  	require.Equal(t, 8, cap(enc.buffer))
    73  
    74  	r := bytes.NewReader(nil)
    75  	buf := bufio.NewReader(r)
    76  	dec := NewDecoder(buf, NewOptions().SetBytesPool(p), 10).(*decoder)
    77  	require.Equal(t, 8, len(dec.buffer))
    78  	require.Equal(t, 8, cap(dec.buffer))
    79  	encodeMsg := msgpb.Message{
    80  		Metadata: msgpb.Metadata{
    81  			Shard: 1,
    82  			Id:    2,
    83  		},
    84  		Value: make([]byte, 80),
    85  	}
    86  	decodeMsg := msgpb.Message{}
    87  
    88  	err := enc.Encode(&encodeMsg)
    89  	require.NoError(t, err)
    90  	require.Equal(t, 100, len(enc.buffer))
    91  	require.Equal(t, 100, cap(enc.buffer))
    92  
    93  	r.Reset(enc.Bytes())
    94  	require.NoError(t, dec.Decode(&decodeMsg))
    95  	require.Equal(t, 100, len(dec.buffer))
    96  	require.Equal(t, 100, cap(dec.buffer))
    97  }
    98  
    99  func TestResetReader(t *testing.T) {
   100  	enc := NewEncoder(nil)
   101  	r := bytes.NewReader(nil)
   102  	dec := NewDecoder(r, nil, 10)
   103  	encodeMsg := msgpb.Message{
   104  		Metadata: msgpb.Metadata{
   105  			Shard: 1,
   106  			Id:    2,
   107  		},
   108  		Value: make([]byte, 200),
   109  	}
   110  	decodeMsg := msgpb.Message{}
   111  
   112  	err := enc.Encode(&encodeMsg)
   113  	require.NoError(t, err)
   114  	require.Error(t, dec.Decode(&decodeMsg))
   115  
   116  	r2 := bytes.NewReader(enc.Bytes())
   117  	dec.(*decoder).ResetReader(r2)
   118  	require.NoError(t, dec.Decode(&decodeMsg))
   119  }
   120  
   121  func TestEncodeMessageLargerThanMaxSize(t *testing.T) {
   122  	opts := NewOptions().SetMaxMessageSize(4)
   123  	enc := NewEncoder(opts)
   124  	encodeMsg := msgpb.Message{
   125  		Metadata: msgpb.Metadata{
   126  			Shard: 1,
   127  			Id:    2,
   128  		},
   129  		Value: make([]byte, 10),
   130  	}
   131  
   132  	err := enc.Encode(&encodeMsg)
   133  	require.Error(t, err)
   134  	require.Contains(t, err.Error(), "larger than maximum supported size")
   135  }
   136  
   137  func TestDecodeMessageLargerThanMaxSize(t *testing.T) {
   138  	enc := NewEncoder(nil)
   139  	encodeMsg := msgpb.Message{
   140  		Metadata: msgpb.Metadata{
   141  			Shard: 1,
   142  			Id:    2,
   143  		},
   144  		Value: make([]byte, 10),
   145  	}
   146  
   147  	err := enc.Encode(&encodeMsg)
   148  	require.NoError(t, err)
   149  
   150  	decodeMsg := msgpb.Message{}
   151  	opts := NewOptions().SetMaxMessageSize(8)
   152  	buf := bufio.NewReader(bytes.NewReader(enc.Bytes()))
   153  	dec := NewDecoder(buf, opts, 10)
   154  
   155  	// NB(r): We need to make sure does not grow the buffer
   156  	// if over max size, so going to take size of buffer, make
   157  	// sure its sizeEncodingLength so we can measure if it increases at all.
   158  	require.Equal(t, sizeEncodingLength, cap(dec.(*decoder).buffer))
   159  
   160  	err = dec.Decode(&decodeMsg)
   161  	require.Error(t, err)
   162  	require.Contains(t, err.Error(), "larger than maximum supported size")
   163  
   164  	// Make sure did not grow buffer before returning error.
   165  	require.Equal(t, sizeEncodingLength, cap(dec.(*decoder).buffer))
   166  }
   167  
   168  func TestEncodeDecodeRoundTrip(t *testing.T) {
   169  	r := bytes.NewReader(nil)
   170  	buf := bufio.NewReader(r)
   171  
   172  	enc := NewEncoder(nil)
   173  	dec := NewDecoder(buf, nil, 10)
   174  
   175  	clientConn, serverConn := net.Pipe()
   176  	dec.ResetReader(serverConn)
   177  
   178  	testMsg := msgpb.Message{
   179  		Metadata: msgpb.Metadata{
   180  			Shard: 1,
   181  			Id:    2,
   182  		},
   183  		Value: make([]byte, 10),
   184  	}
   185  	go func() {
   186  		require.NoError(t, enc.Encode(&testMsg))
   187  		_, err := clientConn.Write(enc.Bytes())
   188  		require.NoError(t, err)
   189  	}()
   190  	var msg msgpb.Message
   191  	require.NoError(t, dec.Decode(&msg))
   192  	require.Equal(t, testMsg, msg)
   193  }
   194  
   195  // nolint: unparam
   196  func getBytesPool(bucketSizes int, bucketCaps []int) pool.BytesPool {
   197  	buckets := make([]pool.Bucket, len(bucketCaps))
   198  	for i, cap := range bucketCaps {
   199  		buckets[i] = pool.Bucket{
   200  			Count:    pool.Size(bucketSizes),
   201  			Capacity: cap,
   202  		}
   203  	}
   204  
   205  	return pool.NewBytesPool(buckets, nil)
   206  }