github.com/hashicorp/memberlist@v0.5.0/label_test.go (about)

     1  package memberlist
     2  
     3  import (
     4  	"bytes"
     5  	"io"
     6  	"net"
     7  	"strings"
     8  	"testing"
     9  
    10  	"github.com/stretchr/testify/require"
    11  )
    12  
    13  func TestAddLabelHeaderToPacket(t *testing.T) {
    14  	type testcase struct {
    15  		buf          []byte
    16  		label        string
    17  		expectPacket []byte
    18  		expectErr    string
    19  	}
    20  
    21  	run := func(t *testing.T, tc testcase) {
    22  		got, err := AddLabelHeaderToPacket(tc.buf, tc.label)
    23  		if tc.expectErr != "" {
    24  			require.Error(t, err)
    25  			require.Contains(t, err.Error(), tc.expectErr)
    26  		} else {
    27  			require.NoError(t, err)
    28  			require.Equal(t, tc.expectPacket, got)
    29  		}
    30  	}
    31  
    32  	longLabel := strings.Repeat("a", 255)
    33  
    34  	cases := map[string]testcase{
    35  		"nil buf with no label": testcase{
    36  			buf:          nil,
    37  			label:        "",
    38  			expectPacket: nil,
    39  		},
    40  		"nil buf with label": testcase{
    41  			buf:          nil,
    42  			label:        "foo",
    43  			expectPacket: append([]byte{byte(hasLabelMsg), 3}, []byte("foo")...),
    44  		},
    45  		"message with label": testcase{
    46  			buf:          []byte("something"),
    47  			label:        "foo",
    48  			expectPacket: append([]byte{byte(hasLabelMsg), 3}, []byte("foosomething")...),
    49  		},
    50  		"message with no label": testcase{
    51  			buf:          []byte("something"),
    52  			label:        "",
    53  			expectPacket: []byte("something"),
    54  		},
    55  		"message with almost too long label": testcase{
    56  			buf:          []byte("something"),
    57  			label:        longLabel,
    58  			expectPacket: append([]byte{byte(hasLabelMsg), 255}, []byte(longLabel+"something")...),
    59  		},
    60  		"label too long by one byte": testcase{
    61  			buf:       []byte("something"),
    62  			label:     longLabel + "x",
    63  			expectErr: `label "` + longLabel + `x" is too long`,
    64  		},
    65  	}
    66  
    67  	for name, tc := range cases {
    68  		t.Run(name, func(t *testing.T) {
    69  			run(t, tc)
    70  		})
    71  	}
    72  }
    73  
    74  func TestRemoveLabelHeaderFromPacket(t *testing.T) {
    75  	type testcase struct {
    76  		buf          []byte
    77  		expectLabel  string
    78  		expectPacket []byte
    79  		expectErr    string
    80  	}
    81  
    82  	run := func(t *testing.T, tc testcase) {
    83  		gotBuf, gotLabel, err := RemoveLabelHeaderFromPacket(tc.buf)
    84  		if tc.expectErr != "" {
    85  			require.Error(t, err)
    86  			require.Contains(t, err.Error(), tc.expectErr)
    87  		} else {
    88  			require.NoError(t, err)
    89  			require.Equal(t, tc.expectPacket, gotBuf)
    90  			require.Equal(t, tc.expectLabel, gotLabel)
    91  		}
    92  	}
    93  
    94  	cases := map[string]testcase{
    95  		"empty buf": testcase{
    96  			buf:          []byte{},
    97  			expectLabel:  "",
    98  			expectPacket: []byte{},
    99  		},
   100  		"ping with no label": testcase{
   101  			buf:          buildBuffer(t, pingMsg, "blah"),
   102  			expectLabel:  "",
   103  			expectPacket: buildBuffer(t, pingMsg, "blah"),
   104  		},
   105  		"error with no label": testcase{ // 2021-10: largest standard message type
   106  			buf:          buildBuffer(t, errMsg, "blah"),
   107  			expectLabel:  "",
   108  			expectPacket: buildBuffer(t, errMsg, "blah"),
   109  		},
   110  		"v1 encrypt with no label": testcase{ // 2021-10: highest encryption version
   111  			buf:          buildBuffer(t, maxEncryptionVersion, "blah"),
   112  			expectLabel:  "",
   113  			expectPacket: buildBuffer(t, maxEncryptionVersion, "blah"),
   114  		},
   115  		"buf too small for label": testcase{
   116  			buf:       buildBuffer(t, hasLabelMsg, "x"),
   117  			expectErr: `cannot decode label; packet has been truncated`,
   118  		},
   119  		"buf too small for label size": testcase{
   120  			buf:       buildBuffer(t, hasLabelMsg),
   121  			expectErr: `cannot decode label; packet has been truncated`,
   122  		},
   123  		"label empty": testcase{
   124  			buf:       buildBuffer(t, hasLabelMsg, 0, "x"),
   125  			expectErr: `label header cannot be empty when present`,
   126  		},
   127  		"label truncated": testcase{
   128  			buf:       buildBuffer(t, hasLabelMsg, 2, "x"),
   129  			expectErr: `cannot decode label; packet has been truncated`,
   130  		},
   131  		"ping with label": testcase{
   132  			buf:          buildBuffer(t, hasLabelMsg, 3, "abc", pingMsg, "blah"),
   133  			expectLabel:  "abc",
   134  			expectPacket: buildBuffer(t, pingMsg, "blah"),
   135  		},
   136  		"error with label": testcase{ // 2021-10: largest standard message type
   137  			buf:          buildBuffer(t, hasLabelMsg, 3, "abc", errMsg, "blah"),
   138  			expectLabel:  "abc",
   139  			expectPacket: buildBuffer(t, errMsg, "blah"),
   140  		},
   141  		"v1 encrypt with label": testcase{ // 2021-10: highest encryption version
   142  			buf:          buildBuffer(t, hasLabelMsg, 3, "abc", maxEncryptionVersion, "blah"),
   143  			expectLabel:  "abc",
   144  			expectPacket: buildBuffer(t, maxEncryptionVersion, "blah"),
   145  		},
   146  	}
   147  
   148  	for name, tc := range cases {
   149  		t.Run(name, func(t *testing.T) {
   150  			run(t, tc)
   151  		})
   152  	}
   153  }
   154  
   155  func TestAddLabelHeaderToStream(t *testing.T) {
   156  	type testcase struct {
   157  		label      string
   158  		expectData []byte
   159  		expectErr  string
   160  	}
   161  
   162  	suffixData := "EXTRA DATA"
   163  
   164  	run := func(t *testing.T, tc testcase) {
   165  		server, client := net.Pipe()
   166  		defer server.Close()
   167  		defer client.Close()
   168  
   169  		var (
   170  			dataCh = make(chan []byte, 1)
   171  			errCh  = make(chan error, 1)
   172  		)
   173  		go func() {
   174  			var buf bytes.Buffer
   175  			_, err := io.Copy(&buf, server)
   176  			if err != nil {
   177  				errCh <- err
   178  			}
   179  			dataCh <- buf.Bytes()
   180  		}()
   181  
   182  		err := AddLabelHeaderToStream(client, tc.label)
   183  		if tc.expectErr != "" {
   184  			require.Error(t, err)
   185  			require.Contains(t, err.Error(), tc.expectErr)
   186  			return
   187  		}
   188  		require.NoError(t, err)
   189  
   190  		client.Write([]byte(suffixData))
   191  		client.Close()
   192  
   193  		expect := make([]byte, 0, len(suffixData)+len(tc.expectData))
   194  		expect = append(expect, tc.expectData...)
   195  		expect = append(expect, suffixData...)
   196  
   197  		select {
   198  		case err := <-errCh:
   199  			require.NoError(t, err)
   200  		case got := <-dataCh:
   201  			require.Equal(t, expect, got)
   202  		}
   203  	}
   204  
   205  	longLabel := strings.Repeat("a", 255)
   206  
   207  	cases := map[string]testcase{
   208  		"no label": testcase{
   209  			label:      "",
   210  			expectData: nil,
   211  		},
   212  		"with label": testcase{
   213  			label:      "foo",
   214  			expectData: buildBuffer(t, hasLabelMsg, 3, "foo"),
   215  		},
   216  		"almost too long label": testcase{
   217  			label:      longLabel,
   218  			expectData: buildBuffer(t, hasLabelMsg, 255, longLabel),
   219  		},
   220  		"label too long by one byte": testcase{
   221  			label:     longLabel + "x",
   222  			expectErr: `label "` + longLabel + `x" is too long`,
   223  		},
   224  	}
   225  
   226  	for name, tc := range cases {
   227  		t.Run(name, func(t *testing.T) {
   228  			run(t, tc)
   229  		})
   230  	}
   231  }
   232  
   233  func TestRemoveLabelHeaderFromStream(t *testing.T) {
   234  	type testcase struct {
   235  		buf         []byte
   236  		expectLabel string
   237  		expectData  []byte
   238  		expectErr   string
   239  	}
   240  
   241  	run := func(t *testing.T, tc testcase) {
   242  		server, client := net.Pipe()
   243  		defer server.Close()
   244  		defer client.Close()
   245  
   246  		var (
   247  			errCh = make(chan error, 1)
   248  		)
   249  		go func() {
   250  			_, err := server.Write(tc.buf)
   251  			if err != nil {
   252  				errCh <- err
   253  			}
   254  			server.Close()
   255  		}()
   256  
   257  		newConn, gotLabel, err := RemoveLabelHeaderFromStream(client)
   258  		if tc.expectErr != "" {
   259  			require.Error(t, err)
   260  			require.Contains(t, err.Error(), tc.expectErr)
   261  			return
   262  		}
   263  		require.NoError(t, err)
   264  
   265  		gotBuf, err := io.ReadAll(newConn)
   266  		require.NoError(t, err)
   267  
   268  		require.Equal(t, tc.expectData, gotBuf)
   269  		require.Equal(t, tc.expectLabel, gotLabel)
   270  	}
   271  
   272  	cases := map[string]testcase{
   273  		"empty buf": testcase{
   274  			buf:         []byte{},
   275  			expectLabel: "",
   276  			expectData:  []byte{},
   277  		},
   278  		"ping with no label": testcase{
   279  			buf:         buildBuffer(t, pingMsg, "blah"),
   280  			expectLabel: "",
   281  			expectData:  buildBuffer(t, pingMsg, "blah"),
   282  		},
   283  		"error with no label": testcase{ // 2021-10: largest standard message type
   284  			buf:         buildBuffer(t, errMsg, "blah"),
   285  			expectLabel: "",
   286  			expectData:  buildBuffer(t, errMsg, "blah"),
   287  		},
   288  		"v1 encrypt with no label": testcase{ // 2021-10: highest encryption version
   289  			buf:         buildBuffer(t, maxEncryptionVersion, "blah"),
   290  			expectLabel: "",
   291  			expectData:  buildBuffer(t, maxEncryptionVersion, "blah"),
   292  		},
   293  		"buf too small for label": testcase{
   294  			buf:       buildBuffer(t, hasLabelMsg, "x"),
   295  			expectErr: `cannot decode label; stream has been truncated`,
   296  		},
   297  		"buf too small for label size": testcase{
   298  			buf:       buildBuffer(t, hasLabelMsg),
   299  			expectErr: `cannot decode label; stream has been truncated`,
   300  		},
   301  		"label empty": testcase{
   302  			buf:       buildBuffer(t, hasLabelMsg, 0, "x"),
   303  			expectErr: `label header cannot be empty when present`,
   304  		},
   305  		"label truncated": testcase{
   306  			buf:       buildBuffer(t, hasLabelMsg, 2, "x"),
   307  			expectErr: `cannot decode label; stream has been truncated`,
   308  		},
   309  		"ping with label": testcase{
   310  			buf:         buildBuffer(t, hasLabelMsg, 3, "abc", pingMsg, "blah"),
   311  			expectLabel: "abc",
   312  			expectData:  buildBuffer(t, pingMsg, "blah"),
   313  		},
   314  		"error with label": testcase{ // 2021-10: largest standard message type
   315  			buf:         buildBuffer(t, hasLabelMsg, 3, "abc", errMsg, "blah"),
   316  			expectLabel: "abc",
   317  			expectData:  buildBuffer(t, errMsg, "blah"),
   318  		},
   319  		"v1 encrypt with label": testcase{ // 2021-10: highest encryption version
   320  			buf:         buildBuffer(t, hasLabelMsg, 3, "abc", maxEncryptionVersion, "blah"),
   321  			expectLabel: "abc",
   322  			expectData:  buildBuffer(t, maxEncryptionVersion, "blah"),
   323  		},
   324  	}
   325  
   326  	for name, tc := range cases {
   327  		t.Run(name, func(t *testing.T) {
   328  			run(t, tc)
   329  		})
   330  	}
   331  }
   332  
   333  func buildBuffer(t *testing.T, stuff ...interface{}) []byte {
   334  	var buf bytes.Buffer
   335  	for _, item := range stuff {
   336  		switch x := item.(type) {
   337  		case int:
   338  			x2 := uint(x)
   339  			if x2 > 255 {
   340  				t.Fatalf("int is too big")
   341  			}
   342  			buf.WriteByte(byte(x2))
   343  		case byte:
   344  			buf.WriteByte(byte(x))
   345  		case messageType:
   346  			buf.WriteByte(byte(x))
   347  		case encryptionVersion:
   348  			buf.WriteByte(byte(x))
   349  		case string:
   350  			buf.Write([]byte(x))
   351  		case []byte:
   352  			buf.Write(x)
   353  		default:
   354  			t.Fatalf("unexpected type %T", item)
   355  		}
   356  	}
   357  	return buf.Bytes()
   358  }
   359  
   360  func TestLabelOverhead(t *testing.T) {
   361  	require.Equal(t, 0, labelOverhead(""))
   362  	require.Equal(t, 3, labelOverhead("a"))
   363  	require.Equal(t, 9, labelOverhead("abcdefg"))
   364  }